diff --git a/mautrix_telegram/__main__.py b/mautrix_telegram/__main__.py
index 97656dc6..dde8aaa3 100644
--- a/mautrix_telegram/__main__.py
+++ b/mautrix_telegram/__main__.py
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Optional
+from typing import Coroutine, List, Optional
import argparse
import asyncio
import logging.config
@@ -115,7 +115,7 @@ with appserv.run(config["appservice.hostname"], config["appservice.port"]) as st
startup_actions = (init_puppet(context) +
init_user(context) +
[start,
- context.mx.init_as_bot()])
+ context.mx.init_as_bot()]) # type: List[Coroutine]
if context.bot:
startup_actions.append(context.bot.start())
diff --git a/mautrix_telegram/abstract_user.py b/mautrix_telegram/abstract_user.py
index f1a72db4..5452328f 100644
--- a/mautrix_telegram/abstract_user.py
+++ b/mautrix_telegram/abstract_user.py
@@ -38,6 +38,7 @@ from .db import Message as DBMessage
from .tgclient import MautrixTelegramClient
if TYPE_CHECKING:
+ from .types import TelegramID
from .context import Context
from .config import Config
from .bot import Bot
@@ -60,17 +61,18 @@ class AbstractUser(ABC):
bot = None # type: Bot
ignore_incoming_bot_events = True # type: bool
- def __init__(self):
+ def __init__(self) -> None:
self.is_admin = False # type: bool
self.matrix_puppet_whitelisted = False # type: bool
self.puppet_whitelisted = False # type: bool
self.whitelisted = False # type: bool
self.relaybot_whitelisted = False # type: bool
self.client = None # type: MautrixTelegramClient
- self.tgid = None # type: int
+ self.tgid = None # type: TelegramID
self.mxid = None # type: str
self.is_relaybot = False # type: bool
self.is_bot = False # type: bool
+ self.relaybot = None # type: Optional[Bot]
@property
def connected(self) -> bool:
@@ -93,7 +95,7 @@ class AbstractUser(ABC):
config["telegram.proxy.rdns"],
config["telegram.proxy.username"], config["telegram.proxy.password"])
- def _init_client(self):
+ def _init_client(self) -> None:
self.log.debug(f"Initializing client for {self.name}")
device = f"{platform.system()} {platform.release()}"
sysversion = MautrixTelegramClient.__version__
@@ -114,18 +116,18 @@ class AbstractUser(ABC):
return False
@abstractmethod
- async def post_login(self):
+ async def post_login(self) -> None:
raise NotImplementedError()
@abstractmethod
- def register_portal(self, portal: po.Portal):
+ def register_portal(self, portal: po.Portal) -> None:
raise NotImplementedError()
@abstractmethod
- def unregister_portal(self, portal: po.Portal):
+ def unregister_portal(self, portal: po.Portal) -> None:
raise NotImplementedError()
- async def _update_catch(self, update: TypeUpdate):
+ async def _update_catch(self, update: TypeUpdate) -> None:
try:
if not await self.update(update):
await self._update(update)
@@ -154,14 +156,14 @@ class AbstractUser(ABC):
and (not self.is_bot or allow_bot)
and await self.is_logged_in())
- async def start(self, delete_unless_authenticated: bool = False) -> "AbstractUser":
+ async def start(self, delete_unless_authenticated: bool = False) -> 'AbstractUser':
if not self.client:
self._init_client()
await self.client.connect()
self.log.debug("%s connected: %s", self.mxid, self.connected)
return self
- async def ensure_started(self, even_if_no_session=False) -> "AbstractUser":
+ async def ensure_started(self, even_if_no_session=False) -> 'AbstractUser':
if not self.puppet_whitelisted:
return self
self.log.debug("ensure_started(%s, connected=%s, even_if_no_session=%s, session_count=%s)",
@@ -175,13 +177,13 @@ class AbstractUser(ABC):
await self.start(delete_unless_authenticated=not even_if_no_session)
return self
- async def stop(self):
+ async def stop(self) -> None:
await self.client.disconnect()
self.client = None
# region Telegram update handling
- async def _update(self, update: TypeUpdate):
+ async def _update(self, update: TypeUpdate) -> None:
if isinstance(update, (UpdateShortChatMessage, UpdateShortMessage, UpdateNewChannelMessage,
UpdateNewMessage, UpdateEditMessage, UpdateEditChannelMessage)):
await self.update_message(update)
@@ -207,18 +209,18 @@ class AbstractUser(ABC):
self.log.debug("Unhandled update: %s", update)
@staticmethod
- async def update_pinned_messages(update: UpdateChannelPinnedMessage):
+ async def update_pinned_messages(update: UpdateChannelPinnedMessage) -> None:
portal = po.Portal.get_by_tgid(update.channel_id)
if portal and portal.mxid:
await portal.receive_telegram_pin_id(update.id)
@staticmethod
- async def update_participants(update: UpdateChatParticipants):
+ async def update_participants(update: UpdateChatParticipants) -> None:
portal = po.Portal.get_by_tgid(update.participants.chat_id)
if portal and portal.mxid:
await portal.update_telegram_participants(update.participants.participants)
- async def update_read_receipt(self, update: UpdateReadHistoryOutbox):
+ async def update_read_receipt(self, update: UpdateReadHistoryOutbox) -> None:
if not isinstance(update.peer, PeerUser):
self.log.debug("Unexpected read receipt peer: %s", update.peer)
return
@@ -235,7 +237,8 @@ class AbstractUser(ABC):
puppet = pu.Puppet.get(update.peer.user_id)
await puppet.intent.mark_read(portal.mxid, message.mxid)
- async def update_admin(self, update: Union[UpdateChatAdmins, UpdateChatParticipantAdmin]):
+ async def update_admin(self,
+ update: Union[UpdateChatAdmins, UpdateChatParticipantAdmin]) -> None:
# TODO duplication not checked
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
if isinstance(update, UpdateChatAdmins):
@@ -245,7 +248,7 @@ class AbstractUser(ABC):
else:
self.log.warning("Unexpected admin status update: %s", update)
- async def update_typing(self, update: Union[UpdateUserTyping, UpdateChatUserTyping]):
+ async def update_typing(self, update: Union[UpdateUserTyping, UpdateChatUserTyping]) -> None:
if isinstance(update, UpdateUserTyping):
portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user")
else:
@@ -253,7 +256,7 @@ class AbstractUser(ABC):
sender = pu.Puppet.get(update.user_id)
await portal.handle_telegram_typing(sender, update)
- async def update_others_info(self, update: Union[UpdateUserName, UpdateUserPhoto]):
+ async def update_others_info(self, update: Union[UpdateUserName, UpdateUserPhoto]) -> None:
# TODO duplication not checked
puppet = pu.Puppet.get(update.user_id)
if isinstance(update, UpdateUserName):
@@ -265,7 +268,7 @@ class AbstractUser(ABC):
else:
self.log.warning("Unexpected other user info update: %s", update)
- async def update_status(self, update: UpdateUserStatus):
+ async def update_status(self, update: UpdateUserStatus) -> None:
puppet = pu.Puppet.get(update.user_id)
if isinstance(update.status, UserStatusOnline):
await puppet.default_mxid_intent.set_presence("online")
@@ -300,7 +303,7 @@ class AbstractUser(ABC):
return update, sender, portal
@staticmethod
- async def _try_redact(portal: po.Portal, message: DBMessage):
+ async def _try_redact(portal: po.Portal, message: DBMessage) -> None:
if not portal:
return
try:
@@ -308,7 +311,7 @@ class AbstractUser(ABC):
except MatrixRequestError:
pass
- async def delete_message(self, update: UpdateDeleteMessages):
+ async def delete_message(self, update: UpdateDeleteMessages) -> None:
if len(update.messages) > MAX_DELETIONS:
return
@@ -324,7 +327,7 @@ class AbstractUser(ABC):
await self._try_redact(portal, message)
self.db.commit()
- async def delete_channel_message(self, update: UpdateDeleteChannelMessages):
+ async def delete_channel_message(self, update: UpdateDeleteChannelMessages) -> None:
if len(update.messages) > MAX_DELETIONS:
return
@@ -340,7 +343,7 @@ class AbstractUser(ABC):
await self._try_redact(portal, message)
self.db.commit()
- async def update_message(self, original_update: UpdateMessage):
+ async def update_message(self, original_update: UpdateMessage) -> None:
update, sender, portal = self.get_message_details(original_update)
if self.ignore_incoming_bot_events and self.bot and sender.id == self.bot.tgid:
self.log.debug(f"Ignoring relaybot-sent message %s to %s", update, portal.tgid_log)
@@ -369,9 +372,9 @@ class AbstractUser(ABC):
# endregion
-def init(context: "Context"):
+def init(context: "Context") -> None:
global config, MAX_DELETIONS
- AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, AbstractUser.relaybot = context
+ AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, AbstractUser.relaybot = context.core
AbstractUser.ignore_incoming_bot_events = config["bridge.relaybot.ignore_own_incoming_events"]
AbstractUser.session_container = context.session_container
MAX_DELETIONS = config.get("bridge.max_telegram_delete", 10)
diff --git a/mautrix_telegram/bot.py b/mautrix_telegram/bot.py
index 51a6a110..78b156db 100644
--- a/mautrix_telegram/bot.py
+++ b/mautrix_telegram/bot.py
@@ -14,21 +14,27 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Awaitable, Callable, Pattern, Dict, TYPE_CHECKING
+from typing import Awaitable, Callable, Dict, List, Optional, Pattern, TYPE_CHECKING
import logging
import re
-from telethon.tl.types import *
+from telethon.tl.types import (
+ ChannelParticipantAdmin, ChannelParticipantCreator, ChatForbidden, ChatParticipantAdmin,
+ ChatParticipantCreator, InputChannel, InputUser, Message, MessageActionChatAddUser,
+ MessageActionChatDeleteUser, MessageEntityBotCommand, MessageService, PeerChannel, PeerChat,
+ TypePeer, UpdateNewChannelMessage, UpdateNewMessage)
from telethon.tl.functions.messages import GetChatsRequest, GetFullChatRequest
from telethon.tl.functions.channels import GetChannelsRequest, GetParticipantRequest
from telethon.errors import ChannelInvalidError, ChannelPrivateError
+from .types import MatrixUserID
from .abstract_user import AbstractUser
from .db import BotChat
from . import puppet as pu, portal as po, user as u
if TYPE_CHECKING:
from .config import Config
+ from .context import Context
config = None # type: Config
@@ -39,7 +45,7 @@ class Bot(AbstractUser):
log = logging.getLogger("mau.bot") # type: logging.Logger
mxid_regex = re.compile("@.+:.+") # type: Pattern
- def __init__(self, token: str):
+ def __init__(self, token: str) -> None:
super().__init__()
self.token = token # type: str
self.puppet_whitelisted = True # type: bool
@@ -53,7 +59,7 @@ class Bot(AbstractUser):
self.whitelist_group_admins = (config["bridge.relaybot.whitelist_group_admins"]
or False) # type: bool
- async def init_permissions(self):
+ async def init_permissions(self) -> None:
whitelist = config["bridge.relaybot.whitelist"] or []
for id in whitelist:
if isinstance(id, str):
@@ -65,14 +71,14 @@ class Bot(AbstractUser):
if isinstance(id, int):
self.tg_whitelist.append(id)
- async def start(self, delete_unless_authenticated: bool = False) -> "Bot":
+ async def start(self, delete_unless_authenticated: bool = False) -> 'Bot':
await super().start(delete_unless_authenticated)
if not await self.is_logged_in():
await self.client.sign_in(bot_token=self.token)
await self.post_login()
return self
- async def post_login(self):
+ async def post_login(self) -> None:
await self.init_permissions()
info = await self.client.get_me()
self.tgid = info.id
@@ -100,19 +106,19 @@ class Bot(AbstractUser):
except Exception:
self.log.exception("Failed to run catch_up() for bot")
- def register_portal(self, portal: po.Portal):
+ def register_portal(self, portal: po.Portal) -> None:
self.add_chat(portal.tgid, portal.peer_type)
- def unregister_portal(self, portal: po.Portal):
+ def unregister_portal(self, portal: po.Portal) -> None:
self.remove_chat(portal.tgid)
- def add_chat(self, id: int, type: str):
+ def add_chat(self, id: int, type: str) -> None:
if id not in self.chats:
self.chats[id] = type
self.db.add(BotChat(id=id, type=type))
self.db.commit()
- def remove_chat(self, id: int):
+ def remove_chat(self, id: int) -> None:
try:
del self.chats[id]
except KeyError:
@@ -141,6 +147,7 @@ class Bot(AbstractUser):
for p in participants:
if p.user_id == tgid:
return isinstance(p, (ChatParticipantCreator, ChatParticipantAdmin))
+ return False
async def check_can_use_commands(self, event: Message, reply: ReplyFunc) -> bool:
if not await self._can_use_commands(event.to_id, event.from_id):
@@ -148,7 +155,7 @@ class Bot(AbstractUser):
return False
return True
- async def handle_command_portal(self, portal: po.Portal, reply: ReplyFunc):
+ async def handle_command_portal(self, portal: po.Portal, reply: ReplyFunc) -> None:
if not config["bridge.relaybot.authless_portals"]:
return await reply("This bridge doesn't allow portal creation from Telegram.")
@@ -164,15 +171,16 @@ class Bot(AbstractUser):
return await reply(
"Portal is not public. Use `/invite ` to get an invite.")
- async def handle_command_invite(self, portal: po.Portal, reply: ReplyFunc, mxid: str):
- if len(mxid) == 0:
+ async def handle_command_invite(self, portal: po.Portal, reply: ReplyFunc,
+ mxid_input: MatrixUserID) -> Message:
+ if len(mxid_input) == 0:
return await reply("Usage: `/invite `")
elif not portal.mxid:
return await reply("Portal does not have Matrix room. "
"Create one with /portal first.")
- if not self.mxid_regex.match(mxid):
+ if not self.mxid_regex.match(mxid_input):
return await reply("That doesn't look like a Matrix ID.")
- user = await u.User.get_by_mxid(mxid).ensure_started()
+ user = await u.User.get_by_mxid(MatrixUserID(mxid_input)).ensure_started()
if not user.relaybot_whitelisted:
return await reply("That user is not whitelisted to use the bridge.")
elif await user.is_logged_in():
@@ -183,7 +191,7 @@ class Bot(AbstractUser):
await portal.main_intent.invite(portal.mxid, user.mxid)
return await reply(f"Invited `{user.mxid}` to the portal.")
- def handle_command_id(self, message: Message, reply: ReplyFunc):
+ def handle_command_id(self, message: Message, reply: ReplyFunc) -> Awaitable[Message]:
# Provide the prefixed ID to the user so that the user wouldn't need to specify whether the
# chat is a normal group or a supergroup/channel when using the ID.
if isinstance(message.to_id, PeerChannel):
@@ -205,8 +213,8 @@ class Bot(AbstractUser):
return False
- async def handle_command(self, message: Message):
- def reply(reply_text):
+ async def handle_command(self, message: Message) -> None:
+ def reply(reply_text: str) -> Awaitable[Message]:
return self.client.send_message(message.to_id, reply_text, reply_to=message.id)
text = message.message
@@ -227,9 +235,9 @@ class Bot(AbstractUser):
mxid = text[text.index(" ") + 1:]
except ValueError:
mxid = ""
- await self.handle_command_invite(portal, reply, mxid=mxid)
+ await self.handle_command_invite(portal, reply, mxid_input=mxid)
- def handle_service_message(self, message: MessageService):
+ def handle_service_message(self, message: MessageService) -> None:
to_id = message.to_id
if isinstance(to_id, PeerChannel):
to_id = to_id.channel_id
@@ -246,11 +254,12 @@ class Bot(AbstractUser):
elif isinstance(action, MessageActionChatDeleteUser) and action.user_id == self.tgid:
self.remove_chat(to_id)
- async def update(self, update):
+ async def update(self, update) -> bool:
if not isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)):
- return
+ return False
if isinstance(update.message, MessageService):
- return self.handle_service_message(update.message)
+ self.handle_service_message(update.message)
+ return False
is_command = (isinstance(update.message, Message)
and update.message.entities and len(update.message.entities) > 0
@@ -266,7 +275,7 @@ class Bot(AbstractUser):
return "bot"
-def init(context) -> Optional[Bot]:
+def init(context: 'Context') -> Optional[Bot]:
global config
config = context.config
token = config["telegram.bot_token"]
diff --git a/mautrix_telegram/commands/auth.py b/mautrix_telegram/commands/auth.py
index 38b1fbe2..12151e4f 100644
--- a/mautrix_telegram/commands/auth.py
+++ b/mautrix_telegram/commands/auth.py
@@ -14,10 +14,14 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Dict
+from typing import Any, Awaitable, Dict, Optional
import asyncio
-from telethon.errors import *
+from telethon.errors import (
+ AccessTokenExpiredError, AccessTokenInvalidError, FirstNameInvalidError, FloodWaitError,
+ PasswordHashInvalidError, PhoneCodeExpiredError, PhoneCodeInvalidError,
+ PhoneNumberAppSignupForbiddenError, PhoneNumberBannedError, PhoneNumberFloodError,
+ PhoneNumberOccupiedError, PhoneNumberUnoccupiedError, SessionPasswordNeededError)
from . import command_handler, CommandEvent, SECTION_AUTH
from .. import puppet as pu
@@ -27,7 +31,7 @@ from ..util import format_duration
@command_handler(needs_auth=False,
help_section=SECTION_AUTH,
help_text="Check if you're logged into Telegram.")
-async def ping(evt: CommandEvent):
+async def ping(evt: CommandEvent) -> Optional[Dict]:
me = await evt.sender.client.get_me() if await evt.sender.is_logged_in() else None
if me:
return await evt.reply(f"You're logged in as @{me.username}")
@@ -38,7 +42,7 @@ async def ping(evt: CommandEvent):
@command_handler(needs_auth=False, needs_puppeting=False,
help_section=SECTION_AUTH,
help_text="Get the info of the message relay Telegram bot.")
-async def ping_bot(evt: CommandEvent):
+async def ping_bot(evt: CommandEvent) -> Optional[Dict]:
if not evt.tgbot:
return await evt.reply("Telegram message relay bot not configured.")
bot_info = await evt.tgbot.client.get_me()
@@ -53,19 +57,19 @@ async def ping_bot(evt: CommandEvent):
help_section=SECTION_AUTH,
help_text="Revert your Telegram account's Matrix puppet to use the default Matrix "
"account.")
-async def logout_matrix(evt: CommandEvent):
+async def logout_matrix(evt: CommandEvent) -> Optional[Dict]:
puppet = pu.Puppet.get(evt.sender.tgid)
if not puppet.is_real_user:
return await evt.reply("You are not logged in with your Matrix account.")
await puppet.switch_mxid(None, None)
- await evt.reply("Reverted your Telegram account's Matrix puppet back to the default.")
+ return await evt.reply("Reverted your Telegram account's Matrix puppet back to the default.")
@command_handler(needs_auth=True, management_only=True, needs_matrix_puppeting=True,
help_section=SECTION_AUTH,
help_text="Replace your Telegram account's Matrix puppet with your own Matrix "
"account")
-async def login_matrix(evt: CommandEvent):
+async def login_matrix(evt: CommandEvent) -> Optional[Dict]:
puppet = pu.Puppet.get(evt.sender.tgid)
if puppet.is_real_user:
return await evt.reply("You have already logged in with your Matrix account. "
@@ -96,7 +100,7 @@ async def login_matrix(evt: CommandEvent):
return await evt.reply("This bridge instance has been configured to not allow logging in.")
-async def enter_matrix_token(evt: CommandEvent):
+async def enter_matrix_token(evt: CommandEvent) -> Dict:
evt.sender.command_status = None
puppet = pu.Puppet.get(evt.sender.tgid)
@@ -105,10 +109,11 @@ async def enter_matrix_token(evt: CommandEvent):
"Log out with `$cmdprefix+sp logout-matrix` first.")
resp = await puppet.switch_mxid(" ".join(evt.args), evt.sender.mxid)
- if resp == 2:
+ if resp == pu.PuppetError.OnlyLoginSelf:
return await evt.reply("You can only log in as your own Matrix user.")
- elif resp == 1:
+ elif resp == pu.PuppetError.InvalidAccessToken:
return await evt.reply("Failed to verify access token.")
+ assert resp == pu.PuppetError.Success, "Encountered an unhandled PuppetError."
return await evt.reply(
f"Replaced your Telegram account's Matrix puppet with {puppet.custom_mxid}.")
@@ -117,7 +122,7 @@ async def enter_matrix_token(evt: CommandEvent):
help_section=SECTION_AUTH,
help_args="<_phone_> <_full name_>",
help_text="Register to Telegram")
-async def register(evt: CommandEvent):
+async def register(evt: CommandEvent) -> Optional[Dict]:
if await evt.sender.is_logged_in():
return await evt.reply("You are already logged in.")
elif len(evt.args) < 1:
@@ -134,9 +139,10 @@ async def register(evt: CommandEvent):
"action": "Register",
"full_name": full_name,
})
+ return None
-async def enter_code_register(evt: CommandEvent):
+async def enter_code_register(evt: CommandEvent) -> Dict:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp `")
try:
@@ -165,7 +171,7 @@ async def enter_code_register(evt: CommandEvent):
@command_handler(needs_auth=False, management_only=True,
help_section=SECTION_AUTH,
help_text="Get instructions on how to log in.")
-async def login(evt: CommandEvent):
+async def login(evt: CommandEvent) -> Optional[Dict]:
if await evt.sender.is_logged_in():
return await evt.reply("You are already logged in.")
@@ -196,7 +202,8 @@ async def login(evt: CommandEvent):
return await evt.reply("This bridge instance has been configured to not allow logging in.")
-async def request_code(evt: CommandEvent, phone_number: str, next_status: Dict[str, str]):
+async def request_code(evt: CommandEvent, phone_number: str, next_status: Dict[str, Any]
+ ) -> Dict:
ok = False
try:
await evt.sender.ensure_started(even_if_no_session=True)
@@ -228,7 +235,7 @@ async def request_code(evt: CommandEvent, phone_number: str, next_status: Dict[s
@command_handler(needs_auth=False)
-async def enter_phone_or_token(evt: CommandEvent):
+async def enter_phone_or_token(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp enter-phone-or-token `")
elif not evt.config.get("bridge.allow_matrix_login", True):
@@ -248,10 +255,11 @@ async def enter_phone_or_token(evt: CommandEvent):
"next": enter_code,
"action": "Login",
})
+ return None
@command_handler(needs_auth=False)
-async def enter_code(evt: CommandEvent):
+async def enter_code(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp enter-code `")
elif not evt.config.get("bridge.allow_matrix_login", True):
@@ -263,10 +271,11 @@ async def enter_code(evt: CommandEvent):
evt.log.exception("Error sending phone code")
return await evt.reply("Unhandled exception while sending code. "
"Check console for more details.")
+ return None
@command_handler(needs_auth=False)
-async def enter_password(evt: CommandEvent):
+async def enter_password(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp enter-password `")
elif not evt.config.get("bridge.allow_matrix_login", True):
@@ -282,9 +291,10 @@ async def enter_password(evt: CommandEvent):
evt.log.exception("Error sending password")
return await evt.reply("Unhandled exception while sending password. "
"Check console for more details.")
+ return None
-async def sign_in(evt: CommandEvent, **sign_in_info):
+async def sign_in(evt: CommandEvent, **sign_in_info) -> Dict:
try:
await evt.sender.ensure_started(even_if_no_session=True)
user = await evt.sender.client.sign_in(**sign_in_info)
@@ -309,7 +319,7 @@ async def sign_in(evt: CommandEvent, **sign_in_info):
@command_handler(needs_auth=True,
help_section=SECTION_AUTH,
help_text="Log out from Telegram.")
-async def logout(evt: CommandEvent):
+async def logout(evt: CommandEvent) -> Optional[Dict]:
if await evt.sender.log_out():
return await evt.reply("Logged out successfully.")
return await evt.reply("Failed to log out.")
diff --git a/mautrix_telegram/commands/clean_rooms.py b/mautrix_telegram/commands/clean_rooms.py
index aac5a54d..57760def 100644
--- a/mautrix_telegram/commands/clean_rooms.py
+++ b/mautrix_telegram/commands/clean_rooms.py
@@ -14,21 +14,21 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Tuple, List
+from typing import Dict, List, NewType, Optional, Tuple, Union
from mautrix_appservice import MatrixRequestError, IntentAPI
+from ..types import MatrixRoomID, MatrixUserID
from . import command_handler, CommandEvent, SECTION_ADMIN
from .. import puppet as pu, portal as po
-ManagementRoomList = List[Tuple[str, str]]
-RoomIDList = List[str]
+ManagementRoom = NewType('ManagementRoom', Tuple[MatrixRoomID, MatrixUserID])
-async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList,
- List["po.Portal"], List["po.Portal"]]:
- management_rooms = [] # type: ManagementRoomList
- unidentified_rooms = [] # type: RoomIDList
+async def _find_rooms(intent: IntentAPI) -> Tuple[List[ManagementRoom], List[MatrixRoomID],
+ List['po.Portal'], List['po.Portal']]:
+ management_rooms = [] # type: List[ManagementRoom]
+ unidentified_rooms = [] # type: List[MatrixRoomID]
portals = [] # type: List[po.Portal]
empty_portals = [] # type: List[po.Portal]
@@ -45,7 +45,7 @@ async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList
if pu.Puppet.get_id_from_mxid(other_member):
unidentified_rooms.append(room)
else:
- management_rooms.append((room, other_member))
+ management_rooms.append(ManagementRoom((room, other_member)))
else:
unidentified_rooms.append(room)
else:
@@ -61,7 +61,7 @@ async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList
@command_handler(needs_admin=True, needs_auth=False, management_only=True, name="clean-rooms",
help_section=SECTION_ADMIN,
help_text="Clean up unused portal/management rooms.")
-async def clean_rooms(evt: CommandEvent):
+async def clean_rooms(evt: CommandEvent) -> Optional[Dict]:
management_rooms, unidentified_rooms, portals, empty_portals = await _find_rooms(evt.az.intent)
reply = ["#### Management rooms (M)"]
@@ -106,13 +106,14 @@ async def clean_rooms(evt: CommandEvent):
return await evt.reply("\n".join(reply))
-async def set_rooms_to_clean(evt, management_rooms: ManagementRoomList,
- unidentified_rooms: RoomIDList, portals: List["po.Portal"],
- empty_portals: List["po.Portal"]):
+async def set_rooms_to_clean(evt, management_rooms: List[ManagementRoom],
+ unidentified_rooms: List[MatrixRoomID], portals: List["po.Portal"],
+ empty_portals: List["po.Portal"]) -> None:
command = evt.args[0]
- rooms_to_clean = []
+ rooms_to_clean = [] # type: List[Union[po.Portal, MatrixRoomID]]
if command == "clean-recommended":
- rooms_to_clean = empty_portals + unidentified_rooms
+ rooms_to_clean += empty_portals
+ rooms_to_clean += unidentified_rooms
elif command == "clean-groups":
if len(evt.args) < 2:
return await evt.reply("**Usage:** `$cmdprefix+sp clean-groups [M][A][U][I]")
@@ -158,7 +159,7 @@ async def set_rooms_to_clean(evt, management_rooms: ManagementRoomList,
"`$cmdprefix+sp confirm-clean`.")
-async def execute_room_cleanup(evt, rooms_to_clean):
+async def execute_room_cleanup(evt, rooms_to_clean: List[Union[po.Portal, MatrixRoomID]]) -> None:
if len(evt.args) > 0 and evt.args[0] == "confirm-clean":
await evt.reply(f"Cleaning {len(rooms_to_clean)} rooms. "
"This might take a while.")
@@ -167,7 +168,7 @@ async def execute_room_cleanup(evt, rooms_to_clean):
if isinstance(room, po.Portal):
await room.cleanup_and_delete()
cleaned += 1
- elif isinstance(room, str):
+ elif isinstance(room, str): # str is aliased by MatrixRoomID
await po.Portal.cleanup_room(evt.az.intent, room, message="Room deleted")
cleaned += 1
evt.sender.command_status = None
diff --git a/mautrix_telegram/commands/handler.py b/mautrix_telegram/commands/handler.py
index c7d2b1c2..2f0f7750 100644
--- a/mautrix_telegram/commands/handler.py
+++ b/mautrix_telegram/commands/handler.py
@@ -14,19 +14,20 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import List, Dict, Callable, Optional
+from typing import Any, Awaitable, Callable, Coroutine, Dict, List, NamedTuple, Optional, Union
from collections import namedtuple
import markdown
import logging
from telethon.errors import FloodWaitError
+from ..types import MatrixRoomID
from ..util import format_duration
from .. import user as u, context as c
command_handlers = {} # type: Dict[str, CommandHandler]
-HelpSection = namedtuple("HelpSection", "name order description")
+HelpSection = NamedTuple('HelpSection', [('name', str), ('order', int), ('description', str)])
SECTION_GENERAL = HelpSection("General", 0, "")
SECTION_AUTH = HelpSection("Authentication", 10, "")
@@ -37,8 +38,8 @@ SECTION_ADMIN = HelpSection("Administration", 50, "")
class CommandEvent:
- def __init__(self, processor: "CommandProcessor", room: str, sender: u.User, command: str,
- args: List[str], is_management: bool, is_portal: bool):
+ def __init__(self, processor: 'CommandProcessor', room: MatrixRoomID, sender: u.User,
+ command: str, args: List[str], is_management: bool, is_portal: bool) -> None:
self.az = processor.az
self.log = processor.log
self.loop = processor.loop
@@ -53,7 +54,8 @@ class CommandEvent:
self.is_management = is_management
self.is_portal = is_portal
- def reply(self, message: str, allow_html: bool = False, render_markdown: bool = True):
+ def reply(self, message: str, allow_html: bool = False, render_markdown: bool = True
+ ) -> Awaitable[Dict]:
message = message.replace("$cmdprefix+sp ",
"" if self.is_management else f"{self.command_prefix} ")
message = message.replace("$cmdprefix", self.command_prefix)
@@ -66,10 +68,10 @@ class CommandEvent:
class CommandHandler:
- def __init__(self, handler: Callable[[CommandEvent], None], needs_auth: bool,
+ def __init__(self, handler: Callable[[CommandEvent], Awaitable[Dict]], needs_auth: bool,
needs_puppeting: bool, needs_matrix_puppeting: bool, needs_admin: bool,
management_only: bool, name: str, help_text: str, help_args: str,
- help_section: HelpSection):
+ help_section: HelpSection) -> None:
self._handler = handler
self.needs_auth = needs_auth
self.needs_puppeting = needs_puppeting
@@ -103,7 +105,8 @@ class CommandHandler:
(not self.needs_admin or is_admin) and
(not self.needs_auth or is_logged_in))
- async def __call__(self, evt: CommandEvent):
+ async def __call__(self, evt: CommandEvent
+ ) -> Dict:
error = await self.get_permission_error(evt)
if error is not None:
return await evt.reply(error)
@@ -118,13 +121,21 @@ class CommandHandler:
return f"**{self.name}** {self._help_args} - {self._help_text}"
-def command_handler(_func: Optional[Callable[[CommandEvent], None]] = None, *, needs_auth=True,
- needs_puppeting=True, needs_matrix_puppeting=False, needs_admin=False,
- management_only=False, name=None, help_text="", help_args="",
- help_section=None):
+def command_handler(_func: Optional[Callable[[CommandEvent], Awaitable[Dict]]] = None, *,
+ needs_auth: bool = True,
+ needs_puppeting: bool = True,
+ needs_matrix_puppeting: bool = False,
+ needs_admin: bool = False,
+ management_only: bool = False,
+ name: Optional[str] = None,
+ help_text: str = "",
+ help_args: str = "",
+ help_section: HelpSection = None
+ ) -> Callable[[Callable[[CommandEvent], Awaitable[Optional[Dict]]]],
+ CommandHandler]:
input_name = name
- def decorator(func: Callable[[CommandEvent], None]):
+ def decorator(func: Callable[[CommandEvent], Awaitable[Optional[Dict]]]) -> CommandHandler:
name = input_name or func.__name__.replace("_", "-")
handler = CommandHandler(func, needs_auth, needs_puppeting, needs_matrix_puppeting,
needs_admin, management_only, name, help_text, help_args,
@@ -138,27 +149,27 @@ def command_handler(_func: Optional[Callable[[CommandEvent], None]] = None, *, n
class CommandProcessor:
log = logging.getLogger("mau.commands")
- def __init__(self, context: c.Context):
- self.az, self.db, self.config, self.loop, self.tgbot = context
+ def __init__(self, context: c.Context) -> None:
+ self.az, self.db, self.config, self.loop, self.tgbot = context.core
self.public_website = context.public_website
self.command_prefix = self.config["bridge.command_prefix"]
- async def handle(self, room: str, sender: u.User, command: str, args: List[str],
- is_management: bool, is_portal: bool):
+ async def handle(self, room: MatrixRoomID, sender: u.User, command: str, args: List[str],
+ is_management: bool, is_portal: bool) -> Optional[Dict]:
evt = CommandEvent(self, room, sender, command, args, is_management, is_portal)
orig_command = command
command = command.lower()
try:
- command = command_handlers[command]
+ command_handler = command_handlers[command]
except KeyError:
if sender.command_status and "next" in sender.command_status:
args.insert(0, orig_command)
evt.command = ""
command = sender.command_status["next"]
else:
- command = command_handlers["unknown-command"]
+ command_handler = command_handlers["unknown-command"]
try:
- await command(evt)
+ await command_handler(evt)
except FloodWaitError as e:
return await evt.reply(f"Flood error: Please wait {format_duration(e.seconds)}")
except Exception:
@@ -166,3 +177,4 @@ class CommandProcessor:
f"{evt.command} {' '.join(args)} from {sender.mxid}")
return await evt.reply("Unhandled error while handling command. "
"Check logs for more details.")
+ return None
diff --git a/mautrix_telegram/commands/meta.py b/mautrix_telegram/commands/meta.py
index f50a83e0..5920dbee 100644
--- a/mautrix_telegram/commands/meta.py
+++ b/mautrix_telegram/commands/meta.py
@@ -14,46 +14,49 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from typing import Dict, List, Optional, Tuple
+
from . import command_handler, CommandEvent, _command_handlers, SECTION_GENERAL
+from .handler import HelpSection
@command_handler(needs_auth=False, needs_puppeting=False,
help_section=SECTION_GENERAL,
help_text="Cancel an ongoing action (such as login)")
-def cancel(evt: CommandEvent):
+async def cancel(evt: CommandEvent) -> Optional[Dict]:
if evt.sender.command_status:
action = evt.sender.command_status["action"]
evt.sender.command_status = None
- return evt.reply(f"{action} cancelled.")
+ return await evt.reply(f"{action} cancelled.")
else:
- return evt.reply("No ongoing command.")
+ return await evt.reply("No ongoing command.")
@command_handler(needs_auth=False, needs_puppeting=False)
-def unknown_command(evt: CommandEvent):
- return evt.reply("Unknown command. Try `$cmdprefix+sp help` for help.")
+async def unknown_command(evt: CommandEvent) -> Optional[Dict]:
+ return await evt.reply("Unknown command. Try `$cmdprefix+sp help` for help.")
-help_cache = {}
+help_cache = {} # type: Dict[Tuple[bool, bool, bool, bool, bool], str]
-async def _get_help_text(evt: CommandEvent):
+async def _get_help_text(evt: CommandEvent) -> str:
cache_key = (evt.is_management, evt.sender.puppet_whitelisted,
evt.sender.matrix_puppet_whitelisted, evt.sender.is_admin,
await evt.sender.is_logged_in())
if cache_key not in help_cache:
- help = {}
+ help_sections = {} # type: Dict[HelpSection, List[str]]
for handler in _command_handlers.values():
if handler.has_help and handler.has_permission(*cache_key):
- help.setdefault(handler.help_section, [])
- help[handler.help_section].append(handler.help + " ")
- help = sorted(help.items(), key=lambda item: item[0].order)
- help = ["#### {}\n{}\n".format(key.name, "\n".join(value)) for key, value in help]
+ help_sections.setdefault(handler.help_section, [])
+ help_sections[handler.help_section].append(handler.help + " ")
+ help_sorted = sorted(help_sections.items(), key=lambda item: item[0].order)
+ help = ["#### {}\n{}\n".format(key.name, "\n".join(value)) for key, value in help_sorted]
help_cache[cache_key] = "\n".join(help)
return help_cache[cache_key]
-def _get_management_status(evt: CommandEvent):
+def _get_management_status(evt: CommandEvent) -> str:
if evt.is_management:
return "This is a management room: prefixing commands with `$cmdprefix` is not required."
elif evt.is_portal:
@@ -65,5 +68,5 @@ def _get_management_status(evt: CommandEvent):
@command_handler(needs_auth=False, needs_puppeting=False,
help_section=SECTION_GENERAL,
help_text="Show this help message.")
-async def help(evt: CommandEvent):
+async def help(evt: CommandEvent) -> Optional[Dict]:
return await evt.reply(_get_management_status(evt) + "\n" + await _get_help_text(evt))
diff --git a/mautrix_telegram/commands/portal.py b/mautrix_telegram/commands/portal.py
index c2ff2347..4b6adde1 100644
--- a/mautrix_telegram/commands/portal.py
+++ b/mautrix_telegram/commands/portal.py
@@ -14,13 +14,15 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Optional, Callable
+from typing import Awaitable, Dict, Callable, Coroutine, Optional, Tuple, Union, cast
import asyncio
-from telethon.errors import *
+from telethon.errors import (ChatAdminRequiredError, UsernameInvalidError,
+ UsernameNotModifiedError, UsernameOccupiedError)
from telethon.tl.types import ChatForbidden, ChannelForbidden
from mautrix_appservice import MatrixRequestError, IntentAPI
+from ..types import MatrixRoomID, TelegramID
from .. import portal as po, user as u
from . import (command_handler, CommandEvent,
SECTION_ADMIN, SECTION_CREATING_PORTALS, SECTION_PORTAL_MANAGEMENT)
@@ -30,7 +32,7 @@ from . import (command_handler, CommandEvent,
help_section=SECTION_ADMIN,
help_args="<_level_> [_mxid_]",
help_text="Set a temporary power level without affecting Telegram.")
-async def set_power_level(evt: CommandEvent):
+async def set_power_level(evt: CommandEvent) -> Dict:
try:
level = int(evt.args[0])
except KeyError:
@@ -45,11 +47,12 @@ async def set_power_level(evt: CommandEvent):
except MatrixRequestError:
evt.log.exception("Failed to set power level.")
return await evt.reply("Failed to set power level.")
+ return {}
@command_handler(help_section=SECTION_PORTAL_MANAGEMENT,
help_text="Get a Telegram invite link to the current chat.")
-async def invite_link(evt: CommandEvent):
+async def invite_link(evt: CommandEvent) -> Dict:
portal = po.Portal.get_by_mxid(evt.room_id)
if not portal:
return await evt.reply("This is not a portal room.")
@@ -66,7 +69,8 @@ async def invite_link(evt: CommandEvent):
return await evt.reply("You don't have the permission to create an invite link.")
-async def user_has_power_level(room: str, intent, sender: u.User, event: str, default: int = 50):
+async def user_has_power_level(room: str, intent, sender: u.User, event: str, default: int = 50
+ ) -> bool:
if sender.is_admin:
return True
# Make sure the state store contains the power levels.
@@ -80,8 +84,9 @@ async def user_has_power_level(room: str, intent, sender: u.User, event: str, de
async def _get_portal_and_check_permission(evt: CommandEvent, permission: str,
- action: Optional[str] = None):
- room_id = evt.args[0] if len(evt.args) > 0 else evt.room_id
+ action: Optional[str] = None
+ ) -> Tuple[Union[Dict, po.Portal], bool]:
+ room_id = MatrixRoomID(evt.args[0]) if len(evt.args) > 0 else evt.room_id
portal = po.Portal.get_by_mxid(room_id)
if not portal:
@@ -95,8 +100,8 @@ async def _get_portal_and_check_permission(evt: CommandEvent, permission: str,
def _get_portal_murder_function(action: str, room_id: str, function: Callable, command: str,
- completed_message: str):
- async def post_confirm(confirm):
+ completed_message: str) -> Dict:
+ async def post_confirm(confirm) -> Optional[Dict]:
confirm.sender.command_status = None
if len(confirm.args) > 0 and confirm.args[0] == f"confirm-{command}":
await function()
@@ -104,6 +109,7 @@ def _get_portal_murder_function(action: str, room_id: str, function: Callable, c
return await confirm.reply(completed_message)
else:
return await confirm.reply(f"{action} cancelled.")
+ return None
return {
"next": post_confirm,
@@ -116,10 +122,11 @@ def _get_portal_murder_function(action: str, room_id: str, function: Callable, c
help_text="Remove all users from the current portal room and forget the portal. "
"Only works for group chats; to delete a private chat portal, simply "
"leave the room.")
-async def delete_portal(evt: CommandEvent):
- portal, ok = await _get_portal_and_check_permission(evt, "unbridge")
+async def delete_portal(evt: CommandEvent) -> Optional[Dict]:
+ result, ok = await _get_portal_and_check_permission(evt, "unbridge")
if not ok:
- return
+ return None
+ portal = cast('po.Portal', result)
evt.sender.command_status = _get_portal_murder_function("Portal deletion", portal.mxid,
portal.cleanup_and_delete, "delete",
@@ -137,10 +144,11 @@ async def delete_portal(evt: CommandEvent):
@command_handler(needs_auth=False, needs_puppeting=False,
help_section=SECTION_PORTAL_MANAGEMENT,
help_text="Remove puppets from the current portal room and forget the portal.")
-async def unbridge(evt: CommandEvent):
- portal, ok = await _get_portal_and_check_permission(evt, "unbridge")
+async def unbridge(evt: CommandEvent) -> Optional[Dict]:
+ result, ok = await _get_portal_and_check_permission(evt, "unbridge")
if not ok:
- return
+ return None
+ portal = cast('po.Portal', result)
evt.sender.command_status = _get_portal_murder_function("Room unbridging", portal.mxid,
portal.unbridge, "unbridge",
@@ -156,11 +164,11 @@ async def unbridge(evt: CommandEvent):
help_text="Bridge the current Matrix room to the Telegram chat with the given "
"ID. The ID must be the prefixed version that you get with the `/id` "
"command of the Telegram-side bot.")
-async def bridge(evt: CommandEvent):
+async def bridge(evt: CommandEvent) -> Dict:
if len(evt.args) == 0:
return await evt.reply("**Usage:** "
"`$cmdprefix+sp bridge [Matrix room ID]`")
- room_id = evt.args[1] if len(evt.args) > 1 else evt.room_id
+ room_id = MatrixRoomID(evt.args[1]) if len(evt.args) > 1 else evt.room_id
that_this = "This" if room_id == evt.room_id else "That"
portal = po.Portal.get_by_mxid(room_id)
@@ -171,12 +179,12 @@ async def bridge(evt: CommandEvent):
return await evt.reply(f"You do not have the permissions to bridge {that_this} room.")
# The /id bot command provides the prefixed ID, so we assume
- tgid = evt.args[0]
- if tgid.startswith("-100"):
- tgid = int(tgid[4:])
+ tgid_str = evt.args[0]
+ if tgid_str.startswith("-100"):
+ tgid = TelegramID(int(tgid_str[4:]))
peer_type = "channel"
- elif tgid.startswith("-"):
- tgid = -int(tgid)
+ elif tgid_str.startswith("-"):
+ tgid = TelegramID(-int(tgid_str))
peer_type = "chat"
else:
return await evt.reply("That doesn't seem like a prefixed Telegram chat ID.\n\n"
@@ -222,7 +230,8 @@ async def bridge(evt: CommandEvent):
"chat to this room, use `$cmdprefix+sp continue`")
-async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Portal"):
+async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Portal"
+ ) -> Tuple[bool, Coroutine[None, None, None]]:
if not portal.mxid:
await evt.reply("The portal seems to have lost its Matrix room between you"
"calling `$cmdprefix+sp bridge` and this command.\n\n"
@@ -245,7 +254,7 @@ async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Porta
return False, None
-async def confirm_bridge(evt: CommandEvent):
+async def confirm_bridge(evt: CommandEvent) -> Optional[Dict]:
status = evt.sender.command_status
try:
portal = po.Portal.get_by_tgid(status["tgid"], peer_type=status["peer_type"])
@@ -258,7 +267,7 @@ async def confirm_bridge(evt: CommandEvent):
if "mxid" in status:
ok, coro = await cleanup_old_portal_while_bridging(evt, portal)
if not ok:
- return
+ return None
elif coro:
asyncio.ensure_future(coro, loop=evt.loop)
await evt.reply("Cleaning up previous portal room...")
@@ -302,7 +311,7 @@ async def confirm_bridge(evt: CommandEvent):
return await evt.reply("Bridging complete. Portal synchronization should begin momentarily.")
-async def get_initial_state(intent: IntentAPI, room_id: str):
+async def get_initial_state(intent: IntentAPI, room_id: str) -> Tuple[str, str, Dict]:
state = await intent.get_room_state(room_id)
title = None
about = None
@@ -328,7 +337,7 @@ async def get_initial_state(intent: IntentAPI, room_id: str):
help_text="Create a Telegram chat of the given type for the current Matrix room. "
"The type is either `group`, `supergroup` or `channel` (defaults to "
"`group`).")
-async def create(evt: CommandEvent):
+async def create(evt: CommandEvent) -> Dict:
type = evt.args[0] if len(evt.args) > 0 else "group"
if type not in {"chat", "group", "supergroup", "channel"}:
return await evt.reply(
@@ -363,7 +372,7 @@ async def create(evt: CommandEvent):
@command_handler(help_section=SECTION_PORTAL_MANAGEMENT,
help_text="Upgrade a normal Telegram group to a supergroup.")
-async def upgrade(evt: CommandEvent):
+async def upgrade(evt: CommandEvent) -> Dict:
portal = po.Portal.get_by_mxid(evt.room_id)
if not portal:
return await evt.reply("This is not a portal room.")
@@ -385,7 +394,7 @@ async def upgrade(evt: CommandEvent):
help_args="<_name_|`-`>",
help_text="Change the username of a supergroup/channel. "
"To disable, use a dash (`-`) as the name.")
-async def group_name(evt: CommandEvent):
+async def group_name(evt: CommandEvent) -> Dict:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp group-name `")
@@ -421,7 +430,7 @@ async def group_name(evt: CommandEvent):
help_args="<`whitelist`|`blacklist`>",
help_text="Change whether the bridge will allow or disallow bridging rooms by "
"default.")
-async def filter_mode(evt: CommandEvent):
+async def filter_mode(evt: CommandEvent) -> Dict:
try:
mode = evt.args[0]
if mode not in ("whitelist", "blacklist"):
@@ -446,19 +455,19 @@ async def filter_mode(evt: CommandEvent):
help_section=SECTION_ADMIN,
help_args="<`whitelist`|`blacklist`> <_chat ID_>",
help_text="Allow or disallow bridging a specific chat.")
-async def filter(evt: CommandEvent):
+async def filter(evt: CommandEvent) -> Optional[Dict]:
try:
action = evt.args[0]
if action not in ("whitelist", "blacklist", "add", "remove"):
raise ValueError()
- id = evt.args[1]
- if id.startswith("-100"):
- id = int(id[4:])
- elif id.startswith("-"):
- id = int(id[1:])
+ id_str = evt.args[1]
+ if id_str.startswith("-100"):
+ id = int(id_str[4:])
+ elif id_str.startswith("-"):
+ id = int(id_str[1:])
else:
- id = int(id)
+ id = int(id_str)
except (IndexError, ValueError):
return await evt.reply("**Usage:** `$cmdprefix+sp filter `")
@@ -471,7 +480,7 @@ async def filter(evt: CommandEvent):
if action in ("blacklist", "whitelist"):
action = "add" if mode == action else "remove"
- def save():
+ def save() -> None:
evt.config["bridge.filter.list"] = list
evt.config.save()
po.Portal.filter_list = list
@@ -488,3 +497,4 @@ async def filter(evt: CommandEvent):
list.remove(id)
save()
return await evt.reply(f"Chat ID removed from {mode}.")
+ return None
diff --git a/mautrix_telegram/commands/telegram.py b/mautrix_telegram/commands/telegram.py
index 75221c21..2f968742 100644
--- a/mautrix_telegram/commands/telegram.py
+++ b/mautrix_telegram/commands/telegram.py
@@ -14,8 +14,13 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from telethon.errors import *
+from typing import Awaitable, Dict, List, Optional, Tuple
+import re
+
+from telethon.errors import (
+ InviteHashInvalidError, InviteHashExpiredError, UserAlreadyParticipantError)
from telethon.tl.types import User as TLUser
+from telethon.tl.types import TypeUpdates
from telethon.tl.functions.messages import ImportChatInviteRequest, CheckChatInviteRequest
from telethon.tl.functions.channels import JoinChannelRequest
@@ -26,7 +31,7 @@ from . import command_handler, CommandEvent, SECTION_MISC, SECTION_CREATING_PORT
@command_handler(help_section=SECTION_MISC,
help_args="[_-r|--remote_] <_query_>",
help_text="Search your contacts or the Telegram servers for users.")
-async def search(evt: CommandEvent):
+async def search(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp search [-r|--remote] `")
@@ -47,7 +52,7 @@ async def search(evt: CommandEvent):
"Minimum length of remote query is 5 characters.")
return await evt.reply("No results 3:")
- reply = []
+ reply = [] # type: List[str]
if remote:
reply += ["**Results from Telegram server:**", ""]
else:
@@ -68,7 +73,7 @@ async def search(evt: CommandEvent):
"either the internal user ID, the username or the phone number. "
"**N.B.** The phone numbers you start chats with must already be in "
"your contacts.")
-async def private_message(evt: CommandEvent):
+async def private_message(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp pm `")
@@ -87,7 +92,7 @@ async def private_message(evt: CommandEvent):
f"{pu.Puppet.get_displayname(user, False)}")
-async def _join(evt: CommandEvent, arg: str):
+async def _join(evt: CommandEvent, arg: str) -> Tuple[TypeUpdates, Dict]:
if arg.startswith("joinchat/"):
invite_hash = arg[len("joinchat/"):]
try:
@@ -110,7 +115,7 @@ async def _join(evt: CommandEvent, arg: str):
@command_handler(help_section=SECTION_CREATING_PORTALS,
help_args="<_link_>",
help_text="Join a chat with an invite link.")
-async def join(evt: CommandEvent):
+async def join(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp join `")
@@ -121,7 +126,7 @@ async def join(evt: CommandEvent):
updates, _ = await _join(evt, arg.group(1))
if not updates:
- return
+ return None
for chat in updates.chats:
portal = po.Portal.get_by_entity(chat)
@@ -132,12 +137,13 @@ async def join(evt: CommandEvent):
await evt.reply(f"Creating room for {chat.title}... This might take a while.")
await portal.create_matrix_room(evt.sender, chat, [evt.sender.mxid])
return await evt.reply(f"Created room for {portal.title}")
+ return None
@command_handler(help_section=SECTION_MISC,
help_args="[`chats`|`contacts`|`me`]",
help_text="Synchronize your chat portals, contacts and/or own info.")
-async def sync(evt: CommandEvent):
+async def sync(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) > 0:
sync_only = evt.args[0]
if sync_only not in ("chats", "contacts", "me"):
diff --git a/mautrix_telegram/config.py b/mautrix_telegram/config.py
index 159a61d5..7f0029d3 100644
--- a/mautrix_telegram/config.py
+++ b/mautrix_telegram/config.py
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Tuple, Any, Optional
+from typing import Any, Dict, Optional, Tuple
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap
import random
@@ -25,7 +25,7 @@ yaml.indent(4)
class DictWithRecursion:
- def __init__(self, data: CommentedMap = None):
+ def __init__(self, data: Optional[CommentedMap] = None) -> None:
self._data = data or CommentedMap() # type: CommentedMap
def _recursive_get(self, data: CommentedMap, key: str, default_value: Any) -> Any:
@@ -46,7 +46,7 @@ class DictWithRecursion:
def __contains__(self, key: str) -> bool:
return self[key] is not None
- def _recursive_set(self, data: CommentedMap, key: str, value: Any):
+ def _recursive_set(self, data: CommentedMap, key: str, value: Any) -> None:
if '.' in key:
key, next_key = key.split('.', 1)
if key not in data:
@@ -56,16 +56,16 @@ class DictWithRecursion:
return
data[key] = value
- def set(self, key: str, value: Any, allow_recursion: bool = True):
+ def set(self, key: str, value: Any, allow_recursion: bool = True) -> None:
if allow_recursion and '.' in key:
self._recursive_set(self._data, key, value)
return
self._data[key] = value
- def __setitem__(self, key: str, value: Any):
+ def __setitem__(self, key: str, value: Any) -> None:
self.set(key, value)
- def _recursive_del(self, data: CommentedMap, key: str):
+ def _recursive_del(self, data: CommentedMap, key: str) -> None:
if '.' in key:
key, next_key = key.split('.', 1)
if key not in data:
@@ -79,7 +79,7 @@ class DictWithRecursion:
except KeyError:
pass
- def delete(self, key: str, allow_recursion: bool = True):
+ def delete(self, key: str, allow_recursion: bool = True) -> None:
if allow_recursion and '.' in key:
self._recursive_del(self._data, key)
return
@@ -89,19 +89,19 @@ class DictWithRecursion:
except KeyError:
pass
- def __delitem__(self, key: str):
+ def __delitem__(self, key: str) -> None:
self.delete(key)
class Config(DictWithRecursion):
- def __init__(self, path: str, registration_path: str, base_path: str):
+ def __init__(self, path: str, registration_path: str, base_path: str) -> None:
super().__init__()
self.path = path # type: str
self.registration_path = registration_path # type: str
self.base_path = base_path # type: str
- self._registration = None # type: dict
+ self._registration = None # type: Optional[Dict]
- def load(self):
+ def load(self) -> None:
with open(self.path, 'r') as stream:
self._data = yaml.load(stream)
@@ -113,7 +113,7 @@ class Config(DictWithRecursion):
pass
return None
- def save(self):
+ def save(self) -> None:
with open(self.path, 'w') as stream:
yaml.dump(self._data, stream)
if self._registration and self.registration_path:
@@ -124,16 +124,16 @@ class Config(DictWithRecursion):
def _new_token() -> str:
return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64))
- def update(self):
+ def update(self) -> None:
base = self.load_base()
if not base:
return
- def copy(from_path, to_path=None):
+ def copy(from_path, to_path=None) -> None:
if from_path in self:
base[to_path or from_path] = self[from_path]
- def copy_dict(from_path, to_path=None, override_existing_map=True):
+ def copy_dict(from_path, to_path=None, override_existing_map=True) -> None:
if from_path in self:
to_path = to_path or from_path
if override_existing_map or to_path not in base:
@@ -273,7 +273,7 @@ class Config(DictWithRecursion):
return self._get_permissions("*")
- def generate_registration(self):
+ def generate_registration(self) -> None:
homeserver = self["homeserver.domain"]
username_format = self.get("bridge.username_template", "telegram_{userid}") \
diff --git a/mautrix_telegram/context.py b/mautrix_telegram/context.py
index de257477..4330c102 100644
--- a/mautrix_telegram/context.py
+++ b/mautrix_telegram/context.py
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import TYPE_CHECKING, Optional
+from typing import Generator, Optional, Tuple, Union, TYPE_CHECKING
if TYPE_CHECKING:
import asyncio
@@ -32,7 +32,8 @@ if TYPE_CHECKING:
class Context:
def __init__(self, az: "AppService", db: "scoped_session", config: "Config",
- loop: "asyncio.AbstractEventLoop", session_container: "AlchemySessionContainer"):
+ loop: "asyncio.AbstractEventLoop", session_container: "AlchemySessionContainer"
+ ) -> None:
self.az = az # type: AppService
self.db = db # type: scoped_session
self.config = config # type: Config
@@ -43,9 +44,7 @@ class Context:
self.public_website = None # type: PublicBridgeWebsite
self.provisioning_api = None # type: ProvisioningAPI
- def __iter__(self):
- yield self.az
- yield self.db
- yield self.config
- yield self.loop
- yield self.bot
+ @property
+ def core(self) -> Tuple['AppService', 'scoped_session', 'Config',
+ 'asyncio.AbstractEventLoop', Optional['Bot']]:
+ return (self.az, self.db, self.config, self.loop, self.bot)
diff --git a/mautrix_telegram/db.py b/mautrix_telegram/db.py
index 5a0baf70..751ca76f 100644
--- a/mautrix_telegram/db.py
+++ b/mautrix_telegram/db.py
@@ -14,6 +14,8 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from typing import Dict
+
from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer,
BigInteger, String, Boolean, Text)
from sqlalchemy.sql import expression
@@ -88,20 +90,20 @@ class RoomState(Base):
room_id = Column(String, primary_key=True)
_power_levels_text = Column("power_levels", Text, nullable=True)
- _power_levels_json = None
+ _power_levels_json = {} # type: Dict
@property
- def has_power_levels(self):
+ def has_power_levels(self) -> bool:
return bool(self._power_levels_text)
@property
- def power_levels(self):
+ def power_levels(self) -> Dict:
if not self._power_levels_json and self._power_levels_text:
self._power_levels_json = json.loads(self._power_levels_text)
- return self._power_levels_json or {}
+ return self._power_levels_json
@power_levels.setter
- def power_levels(self, val):
+ def power_levels(self, val: Dict) -> None:
self._power_levels_json = val
self._power_levels_text = json.dumps(val)
@@ -116,7 +118,7 @@ class UserProfile(Base):
displayname = Column(String, nullable=True)
avatar_url = Column(String, nullable=True)
- def dict(self):
+ def dict(self) -> Dict[str, Column]:
return {
"membership": self.membership,
"displayname": self.displayname,
@@ -171,7 +173,7 @@ class TelegramFile(Base):
thumbnail = relationship("TelegramFile", uselist=False)
-def init(db_session):
+def init(db_session) -> None:
Portal.query = db_session.query_property()
Message.query = db_session.query_property()
UserPortal.query = db_session.query_property()
diff --git a/mautrix_telegram/formatter/from_matrix/parser_htmlparser.py b/mautrix_telegram/formatter/from_matrix/parser_htmlparser.py
index 4de0cba1..ad085fe9 100644
--- a/mautrix_telegram/formatter/from_matrix/parser_htmlparser.py
+++ b/mautrix_telegram/formatter/from_matrix/parser_htmlparser.py
@@ -80,12 +80,12 @@ class MatrixParser(HTMLParser, MatrixParserCommon):
args["url"] = url
return MessageEntityTextUrl, None
- def handle_starttag(self, tag: str, attrs: List[Tuple[str, str]]):
+ def handle_starttag(self, tag: str, attrs_list: List[Tuple[str, str]]):
self._open_tags.appendleft(tag)
self._open_tags_meta.appendleft(0)
- attrs = dict(attrs)
- entity_type = None # type: type(TypeMessageEntity)
+ attrs = dict(attrs_list)
+ entity_type = None # type: Optional[Type[TypeMessageEntity]]
args = {} # type: Dict[str, Any]
if tag in ("strong", "b"):
entity_type = MessageEntityBold
diff --git a/mautrix_telegram/formatter/from_matrix/parser_lxml.py b/mautrix_telegram/formatter/from_matrix/parser_lxml.py
index 0a997db3..d529feab 100644
--- a/mautrix_telegram/formatter/from_matrix/parser_lxml.py
+++ b/mautrix_telegram/formatter/from_matrix/parser_lxml.py
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Optional, List, Tuple, Union, Callable
+from typing import Callable, List, Optional, Sequence, Tuple, Type, Union
from lxml import html
from telethon.tl.types import (MessageEntityMention as Mention,
@@ -83,11 +83,11 @@ def offset_length_multiply(amount: int):
class TelegramMessage:
- def __init__(self, text: str = "", entities: Optional[List[TypeMessageEntity]] = None):
+ def __init__(self, text: str = "", entities: Optional[List[TypeMessageEntity]] = None) -> None:
self.text = text # type: str
self.entities = entities or [] # type: List[TypeMessageEntity]
- def offset_entities(self, offset: int) -> "TelegramMessage":
+ def offset_entities(self, offset: int) -> 'TelegramMessage':
def apply_offset(entity: TypeMessageEntity, inner_offset: int
) -> Optional[TypeMessageEntity]:
entity = Entity.copy(entity)
@@ -104,7 +104,7 @@ class TelegramMessage:
self.entities = [x for x in self.entities if x is not None]
return self
- def append(self, *args: Union[str, "TelegramMessage"]) -> "TelegramMessage":
+ def append(self, *args: Union[str, 'TelegramMessage']) -> 'TelegramMessage':
for msg in args:
if isinstance(msg, str):
msg = TelegramMessage(text=msg)
@@ -112,7 +112,7 @@ class TelegramMessage:
self.text += msg.text
return self
- def prepend(self, *args: Union[str, "TelegramMessage"]) -> "TelegramMessage":
+ def prepend(self, *args: Union[str, 'TelegramMessage']) -> 'TelegramMessage':
for msg in args:
if isinstance(msg, str):
msg = TelegramMessage(text=msg)
@@ -120,17 +120,17 @@ class TelegramMessage:
self.text = msg.text + self.text
return self
- def format(self, entity_type: type(TypeMessageEntity), offset: int = None, length: int = None,
- **kwargs) -> "TelegramMessage":
+ def format(self, entity_type: Type[TypeMessageEntity], offset: int = None, length: int = None,
+ **kwargs) -> 'TelegramMessage':
self.entities.append(entity_type(offset=offset or 0,
length=length if length is not None else len(self.text),
**kwargs))
return self
- def concat(self, *args: Union[str, "TelegramMessage"]) -> "TelegramMessage":
+ def concat(self, *args: Union[str, 'TelegramMessage']) -> 'TelegramMessage':
return TelegramMessage().append(self, *args)
- def trim(self) -> "TelegramMessage":
+ def trim(self) -> 'TelegramMessage':
orig_len = len(self.text)
self.text = self.text.lstrip()
diff = orig_len - len(self.text)
@@ -138,7 +138,7 @@ class TelegramMessage:
self.offset_entities(-diff)
return self
- def split(self, separator, max_items: int = 0) -> List["TelegramMessage"]:
+ def split(self, separator, max_items: int = 0) -> List['TelegramMessage']:
text_parts = self.text.split(separator, max_items - 1)
output = [] # type: List[TelegramMessage]
@@ -158,7 +158,8 @@ class TelegramMessage:
return output
@staticmethod
- def join(items: List[Union[str, "TelegramMessage"]], separator: str = " ") -> "TelegramMessage":
+ def join(items: Sequence[Union[str, 'TelegramMessage']],
+ separator: str = " ") -> 'TelegramMessage':
main = TelegramMessage()
for msg in items:
if isinstance(msg, str):
diff --git a/mautrix_telegram/formatter/from_telegram.py b/mautrix_telegram/formatter/from_telegram.py
index 33f8a335..7014f600 100644
--- a/mautrix_telegram/formatter/from_telegram.py
+++ b/mautrix_telegram/formatter/from_telegram.py
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Optional, List, Tuple, TYPE_CHECKING
+from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from html import escape
import logging
import re
@@ -28,6 +28,7 @@ from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName,
from mautrix_appservice import MatrixRequestError
from mautrix_appservice.intent_api import IntentAPI
+from ..types import TelegramID
from .. import user as u, puppet as pu, portal as po
from ..db import Message as DBMessage
from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html,
@@ -40,14 +41,14 @@ if TYPE_CHECKING:
try:
from lxml.html.diff import htmldiff
except ImportError:
- htmldiff = None # type: function
+ htmldiff = None # type: ignore
log = logging.getLogger("mau.fmt.tg") # type: logging.Logger
should_highlight_edits = False # type: bool
-def telegram_reply_to_matrix(evt: Message, source: "AbstractUser") -> dict:
+def telegram_reply_to_matrix(evt: Message, source: 'AbstractUser') -> Dict:
if evt.reply_to_msg_id:
space = (evt.to_id.channel_id
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
@@ -116,7 +117,7 @@ def highlight_edits(new_html: str, old_html: str) -> str:
async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: Message,
- relates_to: dict, main_intent: IntentAPI, is_edit: bool
+ relates_to: Dict, main_intent: IntentAPI, is_edit: bool
) -> Tuple[str, str]:
space = (evt.to_id.channel_id
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
@@ -177,10 +178,10 @@ async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: M
async def telegram_to_matrix(evt: Message, source: "AbstractUser",
main_intent: Optional[IntentAPI] = None,
is_edit: bool = False, prefix_text: Optional[str] = None,
- prefix_html: Optional[str] = None) -> Tuple[str, str, dict]:
+ prefix_html: Optional[str] = None) -> Tuple[str, str, Dict]:
text = add_surrogates(evt.message)
html = _telegram_entities_to_matrix_catch(text, evt.entities) if evt.entities else None
- relates_to = {}
+ relates_to = {} # type: Dict
if prefix_html:
html = prefix_html + (html or escape(text))
@@ -217,6 +218,7 @@ def _telegram_entities_to_matrix_catch(text: str, entities: List[TypeMessageEnti
"message=%s\n"
"entities=%s",
text, entities)
+ return "[failed conversion in _telegram_entities_to_matrix]"
def _telegram_entities_to_matrix(text: str, entities: List[TypeMessageEntity]) -> str:
@@ -290,7 +292,7 @@ def _parse_mention(html: List[str], entity_text: str) -> bool:
return False
-def _parse_name_mention(html: List[str], entity_text: str, user_id: int) -> bool:
+def _parse_name_mention(html: List[str], entity_text: str, user_id: TelegramID) -> bool:
user = u.User.get_by_tgid(user_id)
if user:
mxid = user.mxid
@@ -315,8 +317,8 @@ def _parse_url(html: List[str], entity_text: str, url: str) -> bool:
message_link_match = message_link_regex.match(url)
if message_link_match:
- group, msgid = message_link_match.groups()
- msgid = int(msgid)
+ group, msgid_str = message_link_match.groups()
+ msgid = int(msgid_str)
portal = po.Portal.find_by_username(group)
if portal:
@@ -328,6 +330,6 @@ def _parse_url(html: List[str], entity_text: str, url: str) -> bool:
return False
-def init_tg(context: "Context"):
+def init_tg(context: "Context") -> None:
global should_highlight_edits
should_highlight_edits = htmldiff and context.config["bridge.highlight_edits"]
diff --git a/mautrix_telegram/matrix.py b/mautrix_telegram/matrix.py
index 4deda884..ace90e7c 100644
--- a/mautrix_telegram/matrix.py
+++ b/mautrix_telegram/matrix.py
@@ -14,27 +14,35 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import List, Dict, Tuple, Set, Match
+from typing import Dict, List, Match, Optional, Set, Tuple, TYPE_CHECKING
import logging
import asyncio
import re
from mautrix_appservice import MatrixRequestError, IntentError
+from .types import MatrixEvent, MatrixEventID, MatrixRoomID, MatrixUserID
from . import user as u, portal as po, puppet as pu, commands as com
+if TYPE_CHECKING:
+ from mautrix_appservice import AppService
+ from .context import Context
+ from sqlalchemy.orm import scoped_session
+ from .config import Config
+ from .bot import Bot
+
class MatrixHandler:
log = logging.getLogger("mau.mx") # type: logging.Logger
- def __init__(self, context):
- self.az, self.db, self.config, _, self.tgbot = context
+ def __init__(self, context: 'Context') -> None:
+ self.az, self.db, self.config, _, self.tgbot = context.core
self.commands = com.CommandProcessor(context) # type: com.CommandProcessor
- self.previously_typing = [] # type: List[str]
+ self.previously_typing = [] # type: List[MatrixUserID]
self.az.matrix_event_handler(self.handle_event)
- async def init_as_bot(self):
+ async def init_as_bot(self) -> None:
displayname = self.config["appservice.bot_displayname"]
if displayname:
try:
@@ -50,7 +58,8 @@ class MatrixHandler:
except asyncio.TimeoutError:
self.log.exception("TimeoutError when trying to set avatar")
- async def handle_puppet_invite(self, room_id, puppet: pu.Puppet, inviter: u.User):
+ async def handle_puppet_invite(self, room_id: MatrixRoomID, puppet: pu.Puppet, inviter: u.User
+ ) -> None:
intent = puppet.default_mxid_intent
self.log.debug(f"{inviter} invited puppet for {puppet.tgid} to {room_id}")
if not await inviter.is_logged_in():
@@ -80,6 +89,7 @@ class MatrixHandler:
await intent.join_room(room_id)
portal = po.Portal.get_by_tgid(puppet.tgid, inviter.tgid, "user")
+ # TODO: if portal is None:
if portal.mxid:
try:
await intent.invite(portal.mxid, inviter.mxid)
@@ -95,13 +105,13 @@ class MatrixHandler:
portal.mxid = room_id
portal.save()
inviter.register_portal(portal)
- await intent.send_notice(room_id, "po.Portal to private chat created.")
+ await intent.send_notice(room_id, "Portal to private chat created.")
else:
await intent.join_room(room_id)
await intent.send_notice(room_id, "This puppet will remain inactive until a "
"Telegram chat is created for this room.")
- async def accept_bot_invite(self, room_id: str, inviter: u.User):
+ async def accept_bot_invite(self, room_id: MatrixRoomID, inviter: u.User) -> None:
tries = 0
while tries < 5:
try:
@@ -126,9 +136,13 @@ class MatrixHandler:
"bridge.permissions section in your config file.")
await self.az.intent.leave_room(room_id)
- async def handle_invite(self, room_id: str, user_id: str, inviter_mxid: str):
+ async def handle_invite(self, room_id: MatrixRoomID, user_id: MatrixUserID,
+ inviter_mxid: MatrixUserID) -> None:
self.log.debug(f"{inviter_mxid} invited {user_id} to {room_id}")
- inviter = await u.User.get_by_mxid(inviter_mxid).ensure_started()
+ inviter = u.User.get_by_mxid(inviter_mxid)
+ if inviter is None:
+ self.log.exception("Failed to find user with Matrix ID {inviter_mxid}")
+ await inviter.ensure_started()
if user_id == self.az.bot_mxid:
return await self.accept_bot_invite(room_id, inviter)
elif not inviter.whitelisted:
@@ -150,7 +164,8 @@ class MatrixHandler:
# The rest can probably be ignored
- async def handle_join(self, room_id: str, user_id: str, event_id: str):
+ async def handle_join(self, room_id: MatrixRoomID, user_id: MatrixUserID,
+ event_id: MatrixEventID) -> None:
user = await u.User.get_by_mxid(user_id).ensure_started()
portal = po.Portal.get_by_mxid(room_id)
@@ -171,7 +186,8 @@ class MatrixHandler:
if await user.is_logged_in() or portal.has_bot:
await portal.join_matrix(user, event_id)
- async def handle_part(self, room_id: str, user_id, sender_mxid: str, event_id: str):
+ async def handle_part(self, room_id: MatrixRoomID, user_id: MatrixUserID,
+ sender_mxid: MatrixUserID, event_id: MatrixEventID) -> None:
self.log.debug(f"{user_id} left {room_id}")
sender = u.User.get_by_mxid(sender_mxid, create=False)
@@ -185,6 +201,7 @@ class MatrixHandler:
puppet = pu.Puppet.get_by_mxid(user_id)
if sender and puppet:
+ # TODO: Puppet should probably be an AbstractUser
await portal.leave_matrix(puppet, sender, event_id)
user = u.User.get_by_mxid(user_id, create=False)
@@ -194,7 +211,7 @@ class MatrixHandler:
if await user.is_logged_in() or portal.has_bot:
await portal.leave_matrix(user, sender, event_id)
- def is_command(self, message: dict) -> Tuple[bool, str]:
+ def is_command(self, message: Dict) -> Tuple[bool, str]:
text = message.get("body", "")
prefix = self.config["bridge.command_prefix"]
is_command = text.startswith(prefix)
@@ -202,9 +219,10 @@ class MatrixHandler:
text = text[len(prefix) + 1:]
return is_command, text
- async def handle_message(self, room, sender, message, event_id):
+ async def handle_message(self, room: MatrixRoomID, sender_id: MatrixUserID, message: Dict,
+ event_id: MatrixEventID) -> None:
is_command, text = self.is_command(message)
- sender = await u.User.get_by_mxid(sender).ensure_started()
+ sender = await u.User.get_by_mxid(sender_id).ensure_started()
if not sender.relaybot_whitelisted:
self.log.debug(f"Ignoring message \"{message}\" from {sender} to {room}:"
" u.User is not whitelisted.")
@@ -237,7 +255,8 @@ class MatrixHandler:
is_portal=portal is not None)
@staticmethod
- async def handle_redaction(room_id: str, sender_mxid: str, event_id: str):
+ async def handle_redaction(room_id: MatrixRoomID, sender_mxid: MatrixUserID,
+ event_id: MatrixEventID) -> None:
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if not sender.relaybot_whitelisted:
return
@@ -249,14 +268,16 @@ class MatrixHandler:
await portal.handle_matrix_deletion(sender, event_id)
@staticmethod
- async def handle_power_levels(room_id: str, sender_mxid: str, new: dict, old: dict):
+ async def handle_power_levels(room_id: MatrixRoomID, sender_mxid: MatrixUserID,
+ new: Dict, old: Dict) -> None:
portal = po.Portal.get_by_mxid(room_id)
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if await sender.has_full_access(allow_bot=True) and portal:
await portal.handle_matrix_power_levels(sender, new["users"], old["users"])
@staticmethod
- async def handle_room_meta(evt_type: str, room_id: str, sender_mxid: str, content: dict):
+ async def handle_room_meta(evt_type: str, room_id: MatrixRoomID, sender_mxid: MatrixUserID,
+ content: dict) -> None:
portal = po.Portal.get_by_mxid(room_id)
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if await sender.has_full_access(allow_bot=True) and portal:
@@ -270,8 +291,8 @@ class MatrixHandler:
await handler(sender, content[content_key])
@staticmethod
- async def handle_room_pin(room_id: str, sender_mxid: str, new_events: Set[str],
- old_events: Set[str]):
+ async def handle_room_pin(room_id: MatrixRoomID, sender_mxid: MatrixUserID,
+ new_events: Set[str], old_events: Set[str]) -> None:
portal = po.Portal.get_by_mxid(room_id)
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if await sender.has_full_access(allow_bot=True) and portal:
@@ -284,8 +305,8 @@ class MatrixHandler:
await portal.handle_matrix_pin(sender, None)
@staticmethod
- async def handle_name_change(room_id: str, user_id: str, displayname: str,
- prev_displayname: str, event_id: str):
+ async def handle_name_change(room_id: MatrixRoomID, user_id: MatrixUserID, displayname: str,
+ prev_displayname: str, event_id: MatrixEventID) -> None:
portal = po.Portal.get_by_mxid(room_id)
if not portal or not portal.has_bot:
return
@@ -295,13 +316,14 @@ class MatrixHandler:
await portal.name_change_matrix(user, displayname, prev_displayname, event_id)
@staticmethod
- def parse_read_receipts(content: dict) -> Dict[str, str]:
+ def parse_read_receipts(content: Dict) -> Dict[MatrixUserID, MatrixEventID]:
return {user_id: event_id
for event_id, receipts in content.items()
for user_id in receipts.get("m.read", {})}
@staticmethod
- async def handle_read_receipts(room_id: str, receipts: Dict[str, str]):
+ async def handle_read_receipts(room_id: MatrixRoomID,
+ receipts: Dict[MatrixUserID, MatrixEventID]) -> None:
portal = po.Portal.get_by_mxid(room_id)
if not portal:
return
@@ -313,13 +335,13 @@ class MatrixHandler:
await portal.mark_read(user, event_id)
@staticmethod
- async def handle_presence(user_id: str, presence: str):
+ async def handle_presence(user_id: MatrixUserID, presence: str) -> None:
user = await u.User.get_by_mxid(user_id).ensure_started()
if not await user.is_logged_in():
return
- await user.set_presence(presence == "online")
+ user.set_presence(presence == "online")
- async def handle_typing(self, room_id: str, now_typing: List[str]):
+ async def handle_typing(self, room_id: MatrixRoomID, now_typing: List[MatrixUserID]) -> None:
portal = po.Portal.get_by_mxid(room_id)
if not portal:
return
@@ -334,35 +356,35 @@ class MatrixHandler:
if not await user.is_logged_in():
continue
- await portal.set_typing(user, is_typing)
+ portal.set_typing(user, is_typing)
self.previously_typing = now_typing
- def filter_matrix_event(self, event: dict):
+ def filter_matrix_event(self, event: MatrixEvent) -> bool:
sender = event.get("sender", None)
if not sender:
return False
return (sender == self.az.bot_mxid
or pu.Puppet.get_id_from_mxid(sender) is not None)
- async def try_handle_event(self, evt: dict):
+ async def try_handle_event(self, evt: MatrixEvent) -> None:
try:
await self.handle_event(evt)
except Exception:
self.log.exception("Error handling manually received Matrix event")
- async def handle_event(self, evt: dict):
+ async def handle_event(self, evt: MatrixEvent) -> None:
if self.filter_matrix_event(evt):
return
self.log.debug("Received event: %s", evt)
evt_type = evt.get("type", "m.unknown") # type: str
- room_id = evt.get("room_id", None) # type: str
- event_id = evt.get("event_id", None) # type: str
- sender = evt.get("sender", None) # type: str
- content = evt.get("content", {}) # type: dict
+ room_id = evt.get("room_id", None) # type: Optional[MatrixRoomID]
+ event_id = evt.get("event_id", None) # type: Optional[MatrixEventID]
+ sender = evt.get("sender", None) # type: Optional[MatrixUserID]
+ content = evt.get("content", {}) # type: Dict
if evt_type == "m.room.member":
- state_key = evt["state_key"] # type: str
- prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: dict
+ state_key = evt["state_key"] # type: MatrixUserID
+ prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: Dict
membership = content.get("membership", "") # type: str
prev_membership = prev_content.get("membership", "leave") # type: str
if membership == prev_membership:
@@ -386,7 +408,7 @@ class MatrixHandler:
elif evt_type == "m.room.redaction":
await self.handle_redaction(room_id, sender, evt["redacts"])
elif evt_type == "m.room.power_levels":
- prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: dict
+ prev_content = evt.get("unsigned", {}).get("prev_content", {})
await self.handle_power_levels(room_id, sender, evt["content"], prev_content)
elif evt_type in ("m.room.name", "m.room.avatar", "m.room.topic"):
await self.handle_room_meta(evt_type, room_id, sender, evt["content"])
diff --git a/mautrix_telegram/portal.py b/mautrix_telegram/portal.py
index 7b90f531..3803da5a 100644
--- a/mautrix_telegram/portal.py
+++ b/mautrix_telegram/portal.py
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Pattern, Dict, Tuple, Awaitable, TYPE_CHECKING
+from typing import Awaitable, Dict, List, Optional, Pattern, Tuple, Union, cast, TYPE_CHECKING
from collections import deque
from datetime import datetime
from string import Template
@@ -32,14 +32,37 @@ from sqlalchemy import orm
from sqlalchemy.exc import IntegrityError, InvalidRequestError
from sqlalchemy.orm.exc import FlushError
-from telethon.tl.functions.messages import *
-from telethon.tl.functions.channels import *
+from telethon.tl.functions.messages import (
+ AddChatUserRequest, CreateChatRequest, DeleteChatUserRequest, EditChatAdminRequest,
+ EditChatPhotoRequest, EditChatTitleRequest, ExportChatInviteRequest, GetFullChatRequest,
+ MigrateChatRequest, SetTypingRequest)
+from telethon.tl.functions.channels import (
+ CreateChannelRequest, EditAboutRequest, EditAdminRequest, EditBannedRequest, EditPhotoRequest,
+ EditTitleRequest, ExportInviteRequest, GetParticipantsRequest, InviteToChannelRequest,
+ JoinChannelRequest, LeaveChannelRequest, UpdatePinnedMessageRequest, UpdateUsernameRequest)
from telethon.tl.functions.messages import ReadHistoryRequest as ReadMessageHistoryRequest
from telethon.tl.functions.channels import ReadHistoryRequest as ReadChannelHistoryRequest
from telethon.errors import ChatAdminRequiredError, ChatNotModifiedError
-from telethon.tl.types import *
+from telethon.tl.types import (
+ Channel, ChannelAdminRights, ChannelBannedRights, ChannelFull, ChannelParticipantAdmin,
+ ChannelParticipantCreator, ChannelParticipantsRecent, ChannelParticipantsSearch, Chat,
+ ChatFull, ChatInviteEmpty, ChatParticipantAdmin, ChatParticipantCreator, ChatPhoto,
+ DocumentAttributeFilename, DocumentAttributeImageSize, DocumentAttributeSticker,
+ DocumentAttributeVideo, FileLocation, GeoPoint, InputChannel, InputChatUploadedPhoto,
+ InputPeerChannel, InputPeerChat, InputPeerUser, InputUser, InputUserSelf, Message,
+ MessageActionChannelCreate, MessageActionChatAddUser, MessageActionChatCreate,
+ MessageActionChatDeletePhoto, MessageActionChatDeleteUser, MessageActionChatEditPhoto,
+ MessageActionChatEditTitle, MessageActionChatJoinedByLink, MessageActionChatMigrateTo,
+ MessageActionPinMessage, MessageMediaContact, MessageMediaDocument, MessageMediaGeo,
+ MessageMediaPhoto, MessageService, PeerChannel, PeerChat, PeerUser, Photo, PhotoCachedSize,
+ SendMessageCancelAction, SendMessageTypingAction, TypeChannelParticipant, TypeChat,
+ TypeChatParticipant, TypeDocumentAttribute, TypeInputPeer, TypeMessageAction,
+ TypeMessageEntity, TypePeer, TypePhotoSize, TypeUpdates, TypeUser, TypeUserFull,
+ UpdateChatUserTyping, UpdateNewChannelMessage, UpdateNewMessage, UpdateUserTyping, User,
+ UserFull)
from mautrix_appservice import MatrixRequestError, IntentError, AppService, IntentAPI
+from .types import MatrixEventID, MatrixRoomID, MatrixUserID, TelegramID
from .context import Context
from .db import Portal as DBPortal, Message as DBMessage, TelegramFile as DBTelegramFile
from . import puppet as p, user as u, formatter, util
@@ -82,18 +105,18 @@ class Portal:
by_mxid = {} # type: Dict[str, Portal]
by_tgid = {} # type: Dict[Tuple[int, int], Portal]
- def __init__(self, tgid: int, peer_type: str, tg_receiver: Optional[int] = None,
- mxid: Optional[str] = None, username: Optional[str] = None,
+ def __init__(self, tgid: TelegramID, peer_type: str, tg_receiver: Optional[int] = None,
+ mxid: Optional[MatrixRoomID] = None, username: Optional[str] = None,
megagroup: Optional[bool] = False, title: Optional[str] = None,
about: Optional[str] = None, photo_id: Optional[str] = None,
- db_instance: DBPortal = None):
- self.mxid = mxid # type: str
- self.tgid = tgid # type: int
+ db_instance: DBPortal = None) -> None:
+ self.mxid = mxid # type: Optional[MatrixRoomID]
+ self.tgid = tgid # type: TelegramID
self.tg_receiver = tg_receiver or tgid # type: int
self.peer_type = peer_type # type: str
self.username = username # type: str
self.megagroup = megagroup # type: bool
- self.title = title # type: str
+ self.title = title # type: Optional[str]
self.about = about # type: str
self.photo_id = photo_id # type: str
self._db_instance = db_instance # type: DBPortal
@@ -138,7 +161,7 @@ class Portal:
@property
def has_bot(self) -> bool:
- return self.bot and self.bot.is_in_chat(self.tgid)
+ return bool(self.bot and self.bot.is_in_chat(self.tgid))
@property
def main_intent(self) -> IntentAPI:
@@ -232,13 +255,13 @@ class Portal:
del self._dedup_mxid[self._dedup.popleft()]
return None
- def get_input_entity(self, user: u.User) -> Awaitable[TypeInputPeer]:
+ def get_input_entity(self, user: 'u.User') -> Awaitable[TypeInputPeer]:
return user.client.get_input_entity(self.peer)
# endregion
# region Matrix room info updating
- async def invite_to_matrix(self, users: InviteList):
+ async def invite_to_matrix(self, users: InviteList) -> None:
if isinstance(users, str):
await self.main_intent.invite(self.mxid, users, check_cache=True)
elif isinstance(users, list):
@@ -247,10 +270,10 @@ class Portal:
else:
raise ValueError("Invalid invite identifier given to invite_matrix()")
- async def update_matrix_room(self, user: "AbstractUser", entity: TypeChat, direct: bool,
- puppet: p.Puppet = None, levels: dict = None,
+ async def update_matrix_room(self, user: 'AbstractUser', entity: TypeChat, direct: bool,
+ puppet: p.Puppet = None, levels: Dict = None,
users: List[User] = None,
- participants: List[TypeParticipant] = None):
+ participants: List[TypeParticipant] = None) -> None:
if not direct:
await self.update_info(user, entity)
if not users or not participants:
@@ -280,8 +303,8 @@ class Portal:
async with self._room_create_lock:
return await self._create_matrix_room(user, entity, invites)
- async def _create_matrix_room(self, user: "AbstractUser", entity: TypeChat, invites: InviteList
- ) -> Optional[str]:
+ async def _create_matrix_room(self, user: 'AbstractUser', entity: TypeChat, invites: InviteList
+ ) -> Optional[MatrixRoomID]:
direct = self.peer_type == "user"
if self.mxid:
@@ -346,6 +369,8 @@ class Portal:
participants=participants),
loop=self.loop)
+ return self.mxid
+
def _get_base_power_levels(self, levels: dict = None, entity: TypeChat = None) -> dict:
levels = levels or {}
power_level_requirement = (0 if self.peer_type == "chat" and not entity.admins_enabled
@@ -383,7 +408,7 @@ class Portal:
return None
return self.alias_template.format(groupname=username)
- def add_bot_chat(self, bot: User):
+ def add_bot_chat(self, bot: User) -> None:
if self.bot and bot.id == self.bot.tgid:
self.bot.add_chat(self.tgid, self.peer_type)
return
@@ -392,7 +417,7 @@ class Portal:
if user and user.is_bot:
user.register_portal(self)
- async def sync_telegram_users(self, source: "AbstractUser", users: List[User]):
+ async def sync_telegram_users(self, source: "AbstractUser", users: List[User]) -> None:
allowed_tgids = set()
for entity in users:
puppet = p.Puppet.get(entity.id)
@@ -414,18 +439,19 @@ class Portal:
and config["bridge.max_initial_member_sync"] == -1
and (self.megagroup or self.peer_type != "channel"))
if trust_member_list:
- joined_mxids = await self.main_intent.get_room_members(self.mxid)
- for user in joined_mxids:
- if user == self.az.bot_mxid:
+ joined_mxids = cast(List[MatrixUserID],
+ await self.main_intent.get_room_members(self.mxid))
+ for user_mxid in joined_mxids:
+ if user_mxid == self.az.bot_mxid:
continue
- puppet_id = p.Puppet.get_id_from_mxid(user)
+ puppet_id = p.Puppet.get_id_from_mxid(user_mxid)
if puppet_id and puppet_id not in allowed_tgids:
if self.bot and puppet_id == self.bot.tgid:
self.bot.remove_chat(self.tgid)
- await self.main_intent.kick(self.mxid, user,
+ await self.main_intent.kick(self.mxid, user_mxid,
"User had left this Telegram chat.")
continue
- mx_user = u.User.get_by_mxid(user, create=False)
+ mx_user = u.User.get_by_mxid(user_mxid, create=False)
if mx_user and mx_user.is_bot and mx_user.tgid not in allowed_tgids:
mx_user.unregister_portal(self)
@@ -434,7 +460,8 @@ class Portal:
"You had left this Telegram chat.")
continue
- async def add_telegram_user(self, user_id: int, source: Optional["AbstractUser"] = None):
+ async def add_telegram_user(self, user_id: TelegramID, source: Optional['AbstractUser'] = None
+ ) -> None:
puppet = p.Puppet.get(user_id)
if source:
entity = await source.client.get_entity(PeerUser(user_id))
@@ -446,7 +473,7 @@ class Portal:
user.register_portal(self)
await self.invite_to_matrix(user.mxid)
- async def delete_telegram_user(self, user_id: int, sender: p.Puppet):
+ async def delete_telegram_user(self, user_id: TelegramID, sender: p.Puppet) -> None:
puppet = p.Puppet.get(user_id)
user = u.User.get_by_tgid(user_id)
kick_message = (f"Kicked by {sender.displayname}"
@@ -460,7 +487,7 @@ class Portal:
user.unregister_portal(self)
await self.main_intent.kick(self.mxid, user.mxid, kick_message)
- async def update_info(self, user: "AbstractUser", entity: TypeChat = None):
+ async def update_info(self, user: "AbstractUser", entity: TypeChat = None) -> None:
if self.peer_type == "user":
self.log.warning(f"Called update_info() for direct chat portal {self.tgid_log}")
return
@@ -524,7 +551,7 @@ class Portal:
return max(photo.sizes, key=(lambda photo2: (
len(photo2.bytes) if isinstance(photo2, PhotoCachedSize) else photo2.size)))
- async def remove_avatar(self, _: "AbstractUser", save: bool = False):
+ async def remove_avatar(self, _: "AbstractUser", save: bool = False) -> None:
await self.main_intent.set_room_avatar(self.mxid, None)
self.photo_id = None
if save:
@@ -544,8 +571,9 @@ class Portal:
return True
return False
- async def _get_users(self, user: "AbstractUser", entity: Union[TypeInputPeer, InputUser,
- TypeChat, TypeUser]
+ async def _get_users(self,
+ user: 'AbstractUser',
+ entity: Union[TypeInputPeer, InputUser, TypeChat, TypeUser]
) -> Tuple[List[TypeUser], List[TypeParticipant]]:
if self.peer_type == "chat":
chat = await user.client(GetFullChatRequest(chat_id=self.tgid))
@@ -564,7 +592,7 @@ class Portal:
entity, ChannelParticipantsRecent(), offset=0, limit=limit, hash=0))
return response.users, response.participants
elif limit > 200 or limit == -1:
- users, participants = [], []
+ users, participants = [], [] # type: Tuple[List[TypeUser], List[TypeParticipant]]
offset = 0
remaining_quota = limit if limit > 0 else 1000000
query = (ChannelParticipantsSearch("") if limit == -1
@@ -585,8 +613,9 @@ class Portal:
return [], []
elif self.peer_type == "user":
return [entity], []
+ return [], []
- async def get_invite_link(self, user: u.User) -> str:
+ async def get_invite_link(self, user: 'u.User') -> str:
if self.peer_type == "user":
raise ValueError("You can't invite users to private chats.")
elif self.peer_type == "chat":
@@ -604,7 +633,7 @@ class Portal:
return link.link
- async def get_authenticated_matrix_users(self) -> List[u.User]:
+ async def get_authenticated_matrix_users(self) -> List['u.User']:
try:
members = await self.main_intent.get_room_members(self.mxid)
except MatrixRequestError:
@@ -622,7 +651,7 @@ class Portal:
@staticmethod
async def cleanup_room(intent: IntentAPI, room_id: str, message: str = "Portal deleted",
- puppets_only: bool = False):
+ puppets_only: bool = False) -> None:
try:
members = await intent.get_room_members(room_id)
except MatrixRequestError:
@@ -639,11 +668,11 @@ class Portal:
pass
await intent.leave_room(room_id)
- async def unbridge(self):
+ async def unbridge(self) -> None:
await self.cleanup_room(self.main_intent, self.mxid, "Room unbridged", puppets_only=True)
self.delete()
- async def cleanup_and_delete(self):
+ async def cleanup_and_delete(self) -> None:
await self.cleanup_room(self.main_intent, self.mxid)
self.delete()
@@ -663,8 +692,8 @@ class Portal:
else:
return ""
- async def _get_state_change_message(self, event: str, user: u.User,
- arguments: Optional[dict] = None) -> Optional[dict]:
+ async def _get_state_change_message(self, event: str, user: 'u.User',
+ arguments: Optional[Dict] = None) -> Optional[Dict]:
tpl = config[f"bridge.state_event_formats.{event}"]
if len(tpl) == 0:
# Empty format means they don't want the message
@@ -681,8 +710,8 @@ class Portal:
"formatted_body": message,
}
- async def name_change_matrix(self, user: u.User, displayname: str, prev_displayname: str,
- event_id: str):
+ async def name_change_matrix(self, user: 'u.User', displayname: str, prev_displayname: str,
+ event_id: str) -> None:
async with self.require_send_lock(self.bot.tgid):
message = await self._get_state_change_message(
"name_change", user,
@@ -695,15 +724,16 @@ class Portal:
space = self.tgid if self.peer_type == "channel" else self.bot.tgid
self.is_duplicate(response, (event_id, space))
- async def get_displayname(self, user: u.User) -> str:
+ async def get_displayname(self, user: 'u.User') -> str:
return (await self.main_intent.get_displayname(self.mxid, user.mxid)
or user.mxid_localpart)
- def set_typing(self, user: u.User, typing: bool = True, action=SendMessageTypingAction):
+ def set_typing(self, user: 'u.User', typing: bool = True,
+ action: type = SendMessageTypingAction) -> bool:
return user.client(SetTypingRequest(
self.peer, action() if typing else SendMessageCancelAction()))
- async def mark_read(self, user: u.User, event_id: str):
+ async def mark_read(self, user: 'u.User', event_id: MatrixEventID) -> None:
if user.is_bot:
return
space = self.tgid if self.peer_type == "channel" else user.tgid
@@ -718,7 +748,8 @@ class Portal:
else:
await user.client(ReadMessageHistoryRequest(peer=self.peer, max_id=message.tgid))
- async def leave_matrix(self, user: u.User, source: u.User, event_id: str):
+ async def leave_matrix(self, user: 'u.User', source: 'u.User', event_id: MatrixEventID
+ ) -> None:
if await user.needs_relaybot(self):
async with self.require_send_lock(self.bot.tgid):
message = await self._get_state_change_message("leave", user)
@@ -754,7 +785,7 @@ class Portal:
channel = await self.get_input_entity(user)
await user.client(LeaveChannelRequest(channel=channel))
- async def join_matrix(self, user: u.User, event_id: str):
+ async def join_matrix(self, user: 'u.User', event_id: str) -> None:
if await user.needs_relaybot(self):
async with self.require_send_lock(self.bot.tgid):
message = await self._get_state_change_message("join", user)
@@ -773,7 +804,7 @@ class Portal:
# We'll just assume the user is already in the chat.
pass
- async def _apply_msg_format(self, sender: u.User, msgtype: str, message: dict):
+ async def _apply_msg_format(self, sender: 'u.User', msgtype: str, message: Dict) -> None:
if "formatted_body" not in message:
message["format"] = "org.matrix.custom.html"
message["formatted_body"] = escape_html(message.get("body", ""))
@@ -788,7 +819,8 @@ class Portal:
message=body)
message["formatted_body"] = Template(tpl).safe_substitute(tpl_args)
- async def _pre_process_matrix_message(self, sender: u.User, use_relaybot: bool, message: dict):
+ async def _pre_process_matrix_message(self, sender: 'u.User', use_relaybot: bool,
+ message: dict) -> None:
msgtype = message.get("msgtype", "m.text")
if msgtype == "m.emote":
await self._apply_msg_format(sender, msgtype, message)
@@ -797,7 +829,7 @@ class Portal:
await self._apply_msg_format(sender, msgtype, message)
@staticmethod
- def _matrix_event_to_entities(event: dict) -> Tuple[str, Optional[List[TypeMessageEntity]]]:
+ def _matrix_event_to_entities(event: Dict) -> Tuple[str, Optional[List[TypeMessageEntity]]]:
try:
if event.get("format", None) == "org.matrix.custom.html":
message, entities = formatter.matrix_to_telegram(event.get("formatted_body", ""))
@@ -825,7 +857,8 @@ class Portal:
return None
async def _handle_matrix_text(self, sender_id: int, event_id: str, space: int,
- client: "MautrixTelegramClient", message: dict, reply_to: int):
+ client: 'MautrixTelegramClient', message: Dict, reply_to: int
+ ) -> None:
lock = self.require_send_lock(sender_id)
async with lock:
response = await client.send_message(self.peer, message, reply_to=reply_to,
@@ -833,7 +866,8 @@ class Portal:
self._add_telegram_message_to_db(event_id, space, response)
async def _handle_matrix_file(self, msgtype: str, sender_id: int, event_id: str, space: int,
- client: "MautrixTelegramClient", message: dict, reply_to: int):
+ client: 'MautrixTelegramClient', message: dict, reply_to: int
+ ) -> None:
file = await self.main_intent.download_file(message["url"])
info = message.get("info", {})
@@ -867,24 +901,25 @@ class Portal:
self._add_telegram_message_to_db(event_id, space, response)
async def _handle_matrix_location(self, sender_id: int, event_id: str, space: int,
- client: "MautrixTelegramClient", message: dict,
- reply_to: int):
+ client: 'MautrixTelegramClient', message: Dict,
+ reply_to: int) -> None:
try:
lat, long = message["geo_uri"][len("geo:"):].split(",")
lat, long = float(lat), float(long)
except (KeyError, ValueError):
self.log.exception("Failed to parse location")
return None
- message, entities = self._matrix_event_to_entities(message)
+ caption, entities = self._matrix_event_to_entities(message)
media = MessageMediaGeo(geo=GeoPoint(lat, long, access_hash=0))
lock = self.require_send_lock(sender_id)
async with lock:
response = await client.send_media(self.peer, media, reply_to=reply_to,
- caption=message, entities=entities)
+ caption=caption, entities=entities)
self._add_telegram_message_to_db(event_id, space, response)
- def _add_telegram_message_to_db(self, event_id: str, space: int, response: TypeMessage):
+ def _add_telegram_message_to_db(self, event_id: str, space: int,
+ response: TypeMessage) -> None:
self.log.debug("Handled Matrix message: %s", response)
self.is_duplicate(response, (event_id, space))
self.db.add(DBMessage(
@@ -894,7 +929,7 @@ class Portal:
mxid=event_id))
self.db.commit()
- async def handle_matrix_message(self, sender: u.User, message: dict, event_id: str):
+ async def handle_matrix_message(self, sender: 'u.User', message: dict, event_id: str) -> None:
puppet = p.Puppet.get_by_custom_mxid(sender.mxid)
if puppet and message.get("net.maunium.telegram.puppet", False):
self.log.debug("Ignoring puppet-sent message by confirmed puppet user %s", sender.mxid)
@@ -922,7 +957,7 @@ class Portal:
else:
self.log.debug(f"Unhandled Matrix event: {message}")
- async def handle_matrix_pin(self, sender: u.User, pinned_message: Optional[str]):
+ async def handle_matrix_pin(self, sender: 'u.User', pinned_message: Optional[str]) -> None:
if self.peer_type != "channel":
return
try:
@@ -936,17 +971,18 @@ class Portal:
except ChatNotModifiedError:
pass
- async def handle_matrix_deletion(self, deleter: u.User, event_id: str):
- deleter = deleter if not await deleter.needs_relaybot(self) else self.bot
- space = self.tgid if self.peer_type == "channel" else deleter.tgid
+ async def handle_matrix_deletion(self, deleter: 'u.User', event_id: MatrixEventID) -> None:
+ real_deleter = deleter if not await deleter.needs_relaybot(self) else self.bot
+ space = self.tgid if self.peer_type == "channel" else real_deleter.tgid
message = DBMessage.query.filter(DBMessage.mxid == event_id,
DBMessage.tg_space == space,
DBMessage.mx_room == self.mxid).one_or_none()
if not message:
return
- await deleter.client.delete_messages(self.peer, [message.tgid])
+ await real_deleter.client.delete_messages(self.peer, [message.tgid])
- async def _update_telegram_power_level(self, sender: u.User, user_id: int, level: int):
+ async def _update_telegram_power_level(self, sender: 'u.User', user_id: TelegramID,
+ level: int) -> None:
if self.peer_type == "chat":
await sender.client(EditChatAdminRequest(
chat_id=self.tgid, user_id=user_id, is_admin=level >= 50))
@@ -962,8 +998,9 @@ class Portal:
EditAdminRequest(channel=await self.get_input_entity(sender),
user_id=user_id, admin_rights=rights))
- async def handle_matrix_power_levels(self, sender: u.User, new_users: Dict[str, int],
- old_users: Dict[str, int]):
+ async def handle_matrix_power_levels(self, sender: 'u.User',
+ new_users: Dict[MatrixUserID, int],
+ old_users: Dict[str, int]) -> None:
# TODO handle all power level changes and bridge exact admin rights to supergroups/channels
for user, level in new_users.items():
if not user or user == self.main_intent.mxid or user == sender.mxid:
@@ -979,7 +1016,7 @@ class Portal:
if user not in old_users or level != old_users[user]:
await self._update_telegram_power_level(sender, user_id, level)
- async def handle_matrix_about(self, sender: u.User, about: str):
+ async def handle_matrix_about(self, sender: 'u.User', about: str) -> None:
if self.peer_type not in {"channel"}:
return
channel = await self.get_input_entity(sender)
@@ -987,7 +1024,7 @@ class Portal:
self.about = about
self.save()
- async def handle_matrix_title(self, sender: u.User, title: str):
+ async def handle_matrix_title(self, sender: 'u.User', title: str) -> None:
if self.peer_type not in {"chat", "channel"}:
return
@@ -1000,7 +1037,7 @@ class Portal:
self.title = title
self.save()
- async def handle_matrix_avatar(self, sender: u.User, url: str):
+ async def handle_matrix_avatar(self, sender: 'u.User', url: str) -> None:
if self.peer_type not in {"chat", "channel"}:
# Invalid peer type
return
@@ -1027,7 +1064,7 @@ class Portal:
self.save()
break
- def _register_outgoing_actions_for_dedup(self, response: TypeUpdates):
+ def _register_outgoing_actions_for_dedup(self, response: TypeUpdates) -> None:
for update in response.updates:
check_dedup = (isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage))
and isinstance(update.message, MessageService))
@@ -1051,7 +1088,7 @@ class Portal:
user_tgids.add(puppet_id)
return list(user_tgids)
- async def upgrade_telegram_chat(self, source: u.User):
+ async def upgrade_telegram_chat(self, source: 'u.User') -> None:
if self.peer_type != "chat":
raise ValueError("Only normal group chats are upgradable to supergroups.")
@@ -1067,7 +1104,7 @@ class Portal:
self.migrate_and_save(entity.id)
await self.update_info(source, entity)
- async def set_telegram_username(self, source: u.User, username: str):
+ async def set_telegram_username(self, source: 'u.User', username: str) -> None:
if self.peer_type != "channel":
raise ValueError("Only channels and supergroups have usernames.")
await source.client(
@@ -1075,7 +1112,7 @@ class Portal:
if await self.update_username(username):
self.save()
- async def create_telegram_chat(self, source: u.User, supergroup: bool = False):
+ async def create_telegram_chat(self, source: 'u.User', supergroup: bool = False) -> None:
if not self.mxid:
raise ValueError("Can't create Telegram chat for portal without Matrix room.")
elif self.tgid:
@@ -1116,7 +1153,8 @@ class Portal:
await self.main_intent.set_power_levels(self.mxid, levels)
await self.handle_matrix_power_levels(source, levels["users"], {})
- async def invite_telegram(self, source: u.User, puppet: Union[p.Puppet, "AbstractUser"]):
+ async def invite_telegram(self, source: 'u.User',
+ puppet: Union[p.Puppet, "AbstractUser"]) -> None:
if self.peer_type == "chat":
await source.client(
AddChatUserRequest(chat_id=self.tgid, user_id=puppet.tgid, fwd_limit=0))
@@ -1129,7 +1167,7 @@ class Portal:
# region Telegram event handling
async def handle_telegram_typing(self, user: p.Puppet,
- _: Union[UpdateUserTyping, UpdateChatUserTyping]):
+ _: Union[UpdateUserTyping, UpdateChatUserTyping]) -> None:
if self.mxid:
await user.intent.set_typing(self.mxid, is_typing=True)
@@ -1139,7 +1177,7 @@ class Portal:
return None
async def handle_telegram_photo(self, source: "AbstractUser", intent: IntentAPI, evt: Message,
- relates_to=None):
+ relates_to: Dict = {}) -> None:
largest_size = self._get_largest_photo_size(evt.media.photo)
file = await util.transfer_file_to_matrix(self.db, source.client, intent,
largest_size.location)
@@ -1169,7 +1207,7 @@ class Portal:
external_url=self.get_external_url(evt))
@staticmethod
- def _parse_telegram_document_attributes(attributes: List[TypeDocumentAttribute]) -> dict:
+ def _parse_telegram_document_attributes(attributes: List[TypeDocumentAttribute]) -> Dict:
attrs = {
"name": None,
"mime_type": None,
@@ -1177,7 +1215,7 @@ class Portal:
"sticker_alt": None,
"width": None,
"height": None,
- }
+ } # type: Dict
for attr in attributes:
if isinstance(attr, DocumentAttributeFilename):
attrs["name"] = attrs["name"] or attr.file_name
@@ -1190,8 +1228,8 @@ class Portal:
return attrs
@staticmethod
- def _parse_telegram_document_meta(evt: Message, file: DBTelegramFile, attrs: dict
- ) -> Tuple[dict, str]:
+ def _parse_telegram_document_meta(evt: Message, file: DBTelegramFile, attrs: Dict
+ ) -> Tuple[Dict, str]:
document = evt.media.document
name = evt.message or attrs["name"]
if attrs["is_sticker"]:
@@ -1225,7 +1263,7 @@ class Portal:
async def handle_telegram_document(self, source: "AbstractUser", intent: IntentAPI,
evt: Message,
- relates_to: dict = None) -> Optional[dict]:
+ relates_to: dict = None) -> Optional[Dict]:
document = evt.media.document
attrs = self._parse_telegram_document_attributes(document.attributes)
@@ -1300,7 +1338,8 @@ class Portal:
msgtype=msgtype, timestamp=evt.date,
external_url=self.get_external_url(evt))
- async def handle_telegram_edit(self, source: "AbstractUser", sender: p.Puppet, evt: Message):
+ async def handle_telegram_edit(self, source: "AbstractUser", sender: p.Puppet,
+ evt: Message) -> None:
if not self.mxid:
return
elif not config["bridge.edits_as_replies"]:
@@ -1349,7 +1388,8 @@ class Portal:
.update({"mxid": mxid})
self.db.commit()
- async def handle_telegram_message(self, source: "AbstractUser", sender: p.Puppet, evt: Message):
+ async def handle_telegram_message(self, source: "AbstractUser", sender: p.Puppet,
+ evt: Message) -> None:
if not self.mxid:
await self.create_matrix_room(source, invites=[source.mxid], update_if_exists=False)
@@ -1461,7 +1501,7 @@ class Portal:
return True
async def handle_telegram_action(self, source: "AbstractUser", sender: p.Puppet,
- update: MessageService):
+ update: MessageService) -> None:
action = update.action
should_ignore = ((not self.mxid and not await self._create_room_on_action(source, action))
or self.is_duplicate_action(update))
@@ -1491,9 +1531,9 @@ class Portal:
else:
self.log.debug("Unhandled Telegram action in %s: %s", self.title, action)
- async def set_telegram_admin(self, user_id: int):
+ async def set_telegram_admin(self, user_id: TelegramID) -> None:
puppet = p.Puppet.get(user_id)
- user = await u.User.get_by_tgid(user_id)
+ user = u.User.get_by_tgid(user_id)
levels = await self.main_intent.get_power_levels(self.mxid)
if user:
@@ -1502,12 +1542,12 @@ class Portal:
levels["users"][puppet.mxid] = 50
await self.main_intent.set_power_levels(self.mxid, levels)
- async def receive_telegram_pin_sender(self, sender: p.Puppet):
+ async def receive_telegram_pin_sender(self, sender: p.Puppet) -> None:
self._temp_pinned_message_sender = sender
if self._temp_pinned_message_id:
await self.update_telegram_pin()
- async def update_telegram_pin(self):
+ async def update_telegram_pin(self) -> None:
intent = (self._temp_pinned_message_sender.intent
if self._temp_pinned_message_sender else self.main_intent)
msg_id = self._temp_pinned_message_id
@@ -1520,7 +1560,7 @@ class Portal:
else:
await intent.set_pinned_messages(self.mxid, [])
- async def receive_telegram_pin_id(self, msg_id: int):
+ async def receive_telegram_pin_id(self, msg_id: int) -> None:
if msg_id == 0:
return await self.update_telegram_pin()
self._temp_pinned_message_id = msg_id
@@ -1528,7 +1568,7 @@ class Portal:
await self.update_telegram_pin()
@staticmethod
- def _get_level_from_participant(participant: TypeParticipant, _) -> int:
+ def _get_level_from_participant(participant: TypeParticipant, _: Dict) -> int:
# TODO use the power level requirements to get better precision in channels
if isinstance(participant, (ChatParticipantAdmin, ChannelParticipantAdmin)):
return 50
@@ -1537,7 +1577,7 @@ class Portal:
return 0
@staticmethod
- def _participant_to_power_levels(levels: dict, user: Union[u.User, p.Puppet], new_level: int,
+ def _participant_to_power_levels(levels: dict, user: Union['u.User', p.Puppet], new_level: int,
bot_level: int) -> bool:
new_level = min(new_level, bot_level)
default_level = levels["users_default"] if "users_default" in levels else 0
@@ -1569,7 +1609,7 @@ class Portal:
except KeyError:
return 50
- def _participants_to_power_levels(self, participants: List[TypeParticipant], levels: dict
+ def _participants_to_power_levels(self, participants: List[TypeParticipant], levels: Dict
) -> bool:
bot_level = self._get_bot_level(levels)
if bot_level < self._get_powerlevel_level(levels):
@@ -1596,13 +1636,13 @@ class Portal:
return changed
async def update_telegram_participants(self, participants: List[TypeParticipant],
- levels: dict = None):
+ levels: dict = None) -> None:
if not levels:
levels = await self.main_intent.get_power_levels(self.mxid)
if self._participants_to_power_levels(participants, levels):
await self.main_intent.set_power_levels(self.mxid, levels)
- async def set_telegram_admins_enabled(self, enabled: bool):
+ async def set_telegram_admins_enabled(self, enabled: bool) -> None:
level = 50 if enabled else 10
levels = await self.main_intent.get_power_levels(self.mxid)
levels["invite"] = level
@@ -1624,7 +1664,7 @@ class Portal:
mxid=self.mxid, username=self.username, megagroup=self.megagroup,
title=self.title, about=self.about, photo_id=self.photo_id)
- def migrate_and_save(self, new_id: int):
+ def migrate_and_save(self, new_id: TelegramID) -> None:
existing = DBPortal.query.get(self.tgid_full)
if existing:
self.db.delete(existing)
@@ -1637,7 +1677,7 @@ class Portal:
self.by_tgid[self.tgid_full] = self
self.save()
- def save(self):
+ def save(self) -> None:
self.db_instance.mxid = self.mxid
self.db_instance.username = self.username
self.db_instance.title = self.title
@@ -1645,7 +1685,7 @@ class Portal:
self.db_instance.photo_id = self.photo_id
self.db.commit()
- def delete(self):
+ def delete(self) -> None:
try:
del self.by_tgid[self.tgid_full]
except KeyError:
@@ -1660,7 +1700,7 @@ class Portal:
self.deleted = True
@classmethod
- def from_db(cls, db_portal: DBPortal) -> "Portal":
+ def from_db(cls, db_portal: DBPortal) -> 'Portal':
return Portal(tgid=db_portal.tgid, tg_receiver=db_portal.tg_receiver,
peer_type=db_portal.peer_type, mxid=db_portal.mxid,
username=db_portal.username, megagroup=db_portal.megagroup,
@@ -1671,7 +1711,7 @@ class Portal:
# region Class instance lookup
@classmethod
- def get_by_mxid(cls, mxid: str) -> Optional["Portal"]:
+ def get_by_mxid(cls, mxid: MatrixRoomID) -> Optional['Portal']:
try:
return cls.by_mxid[mxid]
except KeyError:
@@ -1691,7 +1731,7 @@ class Portal:
return None
@classmethod
- def find_by_username(cls, username: str) -> Optional["Portal"]:
+ def find_by_username(cls, username: str) -> Optional['Portal']:
if not username:
return None
@@ -1699,15 +1739,15 @@ class Portal:
if portal.username and portal.username.lower() == username.lower():
return portal
- portal = DBPortal.query.filter(DBPortal.username == username).one_or_none()
- if portal:
- return cls.from_db(portal)
+ dbportal = DBPortal.query.filter(DBPortal.username == username).one_or_none()
+ if dbportal:
+ return cls.from_db(dbportal)
return None
@classmethod
- def get_by_tgid(cls, tgid: int, tg_receiver: int = None, peer_type: str = None
- ) -> Optional["Portal"]:
+ def get_by_tgid(cls, tgid: TelegramID, tg_receiver: Optional[TelegramID] = None,
+ peer_type: str = None) -> Optional['Portal']:
tg_receiver = tg_receiver or tgid
tgid_full = (tgid, tg_receiver)
try:
@@ -1728,8 +1768,10 @@ class Portal:
return None
@classmethod
- def get_by_entity(cls, entity: Union[TypeChat, TypePeer, TypeUser, TypeUserFull, TypeInputPeer],
- receiver_id: int = None, create: bool = True) -> Optional["Portal"]:
+ def get_by_entity(cls, entity: Union[TypeChat, TypePeer, TypeUser, TypeUserFull,
+ TypeInputPeer],
+ receiver_id: Optional[TelegramID] = None, create: bool = True
+ ) -> Optional['Portal']:
entity_type = type(entity)
if entity_type in {Chat, ChatFull}:
type_name = "chat"
@@ -1758,9 +1800,9 @@ class Portal:
# endregion
-def init(context: Context):
+def init(context: Context) -> None:
global config
- Portal.az, Portal.db, config, Portal.loop, Portal.bot = context
+ Portal.az, Portal.db, config, Portal.loop, Portal.bot = context.core
Portal.bridge_notices = config["bridge.bridge_notices"]
Portal.filter_mode = config["bridge.filter.mode"]
Portal.filter_list = config["bridge.filter.list"]
diff --git a/mautrix_telegram/puppet.py b/mautrix_telegram/puppet.py
index f5642bc2..9c568b07 100644
--- a/mautrix_telegram/puppet.py
+++ b/mautrix_telegram/puppet.py
@@ -14,17 +14,19 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Optional, Awaitable, Pattern, Dict, List, TYPE_CHECKING
+from typing import Awaitable, Coroutine, Dict, List, NewType, Optional, Pattern, TYPE_CHECKING
from difflib import SequenceMatcher
import re
import logging
import asyncio
+from enum import Enum
from sqlalchemy import orm
-from telethon.tl.types import UserProfilePhoto
+from telethon.tl.types import UserProfilePhoto, User, FileLocation
from mautrix_appservice import AppService, IntentAPI, IntentError, MatrixRequestError
+from .types import MatrixUserID, TelegramID
from .db import Puppet as DBPuppet
from . import util
@@ -32,6 +34,11 @@ if TYPE_CHECKING:
from .matrix import MatrixHandler
from .config import Config
from .context import Context
+ from . import user as u
+ from .abstract_user import AbstractUser
+
+
+PuppetError = Enum('PuppetError', 'Success OnlyLoginSelf InvalidAccessToken')
config = None # type: Config
@@ -45,87 +52,100 @@ class Puppet:
mxid_regex = None # type: Pattern
username_template = None # type: str
hs_domain = None # type: str
- cache = {} # type: Dict[str, Puppet]
+ cache = {} # type: Dict[TelegramID, Puppet]
by_custom_mxid = {} # type: Dict[str, Puppet]
- def __init__(self, id=None, access_token=None, custom_mxid=None, username=None,
- displayname=None, displayname_source=None, photo_id=None, is_bot=None,
- is_registered=False, db_instance=None):
- self.id = id
- self.access_token = access_token
- self.custom_mxid = custom_mxid
- self.is_real_user = self.custom_mxid and self.access_token
- self.default_mxid = self.get_mxid_from_id(self.id)
- self.mxid = self.custom_mxid or self.default_mxid
+ def __init__(self,
+ id: TelegramID,
+ access_token: Optional[str] = None,
+ custom_mxid: Optional[MatrixUserID] = None,
+ username: Optional[str] = None,
+ displayname: Optional[str] = None,
+ displayname_source: Optional[TelegramID] = None,
+ photo_id: Optional[str] = None,
+ is_bot: bool = False,
+ is_registered: bool = False,
+ db_instance: Optional[DBPuppet] = None) -> None:
+ self.id = id # type: TelegramID
+ self.access_token = access_token # type: Optional[str]
+ self.custom_mxid = custom_mxid # type: Optional[MatrixUserID]
+ self.default_mxid = self.get_mxid_from_id(self.id) # type: MatrixUserID
- self.username = username
- self.displayname = displayname
- self.displayname_source = displayname_source
- self.photo_id = photo_id
- self.is_bot = is_bot
- self.is_registered = is_registered
- self._db_instance = db_instance
+ self.username = username # type: Optional[str]
+ self.displayname = displayname # type: Optional[str]
+ self.displayname_source = displayname_source # type: Optional[TelegramID]
+ self.photo_id = photo_id # type: Optional[str]
+ self.is_bot = is_bot # type: bool
+ self.is_registered = is_registered # type: bool
+ self._db_instance = db_instance # type: Optional[DBPuppet]
self.default_mxid_intent = self.az.intent.user(self.default_mxid)
- self.intent = None # type: IntentAPI
- self.refresh_intents()
+ self.intent = self._fresh_intent() # type: IntentAPI
self.cache[id] = self
if self.custom_mxid:
self.by_custom_mxid[self.custom_mxid] = self
@property
- def tgid(self):
+ def mxid(self):
+ return self.custom_mxid or self.default_mxid
+
+ @property
+ def tgid(self) -> TelegramID:
return self.id
+ @property
+ def is_real_user(self) -> bool:
+ """ Is True when the puppet is a real Matrix user. """
+ return bool(self.custom_mxid and self.access_token)
+
@staticmethod
- async def is_logged_in():
+ async def is_logged_in() -> bool:
+ """ Is True if the puppet is logged in. """
return True
# region Custom puppet management
- def refresh_intents(self):
- self.is_real_user = self.custom_mxid and self.access_token
- self.intent = (self.az.intent.user(self.custom_mxid, self.access_token)
- if self.is_real_user else self.default_mxid_intent)
+ def _fresh_intent(self) -> IntentAPI:
+ return (self.az.intent.user(self.custom_mxid, self.access_token)
+ if self.is_real_user else self.default_mxid_intent)
- async def switch_mxid(self, access_token, mxid):
+ async def switch_mxid(self, access_token: str, mxid: MatrixUserID) -> PuppetError:
prev_mxid = self.custom_mxid
self.custom_mxid = mxid
self.access_token = access_token
- self.refresh_intents()
+ self.intent = self._fresh_intent()
err = await self.init_custom_mxid()
- if err != 0:
+ if err != PuppetError.Success:
return err
try:
- del self.by_custom_mxid[prev_mxid]
+ del self.by_custom_mxid[prev_mxid] # type: ignore
except KeyError:
pass
- self.mxid = self.custom_mxid or self.default_mxid
if self.mxid != self.default_mxid:
self.by_custom_mxid[self.mxid] = self
await self.leave_rooms_with_default_user()
self.save()
- return 0
+ return PuppetError.Success
- async def init_custom_mxid(self):
+ async def init_custom_mxid(self) -> PuppetError:
if not self.is_real_user:
- return 0
+ return PuppetError.Success
mxid = await self.intent.whoami()
if not mxid or mxid != self.custom_mxid:
self.custom_mxid = None
self.access_token = None
- self.refresh_intents()
+ self.intent = self._fresh_intent()
if mxid != self.custom_mxid:
- return 2
- return 1
+ return PuppetError.OnlyLoginSelf
+ return PuppetError.InvalidAccessToken
if config["bridge.sync_with_custom_puppets"]:
asyncio.ensure_future(self.sync(), loop=self.loop)
- return 0
+ return PuppetError.Success
- async def leave_rooms_with_default_user(self):
+ async def leave_rooms_with_default_user(self) -> None:
for room_id in await self.default_mxid_intent.get_joined_rooms():
try:
await self.default_mxid_intent.leave_room(room_id)
@@ -159,7 +179,7 @@ class Puppet:
},
})
- def filter_events(self, events):
+ def filter_events(self, events: List[Dict]) -> List:
new_events = []
for event in events:
evt_type = event.get("type", None)
@@ -186,28 +206,28 @@ class Puppet:
new_events.append(event)
return new_events
- def handle_sync(self, presence, ephemeral):
- presence = [self.mx.try_handle_event(event) for event in presence]
+ def handle_sync(self, presence: List, ephemeral: Dict) -> None:
+ presence_events = [self.mx.try_handle_event(event) for event in presence]
for room_id, events in ephemeral.items():
for event in events:
event["room_id"] = room_id
- ephemeral = [self.mx.try_handle_event(event)
- for events in ephemeral.values()
- for event in self.filter_events(events)]
+ ephemeral_events = [self.mx.try_handle_event(event)
+ for events in ephemeral.values()
+ for event in self.filter_events(events)]
- events = ephemeral + presence
+ events = ephemeral_events + presence_events # List[Callable[[int], Awaitable[None]]]
coro = asyncio.gather(*events, loop=self.loop)
asyncio.ensure_future(coro, loop=self.loop)
- async def sync(self):
+ async def sync(self) -> None:
try:
await self._sync()
except Exception:
self.log.exception("Fatal error syncing")
- async def _sync(self):
+ async def _sync(self) -> None:
if not self.is_real_user:
self.log.warning("Called sync() for non-custom puppet.")
return
@@ -220,13 +240,14 @@ class Puppet:
while access_token_at_start == self.access_token:
try:
sync_resp = await self.intent.client.sync(filter=filter_id, since=next_batch,
- set_presence="offline")
+ set_presence="offline") # type: Dict
errors = 0
if next_batch is not None:
- presence = sync_resp.get("presence", {}).get("events", [])
+ presence = sync_resp.get("presence", {}).get("events", []) # type: List
ephemeral = {room: data.get("ephemeral", {}).get("events", [])
for room, data
- in sync_resp.get("rooms", {}).get("join", {}).items()}
+ in sync_resp.get("rooms", {}).get("join", {}).items()
+ } # type: Dict
self.handle_sync(presence, ephemeral)
next_batch = sync_resp.get("next_batch", None)
except MatrixRequestError as e:
@@ -241,25 +262,25 @@ class Puppet:
# region DB conversion
@property
- def db_instance(self):
+ def db_instance(self) -> DBPuppet:
if not self._db_instance:
self._db_instance = self.new_db_instance()
return self._db_instance
- def new_db_instance(self):
+ def new_db_instance(self) -> DBPuppet:
return DBPuppet(id=self.id, access_token=self.access_token, custom_mxid=self.custom_mxid,
username=self.username, displayname=self.displayname,
displayname_source=self.displayname_source, photo_id=self.photo_id,
is_bot=self.is_bot, matrix_registered=self.is_registered)
@classmethod
- def from_db(cls, db_puppet):
+ def from_db(cls, db_puppet: DBPuppet) -> 'Puppet':
return Puppet(db_puppet.id, db_puppet.access_token, db_puppet.custom_mxid,
db_puppet.username, db_puppet.displayname, db_puppet.displayname_source,
db_puppet.photo_id, db_puppet.is_bot, db_puppet.matrix_registered,
db_instance=db_puppet)
- def save(self):
+ def save(self) -> None:
self.db_instance.access_token = self.access_token
self.db_instance.custom_mxid = self.custom_mxid
self.db_instance.username = self.username
@@ -272,16 +293,16 @@ class Puppet:
# endregion
# region Info updating
- def similarity(self, query):
+ def similarity(self, query: str) -> int:
username_similarity = (SequenceMatcher(None, self.username, query).ratio()
if self.username else 0)
displayname_similarity = (SequenceMatcher(None, self.displayname, query).ratio()
if self.displayname else 0)
similarity = max(username_similarity, displayname_similarity)
- return round(similarity * 1000) / 10
+ return int(round(similarity * 1000) / 10)
@staticmethod
- def get_displayname(info, enable_format=True):
+ def get_displayname(info: User, enable_format: bool = True) -> str:
data = {
"phone number": info.phone if hasattr(info, "phone") else None,
"username": info.username,
@@ -308,7 +329,7 @@ class Puppet:
return config.get("bridge.displayname_template", "{displayname} (Telegram)").format(
displayname=name)
- async def update_info(self, source, info):
+ async def update_info(self, source: 'AbstractUser', info: User) -> None:
changed = False
if self.username != info.username:
self.username = info.username
@@ -323,24 +344,26 @@ class Puppet:
if changed:
self.save()
- async def update_displayname(self, source, info):
+ async def update_displayname(self, source: 'AbstractUser', info: User) -> bool:
ignore_source = (not source.is_relaybot
and self.displayname_source is not None
and self.displayname_source != source.tgid)
if ignore_source:
- return
+ return False
displayname = self.get_displayname(info)
if displayname != self.displayname:
await self.default_mxid_intent.set_display_name(displayname)
self.displayname = displayname
- self.displayname_source = source.tgid
+ self.displayname_source = TelegramID(source.tgid)
return True
elif source.is_relaybot or self.displayname_source is None:
- self.displayname_source = source.tgid
+ self.displayname_source = TelegramID(source.tgid)
return True
+ else:
+ return False
- async def update_avatar(self, source, photo):
+ async def update_avatar(self, source: 'AbstractUser', photo: FileLocation) -> bool:
photo_id = f"{photo.volume_id}-{photo.local_id}"
if self.photo_id != photo_id:
file = await util.transfer_file_to_matrix(self.db, source.client,
@@ -355,7 +378,7 @@ class Puppet:
# region Getters
@classmethod
- def get(cls, tgid, create=True) -> "Optional[Puppet]":
+ def get(cls, tgid: TelegramID, create: bool = True) -> Optional['Puppet']:
try:
return cls.cache[tgid]
except KeyError:
@@ -374,12 +397,15 @@ class Puppet:
return None
@classmethod
- def get_by_mxid(cls, mxid, create=True) -> "Optional[Puppet]":
+ def get_by_mxid(cls, mxid: MatrixUserID, create: bool = True) -> Optional['Puppet']:
tgid = cls.get_id_from_mxid(mxid)
- return cls.get(tgid, create) if tgid else None
+ if tgid:
+ return cls.get(tgid, create)
+
+ return None
@classmethod
- def get_by_custom_mxid(cls, mxid):
+ def get_by_custom_mxid(cls, mxid: MatrixUserID) -> Optional['Puppet']:
if not mxid:
raise ValueError("Matrix ID can't be empty")
@@ -396,25 +422,25 @@ class Puppet:
return None
@classmethod
- def get_all_with_custom_mxid(cls):
+ def get_all_with_custom_mxid(cls) -> List['Puppet']:
return [cls.by_custom_mxid[puppet.mxid]
if puppet.custom_mxid in cls.by_custom_mxid
else cls.from_db(puppet)
for puppet in DBPuppet.query.filter(DBPuppet.custom_mxid is not None).all()]
@classmethod
- def get_id_from_mxid(cls, mxid):
+ def get_id_from_mxid(cls, mxid: MatrixUserID) -> Optional[TelegramID]:
match = cls.mxid_regex.match(mxid)
if match:
- return int(match.group(1))
+ return TelegramID(int(match.group(1)))
return None
@classmethod
- def get_mxid_from_id(cls, tgid):
- return f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}"
+ def get_mxid_from_id(cls, tgid: TelegramID) -> MatrixUserID:
+ return MatrixUserID(f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}")
@classmethod
- def find_by_username(cls, username) -> "Optional[Puppet]":
+ def find_by_username(cls, username: str) -> Optional['Puppet']:
if not username:
return None
@@ -422,14 +448,14 @@ class Puppet:
if puppet.username and puppet.username.lower() == username.lower():
return puppet
- puppet = DBPuppet.query.filter(DBPuppet.username == username).one_or_none()
- if puppet:
- return cls.from_db(puppet)
+ dbpuppet = DBPuppet.query.filter(DBPuppet.username == username).one_or_none()
+ if dbpuppet:
+ return cls.from_db(dbpuppet)
return None
@classmethod
- def find_by_displayname(cls, displayname) -> "Optional[Puppet]":
+ def find_by_displayname(cls, displayname: str) -> Optional['Puppet']:
if not displayname:
return None
@@ -437,17 +463,17 @@ class Puppet:
if puppet.displayname and puppet.displayname == displayname:
return puppet
- puppet = DBPuppet.query.filter(DBPuppet.displayname == displayname).one_or_none()
- if puppet:
- return cls.from_db(puppet)
+ dbpuppet = DBPuppet.query.filter(DBPuppet.displayname == displayname).one_or_none()
+ if dbpuppet:
+ return cls.from_db(dbpuppet)
return None
# endregion
-def init(context: "Context") -> List[Awaitable[int]]:
+def init(context: 'Context') -> List[Coroutine]: # [None, None, PuppetError]
global config
- Puppet.az, Puppet.db, config, Puppet.loop, _ = context
+ Puppet.az, Puppet.db, config, Puppet.loop, _ = context.core
Puppet.mx = context.mx
Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}")
Puppet.hs_domain = config["homeserver"]["domain"]
diff --git a/mautrix_telegram/scripts/telematrix_import/__main__.py b/mautrix_telegram/scripts/telematrix_import/__main__.py
index 2de531c7..119c7689 100644
--- a/mautrix_telegram/scripts/telematrix_import/__main__.py
+++ b/mautrix_telegram/scripts/telematrix_import/__main__.py
@@ -40,7 +40,7 @@ telematrix_db_engine.dispose()
portals = {}
chats = {}
messages = {}
-puppets = {}
+puppets = {} # Dict[int, Puppet]
for chat_link in chat_links:
if type(chat_link.tg_room) is str:
diff --git a/mautrix_telegram/sqlstatestore.py b/mautrix_telegram/sqlstatestore.py
index 68e9fd9d..ee5b3609 100644
--- a/mautrix_telegram/sqlstatestore.py
+++ b/mautrix_telegram/sqlstatestore.py
@@ -20,37 +20,39 @@ from sqlalchemy import orm
from mautrix_appservice import StateStore
+from .types import MatrixUserID, MatrixRoomID
from . import puppet as pu
from .db import RoomState, UserProfile
class SQLStateStore(StateStore):
- def __init__(self, db):
+ def __init__(self, db: orm.Session) -> None:
super().__init__()
self.db = db # type: orm.Session
self.profile_cache = {} # type: Dict[Tuple[str, str], UserProfile]
self.room_state_cache = {} # type: Dict[str, RoomState]
@staticmethod
- def is_registered(user: str) -> bool:
+ def is_registered(user: MatrixUserID) -> bool:
puppet = pu.Puppet.get_by_mxid(user)
return puppet.is_registered if puppet else False
@staticmethod
- def registered(user: str):
+ def registered(user: MatrixUserID) -> None:
puppet = pu.Puppet.get_by_mxid(user)
if puppet:
puppet.is_registered = True
puppet.save()
- def update_state(self, event: dict):
+ def update_state(self, event: Dict) -> None:
event_type = event["type"]
if event_type == "m.room.power_levels":
self.set_power_levels(event["room_id"], event["content"])
elif event_type == "m.room.member":
self.set_member(event["room_id"], event["state_key"], event["content"])
- def _get_user_profile(self, room_id: str, user_id: str, create: bool = True) -> UserProfile:
+ def _get_user_profile(self, room_id: MatrixRoomID, user_id: MatrixUserID, create: bool = True
+ ) -> UserProfile:
key = (room_id, user_id)
try:
return self.profile_cache[key]
@@ -67,22 +69,22 @@ class SQLStateStore(StateStore):
self.profile_cache[key] = profile
return profile
- def get_member(self, room: str, user: str) -> dict:
+ def get_member(self, room: MatrixRoomID, user: MatrixUserID) -> Dict:
return self._get_user_profile(room, user).dict()
- def set_member(self, room: str, user: str, member: dict):
+ def set_member(self, room: MatrixRoomID, user: MatrixUserID, member: Dict) -> None:
profile = self._get_user_profile(room, user)
profile.membership = member.get("membership", profile.membership or "leave")
profile.displayname = member.get("displayname", profile.displayname)
profile.avatar_url = member.get("avatar_url", profile.avatar_url)
self.db.commit()
- def set_membership(self, room: str, user: str, membership: str):
+ def set_membership(self, room: MatrixRoomID, user: MatrixUserID, membership: str) -> None:
self.set_member(room, user, {
"membership": membership,
})
- def _get_room_state(self, room_id: str, create: bool = True) -> RoomState:
+ def _get_room_state(self, room_id: MatrixRoomID, create: bool = True) -> RoomState:
try:
return self.room_state_cache[room_id]
except KeyError:
@@ -96,13 +98,13 @@ class SQLStateStore(StateStore):
self.room_state_cache[room_id] = room
return room
- def has_power_levels(self, room: str) -> bool:
+ def has_power_levels(self, room: MatrixRoomID) -> bool:
return self._get_room_state(room).has_power_levels
- def get_power_levels(self, room: str) -> dict:
+ def get_power_levels(self, room: MatrixRoomID) -> Dict:
return self._get_room_state(room).power_levels
- def set_power_level(self, room: str, user: str, level: int):
+ def set_power_level(self, room: MatrixRoomID, user: MatrixUserID, level: int) -> None:
room_state = self._get_room_state(room)
power_levels = room_state.power_levels
if not power_levels:
@@ -114,7 +116,7 @@ class SQLStateStore(StateStore):
room_state.power_levels = power_levels
self.db.commit()
- def set_power_levels(self, room: str, content: dict):
+ def set_power_levels(self, room: MatrixRoomID, content: Dict) -> None:
state = self._get_room_state(room)
state.power_levels = content
self.db.commit()
diff --git a/mautrix_telegram/tgclient.py b/mautrix_telegram/tgclient.py
index 4534524e..769d431c 100644
--- a/mautrix_telegram/tgclient.py
+++ b/mautrix_telegram/tgclient.py
@@ -14,9 +14,13 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from typing import List, Union, Optional
+
from telethon import TelegramClient, utils
from telethon.tl.functions.messages import SendMediaRequest
-from telethon.tl.types import *
+from telethon.tl.types import (
+ InputMediaUploadedDocument, InputMediaUploadedPhoto, TypeDocumentAttribute, TypeInputMedia,
+ TypeInputPeer, TypeMessageEntity, TypeMessageMedia, TypePeer)
from telethon.tl import custom
diff --git a/mautrix_telegram/types.py b/mautrix_telegram/types.py
new file mode 100644
index 00000000..9ae209f7
--- /dev/null
+++ b/mautrix_telegram/types.py
@@ -0,0 +1,10 @@
+from typing import Dict, NewType
+
+# MatrixId = NewType('MatrixId', str)
+MatrixUserID = NewType('MatrixUserID', str)
+MatrixRoomID = NewType('MatrixRoomID', str)
+MatrixEventID = NewType('MatrixEventID', str)
+
+MatrixEvent = NewType('MatrixEvent', Dict)
+
+TelegramID = NewType('TelegramID', int)
diff --git a/mautrix_telegram/user.py b/mautrix_telegram/user.py
index c8a28131..af86ce46 100644
--- a/mautrix_telegram/user.py
+++ b/mautrix_telegram/user.py
@@ -14,18 +14,21 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Dict, Awaitable, Optional, Match, Tuple, TYPE_CHECKING
+from typing import Coroutine, Dict, List, Match, NewType, Optional, Tuple, cast, TYPE_CHECKING
import logging
import asyncio
import re
-from telethon.tl.types import *
+from telethon.tl.types import (
+ TypeUpdate, UpdateNewMessage, UpdateNewChannelMessage, PeerUser,
+ UpdateShortChatMessage, UpdateShortMessage)
from telethon.tl.types import User as TLUser
from telethon.tl.types.contacts import ContactsNotModified
from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest
from telethon.tl.functions.account import UpdateStatusRequest
from mautrix_appservice import MatrixRequestError
+from .types import MatrixUserID, TelegramID
from .db import User as DBUser, Contact as DBContact, Portal as DBPortal
from .abstract_user import AbstractUser
from . import portal as po, puppet as pu
@@ -36,7 +39,7 @@ if TYPE_CHECKING:
config = None # type: Config
-SearchResults = List[Tuple["pu.Puppet", int]]
+SearchResult = NewType('SearchResult', Tuple['pu.Puppet', int])
class User(AbstractUser):
@@ -44,23 +47,23 @@ class User(AbstractUser):
by_mxid = {} # type: Dict[str, User]
by_tgid = {} # type: Dict[int, User]
- def __init__(self, mxid: str, tgid: Optional[int] = None, username: Optional[str] = None,
- db_contacts: Optional[List[DBContact]] = None, saved_contacts: int = 0,
- is_bot: bool = False, db_portals: Optional[List[DBPortal]] = None,
- db_instance: Optional[DBUser] = None):
+ def __init__(self, mxid: MatrixUserID, tgid: Optional[TelegramID] = None,
+ username: Optional[str] = None, db_contacts: Optional[List[DBContact]] = None,
+ saved_contacts: int = 0, is_bot: bool = False, db_portals: List[DBPortal] = [],
+ db_instance: Optional[DBUser] = None) -> None:
super().__init__()
- self.mxid = mxid # type: str
- self.tgid = tgid # type: int
+ self.mxid = mxid # type: MatrixUserID
+ self.tgid = tgid # type: TelegramID
self.is_bot = is_bot # type: bool
self.username = username # type: str
self.contacts = [] # type: List[pu.Puppet]
self.saved_contacts = saved_contacts # type: int
self.db_contacts = db_contacts # type: List[DBContact]
self.portals = {} # type: Dict[Tuple[int, int], po.Portal]
- self.db_portals = db_portals # type: List[DBPortal]
- self._db_instance = db_instance # type: DBUser
+ self.db_portals = db_portals or [] # type: List[DBPortal]
+ self._db_instance = db_instance # type: Optional[DBUser]
- self.command_status = None # type: dict
+ self.command_status = None # type: Dict
(self.relaybot_whitelisted,
self.whitelisted,
@@ -93,7 +96,7 @@ class User(AbstractUser):
for puppet in self.contacts]
@db_contacts.setter
- def db_contacts(self, contacts: List[DBContact]):
+ def db_contacts(self, contacts: List[DBContact]) -> None:
self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts] if contacts else []
@property
@@ -101,7 +104,7 @@ class User(AbstractUser):
return [portal.db_instance for portal in self.portals.values() if not portal.deleted]
@db_portals.setter
- def db_portals(self, portals: List[DBPortal]):
+ def db_portals(self, portals: List[DBPortal]) -> None:
self.portals = {(portal.tgid, portal.tg_receiver):
po.Portal.get_by_tgid(portal.tgid, portal.tg_receiver)
for portal in portals} if portals else {}
@@ -119,7 +122,7 @@ class User(AbstractUser):
contacts=self.db_contacts, saved_contacts=self.saved_contacts or 0,
portals=self.db_portals)
- def save(self):
+ def save(self) -> None:
self.db_instance.tgid = self.tgid
self.db_instance.username = self.username
self.db_instance.contacts = self.db_contacts
@@ -127,7 +130,7 @@ class User(AbstractUser):
self.db_instance.portals = self.db_portals
self.db.commit()
- def delete(self):
+ def delete(self) -> None:
try:
del self.by_mxid[self.mxid]
del self.by_tgid[self.tgid]
@@ -138,14 +141,14 @@ class User(AbstractUser):
self.db.commit()
@classmethod
- def from_db(cls, db_user: DBUser) -> "User":
+ def from_db(cls, db_user: DBUser) -> 'User':
return User(db_user.mxid, db_user.tgid, db_user.tg_username, db_user.contacts,
False, db_user.saved_contacts, db_user.portals, db_instance=db_user)
# endregion
# region Telegram connection management
- async def start(self, delete_unless_authenticated: bool = False) -> "User":
+ async def start(self, delete_unless_authenticated: bool = False) -> 'User':
await super().start()
if await self.is_logged_in():
self.log.debug(f"Ensuring post_login() for {self.name}")
@@ -156,7 +159,7 @@ class User(AbstractUser):
self.client.session.delete()
return self
- async def post_login(self, info: TLUser = None):
+ async def post_login(self, info: TLUser = None) -> None:
try:
await self.update_info(info)
if not self.is_bot:
@@ -167,9 +170,9 @@ class User(AbstractUser):
except Exception:
self.log.exception("Failed to run post-login functions for %s", self.mxid)
- async def update(self, update: TypeUpdate):
+ async def update(self, update: TypeUpdate) -> bool:
if not self.is_bot:
- return
+ return False
if isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)):
message = update.message
@@ -183,22 +186,25 @@ class User(AbstractUser):
elif isinstance(update, UpdateShortMessage):
portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user")
else:
- return
+ return False
- self.register_portal(portal)
+ if portal:
+ self.register_portal(portal)
+
+ return True
# endregion
# region Telegram actions that need custom methods
- def ensure_started(self, even_if_no_session: bool = False) -> "Awaitable[User]":
- return super().ensure_started(even_if_no_session)
+ def ensure_started(self, even_if_no_session: bool = False) -> Coroutine[None, None, 'User']:
+ return cast(Coroutine[None, None, 'User'], super().ensure_started(even_if_no_session))
- def set_presence(self, online: bool = True):
+ def set_presence(self, online: bool = True) -> bool:
if self.is_bot:
- return
+ return False
return self.client(UpdateStatusRequest(offline=not online))
- async def update_info(self, info: TLUser = None):
+ async def update_info(self, info: TLUser = None) -> None:
info = info or await self.client.get_me()
changed = False
if self.is_bot != info.bot:
@@ -213,7 +219,7 @@ class User(AbstractUser):
if changed:
self.save()
- async def log_out(self):
+ async def log_out(self) -> bool:
puppet = pu.Puppet.get(self.tgid)
if puppet.is_real_user:
await puppet.switch_mxid(None, None)
@@ -241,28 +247,29 @@ class User(AbstractUser):
return True
def _search_local(self, query: str, max_results: int = 5, min_similarity: int = 45
- ) -> SearchResults:
- results = [] # type: SearchResults
+ ) -> List[SearchResult]:
+ results = [] # type: List[SearchResult]
for contact in self.contacts:
similarity = contact.similarity(query)
if similarity >= min_similarity:
- results.append((contact, similarity))
+ results.append(SearchResult((contact, similarity)))
results.sort(key=lambda tup: tup[1], reverse=True)
return results[0:max_results]
- async def _search_remote(self, query: str, max_results: int = 5) -> SearchResults:
+ async def _search_remote(self, query: str, max_results: int = 5) -> List[SearchResult]:
if len(query) < 5:
return []
server_results = await self.client(SearchRequest(q=query, limit=max_results))
- results = [] # type: SearchResults
+ results = [] # type: List[SearchResult]
for user in server_results.users:
puppet = pu.Puppet.get(user.id)
await puppet.update_info(self, user)
- results.append((puppet, puppet.similarity(query)))
+ results.append(SearchResult((puppet, puppet.similarity(query))))
results.sort(key=lambda tup: tup[1], reverse=True)
return results[0:max_results]
- async def search(self, query: str, force_remote: bool = False) -> Tuple[SearchResults, bool]:
+ async def search(self, query: str, force_remote: bool = False
+ ) -> Tuple[List[SearchResult], bool]:
if force_remote:
return await self._search_remote(query), True
@@ -272,7 +279,7 @@ class User(AbstractUser):
return await self._search_remote(query), True
- async def sync_dialogs(self, synchronous_create: bool = False):
+ async def sync_dialogs(self, synchronous_create: bool = False) -> None:
creators = []
for entity in await self.get_dialogs(limit=30):
portal = po.Portal.get_by_entity(entity)
@@ -283,7 +290,7 @@ class User(AbstractUser):
self.save()
await asyncio.gather(*creators, loop=self.loop)
- def register_portal(self, portal: po.Portal):
+ def register_portal(self, portal: po.Portal) -> None:
try:
if self.portals[portal.tgid_full] == portal:
return
@@ -292,7 +299,7 @@ class User(AbstractUser):
self.portals[portal.tgid_full] = portal
self.save()
- def unregister_portal(self, portal: po.Portal):
+ def unregister_portal(self, portal: po.Portal) -> None:
try:
del self.portals[portal.tgid_full]
self.save()
@@ -309,7 +316,7 @@ class User(AbstractUser):
acc = (acc * 20261 + id) & 0xffffffff
return acc & 0x7fffffff
- async def sync_contacts(self):
+ async def sync_contacts(self) -> None:
response = await self.client(GetContactsRequest(hash=self._hash_contacts()))
if isinstance(response, ContactsNotModified):
return
@@ -326,7 +333,7 @@ class User(AbstractUser):
# region Class instance lookup
@classmethod
- def get_by_mxid(cls, mxid: str, create: bool=True) -> "Optional[User]":
+ def get_by_mxid(cls, mxid: MatrixUserID, create: bool = True) -> Optional['User']:
if not mxid:
raise ValueError("Matrix ID can't be empty")
@@ -349,7 +356,7 @@ class User(AbstractUser):
return None
@classmethod
- def get_by_tgid(cls, tgid: int) -> "Optional[User]":
+ def get_by_tgid(cls, tgid: int) -> Optional['User']:
try:
return cls.by_tgid[tgid]
except KeyError:
@@ -363,7 +370,7 @@ class User(AbstractUser):
return None
@classmethod
- def find_by_username(cls, username: str) -> "Optional[User]":
+ def find_by_username(cls, username: str) -> Optional['User']:
if not username:
return None
@@ -379,7 +386,7 @@ class User(AbstractUser):
# endregion
-def init(context: "Context") -> List[Awaitable[User]]:
+def init(context: 'Context') -> List[Coroutine]: # [None, None, AbstractUser]
global config
config = context.config
diff --git a/mautrix_telegram/util/file_transfer.py b/mautrix_telegram/util/file_transfer.py
index d950b2a0..b60fc71b 100644
--- a/mautrix_telegram/util/file_transfer.py
+++ b/mautrix_telegram/util/file_transfer.py
@@ -27,7 +27,8 @@ from sqlalchemy.orm.exc import FlushError
from telethon.tl.types import (Document, FileLocation, InputFileLocation,
InputDocumentFileLocation, PhotoSize, PhotoCachedSize)
-from telethon.errors import *
+from telethon.errors import (AuthBytesInvalidError, AuthKeyInvalidError, LocationInvalidError,
+ SecurityError)
from mautrix_appservice import IntentAPI
from ..tgclient import MautrixTelegramClient
diff --git a/mautrix_telegram/util/format_duration.py b/mautrix_telegram/util/format_duration.py
index 9402b83e..44d16550 100644
--- a/mautrix_telegram/util/format_duration.py
+++ b/mautrix_telegram/util/format_duration.py
@@ -17,10 +17,10 @@
def format_duration(seconds: int) -> str:
- def pluralize(count, singular):
+ def pluralize(count: int, singular: str) -> str:
return singular if count == 1 else singular + "s"
- def include(count, word):
+ def include(count: int, word: str) -> str:
return f"{count} {pluralize(count, word)}" if count > 0 else ""
minutes, seconds = divmod(seconds, 60)
diff --git a/mautrix_telegram/util/signed_token.py b/mautrix_telegram/util/signed_token.py
index 13281012..febb2aa4 100644
--- a/mautrix_telegram/util/signed_token.py
+++ b/mautrix_telegram/util/signed_token.py
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Optional
+from typing import Dict, Optional
import json
import base64
import hashlib
@@ -28,13 +28,13 @@ def _get_checksum(key: str, payload: bytes) -> str:
return checksum
-def sign_token(key: str, payload: dict) -> str:
- payload = base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8"))
- checksum = _get_checksum(key, payload)
- return f"{checksum}:{payload.decode('utf-8')}"
+def sign_token(key: str, payload: Dict) -> str:
+ payload_b64 = base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8"))
+ checksum = _get_checksum(key, payload_b64)
+ return f"{checksum}:{payload_b64.decode('utf-8')}"
-def verify_token(key: str, data: str) -> Optional[dict]:
+def verify_token(key: str, data: str) -> Optional[Dict]:
if not data:
return None
diff --git a/mautrix_telegram/web/common/auth_api.py b/mautrix_telegram/web/common/auth_api.py
index 24fa74e9..b293f368 100644
--- a/mautrix_telegram/web/common/auth_api.py
+++ b/mautrix_telegram/web/common/auth_api.py
@@ -23,7 +23,7 @@ from telethon.errors import *
from ...commands.auth import enter_password
from ...util import format_duration
-from ...puppet import Puppet
+from ...puppet import Puppet, PuppetError
from ...user import User
@@ -51,12 +51,13 @@ class AuthAPI(abc.ABC):
"account.", errcode="already-logged-in")
resp = await puppet.switch_mxid(token, user.mxid)
- if resp == 2:
+ if resp == PuppetError.OnlyLoginSelf:
return self.get_mx_login_response(status=403, errcode="only-login-self",
error="You can only log in as your own Matrix user.")
- elif resp == 1:
+ elif resp == PuppetError.InvalidAccessToken:
return self.get_mx_login_response(status=401, errcode="invalid-access-token",
error="Failed to verify access token.")
+ assert resp == PuppetError.Success, "Encountered an unhandled PuppetError."
return self.get_mx_login_response(mxid=user.mxid, status=200, state="logged-in")
diff --git a/mautrix_telegram/web/provisioning/__init__.py b/mautrix_telegram/web/provisioning/__init__.py
index 04aa499a..50a0dddd 100644
--- a/mautrix_telegram/web/provisioning/__init__.py
+++ b/mautrix_telegram/web/provisioning/__init__.py
@@ -15,7 +15,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
from aiohttp import web
-from typing import Tuple, Optional, Callable, Awaitable, TYPE_CHECKING
+from typing import Awaitable, Callable, Dict, Optional, Tuple, TYPE_CHECKING
import asyncio
import logging
import json
@@ -24,6 +24,7 @@ from telethon.utils import get_peer_id, resolve_id
from telethon.tl.types import ChatForbidden, ChannelForbidden, TypeChat
from mautrix_appservice import AppService, MatrixRequestError, IntentError
+from ...types import MatrixUserID, TelegramID
from ...user import User
from ...portal import Portal
from ...commands.portal import user_has_power_level, get_initial_state
@@ -36,7 +37,7 @@ if TYPE_CHECKING:
class ProvisioningAPI(AuthAPI):
log = logging.getLogger("mau.web.provisioning")
- def __init__(self, context: "Context"):
+ def __init__(self, context: "Context") -> None:
super().__init__(context.loop)
self.secret = context.config["appservice.provisioning.shared_secret"]
self.az = context.az # type: AppService
@@ -118,10 +119,10 @@ class ProvisioningAPI(AuthAPI):
chat_id = request.match_info["chat_id"]
if chat_id.startswith("-100"):
- tgid = int(chat_id[4:])
+ tgid = TelegramID(int(chat_id[4:]))
peer_type = "channel"
elif chat_id.startswith("-"):
- tgid = -int(chat_id)
+ tgid = TelegramID(-int(chat_id))
peer_type = "chat"
else:
return self.get_error_response(400, "tgid_invalid", "Invalid Telegram chat ID.")
@@ -153,14 +154,14 @@ class ProvisioningAPI(AuthAPI):
"Matrix room.")
is_logged_in = user is not None and await user.is_logged_in()
- user = user if is_logged_in else self.context.bot
- if not user:
+ acting_user = user if is_logged_in else self.context.bot
+ if not acting_user:
return self.get_login_response(status=403, errcode="not_logged_in",
error="You are not logged in and there is no relay bot.")
entity = None # type: Optional[TypeChat]
try:
- entity = await user.client.get_entity(portal.peer)
+ entity = await acting_user.client.get_entity(portal.peer)
except Exception:
self.log.exception("Failed to get_entity(%s) for manual bridging.", portal.peer)
@@ -411,7 +412,7 @@ class ProvisioningAPI(AuthAPI):
except json.JSONDecodeError:
return None
- async def get_user(self, mxid: str, expect_logged_in: Optional[bool] = False,
+ async def get_user(self, mxid: MatrixUserID, expect_logged_in: Optional[bool] = False,
require_puppeting: bool = True, require_user: bool = True
) -> Tuple[Optional[User], Optional[web.Response]]:
if not mxid:
@@ -439,7 +440,7 @@ class ProvisioningAPI(AuthAPI):
expect_logged_in: Optional[bool] = False,
require_puppeting: bool = False,
want_data: bool = True,
- ) -> (Tuple[Optional[dict],
+ ) -> (Tuple[Optional[Dict],
Optional[User],
Optional[web.Response]]):
err = self.check_authorization(request)