Switch from SQLAlchemy to asyncpg/aiosqlite

This commit is contained in:
Tulir Asokan
2021-12-20 22:39:09 +02:00
parent f12f3fe007
commit 89ab29ea5f
61 changed files with 4681 additions and 4628 deletions
+143 -204
View File
@@ -1,5 +1,5 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2020 Tulir Asokan
# Copyright (C) 2021 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@@ -13,111 +13,79 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Any, Dict, Iterable, Optional, Union, Tuple, TYPE_CHECKING
from __future__ import annotations
from typing import Awaitable, AsyncGenerator, AsyncIterable, TYPE_CHECKING, cast
from difflib import SequenceMatcher
import unicodedata
import asyncio
import logging
from telethon.tl.types import (UserProfilePhoto, User, UpdateUserName, PeerUser, TypeInputPeer,
InputPeerPhotoFileLocation, UserProfilePhotoEmpty, TypeInputUser)
from yarl import URL
from mautrix.appservice import AppService, IntentAPI
from mautrix.errors import MatrixRequestError, MatrixError
from mautrix.bridge import BasePuppet
from mautrix.appservice import IntentAPI
from mautrix.errors import MatrixError
from mautrix.bridge import BasePuppet, async_getter_lock
from mautrix.types import UserID, SyncToken, RoomID, ContentURI
from mautrix.util.simple_template import SimpleTemplate
from mautrix.util.logging import TraceLogger
from .config import Config
from .types import TelegramID
from .db import Puppet as DBPuppet
from . import util, portal as p
from . import util, portal as p, abstract_user as au
if TYPE_CHECKING:
from .matrix import MatrixHandler
from .config import Config
from .context import Context
from .abstract_user import AbstractUser
config: Optional['Config'] = None
from .__main__ import TelegramBridge
class Puppet(BasePuppet):
log: TraceLogger = logging.getLogger("mau.puppet")
az: AppService
mx: 'MatrixHandler'
loop: asyncio.AbstractEventLoop
class Puppet(DBPuppet, BasePuppet):
config: Config
hs_domain: str
mxid_template: SimpleTemplate[TelegramID]
displayname_template: SimpleTemplate[str]
cache: Dict[TelegramID, 'Puppet'] = {}
by_custom_mxid: Dict[UserID, 'Puppet'] = {}
by_tgid: dict[TelegramID, Puppet] = {}
by_custom_mxid: dict[UserID, Puppet] = {}
id: TelegramID
access_token: Optional[str]
custom_mxid: Optional[UserID]
_next_batch: Optional[SyncToken]
base_url: Optional[URL]
default_mxid: UserID
def __init__(
self,
id: TelegramID,
is_registered: bool = False,
displayname: str | None = None,
displayname_source: TelegramID | None = None,
displayname_contact: bool = True,
displayname_quality: int = 0,
disable_updates: bool = False,
username: str | None = None,
photo_id: str | None = None,
is_bot: bool = False,
custom_mxid: UserID | None = None,
access_token: str | None = None,
next_batch: SyncToken | None = None,
base_url: str | None = None
) -> None:
super().__init__(
id=id,
is_registered=is_registered,
displayname=displayname,
displayname_source=displayname_source,
displayname_contact=displayname_contact,
displayname_quality=displayname_quality,
disable_updates=disable_updates,
username=username,
photo_id=photo_id,
is_bot=is_bot,
custom_mxid=custom_mxid,
access_token=access_token,
next_batch=next_batch,
base_url=base_url,
)
username: Optional[str]
displayname: Optional[str]
displayname_source: Optional[TelegramID]
displayname_contact: bool
displayname_quality: int
photo_id: Optional[str]
is_bot: bool
is_registered: bool
disable_updates: bool
default_mxid_intent: IntentAPI
intent: IntentAPI
sync_task: Optional[asyncio.Future]
_db_instance: Optional[DBPuppet]
def __init__(self,
id: TelegramID,
access_token: Optional[str] = None,
custom_mxid: Optional[UserID] = None,
next_batch: Optional[SyncToken] = None,
base_url: Optional[str] = None,
username: Optional[str] = None,
displayname: Optional[str] = None,
displayname_source: Optional[TelegramID] = None,
displayname_contact: bool = True,
displayname_quality: int = 0,
photo_id: Optional[str] = None,
is_bot: bool = False,
is_registered: bool = False,
disable_updates: bool = False,
db_instance: Optional[DBPuppet] = None) -> None:
self.id = id
self.access_token = access_token
self.custom_mxid = custom_mxid
self._next_batch = next_batch
self.base_url = URL(base_url) if base_url else None
self.default_mxid = self.get_mxid_from_id(self.id)
self.username = username
self.displayname = displayname
self.displayname_source = displayname_source
self.displayname_contact = displayname_contact
self.displayname_quality = displayname_quality
self.photo_id = photo_id
self.is_bot = is_bot
self.is_registered = is_registered
self.disable_updates = disable_updates
self._db_instance = db_instance
self.default_mxid_intent = self.az.intent.user(self.default_mxid)
self.intent = self._fresh_intent()
self.sync_task = None
self.cache[id] = self
self.by_tgid[id] = self
if self.custom_mxid:
self.by_custom_mxid[self.custom_mxid] = self
@@ -128,76 +96,59 @@ class Puppet(BasePuppet):
return self.id
@property
def peer(self) -> PeerUser:
return PeerUser(user_id=self.tgid)
def tg_username(self) -> str | None:
return self.username
@property
def next_batch(self) -> SyncToken:
return self._next_batch
@next_batch.setter
def next_batch(self, value: SyncToken) -> None:
self._next_batch = value
self.db_instance.edit(next_batch=self._next_batch)
@staticmethod
async def is_logged_in() -> bool:
""" Is True if the puppet is logged in. """
return True
def peer(self) -> PeerUser:
return PeerUser(user_id=self.tgid)
@property
def plain_displayname(self) -> str:
return self.displayname_template.parse(self.displayname) or self.displayname
def get_input_entity(self, user: 'AbstractUser'
) -> Awaitable[Union[TypeInputPeer, TypeInputUser]]:
def get_input_entity(self, user: au.AbstractUser) -> Awaitable[TypeInputPeer | TypeInputUser]:
return user.client.get_input_entity(self.peer)
def intent_for(self, portal: 'p.Portal') -> IntentAPI:
def intent_for(self, portal: p.Portal) -> IntentAPI:
if portal.tgid == self.tgid:
return self.default_mxid_intent
return self.intent
# region DB conversion
@property
def db_instance(self) -> DBPuppet:
if not self._db_instance:
self._db_instance = self.new_db_instance()
return self._db_instance
@property
def _fields(self) -> Dict[str, Any]:
return dict(access_token=self.access_token, next_batch=self._next_batch,
custom_mxid=self.custom_mxid, username=self.username, is_bot=self.is_bot,
displayname=self.displayname, displayname_source=self.displayname_source,
displayname_contact=self.displayname_contact,
displayname_quality=self.displayname_quality, photo_id=self.photo_id,
matrix_registered=self.is_registered, disable_updates=self.disable_updates,
base_url=str(self.base_url) if self.base_url else None)
def new_db_instance(self) -> DBPuppet:
return DBPuppet(id=self.id, **self._fields)
async def save(self) -> None:
self.db_instance.edit(**self._fields)
@classmethod
def from_db(cls, db_puppet: DBPuppet) -> 'Puppet':
return Puppet(db_puppet.id, db_puppet.access_token, db_puppet.custom_mxid,
db_puppet.next_batch, db_puppet.base_url, db_puppet.username,
db_puppet.displayname, db_puppet.displayname_source,
db_puppet.displayname_contact, db_puppet.displayname_quality,
db_puppet.photo_id, db_puppet.is_bot, db_puppet.matrix_registered,
db_puppet.disable_updates, db_instance=db_puppet)
def init_cls(cls, bridge: 'TelegramBridge') -> AsyncIterable[Awaitable[None]]:
cls.config = bridge.config
cls.loop = bridge.loop
cls.mx = bridge.matrix
cls.az = bridge.az
cls.hs_domain = cls.config["homeserver.domain"]
mxid_tpl = SimpleTemplate(
cls.config["bridge.username_template"],
"userid",
prefix="@",
suffix=f":{Puppet.hs_domain}",
type=int,
)
cls.mxid_template = cast(SimpleTemplate[TelegramID], mxid_tpl)
cls.displayname_template = SimpleTemplate(
cls.config["bridge.displayname_template"], "displayname"
)
cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
cls.homeserver_url_map = {server: URL(url) for server, url
in cls.config["bridge.double_puppet_server_map"].items()}
cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"]
cls.login_shared_secret_map = {server: secret.encode("utf-8") for server, secret
in cls.config["bridge.login_shared_secret_map"].items()}
cls.login_device_name = "Telegram Bridge"
return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
# endregion
# region Info updating
def similarity(self, query: str) -> int:
username_similarity = (SequenceMatcher(None, self.username, query).ratio()
if self.username else 0)
displayname_similarity = (SequenceMatcher(None, self.displayname, query).ratio()
displayname_similarity = (SequenceMatcher(None, self.plain_displayname, query).ratio()
if self.displayname else 0)
similarity = max(username_similarity, displayname_similarity)
return int(round(similarity * 100))
@@ -211,11 +162,11 @@ class Puppet(BasePuppet):
"\u200c\u200d\u200e\u200f\ufe0f")
allowed_other_format = ("\u200d", "\u200c")
name = "".join(c for c in name.strip(whitespace) if unicodedata.category(c) != 'Cf'
or c in allowed_other_format)
or c in allowed_other_format)
return name
@classmethod
def get_displayname(cls, info: User, enable_format: bool = True) -> Tuple[str, int]:
def get_displayname(cls, info: User, enable_format: bool = True) -> tuple[str, int]:
fn = cls._filter_name(info.first_name)
ln = cls._filter_name(info.last_name)
data = {
@@ -226,7 +177,7 @@ class Puppet(BasePuppet):
"first name": fn,
"last name": ln,
}
preferences = config["bridge.displayname_preference"]
preferences = cls.config["bridge.displayname_preference"]
name = None
quality = 99
for preference in preferences:
@@ -244,13 +195,13 @@ class Puppet(BasePuppet):
return (cls.displayname_template.format_full(name) if enable_format else name), quality
async def try_update_info(self, source: 'AbstractUser', info: User) -> None:
async def try_update_info(self, source: au.AbstractUser, info: User) -> None:
try:
await self.update_info(source, info)
except Exception:
source.log.exception(f"Failed to update info of {self.tgid}")
async def update_info(self, source: 'AbstractUser', info: User) -> None:
async def update_info(self, source: au.AbstractUser, info: User) -> None:
changed = False
if self.username != info.username:
self.username = info.username
@@ -268,7 +219,7 @@ class Puppet(BasePuppet):
if changed:
await self.save()
async def update_displayname(self, source: 'AbstractUser', info: Union[User, UpdateUserName]
async def update_displayname(self, source: au.AbstractUser, info: User | UpdateUserName
) -> bool:
if self.disable_updates:
return False
@@ -306,7 +257,7 @@ class Puppet(BasePuppet):
self.displayname_quality = quality
try:
await self.default_mxid_intent.set_displayname(
displayname[:config["bridge.displayname_max_length"]])
displayname[:self.config["bridge.displayname_max_length"]])
except MatrixError:
self.log.exception("Failed to set displayname")
self.displayname = ""
@@ -318,8 +269,8 @@ class Puppet(BasePuppet):
return True
return False
async def update_avatar(self, source: 'AbstractUser',
photo: Union[UserProfilePhoto, UserProfilePhotoEmpty]) -> bool:
async def update_avatar(self, source: au.AbstractUser,
photo: UserProfilePhoto | UserProfilePhotoEmpty) -> bool:
if self.disable_updates:
return False
@@ -330,7 +281,7 @@ class Puppet(BasePuppet):
else:
self.log.warning(f"Unknown user profile photo type: {type(photo)}")
return False
if not photo_id and not config["bridge.allow_avatar_remove"]:
if not photo_id and not self.config["bridge.allow_avatar_remove"]:
return False
if self.photo_id != photo_id:
if not photo_id:
@@ -359,72 +310,73 @@ class Puppet(BasePuppet):
return False
async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
portal: p.Portal = p.Portal.get_by_mxid(room_id)
portal: p.Portal = await p.Portal.get_by_mxid(room_id)
return portal and not portal.backfill_lock.locked and portal.peer_type != "user"
# endregion
# region Getters
def _add_to_cache(self) -> None:
self.by_tgid[self.id] = self
if self.custom_mxid:
self.by_custom_mxid[self.custom_mxid] = self
@classmethod
def get(cls, tgid: TelegramID, create: bool = True) -> Optional['Puppet']:
@async_getter_lock
async def get_by_tgid(cls, tgid: TelegramID, *, create: bool = True) -> Puppet | None:
if tgid is None:
return None
try:
return cls.cache[tgid]
return cls.by_tgid[tgid]
except KeyError:
pass
puppet = DBPuppet.get_by_tgid(tgid)
puppet = cast(cls, await super().get_by_tgid(tgid))
if puppet:
return cls.from_db(puppet)
puppet._add_to_cache()
return puppet
if create:
puppet = cls(tgid)
puppet.db_instance.insert()
await puppet.insert()
puppet._add_to_cache()
return puppet
return None
@classmethod
def deprecated_sync_get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['Puppet']:
tgid = cls.get_id_from_mxid(mxid)
if tgid:
return cls.get(tgid, create)
return None
def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Awaitable[Puppet | None]:
return cls.get_by_tgid(cls.get_id_from_mxid(mxid), create=create)
@classmethod
async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['Puppet']:
return cls.deprecated_sync_get_by_mxid(mxid, create)
@classmethod
def deprecated_sync_get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
if not mxid:
raise ValueError("Matrix ID can't be empty")
@async_getter_lock
async def get_by_custom_mxid(cls, mxid: UserID) -> Puppet | None:
try:
return cls.by_custom_mxid[mxid]
except KeyError:
pass
puppet = DBPuppet.get_by_custom_mxid(mxid)
puppet = cast(cls, await super().get_by_custom_mxid(mxid))
if puppet:
puppet = cls.from_db(puppet)
puppet._add_to_cache()
return puppet
return None
@classmethod
async def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
return cls.deprecated_sync_get_by_custom_mxid(mxid)
async def all_with_custom_mxid(cls) -> AsyncGenerator[Puppet, None]:
puppets = await super().all_with_custom_mxid()
puppet: cls
for puppet in puppets:
try:
yield cls.by_tgid[puppet.tgid]
except KeyError:
puppet._add_to_cache()
yield puppet
@classmethod
def all_with_custom_mxid(cls) -> Iterable['Puppet']:
return (cls.by_custom_mxid[puppet.custom_mxid]
if puppet.custom_mxid in cls.by_custom_mxid
else cls.from_db(puppet)
for puppet in DBPuppet.all_with_custom_mxid())
@classmethod
def get_id_from_mxid(cls, mxid: UserID) -> Optional[TelegramID]:
def get_id_from_mxid(cls, mxid: UserID) -> TelegramID | None:
return cls.mxid_template.parse(mxid)
@classmethod
@@ -432,56 +384,43 @@ class Puppet(BasePuppet):
return UserID(cls.mxid_template.format_full(tgid))
@classmethod
def find_by_username(cls, username: str) -> Optional['Puppet']:
async def find_by_username(cls, username: str) -> Puppet | None:
if not username:
return None
username = username.lower()
for _, puppet in cls.cache.items():
for _, puppet in cls.by_tgid.items():
if puppet.username and puppet.username.lower() == username:
return puppet
dbpuppet = DBPuppet.get_by_username(username)
if dbpuppet:
return cls.from_db(dbpuppet)
puppet = cast(cls, await super().find_by_username(username))
if puppet:
try:
return cls.by_tgid[puppet.tgid]
except KeyError:
puppet._add_to_cache()
return puppet
return None
@classmethod
def find_by_displayname(cls, displayname: str) -> Optional['Puppet']:
async def find_by_displayname(cls, displayname: str) -> Puppet | None:
if not displayname:
return None
for _, puppet in cls.cache.items():
for _, puppet in cls.by_tgid.items():
if puppet.displayname and puppet.displayname == displayname:
return puppet
dbpuppet = DBPuppet.get_by_displayname(displayname)
if dbpuppet:
return cls.from_db(dbpuppet)
puppet = cast(cls, await super().find_by_displayname(displayname))
if puppet:
try:
return cls.by_tgid[puppet.tgid]
except KeyError:
puppet._add_to_cache()
return puppet
return None
# endregion
def init(context: 'Context') -> Iterable[Awaitable[Any]]:
global config
Puppet.az, config, Puppet.loop, _ = context.core
Puppet.mx = context.mx
Puppet.hs_domain = config["homeserver"]["domain"]
Puppet.mxid_template = SimpleTemplate(config["bridge.username_template"], "userid",
prefix="@", suffix=f":{Puppet.hs_domain}", type=int)
Puppet.displayname_template = SimpleTemplate(config["bridge.displayname_template"],
"displayname")
Puppet.sync_with_custom_puppets = config["bridge.sync_with_custom_puppets"]
Puppet.homeserver_url_map = {server: URL(url) for server, url
in config["bridge.double_puppet_server_map"].items()}
Puppet.allow_discover_url = config["bridge.double_puppet_allow_discovery"]
Puppet.login_shared_secret_map = {server: secret.encode("utf-8") for server, secret
in config["bridge.login_shared_secret_map"].items()}
Puppet.login_device_name = "Telegram Bridge"
return (puppet.try_start() for puppet in Puppet.all_with_custom_mxid())