Switch from SQLAlchemy to asyncpg/aiosqlite

This commit is contained in:
Tulir Asokan
2021-12-20 22:39:09 +02:00
parent f12f3fe007
commit 89ab29ea5f
61 changed files with 4681 additions and 4628 deletions
+210 -216
View File
@@ -13,10 +13,10 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import (Awaitable, Dict, List, Iterable, NamedTuple, Optional, Tuple, Any, cast,
TYPE_CHECKING)
from __future__ import annotations
from typing import Awaitable, AsyncIterable, NamedTuple, AsyncGenerator, TYPE_CHECKING, cast
from datetime import datetime, timezone
import logging
import asyncio
from telethon.tl.types import (TypeUpdate, UpdateNewMessage, UpdateNewChannelMessage,
@@ -35,21 +35,17 @@ from mautrix.client import Client
from mautrix.errors import MatrixRequestError, MNotFound
from mautrix.types import UserID, RoomID, PushRuleScope, PushRuleKind, PushActionType, RoomTagInfo
from mautrix.appservice import DOUBLE_PUPPET_SOURCE_KEY
from mautrix.bridge import BaseUser
from mautrix.bridge import BaseUser, async_getter_lock
from mautrix.util.bridge_state import BridgeState, BridgeStateEvent
from mautrix.util.logging import TraceLogger
from mautrix.util.opt_prometheus import Gauge
from .types import TelegramID
from .db import User as DBUser, Portal as DBPortal, Message as DBMessage
from .db import User as DBUser, Message as DBMessage, PgSession
from .abstract_user import AbstractUser
from . import portal as po, puppet as pu
if TYPE_CHECKING:
from .config import Config
from .context import Context
config: Optional['Config'] = None
from .__main__ import TelegramBridge
SearchResult = NamedTuple('SearchResult', puppet='pu.Puppet', similarity=int)
@@ -64,54 +60,46 @@ BridgeState.human_readable_errors.update({
})
class User(AbstractUser, BaseUser):
log: TraceLogger = logging.getLogger("mau.user")
by_mxid: Dict[str, 'User'] = {}
by_tgid: Dict[int, 'User'] = {}
class User(DBUser, AbstractUser, BaseUser):
by_mxid: dict[str, User] = {}
by_tgid: dict[int, User] = {}
phone: Optional[str]
contacts: List['pu.Puppet']
saved_contacts: int
portals: Dict[Tuple[TelegramID, TelegramID], 'po.Portal']
command_status: Optional[Dict[str, Any]]
_portals_cache: dict[tuple[TelegramID, TelegramID], po.Portal] | None
_db_instance: Optional[DBUser]
_ensure_started_lock: asyncio.Lock
_track_connection_task: Optional[asyncio.Task]
_track_connection_task: asyncio.Task | None
_is_backfilling: bool
def __init__(self, mxid: UserID, tgid: Optional[TelegramID] = None,
username: Optional[str] = None, phone: Optional[str] = None,
db_contacts: Optional[Iterable[TelegramID]] = None,
saved_contacts: int = 0, is_bot: bool = False,
db_portals: Optional[Iterable[Tuple[TelegramID, TelegramID]]] = None,
db_instance: Optional[DBUser] = None) -> None:
def __init__(
self,
mxid: UserID,
tgid: TelegramID | None = None,
tg_username: str | None = None,
tg_phone: str | None = None,
is_bot: bool = False,
saved_contacts: int = 0,
) -> None:
super().__init__(
mxid=mxid,
tgid=tgid,
tg_username=tg_username,
tg_phone=tg_phone,
is_bot=is_bot,
saved_contacts=saved_contacts,
)
AbstractUser.__init__(self)
self.mxid = mxid
BaseUser.__init__(self)
self.tgid = tgid
self.is_bot = is_bot
self.username = username
self.phone = phone
self.contacts = []
self.saved_contacts = saved_contacts
self.db_contacts = db_contacts
self.portals = {}
self.db_portals = db_portals or []
self._db_instance = db_instance
self._ensure_started_lock = asyncio.Lock()
self._track_connection_task = None
self._is_backfilling = False
self._portals_cache = None
(self.relaybot_whitelisted,
self.whitelisted,
self.puppet_whitelisted,
self.matrix_puppet_whitelisted,
self.is_admin,
self.permissions) = config.get_permissions(self.mxid)
self.by_mxid[mxid] = self
if tgid:
self.by_tgid[tgid] = self
self.permissions) = self.config.get_permissions(self.mxid)
@property
def name(self) -> str:
@@ -124,7 +112,7 @@ class User(AbstractUser, BaseUser):
@property
def human_tg_id(self) -> str:
return f"@{self.username}" if self.username else f"+{self.phone}" or None
return f"@{self.tg_username}" if self.tg_username else f"+{self.tg_phone}" or None
# TODO replace with proper displayname getting everywhere
@property
@@ -135,65 +123,15 @@ class User(AbstractUser, BaseUser):
def plain_displayname(self) -> str:
return self.displayname
@property
def db_contacts(self) -> Iterable[TelegramID]:
return (puppet.id
for puppet in self.contacts
if puppet)
@db_contacts.setter
def db_contacts(self, contacts: Iterable[TelegramID]) -> None:
self.contacts = [pu.Puppet.get(entry) for entry in contacts] if contacts else []
@property
def db_portals(self) -> Iterable[Tuple[TelegramID, TelegramID]]:
return (portal.tgid_full
for portal in self.portals.values()
if portal and not portal.deleted)
@db_portals.setter
def db_portals(self, portals: Iterable[Tuple[TelegramID, TelegramID]]) -> None:
self.portals = {
tgid_full: po.Portal.get_by_tgid(*tgid_full)
for tgid_full in portals
} if portals else {}
# region Database conversion
@property
def db_instance(self) -> DBUser:
if not self._db_instance:
self._db_instance = self.new_db_instance()
return self._db_instance
def new_db_instance(self) -> DBUser:
return DBUser(mxid=self.mxid, tgid=self.tgid, tg_username=self.username,
saved_contacts=self.saved_contacts, portals=self.db_portals)
async def save(self, contacts: bool = False, portals: bool = False) -> None:
self.db_instance.edit(tgid=self.tgid, tg_username=self.username, tg_phone=self.phone,
saved_contacts=self.saved_contacts)
if contacts:
self.db_instance.contacts = self.db_contacts
if portals:
self.db_instance.portals = self.db_portals
def delete(self, delete_db: bool = True) -> None:
try:
del self.by_mxid[self.mxid]
del self.by_tgid[self.tgid]
except KeyError:
pass
if delete_db and self._db_instance:
self._db_instance.delete()
@classmethod
def from_db(cls, db_user: DBUser) -> 'User':
return User(db_user.mxid, db_user.tgid, db_user.tg_username, db_user.tg_phone,
db_user.contacts, db_user.saved_contacts, False, db_user.portals,
db_instance=db_user)
def init_cls(cls, bridge: 'TelegramBridge') -> AsyncIterable[Awaitable[User]]:
cls.config = bridge.config
cls.bridge = bridge
cls.az = bridge.az
cls.loop = bridge.loop
return (user.try_ensure_started() async for user in cls.all_with_tgid())
# endregion
# region Telegram connection management
async def try_ensure_started(self) -> None:
@@ -202,19 +140,19 @@ class User(AbstractUser, BaseUser):
except Exception:
self.log.exception("Exception in ensure_started")
else:
if not self.client and not self.session_container.has_session(self.mxid):
if not self.client and not await PgSession.has(self.mxid):
self.log.warning("Didn't start user: no session stored")
if self.tgid:
await self.push_bridge_state(BridgeStateEvent.BAD_CREDENTIALS,
error="tg-no-auth")
async def ensure_started(self, even_if_no_session=False) -> 'User':
async def ensure_started(self, even_if_no_session=False) -> User:
if not self.puppet_whitelisted or self.connected:
return self
async with self._ensure_started_lock:
return cast(User, await super().ensure_started(even_if_no_session))
async def start(self, delete_unless_authenticated: bool = False) -> 'User':
async def start(self, delete_unless_authenticated: bool = False) -> User:
try:
await super().start()
except AuthKeyDuplicatedError:
@@ -222,7 +160,7 @@ class User(AbstractUser, BaseUser):
await self.push_bridge_state(BridgeStateEvent.BAD_CREDENTIALS,
error="tg-auth-key-duplicated")
await self.client.disconnect()
self.client.session.delete()
await self.client.session.delete()
self.client = None
if not delete_unless_authenticated:
# The caller wants the client to be connected, so restart the connection.
@@ -257,7 +195,7 @@ class User(AbstractUser, BaseUser):
if delete_unless_authenticated:
self.log.debug(f"Unauthenticated user {self.name} start()ed, deleting session...")
await self.client.disconnect()
self.client.session.delete()
await self.client.session.delete()
return self
@property
@@ -283,7 +221,7 @@ class User(AbstractUser, BaseUser):
state.remote_id = str(self.tgid)
state.remote_name = self.human_tg_id
async def get_bridge_states(self) -> List[BridgeState]:
async def get_bridge_states(self) -> list[BridgeState]:
if not self.tgid:
return []
if self._is_connected and await self.is_logged_in():
@@ -295,10 +233,10 @@ class User(AbstractUser, BaseUser):
ttl = 240
return [BridgeState(state_event=state_event, ttl=ttl)]
async def get_puppet(self) -> Optional['pu.Puppet']:
async def get_puppet(self) -> pu.Puppet | None:
if not self.tgid:
return None
return pu.Puppet.get(self.tgid)
return await pu.Puppet.get_by_tgid(self.tgid)
async def stop(self) -> None:
if self._track_connection_task:
@@ -308,7 +246,7 @@ class User(AbstractUser, BaseUser):
self._track_metric(METRIC_CONNECTED, False)
async def post_login(self, info: TLUser = None, first_login: bool = False) -> None:
if config["metrics.enabled"] and not self._track_connection_task:
if self.config["metrics.enabled"] and not self._track_connection_task:
self._track_connection_task = self.loop.create_task(self._track_connection())
try:
@@ -320,14 +258,14 @@ class User(AbstractUser, BaseUser):
self._track_metric(METRIC_LOGGED_IN, True)
try:
puppet = pu.Puppet.get(self.tgid)
puppet = await pu.Puppet.get_by_tgid(self.tgid)
if puppet.custom_mxid != self.mxid and puppet.can_auto_login(self.mxid):
self.log.info(f"Automatically enabling custom puppet")
await puppet.switch_mxid(access_token="auto", mxid=self.mxid)
except Exception:
self.log.exception("Failed to automatically enable custom puppet")
if not self.is_bot and config["bridge.startup_sync"]:
if not self.is_bot and self.config["bridge.startup_sync"]:
try:
self._is_backfilling = True
await self.sync_dialogs()
@@ -342,11 +280,13 @@ class User(AbstractUser, BaseUser):
return False
if isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)):
portal = po.Portal.get_by_entity(update.message.peer_id, receiver_id=self.tgid)
portal = await po.Portal.get_by_entity(update.message.peer_id, tg_receiver=self.tgid)
elif isinstance(update, UpdateShortChatMessage):
portal = po.Portal.get_by_tgid(TelegramID(update.chat_id))
portal = await po.Portal.get_by_tgid(TelegramID(update.chat_id))
elif isinstance(update, UpdateShortMessage):
portal = po.Portal.get_by_tgid(TelegramID(update.user_id), self.tgid, "user")
portal = await po.Portal.get_by_tgid(
TelegramID(update.user_id), tg_receiver=self.tgid, peer_type="user"
)
else:
return False
@@ -364,7 +304,7 @@ class User(AbstractUser, BaseUser):
if not self.is_bot:
await self.client(UpdateStatusRequest(offline=not online))
async def get_me(self) -> Optional[TLUser]:
async def get_me(self) -> TLUser | None:
try:
return (await self.client(GetUsersRequest([InputUserSelf()])))[0]
except UnauthorizedError as e:
@@ -384,11 +324,11 @@ class User(AbstractUser, BaseUser):
if self.is_bot != info.bot:
self.is_bot = info.bot
changed = True
if self.username != info.username:
self.username = info.username
if self.tg_username != info.username:
self.tg_username = info.username
changed = True
if self.phone != info.phone:
self.phone = info.phone
if self.tg_phone != info.phone:
self.tg_phone = info.phone
changed = True
if self.tgid != info.id:
self.tgid = TelegramID(info.id)
@@ -396,11 +336,11 @@ class User(AbstractUser, BaseUser):
if changed:
await self.save()
async def log_out(self) -> bool:
puppet = pu.Puppet.get(self.tgid)
if puppet.is_real_user:
await puppet.switch_mxid(None, None)
for _, portal in self.portals.items():
async def kick_from_portals(self) -> None:
if not self.config["bridge.kick_on_logout"]:
return
portals = await self.get_cached_portals()
for _, portal in portals.values():
if not portal or portal.deleted or not portal.mxid or portal.has_bot:
continue
if portal.peer_type == "user":
@@ -411,9 +351,15 @@ class User(AbstractUser, BaseUser):
"Logged out of Telegram.")
except MatrixRequestError:
pass
self.portals = {}
self.contacts = []
await self.save(portals=True, contacts=True)
async def log_out(self) -> bool:
puppet = await pu.Puppet.get_by_tgid(self.tgid)
if puppet.is_real_user:
await puppet.switch_mxid(None, None)
try:
await self.kick_from_portals()
except Exception:
self.log.exception("Failed to kick user from portals on logout")
await self.push_bridge_state(BridgeStateEvent.LOGGED_OUT)
if self.tgid:
try:
@@ -421,51 +367,54 @@ class User(AbstractUser, BaseUser):
except KeyError:
pass
self.tgid = None
await self.save()
ok = await self.client.log_out()
self.client.session.delete()
self.delete()
await self.client.session.delete()
await self.delete()
self.by_mxid.pop(self.mxid, None)
await self.stop()
self._track_metric(METRIC_LOGGED_IN, False)
return ok
def _search_local(self, query: str, max_results: int = 5, min_similarity: int = 45
) -> List[SearchResult]:
results: List[SearchResult] = []
for contact in self.contacts:
async def _search_local(self, query: str, max_results: int = 5, min_similarity: int = 45
) -> list[SearchResult]:
results: list[SearchResult] = []
for contact_id in await self.get_contacts():
contact = await pu.Puppet.get_by_tgid(contact_id, create=False)
if not contact:
continue
similarity = contact.similarity(query)
if similarity >= min_similarity:
results.append(SearchResult(contact, similarity))
results.sort(key=lambda tup: tup[1], reverse=True)
return results[0:max_results]
async def _search_remote(self, query: str, max_results: int = 5) -> List[SearchResult]:
async def _search_remote(self, query: str, max_results: int = 5) -> list[SearchResult]:
if len(query) < 5:
return []
server_results = await self.client(SearchRequest(q=query, limit=max_results))
results: List[SearchResult] = []
results: list[SearchResult] = []
for user in server_results.users:
puppet = pu.Puppet.get(user.id)
puppet = await pu.Puppet.get_by_tgid(user.id)
await puppet.update_info(self, user)
results.append(SearchResult(puppet, puppet.similarity(query)))
results.sort(key=lambda tup: tup[1], reverse=True)
return results[0:max_results]
async def search(self, query: str, force_remote: bool = False
) -> Tuple[List[SearchResult], bool]:
) -> tuple[list[SearchResult], bool]:
if force_remote:
return await self._search_remote(query), True
results = self._search_local(query)
results = await self._search_local(query)
if results:
return results, False
return await self._search_remote(query), True
async def get_direct_chats(self) -> Dict[UserID, List[RoomID]]:
async def get_direct_chats(self) -> dict[UserID, list[RoomID]]:
return {
pu.Puppet.get_mxid_from_id(portal.tgid): [portal.mxid]
for portal in DBPortal.find_private_chats(self.tgid)
async for portal in po.Portal.find_private_chats(self.tgid)
if portal.mxid
}
@@ -478,12 +427,14 @@ class User(AbstractUser, BaseUser):
tag_info = RoomTagInfo(order=0.5)
tag_info[DOUBLE_PUPPET_SOURCE_KEY] = self.bridge.name
await puppet.intent.set_room_tag(portal.mxid, tag, tag_info)
elif not active and tag_info and tag_info.get(DOUBLE_PUPPET_SOURCE_KEY) == self.bridge.name:
elif (
not active and tag_info
and tag_info.get(DOUBLE_PUPPET_SOURCE_KEY) == self.bridge.name
):
await puppet.intent.remove_room_tag(portal.mxid, tag)
@staticmethod
async def _mute_room(puppet: pu.Puppet, portal: po.Portal, mute_until: datetime) -> None:
if not config["bridge.mute_bridging"] or not portal or not portal.mxid:
async def _mute_room(cls, puppet: pu.Puppet, portal: po.Portal, mute_until: datetime) -> None:
if not cls.config["bridge.mute_bridging"] or not portal or not portal.mxid:
return
now = datetime.utcnow().replace(tzinfo=timezone.utc)
if mute_until is not None and mute_until > now:
@@ -497,29 +448,31 @@ class User(AbstractUser, BaseUser):
pass
async def update_folder_peers(self, update: UpdateFolderPeers) -> None:
if config["bridge.tag_only_on_create"]:
if self.config["bridge.tag_only_on_create"]:
return
puppet = await pu.Puppet.get_by_custom_mxid(self.mxid)
if not puppet or not puppet.is_real_user:
return
for peer in update.folder_peers:
portal = po.Portal.get_by_entity(peer.peer, receiver_id=self.tgid, create=False)
await self._tag_room(puppet, portal, config["bridge.archive_tag"],
portal = await po.Portal.get_by_entity(peer.peer, tg_receiver=self.tgid, create=False)
await self._tag_room(puppet, portal, self.config["bridge.archive_tag"],
peer.folder_id == 1)
async def update_pinned_dialogs(self, update: UpdatePinnedDialogs) -> None:
if config["bridge.tag_only_on_create"]:
if self.config["bridge.tag_only_on_create"]:
return
puppet = await pu.Puppet.get_by_custom_mxid(self.mxid)
if not puppet or not puppet.is_real_user:
return
# TODO bridge unpinning properly
for pinned in update.order:
portal = po.Portal.get_by_entity(pinned.peer, receiver_id=self.tgid, create=False)
await self._tag_room(puppet, portal, config["bridge.pinned_tag"], True)
portal = await po.Portal.get_by_entity(
pinned.peer, tg_receiver=self.tgid, create=False
)
await self._tag_room(puppet, portal, self.config["bridge.pinned_tag"], True)
async def update_notify_settings(self, update: UpdateNotifySettings) -> None:
if config["bridge.tag_only_on_create"]:
if self.config["bridge.tag_only_on_create"]:
return
elif not isinstance(update.peer, NotifyPeer):
# TODO handle global notification setting changes?
@@ -527,11 +480,13 @@ class User(AbstractUser, BaseUser):
puppet = await pu.Puppet.get_by_custom_mxid(self.mxid)
if not puppet or not puppet.is_real_user:
return
portal = po.Portal.get_by_entity(update.peer.peer, receiver_id=self.tgid, create=False)
portal = await po.Portal.get_by_entity(
update.peer.peer, tg_receiver=self.tgid, create=False
)
await self._mute_room(puppet, portal, update.notify_settings.mute_until)
async def _sync_dialog(self, portal: po.Portal, dialog: Dialog, should_create: bool,
puppet: Optional[pu.Puppet]) -> None:
puppet: pu.Puppet | None) -> None:
was_created = False
if portal.mxid:
try:
@@ -553,29 +508,41 @@ class User(AbstractUser, BaseUser):
if dialog.unread_count == 0:
# This is usually more reliable than finding a specific message
# e.g. if the last read message is a service message that isn't in the message db
last_read = DBMessage.find_last(portal.mxid, tg_space)
last_read = await DBMessage.find_last(portal.mxid, tg_space)
else:
last_read = DBMessage.get_one_by_tgid(portal.tgid, tg_space,
dialog.dialog.read_inbox_max_id)
last_read = await DBMessage.get_one_by_tgid(portal.tgid, tg_space,
dialog.dialog.read_inbox_max_id)
if last_read:
await puppet.intent.mark_read(last_read.mx_room, last_read.mxid)
if was_created or not config["bridge.tag_only_on_create"]:
if was_created or not self.config["bridge.tag_only_on_create"]:
await self._mute_room(puppet, portal, dialog.dialog.notify_settings.mute_until)
await self._tag_room(puppet, portal, config["bridge.pinned_tag"], dialog.pinned)
await self._tag_room(puppet, portal, config["bridge.archive_tag"], dialog.archived)
await self._tag_room(puppet, portal, self.config["bridge.pinned_tag"],
dialog.pinned)
await self._tag_room(puppet, portal, self.config["bridge.archive_tag"],
dialog.archived)
async def get_cached_portals(self) -> dict[tuple[TelegramID, TelegramID], po.Portal]:
if self._portals_cache is None:
self._portals_cache = {
(tgid, tg_receiver): await po.Portal.get_by_tgid(tgid, tg_receiver=tg_receiver)
for tgid, tg_receiver in await self.get_portals()
}
return self._portals_cache
async def sync_dialogs(self) -> None:
if self.is_bot:
return
creators = []
update_limit = config["bridge.sync_update_limit"] or None
create_limit = config["bridge.sync_create_limit"]
update_limit = self.config["bridge.sync_update_limit"] or None
create_limit = self.config["bridge.sync_create_limit"]
index = 0
self.log.debug(f"Syncing dialogs (update_limit={update_limit}, "
f"create_limit={create_limit})")
await self.push_bridge_state(BridgeStateEvent.BACKFILLING)
puppet = await pu.Puppet.get_by_custom_mxid(self.mxid)
dialog: Dialog
old_portal_cache = await self.get_cached_portals()
new_portal_cache = old_portal_cache.copy()
async for dialog in self.client.iter_dialogs(limit=update_limit, ignore_migrated=True,
archived=False):
entity = dialog.entity
@@ -585,125 +552,152 @@ class User(AbstractUser, BaseUser):
elif isinstance(entity, Chat) and (entity.deactivated or entity.left):
self.log.warning(f"Ignoring deactivated or left chat {entity} while syncing")
continue
elif isinstance(entity, TLUser) and not config["bridge.sync_direct_chats"]:
elif isinstance(entity, TLUser) and not self.config["bridge.sync_direct_chats"]:
self.log.trace(f"Ignoring user {entity.id} while syncing")
continue
portal = po.Portal.get_by_entity(entity, receiver_id=self.tgid)
self.portals[portal.tgid_full] = portal
portal = await po.Portal.get_by_entity(entity, tg_receiver=self.tgid)
new_portal_cache[portal.tgid_full] = portal
coro = self._sync_dialog(portal=portal, dialog=dialog, puppet=puppet,
should_create=not create_limit or index < create_limit)
creators.append(self.loop.create_task(coro))
index += 1
await self.save(portals=True)
if new_portal_cache.keys() != old_portal_cache.keys():
await self.set_portals(new_portal_cache.keys())
self._portals_cache = new_portal_cache
await asyncio.gather(*creators)
await self.update_direct_chats()
self.log.debug("Dialog syncing complete")
async def register_portal(self, portal: po.Portal) -> None:
self.log.trace(f"Registering portal {portal.tgid_full}")
try:
if self.portals[portal.tgid_full] == portal:
if self._portals_cache is not None:
if self._portals_cache.get(portal.tgid_full) == portal:
return
except KeyError:
pass
self.portals[portal.tgid_full] = portal
await self.save(portals=True)
self._portals_cache[portal.tgid_full] = portal
await super().register_portal(portal.tgid, portal.tg_receiver)
async def unregister_portal(self, tgid: TelegramID, tg_receiver: TelegramID) -> None:
self.log.trace(f"Unregistering portal {(tgid, tg_receiver)}")
try:
del self.portals[(tgid, tg_receiver)]
await self.save(portals=True)
except KeyError:
pass
if self._portals_cache is not None:
self._portals_cache.pop((tgid, tg_receiver), None)
await super().unregister_portal(tgid, tg_receiver)
async def needs_relaybot(self, portal: po.Portal) -> bool:
return not await self.is_logged_in() or (
(portal.has_bot or self.is_bot) and portal.tgid_full not in self.portals)
(portal.has_bot or self.is_bot)
and portal.tgid_full not in await self.get_cached_portals()
)
def _hash_contacts(self) -> int:
@staticmethod
def _hash_contacts(count: int, ids: list[TelegramID]) -> int:
acc = 0
for contact in sorted([self.saved_contacts] + [contact.id for contact in self.contacts]):
for contact in sorted([count] + ids):
acc = (acc * 20261 + contact) & 0xffffffff
return acc & 0x7fffffff
async def sync_contacts(self) -> None:
response = await self.client(GetContactsRequest(hash=self._hash_contacts()))
existing_contacts = await self.get_contacts()
contact_hash = self._hash_contacts(self.saved_contacts, existing_contacts)
response = await self.client(GetContactsRequest(hash=contact_hash))
if isinstance(response, ContactsNotModified):
return
self.log.debug(f"Updating contacts of {self.name}...")
self.contacts = []
self.saved_contacts = response.saved_count
if self.saved_contacts != response.saved_count:
self.saved_contacts = response.saved_count
await self.save()
await self.set_contacts(user.id for user in response.users)
for user in response.users:
puppet = pu.Puppet.get(user.id)
puppet = await pu.Puppet.get_by_tgid(user.id)
await puppet.update_info(self, user)
self.contacts.append(puppet)
await self.save(contacts=True)
# endregion
# region Class instance lookup
@classmethod
def get_by_mxid(cls, mxid: UserID, create: bool = True, check_db: bool = True
) -> Optional['User']:
if not mxid:
raise ValueError("Matrix ID can't be empty")
def _add_to_cache(self) -> None:
self.by_mxid[self.mxid] = self
if self.tgid:
self.by_tgid[self.tgid] = self
@classmethod
async def get_and_start_by_mxid(cls, mxid: UserID, even_if_no_session: bool = False) -> User:
user = await cls.get_by_mxid(mxid, create=True)
await user.ensure_started(even_if_no_session=even_if_no_session)
return user
@classmethod
async def all_with_tgid(cls) -> AsyncGenerator[User, None]:
users = await super().all_with_tgid()
user: cls
for user in users:
try:
yield cls.by_mxid[user.mxid]
except KeyError:
user._add_to_cache()
yield user
@classmethod
@async_getter_lock
async def get_by_mxid(
cls, mxid: UserID, *, check_db: bool = True, create: bool = True
) -> User | None:
if not mxid or pu.Puppet.get_id_from_mxid(mxid) or mxid == cls.az.bot_mxid:
return None
try:
return cls.by_mxid[mxid]
except KeyError:
pass
if check_db:
user = DBUser.get_by_mxid(mxid)
if user:
user = cls.from_db(user)
return user
if not check_db:
return None
user = cast(cls, await super().get_by_mxid(mxid))
if user is not None:
user._add_to_cache()
return user
if create:
cls.log.debug(f"Creating user instance for {mxid}")
user = cls(mxid)
user.db_instance.insert()
await user.insert()
user._add_to_cache()
return user
return None
@classmethod
def get_by_tgid(cls, tgid: TelegramID) -> Optional['User']:
@async_getter_lock
async def get_by_tgid(cls, tgid: TelegramID) -> User | None:
try:
return cls.by_tgid[tgid]
except KeyError:
pass
user = DBUser.get_by_tgid(tgid)
if user:
user = cls.from_db(user)
user = cast(cls, await super().get_by_tgid(tgid))
if user is not None:
user._add_to_cache()
return user
return None
@classmethod
def find_by_username(cls, username: str) -> Optional['User']:
async def find_by_username(cls, username: str) -> User | None:
if not username:
return None
username = username.lower()
for _, user in cls.by_tgid.items():
if user.username and user.username.lower() == username:
if user.tg_username and user.tg_username.lower() == username:
return user
puppet = DBUser.get_by_username(username)
if puppet:
return cls.from_db(puppet)
user = cast(cls, await super().find_by_username(username))
if user:
try:
return cls.by_mxid[user.mxid]
except KeyError:
user._add_to_cache()
return user
return None
# endregion
def init(context: 'Context') -> Iterable[Awaitable['User']]:
global config
config = context.config
User.bridge = context.bridge
return (User.from_db(db_user).try_ensure_started()
for db_user in DBUser.all_with_tgid())