Switch from SQLAlchemy to asyncpg/aiosqlite
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
# mautrix-telegram - A Matrix-Telegram puppeting bridge
|
||||
# Copyright (C) 2019 Tulir Asokan
|
||||
# Copyright (C) 2021 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
@@ -13,19 +13,23 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from mautrix.util.async_db import Database
|
||||
|
||||
from mautrix.client.state_store.sqlalchemy import UserProfile, RoomState
|
||||
from .upgrade import upgrade_table
|
||||
|
||||
from .bot_chat import BotChat
|
||||
from .message import Message
|
||||
from .portal import Portal
|
||||
from .puppet import Puppet
|
||||
from .telegram_file import TelegramFile
|
||||
from .user import User, UserPortal, Contact
|
||||
from .user import User
|
||||
from .telethon_session import PgSession
|
||||
|
||||
|
||||
def init(db_engine: Engine) -> None:
|
||||
for table in (Portal, Message, User, Contact, UserPortal, Puppet, TelegramFile, UserProfile,
|
||||
RoomState, BotChat):
|
||||
table.bind(db_engine)
|
||||
def init(db: Database) -> None:
|
||||
for table in (Portal, Message, User, Puppet, TelegramFile, BotChat, PgSession):
|
||||
table.db = db
|
||||
|
||||
|
||||
__all__ = ["upgrade_table", "init", "Portal", "Message", "User", "Puppet", "TelegramFile",
|
||||
"BotChat", "PgSession"]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# mautrix-telegram - A Matrix-Telegram puppeting bridge
|
||||
# Copyright (C) 2019 Tulir Asokan
|
||||
# Copyright (C) 2021 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
@@ -13,26 +13,43 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Iterable
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import Column, BigInteger, String
|
||||
from typing import ClassVar, TYPE_CHECKING
|
||||
|
||||
from mautrix.util.db import Base
|
||||
from asyncpg import Record
|
||||
from attr import dataclass
|
||||
|
||||
from mautrix.util.async_db import Database
|
||||
|
||||
from ..types import TelegramID
|
||||
|
||||
fake_db = Database.create("") if TYPE_CHECKING else None
|
||||
|
||||
|
||||
# Fucking Telegram not telling bots what chats they are in 3:<
|
||||
class BotChat(Base):
|
||||
__tablename__ = "bot_chat"
|
||||
id: TelegramID = Column(BigInteger, primary_key=True)
|
||||
type: str = Column(String, nullable=False)
|
||||
@dataclass
|
||||
class BotChat:
|
||||
db: ClassVar[Database] = fake_db
|
||||
|
||||
id: TelegramID
|
||||
type: str
|
||||
|
||||
@classmethod
|
||||
def delete_by_id(cls, chat_id: TelegramID) -> None:
|
||||
with cls.db.begin() as conn:
|
||||
conn.execute(cls.t.delete().where(cls.c.id == chat_id))
|
||||
def _from_row(cls, row: Record | None) -> BotChat | None:
|
||||
if row is None:
|
||||
return None
|
||||
return cls(**row)
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> Iterable['BotChat']:
|
||||
return cls._select_all()
|
||||
async def delete_by_id(cls, chat_id: TelegramID) -> None:
|
||||
await cls.db.execute("DELETE FROM bot_chat WHERE id=$1", chat_id)
|
||||
|
||||
@classmethod
|
||||
async def all(cls) -> list[BotChat]:
|
||||
rows = await cls.db.fetch("SELECT id, type FROM bot_chat")
|
||||
return [cls._from_row(row) for row in rows]
|
||||
|
||||
async def insert(self) -> None:
|
||||
q = "INSERT INTO bot_chat (id, type) VALUES ($1, $2)"
|
||||
await self.db.execute(q, self.id, self.type)
|
||||
|
||||
+114
-64
@@ -1,5 +1,5 @@
|
||||
# mautrix-telegram - A Matrix-Telegram puppeting bridge
|
||||
# Copyright (C) 2019 Tulir Asokan
|
||||
# Copyright (C) 2021 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
@@ -13,96 +13,146 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional, Iterator, List
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import (Column, UniqueConstraint, BigInteger, Integer, String, Boolean, and_, func,
|
||||
desc, select, false)
|
||||
from typing import ClassVar, TYPE_CHECKING
|
||||
|
||||
from asyncpg import Record
|
||||
from attr import dataclass
|
||||
|
||||
from mautrix.types import RoomID, EventID
|
||||
from mautrix.util.db import Base
|
||||
from mautrix.util.async_db import Database
|
||||
|
||||
from ..types import TelegramID
|
||||
|
||||
fake_db = Database.create("") if TYPE_CHECKING else None
|
||||
|
||||
class Message(Base):
|
||||
__tablename__ = "message"
|
||||
|
||||
mxid: EventID = Column(String)
|
||||
mx_room: RoomID = Column(String)
|
||||
tgid: TelegramID = Column(BigInteger, primary_key=True)
|
||||
tg_space: TelegramID = Column(BigInteger, primary_key=True)
|
||||
edit_index: int = Column(Integer, primary_key=True)
|
||||
redacted: bool = Column(Boolean, server_default=false())
|
||||
@dataclass
|
||||
class Message:
|
||||
db: ClassVar[Database] = fake_db
|
||||
|
||||
__table_args__ = (UniqueConstraint("mxid", "mx_room", "tg_space", name="_mx_id_room_2"),)
|
||||
mxid: EventID
|
||||
mx_room: RoomID
|
||||
tgid: TelegramID
|
||||
tg_space: TelegramID
|
||||
edit_index: int
|
||||
redacted: bool = False
|
||||
|
||||
@classmethod
|
||||
def get_all_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID) -> Iterator['Message']:
|
||||
return cls._select_all(cls.c.tgid == tgid, cls.c.tg_space == tg_space)
|
||||
def _from_row(cls, row: Record | None) -> Message | None:
|
||||
if row is None:
|
||||
return None
|
||||
return cls(**row)
|
||||
|
||||
columns: ClassVar[str] = "mxid, mx_room, tgid, tg_space, edit_index, redacted"
|
||||
|
||||
@classmethod
|
||||
def get_one_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID, edit_index: int = 0
|
||||
) -> Optional['Message']:
|
||||
async def get_all_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID) -> list[Message]:
|
||||
q = f"SELECT {cls.columns} FROM message WHERE tgid=$1 AND tg_space=$2"
|
||||
rows = await cls.db.fetch(q, tgid, tg_space)
|
||||
return [cls._from_row(row) for row in rows]
|
||||
|
||||
@classmethod
|
||||
async def get_one_by_tgid(
|
||||
cls, tgid: TelegramID, tg_space: TelegramID, edit_index: int = 0
|
||||
) -> Message | None:
|
||||
if edit_index < 0:
|
||||
return cls._one_or_none(cls.db.execute(
|
||||
cls.t.select()
|
||||
.where(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space))
|
||||
.order_by(desc(cls.c.edit_index))
|
||||
.limit(1).offset(-edit_index - 1)))
|
||||
q = (
|
||||
f"SELECT {cls.columns} FROM message WHERE tgid=$1 AND tg_space=$2 "
|
||||
f"ORDER BY edit_index DESC LIMIT 1 OFFSET {-edit_index - 1}"
|
||||
)
|
||||
row = await cls.db.fetchrow(q, tgid, tg_space)
|
||||
else:
|
||||
return cls._select_one_or_none(cls.c.tgid == tgid, cls.c.tg_space == tg_space,
|
||||
cls.c.edit_index == edit_index)
|
||||
q = (
|
||||
f"SELECT {cls.columns} FROM message"
|
||||
" WHERE tgid=$1 AND tg_space=$2 AND edit_index=$3"
|
||||
)
|
||||
row = await cls.db.fetchrow(q, tgid, tg_space, edit_index)
|
||||
return cls._from_row(row)
|
||||
|
||||
@classmethod
|
||||
def get_first_by_tgids(cls, tgids: List[TelegramID], tg_space: TelegramID
|
||||
) -> Iterator['Message']:
|
||||
return cls._select_all(cls.c.tgid.in_(tgids), cls.c.tg_space == tg_space,
|
||||
cls.c.edit_index == 0)
|
||||
async def get_first_by_tgids(
|
||||
cls, tgids: list[TelegramID], tg_space: TelegramID
|
||||
) -> list[Message]:
|
||||
if cls.db.scheme == "postgres":
|
||||
q = (
|
||||
f"SELECT {cls.columns} FROM message"
|
||||
" WHERE tgid=ANY($1) AND tg_space=$2 AND edit_index=0"
|
||||
)
|
||||
rows = await cls.db.fetch(q, tgids, tg_space)
|
||||
else:
|
||||
tgid_placeholders = ("?," * len(tgids)).rstrip(",")
|
||||
q = (
|
||||
f"SELECT {cls.columns} FROM message "
|
||||
f"WHERE tg_space=? AND edit_index=0 AND tgid IN ({tgid_placeholders})"
|
||||
)
|
||||
rows = await cls.db.fetch(q, tg_space, *tgids)
|
||||
return [cls._from_row(row) for row in rows]
|
||||
|
||||
@classmethod
|
||||
def count_spaces_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> int:
|
||||
rows = cls.db.execute(select([func.count(cls.c.tg_space)])
|
||||
.where(and_(cls.c.mxid == mxid, cls.c.mx_room == mx_room)))
|
||||
try:
|
||||
count, = next(rows)
|
||||
return count
|
||||
except StopIteration:
|
||||
return 0
|
||||
async def count_spaces_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> int:
|
||||
return await cls.db.fetchval(
|
||||
"SELECT COUNT(tg_space) FROM message WHERE mxid=$1 AND mx_room=$2", mxid, mx_room
|
||||
) or 0
|
||||
|
||||
@classmethod
|
||||
def find_last(cls, mx_room: RoomID, tg_space: TelegramID) -> Optional['Message']:
|
||||
return cls._one_or_none(cls.db.execute(
|
||||
cls._make_simple_select(cls.c.mx_room == mx_room, cls.c.tg_space == tg_space)
|
||||
.order_by(desc(cls.c.tgid)).limit(1)))
|
||||
async def find_last(cls, mx_room: RoomID, tg_space: TelegramID) -> Message | None:
|
||||
q = (
|
||||
f"SELECT {cls.columns} FROM message WHERE mx_room=$1 AND tg_space=$2 "
|
||||
f"ORDER BY tgid DESC LIMIT 1"
|
||||
)
|
||||
return cls._from_row(await cls.db.fetchrow(q, mx_room, tg_space))
|
||||
|
||||
@classmethod
|
||||
def delete_all(cls, mx_room: RoomID) -> None:
|
||||
cls.db.execute(cls.t.delete().where(cls.c.mx_room == mx_room))
|
||||
async def delete_all(cls, mx_room: RoomID) -> None:
|
||||
await cls.db.execute("DELETE FROM message WHERE mx_room=$1", mx_room)
|
||||
|
||||
@classmethod
|
||||
def get_by_mxid(cls, mxid: EventID, mx_room: RoomID, tg_space: TelegramID
|
||||
) -> Optional['Message']:
|
||||
return cls._select_one_or_none(cls.c.mxid == mxid, cls.c.mx_room == mx_room,
|
||||
cls.c.tg_space == tg_space)
|
||||
async def get_by_mxid(
|
||||
cls, mxid: EventID, mx_room: RoomID, tg_space: TelegramID
|
||||
) -> Message | None:
|
||||
q = f"SELECT {cls.columns} FROM message WHERE mxid=$1 AND mx_room=$2 AND tg_space=$3"
|
||||
return cls._from_row(await cls.db.fetchrow(q, mxid, mx_room, tg_space))
|
||||
|
||||
@classmethod
|
||||
def get_by_mxids(cls, mxids: List[EventID], mx_room: RoomID, tg_space: TelegramID
|
||||
) -> Iterator['Message']:
|
||||
return cls._select_all(cls.c.mxid.in_(mxids), cls.c.mx_room == mx_room,
|
||||
cls.c.tg_space == tg_space)
|
||||
async def get_by_mxids(
|
||||
cls, mxids: list[EventID], mx_room: RoomID, tg_space: TelegramID
|
||||
) -> list[Message]:
|
||||
if cls.db.scheme == "postgres":
|
||||
q = (
|
||||
f"SELECT {cls.columns} FROM message"
|
||||
" WHERE mxid=ANY($1) AND mx_room=$2 AND tg_space=$3"
|
||||
)
|
||||
rows = await cls.db.fetch(q, mxids, mx_room, tg_space)
|
||||
else:
|
||||
mxid_placeholders = ("?," * len(mxids)).rstrip(",")
|
||||
q = (
|
||||
f"SELECT {cls.columns} FROM message "
|
||||
f"WHERE mx_room=? AND tg_space=? AND mxid IN ({mxid_placeholders})"
|
||||
)
|
||||
rows = await cls.db.fetch(q, mx_room, tg_space, *mxids)
|
||||
return [cls._from_row(row) for row in rows]
|
||||
|
||||
@classmethod
|
||||
def update_by_tgid(cls, s_tgid: TelegramID, s_tg_space: TelegramID, s_edit_index: int,
|
||||
**values) -> None:
|
||||
with cls.db.begin() as conn:
|
||||
conn.execute(cls.t.update()
|
||||
.where(and_(cls.c.tgid == s_tgid, cls.c.tg_space == s_tg_space,
|
||||
cls.c.edit_index == s_edit_index))
|
||||
.values(**values))
|
||||
async def replace_temp_mxid(cls, temp_mxid: str, mx_room: RoomID, real_mxid: EventID) -> None:
|
||||
q = "UPDATE message SET mxid=$1 WHERE mxid=$2 AND mx_room=$3"
|
||||
await cls.db.execute(q, real_mxid, temp_mxid, mx_room)
|
||||
|
||||
@classmethod
|
||||
def update_by_mxid(cls, s_mxid: EventID, s_mx_room: RoomID, **values) -> None:
|
||||
with cls.db.begin() as conn:
|
||||
conn.execute(cls.t.update()
|
||||
.where(and_(cls.c.mxid == s_mxid, cls.c.mx_room == s_mx_room))
|
||||
.values(**values))
|
||||
async def insert(self) -> None:
|
||||
q = (
|
||||
"INSERT INTO message (mxid, mx_room, tgid, tg_space, edit_index, redacted) "
|
||||
"VALUES ($1, $2, $3, $4, $5, $6)"
|
||||
)
|
||||
await self.db.execute(
|
||||
q, self.mxid, self.mx_room, self.tgid, self.tg_space, self.edit_index, self.redacted
|
||||
)
|
||||
|
||||
async def delete(self) -> None:
|
||||
q = "DELETE FROM message WHERE mxid=$1 AND mx_room=$2 AND tg_space=$3"
|
||||
await self.db.execute(q, self.mxid, self.mx_room, self.tg_space)
|
||||
|
||||
async def mark_redacted(self) -> None:
|
||||
self.redacted = True
|
||||
q = "UPDATE message SET redacted=true WHERE mxid=$1 AND mx_room=$2"
|
||||
await self.db.execute(q, self.mxid, self.mx_room)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# mautrix-telegram - A Matrix-Telegram puppeting bridge
|
||||
# Copyright (C) 2019 Tulir Asokan
|
||||
# Copyright (C) 2021 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
@@ -13,54 +13,116 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional, Iterable
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import Column, BigInteger, String, Boolean, Text, func, sql
|
||||
from typing import ClassVar, Any, TYPE_CHECKING
|
||||
import json
|
||||
|
||||
from asyncpg import Record
|
||||
from attr import dataclass
|
||||
import attr
|
||||
|
||||
from mautrix.types import RoomID, ContentURI
|
||||
from mautrix.util.db import Base
|
||||
from mautrix.util.async_db import Database
|
||||
|
||||
from ..types import TelegramID
|
||||
|
||||
fake_db = Database.create("") if TYPE_CHECKING else None
|
||||
|
||||
class Portal(Base):
|
||||
__tablename__ = "portal"
|
||||
|
||||
@dataclass
|
||||
class Portal:
|
||||
db: ClassVar[Database] = fake_db
|
||||
|
||||
# Telegram chat information
|
||||
tgid: TelegramID = Column(BigInteger, primary_key=True)
|
||||
tg_receiver: TelegramID = Column(BigInteger, primary_key=True)
|
||||
peer_type: str = Column(String, nullable=False)
|
||||
megagroup: bool = Column(Boolean)
|
||||
tgid: TelegramID
|
||||
tg_receiver: TelegramID
|
||||
peer_type: str
|
||||
megagroup: bool
|
||||
|
||||
# Matrix portal information
|
||||
mxid: Optional[RoomID] = Column(String, unique=True, nullable=True)
|
||||
avatar_url: Optional[ContentURI] = Column(String, nullable=True)
|
||||
encrypted: bool = Column(Boolean, nullable=False, server_default=sql.expression.false())
|
||||
|
||||
config: str = Column(Text, nullable=True)
|
||||
mxid: RoomID | None
|
||||
avatar_url: ContentURI | None
|
||||
encrypted: bool
|
||||
|
||||
# Telegram chat metadata
|
||||
username: str = Column(String, nullable=True)
|
||||
title: str = Column(String, nullable=True)
|
||||
about: str = Column(String, nullable=True)
|
||||
photo_id: str = Column(String, nullable=True)
|
||||
username: str | None
|
||||
title: str | None
|
||||
about: str | None
|
||||
photo_id: str | None
|
||||
|
||||
local_config: dict[str, Any] = attr.ib(factory=lambda: {})
|
||||
|
||||
@classmethod
|
||||
def get_by_tgid(cls, tgid: TelegramID, tg_receiver: TelegramID) -> Optional['Portal']:
|
||||
return cls._select_one_or_none(cls.c.tgid == tgid, cls.c.tg_receiver == tg_receiver)
|
||||
def _from_row(cls, row: Record | None) -> Portal | None:
|
||||
if row is None:
|
||||
return None
|
||||
data = {**row}
|
||||
data["local_config"] = json.loads(data.pop("config", None) or "{}")
|
||||
return cls(**data)
|
||||
|
||||
columns: ClassVar[str] = (
|
||||
"tgid, tg_receiver, peer_type, megagroup, mxid, avatar_url, encrypted, config, "
|
||||
"username, title, about, photo_id"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def find_private_chats(cls, tg_receiver: TelegramID) -> Iterable['Portal']:
|
||||
yield from cls._select_all(cls.c.tg_receiver == tg_receiver, cls.c.peer_type == "user")
|
||||
async def get_by_tgid(cls, tgid: TelegramID, tg_receiver: TelegramID) -> Portal | None:
|
||||
q = f"SELECT {cls.columns} FROM portal WHERE tgid=$1 AND tg_receiver=$2"
|
||||
return cls._from_row(await cls.db.fetchrow(q, tgid, tg_receiver))
|
||||
|
||||
@classmethod
|
||||
def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']:
|
||||
return cls._select_one_or_none(cls.c.mxid == mxid)
|
||||
async def get_by_mxid(cls, mxid: RoomID) -> Portal | None:
|
||||
q = f"SELECT {cls.columns} FROM portal WHERE mxid=$1"
|
||||
return cls._from_row(await cls.db.fetchrow(q, mxid))
|
||||
|
||||
@classmethod
|
||||
def get_by_username(cls, username: str) -> Optional['Portal']:
|
||||
return cls._select_one_or_none(func.lower(cls.c.username) == username)
|
||||
async def find_by_username(cls, username: str) -> Portal | None:
|
||||
q = f"SELECT {cls.columns} FROM portal WHERE lower(username)=$1"
|
||||
return cls._from_row(await cls.db.fetchrow(q, username.lower()))
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> Iterable['Portal']:
|
||||
yield from cls._select_all()
|
||||
async def find_private_chats(cls, tg_receiver: TelegramID) -> list[Portal]:
|
||||
q = f"SELECT {cls.columns} FROM portal WHERE tg_receiver=$1 AND peer_type='user'"
|
||||
return [cls._from_row(row) for row in await cls.db.fetch(q, tg_receiver)]
|
||||
|
||||
@classmethod
|
||||
async def all(cls) -> list[Portal]:
|
||||
rows = await cls.db.fetch(f"SELECT {cls.columns} FROM portal")
|
||||
return [cls._from_row(row) for row in rows]
|
||||
|
||||
@property
|
||||
def _values(self):
|
||||
return (self.tgid, self.tg_receiver, self.peer_type, self.mxid, self.avatar_url,
|
||||
self.encrypted, self.username, self.title, self.about, self.photo_id,
|
||||
self.megagroup, json.dumps(self.local_config) if self.local_config else None)
|
||||
|
||||
async def save(self) -> None:
|
||||
q = (
|
||||
"UPDATE portal SET mxid=$4, avatar_url=$5, encrypted=$6, username=$7, title=$8,"
|
||||
" about=$9, photo_id=$10, megagroup=$11, config=$12 "
|
||||
"WHERE tgid=$1 AND tg_receiver=$2 AND (peer_type=$3 OR true)"
|
||||
)
|
||||
await self.db.execute(q, *self._values)
|
||||
|
||||
async def update_id(self, id: TelegramID, peer_type: str) -> None:
|
||||
q = (
|
||||
"UPDATE portal SET tgid=$1, tg_receiver=$1, peer_type=$2 "
|
||||
"WHERE tgid=$3 AND tg_receiver=$3"
|
||||
)
|
||||
await self.db.execute(q, id, peer_type, self.tgid)
|
||||
self.tgid = id
|
||||
self.tg_receiver = id
|
||||
self.peer_type = peer_type
|
||||
|
||||
async def insert(self) -> None:
|
||||
q = (
|
||||
"INSERT INTO portal (tgid, tg_receiver, peer_type, mxid, avatar_url, encrypted,"
|
||||
" username, title, about, photo_id, megagroup, config) "
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)"
|
||||
)
|
||||
await self.db.execute(q, *self._values)
|
||||
|
||||
async def delete(self) -> None:
|
||||
q = "DELETE FROM portal WHERE tgid=$1 AND tg_receiver=$2"
|
||||
await self.db.execute(q, self.tgid, self.tg_receiver)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# mautrix-telegram - A Matrix-Telegram puppeting bridge
|
||||
# Copyright (C) 2019 Tulir Asokan
|
||||
# Copyright (C) 2021 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
@@ -13,51 +13,106 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional, Iterable
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import Column, Integer, BigInteger, String, Text, Boolean
|
||||
from sqlalchemy.sql import expression, func
|
||||
from typing import ClassVar, TYPE_CHECKING
|
||||
|
||||
from asyncpg import Record
|
||||
from attr import dataclass
|
||||
from yarl import URL
|
||||
|
||||
from mautrix.types import UserID, SyncToken
|
||||
from mautrix.util.db import Base
|
||||
from mautrix.util.async_db import Database
|
||||
|
||||
from ..types import TelegramID
|
||||
|
||||
fake_db = Database.create("") if TYPE_CHECKING else None
|
||||
|
||||
class Puppet(Base):
|
||||
__tablename__ = "puppet"
|
||||
|
||||
id: TelegramID = Column(BigInteger, primary_key=True)
|
||||
custom_mxid: UserID = Column(String, nullable=True)
|
||||
access_token: str = Column(String, nullable=True)
|
||||
next_batch: SyncToken = Column(String, nullable=True)
|
||||
base_url: str = Column(Text, nullable=True)
|
||||
displayname: str = Column(String, nullable=True)
|
||||
displayname_source: TelegramID = Column(BigInteger, nullable=True)
|
||||
displayname_contact: bool = Column(Boolean, nullable=False, server_default=expression.true())
|
||||
displayname_quality: int = Column(Integer, nullable=False, server_default="0")
|
||||
username: str = Column(String, nullable=True)
|
||||
photo_id: str = Column(String, nullable=True)
|
||||
is_bot: bool = Column(Boolean, nullable=True)
|
||||
matrix_registered: bool = Column(Boolean, nullable=False, server_default=expression.false())
|
||||
disable_updates: bool = Column(Boolean, nullable=False, server_default=expression.false())
|
||||
@dataclass
|
||||
class Puppet:
|
||||
db: ClassVar[Database] = fake_db
|
||||
|
||||
id: TelegramID
|
||||
|
||||
is_registered: bool
|
||||
|
||||
displayname: str | None
|
||||
displayname_source: TelegramID | None
|
||||
displayname_contact: bool
|
||||
displayname_quality: int
|
||||
disable_updates: bool
|
||||
username: str | None
|
||||
photo_id: str | None
|
||||
is_bot: bool | None
|
||||
|
||||
custom_mxid: UserID | None
|
||||
access_token: str | None
|
||||
next_batch: SyncToken | None
|
||||
base_url: URL | None
|
||||
|
||||
@classmethod
|
||||
def all_with_custom_mxid(cls) -> Iterable['Puppet']:
|
||||
yield from cls._select_all(cls.c.custom_mxid != None)
|
||||
def _from_row(cls, row: Record | None) -> Puppet | None:
|
||||
if row is None:
|
||||
return None
|
||||
data = {**row}
|
||||
base_url = data.pop("base_url", None)
|
||||
return cls(**data, base_url=URL(base_url) if base_url else None)
|
||||
|
||||
columns: ClassVar[str] = (
|
||||
"id, is_registered, displayname, displayname_source, displayname_contact, "
|
||||
"displayname_quality, disable_updates, username, photo_id, is_bot, "
|
||||
"custom_mxid, access_token, next_batch, base_url"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_by_tgid(cls, tgid: TelegramID) -> Optional['Puppet']:
|
||||
return cls._select_one_or_none(cls.c.id == tgid)
|
||||
async def all_with_custom_mxid(cls) -> list[Puppet]:
|
||||
q = f"SELECT {cls.columns} FROM puppet WHERE custom_mxid<>''"
|
||||
return [cls._from_row(row) for row in await cls.db.fetch(q)]
|
||||
|
||||
@classmethod
|
||||
def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
|
||||
return cls._select_one_or_none(cls.c.custom_mxid == mxid)
|
||||
async def get_by_tgid(cls, tgid: TelegramID) -> Puppet | None:
|
||||
q = f"SELECT {cls.columns} FROM puppet WHERE id=$1"
|
||||
return cls._from_row(await cls.db.fetchrow(q, tgid))
|
||||
|
||||
@classmethod
|
||||
def get_by_username(cls, username: str) -> Optional['Puppet']:
|
||||
return cls._select_one_or_none(func.lower(cls.c.username) == username)
|
||||
async def get_by_custom_mxid(cls, mxid: UserID) -> Puppet | None:
|
||||
q = f"SELECT {cls.columns} FROM puppet WHERE custom_mxid=$1"
|
||||
return cls._from_row(await cls.db.fetchrow(q, mxid))
|
||||
|
||||
@classmethod
|
||||
def get_by_displayname(cls, displayname: str) -> Optional['Puppet']:
|
||||
return cls._select_one_or_none(cls.c.displayname == displayname)
|
||||
async def find_by_username(cls, username: str) -> Puppet | None:
|
||||
q = f"SELECT {cls.columns} FROM puppet WHERE lower(username)=$1"
|
||||
return cls._from_row(await cls.db.fetchrow(q, username.lower()))
|
||||
|
||||
@classmethod
|
||||
async def find_by_displayname(cls, displayname: str) -> Puppet | None:
|
||||
q = f"SELECT {cls.columns} FROM puppet WHERE displayname=$1"
|
||||
return cls._from_row(await cls.db.fetchrow(q, displayname))
|
||||
|
||||
@property
|
||||
def _values(self):
|
||||
return (self.id, self.is_registered, self.displayname, self.displayname_source,
|
||||
self.displayname_contact, self.displayname_quality, self.disable_updates,
|
||||
self.username, self.photo_id, self.is_bot, self.custom_mxid, self.access_token,
|
||||
self.next_batch, str(self.base_url) if self.base_url else None)
|
||||
|
||||
async def save(self) -> None:
|
||||
q = (
|
||||
"UPDATE puppet "
|
||||
"SET is_registered=$2, displayname=$3, displayname_source=$4, displayname_contact=$5,"
|
||||
" displayname_quality=$6, disable_updates=$7, username=$8, photo_id=$9, is_bot=$10,"
|
||||
" custom_mxid=$11, access_token=$12, next_batch=$13, base_url=$14 "
|
||||
"WHERE id=$1"
|
||||
)
|
||||
await self.db.execute(q, *self._values)
|
||||
|
||||
async def insert(self) -> None:
|
||||
q = (
|
||||
"INSERT INTO puppet ("
|
||||
" id, is_registered, displayname, displayname_source, displayname_contact,"
|
||||
" displayname_quality, disable_updates, username, photo_id, is_bot,"
|
||||
" custom_mxid, access_token, next_batch, base_url"
|
||||
") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)"
|
||||
)
|
||||
await self.db.execute(q, *self._values)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# mautrix-telegram - A Matrix-Telegram puppeting bridge
|
||||
# Copyright (C) 2019 Tulir Asokan
|
||||
# Copyright (C) 2021 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
@@ -13,69 +13,62 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional, cast, Dict, Any, TYPE_CHECKING
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import (Column, ForeignKey, Integer, BigInteger, String, Boolean, Text,
|
||||
TypeDecorator)
|
||||
from typing import ClassVar, TYPE_CHECKING
|
||||
|
||||
from attr import dataclass
|
||||
|
||||
from mautrix.types import ContentURI, EncryptedFile
|
||||
from mautrix.util.db import Base
|
||||
from mautrix.util.async_db import Database
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.engine.result import RowProxy
|
||||
fake_db = Database.create("") if TYPE_CHECKING else None
|
||||
|
||||
|
||||
class DBEncryptedFile(TypeDecorator):
|
||||
impl = Text
|
||||
@dataclass
|
||||
class TelegramFile:
|
||||
db: ClassVar[Database] = fake_db
|
||||
|
||||
@property
|
||||
def python_type(self):
|
||||
return EncryptedFile
|
||||
|
||||
def process_bind_param(self, value: EncryptedFile, dialect) -> Optional[str]:
|
||||
if value is not None:
|
||||
return value.json()
|
||||
return None
|
||||
|
||||
def process_result_value(self, value: str, dialect) -> Optional[EncryptedFile]:
|
||||
if value is not None:
|
||||
return EncryptedFile.parse_json(value)
|
||||
return None
|
||||
|
||||
def process_literal_param(self, value, dialect):
|
||||
return value
|
||||
|
||||
|
||||
class TelegramFile(Base):
|
||||
__tablename__ = "telegram_file"
|
||||
|
||||
id: str = Column(String, primary_key=True)
|
||||
mxc: ContentURI = Column(String)
|
||||
mime_type: str = Column(String)
|
||||
was_converted: bool = Column(Boolean)
|
||||
timestamp: int = Column(BigInteger)
|
||||
size: Optional[int] = Column(Integer, nullable=True)
|
||||
width: Optional[int] = Column(Integer, nullable=True)
|
||||
height: Optional[int] = Column(Integer, nullable=True)
|
||||
decryption_info: Optional[Dict[str, Any]] = Column(DBEncryptedFile, nullable=True)
|
||||
thumbnail_id: str = Column("thumbnail", String, ForeignKey("telegram_file.id"), nullable=True)
|
||||
thumbnail: Optional['TelegramFile'] = None
|
||||
id: str
|
||||
mxc: ContentURI
|
||||
mime_type: str
|
||||
was_converted: bool
|
||||
timestamp: int
|
||||
size: int | None
|
||||
width: int | None
|
||||
height: int | None
|
||||
decryption_info: EncryptedFile | None
|
||||
thumbnail: TelegramFile | None = None
|
||||
|
||||
@classmethod
|
||||
def scan(cls, row: 'RowProxy') -> 'TelegramFile':
|
||||
telegram_file = cast(TelegramFile, super().scan(row))
|
||||
if isinstance(telegram_file.thumbnail, str):
|
||||
telegram_file.thumbnail = cls.get(telegram_file.thumbnail)
|
||||
return telegram_file
|
||||
async def get(cls, loc_id: str, *, _thumbnail: bool = False) -> TelegramFile | None:
|
||||
q = (
|
||||
"SELECT id, mxc, mime_type, was_converted, timestamp, size, width, height, thumbnail,"
|
||||
" decryption_info "
|
||||
"FROM telegram_file WHERE id=$1"
|
||||
)
|
||||
row = await cls.db.fetchrow(q, loc_id)
|
||||
if row is None:
|
||||
return None
|
||||
data = {**row}
|
||||
thumbnail_id = data.pop("thumbnail", None)
|
||||
if _thumbnail:
|
||||
# Don't allow more than one level of recursion
|
||||
thumbnail_id = None
|
||||
decryption_info = data.pop("decryption_info", None)
|
||||
return cls(
|
||||
**data,
|
||||
thumbnail=(await cls.get(thumbnail_id, _thumbnail=True)) if thumbnail_id else None,
|
||||
decryption_info=EncryptedFile.parse_json(decryption_info) if decryption_info else None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get(cls, loc_id: str) -> Optional['TelegramFile']:
|
||||
return cls._select_one_or_none(cls.c.id == loc_id)
|
||||
|
||||
def insert(self) -> None:
|
||||
with self.db.begin() as conn:
|
||||
conn.execute(self.t.insert().values(
|
||||
id=self.id, mxc=self.mxc, mime_type=self.mime_type,
|
||||
was_converted=self.was_converted, timestamp=self.timestamp, size=self.size,
|
||||
width=self.width, height=self.height, decryption_info=self.decryption_info,
|
||||
thumbnail=self.thumbnail.id if self.thumbnail else self.thumbnail_id))
|
||||
async def insert(self) -> None:
|
||||
q = (
|
||||
"INSERT INTO telegram_file (id, mxc, mime_type, was_converted, size, width, height, "
|
||||
" thumbnail, decryption_info) "
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
|
||||
)
|
||||
await self.db.execute(q, self.id, self.mxc, self.mime_type, self.was_converted, self.size,
|
||||
self.width, self.height,
|
||||
self.thumbnail.id if self.thumbnail else None,
|
||||
self.decryption_info.json() if self.decryption_info else None)
|
||||
|
||||
@@ -0,0 +1,204 @@
|
||||
# mautrix-telegram - A Matrix-Telegram puppeting bridge
|
||||
# Copyright (C) 2021 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar, TYPE_CHECKING
|
||||
import datetime
|
||||
import asyncio
|
||||
|
||||
from telethon.sessions import MemorySession
|
||||
from telethon.tl.types import updates, PeerUser, PeerChat, PeerChannel
|
||||
from telethon.crypto import AuthKey
|
||||
from telethon import utils
|
||||
|
||||
from mautrix.util.async_db import Database
|
||||
|
||||
fake_db = Database.create("") if TYPE_CHECKING else None
|
||||
|
||||
|
||||
class PgSession(MemorySession):
|
||||
db: ClassVar[Database] = fake_db
|
||||
|
||||
session_id: str
|
||||
_dc_id: int
|
||||
_server_address: str | None
|
||||
_port: int | None
|
||||
_auth_key: AuthKey | None
|
||||
_takeout_id: int | None
|
||||
_process_entities_lock: asyncio.Lock
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
dc_id: int = 0,
|
||||
server_address: str | None = None,
|
||||
port: int | None = None,
|
||||
auth_key: AuthKey | None = None,
|
||||
takeout_id: int | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.session_id = session_id
|
||||
self._dc_id = dc_id
|
||||
self._server_address = server_address
|
||||
self._port = port
|
||||
self._auth_key = auth_key
|
||||
self._takeout_id = takeout_id
|
||||
self._process_entities_lock = asyncio.Lock()
|
||||
|
||||
def clone(self, to_instance=None) -> MemorySession:
|
||||
# We don't want to store data of clones
|
||||
# (which are used for temporarily connecting to different DCs)
|
||||
return super().clone(MemorySession())
|
||||
|
||||
@property
|
||||
def auth_key_bytes(self) -> bytes | None:
|
||||
return self._auth_key.key if self._auth_key else None
|
||||
|
||||
@classmethod
|
||||
async def get(cls, session_id: str) -> PgSession:
|
||||
q = (
|
||||
"SELECT session_id, dc_id, server_address, port, auth_key FROM telethon_sessions "
|
||||
"WHERE session_id=$1"
|
||||
)
|
||||
row = await cls.db.fetchrow(q, session_id)
|
||||
if row is None:
|
||||
return cls(session_id)
|
||||
data = {**row}
|
||||
auth_key = AuthKey(data.pop("auth_key", None))
|
||||
return cls(**data, auth_key=auth_key)
|
||||
|
||||
@classmethod
|
||||
async def has(cls, session_id: str) -> bool:
|
||||
q = "SELECT COUNT(*) FROM telethon_sessions WHERE session_id=$1"
|
||||
count = await cls.db.fetchval(q, session_id)
|
||||
return count > 0
|
||||
|
||||
async def save(self) -> None:
|
||||
q = (
|
||||
"INSERT INTO telethon_sessions (session_id, dc_id, server_address, port, auth_key) "
|
||||
"VALUES ($1, $2, $3, $4, $5) ON CONFLICT (session_id) "
|
||||
"DO UPDATE SET dc_id=$2, server_address=$3, port=$4, auth_key=$5"
|
||||
)
|
||||
await self.db.execute(
|
||||
q, self.session_id, self.dc_id, self.server_address, self.port, self.auth_key_bytes
|
||||
)
|
||||
|
||||
_tables: ClassVar[tuple[str, ...]] = (
|
||||
"telethon_sessions", "telethon_entities", "telethon_sent_files", "telethon_update_state"
|
||||
)
|
||||
|
||||
async def delete(self) -> None:
|
||||
async with self.db.acquire() as conn, conn.transaction():
|
||||
for table in self._tables:
|
||||
await conn.execute(f"DELETE FROM {table} WHERE session_id=$1", self.session_id)
|
||||
|
||||
async def close(self) -> None:
|
||||
# Nothing to do here, DB connection is global
|
||||
pass
|
||||
|
||||
async def get_update_state(self, entity_id: int) -> updates.State | None:
|
||||
q = (
|
||||
"SELECT pts, qts, date, seq, unread_count FROM telethon_update_state "
|
||||
"WHERE session_id=$1 AND entity_id=$2"
|
||||
)
|
||||
row = await self.db.fetchrow(q, self.session_id, entity_id)
|
||||
if row is None:
|
||||
return None
|
||||
date = datetime.datetime.utcfromtimestamp(row["date"])
|
||||
return updates.State(row["pts"], row["qts"], date, row["seq"], row["unread_count"])
|
||||
|
||||
async def set_update_state(self, entity_id: int, row: updates.State) -> None:
|
||||
q = (
|
||||
"INSERT INTO telethon_update_state"
|
||||
" (session_id, entity_id, pts, qts, date, seq, unread_count) "
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7)"
|
||||
"ON CONFLICT (session_id, entity_id) DO UPDATE"
|
||||
" SET pts=$3, qts=$4, date=$5, seq=$6, unread_count=$7"
|
||||
)
|
||||
ts = row.date.timestamp()
|
||||
await self.db.execute(
|
||||
q, self.session_id, entity_id, row.pts, row.qts, ts, row.seq, row.unread_count
|
||||
)
|
||||
|
||||
def _entity_values_to_row(
|
||||
self, id: int, hash: int, username: str | None, phone: str | int | None, name: str | None
|
||||
) -> tuple[str, int, int, str | None, str | None, str | None]:
|
||||
return self.session_id, id, hash, username, str(phone) if phone else None, name
|
||||
|
||||
async def process_entities(self, tlo) -> None:
|
||||
# Postgres likes to deadlock on simultaneous upserts, so just lock the whole thing here
|
||||
# TODO: make sure postgres doesn't deadlock on upserts when session_id is different
|
||||
async with self._process_entities_lock:
|
||||
await self._locked_process_entities(tlo)
|
||||
|
||||
async def _locked_process_entities(self, tlo) -> None:
|
||||
rows: list[
|
||||
tuple[str, int, int, str | None, str | None, str | None]
|
||||
] = self._entities_to_rows(tlo)
|
||||
if not rows:
|
||||
return
|
||||
if self.db.scheme == "postgres":
|
||||
q = (
|
||||
"INSERT INTO telethon_entities (session_id, id, hash, username, phone, name) "
|
||||
"VALUES ($1, unnest($2::bigint[]), unnest($3::bigint[]), "
|
||||
" unnest($4::text[]), unnest($5::text[]), unnest($6::text[])) "
|
||||
"ON CONFLICT (session_id, id) DO UPDATE"
|
||||
" SET hash=excluded.hash, username=excluded.username,"
|
||||
" phone=excluded.phone, name=excluded.name"
|
||||
)
|
||||
_, ids, hashes, usernames, phones, names = zip(*rows)
|
||||
await self.db.execute(q, self.session_id, ids, hashes, usernames, phones, names)
|
||||
else:
|
||||
q = (
|
||||
"INSERT INTO telethon_entities (session_id, id, hash, username, phone, name) "
|
||||
"VALUES ($1, $2, $3, $4, $5, $6) "
|
||||
"ON CONFLICT (session_id, id) DO UPDATE "
|
||||
" SET hash=$3, username=$4, phone=$5, name=$6"
|
||||
)
|
||||
await self.db.executemany(q, rows)
|
||||
|
||||
async def _select_entity(
|
||||
self, constraint: str, *args: str | int | tuple[int, ...]
|
||||
) -> tuple[int, int] | None:
|
||||
row = await self.db.fetchrow(
|
||||
f"SELECT id, hash FROM telethon_entities WHERE {constraint}", *args
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
return row["id"], row["hash"]
|
||||
|
||||
async def get_entity_rows_by_phone(self, key: str | int) -> tuple[int, int] | None:
|
||||
return await self._select_entity("phone=$1", str(key))
|
||||
|
||||
async def get_entity_rows_by_username(self, key: str) -> tuple[int, int] | None:
|
||||
return await self._select_entity("username=$1", key)
|
||||
|
||||
async def get_entity_rows_by_name(self, key: str) -> tuple[int, int] | None:
|
||||
return await self._select_entity("name=$1", key)
|
||||
|
||||
async def get_entity_rows_by_id(self, key: int, exact: bool = True) -> tuple[int, int] | None:
|
||||
if exact:
|
||||
return await self._select_entity("id=$1", key)
|
||||
|
||||
ids = (
|
||||
utils.get_peer_id(PeerUser(key)),
|
||||
utils.get_peer_id(PeerChat(key)),
|
||||
utils.get_peer_id(PeerChannel(key))
|
||||
)
|
||||
if self.db.scheme == "postgres":
|
||||
return await self._select_entity("id=ANY($1)", ids)
|
||||
else:
|
||||
return await self._select_entity(f"id IN ($1, $2, $3)", *ids)
|
||||
@@ -0,0 +1,5 @@
|
||||
from mautrix.util.async_db import UpgradeTable
|
||||
|
||||
upgrade_table = UpgradeTable()
|
||||
|
||||
from . import v01_initial_revision
|
||||
@@ -0,0 +1,300 @@
|
||||
# mautrix-telegram - A Matrix-Telegram puppeting bridge
|
||||
# Copyright (C) 2021 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from asyncpg import Connection
|
||||
from . import upgrade_table
|
||||
|
||||
legacy_version_query = "SELECT version_num FROM alembic_version"
|
||||
last_legacy_version = "bfc0a39bfe02"
|
||||
|
||||
|
||||
def table_exists(scheme: str, name: str) -> str:
|
||||
if scheme == "sqlite":
|
||||
return f"SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND name='{name}')"
|
||||
elif scheme == "postgres":
|
||||
return f"SELECT EXISTS(SELECT FROM information_schema.tables WHERE table_name='{name}')"
|
||||
raise RuntimeError("unsupported database scheme")
|
||||
|
||||
|
||||
@upgrade_table.register(description="Initial asyncpg revision")
|
||||
async def upgrade_v1(conn: Connection, scheme: str) -> None:
|
||||
is_legacy = await conn.fetchval(table_exists(scheme, "alembic_version"))
|
||||
if is_legacy:
|
||||
await migrate_legacy_to_v1(conn, scheme)
|
||||
else:
|
||||
await create_v1_tables(conn)
|
||||
|
||||
|
||||
async def migrate_legacy_to_v1(conn: Connection, scheme: str) -> None:
|
||||
legacy_version = await conn.fetchval(legacy_version_query)
|
||||
if legacy_version != last_legacy_version:
|
||||
raise RuntimeError("Legacy database is not on last version. Please upgrade the old "
|
||||
"database with alembic or drop it completely first.")
|
||||
if scheme != "sqlite":
|
||||
await conn.execute(
|
||||
"""
|
||||
ALTER TABLE contact
|
||||
DROP CONSTRAINT contact_user_fkey,
|
||||
DROP CONSTRAINT contact_contact_fkey,
|
||||
ADD CONSTRAINT contact_user_fkey FOREIGN KEY (contact) REFERENCES puppet(id)
|
||||
ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
ADD CONSTRAINT contact_contact_fkey FOREIGN KEY ("user") REFERENCES "user"(tgid)
|
||||
ON DELETE CASCADE ON UPDATE CASCADE
|
||||
"""
|
||||
)
|
||||
await conn.execute(
|
||||
"""
|
||||
ALTER TABLE telethon_sessions
|
||||
DROP CONSTRAINT telethon_sessions_pkey,
|
||||
ADD CONSTRAINT telethon_sessions_pkey PRIMARY KEY (session_id)
|
||||
"""
|
||||
)
|
||||
await conn.execute(
|
||||
"""
|
||||
ALTER TABLE telegram_file
|
||||
DROP CONSTRAINT fk_file_thumbnail,
|
||||
ADD CONSTRAINT fk_file_thumbnail
|
||||
FOREIGN KEY (thumbnail) REFERENCES telegram_file(id)
|
||||
ON UPDATE CASCADE ON DELETE SET NULL
|
||||
"""
|
||||
)
|
||||
await conn.execute("ALTER TABLE puppet ALTER COLUMN id DROP DEFAULT")
|
||||
await conn.execute("DROP SEQUENCE puppet_id_seq")
|
||||
await conn.execute("ALTER TABLE bot_chat ALTER COLUMN id DROP DEFAULT")
|
||||
await conn.execute("DROP SEQUENCE bot_chat_id_seq")
|
||||
await conn.execute("ALTER TABLE portal ALTER COLUMN config TYPE jsonb USING config::jsonb")
|
||||
await conn.execute(
|
||||
"ALTER TABLE telegram_file ALTER COLUMN decryption_info TYPE jsonb "
|
||||
"USING decryption_info::jsonb"
|
||||
)
|
||||
await varchar_to_text(conn)
|
||||
else:
|
||||
await conn.execute(
|
||||
"""CREATE TABLE telethon_sessions_new (
|
||||
session_id TEXT PRIMARY KEY,
|
||||
dc_id INTEGER,
|
||||
server_address TEXT,
|
||||
port INTEGER,
|
||||
auth_key bytea
|
||||
)"""
|
||||
)
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO telethon_sessions_new (session_id, dc_id, server_address, port, auth_key)
|
||||
SELECT session_id, dc_id, server_address, port, auth_key FROM telethon_sessions
|
||||
"""
|
||||
)
|
||||
await conn.execute("DROP TABLE telethon_sessions")
|
||||
await conn.execute("ALTER TABLE telethon_sessions_new RENAME TO telethon_sessions")
|
||||
|
||||
await update_state_store(conn, scheme)
|
||||
await conn.execute('ALTER TABLE "user" ADD COLUMN is_bot BOOLEAN NOT NULL DEFAULT false')
|
||||
await conn.execute("ALTER TABLE puppet RENAME COLUMN matrix_registered TO is_registered")
|
||||
await conn.execute("DROP TABLE telethon_version")
|
||||
await conn.execute("DROP TABLE alembic_version")
|
||||
|
||||
|
||||
async def update_state_store(conn: Connection, scheme: str) -> None:
|
||||
# The Matrix state store already has more or less the correct schema, so set the version
|
||||
await conn.execute("CREATE TABLE mx_version (version INTEGER PRIMARY KEY)")
|
||||
await conn.execute("INSERT INTO mx_version (version) VALUES (2)")
|
||||
if scheme != "sqlite":
|
||||
# Also add the membership type on postgres
|
||||
await conn.execute(
|
||||
"CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock')"
|
||||
)
|
||||
await conn.execute(
|
||||
"ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE membership "
|
||||
"USING LOWER(membership)::membership"
|
||||
)
|
||||
else:
|
||||
# On SQLite there's no custom type, but we still want to lowercase everything
|
||||
await conn.execute("UPDATE mx_user_profile SET membership=LOWER(membership)")
|
||||
|
||||
|
||||
async def varchar_to_text(conn: Connection) -> None:
|
||||
columns_to_adjust = {
|
||||
"user": ("mxid", "tg_username", "tg_phone"),
|
||||
"portal": (
|
||||
"peer_type", "mxid", "username", "title", "about", "photo_id", "avatar_url", "config"
|
||||
),
|
||||
"message": ("mxid", "mx_room"),
|
||||
"puppet": (
|
||||
"displayname", "username", "photo_id",
|
||||
) + (
|
||||
"access_token", "custom_mxid", "next_batch", "base_url"
|
||||
),
|
||||
"bot_chat": ("type",),
|
||||
"telegram_file": ("id", "mxc", "mime_type", "thumbnail"),
|
||||
# Phone is a bigint in the old schema, which is safe, but we don't do math on it,
|
||||
# so let's change it to a string
|
||||
"telethon_entities": ("session_id", "username", "name", "phone"),
|
||||
"telethon_sent_files": ("session_id",),
|
||||
"telethon_sessions": ("session_id", "server_address"),
|
||||
"telethon_update_state": ("session_id",),
|
||||
"mx_room_state": ("room_id",),
|
||||
"mx_user_profile": ("room_id", "user_id", "displayname", "avatar_url"),
|
||||
}
|
||||
for table, columns in columns_to_adjust.items():
|
||||
for column in columns:
|
||||
await conn.execute(f'ALTER TABLE "{table}" ALTER COLUMN {column} TYPE TEXT')
|
||||
|
||||
|
||||
async def create_v1_tables(conn: Connection) -> None:
|
||||
await conn.execute(
|
||||
"""CREATE TABLE "user" (
|
||||
mxid TEXT PRIMARY KEY,
|
||||
tgid BIGINT UNIQUE,
|
||||
tg_username TEXT,
|
||||
tg_phone TEXT,
|
||||
is_bot BOOLEAN NOT NULL DEFAULT false,
|
||||
saved_contacts INTEGER NOT NULL DEFAULT 0
|
||||
)"""
|
||||
)
|
||||
await conn.execute(
|
||||
"""CREATE TABLE portal (
|
||||
tgid BIGINT,
|
||||
tg_receiver BIGINT,
|
||||
peer_type TEXT NOT NULL,
|
||||
mxid TEXT UNIQUE,
|
||||
avatar_url TEXT,
|
||||
encrypted BOOLEAN NOT NULL DEFAULT false,
|
||||
username TEXT,
|
||||
title TEXT,
|
||||
about TEXT,
|
||||
photo_id TEXT,
|
||||
megagroup BOOLEAN,
|
||||
config jsonb,
|
||||
PRIMARY KEY (tgid, tg_receiver)
|
||||
)"""
|
||||
)
|
||||
await conn.execute(
|
||||
"""CREATE TABLE message (
|
||||
mxid TEXT,
|
||||
mx_room TEXT,
|
||||
tgid BIGINT NOT NULL,
|
||||
tg_space BIGINT NOT NULL,
|
||||
edit_index INTEGER NOT NULL,
|
||||
redacted BOOLEAN NOT NULL DEFAULT false,
|
||||
PRIMARY KEY (tgid, tg_space, edit_index),
|
||||
UNIQUE (mxid, mx_room, tg_space)
|
||||
)"""
|
||||
)
|
||||
await conn.execute(
|
||||
"""CREATE TABLE puppet (
|
||||
id BIGINT PRIMARY KEY,
|
||||
|
||||
is_registered BOOLEAN NOT NULL DEFAULT false,
|
||||
|
||||
displayname TEXT,
|
||||
displayname_source BIGINT,
|
||||
displayname_contact BOOLEAN NOT NULL DEFAULT true,
|
||||
displayname_quality INTEGER NOT NULL DEFAULT 0,
|
||||
disable_updates BOOLEAN NOT NULL DEFAULT false,
|
||||
username TEXT,
|
||||
photo_id TEXT,
|
||||
is_bot BOOLEAN,
|
||||
|
||||
access_token TEXT,
|
||||
custom_mxid TEXT,
|
||||
next_batch TEXT,
|
||||
base_url TEXT
|
||||
)"""
|
||||
)
|
||||
await conn.execute(
|
||||
"""CREATE TABLE telegram_file (
|
||||
id TEXT PRIMARY KEY,
|
||||
mxc TEXT NOT NULL,
|
||||
mime_type TEXT,
|
||||
was_converted BOOLEAN NOT NULL DEFAULT false,
|
||||
timestamp BIGINT NOT NULL DEFAULT 0,
|
||||
size BIGINT,
|
||||
width INTEGER,
|
||||
height INTEGER,
|
||||
thumbnail TEXT,
|
||||
decryption_info jsonb,
|
||||
FOREIGN KEY (thumbnail) REFERENCES telegram_file(id)
|
||||
ON UPDATE CASCADE ON DELETE SET NULL
|
||||
)"""
|
||||
)
|
||||
await conn.execute(
|
||||
"""CREATE TABLE bot_chat (
|
||||
id BIGINT PRIMARY KEY,
|
||||
type TEXT NOT NULL
|
||||
)"""
|
||||
)
|
||||
await conn.execute(
|
||||
"""CREATE TABLE user_portal (
|
||||
"user" BIGINT,
|
||||
portal BIGINT,
|
||||
portal_receiver BIGINT,
|
||||
PRIMARY KEY ("user", portal, portal_receiver),
|
||||
FOREIGN KEY ("user") REFERENCES "user"(tgid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
FOREIGN KEY (portal, portal_receiver) REFERENCES portal(tgid, tg_receiver)
|
||||
ON DELETE CASCADE ON UPDATE CASCADE
|
||||
)"""
|
||||
)
|
||||
await conn.execute(
|
||||
"""CREATE TABLE contact (
|
||||
"user" BIGINT,
|
||||
contact BIGINT,
|
||||
PRIMARY KEY ("user", contact),
|
||||
FOREIGN KEY ("user") REFERENCES "user"(tgid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
FOREIGN KEY (contact) REFERENCES puppet(id) ON DELETE CASCADE ON UPDATE CASCADE
|
||||
)"""
|
||||
)
|
||||
await conn.execute(
|
||||
"""CREATE TABLE telethon_sessions (
|
||||
session_id TEXT PRIMARY KEY,
|
||||
dc_id INTEGER,
|
||||
server_address TEXT,
|
||||
port INTEGER,
|
||||
auth_key bytea
|
||||
)"""
|
||||
)
|
||||
await conn.execute(
|
||||
"""CREATE TABLE telethon_entities (
|
||||
session_id TEXT,
|
||||
id BIGINT,
|
||||
hash BIGINT NOT NULL,
|
||||
username TEXT,
|
||||
phone TEXT,
|
||||
name TEXT,
|
||||
PRIMARY KEY (session_id, id)
|
||||
)"""
|
||||
)
|
||||
await conn.execute(
|
||||
"""CREATE TABLE telethon_sent_files (
|
||||
session_id TEXT,
|
||||
md5_digest bytea,
|
||||
file_size INTEGER,
|
||||
type INTEGER,
|
||||
id BIGINT,
|
||||
hash BIGINT,
|
||||
PRIMARY KEY (session_id, md5_digest, file_size, type)
|
||||
)"""
|
||||
)
|
||||
await conn.execute(
|
||||
"""CREATE TABLE telethon_update_state (
|
||||
session_id TEXT,
|
||||
entity_id BIGINT,
|
||||
pts BIGINT,
|
||||
qts BIGINT,
|
||||
date BIGINT,
|
||||
seq BIGINT,
|
||||
unread_count INTEGER,
|
||||
PRIMARY KEY (session_id, entity_id)
|
||||
)"""
|
||||
)
|
||||
+91
-68
@@ -1,5 +1,5 @@
|
||||
# mautrix-telegram - A Matrix-Telegram puppeting bridge
|
||||
# Copyright (C) 2019 Tulir Asokan
|
||||
# Copyright (C) 2021 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
@@ -13,96 +13,119 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional, Iterable, Tuple
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import Column, ForeignKey, ForeignKeyConstraint, BigInteger, Integer, String, func
|
||||
from typing import Iterable, ClassVar, TYPE_CHECKING
|
||||
|
||||
from asyncpg import Record
|
||||
from attr import dataclass
|
||||
|
||||
from mautrix.types import UserID
|
||||
from mautrix.util.db import Base
|
||||
from mautrix.util.async_db import Database
|
||||
|
||||
from ..types import TelegramID
|
||||
|
||||
fake_db = Database.create("") if TYPE_CHECKING else None
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "user"
|
||||
|
||||
mxid: UserID = Column(String, primary_key=True)
|
||||
tgid: Optional[TelegramID] = Column(BigInteger, nullable=True, unique=True)
|
||||
tg_username: str = Column(String, nullable=True)
|
||||
tg_phone: str = Column(String, nullable=True)
|
||||
saved_contacts: int = Column(Integer, default=0, nullable=False)
|
||||
@dataclass
|
||||
class User:
|
||||
db: ClassVar[Database] = fake_db
|
||||
|
||||
mxid: UserID
|
||||
tgid: TelegramID | None
|
||||
tg_username: str | None
|
||||
tg_phone: str | None
|
||||
is_bot: bool
|
||||
saved_contacts: int
|
||||
|
||||
@classmethod
|
||||
def all_with_tgid(cls) -> Iterable['User']:
|
||||
return cls._select_all(cls.c.tgid != None)
|
||||
def _from_row(cls, row: Record | None) -> User | None:
|
||||
if row is None:
|
||||
return None
|
||||
return cls(**row)
|
||||
|
||||
columns: ClassVar[str] = "mxid, tgid, tg_username, tg_phone, is_bot, saved_contacts"
|
||||
|
||||
@classmethod
|
||||
def get_by_tgid(cls, tgid: TelegramID) -> Optional['User']:
|
||||
return cls._select_one_or_none(cls.c.tgid == tgid)
|
||||
async def get_by_tgid(cls, tgid: TelegramID) -> User | None:
|
||||
q = f'SELECT {cls.columns} FROM "user" WHERE tgid=$1'
|
||||
return cls._from_row(await cls.db.fetchrow(q, tgid))
|
||||
|
||||
@classmethod
|
||||
def get_by_mxid(cls, mxid: UserID) -> Optional['User']:
|
||||
return cls._select_one_or_none(cls.c.mxid == mxid)
|
||||
async def get_by_mxid(cls, mxid: UserID) -> User | None:
|
||||
q = f'SELECT {cls.columns} FROM "user" WHERE mxid=$1'
|
||||
return cls._from_row(await cls.db.fetchrow(q, mxid))
|
||||
|
||||
@classmethod
|
||||
def get_by_username(cls, username: str) -> Optional['User']:
|
||||
return cls._select_one_or_none(func.lower(cls.c.tg_username) == username)
|
||||
async def find_by_username(cls, username: str) -> User | None:
|
||||
q = f'SELECT {cls.columns} FROM "user" WHERE lower(tg_username)=$1'
|
||||
return cls._from_row(await cls.db.fetchrow(q, username.lower()))
|
||||
|
||||
@classmethod
|
||||
async def all_with_tgid(cls) -> list[User]:
|
||||
q = f'SELECT {cls.columns} FROM "user" WHERE tgid IS NOT NULL'
|
||||
return [cls._from_row(row) for row in await cls.db.fetch(q)]
|
||||
|
||||
async def delete(self) -> None:
|
||||
await self.db.execute('DELETE FROM "user" WHERE mxid=$1', self.mxid)
|
||||
|
||||
@property
|
||||
def contacts(self) -> Iterable[TelegramID]:
|
||||
rows = self.db.execute(Contact.t.select().where(Contact.c.user == self.tgid))
|
||||
for row in rows:
|
||||
user, contact = row
|
||||
yield contact
|
||||
def _values(self):
|
||||
return (
|
||||
self.mxid, self.tgid, self.tg_username, self.tg_phone, self.is_bot, self.saved_contacts
|
||||
)
|
||||
|
||||
@contacts.setter
|
||||
def contacts(self, puppets: Iterable[TelegramID]) -> None:
|
||||
with self.db.begin() as conn:
|
||||
conn.execute(Contact.t.delete().where(Contact.c.user == self.tgid))
|
||||
insert_puppets = [{"user": self.tgid, "contact": tgid} for tgid in puppets]
|
||||
if insert_puppets:
|
||||
conn.execute(Contact.t.insert(), insert_puppets)
|
||||
async def save(self) -> None:
|
||||
q = (
|
||||
'UPDATE "user" SET tgid=$2, tg_username=$3, tg_phone=$4, is_bot=$5, saved_contacts=$6 '
|
||||
'WHERE mxid=$1'
|
||||
)
|
||||
await self.db.execute(q, *self._values)
|
||||
|
||||
@property
|
||||
def portals(self) -> Iterable[Tuple[TelegramID, TelegramID]]:
|
||||
rows = self.db.execute(UserPortal.t.select().where(UserPortal.c.user == self.tgid))
|
||||
for row in rows:
|
||||
user, portal, portal_receiver = row
|
||||
yield (portal, portal_receiver)
|
||||
async def insert(self) -> None:
|
||||
q = (
|
||||
'INSERT INTO "user" (mxid, tgid, tg_username, tg_phone, is_bot, saved_contacts) '
|
||||
'VALUES ($1, $2, $3, $4, $5, $6)'
|
||||
)
|
||||
await self.db.execute(q, *self._values)
|
||||
|
||||
@portals.setter
|
||||
def portals(self, portals: Iterable[Tuple[TelegramID, TelegramID]]) -> None:
|
||||
with self.db.begin() as conn:
|
||||
conn.execute(UserPortal.t.delete().where(UserPortal.c.user == self.tgid))
|
||||
insert_portals = [{
|
||||
"user": self.tgid,
|
||||
"portal": tgid,
|
||||
"portal_receiver": tg_receiver
|
||||
} for tgid, tg_receiver in portals]
|
||||
if insert_portals:
|
||||
conn.execute(UserPortal.t.insert(), insert_portals)
|
||||
async def get_contacts(self) -> list[TelegramID]:
|
||||
rows = await self.db.fetch('SELECT contact FROM contact WHERE "user"=$1', self.tgid)
|
||||
return [TelegramID(row["contact"]) for row in rows]
|
||||
|
||||
def delete(self) -> None:
|
||||
super().delete()
|
||||
self.portals = []
|
||||
self.contacts = []
|
||||
async def set_contacts(self, puppets: Iterable[TelegramID]) -> None:
|
||||
columns = ["user", "contact"]
|
||||
records = [(self.tgid, puppet_id) for puppet_id in puppets]
|
||||
async with self.db.acquire() as conn, conn.transaction():
|
||||
await conn.execute('DELETE FROM contact WHERE "user"=$1', self.tgid)
|
||||
if self.db.scheme == "postgres":
|
||||
await conn.copy_records_to_table("contact", records=records, columns=columns)
|
||||
else:
|
||||
q = 'INSERT INTO contact ("user", contact) VALUES ($1, $2)'
|
||||
await conn.executemany(q, records)
|
||||
|
||||
async def get_portals(self) -> list[tuple[TelegramID, TelegramID]]:
|
||||
q = 'SELECT portal, portal_receiver FROM user_portal WHERE "user"=$1'
|
||||
rows = await self.db.fetch(q, self.tgid)
|
||||
return [(TelegramID(row["portal"]), TelegramID(row["portal_receiver"])) for row in rows]
|
||||
|
||||
class UserPortal(Base):
|
||||
__tablename__ = "user_portal"
|
||||
async def set_portals(self, portals: Iterable[tuple[TelegramID, TelegramID]]) -> None:
|
||||
columns = ["user", "portal", "portal_receiver"]
|
||||
records = [(self.tgid, tgid, tg_receiver) for tgid, tg_receiver in portals]
|
||||
async with self.db.acquire() as conn, conn.transaction():
|
||||
await conn.execute('DELETE FROM user_portal WHERE "user"=$1', self.tgid)
|
||||
if self.db.scheme == "postgres":
|
||||
await conn.copy_records_to_table("user_portal", records=records, columns=columns)
|
||||
else:
|
||||
q = 'INSERT INTO user_portal ("user", portal, portal_receiver) VALUES ($1, $2, $3)'
|
||||
await conn.executemany(q, records)
|
||||
|
||||
user: TelegramID = Column(BigInteger, ForeignKey("user.tgid", onupdate="CASCADE",
|
||||
ondelete="CASCADE"), primary_key=True)
|
||||
portal: TelegramID = Column(BigInteger, primary_key=True)
|
||||
portal_receiver: TelegramID = Column(BigInteger, primary_key=True)
|
||||
async def register_portal(self, tgid: TelegramID, tg_receiver: TelegramID) -> None:
|
||||
q = ('INSERT INTO user_portal ("user", portal, portal_receiver) VALUES ($1, $2, $3) '
|
||||
'ON CONFLICT ("user", portal, portal_receiver) DO NOTHING')
|
||||
await self.db.execute(q, self.tgid, tgid, tg_receiver)
|
||||
|
||||
__table_args__ = (ForeignKeyConstraint(("portal", "portal_receiver"),
|
||||
("portal.tgid", "portal.tg_receiver"),
|
||||
onupdate="CASCADE", ondelete="CASCADE"),)
|
||||
|
||||
|
||||
class Contact(Base):
|
||||
__tablename__ = "contact"
|
||||
|
||||
user: TelegramID = Column(BigInteger, ForeignKey("user.tgid"), primary_key=True)
|
||||
contact: TelegramID = Column(BigInteger, ForeignKey("puppet.id"), primary_key=True)
|
||||
async def unregister_portal(self, tgid: TelegramID, tg_receiver: TelegramID) -> None:
|
||||
q = 'DELETE FROM user_portal WHERE "user"=$1 AND portal=$2 AND portal_receiver=$3'
|
||||
await self.db.execute(q, self.tgid, tgid, tg_receiver)
|
||||
|
||||
Reference in New Issue
Block a user