Switch from SQLAlchemy to asyncpg/aiosqlite
This commit is contained in:
+210
-216
@@ -13,10 +13,10 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import (Awaitable, Dict, List, Iterable, NamedTuple, Optional, Tuple, Any, cast,
|
||||
TYPE_CHECKING)
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Awaitable, AsyncIterable, NamedTuple, AsyncGenerator, TYPE_CHECKING, cast
|
||||
from datetime import datetime, timezone
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
from telethon.tl.types import (TypeUpdate, UpdateNewMessage, UpdateNewChannelMessage,
|
||||
@@ -35,21 +35,17 @@ from mautrix.client import Client
|
||||
from mautrix.errors import MatrixRequestError, MNotFound
|
||||
from mautrix.types import UserID, RoomID, PushRuleScope, PushRuleKind, PushActionType, RoomTagInfo
|
||||
from mautrix.appservice import DOUBLE_PUPPET_SOURCE_KEY
|
||||
from mautrix.bridge import BaseUser
|
||||
from mautrix.bridge import BaseUser, async_getter_lock
|
||||
from mautrix.util.bridge_state import BridgeState, BridgeStateEvent
|
||||
from mautrix.util.logging import TraceLogger
|
||||
from mautrix.util.opt_prometheus import Gauge
|
||||
|
||||
from .types import TelegramID
|
||||
from .db import User as DBUser, Portal as DBPortal, Message as DBMessage
|
||||
from .db import User as DBUser, Message as DBMessage, PgSession
|
||||
from .abstract_user import AbstractUser
|
||||
from . import portal as po, puppet as pu
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .config import Config
|
||||
from .context import Context
|
||||
|
||||
config: Optional['Config'] = None
|
||||
from .__main__ import TelegramBridge
|
||||
|
||||
SearchResult = NamedTuple('SearchResult', puppet='pu.Puppet', similarity=int)
|
||||
|
||||
@@ -64,54 +60,46 @@ BridgeState.human_readable_errors.update({
|
||||
})
|
||||
|
||||
|
||||
class User(AbstractUser, BaseUser):
|
||||
log: TraceLogger = logging.getLogger("mau.user")
|
||||
by_mxid: Dict[str, 'User'] = {}
|
||||
by_tgid: Dict[int, 'User'] = {}
|
||||
class User(DBUser, AbstractUser, BaseUser):
|
||||
by_mxid: dict[str, User] = {}
|
||||
by_tgid: dict[int, User] = {}
|
||||
|
||||
phone: Optional[str]
|
||||
contacts: List['pu.Puppet']
|
||||
saved_contacts: int
|
||||
portals: Dict[Tuple[TelegramID, TelegramID], 'po.Portal']
|
||||
command_status: Optional[Dict[str, Any]]
|
||||
_portals_cache: dict[tuple[TelegramID, TelegramID], po.Portal] | None
|
||||
|
||||
_db_instance: Optional[DBUser]
|
||||
_ensure_started_lock: asyncio.Lock
|
||||
_track_connection_task: Optional[asyncio.Task]
|
||||
_track_connection_task: asyncio.Task | None
|
||||
_is_backfilling: bool
|
||||
|
||||
def __init__(self, mxid: UserID, tgid: Optional[TelegramID] = None,
|
||||
username: Optional[str] = None, phone: Optional[str] = None,
|
||||
db_contacts: Optional[Iterable[TelegramID]] = None,
|
||||
saved_contacts: int = 0, is_bot: bool = False,
|
||||
db_portals: Optional[Iterable[Tuple[TelegramID, TelegramID]]] = None,
|
||||
db_instance: Optional[DBUser] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
mxid: UserID,
|
||||
tgid: TelegramID | None = None,
|
||||
tg_username: str | None = None,
|
||||
tg_phone: str | None = None,
|
||||
is_bot: bool = False,
|
||||
saved_contacts: int = 0,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
mxid=mxid,
|
||||
tgid=tgid,
|
||||
tg_username=tg_username,
|
||||
tg_phone=tg_phone,
|
||||
is_bot=is_bot,
|
||||
saved_contacts=saved_contacts,
|
||||
)
|
||||
AbstractUser.__init__(self)
|
||||
self.mxid = mxid
|
||||
BaseUser.__init__(self)
|
||||
self.tgid = tgid
|
||||
self.is_bot = is_bot
|
||||
self.username = username
|
||||
self.phone = phone
|
||||
self.contacts = []
|
||||
self.saved_contacts = saved_contacts
|
||||
self.db_contacts = db_contacts
|
||||
self.portals = {}
|
||||
self.db_portals = db_portals or []
|
||||
self._db_instance = db_instance
|
||||
self._ensure_started_lock = asyncio.Lock()
|
||||
self._track_connection_task = None
|
||||
self._is_backfilling = False
|
||||
self._portals_cache = None
|
||||
|
||||
(self.relaybot_whitelisted,
|
||||
self.whitelisted,
|
||||
self.puppet_whitelisted,
|
||||
self.matrix_puppet_whitelisted,
|
||||
self.is_admin,
|
||||
self.permissions) = config.get_permissions(self.mxid)
|
||||
|
||||
self.by_mxid[mxid] = self
|
||||
if tgid:
|
||||
self.by_tgid[tgid] = self
|
||||
self.permissions) = self.config.get_permissions(self.mxid)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -124,7 +112,7 @@ class User(AbstractUser, BaseUser):
|
||||
|
||||
@property
|
||||
def human_tg_id(self) -> str:
|
||||
return f"@{self.username}" if self.username else f"+{self.phone}" or None
|
||||
return f"@{self.tg_username}" if self.tg_username else f"+{self.tg_phone}" or None
|
||||
|
||||
# TODO replace with proper displayname getting everywhere
|
||||
@property
|
||||
@@ -135,65 +123,15 @@ class User(AbstractUser, BaseUser):
|
||||
def plain_displayname(self) -> str:
|
||||
return self.displayname
|
||||
|
||||
@property
|
||||
def db_contacts(self) -> Iterable[TelegramID]:
|
||||
return (puppet.id
|
||||
for puppet in self.contacts
|
||||
if puppet)
|
||||
|
||||
@db_contacts.setter
|
||||
def db_contacts(self, contacts: Iterable[TelegramID]) -> None:
|
||||
self.contacts = [pu.Puppet.get(entry) for entry in contacts] if contacts else []
|
||||
|
||||
@property
|
||||
def db_portals(self) -> Iterable[Tuple[TelegramID, TelegramID]]:
|
||||
return (portal.tgid_full
|
||||
for portal in self.portals.values()
|
||||
if portal and not portal.deleted)
|
||||
|
||||
@db_portals.setter
|
||||
def db_portals(self, portals: Iterable[Tuple[TelegramID, TelegramID]]) -> None:
|
||||
self.portals = {
|
||||
tgid_full: po.Portal.get_by_tgid(*tgid_full)
|
||||
for tgid_full in portals
|
||||
} if portals else {}
|
||||
|
||||
# region Database conversion
|
||||
|
||||
@property
|
||||
def db_instance(self) -> DBUser:
|
||||
if not self._db_instance:
|
||||
self._db_instance = self.new_db_instance()
|
||||
return self._db_instance
|
||||
|
||||
def new_db_instance(self) -> DBUser:
|
||||
return DBUser(mxid=self.mxid, tgid=self.tgid, tg_username=self.username,
|
||||
saved_contacts=self.saved_contacts, portals=self.db_portals)
|
||||
|
||||
async def save(self, contacts: bool = False, portals: bool = False) -> None:
|
||||
self.db_instance.edit(tgid=self.tgid, tg_username=self.username, tg_phone=self.phone,
|
||||
saved_contacts=self.saved_contacts)
|
||||
if contacts:
|
||||
self.db_instance.contacts = self.db_contacts
|
||||
if portals:
|
||||
self.db_instance.portals = self.db_portals
|
||||
|
||||
def delete(self, delete_db: bool = True) -> None:
|
||||
try:
|
||||
del self.by_mxid[self.mxid]
|
||||
del self.by_tgid[self.tgid]
|
||||
except KeyError:
|
||||
pass
|
||||
if delete_db and self._db_instance:
|
||||
self._db_instance.delete()
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_user: DBUser) -> 'User':
|
||||
return User(db_user.mxid, db_user.tgid, db_user.tg_username, db_user.tg_phone,
|
||||
db_user.contacts, db_user.saved_contacts, False, db_user.portals,
|
||||
db_instance=db_user)
|
||||
def init_cls(cls, bridge: 'TelegramBridge') -> AsyncIterable[Awaitable[User]]:
|
||||
cls.config = bridge.config
|
||||
cls.bridge = bridge
|
||||
cls.az = bridge.az
|
||||
cls.loop = bridge.loop
|
||||
|
||||
return (user.try_ensure_started() async for user in cls.all_with_tgid())
|
||||
|
||||
# endregion
|
||||
# region Telegram connection management
|
||||
|
||||
async def try_ensure_started(self) -> None:
|
||||
@@ -202,19 +140,19 @@ class User(AbstractUser, BaseUser):
|
||||
except Exception:
|
||||
self.log.exception("Exception in ensure_started")
|
||||
else:
|
||||
if not self.client and not self.session_container.has_session(self.mxid):
|
||||
if not self.client and not await PgSession.has(self.mxid):
|
||||
self.log.warning("Didn't start user: no session stored")
|
||||
if self.tgid:
|
||||
await self.push_bridge_state(BridgeStateEvent.BAD_CREDENTIALS,
|
||||
error="tg-no-auth")
|
||||
|
||||
async def ensure_started(self, even_if_no_session=False) -> 'User':
|
||||
async def ensure_started(self, even_if_no_session=False) -> User:
|
||||
if not self.puppet_whitelisted or self.connected:
|
||||
return self
|
||||
async with self._ensure_started_lock:
|
||||
return cast(User, await super().ensure_started(even_if_no_session))
|
||||
|
||||
async def start(self, delete_unless_authenticated: bool = False) -> 'User':
|
||||
async def start(self, delete_unless_authenticated: bool = False) -> User:
|
||||
try:
|
||||
await super().start()
|
||||
except AuthKeyDuplicatedError:
|
||||
@@ -222,7 +160,7 @@ class User(AbstractUser, BaseUser):
|
||||
await self.push_bridge_state(BridgeStateEvent.BAD_CREDENTIALS,
|
||||
error="tg-auth-key-duplicated")
|
||||
await self.client.disconnect()
|
||||
self.client.session.delete()
|
||||
await self.client.session.delete()
|
||||
self.client = None
|
||||
if not delete_unless_authenticated:
|
||||
# The caller wants the client to be connected, so restart the connection.
|
||||
@@ -257,7 +195,7 @@ class User(AbstractUser, BaseUser):
|
||||
if delete_unless_authenticated:
|
||||
self.log.debug(f"Unauthenticated user {self.name} start()ed, deleting session...")
|
||||
await self.client.disconnect()
|
||||
self.client.session.delete()
|
||||
await self.client.session.delete()
|
||||
return self
|
||||
|
||||
@property
|
||||
@@ -283,7 +221,7 @@ class User(AbstractUser, BaseUser):
|
||||
state.remote_id = str(self.tgid)
|
||||
state.remote_name = self.human_tg_id
|
||||
|
||||
async def get_bridge_states(self) -> List[BridgeState]:
|
||||
async def get_bridge_states(self) -> list[BridgeState]:
|
||||
if not self.tgid:
|
||||
return []
|
||||
if self._is_connected and await self.is_logged_in():
|
||||
@@ -295,10 +233,10 @@ class User(AbstractUser, BaseUser):
|
||||
ttl = 240
|
||||
return [BridgeState(state_event=state_event, ttl=ttl)]
|
||||
|
||||
async def get_puppet(self) -> Optional['pu.Puppet']:
|
||||
async def get_puppet(self) -> pu.Puppet | None:
|
||||
if not self.tgid:
|
||||
return None
|
||||
return pu.Puppet.get(self.tgid)
|
||||
return await pu.Puppet.get_by_tgid(self.tgid)
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._track_connection_task:
|
||||
@@ -308,7 +246,7 @@ class User(AbstractUser, BaseUser):
|
||||
self._track_metric(METRIC_CONNECTED, False)
|
||||
|
||||
async def post_login(self, info: TLUser = None, first_login: bool = False) -> None:
|
||||
if config["metrics.enabled"] and not self._track_connection_task:
|
||||
if self.config["metrics.enabled"] and not self._track_connection_task:
|
||||
self._track_connection_task = self.loop.create_task(self._track_connection())
|
||||
|
||||
try:
|
||||
@@ -320,14 +258,14 @@ class User(AbstractUser, BaseUser):
|
||||
self._track_metric(METRIC_LOGGED_IN, True)
|
||||
|
||||
try:
|
||||
puppet = pu.Puppet.get(self.tgid)
|
||||
puppet = await pu.Puppet.get_by_tgid(self.tgid)
|
||||
if puppet.custom_mxid != self.mxid and puppet.can_auto_login(self.mxid):
|
||||
self.log.info(f"Automatically enabling custom puppet")
|
||||
await puppet.switch_mxid(access_token="auto", mxid=self.mxid)
|
||||
except Exception:
|
||||
self.log.exception("Failed to automatically enable custom puppet")
|
||||
|
||||
if not self.is_bot and config["bridge.startup_sync"]:
|
||||
if not self.is_bot and self.config["bridge.startup_sync"]:
|
||||
try:
|
||||
self._is_backfilling = True
|
||||
await self.sync_dialogs()
|
||||
@@ -342,11 +280,13 @@ class User(AbstractUser, BaseUser):
|
||||
return False
|
||||
|
||||
if isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)):
|
||||
portal = po.Portal.get_by_entity(update.message.peer_id, receiver_id=self.tgid)
|
||||
portal = await po.Portal.get_by_entity(update.message.peer_id, tg_receiver=self.tgid)
|
||||
elif isinstance(update, UpdateShortChatMessage):
|
||||
portal = po.Portal.get_by_tgid(TelegramID(update.chat_id))
|
||||
portal = await po.Portal.get_by_tgid(TelegramID(update.chat_id))
|
||||
elif isinstance(update, UpdateShortMessage):
|
||||
portal = po.Portal.get_by_tgid(TelegramID(update.user_id), self.tgid, "user")
|
||||
portal = await po.Portal.get_by_tgid(
|
||||
TelegramID(update.user_id), tg_receiver=self.tgid, peer_type="user"
|
||||
)
|
||||
else:
|
||||
return False
|
||||
|
||||
@@ -364,7 +304,7 @@ class User(AbstractUser, BaseUser):
|
||||
if not self.is_bot:
|
||||
await self.client(UpdateStatusRequest(offline=not online))
|
||||
|
||||
async def get_me(self) -> Optional[TLUser]:
|
||||
async def get_me(self) -> TLUser | None:
|
||||
try:
|
||||
return (await self.client(GetUsersRequest([InputUserSelf()])))[0]
|
||||
except UnauthorizedError as e:
|
||||
@@ -384,11 +324,11 @@ class User(AbstractUser, BaseUser):
|
||||
if self.is_bot != info.bot:
|
||||
self.is_bot = info.bot
|
||||
changed = True
|
||||
if self.username != info.username:
|
||||
self.username = info.username
|
||||
if self.tg_username != info.username:
|
||||
self.tg_username = info.username
|
||||
changed = True
|
||||
if self.phone != info.phone:
|
||||
self.phone = info.phone
|
||||
if self.tg_phone != info.phone:
|
||||
self.tg_phone = info.phone
|
||||
changed = True
|
||||
if self.tgid != info.id:
|
||||
self.tgid = TelegramID(info.id)
|
||||
@@ -396,11 +336,11 @@ class User(AbstractUser, BaseUser):
|
||||
if changed:
|
||||
await self.save()
|
||||
|
||||
async def log_out(self) -> bool:
|
||||
puppet = pu.Puppet.get(self.tgid)
|
||||
if puppet.is_real_user:
|
||||
await puppet.switch_mxid(None, None)
|
||||
for _, portal in self.portals.items():
|
||||
async def kick_from_portals(self) -> None:
|
||||
if not self.config["bridge.kick_on_logout"]:
|
||||
return
|
||||
portals = await self.get_cached_portals()
|
||||
for _, portal in portals.values():
|
||||
if not portal or portal.deleted or not portal.mxid or portal.has_bot:
|
||||
continue
|
||||
if portal.peer_type == "user":
|
||||
@@ -411,9 +351,15 @@ class User(AbstractUser, BaseUser):
|
||||
"Logged out of Telegram.")
|
||||
except MatrixRequestError:
|
||||
pass
|
||||
self.portals = {}
|
||||
self.contacts = []
|
||||
await self.save(portals=True, contacts=True)
|
||||
|
||||
async def log_out(self) -> bool:
|
||||
puppet = await pu.Puppet.get_by_tgid(self.tgid)
|
||||
if puppet.is_real_user:
|
||||
await puppet.switch_mxid(None, None)
|
||||
try:
|
||||
await self.kick_from_portals()
|
||||
except Exception:
|
||||
self.log.exception("Failed to kick user from portals on logout")
|
||||
await self.push_bridge_state(BridgeStateEvent.LOGGED_OUT)
|
||||
if self.tgid:
|
||||
try:
|
||||
@@ -421,51 +367,54 @@ class User(AbstractUser, BaseUser):
|
||||
except KeyError:
|
||||
pass
|
||||
self.tgid = None
|
||||
await self.save()
|
||||
ok = await self.client.log_out()
|
||||
self.client.session.delete()
|
||||
self.delete()
|
||||
await self.client.session.delete()
|
||||
await self.delete()
|
||||
self.by_mxid.pop(self.mxid, None)
|
||||
await self.stop()
|
||||
self._track_metric(METRIC_LOGGED_IN, False)
|
||||
return ok
|
||||
|
||||
def _search_local(self, query: str, max_results: int = 5, min_similarity: int = 45
|
||||
) -> List[SearchResult]:
|
||||
results: List[SearchResult] = []
|
||||
for contact in self.contacts:
|
||||
async def _search_local(self, query: str, max_results: int = 5, min_similarity: int = 45
|
||||
) -> list[SearchResult]:
|
||||
results: list[SearchResult] = []
|
||||
for contact_id in await self.get_contacts():
|
||||
contact = await pu.Puppet.get_by_tgid(contact_id, create=False)
|
||||
if not contact:
|
||||
continue
|
||||
similarity = contact.similarity(query)
|
||||
if similarity >= min_similarity:
|
||||
results.append(SearchResult(contact, similarity))
|
||||
results.sort(key=lambda tup: tup[1], reverse=True)
|
||||
return results[0:max_results]
|
||||
|
||||
async def _search_remote(self, query: str, max_results: int = 5) -> List[SearchResult]:
|
||||
async def _search_remote(self, query: str, max_results: int = 5) -> list[SearchResult]:
|
||||
if len(query) < 5:
|
||||
return []
|
||||
server_results = await self.client(SearchRequest(q=query, limit=max_results))
|
||||
results: List[SearchResult] = []
|
||||
results: list[SearchResult] = []
|
||||
for user in server_results.users:
|
||||
puppet = pu.Puppet.get(user.id)
|
||||
puppet = await pu.Puppet.get_by_tgid(user.id)
|
||||
await puppet.update_info(self, user)
|
||||
results.append(SearchResult(puppet, puppet.similarity(query)))
|
||||
results.sort(key=lambda tup: tup[1], reverse=True)
|
||||
return results[0:max_results]
|
||||
|
||||
async def search(self, query: str, force_remote: bool = False
|
||||
) -> Tuple[List[SearchResult], bool]:
|
||||
) -> tuple[list[SearchResult], bool]:
|
||||
if force_remote:
|
||||
return await self._search_remote(query), True
|
||||
|
||||
results = self._search_local(query)
|
||||
results = await self._search_local(query)
|
||||
if results:
|
||||
return results, False
|
||||
|
||||
return await self._search_remote(query), True
|
||||
|
||||
async def get_direct_chats(self) -> Dict[UserID, List[RoomID]]:
|
||||
async def get_direct_chats(self) -> dict[UserID, list[RoomID]]:
|
||||
return {
|
||||
pu.Puppet.get_mxid_from_id(portal.tgid): [portal.mxid]
|
||||
for portal in DBPortal.find_private_chats(self.tgid)
|
||||
async for portal in po.Portal.find_private_chats(self.tgid)
|
||||
if portal.mxid
|
||||
}
|
||||
|
||||
@@ -478,12 +427,14 @@ class User(AbstractUser, BaseUser):
|
||||
tag_info = RoomTagInfo(order=0.5)
|
||||
tag_info[DOUBLE_PUPPET_SOURCE_KEY] = self.bridge.name
|
||||
await puppet.intent.set_room_tag(portal.mxid, tag, tag_info)
|
||||
elif not active and tag_info and tag_info.get(DOUBLE_PUPPET_SOURCE_KEY) == self.bridge.name:
|
||||
elif (
|
||||
not active and tag_info
|
||||
and tag_info.get(DOUBLE_PUPPET_SOURCE_KEY) == self.bridge.name
|
||||
):
|
||||
await puppet.intent.remove_room_tag(portal.mxid, tag)
|
||||
|
||||
@staticmethod
|
||||
async def _mute_room(puppet: pu.Puppet, portal: po.Portal, mute_until: datetime) -> None:
|
||||
if not config["bridge.mute_bridging"] or not portal or not portal.mxid:
|
||||
async def _mute_room(cls, puppet: pu.Puppet, portal: po.Portal, mute_until: datetime) -> None:
|
||||
if not cls.config["bridge.mute_bridging"] or not portal or not portal.mxid:
|
||||
return
|
||||
now = datetime.utcnow().replace(tzinfo=timezone.utc)
|
||||
if mute_until is not None and mute_until > now:
|
||||
@@ -497,29 +448,31 @@ class User(AbstractUser, BaseUser):
|
||||
pass
|
||||
|
||||
async def update_folder_peers(self, update: UpdateFolderPeers) -> None:
|
||||
if config["bridge.tag_only_on_create"]:
|
||||
if self.config["bridge.tag_only_on_create"]:
|
||||
return
|
||||
puppet = await pu.Puppet.get_by_custom_mxid(self.mxid)
|
||||
if not puppet or not puppet.is_real_user:
|
||||
return
|
||||
for peer in update.folder_peers:
|
||||
portal = po.Portal.get_by_entity(peer.peer, receiver_id=self.tgid, create=False)
|
||||
await self._tag_room(puppet, portal, config["bridge.archive_tag"],
|
||||
portal = await po.Portal.get_by_entity(peer.peer, tg_receiver=self.tgid, create=False)
|
||||
await self._tag_room(puppet, portal, self.config["bridge.archive_tag"],
|
||||
peer.folder_id == 1)
|
||||
|
||||
async def update_pinned_dialogs(self, update: UpdatePinnedDialogs) -> None:
|
||||
if config["bridge.tag_only_on_create"]:
|
||||
if self.config["bridge.tag_only_on_create"]:
|
||||
return
|
||||
puppet = await pu.Puppet.get_by_custom_mxid(self.mxid)
|
||||
if not puppet or not puppet.is_real_user:
|
||||
return
|
||||
# TODO bridge unpinning properly
|
||||
for pinned in update.order:
|
||||
portal = po.Portal.get_by_entity(pinned.peer, receiver_id=self.tgid, create=False)
|
||||
await self._tag_room(puppet, portal, config["bridge.pinned_tag"], True)
|
||||
portal = await po.Portal.get_by_entity(
|
||||
pinned.peer, tg_receiver=self.tgid, create=False
|
||||
)
|
||||
await self._tag_room(puppet, portal, self.config["bridge.pinned_tag"], True)
|
||||
|
||||
async def update_notify_settings(self, update: UpdateNotifySettings) -> None:
|
||||
if config["bridge.tag_only_on_create"]:
|
||||
if self.config["bridge.tag_only_on_create"]:
|
||||
return
|
||||
elif not isinstance(update.peer, NotifyPeer):
|
||||
# TODO handle global notification setting changes?
|
||||
@@ -527,11 +480,13 @@ class User(AbstractUser, BaseUser):
|
||||
puppet = await pu.Puppet.get_by_custom_mxid(self.mxid)
|
||||
if not puppet or not puppet.is_real_user:
|
||||
return
|
||||
portal = po.Portal.get_by_entity(update.peer.peer, receiver_id=self.tgid, create=False)
|
||||
portal = await po.Portal.get_by_entity(
|
||||
update.peer.peer, tg_receiver=self.tgid, create=False
|
||||
)
|
||||
await self._mute_room(puppet, portal, update.notify_settings.mute_until)
|
||||
|
||||
async def _sync_dialog(self, portal: po.Portal, dialog: Dialog, should_create: bool,
|
||||
puppet: Optional[pu.Puppet]) -> None:
|
||||
puppet: pu.Puppet | None) -> None:
|
||||
was_created = False
|
||||
if portal.mxid:
|
||||
try:
|
||||
@@ -553,29 +508,41 @@ class User(AbstractUser, BaseUser):
|
||||
if dialog.unread_count == 0:
|
||||
# This is usually more reliable than finding a specific message
|
||||
# e.g. if the last read message is a service message that isn't in the message db
|
||||
last_read = DBMessage.find_last(portal.mxid, tg_space)
|
||||
last_read = await DBMessage.find_last(portal.mxid, tg_space)
|
||||
else:
|
||||
last_read = DBMessage.get_one_by_tgid(portal.tgid, tg_space,
|
||||
dialog.dialog.read_inbox_max_id)
|
||||
last_read = await DBMessage.get_one_by_tgid(portal.tgid, tg_space,
|
||||
dialog.dialog.read_inbox_max_id)
|
||||
if last_read:
|
||||
await puppet.intent.mark_read(last_read.mx_room, last_read.mxid)
|
||||
if was_created or not config["bridge.tag_only_on_create"]:
|
||||
if was_created or not self.config["bridge.tag_only_on_create"]:
|
||||
await self._mute_room(puppet, portal, dialog.dialog.notify_settings.mute_until)
|
||||
await self._tag_room(puppet, portal, config["bridge.pinned_tag"], dialog.pinned)
|
||||
await self._tag_room(puppet, portal, config["bridge.archive_tag"], dialog.archived)
|
||||
await self._tag_room(puppet, portal, self.config["bridge.pinned_tag"],
|
||||
dialog.pinned)
|
||||
await self._tag_room(puppet, portal, self.config["bridge.archive_tag"],
|
||||
dialog.archived)
|
||||
|
||||
async def get_cached_portals(self) -> dict[tuple[TelegramID, TelegramID], po.Portal]:
|
||||
if self._portals_cache is None:
|
||||
self._portals_cache = {
|
||||
(tgid, tg_receiver): await po.Portal.get_by_tgid(tgid, tg_receiver=tg_receiver)
|
||||
for tgid, tg_receiver in await self.get_portals()
|
||||
}
|
||||
return self._portals_cache
|
||||
|
||||
async def sync_dialogs(self) -> None:
|
||||
if self.is_bot:
|
||||
return
|
||||
creators = []
|
||||
update_limit = config["bridge.sync_update_limit"] or None
|
||||
create_limit = config["bridge.sync_create_limit"]
|
||||
update_limit = self.config["bridge.sync_update_limit"] or None
|
||||
create_limit = self.config["bridge.sync_create_limit"]
|
||||
index = 0
|
||||
self.log.debug(f"Syncing dialogs (update_limit={update_limit}, "
|
||||
f"create_limit={create_limit})")
|
||||
await self.push_bridge_state(BridgeStateEvent.BACKFILLING)
|
||||
puppet = await pu.Puppet.get_by_custom_mxid(self.mxid)
|
||||
dialog: Dialog
|
||||
old_portal_cache = await self.get_cached_portals()
|
||||
new_portal_cache = old_portal_cache.copy()
|
||||
async for dialog in self.client.iter_dialogs(limit=update_limit, ignore_migrated=True,
|
||||
archived=False):
|
||||
entity = dialog.entity
|
||||
@@ -585,125 +552,152 @@ class User(AbstractUser, BaseUser):
|
||||
elif isinstance(entity, Chat) and (entity.deactivated or entity.left):
|
||||
self.log.warning(f"Ignoring deactivated or left chat {entity} while syncing")
|
||||
continue
|
||||
elif isinstance(entity, TLUser) and not config["bridge.sync_direct_chats"]:
|
||||
elif isinstance(entity, TLUser) and not self.config["bridge.sync_direct_chats"]:
|
||||
self.log.trace(f"Ignoring user {entity.id} while syncing")
|
||||
continue
|
||||
portal = po.Portal.get_by_entity(entity, receiver_id=self.tgid)
|
||||
self.portals[portal.tgid_full] = portal
|
||||
portal = await po.Portal.get_by_entity(entity, tg_receiver=self.tgid)
|
||||
new_portal_cache[portal.tgid_full] = portal
|
||||
coro = self._sync_dialog(portal=portal, dialog=dialog, puppet=puppet,
|
||||
should_create=not create_limit or index < create_limit)
|
||||
creators.append(self.loop.create_task(coro))
|
||||
index += 1
|
||||
await self.save(portals=True)
|
||||
if new_portal_cache.keys() != old_portal_cache.keys():
|
||||
await self.set_portals(new_portal_cache.keys())
|
||||
self._portals_cache = new_portal_cache
|
||||
await asyncio.gather(*creators)
|
||||
await self.update_direct_chats()
|
||||
self.log.debug("Dialog syncing complete")
|
||||
|
||||
async def register_portal(self, portal: po.Portal) -> None:
|
||||
self.log.trace(f"Registering portal {portal.tgid_full}")
|
||||
try:
|
||||
if self.portals[portal.tgid_full] == portal:
|
||||
if self._portals_cache is not None:
|
||||
if self._portals_cache.get(portal.tgid_full) == portal:
|
||||
return
|
||||
except KeyError:
|
||||
pass
|
||||
self.portals[portal.tgid_full] = portal
|
||||
await self.save(portals=True)
|
||||
self._portals_cache[portal.tgid_full] = portal
|
||||
await super().register_portal(portal.tgid, portal.tg_receiver)
|
||||
|
||||
async def unregister_portal(self, tgid: TelegramID, tg_receiver: TelegramID) -> None:
|
||||
self.log.trace(f"Unregistering portal {(tgid, tg_receiver)}")
|
||||
try:
|
||||
del self.portals[(tgid, tg_receiver)]
|
||||
await self.save(portals=True)
|
||||
except KeyError:
|
||||
pass
|
||||
if self._portals_cache is not None:
|
||||
self._portals_cache.pop((tgid, tg_receiver), None)
|
||||
await super().unregister_portal(tgid, tg_receiver)
|
||||
|
||||
async def needs_relaybot(self, portal: po.Portal) -> bool:
|
||||
return not await self.is_logged_in() or (
|
||||
(portal.has_bot or self.is_bot) and portal.tgid_full not in self.portals)
|
||||
(portal.has_bot or self.is_bot)
|
||||
and portal.tgid_full not in await self.get_cached_portals()
|
||||
)
|
||||
|
||||
def _hash_contacts(self) -> int:
|
||||
@staticmethod
|
||||
def _hash_contacts(count: int, ids: list[TelegramID]) -> int:
|
||||
acc = 0
|
||||
for contact in sorted([self.saved_contacts] + [contact.id for contact in self.contacts]):
|
||||
for contact in sorted([count] + ids):
|
||||
acc = (acc * 20261 + contact) & 0xffffffff
|
||||
return acc & 0x7fffffff
|
||||
|
||||
async def sync_contacts(self) -> None:
|
||||
response = await self.client(GetContactsRequest(hash=self._hash_contacts()))
|
||||
existing_contacts = await self.get_contacts()
|
||||
contact_hash = self._hash_contacts(self.saved_contacts, existing_contacts)
|
||||
response = await self.client(GetContactsRequest(hash=contact_hash))
|
||||
if isinstance(response, ContactsNotModified):
|
||||
return
|
||||
self.log.debug(f"Updating contacts of {self.name}...")
|
||||
self.contacts = []
|
||||
self.saved_contacts = response.saved_count
|
||||
if self.saved_contacts != response.saved_count:
|
||||
self.saved_contacts = response.saved_count
|
||||
await self.save()
|
||||
await self.set_contacts(user.id for user in response.users)
|
||||
for user in response.users:
|
||||
puppet = pu.Puppet.get(user.id)
|
||||
puppet = await pu.Puppet.get_by_tgid(user.id)
|
||||
await puppet.update_info(self, user)
|
||||
self.contacts.append(puppet)
|
||||
await self.save(contacts=True)
|
||||
|
||||
# endregion
|
||||
# region Class instance lookup
|
||||
|
||||
@classmethod
|
||||
def get_by_mxid(cls, mxid: UserID, create: bool = True, check_db: bool = True
|
||||
) -> Optional['User']:
|
||||
if not mxid:
|
||||
raise ValueError("Matrix ID can't be empty")
|
||||
def _add_to_cache(self) -> None:
|
||||
self.by_mxid[self.mxid] = self
|
||||
if self.tgid:
|
||||
self.by_tgid[self.tgid] = self
|
||||
|
||||
@classmethod
|
||||
async def get_and_start_by_mxid(cls, mxid: UserID, even_if_no_session: bool = False) -> User:
|
||||
user = await cls.get_by_mxid(mxid, create=True)
|
||||
await user.ensure_started(even_if_no_session=even_if_no_session)
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
async def all_with_tgid(cls) -> AsyncGenerator[User, None]:
|
||||
users = await super().all_with_tgid()
|
||||
user: cls
|
||||
for user in users:
|
||||
try:
|
||||
yield cls.by_mxid[user.mxid]
|
||||
except KeyError:
|
||||
user._add_to_cache()
|
||||
yield user
|
||||
|
||||
@classmethod
|
||||
@async_getter_lock
|
||||
async def get_by_mxid(
|
||||
cls, mxid: UserID, *, check_db: bool = True, create: bool = True
|
||||
) -> User | None:
|
||||
if not mxid or pu.Puppet.get_id_from_mxid(mxid) or mxid == cls.az.bot_mxid:
|
||||
return None
|
||||
try:
|
||||
return cls.by_mxid[mxid]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if check_db:
|
||||
user = DBUser.get_by_mxid(mxid)
|
||||
if user:
|
||||
user = cls.from_db(user)
|
||||
return user
|
||||
if not check_db:
|
||||
return None
|
||||
|
||||
user = cast(cls, await super().get_by_mxid(mxid))
|
||||
if user is not None:
|
||||
user._add_to_cache()
|
||||
return user
|
||||
|
||||
if create:
|
||||
cls.log.debug(f"Creating user instance for {mxid}")
|
||||
user = cls(mxid)
|
||||
user.db_instance.insert()
|
||||
await user.insert()
|
||||
user._add_to_cache()
|
||||
return user
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_by_tgid(cls, tgid: TelegramID) -> Optional['User']:
|
||||
@async_getter_lock
|
||||
async def get_by_tgid(cls, tgid: TelegramID) -> User | None:
|
||||
try:
|
||||
return cls.by_tgid[tgid]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
user = DBUser.get_by_tgid(tgid)
|
||||
if user:
|
||||
user = cls.from_db(user)
|
||||
user = cast(cls, await super().get_by_tgid(tgid))
|
||||
if user is not None:
|
||||
user._add_to_cache()
|
||||
return user
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def find_by_username(cls, username: str) -> Optional['User']:
|
||||
async def find_by_username(cls, username: str) -> User | None:
|
||||
if not username:
|
||||
return None
|
||||
|
||||
username = username.lower()
|
||||
|
||||
for _, user in cls.by_tgid.items():
|
||||
if user.username and user.username.lower() == username:
|
||||
if user.tg_username and user.tg_username.lower() == username:
|
||||
return user
|
||||
|
||||
puppet = DBUser.get_by_username(username)
|
||||
if puppet:
|
||||
return cls.from_db(puppet)
|
||||
user = cast(cls, await super().find_by_username(username))
|
||||
if user:
|
||||
try:
|
||||
return cls.by_mxid[user.mxid]
|
||||
except KeyError:
|
||||
user._add_to_cache()
|
||||
return user
|
||||
|
||||
return None
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
def init(context: 'Context') -> Iterable[Awaitable['User']]:
|
||||
global config
|
||||
config = context.config
|
||||
User.bridge = context.bridge
|
||||
|
||||
return (User.from_db(db_user).try_ensure_started()
|
||||
for db_user in DBUser.all_with_tgid())
|
||||
|
||||
Reference in New Issue
Block a user