pkg/store -> pkg/connector/store
Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
This commit is contained in:
@@ -25,7 +25,7 @@ import (
|
||||
"go.mau.fi/util/dbutil"
|
||||
"maunium.net/go/mautrix/bridgev2"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/store"
|
||||
"go.mau.fi/mautrix-telegram/pkg/connector/store"
|
||||
)
|
||||
|
||||
type TelegramConfig struct {
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
// mautrix-telegram - A Matrix-Telegram puppeting bridge.
|
||||
// Copyright (C) 2024 Sumner Evans
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/connector/store/upgrades"
|
||||
)
|
||||
|
||||
type Container struct {
|
||||
*dbutil.Database
|
||||
|
||||
TelegramFile *TelegramFileQuery
|
||||
}
|
||||
|
||||
func NewStore(db *dbutil.Database, log dbutil.DatabaseLogger) *Container {
|
||||
return &Container{
|
||||
Database: db.Child("telegram_version", upgrades.Table, log),
|
||||
|
||||
TelegramFile: &TelegramFileQuery{dbutil.MakeQueryHelper(db, newTelegramFile)},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Container) Upgrade(ctx context.Context) error {
|
||||
return c.Database.Upgrade(ctx)
|
||||
}
|
||||
|
||||
func (c *Container) GetScopedStore(telegramUserID int64) *scopedStore {
|
||||
return &scopedStore{c.Database, telegramUserID}
|
||||
}
|
||||
@@ -0,0 +1,179 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/gotd/td/session"
|
||||
"github.com/gotd/td/telegram/updates"
|
||||
"go.mau.fi/util/dbutil"
|
||||
)
|
||||
|
||||
// scopedStore is a wrapper around a database that implements
|
||||
// [session.Storage] scoped to a specific Telegram user ID.
|
||||
type scopedStore struct {
|
||||
db *dbutil.Database
|
||||
telegramUserID int64
|
||||
}
|
||||
|
||||
const (
|
||||
// Session Storage Queries
|
||||
loadSessionQuery = `SELECT session_data FROM telegram_session WHERE user_id=$1`
|
||||
storeSessionQuery = `
|
||||
INSERT INTO telegram_session (user_id, session_data)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT (user_id) DO UPDATE SET session_data=excluded.session_data
|
||||
`
|
||||
|
||||
// State Storage Queries
|
||||
allChannelsQuery = "SELECT channel_id, pts FROM telegram_channel_state WHERE user_id=$1"
|
||||
getChannelPtsQuery = "SELECT pts FROM telegram_channel_state WHERE user_id=$1 AND channel_id=$2"
|
||||
setChannelPtsQuery = `
|
||||
INSERT INTO telegram_channel_state (user_id, channel_id, pts)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (user_id, channel_id) DO UPDATE SET pts=excluded.pts
|
||||
`
|
||||
getStateQuery = "SELECT pts, qts, date, seq from telegram_user_state WHERE user_id=$1"
|
||||
setStateQuery = `
|
||||
INSERT INTO telegram_user_state (user_id, pts, qts, date, seq)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (user_id) DO UPDATE SET
|
||||
pts=excluded.pts,
|
||||
qts=excluded.qts,
|
||||
date=excluded.date,
|
||||
seq=excluded.seq
|
||||
`
|
||||
setPtsQuery = "UPDATE telegram_user_state SET pts=$1 WHERE user_id=$2"
|
||||
setQtsQuery = "UPDATE telegram_user_state SET qts=$1 WHERE user_id=$2"
|
||||
setDateQuery = "UPDATE telegram_user_state SET date=$1 WHERE user_id=$2"
|
||||
setSeqQuery = "UPDATE telegram_user_state SET seq=$1 WHERE user_id=$2"
|
||||
setDateSeqQuery = "UPDATE telegram_user_state SET date=$1, seq=$2 WHERE user_id=$3"
|
||||
|
||||
// Channel Access Hasher Queries
|
||||
getChannelAccessHashQuery = "SELECT access_hash FROM telegram_channel_access_hashes WHERE user_id=$1 AND channel_id=$2"
|
||||
setChannelAccessHashQuery = `
|
||||
INSERT INTO telegram_channel_access_hashes (user_id, channel_id, access_hash)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (user_id, channel_id) DO UPDATE SET access_hash=excluded.access_hash
|
||||
`
|
||||
)
|
||||
|
||||
var _ session.Storage = (*scopedStore)(nil)
|
||||
|
||||
func (s *scopedStore) LoadSession(ctx context.Context) (sessionData []byte, err error) {
|
||||
row := s.db.QueryRow(ctx, loadSessionQuery, s.telegramUserID)
|
||||
err = row.Scan(&sessionData)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *scopedStore) StoreSession(ctx context.Context, data []byte) error {
|
||||
_, err := s.db.Exec(ctx, storeSessionQuery, s.telegramUserID, data)
|
||||
return err
|
||||
}
|
||||
|
||||
var _ updates.StateStorage = (*scopedStore)(nil)
|
||||
|
||||
func (s *scopedStore) ForEachChannels(ctx context.Context, userID int64, f func(ctx context.Context, channelID int64, pts int) error) error {
|
||||
s.assertUserIDMatches(userID)
|
||||
rows, err := s.db.Query(ctx, allChannelsQuery, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var channelID int64
|
||||
var pts int
|
||||
for rows.Next() {
|
||||
if err = rows.Scan(&channelID, &pts); err != nil {
|
||||
return err
|
||||
} else if err = f(ctx, channelID, pts); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *scopedStore) GetChannelPts(ctx context.Context, userID int64, channelID int64) (pts int, found bool, err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
err = s.db.QueryRow(ctx, getChannelPtsQuery, userID, channelID).Scan(&pts)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return 0, false, nil
|
||||
}
|
||||
return pts, err == nil, err
|
||||
}
|
||||
|
||||
func (s *scopedStore) SetChannelPts(ctx context.Context, userID int64, channelID int64, pts int) (err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
_, err = s.db.Exec(ctx, setChannelPtsQuery, userID, channelID, pts)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *scopedStore) GetState(ctx context.Context, userID int64) (state updates.State, found bool, err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
err = s.db.QueryRow(ctx, getStateQuery, userID).Scan(&state.Pts, &state.Qts, &state.Date, &state.Seq)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return state, false, nil
|
||||
}
|
||||
return state, err == nil, err
|
||||
}
|
||||
|
||||
func (s *scopedStore) SetState(ctx context.Context, userID int64, state updates.State) (err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
_, err = s.db.Exec(ctx, setStateQuery, userID, state.Pts, state.Qts, state.Date, state.Seq)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *scopedStore) SetPts(ctx context.Context, userID int64, pts int) (err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
_, err = s.db.Exec(ctx, setPtsQuery, userID, pts)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *scopedStore) SetQts(ctx context.Context, userID int64, qts int) (err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
_, err = s.db.Exec(ctx, setQtsQuery, userID, qts)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *scopedStore) SetSeq(ctx context.Context, userID int64, seq int) (err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
_, err = s.db.Exec(ctx, setSeqQuery, userID, seq)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *scopedStore) SetDate(ctx context.Context, userID int64, date int) (err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
_, err = s.db.Exec(ctx, setDateQuery, userID, date)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *scopedStore) SetDateSeq(ctx context.Context, userID int64, date int, seq int) (err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
_, err = s.db.Exec(ctx, setDateSeqQuery, userID, date, seq)
|
||||
return
|
||||
}
|
||||
|
||||
var _ updates.ChannelAccessHasher = (*scopedStore)(nil)
|
||||
|
||||
func (s *scopedStore) GetChannelAccessHash(ctx context.Context, userID int64, channelID int64) (accessHash int64, found bool, err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
err = s.db.QueryRow(ctx, getChannelAccessHashQuery, userID, channelID).Scan(&accessHash)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return 0, false, nil
|
||||
}
|
||||
return accessHash, err == nil, err
|
||||
}
|
||||
|
||||
func (s *scopedStore) SetChannelAccessHash(ctx context.Context, userID int64, channelID int64, accessHash int64) (err error) {
|
||||
s.assertUserIDMatches(userID)
|
||||
_, err = s.db.Exec(ctx, setChannelAccessHashQuery, userID, channelID, accessHash)
|
||||
return
|
||||
}
|
||||
|
||||
// Helper Functions
|
||||
|
||||
func (s *scopedStore) assertUserIDMatches(userID int64) {
|
||||
if s.telegramUserID != userID {
|
||||
panic(fmt.Sprintf("scoped store for %d function called with user ID %d", s.telegramUserID, userID))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
)
|
||||
|
||||
const (
|
||||
insertTelegramFileQuery = `
|
||||
INSERT INTO telegram_file (
|
||||
id, mxc, mime_type, was_converted, timestamp, size, width, height, thumbnail, decryption_info)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||
`
|
||||
getTelegramFileSelect = `
|
||||
SELECT id, mxc, mime_type, was_converted, timestamp, size, width, height, thumbnail, decryption_info
|
||||
FROM telegram_file
|
||||
`
|
||||
getTelegramFileByLocationIDQuery = getTelegramFileSelect + "WHERE id=$1"
|
||||
getTelegramFileByMXCQuery = getTelegramFileSelect + "WHERE mxc=$1"
|
||||
)
|
||||
|
||||
type TelegramFileQuery struct {
|
||||
*dbutil.QueryHelper[*TelegramFile]
|
||||
}
|
||||
|
||||
type TelegramFileLocationID string
|
||||
|
||||
type TelegramFile struct {
|
||||
qh *dbutil.QueryHelper[*TelegramFile]
|
||||
|
||||
LocationID TelegramFileLocationID
|
||||
MXC string
|
||||
MimeType string
|
||||
WasConverted bool
|
||||
Timestamp time.Time
|
||||
Size int64
|
||||
Width int
|
||||
Height int
|
||||
ThumbnailID string
|
||||
DecryptionInfo json.RawMessage
|
||||
}
|
||||
|
||||
var _ dbutil.DataStruct[*TelegramFile] = (*TelegramFile)(nil)
|
||||
|
||||
func newTelegramFile(qh *dbutil.QueryHelper[*TelegramFile]) *TelegramFile {
|
||||
return &TelegramFile{qh: qh}
|
||||
}
|
||||
|
||||
func (fq *TelegramFileQuery) GetByLocationID(ctx context.Context, locationID string) (*TelegramFile, error) {
|
||||
return fq.QueryOne(ctx, getTelegramFileByLocationIDQuery, locationID)
|
||||
}
|
||||
|
||||
func (fq *TelegramFileQuery) GetByMXC(ctx context.Context, mxc string) (*TelegramFile, error) {
|
||||
return fq.QueryOne(ctx, getTelegramFileByMXCQuery, mxc)
|
||||
}
|
||||
|
||||
func (f *TelegramFile) sqlVariables() []any {
|
||||
return []any{
|
||||
f.LocationID,
|
||||
f.MXC,
|
||||
f.MimeType,
|
||||
f.WasConverted,
|
||||
f.Timestamp.UnixMilli(),
|
||||
f.Size,
|
||||
f.Width,
|
||||
f.Height,
|
||||
f.ThumbnailID,
|
||||
f.DecryptionInfo,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *TelegramFile) Insert(ctx context.Context) error {
|
||||
return f.qh.Exec(ctx, insertTelegramFileQuery, f.sqlVariables()...)
|
||||
}
|
||||
|
||||
func (f *TelegramFile) Scan(row dbutil.Scannable) (*TelegramFile, error) {
|
||||
var thumbnailID sql.NullString
|
||||
var timestamp int64
|
||||
err := row.Scan(
|
||||
&f.LocationID,
|
||||
&f.MXC,
|
||||
&f.MimeType,
|
||||
&f.WasConverted,
|
||||
×tamp,
|
||||
&f.Size,
|
||||
&f.Width,
|
||||
&f.Height,
|
||||
&thumbnailID,
|
||||
&f.DecryptionInfo,
|
||||
)
|
||||
f.Timestamp = time.UnixMilli(timestamp)
|
||||
f.ThumbnailID = thumbnailID.String
|
||||
return f, err
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
-- v0 -> v1: Latest revision
|
||||
|
||||
CREATE TABLE telegram_session (
|
||||
user_id INTEGER PRIMARY KEY,
|
||||
session_data BYTEA NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE telegram_user_state (
|
||||
user_id INTEGER PRIMARY KEY,
|
||||
pts INTEGER NOT NULL,
|
||||
qts INTEGER NOT NULL,
|
||||
date INTEGER NOT NULL,
|
||||
seq INTEGER NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE telegram_channel_state (
|
||||
user_id INTEGER,
|
||||
channel_id INTEGER,
|
||||
pts INTEGER NOT NULL,
|
||||
|
||||
PRIMARY KEY (user_id, channel_id)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_telegram_channel_state_user_id ON telegram_channel_state (user_id);
|
||||
|
||||
CREATE TABLE telegram_channel_access_hashes (
|
||||
user_id INTEGER,
|
||||
channel_id INTEGER,
|
||||
access_hash INTEGER NOT NULL,
|
||||
|
||||
PRIMARY KEY (user_id, channel_id)
|
||||
);
|
||||
|
||||
CREATE TABLE telegram_file (
|
||||
id TEXT PRIMARY KEY,
|
||||
mxc TEXT NOT NULL,
|
||||
mime_type TEXT,
|
||||
was_converted BOOLEAN NOT NULL DEFAULT false,
|
||||
timestamp BIGINT NOT NULL DEFAULT 0,
|
||||
size BIGINT,
|
||||
width INTEGER,
|
||||
height INTEGER,
|
||||
thumbnail TEXT,
|
||||
decryption_info jsonb,
|
||||
FOREIGN KEY (thumbnail) REFERENCES telegram_file(id)
|
||||
ON UPDATE CASCADE ON DELETE SET NULL
|
||||
);
|
||||
@@ -0,0 +1,32 @@
|
||||
// mautrix-telegram - A Matrix-Telegram puppeting bridge.
|
||||
// Copyright (C) 2024 Sumner Evans
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package upgrades
|
||||
|
||||
import (
|
||||
"embed"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
)
|
||||
|
||||
var Table dbutil.UpgradeTable
|
||||
|
||||
//go:embed *.sql
|
||||
var rawUpgrades embed.FS
|
||||
|
||||
func init() {
|
||||
Table.RegisterFS(rawUpgrades)
|
||||
}
|
||||
Reference in New Issue
Block a user