Switch from SQLAlchemy to asyncpg/aiosqlite

This commit is contained in:
Tulir Asokan
2021-12-20 22:39:09 +02:00
parent f12f3fe007
commit 89ab29ea5f
61 changed files with 4681 additions and 4628 deletions
+91 -68
View File
@@ -1,5 +1,5 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2021 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@@ -13,96 +13,119 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Iterable, Tuple
from __future__ import annotations
from sqlalchemy import Column, ForeignKey, ForeignKeyConstraint, BigInteger, Integer, String, func
from typing import Iterable, ClassVar, TYPE_CHECKING
from asyncpg import Record
from attr import dataclass
from mautrix.types import UserID
from mautrix.util.db import Base
from mautrix.util.async_db import Database
from ..types import TelegramID
fake_db = Database.create("") if TYPE_CHECKING else None
class User(Base):
__tablename__ = "user"
mxid: UserID = Column(String, primary_key=True)
tgid: Optional[TelegramID] = Column(BigInteger, nullable=True, unique=True)
tg_username: str = Column(String, nullable=True)
tg_phone: str = Column(String, nullable=True)
saved_contacts: int = Column(Integer, default=0, nullable=False)
@dataclass
class User:
db: ClassVar[Database] = fake_db
mxid: UserID
tgid: TelegramID | None
tg_username: str | None
tg_phone: str | None
is_bot: bool
saved_contacts: int
@classmethod
def all_with_tgid(cls) -> Iterable['User']:
return cls._select_all(cls.c.tgid != None)
def _from_row(cls, row: Record | None) -> User | None:
if row is None:
return None
return cls(**row)
columns: ClassVar[str] = "mxid, tgid, tg_username, tg_phone, is_bot, saved_contacts"
@classmethod
def get_by_tgid(cls, tgid: TelegramID) -> Optional['User']:
return cls._select_one_or_none(cls.c.tgid == tgid)
async def get_by_tgid(cls, tgid: TelegramID) -> User | None:
q = f'SELECT {cls.columns} FROM "user" WHERE tgid=$1'
return cls._from_row(await cls.db.fetchrow(q, tgid))
@classmethod
def get_by_mxid(cls, mxid: UserID) -> Optional['User']:
return cls._select_one_or_none(cls.c.mxid == mxid)
async def get_by_mxid(cls, mxid: UserID) -> User | None:
q = f'SELECT {cls.columns} FROM "user" WHERE mxid=$1'
return cls._from_row(await cls.db.fetchrow(q, mxid))
@classmethod
def get_by_username(cls, username: str) -> Optional['User']:
return cls._select_one_or_none(func.lower(cls.c.tg_username) == username)
async def find_by_username(cls, username: str) -> User | None:
q = f'SELECT {cls.columns} FROM "user" WHERE lower(tg_username)=$1'
return cls._from_row(await cls.db.fetchrow(q, username.lower()))
@classmethod
async def all_with_tgid(cls) -> list[User]:
q = f'SELECT {cls.columns} FROM "user" WHERE tgid IS NOT NULL'
return [cls._from_row(row) for row in await cls.db.fetch(q)]
async def delete(self) -> None:
await self.db.execute('DELETE FROM "user" WHERE mxid=$1', self.mxid)
@property
def contacts(self) -> Iterable[TelegramID]:
rows = self.db.execute(Contact.t.select().where(Contact.c.user == self.tgid))
for row in rows:
user, contact = row
yield contact
def _values(self):
return (
self.mxid, self.tgid, self.tg_username, self.tg_phone, self.is_bot, self.saved_contacts
)
@contacts.setter
def contacts(self, puppets: Iterable[TelegramID]) -> None:
with self.db.begin() as conn:
conn.execute(Contact.t.delete().where(Contact.c.user == self.tgid))
insert_puppets = [{"user": self.tgid, "contact": tgid} for tgid in puppets]
if insert_puppets:
conn.execute(Contact.t.insert(), insert_puppets)
async def save(self) -> None:
q = (
'UPDATE "user" SET tgid=$2, tg_username=$3, tg_phone=$4, is_bot=$5, saved_contacts=$6 '
'WHERE mxid=$1'
)
await self.db.execute(q, *self._values)
@property
def portals(self) -> Iterable[Tuple[TelegramID, TelegramID]]:
rows = self.db.execute(UserPortal.t.select().where(UserPortal.c.user == self.tgid))
for row in rows:
user, portal, portal_receiver = row
yield (portal, portal_receiver)
async def insert(self) -> None:
q = (
'INSERT INTO "user" (mxid, tgid, tg_username, tg_phone, is_bot, saved_contacts) '
'VALUES ($1, $2, $3, $4, $5, $6)'
)
await self.db.execute(q, *self._values)
@portals.setter
def portals(self, portals: Iterable[Tuple[TelegramID, TelegramID]]) -> None:
with self.db.begin() as conn:
conn.execute(UserPortal.t.delete().where(UserPortal.c.user == self.tgid))
insert_portals = [{
"user": self.tgid,
"portal": tgid,
"portal_receiver": tg_receiver
} for tgid, tg_receiver in portals]
if insert_portals:
conn.execute(UserPortal.t.insert(), insert_portals)
async def get_contacts(self) -> list[TelegramID]:
rows = await self.db.fetch('SELECT contact FROM contact WHERE "user"=$1', self.tgid)
return [TelegramID(row["contact"]) for row in rows]
def delete(self) -> None:
super().delete()
self.portals = []
self.contacts = []
async def set_contacts(self, puppets: Iterable[TelegramID]) -> None:
columns = ["user", "contact"]
records = [(self.tgid, puppet_id) for puppet_id in puppets]
async with self.db.acquire() as conn, conn.transaction():
await conn.execute('DELETE FROM contact WHERE "user"=$1', self.tgid)
if self.db.scheme == "postgres":
await conn.copy_records_to_table("contact", records=records, columns=columns)
else:
q = 'INSERT INTO contact ("user", contact) VALUES ($1, $2)'
await conn.executemany(q, records)
async def get_portals(self) -> list[tuple[TelegramID, TelegramID]]:
q = 'SELECT portal, portal_receiver FROM user_portal WHERE "user"=$1'
rows = await self.db.fetch(q, self.tgid)
return [(TelegramID(row["portal"]), TelegramID(row["portal_receiver"])) for row in rows]
class UserPortal(Base):
__tablename__ = "user_portal"
async def set_portals(self, portals: Iterable[tuple[TelegramID, TelegramID]]) -> None:
columns = ["user", "portal", "portal_receiver"]
records = [(self.tgid, tgid, tg_receiver) for tgid, tg_receiver in portals]
async with self.db.acquire() as conn, conn.transaction():
await conn.execute('DELETE FROM user_portal WHERE "user"=$1', self.tgid)
if self.db.scheme == "postgres":
await conn.copy_records_to_table("user_portal", records=records, columns=columns)
else:
q = 'INSERT INTO user_portal ("user", portal, portal_receiver) VALUES ($1, $2, $3)'
await conn.executemany(q, records)
user: TelegramID = Column(BigInteger, ForeignKey("user.tgid", onupdate="CASCADE",
ondelete="CASCADE"), primary_key=True)
portal: TelegramID = Column(BigInteger, primary_key=True)
portal_receiver: TelegramID = Column(BigInteger, primary_key=True)
async def register_portal(self, tgid: TelegramID, tg_receiver: TelegramID) -> None:
q = ('INSERT INTO user_portal ("user", portal, portal_receiver) VALUES ($1, $2, $3) '
'ON CONFLICT ("user", portal, portal_receiver) DO NOTHING')
await self.db.execute(q, self.tgid, tgid, tg_receiver)
__table_args__ = (ForeignKeyConstraint(("portal", "portal_receiver"),
("portal.tgid", "portal.tg_receiver"),
onupdate="CASCADE", ondelete="CASCADE"),)
class Contact(Base):
__tablename__ = "contact"
user: TelegramID = Column(BigInteger, ForeignKey("user.tgid"), primary_key=True)
contact: TelegramID = Column(BigInteger, ForeignKey("puppet.id"), primary_key=True)
async def unregister_portal(self, tgid: TelegramID, tg_receiver: TelegramID) -> None:
q = 'DELETE FROM user_portal WHERE "user"=$1 AND portal=$2 AND portal_receiver=$3'
await self.db.execute(q, self.tgid, tgid, tg_receiver)