Update to mautrix-python v0.7

This commit is contained in:
Tulir Asokan
2020-08-06 20:34:09 +03:00
parent 92c572d761
commit eae7bba649
20 changed files with 98 additions and 93 deletions
+2 -2
View File
@@ -125,10 +125,10 @@ class TelegramBridge(Bridge):
return Portal.get_by_mxid(room_id) return Portal.get_by_mxid(room_id)
async def get_puppet(self, user_id: UserID, create: bool = False) -> Puppet: async def get_puppet(self, user_id: UserID, create: bool = False) -> Puppet:
return Puppet.get_by_mxid(user_id, create=create) return await Puppet.get_by_mxid(user_id, create=create)
async def get_double_puppet(self, user_id: UserID) -> Puppet: async def get_double_puppet(self, user_id: UserID) -> Puppet:
return Puppet.get_by_custom_mxid(user_id) return await Puppet.get_by_custom_mxid(user_id)
TelegramBridge().run() TelegramBridge().run()
+6 -6
View File
@@ -182,11 +182,11 @@ class AbstractUser(ABC):
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def register_portal(self, portal: po.Portal) -> None: async def register_portal(self, portal: po.Portal) -> None:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def unregister_portal(self, tgid: int, tg_receiver: int) -> None: async def unregister_portal(self, tgid: int, tg_receiver: int) -> None:
raise NotImplementedError() raise NotImplementedError()
async def _update_catch(self, update: TypeUpdate) -> None: async def _update_catch(self, update: TypeUpdate) -> None:
@@ -358,10 +358,10 @@ class AbstractUser(ABC):
if isinstance(update, UpdateUserName): if isinstance(update, UpdateUserName):
puppet.username = update.username puppet.username = update.username
if await puppet.update_displayname(self, update): if await puppet.update_displayname(self, update):
puppet.save() await puppet.save()
elif isinstance(update, UpdateUserPhoto): elif isinstance(update, UpdateUserPhoto):
if await puppet.update_avatar(self, update.photo): if await puppet.update_avatar(self, update.photo):
puppet.save() await puppet.save()
else: else:
self.log.warning(f"Unexpected other user info update: {type(update)}") self.log.warning(f"Unexpected other user info update: {type(update)}")
@@ -461,8 +461,8 @@ class AbstractUser(ABC):
if isinstance(update.action, MessageActionChannelMigrateFrom): if isinstance(update.action, MessageActionChannelMigrateFrom):
self.log.trace(f"Received %s in %s by %d, unregistering portal...", self.log.trace(f"Received %s in %s by %d, unregistering portal...",
update.action, portal.tgid_log, sender.id) update.action, portal.tgid_log, sender.id)
self.unregister_portal(update.action.chat_id, update.action.chat_id) await self.unregister_portal(update.action.chat_id, update.action.chat_id)
self.register_portal(portal) await self.register_portal(portal)
return return
self.log.trace("Handling action %s to %s by %d", update.action, portal.tgid_log, self.log.trace("Handling action %s to %s by %d", update.action, portal.tgid_log,
sender.id) sender.id)
+2 -2
View File
@@ -117,10 +117,10 @@ class Bot(AbstractUser):
except (ChannelPrivateError, ChannelInvalidError): except (ChannelPrivateError, ChannelInvalidError):
self.remove_chat(TelegramID(channel_id.channel_id)) self.remove_chat(TelegramID(channel_id.channel_id))
def register_portal(self, portal: po.Portal) -> None: async def register_portal(self, portal: po.Portal) -> None:
self.add_chat(portal.tgid, portal.peer_type) self.add_chat(portal.tgid, portal.peer_type)
def unregister_portal(self, tgid: int, tg_receiver: int) -> None: async def unregister_portal(self, tgid: int, tg_receiver: int) -> None:
self.remove_chat(tgid) self.remove_chat(tgid)
def add_chat(self, chat_id: TelegramID, chat_type: str) -> None: def add_chat(self, chat_id: TelegramID, chat_type: str) -> None:
+1 -2
View File
@@ -115,8 +115,7 @@ def command_handler(_func: Optional[CommandHandlerFunc] = None, *, needs_auth: b
class CommandProcessor(BaseCommandProcessor): class CommandProcessor(BaseCommandProcessor):
def __init__(self, context: c.Context) -> None: def __init__(self, context: c.Context) -> None:
super().__init__(az=context.az, config=context.config, event_class=CommandEvent, super().__init__(event_class=CommandEvent, bridge=context.bridge)
loop=context.loop, bridge=context.bridge)
self.tgbot = context.bot self.tgbot = context.bot
self.bridge = context.bridge self.bridge = context.bridge
self.az, self.config, self.loop, self.tgbot = context.core self.az, self.config, self.loop, self.tgbot = context.core
+1 -1
View File
@@ -86,7 +86,7 @@ async def reload_user(evt: CommandEvent) -> EventID:
user = u.User.get_by_mxid(mxid, create=False) user = u.User.get_by_mxid(mxid, create=False)
if not user: if not user:
return await evt.reply("User not found") return await evt.reply("User not found")
puppet = pu.Puppet.get_by_custom_mxid(mxid) puppet = await pu.Puppet.get_by_custom_mxid(mxid)
if puppet: if puppet:
puppet.sync_task.cancel() puppet.sync_task.cancel()
await user.stop() await user.stop()
+1 -1
View File
@@ -177,7 +177,7 @@ async def confirm_bridge(evt: CommandEvent) -> Optional[EventID]:
portal.mxid = bridge_to_mxid portal.mxid = bridge_to_mxid
portal.title, portal.about, levels = await get_initial_state(evt.az.intent, evt.room_id) portal.title, portal.about, levels = await get_initial_state(evt.az.intent, evt.room_id)
portal.photo_id = "" portal.photo_id = ""
portal.save() await portal.save()
asyncio.ensure_future(portal.update_matrix_room(user, entity, direct, levels=levels), asyncio.ensure_future(portal.update_matrix_room(user, entity, direct, levels=levels),
loop=evt.loop) loop=evt.loop)
+1 -1
View File
@@ -63,7 +63,7 @@ async def config(evt: CommandEvent) -> None:
await config_add_del(evt, portal, key, value, cmd) await config_add_del(evt, portal, key, value, cmd)
else: else:
return return
portal.save() await portal.save()
def config_help(evt: CommandEvent) -> Awaitable[EventID]: def config_help(evt: CommandEvent) -> Awaitable[EventID]:
+3
View File
@@ -114,6 +114,9 @@ class Config(BaseBridgeConfig):
copy("bridge.encryption.allow") copy("bridge.encryption.allow")
copy("bridge.encryption.default") copy("bridge.encryption.default")
copy("bridge.encryption.database") copy("bridge.encryption.database")
copy("bridge.encryption.key_sharing.allow")
copy("bridge.encryption.key_sharing.require_cross_signing")
copy("bridge.encryption.key_sharing.require_verification")
copy("bridge.private_chat_portal_meta") copy("bridge.private_chat_portal_meta")
copy("bridge.delivery_receipts") copy("bridge.delivery_receipts")
copy("bridge.delivery_error_reports") copy("bridge.delivery_error_reports")
+12
View File
@@ -223,6 +223,18 @@ bridge:
# Pickle: pickle:///filename.pickle # Pickle: pickle:///filename.pickle
# Postgres: postgres://username:password@hostname/dbname # Postgres: postgres://username:password@hostname/dbname
database: default database: default
# Options for automatic key sharing.
key_sharing:
# Enable key sharing? If enabled, key requests for rooms where users are in will be fulfilled.
# You must use a client that supports requesting keys from other users to use this feature.
allow: false
# Require the requesting device to have a valid cross-signing signature?
# This doesn't require that the bridge has verified the device, only that the user has verified it.
# Not yet implemented.
require_cross_signing: false
# Require devices to be verified by the bridge?
# Verification by the bridge is not yet implemented.
require_verification: true
# Whether or not to explicitly set the avatar and room name for private # Whether or not to explicitly set the avatar and room name for private
# chat portal rooms. This will be implicitly enabled if encryption.default is true. # chat portal rooms. This will be implicitly enabled if encryption.default is true.
private_chat_portal_meta: false private_chat_portal_meta: false
@@ -48,7 +48,7 @@ class MatrixParser(BaseMatrixParser[TelegramMessage]):
@classmethod @classmethod
def user_pill_to_fstring(cls, msg: TelegramMessage, user_id: UserID) -> TelegramMessage: def user_pill_to_fstring(cls, msg: TelegramMessage, user_id: UserID) -> TelegramMessage:
user = (pu.Puppet.get_by_mxid(user_id) user = (pu.Puppet.deprecated_sync_get_by_mxid(user_id)
or u.User.get_by_mxid(user_id, create=False)) or u.User.get_by_mxid(user_id, create=False))
if not user: if not user:
return msg return msg
+1 -1
View File
@@ -130,7 +130,7 @@ async def _add_reply_header(source: 'AbstractUser', content: TextMessageEventCon
event: MessageEvent = await main_intent.get_event(msg.mx_room, msg.mxid) event: MessageEvent = await main_intent.get_event(msg.mx_room, msg.mxid)
if isinstance(event.content, TextMessageEventContent): if isinstance(event.content, TextMessageEventContent):
event.content.trim_reply_fallback() event.content.trim_reply_fallback()
puppet = pu.Puppet.get_by_mxid(event.sender, create=False) puppet = await pu.Puppet.get_by_mxid(event.sender, create=False)
content.set_reply(event, displayname=puppet.displayname if puppet else event.sender) content.set_reply(event, displayname=puppet.displayname if puppet else event.sender)
except MatrixRequestError: except MatrixRequestError:
log.exception("Failed to get event to add reply fallback") log.exception("Failed to get event to add reply fallback")
+5 -12
View File
@@ -53,9 +53,7 @@ class MatrixHandler(BaseMatrixHandler):
self.user_id_prefix = f"@{prefix}" self.user_id_prefix = f"@{prefix}"
self.user_id_suffix = f"{suffix}:{homeserver}" self.user_id_suffix = f"{suffix}:{homeserver}"
super(MatrixHandler, self).__init__(context.az, context.config, loop=context.loop, super().__init__(command_processor=com.CommandProcessor(context), bridge=context.bridge)
command_processor=com.CommandProcessor(context),
bridge=context.bridge)
self.bot = context.bot self.bot = context.bot
self.previously_typing = {} self.previously_typing = {}
@@ -107,8 +105,8 @@ class MatrixHandler(BaseMatrixHandler):
e2be_ok = None e2be_ok = None
if self.config["bridge.encryption.default"] and self.e2ee: if self.config["bridge.encryption.default"] and self.e2ee:
e2be_ok = await portal.enable_dm_encryption() e2be_ok = await portal.enable_dm_encryption()
portal.save() await portal.save()
inviter.register_portal(portal) await inviter.register_portal(portal)
if e2be_ok is True: if e2be_ok is True:
evt_type, content = await self.e2ee.encrypt( evt_type, content = await self.e2ee.encrypt(
room_id, EventType.ROOM_MESSAGE, room_id, EventType.ROOM_MESSAGE,
@@ -208,7 +206,7 @@ class MatrixHandler(BaseMatrixHandler):
return return
await sender.ensure_started() await sender.ensure_started()
puppet = pu.Puppet.get_by_mxid(user_id) puppet = await pu.Puppet.get_by_mxid(user_id)
if puppet: if puppet:
if ban: if ban:
await portal.ban_matrix(puppet, sender) await portal.ban_matrix(puppet, sender)
@@ -375,7 +373,7 @@ class MatrixHandler(BaseMatrixHandler):
if not isinstance(evt, (RedactionEvent, MessageEvent, StateEvent, EncryptedEvent)): if not isinstance(evt, (RedactionEvent, MessageEvent, StateEvent, EncryptedEvent)):
return True return True
if evt.content.get(self.az.real_user_content_key, False): if evt.content.get(self.az.real_user_content_key, False):
puppet = pu.Puppet.get_by_custom_mxid(evt.sender) puppet = pu.Puppet.deprecated_sync_get_by_custom_mxid(evt.sender)
if puppet: if puppet:
self.log.debug("Ignoring puppet-sent event %s", evt.event_id) self.log.debug("Ignoring puppet-sent event %s", evt.event_id)
return True return True
@@ -412,11 +410,6 @@ class MatrixHandler(BaseMatrixHandler):
elif evt.type == EventType.ROOM_TOMBSTONE: elif evt.type == EventType.ROOM_TOMBSTONE:
await self.handle_room_upgrade(evt.room_id, evt.sender, evt.content.replacement_room, await self.handle_room_upgrade(evt.room_id, evt.sender, evt.content.replacement_room,
evt.event_id) evt.event_id)
elif evt.type == EventType.ROOM_ENCRYPTION:
portal = po.Portal.get_by_mxid(evt.room_id)
if portal:
portal.encrypted = True
portal.save()
async def log_event_handle_duration(self, evt: Event, duration: float) -> None: async def log_event_handle_duration(self, evt: Event, duration: float) -> None:
if EVENT_TIME: if EVENT_TIME:
+6 -13
View File
@@ -35,6 +35,7 @@ from mautrix.types import (RoomID, RoomAlias, UserID, EventID, EventType, Messag
from mautrix.util.simple_template import SimpleTemplate from mautrix.util.simple_template import SimpleTemplate
from mautrix.util.simple_lock import SimpleLock from mautrix.util.simple_lock import SimpleLock
from mautrix.util.logging import TraceLogger from mautrix.util.logging import TraceLogger
from mautrix.bridge import BasePortal as MautrixBasePortal
from ..types import TelegramID from ..types import TelegramID
from ..context import Context from ..context import Context
@@ -57,7 +58,7 @@ InviteList = Union[UserID, List[UserID]]
config: Optional['Config'] = None config: Optional['Config'] = None
class BasePortal(ABC): class BasePortal(MautrixBasePortal, ABC):
base_log: TraceLogger = logging.getLogger("mau.portal") base_log: TraceLogger = logging.getLogger("mau.portal")
az: AppService = None az: AppService = None
bot: 'Bot' = None bot: 'Bot' = None
@@ -129,7 +130,7 @@ class BasePortal(ABC):
self.deleted = False self.deleted = False
self.log = self.base_log.getChild(self.tgid_log if self.tgid else self.mxid) self.log = self.base_log.getChild(self.tgid_log if self.tgid else self.mxid)
self.backfill_lock = SimpleLock("Waiting for backfilling to finish before handling %s", self.backfill_lock = SimpleLock("Waiting for backfilling to finish before handling %s",
log=self.log, loop=self.loop) log=self.log)
self.backfill_leave = None self.backfill_leave = None
self.dedup = PortalDedup(self) self.dedup = PortalDedup(self)
@@ -289,12 +290,13 @@ class BasePortal(ABC):
@classmethod @classmethod
async def cleanup_room(cls, intent: IntentAPI, room_id: RoomID, message: str, async def cleanup_room(cls, intent: IntentAPI, room_id: RoomID, message: str,
puppets_only: bool = False) -> None: puppets_only: bool = False) -> None:
# TODO use the cleanup_room from BasePortal instead of this
try: try:
members = await intent.get_room_members(room_id) members = await intent.get_room_members(room_id)
except MatrixRequestError: except MatrixRequestError:
members = [] members = []
for user in members: for user in members:
puppet = p.Puppet.get_by_mxid(UserID(user), create=False) puppet = await p.Puppet.get_by_mxid(UserID(user), create=False)
if user != intent.mxid and (not puppets_only or puppet): if user != intent.mxid and (not puppets_only or puppet):
try: try:
if puppet: if puppet:
@@ -340,7 +342,7 @@ class BasePortal(ABC):
config=json.dumps(self.local_config), avatar_url=self.avatar_url, config=json.dumps(self.local_config), avatar_url=self.avatar_url,
encrypted=self.encrypted) encrypted=self.encrypted)
def save(self) -> None: async def save(self) -> None:
self.db_instance.edit(mxid=self.mxid, username=self.username, title=self.title, self.db_instance.edit(mxid=self.mxid, username=self.username, title=self.title,
about=self.about, photo_id=self.photo_id, megagroup=self.megagroup, about=self.about, photo_id=self.photo_id, megagroup=self.megagroup,
config=json.dumps(self.local_config), avatar_url=self.avatar_url, config=json.dumps(self.local_config), avatar_url=self.avatar_url,
@@ -474,15 +476,6 @@ class BasePortal(ABC):
type_name if create else None) type_name if create else None)
# endregion # endregion
async def _send_message(self, intent: IntentAPI, content: MessageEventContent,
event_type: EventType = EventType.ROOM_MESSAGE, **kwargs) -> EventID:
if self.encrypted and self.matrix.e2ee:
if intent.api.is_real_user:
content[intent.api.real_user_content_key] = True
event_type, content = await self.matrix.e2ee.encrypt(self.mxid, event_type, content)
return await intent.send_message_event(self.mxid, event_type, content, **kwargs)
# region Abstract methods (cross-called in matrix/metadata/telegram classes) # region Abstract methods (cross-called in matrix/metadata/telegram classes)
@abstractmethod @abstractmethod
+4 -5
View File
@@ -37,7 +37,6 @@ from telethon.tl.types import (
from mautrix.types import (EventID, RoomID, UserID, ContentURI, MessageType, MessageEventContent, from mautrix.types import (EventID, RoomID, UserID, ContentURI, MessageType, MessageEventContent,
TextMessageEventContent, MediaMessageEventContent, Format, TextMessageEventContent, MediaMessageEventContent, Format,
LocationMessageEventContent) LocationMessageEventContent)
from mautrix.bridge import BasePortal as MautrixBasePortal
from ..types import TelegramID from ..types import TelegramID
from ..db import Message as DBMessage from ..db import Message as DBMessage
@@ -61,7 +60,7 @@ TypeMessage = Union[Message, MessageService]
config: Optional['Config'] = None config: Optional['Config'] = None
class PortalMatrix(BasePortal, MautrixBasePortal, ABC): class PortalMatrix(BasePortal, ABC):
async def _get_state_change_message(self, event: str, user: 'u.User', **kwargs: Any async def _get_state_change_message(self, event: str, user: 'u.User', **kwargs: Any
) -> Optional[str]: ) -> Optional[str]:
tpl = self.get_config(f"state_event_formats.{event}") tpl = self.get_config(f"state_event_formats.{event}")
@@ -484,7 +483,7 @@ class PortalMatrix(BasePortal, MautrixBasePortal, ABC):
peer = await self.get_input_entity(sender) peer = await self.get_input_entity(sender)
await sender.client(EditChatAboutRequest(peer=peer, about=about)) await sender.client(EditChatAboutRequest(peer=peer, about=about))
self.about = about self.about = about
self.save() await self.save()
await self._send_delivery_receipt(event_id) await self._send_delivery_receipt(event_id)
async def handle_matrix_title(self, sender: 'u.User', title: str, event_id: EventID) -> None: async def handle_matrix_title(self, sender: 'u.User', title: str, event_id: EventID) -> None:
@@ -498,7 +497,7 @@ class PortalMatrix(BasePortal, MautrixBasePortal, ABC):
response = await sender.client(EditTitleRequest(channel=channel, title=title)) response = await sender.client(EditTitleRequest(channel=channel, title=title))
self.dedup.register_outgoing_actions(response) self.dedup.register_outgoing_actions(response)
self.title = title self.title = title
self.save() await self.save()
await self._send_delivery_receipt(event_id) await self._send_delivery_receipt(event_id)
await self.update_bridge_info() await self.update_bridge_info()
@@ -530,7 +529,7 @@ class PortalMatrix(BasePortal, MautrixBasePortal, ABC):
if is_photo_update: if is_photo_update:
loc, size = self._get_largest_photo_size(update.message.action.photo) loc, size = self._get_largest_photo_size(update.message.action.photo)
self.photo_id = f"{size.location.volume_id}-{size.location.local_id}" self.photo_id = f"{size.location.volume_id}-{size.location.local_id}"
self.save() await self.save()
break break
await self._send_delivery_receipt(event_id) await self._send_delivery_receipt(event_id)
await self.update_bridge_info() await self.update_bridge_info()
+23 -23
View File
@@ -114,7 +114,7 @@ class PortalMetadata(BasePortal, ABC):
await source.client( await source.client(
UpdateUsernameRequest(await self.get_input_entity(source), username)) UpdateUsernameRequest(await self.get_input_entity(source), username))
if await self._update_username(username): if await self._update_username(username):
self.save() await self.save()
async def create_telegram_chat(self, source: 'u.User', supergroup: bool = False) -> None: async def create_telegram_chat(self, source: 'u.User', supergroup: bool = False) -> None:
if not self.mxid: if not self.mxid:
@@ -217,10 +217,10 @@ class PortalMetadata(BasePortal, ABC):
changed = await self._update_title(puppet.displayname) changed = await self._update_title(puppet.displayname)
changed = await self._update_avatar(user, entity.photo) or changed changed = await self._update_avatar(user, entity.photo) or changed
if changed: if changed:
self.save() await self.save()
await self.update_bridge_info() await self.update_bridge_info()
puppet = p.Puppet.get_by_custom_mxid(user.mxid) puppet = await p.Puppet.get_by_custom_mxid(user.mxid)
if puppet: if puppet:
try: try:
await puppet.intent.ensure_joined(self.mxid) await puppet.intent.ensure_joined(self.mxid)
@@ -352,7 +352,7 @@ class PortalMetadata(BasePortal, ABC):
invites += extra_invites invites += extra_invites
for invite in extra_invites: for invite in extra_invites:
power_levels.users.setdefault(invite, 100) power_levels.users.setdefault(invite, 100)
self._participants_to_power_levels(participants, power_levels) await self._participants_to_power_levels(participants, power_levels)
elif self.bot and self.tg_receiver == self.bot.tgid: elif self.bot and self.tg_receiver == self.bot.tgid:
invites = config["bridge.relaybot.private_chat.invite"] invites = config["bridge.relaybot.private_chat.invite"]
for invite in invites: for invite in invites:
@@ -408,9 +408,9 @@ class PortalMetadata(BasePortal, ABC):
self.mxid = room_id self.mxid = room_id
self.by_mxid[self.mxid] = self self.by_mxid[self.mxid] = self
self.save() await self.save()
await self.az.state_store.set_power_levels(self.mxid, power_levels) await self.az.state_store.set_power_levels(self.mxid, power_levels)
user.register_portal(self) await user.register_portal(self)
update_room = self.loop.create_task(self.update_matrix_room( update_room = self.loop.create_task(self.update_matrix_room(
user, entity, direct, puppet, user, entity, direct, puppet,
@@ -497,8 +497,8 @@ class PortalMetadata(BasePortal, ABC):
return True return True
return False return False
def _participants_to_power_levels(self, participants: List[TypeParticipant], async def _participants_to_power_levels(self, participants: List[TypeParticipant],
levels: PowerLevelStateEventContent) -> bool: levels: PowerLevelStateEventContent) -> bool:
bot_level = levels.get_user_level(self.main_intent.mxid) bot_level = levels.get_user_level(self.main_intent.mxid)
if bot_level < levels.get_event_level(EventType.ROOM_POWER_LEVELS): if bot_level < levels.get_event_level(EventType.ROOM_POWER_LEVELS):
return False return False
@@ -514,7 +514,7 @@ class PortalMetadata(BasePortal, ABC):
new_level = self._get_level_from_participant(participant) new_level = self._get_level_from_participant(participant)
if user: if user:
user.register_portal(self) await user.register_portal(self)
changed = self._participant_to_power_levels(levels, user, new_level, changed = self._participant_to_power_levels(levels, user, new_level,
bot_level) or changed bot_level) or changed
@@ -527,17 +527,17 @@ class PortalMetadata(BasePortal, ABC):
levels: PowerLevelStateEventContent = None) -> None: levels: PowerLevelStateEventContent = None) -> None:
if not levels: if not levels:
levels = await self.main_intent.get_power_levels(self.mxid) levels = await self.main_intent.get_power_levels(self.mxid)
if self._participants_to_power_levels(participants, levels): if await self._participants_to_power_levels(participants, levels):
await self.main_intent.set_power_levels(self.mxid, levels) await self.main_intent.set_power_levels(self.mxid, levels)
def _add_bot_chat(self, bot: User) -> None: async def _add_bot_chat(self, bot: User) -> None:
if self.bot and bot.id == self.bot.tgid: if self.bot and bot.id == self.bot.tgid:
self.bot.add_chat(self.tgid, self.peer_type) self.bot.add_chat(self.tgid, self.peer_type)
return return
user = u.User.get_by_tgid(TelegramID(bot.id)) user = u.User.get_by_tgid(TelegramID(bot.id))
if user and user.is_bot: if user and user.is_bot:
user.register_portal(self) await user.register_portal(self)
async def _sync_telegram_users(self, source: 'AbstractUser', users: List[User]) -> None: async def _sync_telegram_users(self, source: 'AbstractUser', users: List[User]) -> None:
allowed_tgids = set() allowed_tgids = set()
@@ -547,7 +547,7 @@ class PortalMetadata(BasePortal, ABC):
continue continue
puppet = p.Puppet.get(TelegramID(entity.id)) puppet = p.Puppet.get(TelegramID(entity.id))
if entity.bot: if entity.bot:
self._add_bot_chat(entity) await self._add_bot_chat(entity)
allowed_tgids.add(entity.id) allowed_tgids.add(entity.id)
await puppet.intent_for(self).ensure_joined(self.mxid) await puppet.intent_for(self).ensure_joined(self.mxid)
await puppet.update_info(source, entity) await puppet.update_info(source, entity)
@@ -556,7 +556,7 @@ class PortalMetadata(BasePortal, ABC):
if user: if user:
await self.invite_to_matrix(user.mxid) await self.invite_to_matrix(user.mxid)
puppet = p.Puppet.get_by_custom_mxid(user.mxid) puppet = await p.Puppet.get_by_custom_mxid(user.mxid)
if puppet: if puppet:
try: try:
await puppet.intent.ensure_joined(self.mxid) await puppet.intent.ensure_joined(self.mxid)
@@ -587,7 +587,7 @@ class PortalMetadata(BasePortal, ABC):
continue continue
mx_user = u.User.get_by_mxid(user_mxid, create=False) mx_user = u.User.get_by_mxid(user_mxid, create=False)
if mx_user and mx_user.is_bot and mx_user.tgid not in allowed_tgids: if mx_user and mx_user.is_bot and mx_user.tgid not in allowed_tgids:
mx_user.unregister_portal(*self.tgid_full) await mx_user.unregister_portal(*self.tgid_full)
if mx_user and not self.has_bot and mx_user.tgid not in allowed_tgids: if mx_user and not self.has_bot and mx_user.tgid not in allowed_tgids:
try: try:
@@ -607,7 +607,7 @@ class PortalMetadata(BasePortal, ABC):
user = u.User.get_by_tgid(user_id) user = u.User.get_by_tgid(user_id)
if user: if user:
user.register_portal(self) await user.register_portal(self)
await self.invite_to_matrix(user.mxid) await self.invite_to_matrix(user.mxid)
async def _delete_telegram_user(self, user_id: TelegramID, sender: p.Puppet) -> None: async def _delete_telegram_user(self, user_id: TelegramID, sender: p.Puppet) -> None:
@@ -624,7 +624,7 @@ class PortalMetadata(BasePortal, ABC):
else: else:
await puppet.intent_for(self).leave_room(self.mxid) await puppet.intent_for(self).leave_room(self.mxid)
if user: if user:
user.unregister_portal(*self.tgid_full) await user.unregister_portal(*self.tgid_full)
if sender.tgid != puppet.tgid: if sender.tgid != puppet.tgid:
try: try:
await sender.intent_for(self).kick_user(self.mxid, puppet.mxid) await sender.intent_for(self).kick_user(self.mxid, puppet.mxid)
@@ -664,7 +664,7 @@ class PortalMetadata(BasePortal, ABC):
self.log.exception(f"Failed to update info from source {user.tgid}") self.log.exception(f"Failed to update info from source {user.tgid}")
if changed: if changed:
self.save() await self.save()
await self.update_bridge_info() await self.update_bridge_info()
async def _update_username(self, username: str, save: bool = False) -> bool: async def _update_username(self, username: str, save: bool = False) -> bool:
@@ -682,7 +682,7 @@ class PortalMetadata(BasePortal, ABC):
await self.main_intent.set_join_rule(self.mxid, "invite") await self.main_intent.set_join_rule(self.mxid, "invite")
if save: if save:
self.save() await self.save()
return True return True
async def _try_set_state(self, sender: Optional['p.Puppet'], evt_type: EventType, async def _try_set_state(self, sender: Optional['p.Puppet'], evt_type: EventType,
@@ -707,7 +707,7 @@ class PortalMetadata(BasePortal, ABC):
await self._try_set_state(sender, EventType.ROOM_TOPIC, await self._try_set_state(sender, EventType.ROOM_TOPIC,
RoomTopicStateEventContent(topic=self.about)) RoomTopicStateEventContent(topic=self.about))
if save: if save:
self.save() await self.save()
return True return True
async def _update_title(self, title: str, sender: Optional['p.Puppet'] = None, async def _update_title(self, title: str, sender: Optional['p.Puppet'] = None,
@@ -719,7 +719,7 @@ class PortalMetadata(BasePortal, ABC):
await self._try_set_state(sender, EventType.ROOM_NAME, await self._try_set_state(sender, EventType.ROOM_NAME,
RoomNameStateEventContent(name=self.title)) RoomNameStateEventContent(name=self.title))
if save: if save:
self.save() await self.save()
return True return True
async def _update_avatar(self, user: 'AbstractUser', photo: TypeChatPhoto, async def _update_avatar(self, user: 'AbstractUser', photo: TypeChatPhoto,
@@ -750,7 +750,7 @@ class PortalMetadata(BasePortal, ABC):
self.photo_id = "" self.photo_id = ""
self.avatar_url = None self.avatar_url = None
if save: if save:
self.save() await self.save()
return True return True
file = await util.transfer_file_to_matrix(user.client, self.main_intent, loc) file = await util.transfer_file_to_matrix(user.client, self.main_intent, loc)
if file: if file:
@@ -759,7 +759,7 @@ class PortalMetadata(BasePortal, ABC):
self.photo_id = photo_id self.photo_id = photo_id
self.avatar_url = file.mxc self.avatar_url = file.mxc
if save: if save:
self.save() await self.save()
return True return True
return False return False
+14 -6
View File
@@ -24,7 +24,7 @@ from telethon.tl.types import (UserProfilePhoto, User, UpdateUserName, PeerUser,
from mautrix.appservice import AppService, IntentAPI from mautrix.appservice import AppService, IntentAPI
from mautrix.errors import MatrixRequestError from mautrix.errors import MatrixRequestError
from mautrix.bridge import CustomPuppetMixin from mautrix.bridge import BasePuppet
from mautrix.types import UserID, SyncToken, RoomID from mautrix.types import UserID, SyncToken, RoomID
from mautrix.util.simple_template import SimpleTemplate from mautrix.util.simple_template import SimpleTemplate
@@ -41,7 +41,7 @@ if TYPE_CHECKING:
config: Optional['Config'] = None config: Optional['Config'] = None
class Puppet(CustomPuppetMixin): class Puppet(BasePuppet):
log: logging.Logger = logging.getLogger("mau.puppet") log: logging.Logger = logging.getLogger("mau.puppet")
az: AppService az: AppService
mx: 'MatrixHandler' mx: 'MatrixHandler'
@@ -166,7 +166,7 @@ class Puppet(CustomPuppetMixin):
def new_db_instance(self) -> DBPuppet: def new_db_instance(self) -> DBPuppet:
return DBPuppet(id=self.id, **self._fields) return DBPuppet(id=self.id, **self._fields)
def save(self) -> None: async def save(self) -> None:
self.db_instance.edit(**self._fields) self.db_instance.edit(**self._fields)
@classmethod @classmethod
@@ -249,7 +249,7 @@ class Puppet(CustomPuppetMixin):
self.is_bot = info.bot self.is_bot = info.bot
if changed: if changed:
self.save() await self.save()
async def update_displayname(self, source: 'AbstractUser', info: Union[User, UpdateUserName] async def update_displayname(self, source: 'AbstractUser', info: Union[User, UpdateUserName]
) -> bool: ) -> bool:
@@ -355,7 +355,7 @@ class Puppet(CustomPuppetMixin):
return None return None
@classmethod @classmethod
def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['Puppet']: def deprecated_sync_get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['Puppet']:
tgid = cls.get_id_from_mxid(mxid) tgid = cls.get_id_from_mxid(mxid)
if tgid: if tgid:
return cls.get(tgid, create) return cls.get(tgid, create)
@@ -363,7 +363,11 @@ class Puppet(CustomPuppetMixin):
return None return None
@classmethod @classmethod
def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']: async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['Puppet']:
return cls.deprecated_sync_get_by_mxid(mxid, create)
@classmethod
def deprecated_sync_get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
if not mxid: if not mxid:
raise ValueError("Matrix ID can't be empty") raise ValueError("Matrix ID can't be empty")
@@ -379,6 +383,10 @@ class Puppet(CustomPuppetMixin):
return None return None
@classmethod
async def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
return cls.deprecated_sync_get_by_custom_mxid(mxid)
@classmethod @classmethod
def all_with_custom_mxid(cls) -> Iterable['Puppet']: def all_with_custom_mxid(cls) -> Iterable['Puppet']:
return (cls.by_custom_mxid[puppet.custom_mxid] return (cls.by_custom_mxid[puppet.custom_mxid]
+11 -11
View File
@@ -152,7 +152,7 @@ class User(AbstractUser, BaseUser):
return DBUser(mxid=self.mxid, tgid=self.tgid, tg_username=self.username, return DBUser(mxid=self.mxid, tgid=self.tgid, tg_username=self.username,
saved_contacts=self.saved_contacts, portals=self.db_portals) saved_contacts=self.saved_contacts, portals=self.db_portals)
def save(self, contacts: bool = False, portals: bool = False) -> None: async def save(self, contacts: bool = False, portals: bool = False) -> None:
self.db_instance.edit(tgid=self.tgid, tg_username=self.username, tg_phone=self.phone, self.db_instance.edit(tgid=self.tgid, tg_username=self.username, tg_phone=self.phone,
saved_contacts=self.saved_contacts) saved_contacts=self.saved_contacts)
if contacts: if contacts:
@@ -242,7 +242,7 @@ class User(AbstractUser, BaseUser):
return False return False
if portal: if portal:
self.register_portal(portal) await self.register_portal(portal)
return False return False
# Don't bother handling the update # Don't bother handling the update
@@ -271,7 +271,7 @@ class User(AbstractUser, BaseUser):
self.tgid = TelegramID(info.id) self.tgid = TelegramID(info.id)
self.by_tgid[self.tgid] = self self.by_tgid[self.tgid] = self
if changed: if changed:
self.save() await self.save()
async def log_out(self) -> bool: async def log_out(self) -> bool:
puppet = pu.Puppet.get(self.tgid) puppet = pu.Puppet.get(self.tgid)
@@ -287,14 +287,14 @@ class User(AbstractUser, BaseUser):
pass pass
self.portals = {} self.portals = {}
self.contacts = [] self.contacts = []
self.save(portals=True, contacts=True) await self.save(portals=True, contacts=True)
if self.tgid: if self.tgid:
try: try:
del self.by_tgid[self.tgid] del self.by_tgid[self.tgid]
except KeyError: except KeyError:
pass pass
self.tgid = None self.tgid = None
self.save() await self.save()
ok = await self.client.log_out() ok = await self.client.log_out()
if not ok: if not ok:
return False return False
@@ -367,11 +367,11 @@ class User(AbstractUser, BaseUser):
create_task = portal.create_matrix_room(self, entity, invites=[self.mxid]) create_task = portal.create_matrix_room(self, entity, invites=[self.mxid])
creators.append(self.loop.create_task(create_task)) creators.append(self.loop.create_task(create_task))
index += 1 index += 1
self.save(portals=True) await self.save(portals=True)
await asyncio.gather(*creators) await asyncio.gather(*creators)
self.log.debug("Dialog syncing complete") self.log.debug("Dialog syncing complete")
def register_portal(self, portal: po.Portal) -> None: async def register_portal(self, portal: po.Portal) -> None:
self.log.trace(f"Registering portal {portal.tgid_full}") self.log.trace(f"Registering portal {portal.tgid_full}")
try: try:
if self.portals[portal.tgid_full] == portal: if self.portals[portal.tgid_full] == portal:
@@ -379,13 +379,13 @@ class User(AbstractUser, BaseUser):
except KeyError: except KeyError:
pass pass
self.portals[portal.tgid_full] = portal self.portals[portal.tgid_full] = portal
self.save(portals=True) await self.save(portals=True)
def unregister_portal(self, tgid: int, tg_receiver: int) -> None: async def unregister_portal(self, tgid: int, tg_receiver: int) -> None:
self.log.trace(f"Unregistering portal {(tgid, tg_receiver)}") self.log.trace(f"Unregistering portal {(tgid, tg_receiver)}")
try: try:
del self.portals[(tgid, tg_receiver)] del self.portals[(tgid, tg_receiver)]
self.save(portals=True) await self.save(portals=True)
except KeyError: except KeyError:
pass pass
@@ -410,7 +410,7 @@ class User(AbstractUser, BaseUser):
puppet = pu.Puppet.get(user.id) puppet = pu.Puppet.get(user.id)
await puppet.update_info(self, user) await puppet.update_info(self, user)
self.contacts.append(puppet) self.contacts.append(puppet)
self.save(contacts=True) await self.save(contacts=True)
# endregion # endregion
# region Class instance lookup # region Class instance lookup
+2 -4
View File
@@ -136,8 +136,7 @@ async def transfer_thumbnail_to_matrix(client: MautrixTelegramClient, intent: In
decryption_info = None decryption_info = None
upload_mime_type = mime_type upload_mime_type = mime_type
if encrypt: if encrypt:
file, decryption_info_dict = encrypt_attachment(file) file, decryption_info = encrypt_attachment(file)
decryption_info = EncryptedFile.deserialize(decryption_info_dict)
upload_mime_type = "application/octet-stream" upload_mime_type = "application/octet-stream"
content_uri = await intent.upload_media(file, upload_mime_type) content_uri = await intent.upload_media(file, upload_mime_type)
if decryption_info: if decryption_info:
@@ -232,8 +231,7 @@ async def _unlocked_transfer_file_to_matrix(client: MautrixTelegramClient, inten
decryption_info = None decryption_info = None
upload_mime_type = mime_type upload_mime_type = mime_type
if encrypt and encrypt_attachment: if encrypt and encrypt_attachment:
file, decryption_info_dict = encrypt_attachment(file) file, decryption_info = encrypt_attachment(file)
decryption_info = EncryptedFile.deserialize(decryption_info_dict)
upload_mime_type = "application/octet-stream" upload_mime_type = "application/octet-stream"
content_uri = await intent.upload_media(file, upload_mime_type) content_uri = await intent.upload_media(file, upload_mime_type)
if decryption_info: if decryption_info:
@@ -183,7 +183,7 @@ class ProvisioningAPI(AuthAPI):
portal.mxid = room_id portal.mxid = room_id
portal.title, portal.about, levels = await get_initial_state(self.az.intent, room_id) portal.title, portal.about, levels = await get_initial_state(self.az.intent, room_id)
portal.photo_id = "" portal.photo_id = ""
portal.save() await portal.save()
asyncio.ensure_future(portal.update_matrix_room(user, entity, direct, levels=levels), asyncio.ensure_future(portal.update_matrix_room(user, entity, direct, levels=levels),
loop=self.loop) loop=self.loop)
+1 -1
View File
@@ -4,6 +4,6 @@ ruamel.yaml>=0.15.35,<0.17
python-magic>=0.4,<0.5 python-magic>=0.4,<0.5
commonmark>=0.8,<0.10 commonmark>=0.8,<0.10
aiohttp>=3,<4 aiohttp>=3,<4
mautrix>=0.6,<0.7 mautrix==0.7.0rc1
telethon>=1.16,<1.17 telethon>=1.16,<1.17
telethon-session-sqlalchemy>=0.2.14,<0.3 telethon-session-sqlalchemy>=0.2.14,<0.3