Add option to sync portals in backfill queue

This commit is contained in:
Tulir Asokan
2022-10-14 13:55:12 +03:00
parent af2f20f7b2
commit 0bbf64d240
10 changed files with 315 additions and 116 deletions
+1 -1
View File
@@ -620,7 +620,7 @@ class AbstractUser(ABC):
self.log.info( self.log.info(
"Creating Matrix room with data fetched by Telethon due to UpdateChannel" "Creating Matrix room with data fetched by Telethon due to UpdateChannel"
) )
await portal.create_matrix_room(self, chan) await portal.create_matrix_room(self, chan, invites=[self.mxid])
async def update_message(self, original_update: UpdateMessage) -> None: async def update_message(self, original_update: UpdateMessage) -> None:
update, sender, portal = await self.get_message_details(original_update) update, sender, portal = await self.get_message_details(original_update)
+1
View File
@@ -113,6 +113,7 @@ class Config(BaseBridgeConfig):
else: else:
copy("bridge.sync_update_limit") copy("bridge.sync_update_limit")
copy("bridge.sync_create_limit") copy("bridge.sync_create_limit")
copy("bridge.sync_deferred_create_all")
copy("bridge.sync_direct_chats") copy("bridge.sync_direct_chats")
copy("bridge.max_telegram_delete") copy("bridge.max_telegram_delete")
copy("bridge.sync_matrix_state") copy("bridge.sync_matrix_state")
+1 -1
View File
@@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.util.async_db import Database from mautrix.util.async_db import Database
from .backfill_queue import Backfill from .backfill_queue import Backfill, BackfillType
from .bot_chat import BotChat from .bot_chat import BotChat
from .disappearing_message import DisappearingMessage from .disappearing_message import DisappearingMessage
from .message import Message from .message import Message
+80 -27
View File
@@ -15,8 +15,10 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar from typing import TYPE_CHECKING, Any, ClassVar
from datetime import datetime, timedelta from datetime import datetime, timedelta
from enum import Enum
import json
from asyncpg import Record from asyncpg import Record
from attr import dataclass from attr import dataclass
@@ -29,6 +31,11 @@ from ..types import TelegramID
fake_db = Database.create("") if TYPE_CHECKING else None fake_db = Database.create("") if TYPE_CHECKING else None
class BackfillType(Enum):
HISTORICAL = "historical"
SYNC_DIALOG = "sync_dialog"
@dataclass @dataclass
class Backfill: class Backfill:
db: ClassVar[Database] = fake_db db: ClassVar[Database] = fake_db
@@ -36,9 +43,11 @@ class Backfill:
queue_id: int | None queue_id: int | None
user_mxid: UserID user_mxid: UserID
priority: int priority: int
type: BackfillType
portal_tgid: TelegramID portal_tgid: TelegramID
portal_tg_receiver: TelegramID portal_tg_receiver: TelegramID
anchor_msg_id: TelegramID | None anchor_msg_id: TelegramID | None
extra_data: dict[str, Any]
messages_per_batch: int messages_per_batch: int
post_batch_delay: int post_batch_delay: int
max_batches: int max_batches: int
@@ -50,10 +59,12 @@ class Backfill:
def new( def new(
user_mxid: UserID, user_mxid: UserID,
priority: int, priority: int,
type: BackfillType,
portal_tgid: TelegramID, portal_tgid: TelegramID,
portal_tg_receiver: TelegramID, portal_tg_receiver: TelegramID,
messages_per_batch: int, messages_per_batch: int,
anchor_msg_id: TelegramID | None = None, anchor_msg_id: TelegramID | None = None,
extra_data: dict[str, Any] | None = None,
post_batch_delay: int = 0, post_batch_delay: int = 0,
max_batches: int = -1, max_batches: int = -1,
) -> "Backfill": ) -> "Backfill":
@@ -61,9 +72,11 @@ class Backfill:
queue_id=None, queue_id=None,
user_mxid=user_mxid, user_mxid=user_mxid,
priority=priority, priority=priority,
type=type,
portal_tgid=portal_tgid, portal_tgid=portal_tgid,
portal_tg_receiver=portal_tg_receiver, portal_tg_receiver=portal_tg_receiver,
anchor_msg_id=anchor_msg_id, anchor_msg_id=anchor_msg_id,
extra_data=extra_data or {},
messages_per_batch=messages_per_batch, messages_per_batch=messages_per_batch,
post_batch_delay=post_batch_delay, post_batch_delay=post_batch_delay,
max_batches=max_batches, max_batches=max_batches,
@@ -76,14 +89,19 @@ class Backfill:
def _from_row(cls, row: Record | None) -> Backfill | None: def _from_row(cls, row: Record | None) -> Backfill | None:
if row is None: if row is None:
return None return None
return cls(**row) data = {**row}
type = BackfillType(data.pop("type"))
extra_data = json.loads(data.pop("extra_data", None) or "{}")
return cls(**data, type=type, extra_data=extra_data)
columns = [ columns = [
"user_mxid", "user_mxid",
"priority", "priority",
"type",
"portal_tgid", "portal_tgid",
"portal_tg_receiver", "portal_tg_receiver",
"anchor_msg_id", "anchor_msg_id",
"extra_data",
"messages_per_batch", "messages_per_batch",
"post_batch_delay", "post_batch_delay",
"max_batches", "max_batches",
@@ -118,22 +136,37 @@ class Backfill:
) )
@classmethod @classmethod
async def get( async def delete_existing(
cls, cls,
user_mxid: UserID, user_mxid: UserID,
portal_tgid: int, portal_tgid: int,
portal_tg_receiver: int, portal_tg_receiver: int,
type: BackfillType,
) -> Backfill | None: ) -> Backfill | None:
q = f""" q = f"""
SELECT queue_id, {cls.columns_str} WITH deleted_entries AS (
FROM backfill_queue DELETE FROM backfill_queue
WHERE user_mxid=$1 WHERE user_mxid=$1
AND portal_tgid=$2 AND portal_tgid=$2
AND portal_tg_receiver=$3 AND portal_tg_receiver=$3
ORDER BY priority, queue_id AND type=$4
LIMIT 1 AND dispatch_time IS NULL
AND completed_at IS NULL
RETURNING 1
)
WITH dispatched_entries AS (
SELECT 1 FROM backfill_queue
WHERE user_mxid=$1
AND portal_tgid=$2
AND portal_tg_receiver=$3
AND type=$4
AND dispatch_time IS NOT NULL
AND completed_at IS NULL
)
""" """
return cls._from_row(await cls.db.fetchrow(q, user_mxid, portal_tgid, portal_tg_receiver)) return cls._from_row(
await cls.db.fetchrow(q, user_mxid, portal_tgid, portal_tg_receiver, type.value)
)
@classmethod @classmethod
async def delete_all(cls, user_mxid: UserID) -> None: async def delete_all(cls, user_mxid: UserID) -> None:
@@ -144,27 +177,47 @@ class Backfill:
q = "DELETE FROM backfill_queue WHERE portal_tgid=$1 AND portal_tg_receiver=$2" q = "DELETE FROM backfill_queue WHERE portal_tgid=$1 AND portal_tg_receiver=$2"
await cls.db.execute(q, tgid, tg_receiver) await cls.db.execute(q, tgid, tg_receiver)
async def insert(self) -> None: async def insert(self) -> list[Backfill]:
delete_q = f"""
DELETE FROM backfill_queue
WHERE user_mxid=$1
AND portal_tgid=$2
AND portal_tg_receiver=$3
AND type=$4
AND dispatch_time IS NULL
AND completed_at IS NULL
RETURNING {self.columns_str}
"""
q = f""" q = f"""
INSERT INTO backfill_queue ({self.columns_str}) INSERT INTO backfill_queue ({self.columns_str})
VALUES ({','.join(f'${i+1}' for i in range(len(self.columns)))}) VALUES ({','.join(f'${i+1}' for i in range(len(self.columns)))})
RETURNING queue_id RETURNING queue_id
""" """
row = await self.db.fetchrow( async with self.db.acquire() as conn, conn.transaction():
q, deleted_rows = await self.db.fetch(
self.user_mxid, delete_q,
self.priority, self.user_mxid,
self.portal_tgid, self.portal_tgid,
self.portal_tg_receiver, self.portal_tg_receiver,
self.anchor_msg_id, self.type.value,
self.messages_per_batch, )
self.post_batch_delay, self.queue_id = await self.db.fetchval(
self.max_batches, q,
self.dispatch_time, self.user_mxid,
self.completed_at, self.priority,
self.cooldown_timeout, self.type.value,
) self.portal_tgid,
self.queue_id = row["queue_id"] self.portal_tg_receiver,
self.anchor_msg_id,
json.dumps(self.extra_data) if self.extra_data else None,
self.messages_per_batch,
self.post_batch_delay,
self.max_batches,
self.dispatch_time,
self.completed_at,
self.cooldown_timeout,
)
return [self._from_row(row) for row in deleted_rows]
async def mark_dispatched(self) -> None: async def mark_dispatched(self) -> None:
q = "UPDATE backfill_queue SET dispatch_time=$1 WHERE queue_id=$2" q = "UPDATE backfill_queue SET dispatch_time=$1 WHERE queue_id=$2"
+1
View File
@@ -18,4 +18,5 @@ from . import (
v13_multiple_reactions, v13_multiple_reactions,
v14_puppet_custom_mxid_index, v14_puppet_custom_mxid_index,
v15_backfill_anchor_id, v15_backfill_anchor_id,
v16_backfill_type,
) )
@@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.util.async_db import Connection, Scheme from mautrix.util.async_db import Connection, Scheme
latest_version = 15 latest_version = 16
async def create_latest_tables(conn: Connection, scheme: Scheme) -> int: async def create_latest_tables(conn: Connection, scheme: Scheme) -> int:
@@ -219,9 +219,11 @@ async def create_latest_tables(conn: Connection, scheme: Scheme) -> int:
queue_id INTEGER PRIMARY KEY {gen}, queue_id INTEGER PRIMARY KEY {gen},
user_mxid TEXT, user_mxid TEXT,
priority INTEGER NOT NULL, priority INTEGER NOT NULL,
type TEXT NOT NULL,
portal_tgid BIGINT, portal_tgid BIGINT,
portal_tg_receiver BIGINT, portal_tg_receiver BIGINT,
anchor_msg_id BIGINT, anchor_msg_id BIGINT,
extra_data jsonb,
messages_per_batch INTEGER NOT NULL, messages_per_batch INTEGER NOT NULL,
post_batch_delay INTEGER NOT NULL, post_batch_delay INTEGER NOT NULL,
max_batches INTEGER NOT NULL, max_batches INTEGER NOT NULL,
@@ -0,0 +1,28 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.util.async_db import Connection, Scheme
from . import upgrade_table
@upgrade_table.register(description="Add type for backfill queue items")
async def upgrade_v16(conn: Connection, scheme: Scheme) -> None:
await conn.execute(
"ALTER TABLE backfill_queue ADD COLUMN type TEXT NOT NULL DEFAULT 'historical'"
)
await conn.execute("ALTER TABLE backfill_queue ADD COLUMN extra_data jsonb")
if scheme != Scheme.SQLITE:
await conn.execute("ALTER TABLE backfill_queue ALTER COLUMN type DROP DEFAULT")
+4 -1
View File
@@ -169,7 +169,10 @@ bridge:
sync_update_limit: 0 sync_update_limit: 0
# Number of most recently active dialogs to create portals for when syncing chats. # Number of most recently active dialogs to create portals for when syncing chats.
# Set to 0 to remove limit. # Set to 0 to remove limit.
sync_create_limit: 30 sync_create_limit: 15
# Should all chats be scheduled to be created later?
# This is best used in combination with MSC2716 infinite backfill.
sync_deferred_create_all: false
# Whether or not to sync and create portals for direct chats at startup. # Whether or not to sync and create portals for direct chats at startup.
sync_direct_chats: false sync_direct_chats: false
# The maximum number of simultaneous Telegram deletions to handle. # The maximum number of simultaneous Telegram deletions to handle.
+81 -37
View File
@@ -187,6 +187,7 @@ from . import (
from .config import Config from .config import Config
from .db import ( from .db import (
Backfill, Backfill,
BackfillType,
DisappearingMessage, DisappearingMessage,
Message as DBMessage, Message as DBMessage,
Portal as DBPortal, Portal as DBPortal,
@@ -257,6 +258,7 @@ class Portal(DBPortal, BasePortal):
backfill_method_lock: asyncio.Lock backfill_method_lock: asyncio.Lock
backfill_leave: set[IntentAPI] | None backfill_leave: set[IntentAPI] | None
backfill_msc2716: bool backfill_msc2716: bool
backfill_enable: bool
alias: RoomAlias | None alias: RoomAlias | None
@@ -439,6 +441,7 @@ class Portal(DBPortal, BasePortal):
cls.filter_list = cls.config["bridge.filter.list"] cls.filter_list = cls.config["bridge.filter.list"]
cls.hs_domain = cls.config["homeserver.domain"] cls.hs_domain = cls.config["homeserver.domain"]
cls.backfill_msc2716 = cls.config["bridge.backfill.msc2716"] cls.backfill_msc2716 = cls.config["bridge.backfill.msc2716"]
cls.backfill_enable = cls.config["bridge.backfill.enable"]
cls.alias_template = SimpleTemplate( cls.alias_template = SimpleTemplate(
cls.config["bridge.alias_template"], cls.config["bridge.alias_template"],
"groupname", "groupname",
@@ -645,9 +648,10 @@ class Portal(DBPortal, BasePortal):
puppet: p.Puppet = None, puppet: p.Puppet = None,
levels: PowerLevelStateEventContent = None, levels: PowerLevelStateEventContent = None,
users: list[User] = None, users: list[User] = None,
client: MautrixTelegramClient | None = None,
) -> None: ) -> None:
try: try:
await self._update_matrix_room(user, entity, puppet, levels, users) await self._update_matrix_room(user, entity, puppet, levels, users, client)
except Exception: except Exception:
self.log.exception("Fatal error updating Matrix room") self.log.exception("Fatal error updating Matrix room")
@@ -658,12 +662,15 @@ class Portal(DBPortal, BasePortal):
puppet: p.Puppet = None, puppet: p.Puppet = None,
levels: PowerLevelStateEventContent = None, levels: PowerLevelStateEventContent = None,
users: list[User] = None, users: list[User] = None,
client: MautrixTelegramClient | None = None,
) -> None: ) -> None:
if not client:
client = user.client
if not self.is_direct: if not self.is_direct:
await self.update_info(user, entity) await self.update_info(user, entity, client=client)
if not users: if not users:
users = await self._get_users(user, entity) users = await self._get_users(client, entity)
await self._sync_telegram_users(user, users) await self._sync_telegram_users(user, users, client=client)
await self.update_power_levels(users, levels) await self.update_power_levels(users, levels)
else: else:
if not puppet: if not puppet:
@@ -708,12 +715,13 @@ class Portal(DBPortal, BasePortal):
entity: TypeChat | User = None, entity: TypeChat | User = None,
invites: InviteList = None, invites: InviteList = None,
update_if_exists: bool = True, update_if_exists: bool = True,
client: MautrixTelegramClient | None = None,
) -> RoomID | None: ) -> RoomID | None:
if self.mxid: if self.mxid:
if update_if_exists: if update_if_exists:
if not entity: if not entity:
try: try:
entity = await self.get_entity(user) entity = await self.get_entity(user, client)
except Exception: except Exception:
self.log.exception(f"Failed to get entity through {user.tgid} for update") self.log.exception(f"Failed to get entity through {user.tgid} for update")
return self.mxid return self.mxid
@@ -723,7 +731,7 @@ class Portal(DBPortal, BasePortal):
return self.mxid return self.mxid
async with self._room_create_lock: async with self._room_create_lock:
try: try:
return await self._create_matrix_room(user, entity, invites) return await self._create_matrix_room(user, entity, invites, client=client)
except Exception: except Exception:
self.log.exception("Fatal error creating Matrix room") self.log.exception("Fatal error creating Matrix room")
@@ -774,17 +782,23 @@ class Portal(DBPortal, BasePortal):
self.log.warning("Failed to update bridge info", exc_info=True) self.log.warning("Failed to update bridge info", exc_info=True)
async def _create_matrix_room( async def _create_matrix_room(
self, user: au.AbstractUser, entity: TypeChat | User, invites: InviteList self,
user: au.AbstractUser,
entity: TypeChat | User,
invites: InviteList,
client: MautrixTelegramClient | None = None,
) -> RoomID | None: ) -> RoomID | None:
if self.mxid: if self.mxid:
return self.mxid return self.mxid
elif not self.allow_bridging: elif not self.allow_bridging:
return None return None
if not client:
client = user.client
invites = invites or [] invites = invites or []
if not entity: if not entity:
entity = await self.get_entity(user) entity = await self.get_entity(user, client)
self.log.trace("Fetched data: %s", entity) self.log.trace("Fetched data: %s", entity)
participants_count = 2 participants_count = 2
@@ -794,17 +808,17 @@ class Portal(DBPortal, BasePortal):
participants_count = entity.participants_count participants_count = entity.participants_count
if participants_count is None and self.config["bridge.max_member_count"] > 0: if participants_count is None and self.config["bridge.max_member_count"] > 0:
self.log.warning(f"Participant count not found in entity, fetching manually") self.log.warning(f"Participant count not found in entity, fetching manually")
participants_count = (await user.client.get_participants(entity, limit=0)).total participants_count = (await client.get_participants(entity, limit=0)).total
if participants_count and 0 < self.config["bridge.max_member_count"] < participants_count: if participants_count and 0 < self.config["bridge.max_member_count"] < participants_count:
self.log.warning(f"Not bridging chat, too many participants (%d)", participants_count) self.log.warning(f"Not bridging chat, too many participants (%d)", participants_count)
self._bridging_blocked_at_runtime = True self._bridging_blocked_at_runtime = True
return None return None
self.log.debug("Creating room") self.log.debug("Preparing to create room")
if self.is_direct: if self.is_direct:
puppet = await self.get_dm_puppet() puppet = await self.get_dm_puppet()
await puppet.update_info(user, entity) await puppet.update_info(user, entity, client_override=client)
self._main_intent = puppet.intent_for(self) self._main_intent = puppet.intent_for(self)
if self.tgid == user.tgid: if self.tgid == user.tgid:
self.title = "Telegram Saved Messages" self.title = "Telegram Saved Messages"
@@ -812,7 +826,7 @@ class Portal(DBPortal, BasePortal):
else: else:
puppet = None puppet = None
self._main_intent = self.az.intent self._main_intent = self.az.intent
await self.update_info(user, entity) await self.update_info(user, entity, client=client)
preset = RoomCreatePreset.PRIVATE preset = RoomCreatePreset.PRIVATE
if self.peer_type == "channel" and entity.username: if self.peer_type == "channel" and entity.username:
@@ -831,7 +845,7 @@ class Portal(DBPortal, BasePortal):
power_levels = putil.get_base_power_levels(self, entity=entity) power_levels = putil.get_base_power_levels(self, entity=entity)
users = None users = None
if not self.is_direct: if not self.is_direct:
users = await self._get_users(user, entity) users = await self._get_users(client, entity)
if self.has_bot: if self.has_bot:
extra_invites = self.config["bridge.relaybot.group_chat_invite"] extra_invites = self.config["bridge.relaybot.group_chat_invite"]
invites += extra_invites invites += extra_invites
@@ -840,7 +854,7 @@ class Portal(DBPortal, BasePortal):
await putil.participants_to_power_levels(self, users, power_levels) await putil.participants_to_power_levels(self, users, power_levels)
elif self.bot and self.tg_receiver == self.bot.tgid: elif self.bot and self.tg_receiver == self.bot.tgid:
assert puppet is not None assert puppet is not None
invites = self.config["bridge.relaybot.private_chat.invite"] invites += self.config["bridge.relaybot.private_chat.invite"]
for invite in invites: for invite in invites:
power_levels.users.setdefault(invite, 100) power_levels.users.setdefault(invite, 100)
self.title = puppet.displayname self.title = puppet.displayname
@@ -865,10 +879,10 @@ class Portal(DBPortal, BasePortal):
autojoin_invites = self.bridge.homeserver_software.is_hungry autojoin_invites = self.bridge.homeserver_software.is_hungry
create_invites = set() create_invites = set()
if autojoin_invites: if autojoin_invites:
invites = []
create_invites |= set(invites) create_invites |= set(invites)
invites = []
if not self.is_direct: if not self.is_direct:
create_invites |= await self._sync_telegram_users(user, users) create_invites |= await self._sync_telegram_users(user, users, client=client)
if self.config["bridge.encryption.default"] and self.matrix.e2ee: if self.config["bridge.encryption.default"] and self.matrix.e2ee:
self.encrypted = True self.encrypted = True
initial_state.append( initial_state.append(
@@ -896,6 +910,11 @@ class Portal(DBPortal, BasePortal):
) )
with self.backfill_lock: with self.backfill_lock:
self.log.debug(
f"Creating room with parameters invite={create_invites}, {autojoin_invites=}, "
f"{preset=}, {alias=!r}, name={self.title!r}, topic={self.about!r}, "
f"{creation_content=}, is_direct={self.is_direct}"
)
room_id = await self.main_intent.create_room( room_id = await self.main_intent.create_room(
alias_localpart=alias, alias_localpart=alias,
preset=preset, preset=preset,
@@ -912,7 +931,7 @@ class Portal(DBPortal, BasePortal):
self.name_set = bool(self.title) self.name_set = bool(self.title)
self.avatar_set = bool(self.avatar_url) self.avatar_set = bool(self.avatar_url)
if self.encrypted and self.matrix.e2ee and self.is_direct: if not autojoin_invites and self.encrypted and self.matrix.e2ee and self.is_direct:
try: try:
await self.az.intent.ensure_joined(room_id) await self.az.intent.ensure_joined(room_id)
except Exception: except Exception:
@@ -928,7 +947,7 @@ class Portal(DBPortal, BasePortal):
if not autojoin_invites or not self.is_direct: if not autojoin_invites or not self.is_direct:
await self.invite_to_matrix(invites) await self.invite_to_matrix(invites)
await self.update_matrix_room( await self.update_matrix_room(
user, entity, puppet, levels=power_levels, users=users user, entity, puppet, levels=power_levels, users=users, client=client
) )
else: else:
# When using autojoining, all metadata is already set, so just update state caches # When using autojoining, all metadata is already set, so just update state caches
@@ -943,9 +962,9 @@ class Portal(DBPortal, BasePortal):
) )
await self.save() await self.save()
if isinstance(user, u.User) or not self.backfill_msc2716: if self.backfill_enable and (isinstance(user, u.User) or not self.backfill_msc2716):
try: try:
await self.forward_backfill(user, initial=True) await self.forward_backfill(user, initial=True, client=client)
except Exception: except Exception:
self.log.exception("Error in initial backfill") self.log.exception("Error in initial backfill")
if self.backfill_msc2716: if self.backfill_msc2716:
@@ -955,7 +974,7 @@ class Portal(DBPortal, BasePortal):
async def _get_users( async def _get_users(
self, self,
user: au.AbstractUser, client: MautrixTelegramClient,
entity: TypeInputPeer | InputUser | TypeChat | TypeUser | InputChannel, entity: TypeInputPeer | InputUser | TypeChat | TypeUser | InputChannel,
) -> list[TypeUser]: ) -> list[TypeUser]:
if self.peer_type == "channel" and not self.megagroup and not self.sync_channel_members: if self.peer_type == "channel" and not self.megagroup and not self.sync_channel_members:
@@ -963,7 +982,7 @@ class Portal(DBPortal, BasePortal):
limit = self.max_initial_member_sync limit = self.max_initial_member_sync
if limit == 0: if limit == 0:
return [] return []
return await putil.get_users(user.client, self.tgid, entity, limit, self.peer_type) return await putil.get_users(client, self.tgid, entity, limit, self.peer_type)
async def update_power_levels( async def update_power_levels(
self, self,
@@ -985,7 +1004,10 @@ class Portal(DBPortal, BasePortal):
await user.register_portal(self) await user.register_portal(self)
async def _sync_telegram_users( async def _sync_telegram_users(
self, source: au.AbstractUser, users: list[User] self,
source: au.AbstractUser,
users: list[User],
client: MautrixTelegramClient | None = None,
) -> set[UserID] | None: ) -> set[UserID] | None:
allowed_tgids = set() allowed_tgids = set()
join_mxids = set() join_mxids = set()
@@ -996,7 +1018,7 @@ class Portal(DBPortal, BasePortal):
await self._add_bot_chat(entity) await self._add_bot_chat(entity)
allowed_tgids.add(entity.id) allowed_tgids.add(entity.id)
await puppet.update_info(source, entity) await puppet.update_info(source, entity, client_override=client)
if skip_deleted and entity.deleted: if skip_deleted and entity.deleted:
continue continue
@@ -1122,7 +1144,12 @@ class Portal(DBPortal, BasePortal):
except MForbidden as e: except MForbidden as e:
self.log.warning(f"Failed to kick {user.mxid}: {e}") self.log.warning(f"Failed to kick {user.mxid}: {e}")
async def update_info(self, user: au.AbstractUser, entity: TypeChat = None) -> None: async def update_info(
self,
user: au.AbstractUser,
entity: TypeChat = None,
client: MautrixTelegramClient | None = None,
) -> None:
if self.peer_type == "user": if self.peer_type == "user":
self.log.warning("Called update_info() for direct chat portal") self.log.warning("Called update_info() for direct chat portal")
return return
@@ -1131,7 +1158,7 @@ class Portal(DBPortal, BasePortal):
self.log.debug("Updating info") self.log.debug("Updating info")
try: try:
if not entity: if not entity:
entity = await self.get_entity(user) entity = await self.get_entity(user, client)
self.log.trace("Fetched data: %s", entity) self.log.trace("Fetched data: %s", entity)
if self.peer_type == "channel": if self.peer_type == "channel":
@@ -1145,7 +1172,7 @@ class Portal(DBPortal, BasePortal):
changed = await self._update_title(entity.title) or changed changed = await self._update_title(entity.title) or changed
if isinstance(entity.photo, ChatPhoto): if isinstance(entity.photo, ChatPhoto):
changed = await self._update_avatar(user, entity.photo) or changed changed = await self._update_avatar(user, entity.photo, client=client) or changed
except Exception: except Exception:
self.log.exception(f"Failed to update info from source {user.tgid}") self.log.exception(f"Failed to update info from source {user.tgid}")
@@ -1254,6 +1281,7 @@ class Portal(DBPortal, BasePortal):
photo: TypeChatPhoto | TypeUserProfilePhoto, photo: TypeChatPhoto | TypeUserProfilePhoto,
sender: p.Puppet | None = None, sender: p.Puppet | None = None,
save: bool = False, save: bool = False,
client: MautrixTelegramClient | None = None,
) -> bool: ) -> bool:
if isinstance(photo, (ChatPhoto, UserProfilePhoto)): if isinstance(photo, (ChatPhoto, UserProfilePhoto)):
loc = InputPeerPhotoFileLocation( loc = InputPeerPhotoFileLocation(
@@ -1280,7 +1308,7 @@ class Portal(DBPortal, BasePortal):
self.avatar_url = None self.avatar_url = None
elif self.photo_id != photo_id or not self.avatar_url: elif self.photo_id != photo_id or not self.avatar_url:
file = await util.transfer_file_to_matrix( file = await util.transfer_file_to_matrix(
user.client, client or user.client,
self.main_intent, self.main_intent,
loc, loc,
async_upload=self.config["homeserver.async_media"], async_upload=self.config["homeserver.async_media"],
@@ -2649,21 +2677,28 @@ class Portal(DBPortal, BasePortal):
max_batches: int | None = None, max_batches: int | None = None,
messages_per_batch: int | None = None, messages_per_batch: int | None = None,
anchor_msg_id: int | None = None, anchor_msg_id: int | None = None,
extra_data: dict[str, Any] | None = None,
type: BackfillType = BackfillType.HISTORICAL,
) -> None: ) -> None:
# TODO check that there are no queued backfills new_backfill = Backfill.new(
# if not await Backfill.get(source.mxid, self.tgid, self.tg_receiver):
await Backfill.new(
user_mxid=source.mxid, user_mxid=source.mxid,
priority=priority, priority=priority,
type=type,
portal_tgid=self.tgid, portal_tgid=self.tgid,
portal_tg_receiver=self.tg_receiver, portal_tg_receiver=self.tg_receiver,
anchor_msg_id=anchor_msg_id, anchor_msg_id=anchor_msg_id,
extra_data=extra_data,
messages_per_batch=( messages_per_batch=(
messages_per_batch or self.config["bridge.backfill.incremental.messages_per_batch"] messages_per_batch or self.config["bridge.backfill.incremental.messages_per_batch"]
), ),
post_batch_delay=self.config["bridge.backfill.incremental.post_batch_delay"], post_batch_delay=self.config["bridge.backfill.incremental.post_batch_delay"],
max_batches=max_batches or self._default_max_batches, max_batches=max_batches or self._default_max_batches,
).insert() )
deleted_entries = await new_backfill.insert()
if deleted_entries:
self.log.debug(
"Deleted backfill queue entries while inserting new item: %s", deleted_entries
)
source.wakeup_backfill_task.set() source.wakeup_backfill_task.set()
async def forward_backfill( async def forward_backfill(
@@ -2672,14 +2707,17 @@ class Portal(DBPortal, BasePortal):
initial: bool, initial: bool,
last_tgid: int | None = None, last_tgid: int | None = None,
override_limit: int | None = None, override_limit: int | None = None,
client: MautrixTelegramClient | None = None,
) -> str: ) -> str:
if not client:
client = source.client
type = "initial" if initial else "sync" type = "initial" if initial else "sync"
limit = override_limit or self.config[f"bridge.backfill.forward.{type}_limit"] limit = override_limit or self.config[f"bridge.backfill.forward.{type}_limit"]
if limit == 0: if limit == 0:
return "Limit is zero, not backfilling" return "Limit is zero, not backfilling"
with self.backfill_lock: with self.backfill_lock:
output = await self.backfill( output = await self.backfill(
source, source.client, forward=True, forward_limit=limit, last_tgid=last_tgid source, client, forward=True, forward_limit=limit, last_tgid=last_tgid
) )
self.log.debug(f"Forward backfill complete, status: {output}") self.log.debug(f"Forward backfill complete, status: {output}")
return output return output
@@ -2693,6 +2731,8 @@ class Portal(DBPortal, BasePortal):
forward_limit: int | None = None, forward_limit: int | None = None,
last_tgid: int | None = None, last_tgid: int | None = None,
) -> str: ) -> str:
if not self.backfill_enable:
return "Backfilling is disabled in the bridge config"
async with self.backfill_method_lock: async with self.backfill_method_lock:
return await self._locked_backfill( return await self._locked_backfill(
source, client, req, forward, forward_limit, last_tgid source, client, req, forward, forward_limit, last_tgid
@@ -2778,7 +2818,7 @@ class Portal(DBPortal, BasePortal):
self.log.debug(f"Enqueuing more backfill through {source.mxid}") self.log.debug(f"Enqueuing more backfill through {source.mxid}")
await self.enqueue_backfill( await self.enqueue_backfill(
source, source,
priority=100, priority=max(100, req.priority + 1),
messages_per_batch=req.messages_per_batch, messages_per_batch=req.messages_per_batch,
max_batches=-1 if req.max_batches < 0 else (req.max_batches - 1), max_batches=-1 if req.max_batches < 0 else (req.max_batches - 1),
anchor_msg_id=lowest_id, anchor_msg_id=lowest_id,
@@ -3515,9 +3555,13 @@ class Portal(DBPortal, BasePortal):
) -> Awaitable[TypeInputPeer | TypeInputChannel]: ) -> Awaitable[TypeInputPeer | TypeInputChannel]:
return user.client.get_input_entity(self.peer) return user.client.get_input_entity(self.peer)
async def get_entity(self, user: au.AbstractUser) -> TypeChat: async def get_entity(
self, user: au.AbstractUser, client: MautrixTelegramClient | None = None
) -> TypeChat:
if not client:
client = user.client
try: try:
return await user.client.get_entity(self.peer) return await client.get_entity(self.peer)
except ValueError: except ValueError:
if user.is_bot: if user.is_bot:
self.log.warning(f"Could not find entity with bot {user.tgid}. Failing...") self.log.warning(f"Could not find entity with bot {user.tgid}. Failing...")
@@ -3525,7 +3569,7 @@ class Portal(DBPortal, BasePortal):
self.log.warning( self.log.warning(
f"Could not find entity with user {user.tgid}. falling back to get_dialogs." f"Could not find entity with user {user.tgid}. falling back to get_dialogs."
) )
async for dialog in user.client.iter_dialogs(): async for dialog in client.iter_dialogs():
if dialog.entity.id == self.tgid: if dialog.entity.id == self.tgid:
return dialog.entity return dialog.entity
raise raise
+115 -48
View File
@@ -16,7 +16,7 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, NamedTuple, cast from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, NamedTuple, cast
from datetime import datetime, timedelta, timezone from datetime import datetime
import asyncio import asyncio
import time import time
@@ -39,6 +39,7 @@ from telethon.tl.types import (
InputUserSelf, InputUserSelf,
NotifyPeer, NotifyPeer,
PeerUser, PeerUser,
TypeChat,
TypeUpdate, TypeUpdate,
UpdateFolderPeers, UpdateFolderPeers,
UpdateNewChannelMessage, UpdateNewChannelMessage,
@@ -62,7 +63,7 @@ from mautrix.util.opt_prometheus import Gauge
from . import portal as po, puppet as pu, util from . import portal as po, puppet as pu, util
from .abstract_user import AbstractUser from .abstract_user import AbstractUser
from .db import Backfill, Message as DBMessage, PgSession, User as DBUser from .db import Backfill, BackfillType, Message as DBMessage, PgSession, User as DBUser
from .tgclient import MautrixTelegramClient from .tgclient import MautrixTelegramClient
from .types import TelegramID from .types import TelegramID
@@ -347,7 +348,7 @@ class User(DBUser, AbstractUser, BaseUser):
self._track_metric(METRIC_LOGGED_IN, True) self._track_metric(METRIC_LOGGED_IN, True)
if not self._backfill_task or self._backfill_task.done(): if not self._backfill_task or self._backfill_task.done():
self._backfill_task = asyncio.create_task(self._handle_backfill_requests_loop()) self._backfill_task = asyncio.create_task(self._try_handle_backfill_requests_loop())
try: try:
puppet = await pu.Puppet.get_by_tgid(self.tgid) puppet = await pu.Puppet.get_by_tgid(self.tgid)
@@ -378,6 +379,14 @@ class User(DBUser, AbstractUser, BaseUser):
"max_file_size": min(self.bridge.matrix.media_config.upload_size, 2000 * 1024 * 1024), "max_file_size": min(self.bridge.matrix.media_config.upload_size, 2000 * 1024 * 1024),
} }
async def _try_handle_backfill_requests_loop(self) -> None:
if not self.config["bridge.backfill.enable"]:
return
try:
await self._handle_backfill_requests_loop()
except Exception:
self.log.exception("Fatal error in backfill request loop")
async def _handle_backfill_requests_loop(self) -> None: async def _handle_backfill_requests_loop(self) -> None:
while True: while True:
req = await Backfill.get_next(self.mxid) req = await Backfill.get_next(self.mxid)
@@ -388,7 +397,11 @@ class User(DBUser, AbstractUser, BaseUser):
pass pass
self.wakeup_backfill_task.clear() self.wakeup_backfill_task.clear()
else: else:
await self._takeout_and_backfill(req) try:
await self._takeout_and_backfill(req)
except Exception:
self.log.exception("Error in takeout backfill loop, retrying in an hour")
await asyncio.sleep(3600)
async def _takeout_and_backfill(self, first_req: Backfill, first_attempt: bool = True) -> None: async def _takeout_and_backfill(self, first_req: Backfill, first_attempt: bool = True) -> None:
self.takeout_retry_immediate.clear() self.takeout_retry_immediate.clear()
@@ -437,13 +450,33 @@ class User(DBUser, AbstractUser, BaseUser):
TelegramID(req.portal_tgid), tg_receiver=TelegramID(req.portal_tg_receiver) TelegramID(req.portal_tgid), tg_receiver=TelegramID(req.portal_tg_receiver)
) )
await req.mark_dispatched() await req.mark_dispatched()
await portal.backfill(self, client, req=req) if req.type == BackfillType.HISTORICAL:
await portal.backfill(self, client, req=req)
elif req.type == BackfillType.SYNC_DIALOG:
await self._backfill_sync_dialog(portal, client, req.extra_data)
await req.mark_done() await req.mark_done()
await asyncio.sleep(req.post_batch_delay) await asyncio.sleep(req.post_batch_delay)
except Exception: except Exception:
self.log.exception("Error handling backfill request for %s", req.portal_tgid) self.log.exception("Error handling backfill request for %s", req.portal_tgid)
await req.set_cooldown_timeout(1800) await req.set_cooldown_timeout(1800)
async def _backfill_sync_dialog(
self, portal: po.Portal, client: MautrixTelegramClient, post_sync_args: dict[str, Any]
) -> None:
if portal.mxid:
self.log.debug("Portal already exists, skipping dialog sync backfill queue item")
return
self.log.info(f"Creating portal for {portal.tgid_log} as part of backfill loop")
try:
await portal.create_matrix_room(
self, client=client, update_if_exists=False, invites=[self.mxid]
)
except Exception:
self.log.exception(f"Error while creating {portal.tgid_log}")
else:
puppet = await pu.Puppet.get_by_custom_mxid(self.mxid)
await self._post_sync_dialog(portal, puppet, was_created=True, **post_sync_args)
async def update(self, update: TypeUpdate) -> bool: async def update(self, update: TypeUpdate) -> bool:
if not self.is_bot: if not self.is_bot:
return False return False
@@ -604,19 +637,18 @@ class User(DBUser, AbstractUser, BaseUser):
if active and tag_info is None: if active and tag_info is None:
tag_info = RoomTagInfo(order=0.5) tag_info = RoomTagInfo(order=0.5)
tag_info[DOUBLE_PUPPET_SOURCE_KEY] = self.bridge.name tag_info[DOUBLE_PUPPET_SOURCE_KEY] = self.bridge.name
self.log.debug("Adding tag {tag} to {portal.mxid}/{portal.tgid}") self.log.debug(f"Adding tag {tag} to {portal.mxid}/{portal.tgid}")
await puppet.intent.set_room_tag(portal.mxid, tag, tag_info) await puppet.intent.set_room_tag(portal.mxid, tag, tag_info)
elif ( elif (
not active and tag_info and tag_info.get(DOUBLE_PUPPET_SOURCE_KEY) == self.bridge.name not active and tag_info and tag_info.get(DOUBLE_PUPPET_SOURCE_KEY) == self.bridge.name
): ):
self.log.debug("Removing tag {tag} from {portal.mxid}/{portal.tgid}") self.log.debug(f"Removing tag {tag} from {portal.mxid}/{portal.tgid}")
await puppet.intent.remove_room_tag(portal.mxid, tag) await puppet.intent.remove_room_tag(portal.mxid, tag)
async def _mute_room(self, puppet: pu.Puppet, portal: po.Portal, mute_until: datetime) -> None: async def _mute_room(self, puppet: pu.Puppet, portal: po.Portal, mute_until: float) -> None:
if not self.config["bridge.mute_bridging"] or not portal or not portal.mxid: if not self.config["bridge.mute_bridging"] or not portal or not portal.mxid:
return return
now = datetime.utcnow().replace(tzinfo=timezone.utc) if mute_until is not None and mute_until > time.time():
if mute_until is not None and mute_until > now:
self.log.debug( self.log.debug(
f"Muting {portal.mxid}/{portal.tgid} (muted until {mute_until} on Telegram)" f"Muting {portal.mxid}/{portal.tgid} (muted until {mute_until} on Telegram)"
) )
@@ -672,12 +704,24 @@ class User(DBUser, AbstractUser, BaseUser):
portal = await po.Portal.get_by_entity( portal = await po.Portal.get_by_entity(
update.peer.peer, tg_receiver=self.tgid, create=False update.peer.peer, tg_receiver=self.tgid, create=False
) )
await self._mute_room(puppet, portal, update.notify_settings.mute_until) await self._mute_room(puppet, portal, update.notify_settings.mute_until.timestamp())
async def _sync_dialog( async def _sync_dialog(
self, portal: po.Portal, dialog: Dialog, should_create: bool, puppet: pu.Puppet | None self, portal: po.Portal, dialog: Dialog, should_create: bool, puppet: pu.Puppet | None
) -> None: ) -> None:
was_created = False was_created = False
post_sync_args = {
"last_message_ts": cast(datetime, dialog.date).timestamp(),
"unread_count": dialog.unread_count,
"max_read_id": dialog.dialog.read_inbox_max_id,
"mute_until": (
dialog.dialog.notify_settings.mute_until.timestamp()
if dialog.dialog.notify_settings.mute_until
else None
),
"pinned": dialog.pinned,
"archived": dialog.archived,
}
if portal.mxid: if portal.mxid:
try: try:
await portal.forward_backfill(self, initial=False, last_tgid=dialog.message.id) await portal.forward_backfill(self, initial=False, last_tgid=dialog.message.id)
@@ -693,41 +737,65 @@ class User(DBUser, AbstractUser, BaseUser):
was_created = True was_created = True
except Exception: except Exception:
self.log.exception(f"Error while creating {portal.tgid_log}") self.log.exception(f"Error while creating {portal.tgid_log}")
if portal.mxid and puppet and puppet.is_real_user: elif self.config["bridge.sync_deferred_create_all"]:
tg_space = portal.tgid if portal.peer_type == "channel" else self.tgid await portal.enqueue_backfill(
last_message_date: float = cast(datetime, dialog.date).timestamp() self,
unread_threshold_hours = self.config["bridge.backfill.unread_hours_threshold"] priority=40,
force_read = ( type=BackfillType.SYNC_DIALOG,
was_created extra_data=post_sync_args,
and unread_threshold_hours >= 0
and last_message_date + (unread_threshold_hours * 60 * 60) < time.time()
) )
if dialog.unread_count == 0 or force_read: if portal.mxid and puppet and puppet.is_real_user:
# This is usually more reliable than finding a specific message await self._post_sync_dialog(
# e.g. if the last read message is a service message that isn't in the message db portal=portal,
last_read = await DBMessage.find_last(portal.mxid, tg_space) puppet=puppet,
if force_read: was_created=was_created,
self.log.debug( **post_sync_args,
f"Marking {portal.tgid_log} as read because the last message is from " )
f"{dialog.date} (unread threshold is {unread_threshold_hours} hours)"
) async def _post_sync_dialog(
else: self,
last_read = await DBMessage.get_one_by_tgid( portal: po.Portal,
portal.tgid, tg_space, dialog.dialog.read_inbox_max_id puppet: pu.Puppet,
was_created: bool,
max_read_id: int,
last_message_ts: float,
unread_count: int,
mute_until: float,
pinned: bool,
archived: bool,
) -> None:
self.log.debug(
f"Running dialog post-sync for {portal.tgid_log} with args "
f"{was_created=}, {max_read_id=}, {last_message_ts=}, {unread_count=}, "
f"{mute_until=}, {pinned=}, {archived=}"
)
tg_space = portal.tgid if portal.peer_type == "channel" else self.tgid
unread_threshold_hours = self.config["bridge.backfill.unread_hours_threshold"]
force_read = (
was_created
and unread_threshold_hours >= 0
and last_message_ts + (unread_threshold_hours * 60 * 60) < time.time()
)
if unread_count == 0 or force_read:
# This is usually more reliable than finding a specific message
# e.g. if the last read message is a service message that isn't in the message db
last_read = await DBMessage.find_last(portal.mxid, tg_space)
if force_read:
self.log.debug(
f"Marking {portal.tgid_log} as read because the last message is from "
f"{last_message_ts} (unread threshold is {unread_threshold_hours} hours)"
) )
try: else:
if last_read: last_read = await DBMessage.get_one_by_tgid(portal.tgid, tg_space, max_read_id)
await puppet.intent.mark_read(last_read.mx_room, last_read.mxid) try:
if was_created or not self.config["bridge.tag_only_on_create"]: if last_read:
await self._mute_room(puppet, portal, dialog.dialog.notify_settings.mute_until) await puppet.intent.mark_read(last_read.mx_room, last_read.mxid)
await self._tag_room( if was_created or not self.config["bridge.tag_only_on_create"]:
puppet, portal, self.config["bridge.pinned_tag"], dialog.pinned await self._mute_room(puppet, portal, mute_until)
) await self._tag_room(puppet, portal, self.config["bridge.pinned_tag"], pinned)
await self._tag_room( await self._tag_room(puppet, portal, self.config["bridge.archive_tag"], archived)
puppet, portal, self.config["bridge.archive_tag"], dialog.archived except Exception:
) self.log.exception(f"Error updating read status and tags for {portal.tgid_log}")
except Exception:
self.log.exception(f"Error updating read status and tags for {portal.tgid_log}")
async def get_cached_portals(self) -> dict[tuple[TelegramID, TelegramID], po.Portal]: async def get_cached_portals(self) -> dict[tuple[TelegramID, TelegramID], po.Portal]:
if self._portals_cache is None: if self._portals_cache is None:
@@ -744,9 +812,7 @@ class User(DBUser, AbstractUser, BaseUser):
update_limit = self.config["bridge.sync_update_limit"] or None update_limit = self.config["bridge.sync_update_limit"] or None
create_limit = self.config["bridge.sync_create_limit"] create_limit = self.config["bridge.sync_create_limit"]
index = 0 index = 0
self.log.debug( self.log.debug(f"Syncing dialogs ({update_limit=}, {create_limit=})")
f"Syncing dialogs (update_limit={update_limit}, create_limit={create_limit})"
)
await self.push_bridge_state(BridgeStateEvent.BACKFILLING) await self.push_bridge_state(BridgeStateEvent.BACKFILLING)
puppet = await pu.Puppet.get_by_custom_mxid(self.mxid) puppet = await pu.Puppet.get_by_custom_mxid(self.mxid)
dialog: Dialog dialog: Dialog
@@ -767,11 +833,12 @@ class User(DBUser, AbstractUser, BaseUser):
continue continue
portal = await po.Portal.get_by_entity(entity, tg_receiver=self.tgid) portal = await po.Portal.get_by_entity(entity, tg_receiver=self.tgid)
new_portal_cache[portal.tgid_full] = portal new_portal_cache[portal.tgid_full] = portal
should_create = not create_limit or index < create_limit
coro = self._sync_dialog( coro = self._sync_dialog(
portal=portal, portal=portal,
dialog=dialog, dialog=dialog,
puppet=puppet, puppet=puppet,
should_create=not create_limit or index < create_limit, should_create=should_create,
) )
creators.append(asyncio.create_task(coro)) creators.append(asyncio.create_task(coro))
index += 1 index += 1