Add basic support for bridging custom emojis from Telegram

This commit is contained in:
Tulir Asokan
2022-08-12 21:35:50 +03:00
parent 473ab17fe7
commit 76eafbf48c
5 changed files with 113 additions and 21 deletions
+33 -13
View File
@@ -17,10 +17,11 @@ from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from asyncpg import Record
from attr import dataclass
from mautrix.types import ContentURI, EncryptedFile
from mautrix.util.async_db import Database
from mautrix.util.async_db import Database, Scheme
fake_db = Database.create("") if TYPE_CHECKING else None
@@ -40,28 +41,47 @@ class TelegramFile:
decryption_info: EncryptedFile | None
thumbnail: TelegramFile | None = None
columns: ClassVar[str] = (
"id, mxc, mime_type, was_converted, timestamp, size, width, height, thumbnail, "
"decryption_info"
)
@classmethod
async def get(cls, loc_id: str, *, _thumbnail: bool = False) -> TelegramFile | None:
q = (
"SELECT id, mxc, mime_type, was_converted, timestamp, size, width, height, thumbnail,"
" decryption_info "
"FROM telegram_file WHERE id=$1"
)
row = await cls.db.fetchrow(q, loc_id)
def _from_row(cls, row: Record | None) -> TelegramFile | None:
if row is None:
return None
data = {**row}
thumbnail_id = data.pop("thumbnail", None)
if _thumbnail:
# Don't allow more than one level of recursion
thumbnail_id = None
data.pop("thumbnail", None)
decryption_info = data.pop("decryption_info", None)
return cls(
**data,
thumbnail=(await cls.get(thumbnail_id, _thumbnail=True)) if thumbnail_id else None,
thumbnail=None,
decryption_info=EncryptedFile.parse_json(decryption_info) if decryption_info else None,
)
@classmethod
async def get_many(cls, loc_ids: list[str]) -> list[TelegramFile]:
if cls.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
q = f"SELECT {cls.columns} FROM telegram_file WHERE id=ANY($1)"
rows = await cls.db.fetch(q, loc_ids)
else:
tgid_placeholders = ("?," * len(loc_ids)).rstrip(",")
q = f"SELECT {cls.columns} FROM telegram_file WHERE id IN ({tgid_placeholders})"
rows = await cls.db.fetch(q, *loc_ids)
return [cls._from_row(row) for row in rows]
@classmethod
async def get(cls, loc_id: str, *, _thumbnail: bool = False) -> TelegramFile | None:
q = f"SELECT {cls.columns} FROM telegram_file WHERE id=$1"
row = await cls.db.fetchrow(q, loc_id)
file = cls._from_row(row)
if file is None:
return None
thumbnail_id = row.get("thumbnail", None)
if thumbnail_id and not _thumbnail:
file.thumbnail = await cls.get(thumbnail_id, _thumbnail=True)
return file
async def insert(self) -> None:
q = (
"INSERT INTO telegram_file (id, mxc, mime_type, was_converted, size, width, height, "