pkg/store -> pkg/connector/store

Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
This commit is contained in:
Sumner Evans
2024-07-01 15:55:12 -06:00
parent cbba340da6
commit 0921168b91
6 changed files with 5 additions and 3 deletions
+1 -1
View File
@@ -25,7 +25,7 @@ import (
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/bridgev2"
"go.mau.fi/mautrix-telegram/pkg/store"
"go.mau.fi/mautrix-telegram/pkg/connector/store"
)
type TelegramConfig struct {
+47
View File
@@ -0,0 +1,47 @@
// mautrix-telegram - A Matrix-Telegram puppeting bridge.
// Copyright (C) 2024 Sumner Evans
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"context"
"go.mau.fi/util/dbutil"
"go.mau.fi/mautrix-telegram/pkg/connector/store/upgrades"
)
type Container struct {
*dbutil.Database
TelegramFile *TelegramFileQuery
}
func NewStore(db *dbutil.Database, log dbutil.DatabaseLogger) *Container {
return &Container{
Database: db.Child("telegram_version", upgrades.Table, log),
TelegramFile: &TelegramFileQuery{dbutil.MakeQueryHelper(db, newTelegramFile)},
}
}
func (c *Container) Upgrade(ctx context.Context) error {
return c.Database.Upgrade(ctx)
}
func (c *Container) GetScopedStore(telegramUserID int64) *scopedStore {
return &scopedStore{c.Database, telegramUserID}
}
+179
View File
@@ -0,0 +1,179 @@
package store
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/gotd/td/session"
"github.com/gotd/td/telegram/updates"
"go.mau.fi/util/dbutil"
)
// scopedStore is a wrapper around a database that implements
// [session.Storage] scoped to a specific Telegram user ID.
type scopedStore struct {
db *dbutil.Database
telegramUserID int64
}
const (
// Session Storage Queries
loadSessionQuery = `SELECT session_data FROM telegram_session WHERE user_id=$1`
storeSessionQuery = `
INSERT INTO telegram_session (user_id, session_data)
VALUES ($1, $2)
ON CONFLICT (user_id) DO UPDATE SET session_data=excluded.session_data
`
// State Storage Queries
allChannelsQuery = "SELECT channel_id, pts FROM telegram_channel_state WHERE user_id=$1"
getChannelPtsQuery = "SELECT pts FROM telegram_channel_state WHERE user_id=$1 AND channel_id=$2"
setChannelPtsQuery = `
INSERT INTO telegram_channel_state (user_id, channel_id, pts)
VALUES ($1, $2, $3)
ON CONFLICT (user_id, channel_id) DO UPDATE SET pts=excluded.pts
`
getStateQuery = "SELECT pts, qts, date, seq from telegram_user_state WHERE user_id=$1"
setStateQuery = `
INSERT INTO telegram_user_state (user_id, pts, qts, date, seq)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (user_id) DO UPDATE SET
pts=excluded.pts,
qts=excluded.qts,
date=excluded.date,
seq=excluded.seq
`
setPtsQuery = "UPDATE telegram_user_state SET pts=$1 WHERE user_id=$2"
setQtsQuery = "UPDATE telegram_user_state SET qts=$1 WHERE user_id=$2"
setDateQuery = "UPDATE telegram_user_state SET date=$1 WHERE user_id=$2"
setSeqQuery = "UPDATE telegram_user_state SET seq=$1 WHERE user_id=$2"
setDateSeqQuery = "UPDATE telegram_user_state SET date=$1, seq=$2 WHERE user_id=$3"
// Channel Access Hasher Queries
getChannelAccessHashQuery = "SELECT access_hash FROM telegram_channel_access_hashes WHERE user_id=$1 AND channel_id=$2"
setChannelAccessHashQuery = `
INSERT INTO telegram_channel_access_hashes (user_id, channel_id, access_hash)
VALUES ($1, $2, $3)
ON CONFLICT (user_id, channel_id) DO UPDATE SET access_hash=excluded.access_hash
`
)
var _ session.Storage = (*scopedStore)(nil)
func (s *scopedStore) LoadSession(ctx context.Context) (sessionData []byte, err error) {
row := s.db.QueryRow(ctx, loadSessionQuery, s.telegramUserID)
err = row.Scan(&sessionData)
return
}
func (s *scopedStore) StoreSession(ctx context.Context, data []byte) error {
_, err := s.db.Exec(ctx, storeSessionQuery, s.telegramUserID, data)
return err
}
var _ updates.StateStorage = (*scopedStore)(nil)
func (s *scopedStore) ForEachChannels(ctx context.Context, userID int64, f func(ctx context.Context, channelID int64, pts int) error) error {
s.assertUserIDMatches(userID)
rows, err := s.db.Query(ctx, allChannelsQuery, userID)
if err != nil {
return err
}
var channelID int64
var pts int
for rows.Next() {
if err = rows.Scan(&channelID, &pts); err != nil {
return err
} else if err = f(ctx, channelID, pts); err != nil {
return err
}
}
return nil
}
func (s *scopedStore) GetChannelPts(ctx context.Context, userID int64, channelID int64) (pts int, found bool, err error) {
s.assertUserIDMatches(userID)
err = s.db.QueryRow(ctx, getChannelPtsQuery, userID, channelID).Scan(&pts)
if errors.Is(err, sql.ErrNoRows) {
return 0, false, nil
}
return pts, err == nil, err
}
func (s *scopedStore) SetChannelPts(ctx context.Context, userID int64, channelID int64, pts int) (err error) {
s.assertUserIDMatches(userID)
_, err = s.db.Exec(ctx, setChannelPtsQuery, userID, channelID, pts)
return
}
func (s *scopedStore) GetState(ctx context.Context, userID int64) (state updates.State, found bool, err error) {
s.assertUserIDMatches(userID)
err = s.db.QueryRow(ctx, getStateQuery, userID).Scan(&state.Pts, &state.Qts, &state.Date, &state.Seq)
if errors.Is(err, sql.ErrNoRows) {
return state, false, nil
}
return state, err == nil, err
}
func (s *scopedStore) SetState(ctx context.Context, userID int64, state updates.State) (err error) {
s.assertUserIDMatches(userID)
_, err = s.db.Exec(ctx, setStateQuery, userID, state.Pts, state.Qts, state.Date, state.Seq)
return
}
func (s *scopedStore) SetPts(ctx context.Context, userID int64, pts int) (err error) {
s.assertUserIDMatches(userID)
_, err = s.db.Exec(ctx, setPtsQuery, userID, pts)
return
}
func (s *scopedStore) SetQts(ctx context.Context, userID int64, qts int) (err error) {
s.assertUserIDMatches(userID)
_, err = s.db.Exec(ctx, setQtsQuery, userID, qts)
return
}
func (s *scopedStore) SetSeq(ctx context.Context, userID int64, seq int) (err error) {
s.assertUserIDMatches(userID)
_, err = s.db.Exec(ctx, setSeqQuery, userID, seq)
return
}
func (s *scopedStore) SetDate(ctx context.Context, userID int64, date int) (err error) {
s.assertUserIDMatches(userID)
_, err = s.db.Exec(ctx, setDateQuery, userID, date)
return
}
func (s *scopedStore) SetDateSeq(ctx context.Context, userID int64, date int, seq int) (err error) {
s.assertUserIDMatches(userID)
_, err = s.db.Exec(ctx, setDateSeqQuery, userID, date, seq)
return
}
var _ updates.ChannelAccessHasher = (*scopedStore)(nil)
func (s *scopedStore) GetChannelAccessHash(ctx context.Context, userID int64, channelID int64) (accessHash int64, found bool, err error) {
s.assertUserIDMatches(userID)
err = s.db.QueryRow(ctx, getChannelAccessHashQuery, userID, channelID).Scan(&accessHash)
if errors.Is(err, sql.ErrNoRows) {
return 0, false, nil
}
return accessHash, err == nil, err
}
func (s *scopedStore) SetChannelAccessHash(ctx context.Context, userID int64, channelID int64, accessHash int64) (err error) {
s.assertUserIDMatches(userID)
_, err = s.db.Exec(ctx, setChannelAccessHashQuery, userID, channelID, accessHash)
return
}
// Helper Functions
func (s *scopedStore) assertUserIDMatches(userID int64) {
if s.telegramUserID != userID {
panic(fmt.Sprintf("scoped store for %d function called with user ID %d", s.telegramUserID, userID))
}
}
+98
View File
@@ -0,0 +1,98 @@
package store
import (
"context"
"database/sql"
"encoding/json"
"time"
"go.mau.fi/util/dbutil"
)
const (
insertTelegramFileQuery = `
INSERT INTO telegram_file (
id, mxc, mime_type, was_converted, timestamp, size, width, height, thumbnail, decryption_info)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
`
getTelegramFileSelect = `
SELECT id, mxc, mime_type, was_converted, timestamp, size, width, height, thumbnail, decryption_info
FROM telegram_file
`
getTelegramFileByLocationIDQuery = getTelegramFileSelect + "WHERE id=$1"
getTelegramFileByMXCQuery = getTelegramFileSelect + "WHERE mxc=$1"
)
type TelegramFileQuery struct {
*dbutil.QueryHelper[*TelegramFile]
}
type TelegramFileLocationID string
type TelegramFile struct {
qh *dbutil.QueryHelper[*TelegramFile]
LocationID TelegramFileLocationID
MXC string
MimeType string
WasConverted bool
Timestamp time.Time
Size int64
Width int
Height int
ThumbnailID string
DecryptionInfo json.RawMessage
}
var _ dbutil.DataStruct[*TelegramFile] = (*TelegramFile)(nil)
func newTelegramFile(qh *dbutil.QueryHelper[*TelegramFile]) *TelegramFile {
return &TelegramFile{qh: qh}
}
func (fq *TelegramFileQuery) GetByLocationID(ctx context.Context, locationID string) (*TelegramFile, error) {
return fq.QueryOne(ctx, getTelegramFileByLocationIDQuery, locationID)
}
func (fq *TelegramFileQuery) GetByMXC(ctx context.Context, mxc string) (*TelegramFile, error) {
return fq.QueryOne(ctx, getTelegramFileByMXCQuery, mxc)
}
func (f *TelegramFile) sqlVariables() []any {
return []any{
f.LocationID,
f.MXC,
f.MimeType,
f.WasConverted,
f.Timestamp.UnixMilli(),
f.Size,
f.Width,
f.Height,
f.ThumbnailID,
f.DecryptionInfo,
}
}
func (f *TelegramFile) Insert(ctx context.Context) error {
return f.qh.Exec(ctx, insertTelegramFileQuery, f.sqlVariables()...)
}
func (f *TelegramFile) Scan(row dbutil.Scannable) (*TelegramFile, error) {
var thumbnailID sql.NullString
var timestamp int64
err := row.Scan(
&f.LocationID,
&f.MXC,
&f.MimeType,
&f.WasConverted,
&timestamp,
&f.Size,
&f.Width,
&f.Height,
&thumbnailID,
&f.DecryptionInfo,
)
f.Timestamp = time.UnixMilli(timestamp)
f.ThumbnailID = thumbnailID.String
return f, err
}
@@ -0,0 +1,47 @@
-- v0 -> v1: Latest revision
CREATE TABLE telegram_session (
user_id INTEGER PRIMARY KEY,
session_data BYTEA NOT NULL
);
CREATE TABLE telegram_user_state (
user_id INTEGER PRIMARY KEY,
pts INTEGER NOT NULL,
qts INTEGER NOT NULL,
date INTEGER NOT NULL,
seq INTEGER NOT NULL
);
CREATE TABLE telegram_channel_state (
user_id INTEGER,
channel_id INTEGER,
pts INTEGER NOT NULL,
PRIMARY KEY (user_id, channel_id)
);
CREATE INDEX idx_telegram_channel_state_user_id ON telegram_channel_state (user_id);
CREATE TABLE telegram_channel_access_hashes (
user_id INTEGER,
channel_id INTEGER,
access_hash INTEGER NOT NULL,
PRIMARY KEY (user_id, channel_id)
);
CREATE TABLE telegram_file (
id TEXT PRIMARY KEY,
mxc TEXT NOT NULL,
mime_type TEXT,
was_converted BOOLEAN NOT NULL DEFAULT false,
timestamp BIGINT NOT NULL DEFAULT 0,
size BIGINT,
width INTEGER,
height INTEGER,
thumbnail TEXT,
decryption_info jsonb,
FOREIGN KEY (thumbnail) REFERENCES telegram_file(id)
ON UPDATE CASCADE ON DELETE SET NULL
);
+32
View File
@@ -0,0 +1,32 @@
// mautrix-telegram - A Matrix-Telegram puppeting bridge.
// Copyright (C) 2024 Sumner Evans
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package upgrades
import (
"embed"
"go.mau.fi/util/dbutil"
)
var Table dbutil.UpgradeTable
//go:embed *.sql
var rawUpgrades embed.FS
func init() {
Table.RegisterFS(rawUpgrades)
}