Add missing type hints and fix most type errors except for Optionals.

This commit is contained in:
Kai A. Hiller
2018-08-09 02:19:55 +02:00
parent 01e153662e
commit 0f8009b1e9
26 changed files with 505 additions and 384 deletions
+2 -2
View File
@@ -14,7 +14,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional from typing import Coroutine, List, Optional
import argparse import argparse
import asyncio import asyncio
import logging.config import logging.config
@@ -115,7 +115,7 @@ with appserv.run(config["appservice.hostname"], config["appservice.port"]) as st
startup_actions = (init_puppet(context) + startup_actions = (init_puppet(context) +
init_user(context) + init_user(context) +
[start, [start,
context.mx.init_as_bot()]) context.mx.init_as_bot()]) # type: List[Coroutine]
if context.bot: if context.bot:
startup_actions.append(context.bot.start()) startup_actions.append(context.bot.start())
+4 -2
View File
@@ -38,6 +38,7 @@ from .db import Message as DBMessage
from .tgclient import MautrixTelegramClient from .tgclient import MautrixTelegramClient
if TYPE_CHECKING: if TYPE_CHECKING:
from .types import TelegramId
from .context import Context from .context import Context
from .config import Config from .config import Config
from .bot import Bot from .bot import Bot
@@ -67,10 +68,11 @@ class AbstractUser(ABC):
self.whitelisted = False # type: bool self.whitelisted = False # type: bool
self.relaybot_whitelisted = False # type: bool self.relaybot_whitelisted = False # type: bool
self.client = None # type: MautrixTelegramClient self.client = None # type: MautrixTelegramClient
self.tgid = None # type: int self.tgid = None # type: TelegramId
self.mxid = None # type: str self.mxid = None # type: str
self.is_relaybot = False # type: bool self.is_relaybot = False # type: bool
self.is_bot = False # type: bool self.is_bot = False # type: bool
self.relaybot = None # type: Optional[Bot]
@property @property
def connected(self) -> bool: def connected(self) -> bool:
@@ -372,7 +374,7 @@ class AbstractUser(ABC):
def init(context: "Context") -> None: def init(context: "Context") -> None:
global config, MAX_DELETIONS global config, MAX_DELETIONS
AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, AbstractUser.relaybot = context AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, AbstractUser.relaybot = context.core
AbstractUser.ignore_incoming_bot_events = config["bridge.relaybot.ignore_own_incoming_events"] AbstractUser.ignore_incoming_bot_events = config["bridge.relaybot.ignore_own_incoming_events"]
AbstractUser.session_container = context.session_container AbstractUser.session_container = context.session_container
MAX_DELETIONS = config.get("bridge.max_telegram_delete", 10) MAX_DELETIONS = config.get("bridge.max_telegram_delete", 10)
+17 -12
View File
@@ -14,7 +14,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Callable, Dict, Optional, Pattern, TYPE_CHECKING from typing import Awaitable, Callable, Dict, List, Optional, Pattern, TYPE_CHECKING
import logging import logging
import re import re
@@ -27,12 +27,14 @@ from telethon.tl.functions.messages import GetChatsRequest, GetFullChatRequest
from telethon.tl.functions.channels import GetChannelsRequest, GetParticipantRequest from telethon.tl.functions.channels import GetChannelsRequest, GetParticipantRequest
from telethon.errors import ChannelInvalidError, ChannelPrivateError from telethon.errors import ChannelInvalidError, ChannelPrivateError
from .types import MatrixUserId
from .abstract_user import AbstractUser from .abstract_user import AbstractUser
from .db import BotChat from .db import BotChat
from . import puppet as pu, portal as po, user as u from . import puppet as pu, portal as po, user as u
if TYPE_CHECKING: if TYPE_CHECKING:
from .config import Config from .config import Config
from .context import Context
config = None # type: Config config = None # type: Config
@@ -145,6 +147,7 @@ class Bot(AbstractUser):
for p in participants: for p in participants:
if p.user_id == tgid: if p.user_id == tgid:
return isinstance(p, (ChatParticipantCreator, ChatParticipantAdmin)) return isinstance(p, (ChatParticipantCreator, ChatParticipantAdmin))
return False
async def check_can_use_commands(self, event: Message, reply: ReplyFunc) -> bool: async def check_can_use_commands(self, event: Message, reply: ReplyFunc) -> bool:
if not await self._can_use_commands(event.to_id, event.from_id): if not await self._can_use_commands(event.to_id, event.from_id):
@@ -168,15 +171,16 @@ class Bot(AbstractUser):
return await reply( return await reply(
"Portal is not public. Use `/invite <mxid>` to get an invite.") "Portal is not public. Use `/invite <mxid>` to get an invite.")
async def handle_command_invite(self, portal: po.Portal, reply: ReplyFunc, mxid: str) -> None: async def handle_command_invite(self, portal: po.Portal, reply: ReplyFunc,
if len(mxid) == 0: mxid_input: MatrixUserId) -> Message:
if len(mxid_input) == 0:
return await reply("Usage: `/invite <mxid>`") return await reply("Usage: `/invite <mxid>`")
elif not portal.mxid: elif not portal.mxid:
return await reply("Portal does not have Matrix room. " return await reply("Portal does not have Matrix room. "
"Create one with /portal first.") "Create one with /portal first.")
if not self.mxid_regex.match(mxid): if not self.mxid_regex.match(mxid_input):
return await reply("That doesn't look like a Matrix ID.") return await reply("That doesn't look like a Matrix ID.")
user = await u.User.get_by_mxid(mxid).ensure_started() user = await u.User.get_by_mxid(MatrixUserId(mxid_input)).ensure_started()
if not user.relaybot_whitelisted: if not user.relaybot_whitelisted:
return await reply("That user is not whitelisted to use the bridge.") return await reply("That user is not whitelisted to use the bridge.")
elif await user.is_logged_in(): elif await user.is_logged_in():
@@ -187,7 +191,7 @@ class Bot(AbstractUser):
await portal.main_intent.invite(portal.mxid, user.mxid) await portal.main_intent.invite(portal.mxid, user.mxid)
return await reply(f"Invited `{user.mxid}` to the portal.") return await reply(f"Invited `{user.mxid}` to the portal.")
def handle_command_id(self, message: Message, reply: ReplyFunc) -> None: def handle_command_id(self, message: Message, reply: ReplyFunc) -> Awaitable[Message]:
# Provide the prefixed ID to the user so that the user wouldn't need to specify whether the # Provide the prefixed ID to the user so that the user wouldn't need to specify whether the
# chat is a normal group or a supergroup/channel when using the ID. # chat is a normal group or a supergroup/channel when using the ID.
if isinstance(message.to_id, PeerChannel): if isinstance(message.to_id, PeerChannel):
@@ -210,7 +214,7 @@ class Bot(AbstractUser):
return False return False
async def handle_command(self, message: Message) -> None: async def handle_command(self, message: Message) -> None:
def reply(reply_text) -> None: def reply(reply_text: str) -> Awaitable[Message]:
return self.client.send_message(message.to_id, reply_text, reply_to=message.id) return self.client.send_message(message.to_id, reply_text, reply_to=message.id)
text = message.message text = message.message
@@ -231,7 +235,7 @@ class Bot(AbstractUser):
mxid = text[text.index(" ") + 1:] mxid = text[text.index(" ") + 1:]
except ValueError: except ValueError:
mxid = "" mxid = ""
await self.handle_command_invite(portal, reply, mxid=mxid) await self.handle_command_invite(portal, reply, mxid_input=mxid)
def handle_service_message(self, message: MessageService) -> None: def handle_service_message(self, message: MessageService) -> None:
to_id = message.to_id to_id = message.to_id
@@ -250,11 +254,12 @@ class Bot(AbstractUser):
elif isinstance(action, MessageActionChatDeleteUser) and action.user_id == self.tgid: elif isinstance(action, MessageActionChatDeleteUser) and action.user_id == self.tgid:
self.remove_chat(to_id) self.remove_chat(to_id)
async def update(self, update) -> None: async def update(self, update) -> bool:
if not isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)): if not isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)):
return return False
if isinstance(update.message, MessageService): if isinstance(update.message, MessageService):
return self.handle_service_message(update.message) self.handle_service_message(update.message)
return False
is_command = (isinstance(update.message, Message) is_command = (isinstance(update.message, Message)
and update.message.entities and len(update.message.entities) > 0 and update.message.entities and len(update.message.entities) > 0
@@ -270,7 +275,7 @@ class Bot(AbstractUser):
return "bot" return "bot"
def init(context) -> Optional[Bot]: def init(context: 'Context') -> Optional[Bot]:
global config global config
config = context.config config = context.config
token = config["telegram.bot_token"] token = config["telegram.bot_token"]
+24 -18
View File
@@ -14,7 +14,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict from typing import Any, Awaitable, Dict, Optional
import asyncio import asyncio
from telethon.errors import ( from telethon.errors import (
@@ -31,7 +31,7 @@ from ..util import format_duration
@command_handler(needs_auth=False, @command_handler(needs_auth=False,
help_section=SECTION_AUTH, help_section=SECTION_AUTH,
help_text="Check if you're logged into Telegram.") help_text="Check if you're logged into Telegram.")
async def ping(evt: CommandEvent) -> None: async def ping(evt: CommandEvent) -> Optional[Dict]:
me = await evt.sender.client.get_me() if await evt.sender.is_logged_in() else None me = await evt.sender.client.get_me() if await evt.sender.is_logged_in() else None
if me: if me:
return await evt.reply(f"You're logged in as @{me.username}") return await evt.reply(f"You're logged in as @{me.username}")
@@ -42,7 +42,7 @@ async def ping(evt: CommandEvent) -> None:
@command_handler(needs_auth=False, needs_puppeting=False, @command_handler(needs_auth=False, needs_puppeting=False,
help_section=SECTION_AUTH, help_section=SECTION_AUTH,
help_text="Get the info of the message relay Telegram bot.") help_text="Get the info of the message relay Telegram bot.")
async def ping_bot(evt: CommandEvent) -> None: async def ping_bot(evt: CommandEvent) -> Optional[Dict]:
if not evt.tgbot: if not evt.tgbot:
return await evt.reply("Telegram message relay bot not configured.") return await evt.reply("Telegram message relay bot not configured.")
bot_info = await evt.tgbot.client.get_me() bot_info = await evt.tgbot.client.get_me()
@@ -57,19 +57,19 @@ async def ping_bot(evt: CommandEvent) -> None:
help_section=SECTION_AUTH, help_section=SECTION_AUTH,
help_text="Revert your Telegram account's Matrix puppet to use the default Matrix " help_text="Revert your Telegram account's Matrix puppet to use the default Matrix "
"account.") "account.")
async def logout_matrix(evt: CommandEvent) -> None: async def logout_matrix(evt: CommandEvent) -> Optional[Dict]:
puppet = pu.Puppet.get(evt.sender.tgid) puppet = pu.Puppet.get(evt.sender.tgid)
if not puppet.is_real_user: if not puppet.is_real_user:
return await evt.reply("You are not logged in with your Matrix account.") return await evt.reply("You are not logged in with your Matrix account.")
await puppet.switch_mxid(None, None) await puppet.switch_mxid(None, None)
await evt.reply("Reverted your Telegram account's Matrix puppet back to the default.") return await evt.reply("Reverted your Telegram account's Matrix puppet back to the default.")
@command_handler(needs_auth=True, management_only=True, needs_matrix_puppeting=True, @command_handler(needs_auth=True, management_only=True, needs_matrix_puppeting=True,
help_section=SECTION_AUTH, help_section=SECTION_AUTH,
help_text="Replace your Telegram account's Matrix puppet with your own Matrix " help_text="Replace your Telegram account's Matrix puppet with your own Matrix "
"account") "account")
async def login_matrix(evt: CommandEvent) -> None: async def login_matrix(evt: CommandEvent) -> Optional[Dict]:
puppet = pu.Puppet.get(evt.sender.tgid) puppet = pu.Puppet.get(evt.sender.tgid)
if puppet.is_real_user: if puppet.is_real_user:
return await evt.reply("You have already logged in with your Matrix account. " return await evt.reply("You have already logged in with your Matrix account. "
@@ -100,7 +100,7 @@ async def login_matrix(evt: CommandEvent) -> None:
return await evt.reply("This bridge instance has been configured to not allow logging in.") return await evt.reply("This bridge instance has been configured to not allow logging in.")
async def enter_matrix_token(evt: CommandEvent) -> None: async def enter_matrix_token(evt: CommandEvent) -> Dict:
evt.sender.command_status = None evt.sender.command_status = None
puppet = pu.Puppet.get(evt.sender.tgid) puppet = pu.Puppet.get(evt.sender.tgid)
@@ -109,10 +109,11 @@ async def enter_matrix_token(evt: CommandEvent) -> None:
"Log out with `$cmdprefix+sp logout-matrix` first.") "Log out with `$cmdprefix+sp logout-matrix` first.")
resp = await puppet.switch_mxid(" ".join(evt.args), evt.sender.mxid) resp = await puppet.switch_mxid(" ".join(evt.args), evt.sender.mxid)
if resp == 2: if resp == pu.PuppetError.OnlyLoginSelf:
return await evt.reply("You can only log in as your own Matrix user.") return await evt.reply("You can only log in as your own Matrix user.")
elif resp == 1: elif resp == pu.PuppetError.InvalidAccessToken:
return await evt.reply("Failed to verify access token.") return await evt.reply("Failed to verify access token.")
assert resp == pu.PuppetError.Success, "Encountered an unhandled PuppetError."
return await evt.reply( return await evt.reply(
f"Replaced your Telegram account's Matrix puppet with {puppet.custom_mxid}.") f"Replaced your Telegram account's Matrix puppet with {puppet.custom_mxid}.")
@@ -121,7 +122,7 @@ async def enter_matrix_token(evt: CommandEvent) -> None:
help_section=SECTION_AUTH, help_section=SECTION_AUTH,
help_args="<_phone_> <_full name_>", help_args="<_phone_> <_full name_>",
help_text="Register to Telegram") help_text="Register to Telegram")
async def register(evt: CommandEvent) -> None: async def register(evt: CommandEvent) -> Optional[Dict]:
if await evt.sender.is_logged_in(): if await evt.sender.is_logged_in():
return await evt.reply("You are already logged in.") return await evt.reply("You are already logged in.")
elif len(evt.args) < 1: elif len(evt.args) < 1:
@@ -138,9 +139,10 @@ async def register(evt: CommandEvent) -> None:
"action": "Register", "action": "Register",
"full_name": full_name, "full_name": full_name,
}) })
return None
async def enter_code_register(evt: CommandEvent) -> None: async def enter_code_register(evt: CommandEvent) -> Dict:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp <code>`") return await evt.reply("**Usage:** `$cmdprefix+sp <code>`")
try: try:
@@ -169,7 +171,7 @@ async def enter_code_register(evt: CommandEvent) -> None:
@command_handler(needs_auth=False, management_only=True, @command_handler(needs_auth=False, management_only=True,
help_section=SECTION_AUTH, help_section=SECTION_AUTH,
help_text="Get instructions on how to log in.") help_text="Get instructions on how to log in.")
async def login(evt: CommandEvent) -> None: async def login(evt: CommandEvent) -> Optional[Dict]:
if await evt.sender.is_logged_in(): if await evt.sender.is_logged_in():
return await evt.reply("You are already logged in.") return await evt.reply("You are already logged in.")
@@ -200,7 +202,8 @@ async def login(evt: CommandEvent) -> None:
return await evt.reply("This bridge instance has been configured to not allow logging in.") return await evt.reply("This bridge instance has been configured to not allow logging in.")
async def request_code(evt: CommandEvent, phone_number: str, next_status: Dict[str, str]) -> None: async def request_code(evt: CommandEvent, phone_number: str, next_status: Dict[str, Any]
) -> Dict:
ok = False ok = False
try: try:
await evt.sender.ensure_started(even_if_no_session=True) await evt.sender.ensure_started(even_if_no_session=True)
@@ -232,7 +235,7 @@ async def request_code(evt: CommandEvent, phone_number: str, next_status: Dict[s
@command_handler(needs_auth=False) @command_handler(needs_auth=False)
async def enter_phone_or_token(evt: CommandEvent) -> None: async def enter_phone_or_token(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp enter-phone-or-token <phone-or-token>`") return await evt.reply("**Usage:** `$cmdprefix+sp enter-phone-or-token <phone-or-token>`")
elif not evt.config.get("bridge.allow_matrix_login", True): elif not evt.config.get("bridge.allow_matrix_login", True):
@@ -252,10 +255,11 @@ async def enter_phone_or_token(evt: CommandEvent) -> None:
"next": enter_code, "next": enter_code,
"action": "Login", "action": "Login",
}) })
return None
@command_handler(needs_auth=False) @command_handler(needs_auth=False)
async def enter_code(evt: CommandEvent) -> None: async def enter_code(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp enter-code <code>`") return await evt.reply("**Usage:** `$cmdprefix+sp enter-code <code>`")
elif not evt.config.get("bridge.allow_matrix_login", True): elif not evt.config.get("bridge.allow_matrix_login", True):
@@ -267,10 +271,11 @@ async def enter_code(evt: CommandEvent) -> None:
evt.log.exception("Error sending phone code") evt.log.exception("Error sending phone code")
return await evt.reply("Unhandled exception while sending code. " return await evt.reply("Unhandled exception while sending code. "
"Check console for more details.") "Check console for more details.")
return None
@command_handler(needs_auth=False) @command_handler(needs_auth=False)
async def enter_password(evt: CommandEvent) -> None: async def enter_password(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp enter-password <password>`") return await evt.reply("**Usage:** `$cmdprefix+sp enter-password <password>`")
elif not evt.config.get("bridge.allow_matrix_login", True): elif not evt.config.get("bridge.allow_matrix_login", True):
@@ -286,9 +291,10 @@ async def enter_password(evt: CommandEvent) -> None:
evt.log.exception("Error sending password") evt.log.exception("Error sending password")
return await evt.reply("Unhandled exception while sending password. " return await evt.reply("Unhandled exception while sending password. "
"Check console for more details.") "Check console for more details.")
return None
async def sign_in(evt: CommandEvent, **sign_in_info) -> None: async def sign_in(evt: CommandEvent, **sign_in_info) -> Dict:
try: try:
await evt.sender.ensure_started(even_if_no_session=True) await evt.sender.ensure_started(even_if_no_session=True)
user = await evt.sender.client.sign_in(**sign_in_info) user = await evt.sender.client.sign_in(**sign_in_info)
@@ -313,7 +319,7 @@ async def sign_in(evt: CommandEvent, **sign_in_info) -> None:
@command_handler(needs_auth=True, @command_handler(needs_auth=True,
help_section=SECTION_AUTH, help_section=SECTION_AUTH,
help_text="Log out from Telegram.") help_text="Log out from Telegram.")
async def logout(evt: CommandEvent) -> None: async def logout(evt: CommandEvent) -> Optional[Dict]:
if await evt.sender.log_out(): if await evt.sender.log_out():
return await evt.reply("Logged out successfully.") return await evt.reply("Logged out successfully.")
return await evt.reply("Failed to log out.") return await evt.reply("Failed to log out.")
+15 -14
View File
@@ -14,21 +14,21 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Tuple, List from typing import Dict, List, NewType, Optional, Tuple, Union
from mautrix_appservice import MatrixRequestError, IntentAPI from mautrix_appservice import MatrixRequestError, IntentAPI
from ..types import MatrixRoomId, MatrixUserId
from . import command_handler, CommandEvent, SECTION_ADMIN from . import command_handler, CommandEvent, SECTION_ADMIN
from .. import puppet as pu, portal as po from .. import puppet as pu, portal as po
ManagementRoomList = List[Tuple[str, str]] ManagementRoom = NewType('ManagementRoom', Tuple[MatrixRoomId, MatrixUserId])
RoomIDList = List[str]
async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList, async def _find_rooms(intent: IntentAPI) -> Tuple[List[ManagementRoom], List[MatrixRoomId],
List["po.Portal"], List["po.Portal"]]: List["po.Portal"], List["po.Portal"]]:
management_rooms = [] # type: ManagementRoomList management_rooms = [] # type: List[ManagementRoom]
unidentified_rooms = [] # type: RoomIDList unidentified_rooms = [] # type: List[MatrixRoomId]
portals = [] # type: List[po.Portal] portals = [] # type: List[po.Portal]
empty_portals = [] # type: List[po.Portal] empty_portals = [] # type: List[po.Portal]
@@ -45,7 +45,7 @@ async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList
if pu.Puppet.get_id_from_mxid(other_member): if pu.Puppet.get_id_from_mxid(other_member):
unidentified_rooms.append(room) unidentified_rooms.append(room)
else: else:
management_rooms.append((room, other_member)) management_rooms.append(ManagementRoom((room, other_member)))
else: else:
unidentified_rooms.append(room) unidentified_rooms.append(room)
else: else:
@@ -61,7 +61,7 @@ async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList
@command_handler(needs_admin=True, needs_auth=False, management_only=True, name="clean-rooms", @command_handler(needs_admin=True, needs_auth=False, management_only=True, name="clean-rooms",
help_section=SECTION_ADMIN, help_section=SECTION_ADMIN,
help_text="Clean up unused portal/management rooms.") help_text="Clean up unused portal/management rooms.")
async def clean_rooms(evt: CommandEvent) -> None: async def clean_rooms(evt: CommandEvent) -> Optional[Dict]:
management_rooms, unidentified_rooms, portals, empty_portals = await _find_rooms(evt.az.intent) management_rooms, unidentified_rooms, portals, empty_portals = await _find_rooms(evt.az.intent)
reply = ["#### Management rooms (M)"] reply = ["#### Management rooms (M)"]
@@ -106,13 +106,14 @@ async def clean_rooms(evt: CommandEvent) -> None:
return await evt.reply("\n".join(reply)) return await evt.reply("\n".join(reply))
async def set_rooms_to_clean(evt, management_rooms: ManagementRoomList, async def set_rooms_to_clean(evt, management_rooms: List[ManagementRoom],
unidentified_rooms: RoomIDList, portals: List["po.Portal"], unidentified_rooms: List[MatrixRoomId], portals: List["po.Portal"],
empty_portals: List["po.Portal"]) -> None: empty_portals: List["po.Portal"]) -> None:
command = evt.args[0] command = evt.args[0]
rooms_to_clean = [] rooms_to_clean = [] # type: List[Union[po.Portal, MatrixRoomId]]
if command == "clean-recommended": if command == "clean-recommended":
rooms_to_clean = empty_portals + unidentified_rooms rooms_to_clean += empty_portals
rooms_to_clean += unidentified_rooms
elif command == "clean-groups": elif command == "clean-groups":
if len(evt.args) < 2: if len(evt.args) < 2:
return await evt.reply("**Usage:** `$cmdprefix+sp clean-groups [M][A][U][I]") return await evt.reply("**Usage:** `$cmdprefix+sp clean-groups [M][A][U][I]")
@@ -158,7 +159,7 @@ async def set_rooms_to_clean(evt, management_rooms: ManagementRoomList,
"`$cmdprefix+sp confirm-clean`.") "`$cmdprefix+sp confirm-clean`.")
async def execute_room_cleanup(evt, rooms_to_clean) -> None: async def execute_room_cleanup(evt, rooms_to_clean: List[Union[po.Portal, MatrixRoomId]]) -> None:
if len(evt.args) > 0 and evt.args[0] == "confirm-clean": if len(evt.args) > 0 and evt.args[0] == "confirm-clean":
await evt.reply(f"Cleaning {len(rooms_to_clean)} rooms. " await evt.reply(f"Cleaning {len(rooms_to_clean)} rooms. "
"This might take a while.") "This might take a while.")
@@ -167,7 +168,7 @@ async def execute_room_cleanup(evt, rooms_to_clean) -> None:
if isinstance(room, po.Portal): if isinstance(room, po.Portal):
await room.cleanup_and_delete() await room.cleanup_and_delete()
cleaned += 1 cleaned += 1
elif isinstance(room, str): elif isinstance(room, str): # str is aliased by MatrixRoomId
await po.Portal.cleanup_room(evt.az.intent, room, message="Room deleted") await po.Portal.cleanup_room(evt.az.intent, room, message="Room deleted")
cleaned += 1 cleaned += 1
evt.sender.command_status = None evt.sender.command_status = None
+30 -18
View File
@@ -14,19 +14,20 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Dict, Callable, Optional from typing import Any, Awaitable, Callable, Coroutine, Dict, List, NamedTuple, Optional, Union
from collections import namedtuple from collections import namedtuple
import markdown import markdown
import logging import logging
from telethon.errors import FloodWaitError from telethon.errors import FloodWaitError
from ..types import MatrixRoomId
from ..util import format_duration from ..util import format_duration
from .. import user as u, context as c from .. import user as u, context as c
command_handlers = {} # type: Dict[str, CommandHandler] command_handlers = {} # type: Dict[str, CommandHandler]
HelpSection = namedtuple("HelpSection", "name order description") HelpSection = NamedTuple('HelpSection', [('name', str), ('order', int), ('description', str)])
SECTION_GENERAL = HelpSection("General", 0, "") SECTION_GENERAL = HelpSection("General", 0, "")
SECTION_AUTH = HelpSection("Authentication", 10, "") SECTION_AUTH = HelpSection("Authentication", 10, "")
@@ -37,8 +38,8 @@ SECTION_ADMIN = HelpSection("Administration", 50, "")
class CommandEvent: class CommandEvent:
def __init__(self, processor: "CommandProcessor", room: str, sender: u.User, command: str, def __init__(self, processor: 'CommandProcessor', room: MatrixRoomId, sender: u.User,
args: List[str], is_management: bool, is_portal: bool) -> None: command: str, args: List[str], is_management: bool, is_portal: bool) -> None:
self.az = processor.az self.az = processor.az
self.log = processor.log self.log = processor.log
self.loop = processor.loop self.loop = processor.loop
@@ -53,7 +54,8 @@ class CommandEvent:
self.is_management = is_management self.is_management = is_management
self.is_portal = is_portal self.is_portal = is_portal
def reply(self, message: str, allow_html: bool = False, render_markdown: bool = True) -> None: def reply(self, message: str, allow_html: bool = False, render_markdown: bool = True
) -> Awaitable[Dict]:
message = message.replace("$cmdprefix+sp ", message = message.replace("$cmdprefix+sp ",
"" if self.is_management else f"{self.command_prefix} ") "" if self.is_management else f"{self.command_prefix} ")
message = message.replace("$cmdprefix", self.command_prefix) message = message.replace("$cmdprefix", self.command_prefix)
@@ -66,7 +68,7 @@ class CommandEvent:
class CommandHandler: class CommandHandler:
def __init__(self, handler: Callable[[CommandEvent], None], needs_auth: bool, def __init__(self, handler: Callable[[CommandEvent], Awaitable[Dict]], needs_auth: bool,
needs_puppeting: bool, needs_matrix_puppeting: bool, needs_admin: bool, needs_puppeting: bool, needs_matrix_puppeting: bool, needs_admin: bool,
management_only: bool, name: str, help_text: str, help_args: str, management_only: bool, name: str, help_text: str, help_args: str,
help_section: HelpSection) -> None: help_section: HelpSection) -> None:
@@ -103,7 +105,8 @@ class CommandHandler:
(not self.needs_admin or is_admin) and (not self.needs_admin or is_admin) and
(not self.needs_auth or is_logged_in)) (not self.needs_auth or is_logged_in))
async def __call__(self, evt: CommandEvent) -> None: async def __call__(self, evt: CommandEvent
) -> Dict:
error = await self.get_permission_error(evt) error = await self.get_permission_error(evt)
if error is not None: if error is not None:
return await evt.reply(error) return await evt.reply(error)
@@ -118,13 +121,21 @@ class CommandHandler:
return f"**{self.name}** {self._help_args} - {self._help_text}" return f"**{self.name}** {self._help_args} - {self._help_text}"
def command_handler(_func: Optional[Callable[[CommandEvent], None]] = None, *, needs_auth=True, def command_handler(_func: Optional[Callable[[CommandEvent], Awaitable[Dict]]] = None, *,
needs_puppeting=True, needs_matrix_puppeting=False, needs_admin=False, needs_auth: bool = True,
management_only=False, name=None, help_text="", help_args="", needs_puppeting: bool = True,
help_section=None) -> None: needs_matrix_puppeting: bool = False,
needs_admin: bool = False,
management_only: bool = False,
name: Optional[str] = None,
help_text: str = "",
help_args: str = "",
help_section: HelpSection = None
) -> Callable[[Callable[[CommandEvent], Awaitable[Optional[Dict]]]],
CommandHandler]:
input_name = name input_name = name
def decorator(func: Callable[[CommandEvent], None]) -> None: def decorator(func: Callable[[CommandEvent], Awaitable[Optional[Dict]]]) -> CommandHandler:
name = input_name or func.__name__.replace("_", "-") name = input_name or func.__name__.replace("_", "-")
handler = CommandHandler(func, needs_auth, needs_puppeting, needs_matrix_puppeting, handler = CommandHandler(func, needs_auth, needs_puppeting, needs_matrix_puppeting,
needs_admin, management_only, name, help_text, help_args, needs_admin, management_only, name, help_text, help_args,
@@ -139,26 +150,26 @@ class CommandProcessor:
log = logging.getLogger("mau.commands") log = logging.getLogger("mau.commands")
def __init__(self, context: c.Context) -> None: def __init__(self, context: c.Context) -> None:
self.az, self.db, self.config, self.loop, self.tgbot = context self.az, self.db, self.config, self.loop, self.tgbot = context.core
self.public_website = context.public_website self.public_website = context.public_website
self.command_prefix = self.config["bridge.command_prefix"] self.command_prefix = self.config["bridge.command_prefix"]
async def handle(self, room: str, sender: u.User, command: str, args: List[str], async def handle(self, room: MatrixRoomId, sender: u.User, command: str, args: List[str],
is_management: bool, is_portal: bool) -> None: is_management: bool, is_portal: bool) -> Optional[Dict]:
evt = CommandEvent(self, room, sender, command, args, is_management, is_portal) evt = CommandEvent(self, room, sender, command, args, is_management, is_portal)
orig_command = command orig_command = command
command = command.lower() command = command.lower()
try: try:
command = command_handlers[command] command_handler = command_handlers[command]
except KeyError: except KeyError:
if sender.command_status and "next" in sender.command_status: if sender.command_status and "next" in sender.command_status:
args.insert(0, orig_command) args.insert(0, orig_command)
evt.command = "" evt.command = ""
command = sender.command_status["next"] command = sender.command_status["next"]
else: else:
command = command_handlers["unknown-command"] command_handler = command_handlers["unknown-command"]
try: try:
await command(evt) await command_handler(evt)
except FloodWaitError as e: except FloodWaitError as e:
return await evt.reply(f"Flood error: Please wait {format_duration(e.seconds)}") return await evt.reply(f"Flood error: Please wait {format_duration(e.seconds)}")
except Exception: except Exception:
@@ -166,3 +177,4 @@ class CommandProcessor:
f"{evt.command} {' '.join(args)} from {sender.mxid}") f"{evt.command} {' '.join(args)} from {sender.mxid}")
return await evt.reply("Unhandled error while handling command. " return await evt.reply("Unhandled error while handling command. "
"Check logs for more details.") "Check logs for more details.")
return None
+17 -14
View File
@@ -14,46 +14,49 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Optional, Tuple
from . import command_handler, CommandEvent, _command_handlers, SECTION_GENERAL from . import command_handler, CommandEvent, _command_handlers, SECTION_GENERAL
from .handler import HelpSection
@command_handler(needs_auth=False, needs_puppeting=False, @command_handler(needs_auth=False, needs_puppeting=False,
help_section=SECTION_GENERAL, help_section=SECTION_GENERAL,
help_text="Cancel an ongoing action (such as login)") help_text="Cancel an ongoing action (such as login)")
def cancel(evt: CommandEvent) -> None: async def cancel(evt: CommandEvent) -> Optional[Dict]:
if evt.sender.command_status: if evt.sender.command_status:
action = evt.sender.command_status["action"] action = evt.sender.command_status["action"]
evt.sender.command_status = None evt.sender.command_status = None
return evt.reply(f"{action} cancelled.") return await evt.reply(f"{action} cancelled.")
else: else:
return evt.reply("No ongoing command.") return await evt.reply("No ongoing command.")
@command_handler(needs_auth=False, needs_puppeting=False) @command_handler(needs_auth=False, needs_puppeting=False)
def unknown_command(evt: CommandEvent) -> None: async def unknown_command(evt: CommandEvent) -> Optional[Dict]:
return evt.reply("Unknown command. Try `$cmdprefix+sp help` for help.") return await evt.reply("Unknown command. Try `$cmdprefix+sp help` for help.")
help_cache = {} help_cache = {} # type: Dict[Tuple[bool, bool, bool, bool, bool], str]
async def _get_help_text(evt: CommandEvent) -> None: async def _get_help_text(evt: CommandEvent) -> str:
cache_key = (evt.is_management, evt.sender.puppet_whitelisted, cache_key = (evt.is_management, evt.sender.puppet_whitelisted,
evt.sender.matrix_puppet_whitelisted, evt.sender.is_admin, evt.sender.matrix_puppet_whitelisted, evt.sender.is_admin,
await evt.sender.is_logged_in()) await evt.sender.is_logged_in())
if cache_key not in help_cache: if cache_key not in help_cache:
help = {} help_sections = {} # type: Dict[HelpSection, List[str]]
for handler in _command_handlers.values(): for handler in _command_handlers.values():
if handler.has_help and handler.has_permission(*cache_key): if handler.has_help and handler.has_permission(*cache_key):
help.setdefault(handler.help_section, []) help_sections.setdefault(handler.help_section, [])
help[handler.help_section].append(handler.help + " ") help_sections[handler.help_section].append(handler.help + " ")
help = sorted(help.items(), key=lambda item: item[0].order) help_sorted = sorted(help_sections.items(), key=lambda item: item[0].order)
help = ["#### {}\n{}\n".format(key.name, "\n".join(value)) for key, value in help] help = ["#### {}\n{}\n".format(key.name, "\n".join(value)) for key, value in help_sorted]
help_cache[cache_key] = "\n".join(help) help_cache[cache_key] = "\n".join(help)
return help_cache[cache_key] return help_cache[cache_key]
def _get_management_status(evt: CommandEvent) -> None: def _get_management_status(evt: CommandEvent) -> str:
if evt.is_management: if evt.is_management:
return "This is a management room: prefixing commands with `$cmdprefix` is not required." return "This is a management room: prefixing commands with `$cmdprefix` is not required."
elif evt.is_portal: elif evt.is_portal:
@@ -65,5 +68,5 @@ def _get_management_status(evt: CommandEvent) -> None:
@command_handler(needs_auth=False, needs_puppeting=False, @command_handler(needs_auth=False, needs_puppeting=False,
help_section=SECTION_GENERAL, help_section=SECTION_GENERAL,
help_text="Show this help message.") help_text="Show this help message.")
async def help(evt: CommandEvent) -> None: async def help(evt: CommandEvent) -> Optional[Dict]:
return await evt.reply(_get_management_status(evt) + "\n" + await _get_help_text(evt)) return await evt.reply(_get_management_status(evt) + "\n" + await _get_help_text(evt))
+44 -36
View File
@@ -14,7 +14,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Callable from typing import Awaitable, Dict, Callable, Coroutine, Optional, Tuple, Union, cast
import asyncio import asyncio
from telethon.errors import (ChatAdminRequiredError, UsernameInvalidError, from telethon.errors import (ChatAdminRequiredError, UsernameInvalidError,
@@ -22,6 +22,7 @@ from telethon.errors import (ChatAdminRequiredError, UsernameInvalidError,
from telethon.tl.types import ChatForbidden, ChannelForbidden from telethon.tl.types import ChatForbidden, ChannelForbidden
from mautrix_appservice import MatrixRequestError, IntentAPI from mautrix_appservice import MatrixRequestError, IntentAPI
from ..types import MatrixRoomId, TelegramId
from .. import portal as po, user as u from .. import portal as po, user as u
from . import (command_handler, CommandEvent, from . import (command_handler, CommandEvent,
SECTION_ADMIN, SECTION_CREATING_PORTALS, SECTION_PORTAL_MANAGEMENT) SECTION_ADMIN, SECTION_CREATING_PORTALS, SECTION_PORTAL_MANAGEMENT)
@@ -31,7 +32,7 @@ from . import (command_handler, CommandEvent,
help_section=SECTION_ADMIN, help_section=SECTION_ADMIN,
help_args="<_level_> [_mxid_]", help_args="<_level_> [_mxid_]",
help_text="Set a temporary power level without affecting Telegram.") help_text="Set a temporary power level without affecting Telegram.")
async def set_power_level(evt: CommandEvent) -> None: async def set_power_level(evt: CommandEvent) -> Dict:
try: try:
level = int(evt.args[0]) level = int(evt.args[0])
except KeyError: except KeyError:
@@ -46,11 +47,12 @@ async def set_power_level(evt: CommandEvent) -> None:
except MatrixRequestError: except MatrixRequestError:
evt.log.exception("Failed to set power level.") evt.log.exception("Failed to set power level.")
return await evt.reply("Failed to set power level.") return await evt.reply("Failed to set power level.")
return {}
@command_handler(help_section=SECTION_PORTAL_MANAGEMENT, @command_handler(help_section=SECTION_PORTAL_MANAGEMENT,
help_text="Get a Telegram invite link to the current chat.") help_text="Get a Telegram invite link to the current chat.")
async def invite_link(evt: CommandEvent) -> None: async def invite_link(evt: CommandEvent) -> Dict:
portal = po.Portal.get_by_mxid(evt.room_id) portal = po.Portal.get_by_mxid(evt.room_id)
if not portal: if not portal:
return await evt.reply("This is not a portal room.") return await evt.reply("This is not a portal room.")
@@ -68,7 +70,7 @@ async def invite_link(evt: CommandEvent) -> None:
async def user_has_power_level(room: str, intent, sender: u.User, event: str, default: int = 50 async def user_has_power_level(room: str, intent, sender: u.User, event: str, default: int = 50
) -> None: ) -> bool:
if sender.is_admin: if sender.is_admin:
return True return True
# Make sure the state store contains the power levels. # Make sure the state store contains the power levels.
@@ -82,8 +84,9 @@ async def user_has_power_level(room: str, intent, sender: u.User, event: str, de
async def _get_portal_and_check_permission(evt: CommandEvent, permission: str, async def _get_portal_and_check_permission(evt: CommandEvent, permission: str,
action: Optional[str] = None) -> None: action: Optional[str] = None
room_id = evt.args[0] if len(evt.args) > 0 else evt.room_id ) -> Tuple[Union[Dict, po.Portal], bool]:
room_id = MatrixRoomId(evt.args[0]) if len(evt.args) > 0 else evt.room_id
portal = po.Portal.get_by_mxid(room_id) portal = po.Portal.get_by_mxid(room_id)
if not portal: if not portal:
@@ -97,8 +100,8 @@ async def _get_portal_and_check_permission(evt: CommandEvent, permission: str,
def _get_portal_murder_function(action: str, room_id: str, function: Callable, command: str, def _get_portal_murder_function(action: str, room_id: str, function: Callable, command: str,
completed_message: str) -> None: completed_message: str) -> Dict:
async def post_confirm(confirm) -> None: async def post_confirm(confirm) -> Optional[Dict]:
confirm.sender.command_status = None confirm.sender.command_status = None
if len(confirm.args) > 0 and confirm.args[0] == f"confirm-{command}": if len(confirm.args) > 0 and confirm.args[0] == f"confirm-{command}":
await function() await function()
@@ -106,6 +109,7 @@ def _get_portal_murder_function(action: str, room_id: str, function: Callable, c
return await confirm.reply(completed_message) return await confirm.reply(completed_message)
else: else:
return await confirm.reply(f"{action} cancelled.") return await confirm.reply(f"{action} cancelled.")
return None
return { return {
"next": post_confirm, "next": post_confirm,
@@ -118,10 +122,11 @@ def _get_portal_murder_function(action: str, room_id: str, function: Callable, c
help_text="Remove all users from the current portal room and forget the portal. " help_text="Remove all users from the current portal room and forget the portal. "
"Only works for group chats; to delete a private chat portal, simply " "Only works for group chats; to delete a private chat portal, simply "
"leave the room.") "leave the room.")
async def delete_portal(evt: CommandEvent) -> None: async def delete_portal(evt: CommandEvent) -> Optional[Dict]:
portal, ok = await _get_portal_and_check_permission(evt, "unbridge") result, ok = await _get_portal_and_check_permission(evt, "unbridge")
if not ok: if not ok:
return return None
portal = cast('po.Portal', result)
evt.sender.command_status = _get_portal_murder_function("Portal deletion", portal.mxid, evt.sender.command_status = _get_portal_murder_function("Portal deletion", portal.mxid,
portal.cleanup_and_delete, "delete", portal.cleanup_and_delete, "delete",
@@ -139,10 +144,11 @@ async def delete_portal(evt: CommandEvent) -> None:
@command_handler(needs_auth=False, needs_puppeting=False, @command_handler(needs_auth=False, needs_puppeting=False,
help_section=SECTION_PORTAL_MANAGEMENT, help_section=SECTION_PORTAL_MANAGEMENT,
help_text="Remove puppets from the current portal room and forget the portal.") help_text="Remove puppets from the current portal room and forget the portal.")
async def unbridge(evt: CommandEvent) -> None: async def unbridge(evt: CommandEvent) -> Optional[Dict]:
portal, ok = await _get_portal_and_check_permission(evt, "unbridge") result, ok = await _get_portal_and_check_permission(evt, "unbridge")
if not ok: if not ok:
return return None
portal = cast('po.Portal', result)
evt.sender.command_status = _get_portal_murder_function("Room unbridging", portal.mxid, evt.sender.command_status = _get_portal_murder_function("Room unbridging", portal.mxid,
portal.unbridge, "unbridge", portal.unbridge, "unbridge",
@@ -158,11 +164,11 @@ async def unbridge(evt: CommandEvent) -> None:
help_text="Bridge the current Matrix room to the Telegram chat with the given " help_text="Bridge the current Matrix room to the Telegram chat with the given "
"ID. The ID must be the prefixed version that you get with the `/id` " "ID. The ID must be the prefixed version that you get with the `/id` "
"command of the Telegram-side bot.") "command of the Telegram-side bot.")
async def bridge(evt: CommandEvent) -> None: async def bridge(evt: CommandEvent) -> Dict:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** " return await evt.reply("**Usage:** "
"`$cmdprefix+sp bridge <Telegram chat ID> [Matrix room ID]`") "`$cmdprefix+sp bridge <Telegram chat ID> [Matrix room ID]`")
room_id = evt.args[1] if len(evt.args) > 1 else evt.room_id room_id = MatrixRoomId(evt.args[1]) if len(evt.args) > 1 else evt.room_id
that_this = "This" if room_id == evt.room_id else "That" that_this = "This" if room_id == evt.room_id else "That"
portal = po.Portal.get_by_mxid(room_id) portal = po.Portal.get_by_mxid(room_id)
@@ -173,12 +179,12 @@ async def bridge(evt: CommandEvent) -> None:
return await evt.reply(f"You do not have the permissions to bridge {that_this} room.") return await evt.reply(f"You do not have the permissions to bridge {that_this} room.")
# The /id bot command provides the prefixed ID, so we assume # The /id bot command provides the prefixed ID, so we assume
tgid = evt.args[0] tgid_str = evt.args[0]
if tgid.startswith("-100"): if tgid_str.startswith("-100"):
tgid = int(tgid[4:]) tgid = TelegramId(int(tgid_str[4:]))
peer_type = "channel" peer_type = "channel"
elif tgid.startswith("-"): elif tgid_str.startswith("-"):
tgid = -int(tgid) tgid = TelegramId(-int(tgid_str))
peer_type = "chat" peer_type = "chat"
else: else:
return await evt.reply("That doesn't seem like a prefixed Telegram chat ID.\n\n" return await evt.reply("That doesn't seem like a prefixed Telegram chat ID.\n\n"
@@ -224,7 +230,8 @@ async def bridge(evt: CommandEvent) -> None:
"chat to this room, use `$cmdprefix+sp continue`") "chat to this room, use `$cmdprefix+sp continue`")
async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Portal") -> None: async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Portal"
) -> Tuple[bool, Coroutine[None, None, None]]:
if not portal.mxid: if not portal.mxid:
await evt.reply("The portal seems to have lost its Matrix room between you" await evt.reply("The portal seems to have lost its Matrix room between you"
"calling `$cmdprefix+sp bridge` and this command.\n\n" "calling `$cmdprefix+sp bridge` and this command.\n\n"
@@ -247,7 +254,7 @@ async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Porta
return False, None return False, None
async def confirm_bridge(evt: CommandEvent) -> None: async def confirm_bridge(evt: CommandEvent) -> Optional[Dict]:
status = evt.sender.command_status status = evt.sender.command_status
try: try:
portal = po.Portal.get_by_tgid(status["tgid"], peer_type=status["peer_type"]) portal = po.Portal.get_by_tgid(status["tgid"], peer_type=status["peer_type"])
@@ -260,7 +267,7 @@ async def confirm_bridge(evt: CommandEvent) -> None:
if "mxid" in status: if "mxid" in status:
ok, coro = await cleanup_old_portal_while_bridging(evt, portal) ok, coro = await cleanup_old_portal_while_bridging(evt, portal)
if not ok: if not ok:
return return None
elif coro: elif coro:
asyncio.ensure_future(coro, loop=evt.loop) asyncio.ensure_future(coro, loop=evt.loop)
await evt.reply("Cleaning up previous portal room...") await evt.reply("Cleaning up previous portal room...")
@@ -304,7 +311,7 @@ async def confirm_bridge(evt: CommandEvent) -> None:
return await evt.reply("Bridging complete. Portal synchronization should begin momentarily.") return await evt.reply("Bridging complete. Portal synchronization should begin momentarily.")
async def get_initial_state(intent: IntentAPI, room_id: str) -> None: async def get_initial_state(intent: IntentAPI, room_id: str) -> Tuple[str, str, Dict]:
state = await intent.get_room_state(room_id) state = await intent.get_room_state(room_id)
title = None title = None
about = None about = None
@@ -330,7 +337,7 @@ async def get_initial_state(intent: IntentAPI, room_id: str) -> None:
help_text="Create a Telegram chat of the given type for the current Matrix room. " help_text="Create a Telegram chat of the given type for the current Matrix room. "
"The type is either `group`, `supergroup` or `channel` (defaults to " "The type is either `group`, `supergroup` or `channel` (defaults to "
"`group`).") "`group`).")
async def create(evt: CommandEvent) -> None: async def create(evt: CommandEvent) -> Dict:
type = evt.args[0] if len(evt.args) > 0 else "group" type = evt.args[0] if len(evt.args) > 0 else "group"
if type not in {"chat", "group", "supergroup", "channel"}: if type not in {"chat", "group", "supergroup", "channel"}:
return await evt.reply( return await evt.reply(
@@ -365,7 +372,7 @@ async def create(evt: CommandEvent) -> None:
@command_handler(help_section=SECTION_PORTAL_MANAGEMENT, @command_handler(help_section=SECTION_PORTAL_MANAGEMENT,
help_text="Upgrade a normal Telegram group to a supergroup.") help_text="Upgrade a normal Telegram group to a supergroup.")
async def upgrade(evt: CommandEvent) -> None: async def upgrade(evt: CommandEvent) -> Dict:
portal = po.Portal.get_by_mxid(evt.room_id) portal = po.Portal.get_by_mxid(evt.room_id)
if not portal: if not portal:
return await evt.reply("This is not a portal room.") return await evt.reply("This is not a portal room.")
@@ -387,7 +394,7 @@ async def upgrade(evt: CommandEvent) -> None:
help_args="<_name_|`-`>", help_args="<_name_|`-`>",
help_text="Change the username of a supergroup/channel. " help_text="Change the username of a supergroup/channel. "
"To disable, use a dash (`-`) as the name.") "To disable, use a dash (`-`) as the name.")
async def group_name(evt: CommandEvent) -> None: async def group_name(evt: CommandEvent) -> Dict:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp group-name <name/->`") return await evt.reply("**Usage:** `$cmdprefix+sp group-name <name/->`")
@@ -423,7 +430,7 @@ async def group_name(evt: CommandEvent) -> None:
help_args="<`whitelist`|`blacklist`>", help_args="<`whitelist`|`blacklist`>",
help_text="Change whether the bridge will allow or disallow bridging rooms by " help_text="Change whether the bridge will allow or disallow bridging rooms by "
"default.") "default.")
async def filter_mode(evt: CommandEvent) -> None: async def filter_mode(evt: CommandEvent) -> Dict:
try: try:
mode = evt.args[0] mode = evt.args[0]
if mode not in ("whitelist", "blacklist"): if mode not in ("whitelist", "blacklist"):
@@ -448,19 +455,19 @@ async def filter_mode(evt: CommandEvent) -> None:
help_section=SECTION_ADMIN, help_section=SECTION_ADMIN,
help_args="<`whitelist`|`blacklist`> <_chat ID_>", help_args="<`whitelist`|`blacklist`> <_chat ID_>",
help_text="Allow or disallow bridging a specific chat.") help_text="Allow or disallow bridging a specific chat.")
async def filter(evt: CommandEvent) -> None: async def filter(evt: CommandEvent) -> Optional[Dict]:
try: try:
action = evt.args[0] action = evt.args[0]
if action not in ("whitelist", "blacklist", "add", "remove"): if action not in ("whitelist", "blacklist", "add", "remove"):
raise ValueError() raise ValueError()
id = evt.args[1] id_str = evt.args[1]
if id.startswith("-100"): if id_str.startswith("-100"):
id = int(id[4:]) id = int(id_str[4:])
elif id.startswith("-"): elif id_str.startswith("-"):
id = int(id[1:]) id = int(id_str[1:])
else: else:
id = int(id) id = int(id_str)
except (IndexError, ValueError): except (IndexError, ValueError):
return await evt.reply("**Usage:** `$cmdprefix+sp filter <whitelist/blacklist> <chat ID>`") return await evt.reply("**Usage:** `$cmdprefix+sp filter <whitelist/blacklist> <chat ID>`")
@@ -490,3 +497,4 @@ async def filter(evt: CommandEvent) -> None:
list.remove(id) list.remove(id)
save() save()
return await evt.reply(f"Chat ID removed from {mode}.") return await evt.reply(f"Chat ID removed from {mode}.")
return None
+11 -7
View File
@@ -14,10 +14,13 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Dict, List, Optional, Tuple
import re
from telethon.errors import ( from telethon.errors import (
InviteHashInvalidError, InviteHashExpiredError, UserAlreadyParticipantError) InviteHashInvalidError, InviteHashExpiredError, UserAlreadyParticipantError)
from telethon.tl.types import User as TLUser from telethon.tl.types import User as TLUser
from telethon.tl.types import TypeUpdates
from telethon.tl.functions.messages import ImportChatInviteRequest, CheckChatInviteRequest from telethon.tl.functions.messages import ImportChatInviteRequest, CheckChatInviteRequest
from telethon.tl.functions.channels import JoinChannelRequest from telethon.tl.functions.channels import JoinChannelRequest
@@ -28,7 +31,7 @@ from . import command_handler, CommandEvent, SECTION_MISC, SECTION_CREATING_PORT
@command_handler(help_section=SECTION_MISC, @command_handler(help_section=SECTION_MISC,
help_args="[_-r|--remote_] <_query_>", help_args="[_-r|--remote_] <_query_>",
help_text="Search your contacts or the Telegram servers for users.") help_text="Search your contacts or the Telegram servers for users.")
async def search(evt: CommandEvent) -> None: async def search(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp search [-r|--remote] <query>`") return await evt.reply("**Usage:** `$cmdprefix+sp search [-r|--remote] <query>`")
@@ -49,7 +52,7 @@ async def search(evt: CommandEvent) -> None:
"Minimum length of remote query is 5 characters.") "Minimum length of remote query is 5 characters.")
return await evt.reply("No results 3:") return await evt.reply("No results 3:")
reply = [] reply = [] # type: List[str]
if remote: if remote:
reply += ["**Results from Telegram server:**", ""] reply += ["**Results from Telegram server:**", ""]
else: else:
@@ -70,7 +73,7 @@ async def search(evt: CommandEvent) -> None:
"either the internal user ID, the username or the phone number. " "either the internal user ID, the username or the phone number. "
"**N.B.** The phone numbers you start chats with must already be in " "**N.B.** The phone numbers you start chats with must already be in "
"your contacts.") "your contacts.")
async def private_message(evt: CommandEvent) -> None: async def private_message(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp pm <user identifier>`") return await evt.reply("**Usage:** `$cmdprefix+sp pm <user identifier>`")
@@ -89,7 +92,7 @@ async def private_message(evt: CommandEvent) -> None:
f"{pu.Puppet.get_displayname(user, False)}") f"{pu.Puppet.get_displayname(user, False)}")
async def _join(evt: CommandEvent, arg: str) -> None: async def _join(evt: CommandEvent, arg: str) -> Tuple[TypeUpdates, Dict]:
if arg.startswith("joinchat/"): if arg.startswith("joinchat/"):
invite_hash = arg[len("joinchat/"):] invite_hash = arg[len("joinchat/"):]
try: try:
@@ -112,7 +115,7 @@ async def _join(evt: CommandEvent, arg: str) -> None:
@command_handler(help_section=SECTION_CREATING_PORTALS, @command_handler(help_section=SECTION_CREATING_PORTALS,
help_args="<_link_>", help_args="<_link_>",
help_text="Join a chat with an invite link.") help_text="Join a chat with an invite link.")
async def join(evt: CommandEvent) -> None: async def join(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp join <invite link>`") return await evt.reply("**Usage:** `$cmdprefix+sp join <invite link>`")
@@ -123,7 +126,7 @@ async def join(evt: CommandEvent) -> None:
updates, _ = await _join(evt, arg.group(1)) updates, _ = await _join(evt, arg.group(1))
if not updates: if not updates:
return return None
for chat in updates.chats: for chat in updates.chats:
portal = po.Portal.get_by_entity(chat) portal = po.Portal.get_by_entity(chat)
@@ -134,12 +137,13 @@ async def join(evt: CommandEvent) -> None:
await evt.reply(f"Creating room for {chat.title}... This might take a while.") await evt.reply(f"Creating room for {chat.title}... This might take a while.")
await portal.create_matrix_room(evt.sender, chat, [evt.sender.mxid]) await portal.create_matrix_room(evt.sender, chat, [evt.sender.mxid])
return await evt.reply(f"Created room for {portal.title}") return await evt.reply(f"Created room for {portal.title}")
return None
@command_handler(help_section=SECTION_MISC, @command_handler(help_section=SECTION_MISC,
help_args="[`chats`|`contacts`|`me`]", help_args="[`chats`|`contacts`|`me`]",
help_text="Synchronize your chat portals, contacts and/or own info.") help_text="Synchronize your chat portals, contacts and/or own info.")
async def sync(evt: CommandEvent) -> None: async def sync(evt: CommandEvent) -> Optional[Dict]:
if len(evt.args) > 0: if len(evt.args) > 0:
sync_only = evt.args[0] sync_only = evt.args[0]
if sync_only not in ("chats", "contacts", "me"): if sync_only not in ("chats", "contacts", "me"):
+3 -3
View File
@@ -14,7 +14,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Tuple, Any, Optional from typing import Any, Dict, Optional, Tuple
from ruamel.yaml import YAML from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap from ruamel.yaml.comments import CommentedMap
import random import random
@@ -25,7 +25,7 @@ yaml.indent(4)
class DictWithRecursion: class DictWithRecursion:
def __init__(self, data: CommentedMap = None) -> None: def __init__(self, data: Optional[CommentedMap] = None) -> None:
self._data = data or CommentedMap() # type: CommentedMap self._data = data or CommentedMap() # type: CommentedMap
def _recursive_get(self, data: CommentedMap, key: str, default_value: Any) -> Any: def _recursive_get(self, data: CommentedMap, key: str, default_value: Any) -> Any:
@@ -99,7 +99,7 @@ class Config(DictWithRecursion):
self.path = path # type: str self.path = path # type: str
self.registration_path = registration_path # type: str self.registration_path = registration_path # type: str
self.base_path = base_path # type: str self.base_path = base_path # type: str
self._registration = None # type: dict self._registration = None # type: Optional[Dict]
def load(self) -> None: def load(self) -> None:
with open(self.path, 'r') as stream: with open(self.path, 'r') as stream:
+5 -7
View File
@@ -14,7 +14,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import TYPE_CHECKING, Optional from typing import Generator, Optional, Tuple, Union, TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
import asyncio import asyncio
@@ -44,9 +44,7 @@ class Context:
self.public_website = None # type: PublicBridgeWebsite self.public_website = None # type: PublicBridgeWebsite
self.provisioning_api = None # type: ProvisioningAPI self.provisioning_api = None # type: ProvisioningAPI
def __iter__(self) -> None: @property
yield self.az def core(self) -> Tuple['AppService', 'scoped_session', 'Config',
yield self.db 'asyncio.AbstractEventLoop', Optional['Bot']]:
yield self.config return (self.az, self.db, self.config, self.loop, self.bot)
yield self.loop
yield self.bot
+8 -6
View File
@@ -14,6 +14,8 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict
from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer, from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer,
BigInteger, String, Boolean, Text) BigInteger, String, Boolean, Text)
from sqlalchemy.sql import expression from sqlalchemy.sql import expression
@@ -88,20 +90,20 @@ class RoomState(Base):
room_id = Column(String, primary_key=True) room_id = Column(String, primary_key=True)
_power_levels_text = Column("power_levels", Text, nullable=True) _power_levels_text = Column("power_levels", Text, nullable=True)
_power_levels_json = None _power_levels_json = {} # type: Dict
@property @property
def has_power_levels(self) -> None: def has_power_levels(self) -> bool:
return bool(self._power_levels_text) return bool(self._power_levels_text)
@property @property
def power_levels(self) -> None: def power_levels(self) -> Dict:
if not self._power_levels_json and self._power_levels_text: if not self._power_levels_json and self._power_levels_text:
self._power_levels_json = json.loads(self._power_levels_text) self._power_levels_json = json.loads(self._power_levels_text)
return self._power_levels_json or {} return self._power_levels_json
@power_levels.setter @power_levels.setter
def power_levels(self, val) -> None: def power_levels(self, val: Dict) -> None:
self._power_levels_json = val self._power_levels_json = val
self._power_levels_text = json.dumps(val) self._power_levels_text = json.dumps(val)
@@ -116,7 +118,7 @@ class UserProfile(Base):
displayname = Column(String, nullable=True) displayname = Column(String, nullable=True)
avatar_url = Column(String, nullable=True) avatar_url = Column(String, nullable=True)
def dict(self) -> None: def dict(self) -> Dict[str, Column]:
return { return {
"membership": self.membership, "membership": self.membership,
"displayname": self.displayname, "displayname": self.displayname,
@@ -80,12 +80,12 @@ class MatrixParser(HTMLParser, MatrixParserCommon):
args["url"] = url args["url"] = url
return MessageEntityTextUrl, None return MessageEntityTextUrl, None
def handle_starttag(self, tag: str, attrs: List[Tuple[str, str]]): def handle_starttag(self, tag: str, attrs_list: List[Tuple[str, str]]):
self._open_tags.appendleft(tag) self._open_tags.appendleft(tag)
self._open_tags_meta.appendleft(0) self._open_tags_meta.appendleft(0)
attrs = dict(attrs) attrs = dict(attrs_list)
entity_type = None # type: type(TypeMessageEntity) entity_type = None # type: Optional[Type[TypeMessageEntity]]
args = {} # type: Dict[str, Any] args = {} # type: Dict[str, Any]
if tag in ("strong", "b"): if tag in ("strong", "b"):
entity_type = MessageEntityBold entity_type = MessageEntityBold
@@ -14,7 +14,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, List, Tuple, Union, Callable from typing import Callable, List, Optional, Sequence, Tuple, Type, Union
from lxml import html from lxml import html
from telethon.tl.types import (MessageEntityMention as Mention, from telethon.tl.types import (MessageEntityMention as Mention,
@@ -83,7 +83,7 @@ def offset_length_multiply(amount: int):
class TelegramMessage: class TelegramMessage:
def __init__(self, text: str = "", entities: Optional[List[TypeMessageEntity]] = None): def __init__(self, text: str = "", entities: Optional[List[TypeMessageEntity]] = None) -> None:
self.text = text # type: str self.text = text # type: str
self.entities = entities or [] # type: List[TypeMessageEntity] self.entities = entities or [] # type: List[TypeMessageEntity]
@@ -120,7 +120,7 @@ class TelegramMessage:
self.text = msg.text + self.text self.text = msg.text + self.text
return self return self
def format(self, entity_type: type(TypeMessageEntity), offset: int = None, length: int = None, def format(self, entity_type: Type[TypeMessageEntity], offset: int = None, length: int = None,
**kwargs) -> "TelegramMessage": **kwargs) -> "TelegramMessage":
self.entities.append(entity_type(offset=offset or 0, self.entities.append(entity_type(offset=offset or 0,
length=length if length is not None else len(self.text), length=length if length is not None else len(self.text),
@@ -158,7 +158,8 @@ class TelegramMessage:
return output return output
@staticmethod @staticmethod
def join(items: List[Union[str, "TelegramMessage"]], separator: str = " ") -> "TelegramMessage": def join(items: Sequence[Union[str, "TelegramMessage"]],
separator: str = " ") -> "TelegramMessage":
main = TelegramMessage() main = TelegramMessage()
for msg in items: for msg in items:
if isinstance(msg, str): if isinstance(msg, str):
+11 -9
View File
@@ -14,7 +14,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, List, Tuple, TYPE_CHECKING from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from html import escape from html import escape
import logging import logging
import re import re
@@ -28,6 +28,7 @@ from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName,
from mautrix_appservice import MatrixRequestError from mautrix_appservice import MatrixRequestError
from mautrix_appservice.intent_api import IntentAPI from mautrix_appservice.intent_api import IntentAPI
from ..types import TelegramId
from .. import user as u, puppet as pu, portal as po from .. import user as u, puppet as pu, portal as po
from ..db import Message as DBMessage from ..db import Message as DBMessage
from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html, from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html,
@@ -40,14 +41,14 @@ if TYPE_CHECKING:
try: try:
from lxml.html.diff import htmldiff from lxml.html.diff import htmldiff
except ImportError: except ImportError:
htmldiff = None # type: function htmldiff = None # type: ignore
log = logging.getLogger("mau.fmt.tg") # type: logging.Logger log = logging.getLogger("mau.fmt.tg") # type: logging.Logger
should_highlight_edits = False # type: bool should_highlight_edits = False # type: bool
def telegram_reply_to_matrix(evt: Message, source: "AbstractUser") -> dict: def telegram_reply_to_matrix(evt: Message, source: 'AbstractUser') -> Dict:
if evt.reply_to_msg_id: if evt.reply_to_msg_id:
space = (evt.to_id.channel_id space = (evt.to_id.channel_id
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel) if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
@@ -116,7 +117,7 @@ def highlight_edits(new_html: str, old_html: str) -> str:
async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: Message, async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: Message,
relates_to: dict, main_intent: IntentAPI, is_edit: bool relates_to: Dict, main_intent: IntentAPI, is_edit: bool
) -> Tuple[str, str]: ) -> Tuple[str, str]:
space = (evt.to_id.channel_id space = (evt.to_id.channel_id
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel) if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
@@ -177,10 +178,10 @@ async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: M
async def telegram_to_matrix(evt: Message, source: "AbstractUser", async def telegram_to_matrix(evt: Message, source: "AbstractUser",
main_intent: Optional[IntentAPI] = None, main_intent: Optional[IntentAPI] = None,
is_edit: bool = False, prefix_text: Optional[str] = None, is_edit: bool = False, prefix_text: Optional[str] = None,
prefix_html: Optional[str] = None) -> Tuple[str, str, dict]: prefix_html: Optional[str] = None) -> Tuple[str, str, Dict]:
text = add_surrogates(evt.message) text = add_surrogates(evt.message)
html = _telegram_entities_to_matrix_catch(text, evt.entities) if evt.entities else None html = _telegram_entities_to_matrix_catch(text, evt.entities) if evt.entities else None
relates_to = {} relates_to = {} # type: Dict
if prefix_html: if prefix_html:
html = prefix_html + (html or escape(text)) html = prefix_html + (html or escape(text))
@@ -217,6 +218,7 @@ def _telegram_entities_to_matrix_catch(text: str, entities: List[TypeMessageEnti
"message=%s\n" "message=%s\n"
"entities=%s", "entities=%s",
text, entities) text, entities)
return "[failed conversion in _telegram_entities_to_matrix]"
def _telegram_entities_to_matrix(text: str, entities: List[TypeMessageEntity]) -> str: def _telegram_entities_to_matrix(text: str, entities: List[TypeMessageEntity]) -> str:
@@ -290,7 +292,7 @@ def _parse_mention(html: List[str], entity_text: str) -> bool:
return False return False
def _parse_name_mention(html: List[str], entity_text: str, user_id: int) -> bool: def _parse_name_mention(html: List[str], entity_text: str, user_id: TelegramId) -> bool:
user = u.User.get_by_tgid(user_id) user = u.User.get_by_tgid(user_id)
if user: if user:
mxid = user.mxid mxid = user.mxid
@@ -315,8 +317,8 @@ def _parse_url(html: List[str], entity_text: str, url: str) -> bool:
message_link_match = message_link_regex.match(url) message_link_match = message_link_regex.match(url)
if message_link_match: if message_link_match:
group, msgid = message_link_match.groups() group, msgid_str = message_link_match.groups()
msgid = int(msgid) msgid = int(msgid_str)
portal = po.Portal.find_by_username(group) portal = po.Portal.find_by_username(group)
if portal: if portal:
+58 -37
View File
@@ -14,23 +14,31 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Dict, Tuple, Set, Match from typing import Dict, List, Match, Optional, Set, Tuple, TYPE_CHECKING
import logging import logging
import asyncio import asyncio
import re import re
from mautrix_appservice import MatrixRequestError, IntentError from mautrix_appservice import MatrixRequestError, IntentError
from .types import MatrixEvent, MatrixEventId, MatrixRoomId, MatrixUserId
from . import user as u, portal as po, puppet as pu, commands as com from . import user as u, portal as po, puppet as pu, commands as com
if TYPE_CHECKING:
from mautrix_appservice import AppService
from .context import Context
from sqlalchemy.orm import scoped_session
from .config import Config
from .bot import Bot
class MatrixHandler: class MatrixHandler:
log = logging.getLogger("mau.mx") # type: logging.Logger log = logging.getLogger("mau.mx") # type: logging.Logger
def __init__(self, context) -> None: def __init__(self, context: 'Context') -> None:
self.az, self.db, self.config, _, self.tgbot = context self.az, self.db, self.config, _, self.tgbot = context.core
self.commands = com.CommandProcessor(context) # type: com.CommandProcessor self.commands = com.CommandProcessor(context) # type: com.CommandProcessor
self.previously_typing = [] # type: List[str] self.previously_typing = [] # type: List[MatrixUserId]
self.az.matrix_event_handler(self.handle_event) self.az.matrix_event_handler(self.handle_event)
@@ -50,7 +58,8 @@ class MatrixHandler:
except asyncio.TimeoutError: except asyncio.TimeoutError:
self.log.exception("TimeoutError when trying to set avatar") self.log.exception("TimeoutError when trying to set avatar")
async def handle_puppet_invite(self, room_id, puppet: pu.Puppet, inviter: u.User) -> None: async def handle_puppet_invite(self, room_id: MatrixRoomId, puppet: pu.Puppet, inviter: u.User
) -> None:
intent = puppet.default_mxid_intent intent = puppet.default_mxid_intent
self.log.debug(f"{inviter} invited puppet for {puppet.tgid} to {room_id}") self.log.debug(f"{inviter} invited puppet for {puppet.tgid} to {room_id}")
if not await inviter.is_logged_in(): if not await inviter.is_logged_in():
@@ -80,6 +89,7 @@ class MatrixHandler:
await intent.join_room(room_id) await intent.join_room(room_id)
portal = po.Portal.get_by_tgid(puppet.tgid, inviter.tgid, "user") portal = po.Portal.get_by_tgid(puppet.tgid, inviter.tgid, "user")
# TODO: if portal is None:
if portal.mxid: if portal.mxid:
try: try:
await intent.invite(portal.mxid, inviter.mxid) await intent.invite(portal.mxid, inviter.mxid)
@@ -95,13 +105,13 @@ class MatrixHandler:
portal.mxid = room_id portal.mxid = room_id
portal.save() portal.save()
inviter.register_portal(portal) inviter.register_portal(portal)
await intent.send_notice(room_id, "po.Portal to private chat created.") await intent.send_notice(room_id, "Portal to private chat created.")
else: else:
await intent.join_room(room_id) await intent.join_room(room_id)
await intent.send_notice(room_id, "This puppet will remain inactive until a " await intent.send_notice(room_id, "This puppet will remain inactive until a "
"Telegram chat is created for this room.") "Telegram chat is created for this room.")
async def accept_bot_invite(self, room_id: str, inviter: u.User) -> None: async def accept_bot_invite(self, room_id: MatrixRoomId, inviter: u.User) -> None:
tries = 0 tries = 0
while tries < 5: while tries < 5:
try: try:
@@ -126,9 +136,13 @@ class MatrixHandler:
"<code>bridge.permissions</code> section in your config file.") "<code>bridge.permissions</code> section in your config file.")
await self.az.intent.leave_room(room_id) await self.az.intent.leave_room(room_id)
async def handle_invite(self, room_id: str, user_id: str, inviter_mxid: str) -> None: async def handle_invite(self, room_id: MatrixRoomId, user_id: MatrixUserId,
inviter_mxid: MatrixUserId) -> None:
self.log.debug(f"{inviter_mxid} invited {user_id} to {room_id}") self.log.debug(f"{inviter_mxid} invited {user_id} to {room_id}")
inviter = await u.User.get_by_mxid(inviter_mxid).ensure_started() inviter = u.User.get_by_mxid(inviter_mxid)
if inviter is None:
self.log.exception("Failed to find user with Matrix ID {inviter_mxid}")
await inviter.ensure_started()
if user_id == self.az.bot_mxid: if user_id == self.az.bot_mxid:
return await self.accept_bot_invite(room_id, inviter) return await self.accept_bot_invite(room_id, inviter)
elif not inviter.whitelisted: elif not inviter.whitelisted:
@@ -150,7 +164,8 @@ class MatrixHandler:
# The rest can probably be ignored # The rest can probably be ignored
async def handle_join(self, room_id: str, user_id: str, event_id: str) -> None: async def handle_join(self, room_id: MatrixRoomId, user_id: MatrixUserId,
event_id: MatrixEventId) -> None:
user = await u.User.get_by_mxid(user_id).ensure_started() user = await u.User.get_by_mxid(user_id).ensure_started()
portal = po.Portal.get_by_mxid(room_id) portal = po.Portal.get_by_mxid(room_id)
@@ -171,7 +186,8 @@ class MatrixHandler:
if await user.is_logged_in() or portal.has_bot: if await user.is_logged_in() or portal.has_bot:
await portal.join_matrix(user, event_id) await portal.join_matrix(user, event_id)
async def handle_part(self, room_id: str, user_id, sender_mxid: str, event_id: str) -> None: async def handle_part(self, room_id: MatrixRoomId, user_id: MatrixUserId,
sender_mxid: MatrixUserId, event_id: MatrixEventId) -> None:
self.log.debug(f"{user_id} left {room_id}") self.log.debug(f"{user_id} left {room_id}")
sender = u.User.get_by_mxid(sender_mxid, create=False) sender = u.User.get_by_mxid(sender_mxid, create=False)
@@ -185,6 +201,7 @@ class MatrixHandler:
puppet = pu.Puppet.get_by_mxid(user_id) puppet = pu.Puppet.get_by_mxid(user_id)
if sender and puppet: if sender and puppet:
# TODO: Puppet should probably be an AbstractUser
await portal.leave_matrix(puppet, sender, event_id) await portal.leave_matrix(puppet, sender, event_id)
user = u.User.get_by_mxid(user_id, create=False) user = u.User.get_by_mxid(user_id, create=False)
@@ -194,7 +211,7 @@ class MatrixHandler:
if await user.is_logged_in() or portal.has_bot: if await user.is_logged_in() or portal.has_bot:
await portal.leave_matrix(user, sender, event_id) await portal.leave_matrix(user, sender, event_id)
def is_command(self, message: dict) -> Tuple[bool, str]: def is_command(self, message: Dict) -> Tuple[bool, str]:
text = message.get("body", "") text = message.get("body", "")
prefix = self.config["bridge.command_prefix"] prefix = self.config["bridge.command_prefix"]
is_command = text.startswith(prefix) is_command = text.startswith(prefix)
@@ -202,9 +219,10 @@ class MatrixHandler:
text = text[len(prefix) + 1:] text = text[len(prefix) + 1:]
return is_command, text return is_command, text
async def handle_message(self, room, sender, message, event_id) -> None: async def handle_message(self, room: MatrixRoomId, sender_id: MatrixUserId, message: Dict,
event_id: MatrixEventId) -> None:
is_command, text = self.is_command(message) is_command, text = self.is_command(message)
sender = await u.User.get_by_mxid(sender).ensure_started() sender = await u.User.get_by_mxid(sender_id).ensure_started()
if not sender.relaybot_whitelisted: if not sender.relaybot_whitelisted:
self.log.debug(f"Ignoring message \"{message}\" from {sender} to {room}:" self.log.debug(f"Ignoring message \"{message}\" from {sender} to {room}:"
" u.User is not whitelisted.") " u.User is not whitelisted.")
@@ -237,7 +255,8 @@ class MatrixHandler:
is_portal=portal is not None) is_portal=portal is not None)
@staticmethod @staticmethod
async def handle_redaction(room_id: str, sender_mxid: str, event_id: str) -> None: async def handle_redaction(room_id: MatrixRoomId, sender_mxid: MatrixUserId,
event_id: MatrixEventId) -> None:
sender = await u.User.get_by_mxid(sender_mxid).ensure_started() sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if not sender.relaybot_whitelisted: if not sender.relaybot_whitelisted:
return return
@@ -249,14 +268,15 @@ class MatrixHandler:
await portal.handle_matrix_deletion(sender, event_id) await portal.handle_matrix_deletion(sender, event_id)
@staticmethod @staticmethod
async def handle_power_levels(room_id: str, sender_mxid: str, new: dict, old: dict) -> None: async def handle_power_levels(room_id: MatrixRoomId, sender_mxid: MatrixUserId,
new: Dict, old: Dict) -> None:
portal = po.Portal.get_by_mxid(room_id) portal = po.Portal.get_by_mxid(room_id)
sender = await u.User.get_by_mxid(sender_mxid).ensure_started() sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if await sender.has_full_access(allow_bot=True) and portal: if await sender.has_full_access(allow_bot=True) and portal:
await portal.handle_matrix_power_levels(sender, new["users"], old["users"]) await portal.handle_matrix_power_levels(sender, new["users"], old["users"])
@staticmethod @staticmethod
async def handle_room_meta(evt_type: str, room_id: str, sender_mxid: str, async def handle_room_meta(evt_type: str, room_id: MatrixRoomId, sender_mxid: MatrixUserId,
content: dict) -> None: content: dict) -> None:
portal = po.Portal.get_by_mxid(room_id) portal = po.Portal.get_by_mxid(room_id)
sender = await u.User.get_by_mxid(sender_mxid).ensure_started() sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
@@ -271,8 +291,8 @@ class MatrixHandler:
await handler(sender, content[content_key]) await handler(sender, content[content_key])
@staticmethod @staticmethod
async def handle_room_pin(room_id: str, sender_mxid: str, new_events: Set[str], async def handle_room_pin(room_id: MatrixRoomId, sender_mxid: MatrixUserId,
old_events: Set[str]) -> None: new_events: Set[str], old_events: Set[str]) -> None:
portal = po.Portal.get_by_mxid(room_id) portal = po.Portal.get_by_mxid(room_id)
sender = await u.User.get_by_mxid(sender_mxid).ensure_started() sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if await sender.has_full_access(allow_bot=True) and portal: if await sender.has_full_access(allow_bot=True) and portal:
@@ -285,8 +305,8 @@ class MatrixHandler:
await portal.handle_matrix_pin(sender, None) await portal.handle_matrix_pin(sender, None)
@staticmethod @staticmethod
async def handle_name_change(room_id: str, user_id: str, displayname: str, async def handle_name_change(room_id: MatrixRoomId, user_id: MatrixUserId, displayname: str,
prev_displayname: str, event_id: str) -> None: prev_displayname: str, event_id: MatrixEventId) -> None:
portal = po.Portal.get_by_mxid(room_id) portal = po.Portal.get_by_mxid(room_id)
if not portal or not portal.has_bot: if not portal or not portal.has_bot:
return return
@@ -296,13 +316,14 @@ class MatrixHandler:
await portal.name_change_matrix(user, displayname, prev_displayname, event_id) await portal.name_change_matrix(user, displayname, prev_displayname, event_id)
@staticmethod @staticmethod
def parse_read_receipts(content: dict) -> Dict[str, str]: def parse_read_receipts(content: Dict) -> Dict[MatrixUserId, MatrixEventId]:
return {user_id: event_id return {user_id: event_id
for event_id, receipts in content.items() for event_id, receipts in content.items()
for user_id in receipts.get("m.read", {})} for user_id in receipts.get("m.read", {})}
@staticmethod @staticmethod
async def handle_read_receipts(room_id: str, receipts: Dict[str, str]) -> None: async def handle_read_receipts(room_id: MatrixRoomId,
receipts: Dict[MatrixUserId, MatrixEventId]) -> None:
portal = po.Portal.get_by_mxid(room_id) portal = po.Portal.get_by_mxid(room_id)
if not portal: if not portal:
return return
@@ -314,13 +335,13 @@ class MatrixHandler:
await portal.mark_read(user, event_id) await portal.mark_read(user, event_id)
@staticmethod @staticmethod
async def handle_presence(user_id: str, presence: str) -> None: async def handle_presence(user_id: MatrixUserId, presence: str) -> None:
user = await u.User.get_by_mxid(user_id).ensure_started() user = await u.User.get_by_mxid(user_id).ensure_started()
if not await user.is_logged_in(): if not await user.is_logged_in():
return return
await user.set_presence(presence == "online") user.set_presence(presence == "online")
async def handle_typing(self, room_id: str, now_typing: List[str]) -> None: async def handle_typing(self, room_id: MatrixRoomId, now_typing: List[MatrixUserId]) -> None:
portal = po.Portal.get_by_mxid(room_id) portal = po.Portal.get_by_mxid(room_id)
if not portal: if not portal:
return return
@@ -335,35 +356,35 @@ class MatrixHandler:
if not await user.is_logged_in(): if not await user.is_logged_in():
continue continue
await portal.set_typing(user, is_typing) portal.set_typing(user, is_typing)
self.previously_typing = now_typing self.previously_typing = now_typing
def filter_matrix_event(self, event: dict) -> None: def filter_matrix_event(self, event: MatrixEvent) -> bool:
sender = event.get("sender", None) sender = event.get("sender", None)
if not sender: if not sender:
return False return False
return (sender == self.az.bot_mxid return (sender == self.az.bot_mxid
or pu.Puppet.get_id_from_mxid(sender) is not None) or pu.Puppet.get_id_from_mxid(sender) is not None)
async def try_handle_event(self, evt: dict) -> None: async def try_handle_event(self, evt: MatrixEvent) -> None:
try: try:
await self.handle_event(evt) await self.handle_event(evt)
except Exception: except Exception:
self.log.exception("Error handling manually received Matrix event") self.log.exception("Error handling manually received Matrix event")
async def handle_event(self, evt: dict) -> None: async def handle_event(self, evt: MatrixEvent) -> None:
if self.filter_matrix_event(evt): if self.filter_matrix_event(evt):
return return
self.log.debug("Received event: %s", evt) self.log.debug("Received event: %s", evt)
evt_type = evt.get("type", "m.unknown") # type: str evt_type = evt.get("type", "m.unknown") # type: str
room_id = evt.get("room_id", None) # type: str room_id = evt.get("room_id", None) # type: Optional[MatrixRoomId]
event_id = evt.get("event_id", None) # type: str event_id = evt.get("event_id", None) # type: Optional[MatrixEventId]
sender = evt.get("sender", None) # type: str sender = evt.get("sender", None) # type: Optional[MatrixUserId]
content = evt.get("content", {}) # type: dict content = evt.get("content", {}) # type: Dict
if evt_type == "m.room.member": if evt_type == "m.room.member":
state_key = evt["state_key"] # type: str state_key = evt["state_key"] # type: MatrixUserId
prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: dict prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: Dict
membership = content.get("membership", "") # type: str membership = content.get("membership", "") # type: str
prev_membership = prev_content.get("membership", "leave") # type: str prev_membership = prev_content.get("membership", "leave") # type: str
if membership == prev_membership: if membership == prev_membership:
@@ -387,7 +408,7 @@ class MatrixHandler:
elif evt_type == "m.room.redaction": elif evt_type == "m.room.redaction":
await self.handle_redaction(room_id, sender, evt["redacts"]) await self.handle_redaction(room_id, sender, evt["redacts"])
elif evt_type == "m.room.power_levels": elif evt_type == "m.room.power_levels":
prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: dict prev_content = evt.get("unsigned", {}).get("prev_content", {})
await self.handle_power_levels(room_id, sender, evt["content"], prev_content) await self.handle_power_levels(room_id, sender, evt["content"], prev_content)
elif evt_type in ("m.room.name", "m.room.avatar", "m.room.topic"): elif evt_type in ("m.room.name", "m.room.avatar", "m.room.topic"):
await self.handle_room_meta(evt_type, room_id, sender, evt["content"]) await self.handle_room_meta(evt_type, room_id, sender, evt["content"])
+73 -61
View File
@@ -14,7 +14,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Dict, List, Optional, Pattern, Tuple, Union, TYPE_CHECKING from typing import Awaitable, Dict, List, Optional, Pattern, Tuple, Union, cast, TYPE_CHECKING
from collections import deque from collections import deque
from datetime import datetime from datetime import datetime
from string import Template from string import Template
@@ -62,7 +62,7 @@ from telethon.tl.types import (
UserFull) UserFull)
from mautrix_appservice import MatrixRequestError, IntentError, AppService, IntentAPI from mautrix_appservice import MatrixRequestError, IntentError, AppService, IntentAPI
from .types import MatrixEventId, MatrixRoomId, MatrixUserId, TelegramId
from .context import Context from .context import Context
from .db import Portal as DBPortal, Message as DBMessage, TelegramFile as DBTelegramFile from .db import Portal as DBPortal, Message as DBMessage, TelegramFile as DBTelegramFile
from . import puppet as p, user as u, formatter, util from . import puppet as p, user as u, formatter, util
@@ -105,18 +105,18 @@ class Portal:
by_mxid = {} # type: Dict[str, Portal] by_mxid = {} # type: Dict[str, Portal]
by_tgid = {} # type: Dict[Tuple[int, int], Portal] by_tgid = {} # type: Dict[Tuple[int, int], Portal]
def __init__(self, tgid: int, peer_type: str, tg_receiver: Optional[int] = None, def __init__(self, tgid: TelegramId, peer_type: str, tg_receiver: Optional[int] = None,
mxid: Optional[str] = None, username: Optional[str] = None, mxid: Optional[MatrixRoomId] = None, username: Optional[str] = None,
megagroup: Optional[bool] = False, title: Optional[str] = None, megagroup: Optional[bool] = False, title: Optional[str] = None,
about: Optional[str] = None, photo_id: Optional[str] = None, about: Optional[str] = None, photo_id: Optional[str] = None,
db_instance: DBPortal = None) -> None: db_instance: DBPortal = None) -> None:
self.mxid = mxid # type: str self.mxid = mxid # type: Optional[MatrixRoomId]
self.tgid = tgid # type: int self.tgid = tgid # type: TelegramId
self.tg_receiver = tg_receiver or tgid # type: int self.tg_receiver = tg_receiver or tgid # type: int
self.peer_type = peer_type # type: str self.peer_type = peer_type # type: str
self.username = username # type: str self.username = username # type: str
self.megagroup = megagroup # type: bool self.megagroup = megagroup # type: bool
self.title = title # type: str self.title = title # type: Optional[str]
self.about = about # type: str self.about = about # type: str
self.photo_id = photo_id # type: str self.photo_id = photo_id # type: str
self._db_instance = db_instance # type: DBPortal self._db_instance = db_instance # type: DBPortal
@@ -161,7 +161,7 @@ class Portal:
@property @property
def has_bot(self) -> bool: def has_bot(self) -> bool:
return self.bot and self.bot.is_in_chat(self.tgid) return bool(self.bot and self.bot.is_in_chat(self.tgid))
@property @property
def main_intent(self) -> IntentAPI: def main_intent(self) -> IntentAPI:
@@ -270,8 +270,8 @@ class Portal:
else: else:
raise ValueError("Invalid invite identifier given to invite_matrix()") raise ValueError("Invalid invite identifier given to invite_matrix()")
async def update_matrix_room(self, user: "AbstractUser", entity: TypeChat, direct: bool, async def update_matrix_room(self, user: 'AbstractUser', entity: TypeChat, direct: bool,
puppet: p.Puppet = None, levels: dict = None, puppet: p.Puppet = None, levels: Dict = None,
users: List[User] = None, users: List[User] = None,
participants: List[TypeParticipant] = None) -> None: participants: List[TypeParticipant] = None) -> None:
if not direct: if not direct:
@@ -303,8 +303,8 @@ class Portal:
async with self._room_create_lock: async with self._room_create_lock:
return await self._create_matrix_room(user, entity, invites) return await self._create_matrix_room(user, entity, invites)
async def _create_matrix_room(self, user: "AbstractUser", entity: TypeChat, invites: InviteList async def _create_matrix_room(self, user: 'AbstractUser', entity: TypeChat, invites: InviteList
) -> Optional[str]: ) -> Optional[MatrixRoomId]:
direct = self.peer_type == "user" direct = self.peer_type == "user"
if self.mxid: if self.mxid:
@@ -369,6 +369,8 @@ class Portal:
participants=participants), participants=participants),
loop=self.loop) loop=self.loop)
return self.mxid
def _get_base_power_levels(self, levels: dict = None, entity: TypeChat = None) -> dict: def _get_base_power_levels(self, levels: dict = None, entity: TypeChat = None) -> dict:
levels = levels or {} levels = levels or {}
power_level_requirement = (0 if self.peer_type == "chat" and not entity.admins_enabled power_level_requirement = (0 if self.peer_type == "chat" and not entity.admins_enabled
@@ -437,18 +439,19 @@ class Portal:
and config["bridge.max_initial_member_sync"] == -1 and config["bridge.max_initial_member_sync"] == -1
and (self.megagroup or self.peer_type != "channel")) and (self.megagroup or self.peer_type != "channel"))
if trust_member_list: if trust_member_list:
joined_mxids = await self.main_intent.get_room_members(self.mxid) joined_mxids = cast(List[MatrixUserId],
for user in joined_mxids: await self.main_intent.get_room_members(self.mxid))
if user == self.az.bot_mxid: for user_mxid in joined_mxids:
if user_mxid == self.az.bot_mxid:
continue continue
puppet_id = p.Puppet.get_id_from_mxid(user) puppet_id = p.Puppet.get_id_from_mxid(user_mxid)
if puppet_id and puppet_id not in allowed_tgids: if puppet_id and puppet_id not in allowed_tgids:
if self.bot and puppet_id == self.bot.tgid: if self.bot and puppet_id == self.bot.tgid:
self.bot.remove_chat(self.tgid) self.bot.remove_chat(self.tgid)
await self.main_intent.kick(self.mxid, user, await self.main_intent.kick(self.mxid, user_mxid,
"User had left this Telegram chat.") "User had left this Telegram chat.")
continue continue
mx_user = u.User.get_by_mxid(user, create=False) mx_user = u.User.get_by_mxid(user_mxid, create=False)
if mx_user and mx_user.is_bot and mx_user.tgid not in allowed_tgids: if mx_user and mx_user.is_bot and mx_user.tgid not in allowed_tgids:
mx_user.unregister_portal(self) mx_user.unregister_portal(self)
@@ -457,7 +460,7 @@ class Portal:
"You had left this Telegram chat.") "You had left this Telegram chat.")
continue continue
async def add_telegram_user(self, user_id: int, source: Optional["AbstractUser"] = None async def add_telegram_user(self, user_id: TelegramId, source: Optional['AbstractUser'] = None
) -> None: ) -> None:
puppet = p.Puppet.get(user_id) puppet = p.Puppet.get(user_id)
if source: if source:
@@ -470,7 +473,7 @@ class Portal:
user.register_portal(self) user.register_portal(self)
await self.invite_to_matrix(user.mxid) await self.invite_to_matrix(user.mxid)
async def delete_telegram_user(self, user_id: int, sender: p.Puppet) -> None: async def delete_telegram_user(self, user_id: TelegramId, sender: p.Puppet) -> None:
puppet = p.Puppet.get(user_id) puppet = p.Puppet.get(user_id)
user = u.User.get_by_tgid(user_id) user = u.User.get_by_tgid(user_id)
kick_message = (f"Kicked by {sender.displayname}" kick_message = (f"Kicked by {sender.displayname}"
@@ -568,8 +571,9 @@ class Portal:
return True return True
return False return False
async def _get_users(self, user: "AbstractUser", entity: Union[TypeInputPeer, InputUser, async def _get_users(self,
TypeChat, TypeUser] user: 'AbstractUser',
entity: Union[TypeInputPeer, InputUser, TypeChat, TypeUser]
) -> Tuple[List[TypeUser], List[TypeParticipant]]: ) -> Tuple[List[TypeUser], List[TypeParticipant]]:
if self.peer_type == "chat": if self.peer_type == "chat":
chat = await user.client(GetFullChatRequest(chat_id=self.tgid)) chat = await user.client(GetFullChatRequest(chat_id=self.tgid))
@@ -588,7 +592,7 @@ class Portal:
entity, ChannelParticipantsRecent(), offset=0, limit=limit, hash=0)) entity, ChannelParticipantsRecent(), offset=0, limit=limit, hash=0))
return response.users, response.participants return response.users, response.participants
elif limit > 200 or limit == -1: elif limit > 200 or limit == -1:
users, participants = [], [] users, participants = [], [] # type: Tuple[List[TypeUser], List[TypeParticipant]]
offset = 0 offset = 0
remaining_quota = limit if limit > 0 else 1000000 remaining_quota = limit if limit > 0 else 1000000
query = (ChannelParticipantsSearch("") if limit == -1 query = (ChannelParticipantsSearch("") if limit == -1
@@ -609,6 +613,7 @@ class Portal:
return [], [] return [], []
elif self.peer_type == "user": elif self.peer_type == "user":
return [entity], [] return [entity], []
return [], []
async def get_invite_link(self, user: 'u.User') -> str: async def get_invite_link(self, user: 'u.User') -> str:
if self.peer_type == "user": if self.peer_type == "user":
@@ -688,7 +693,7 @@ class Portal:
return "" return ""
async def _get_state_change_message(self, event: str, user: 'u.User', async def _get_state_change_message(self, event: str, user: 'u.User',
arguments: Optional[dict] = None) -> Optional[dict]: arguments: Optional[Dict] = None) -> Optional[Dict]:
tpl = config[f"bridge.state_event_formats.{event}"] tpl = config[f"bridge.state_event_formats.{event}"]
if len(tpl) == 0: if len(tpl) == 0:
# Empty format means they don't want the message # Empty format means they don't want the message
@@ -724,11 +729,11 @@ class Portal:
or user.mxid_localpart) or user.mxid_localpart)
def set_typing(self, user: 'u.User', typing: bool = True, def set_typing(self, user: 'u.User', typing: bool = True,
action=SendMessageTypingAction) -> None: action: type = SendMessageTypingAction) -> bool:
return user.client(SetTypingRequest( return user.client(SetTypingRequest(
self.peer, action() if typing else SendMessageCancelAction())) self.peer, action() if typing else SendMessageCancelAction()))
async def mark_read(self, user: 'u.User', event_id: str) -> None: async def mark_read(self, user: 'u.User', event_id: MatrixEventId) -> None:
if user.is_bot: if user.is_bot:
return return
space = self.tgid if self.peer_type == "channel" else user.tgid space = self.tgid if self.peer_type == "channel" else user.tgid
@@ -743,7 +748,8 @@ class Portal:
else: else:
await user.client(ReadMessageHistoryRequest(peer=self.peer, max_id=message.tgid)) await user.client(ReadMessageHistoryRequest(peer=self.peer, max_id=message.tgid))
async def leave_matrix(self, user: 'u.User', source: 'u.User', event_id: str) -> None: async def leave_matrix(self, user: 'u.User', source: 'u.User', event_id: MatrixEventId
) -> None:
if await user.needs_relaybot(self): if await user.needs_relaybot(self):
async with self.require_send_lock(self.bot.tgid): async with self.require_send_lock(self.bot.tgid):
message = await self._get_state_change_message("leave", user) message = await self._get_state_change_message("leave", user)
@@ -798,7 +804,7 @@ class Portal:
# We'll just assume the user is already in the chat. # We'll just assume the user is already in the chat.
pass pass
async def _apply_msg_format(self, sender: 'u.User', msgtype: str, message: dict) -> None: async def _apply_msg_format(self, sender: 'u.User', msgtype: str, message: Dict) -> None:
if "formatted_body" not in message: if "formatted_body" not in message:
message["format"] = "org.matrix.custom.html" message["format"] = "org.matrix.custom.html"
message["formatted_body"] = escape_html(message.get("body", "")) message["formatted_body"] = escape_html(message.get("body", ""))
@@ -823,7 +829,7 @@ class Portal:
await self._apply_msg_format(sender, msgtype, message) await self._apply_msg_format(sender, msgtype, message)
@staticmethod @staticmethod
def _matrix_event_to_entities(event: dict) -> Tuple[str, Optional[List[TypeMessageEntity]]]: def _matrix_event_to_entities(event: Dict) -> Tuple[str, Optional[List[TypeMessageEntity]]]:
try: try:
if event.get("format", None) == "org.matrix.custom.html": if event.get("format", None) == "org.matrix.custom.html":
message, entities = formatter.matrix_to_telegram(event.get("formatted_body", "")) message, entities = formatter.matrix_to_telegram(event.get("formatted_body", ""))
@@ -851,7 +857,8 @@ class Portal:
return None return None
async def _handle_matrix_text(self, sender_id: int, event_id: str, space: int, async def _handle_matrix_text(self, sender_id: int, event_id: str, space: int,
client: "MautrixTelegramClient", message: dict, reply_to: int) -> None: client: 'MautrixTelegramClient', message: Dict, reply_to: int
) -> None:
lock = self.require_send_lock(sender_id) lock = self.require_send_lock(sender_id)
async with lock: async with lock:
response = await client.send_message(self.peer, message, reply_to=reply_to, response = await client.send_message(self.peer, message, reply_to=reply_to,
@@ -859,7 +866,8 @@ class Portal:
self._add_telegram_message_to_db(event_id, space, response) self._add_telegram_message_to_db(event_id, space, response)
async def _handle_matrix_file(self, msgtype: str, sender_id: int, event_id: str, space: int, async def _handle_matrix_file(self, msgtype: str, sender_id: int, event_id: str, space: int,
client: "MautrixTelegramClient", message: dict, reply_to: int) -> None: client: 'MautrixTelegramClient', message: dict, reply_to: int
) -> None:
file = await self.main_intent.download_file(message["url"]) file = await self.main_intent.download_file(message["url"])
info = message.get("info", {}) info = message.get("info", {})
@@ -893,7 +901,7 @@ class Portal:
self._add_telegram_message_to_db(event_id, space, response) self._add_telegram_message_to_db(event_id, space, response)
async def _handle_matrix_location(self, sender_id: int, event_id: str, space: int, async def _handle_matrix_location(self, sender_id: int, event_id: str, space: int,
client: "MautrixTelegramClient", message: dict, client: 'MautrixTelegramClient', message: Dict,
reply_to: int) -> None: reply_to: int) -> None:
try: try:
lat, long = message["geo_uri"][len("geo:"):].split(",") lat, long = message["geo_uri"][len("geo:"):].split(",")
@@ -901,13 +909,13 @@ class Portal:
except (KeyError, ValueError): except (KeyError, ValueError):
self.log.exception("Failed to parse location") self.log.exception("Failed to parse location")
return None return None
message, entities = self._matrix_event_to_entities(message) caption, entities = self._matrix_event_to_entities(message)
media = MessageMediaGeo(geo=GeoPoint(lat, long, access_hash=0)) media = MessageMediaGeo(geo=GeoPoint(lat, long, access_hash=0))
lock = self.require_send_lock(sender_id) lock = self.require_send_lock(sender_id)
async with lock: async with lock:
response = await client.send_media(self.peer, media, reply_to=reply_to, response = await client.send_media(self.peer, media, reply_to=reply_to,
caption=message, entities=entities) caption=caption, entities=entities)
self._add_telegram_message_to_db(event_id, space, response) self._add_telegram_message_to_db(event_id, space, response)
def _add_telegram_message_to_db(self, event_id: str, space: int, def _add_telegram_message_to_db(self, event_id: str, space: int,
@@ -963,17 +971,18 @@ class Portal:
except ChatNotModifiedError: except ChatNotModifiedError:
pass pass
async def handle_matrix_deletion(self, deleter: 'u.User', event_id: str) -> None: async def handle_matrix_deletion(self, deleter: 'u.User', event_id: MatrixEventId) -> None:
deleter = deleter if not await deleter.needs_relaybot(self) else self.bot real_deleter = deleter if not await deleter.needs_relaybot(self) else self.bot
space = self.tgid if self.peer_type == "channel" else deleter.tgid space = self.tgid if self.peer_type == "channel" else real_deleter.tgid
message = DBMessage.query.filter(DBMessage.mxid == event_id, message = DBMessage.query.filter(DBMessage.mxid == event_id,
DBMessage.tg_space == space, DBMessage.tg_space == space,
DBMessage.mx_room == self.mxid).one_or_none() DBMessage.mx_room == self.mxid).one_or_none()
if not message: if not message:
return return
await deleter.client.delete_messages(self.peer, [message.tgid]) await real_deleter.client.delete_messages(self.peer, [message.tgid])
async def _update_telegram_power_level(self, sender: 'u.User', user_id: int, level: int) -> None: async def _update_telegram_power_level(self, sender: 'u.User', user_id: TelegramId,
level: int) -> None:
if self.peer_type == "chat": if self.peer_type == "chat":
await sender.client(EditChatAdminRequest( await sender.client(EditChatAdminRequest(
chat_id=self.tgid, user_id=user_id, is_admin=level >= 50)) chat_id=self.tgid, user_id=user_id, is_admin=level >= 50))
@@ -989,7 +998,8 @@ class Portal:
EditAdminRequest(channel=await self.get_input_entity(sender), EditAdminRequest(channel=await self.get_input_entity(sender),
user_id=user_id, admin_rights=rights)) user_id=user_id, admin_rights=rights))
async def handle_matrix_power_levels(self, sender: 'u.User', new_users: Dict[str, int], async def handle_matrix_power_levels(self, sender: 'u.User',
new_users: Dict[MatrixUserId, int],
old_users: Dict[str, int]) -> None: old_users: Dict[str, int]) -> None:
# TODO handle all power level changes and bridge exact admin rights to supergroups/channels # TODO handle all power level changes and bridge exact admin rights to supergroups/channels
for user, level in new_users.items(): for user, level in new_users.items():
@@ -1167,7 +1177,7 @@ class Portal:
return None return None
async def handle_telegram_photo(self, source: "AbstractUser", intent: IntentAPI, evt: Message, async def handle_telegram_photo(self, source: "AbstractUser", intent: IntentAPI, evt: Message,
relates_to=None) -> None: relates_to: Dict = {}) -> None:
largest_size = self._get_largest_photo_size(evt.media.photo) largest_size = self._get_largest_photo_size(evt.media.photo)
file = await util.transfer_file_to_matrix(self.db, source.client, intent, file = await util.transfer_file_to_matrix(self.db, source.client, intent,
largest_size.location) largest_size.location)
@@ -1197,7 +1207,7 @@ class Portal:
external_url=self.get_external_url(evt)) external_url=self.get_external_url(evt))
@staticmethod @staticmethod
def _parse_telegram_document_attributes(attributes: List[TypeDocumentAttribute]) -> dict: def _parse_telegram_document_attributes(attributes: List[TypeDocumentAttribute]) -> Dict:
attrs = { attrs = {
"name": None, "name": None,
"mime_type": None, "mime_type": None,
@@ -1205,7 +1215,7 @@ class Portal:
"sticker_alt": None, "sticker_alt": None,
"width": None, "width": None,
"height": None, "height": None,
} } # type: Dict
for attr in attributes: for attr in attributes:
if isinstance(attr, DocumentAttributeFilename): if isinstance(attr, DocumentAttributeFilename):
attrs["name"] = attrs["name"] or attr.file_name attrs["name"] = attrs["name"] or attr.file_name
@@ -1218,8 +1228,8 @@ class Portal:
return attrs return attrs
@staticmethod @staticmethod
def _parse_telegram_document_meta(evt: Message, file: DBTelegramFile, attrs: dict def _parse_telegram_document_meta(evt: Message, file: DBTelegramFile, attrs: Dict
) -> Tuple[dict, str]: ) -> Tuple[Dict, str]:
document = evt.media.document document = evt.media.document
name = evt.message or attrs["name"] name = evt.message or attrs["name"]
if attrs["is_sticker"]: if attrs["is_sticker"]:
@@ -1253,7 +1263,7 @@ class Portal:
async def handle_telegram_document(self, source: "AbstractUser", intent: IntentAPI, async def handle_telegram_document(self, source: "AbstractUser", intent: IntentAPI,
evt: Message, evt: Message,
relates_to: dict = None) -> Optional[dict]: relates_to: dict = None) -> Optional[Dict]:
document = evt.media.document document = evt.media.document
attrs = self._parse_telegram_document_attributes(document.attributes) attrs = self._parse_telegram_document_attributes(document.attributes)
@@ -1521,9 +1531,9 @@ class Portal:
else: else:
self.log.debug("Unhandled Telegram action in %s: %s", self.title, action) self.log.debug("Unhandled Telegram action in %s: %s", self.title, action)
async def set_telegram_admin(self, user_id: int) -> None: async def set_telegram_admin(self, user_id: TelegramId) -> None:
puppet = p.Puppet.get(user_id) puppet = p.Puppet.get(user_id)
user = await u.User.get_by_tgid(user_id) user = u.User.get_by_tgid(user_id)
levels = await self.main_intent.get_power_levels(self.mxid) levels = await self.main_intent.get_power_levels(self.mxid)
if user: if user:
@@ -1558,7 +1568,7 @@ class Portal:
await self.update_telegram_pin() await self.update_telegram_pin()
@staticmethod @staticmethod
def _get_level_from_participant(participant: TypeParticipant, _) -> int: def _get_level_from_participant(participant: TypeParticipant, _: Dict) -> int:
# TODO use the power level requirements to get better precision in channels # TODO use the power level requirements to get better precision in channels
if isinstance(participant, (ChatParticipantAdmin, ChannelParticipantAdmin)): if isinstance(participant, (ChatParticipantAdmin, ChannelParticipantAdmin)):
return 50 return 50
@@ -1599,7 +1609,7 @@ class Portal:
except KeyError: except KeyError:
return 50 return 50
def _participants_to_power_levels(self, participants: List[TypeParticipant], levels: dict def _participants_to_power_levels(self, participants: List[TypeParticipant], levels: Dict
) -> bool: ) -> bool:
bot_level = self._get_bot_level(levels) bot_level = self._get_bot_level(levels)
if bot_level < self._get_powerlevel_level(levels): if bot_level < self._get_powerlevel_level(levels):
@@ -1654,7 +1664,7 @@ class Portal:
mxid=self.mxid, username=self.username, megagroup=self.megagroup, mxid=self.mxid, username=self.username, megagroup=self.megagroup,
title=self.title, about=self.about, photo_id=self.photo_id) title=self.title, about=self.about, photo_id=self.photo_id)
def migrate_and_save(self, new_id: int) -> None: def migrate_and_save(self, new_id: TelegramId) -> None:
existing = DBPortal.query.get(self.tgid_full) existing = DBPortal.query.get(self.tgid_full)
if existing: if existing:
self.db.delete(existing) self.db.delete(existing)
@@ -1701,7 +1711,7 @@ class Portal:
# region Class instance lookup # region Class instance lookup
@classmethod @classmethod
def get_by_mxid(cls, mxid: str) -> Optional["Portal"]: def get_by_mxid(cls, mxid: MatrixRoomId) -> Optional['Portal']:
try: try:
return cls.by_mxid[mxid] return cls.by_mxid[mxid]
except KeyError: except KeyError:
@@ -1721,7 +1731,7 @@ class Portal:
return None return None
@classmethod @classmethod
def find_by_username(cls, username: str) -> Optional["Portal"]: def find_by_username(cls, username: str) -> Optional['Portal']:
if not username: if not username:
return None return None
@@ -1729,15 +1739,15 @@ class Portal:
if portal.username and portal.username.lower() == username.lower(): if portal.username and portal.username.lower() == username.lower():
return portal return portal
portal = DBPortal.query.filter(DBPortal.username == username).one_or_none() dbportal = DBPortal.query.filter(DBPortal.username == username).one_or_none()
if portal: if dbportal:
return cls.from_db(portal) return cls.from_db(dbportal)
return None return None
@classmethod @classmethod
def get_by_tgid(cls, tgid: int, tg_receiver: int = None, peer_type: str = None def get_by_tgid(cls, tgid: TelegramId, tg_receiver: Optional[TelegramId] = None,
) -> Optional["Portal"]: peer_type: str = None) -> Optional['Portal']:
tg_receiver = tg_receiver or tgid tg_receiver = tg_receiver or tgid
tgid_full = (tgid, tg_receiver) tgid_full = (tgid, tg_receiver)
try: try:
@@ -1758,8 +1768,10 @@ class Portal:
return None return None
@classmethod @classmethod
def get_by_entity(cls, entity: Union[TypeChat, TypePeer, TypeUser, TypeUserFull, TypeInputPeer], def get_by_entity(cls, entity: Union[TypeChat, TypePeer, TypeUser, TypeUserFull,
receiver_id: int = None, create: bool = True) -> Optional["Portal"]: TypeInputPeer],
receiver_id: Optional[TelegramId] = None, create: bool = True
) -> Optional['Portal']:
entity_type = type(entity) entity_type = type(entity)
if entity_type in {Chat, ChatFull}: if entity_type in {Chat, ChatFull}:
type_name = "chat" type_name = "chat"
@@ -1790,7 +1802,7 @@ class Portal:
def init(context: Context) -> None: def init(context: Context) -> None:
global config global config
Portal.az, Portal.db, config, Portal.loop, Portal.bot = context Portal.az, Portal.db, config, Portal.loop, Portal.bot = context.core
Portal.bridge_notices = config["bridge.bridge_notices"] Portal.bridge_notices = config["bridge.bridge_notices"]
Portal.filter_mode = config["bridge.filter.mode"] Portal.filter_mode = config["bridge.filter.mode"]
Portal.filter_list = config["bridge.filter.list"] Portal.filter_list = config["bridge.filter.list"]
+106 -80
View File
@@ -14,17 +14,19 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Awaitable, Pattern, Dict, List, TYPE_CHECKING from typing import Awaitable, Coroutine, Dict, List, NewType, Optional, Pattern, TYPE_CHECKING
from difflib import SequenceMatcher from difflib import SequenceMatcher
import re import re
import logging import logging
import asyncio import asyncio
from enum import Enum
from sqlalchemy import orm from sqlalchemy import orm
from telethon.tl.types import UserProfilePhoto from telethon.tl.types import UserProfilePhoto, User, FileLocation
from mautrix_appservice import AppService, IntentAPI, IntentError, MatrixRequestError from mautrix_appservice import AppService, IntentAPI, IntentError, MatrixRequestError
from .types import MatrixUserId, TelegramId
from .db import Puppet as DBPuppet from .db import Puppet as DBPuppet
from . import util from . import util
@@ -32,6 +34,11 @@ if TYPE_CHECKING:
from .matrix import MatrixHandler from .matrix import MatrixHandler
from .config import Config from .config import Config
from .context import Context from .context import Context
from . import user as u
from .abstract_user import AbstractUser
PuppetError = Enum('PuppetError', 'Success OnlyLoginSelf InvalidAccessToken')
config = None # type: Config config = None # type: Config
@@ -45,85 +52,98 @@ class Puppet:
mxid_regex = None # type: Pattern mxid_regex = None # type: Pattern
username_template = None # type: str username_template = None # type: str
hs_domain = None # type: str hs_domain = None # type: str
cache = {} # type: Dict[str, Puppet] cache = {} # type: Dict[TelegramId, Puppet]
by_custom_mxid = {} # type: Dict[str, Puppet] by_custom_mxid = {} # type: Dict[str, Puppet]
def __init__(self, id=None, access_token=None, custom_mxid=None, username=None, def __init__(self,
displayname=None, displayname_source=None, photo_id=None, is_bot=None, id: TelegramId,
is_registered=False, db_instance=None) -> None: access_token: Optional[str] = None,
self.id = id custom_mxid: Optional[MatrixUserId] = None,
self.access_token = access_token username: Optional[str] = None,
self.custom_mxid = custom_mxid displayname: Optional[str] = None,
self.is_real_user = self.custom_mxid and self.access_token displayname_source: Optional[TelegramId] = None,
self.default_mxid = self.get_mxid_from_id(self.id) photo_id: Optional[str] = None,
self.mxid = self.custom_mxid or self.default_mxid is_bot: bool = False,
is_registered: bool = False,
db_instance: Optional[DBPuppet] = None) -> None:
self.id = id # type: TelegramId
self.access_token = access_token # type: Optional[str]
self.custom_mxid = custom_mxid # type: Optional[MatrixUserId]
self.default_mxid = self.get_mxid_from_id(self.id) # type: MatrixUserId
self.username = username self.username = username # type: Optional[str]
self.displayname = displayname self.displayname = displayname # type: Optional[str]
self.displayname_source = displayname_source self.displayname_source = displayname_source # type: Optional[TelegramId]
self.photo_id = photo_id self.photo_id = photo_id # type: Optional[str]
self.is_bot = is_bot self.is_bot = is_bot # type: bool
self.is_registered = is_registered self.is_registered = is_registered # type: bool
self._db_instance = db_instance self._db_instance = db_instance # type: Optional[DBPuppet]
self.default_mxid_intent = self.az.intent.user(self.default_mxid) self.default_mxid_intent = self.az.intent.user(self.default_mxid)
self.intent = None # type: IntentAPI self.intent = self._fresh_intent() # type: IntentAPI
self.refresh_intents()
self.cache[id] = self self.cache[id] = self
if self.custom_mxid: if self.custom_mxid:
self.by_custom_mxid[self.custom_mxid] = self self.by_custom_mxid[self.custom_mxid] = self
@property @property
def tgid(self) -> None: def mxid(self):
return self.custom_mxid or self.default_mxid
@property
def tgid(self) -> TelegramId:
return self.id return self.id
@property
def is_real_user(self) -> bool:
""" Is True when the puppet is a real Matrix user. """
return bool(self.custom_mxid and self.access_token)
@staticmethod @staticmethod
async def is_logged_in() -> None: async def is_logged_in() -> bool:
""" Is True if the puppet is logged in. """
return True return True
# region Custom puppet management # region Custom puppet management
def refresh_intents(self) -> None: def _fresh_intent(self) -> IntentAPI:
self.is_real_user = self.custom_mxid and self.access_token return (self.az.intent.user(self.custom_mxid, self.access_token)
self.intent = (self.az.intent.user(self.custom_mxid, self.access_token) if self.is_real_user else self.default_mxid_intent)
if self.is_real_user else self.default_mxid_intent)
async def switch_mxid(self, access_token, mxid) -> None: async def switch_mxid(self, access_token: str, mxid: MatrixUserId) -> PuppetError:
prev_mxid = self.custom_mxid prev_mxid = self.custom_mxid
self.custom_mxid = mxid self.custom_mxid = mxid
self.access_token = access_token self.access_token = access_token
self.refresh_intents() self.intent = self._fresh_intent()
err = await self.init_custom_mxid() err = await self.init_custom_mxid()
if err != 0: if err != PuppetError.Success:
return err return err
try: try:
del self.by_custom_mxid[prev_mxid] del self.by_custom_mxid[prev_mxid] # type: ignore
except KeyError: except KeyError:
pass pass
self.mxid = self.custom_mxid or self.default_mxid
if self.mxid != self.default_mxid: if self.mxid != self.default_mxid:
self.by_custom_mxid[self.mxid] = self self.by_custom_mxid[self.mxid] = self
await self.leave_rooms_with_default_user() await self.leave_rooms_with_default_user()
self.save() self.save()
return 0 return PuppetError.Success
async def init_custom_mxid(self) -> None: async def init_custom_mxid(self) -> PuppetError:
if not self.is_real_user: if not self.is_real_user:
return 0 return PuppetError.Success
mxid = await self.intent.whoami() mxid = await self.intent.whoami()
if not mxid or mxid != self.custom_mxid: if not mxid or mxid != self.custom_mxid:
self.custom_mxid = None self.custom_mxid = None
self.access_token = None self.access_token = None
self.refresh_intents() self.intent = self._fresh_intent()
if mxid != self.custom_mxid: if mxid != self.custom_mxid:
return 2 return PuppetError.OnlyLoginSelf
return 1 return PuppetError.InvalidAccessToken
if config["bridge.sync_with_custom_puppets"]: if config["bridge.sync_with_custom_puppets"]:
asyncio.ensure_future(self.sync(), loop=self.loop) asyncio.ensure_future(self.sync(), loop=self.loop)
return 0 return PuppetError.Success
async def leave_rooms_with_default_user(self) -> None: async def leave_rooms_with_default_user(self) -> None:
for room_id in await self.default_mxid_intent.get_joined_rooms(): for room_id in await self.default_mxid_intent.get_joined_rooms():
@@ -159,7 +179,7 @@ class Puppet:
}, },
}) })
def filter_events(self, events) -> None: def filter_events(self, events: List[Dict]) -> List:
new_events = [] new_events = []
for event in events: for event in events:
evt_type = event.get("type", None) evt_type = event.get("type", None)
@@ -186,18 +206,18 @@ class Puppet:
new_events.append(event) new_events.append(event)
return new_events return new_events
def handle_sync(self, presence, ephemeral) -> None: def handle_sync(self, presence: List, ephemeral: Dict) -> None:
presence = [self.mx.try_handle_event(event) for event in presence] presence_events = [self.mx.try_handle_event(event) for event in presence]
for room_id, events in ephemeral.items(): for room_id, events in ephemeral.items():
for event in events: for event in events:
event["room_id"] = room_id event["room_id"] = room_id
ephemeral = [self.mx.try_handle_event(event) ephemeral_events = [self.mx.try_handle_event(event)
for events in ephemeral.values() for events in ephemeral.values()
for event in self.filter_events(events)] for event in self.filter_events(events)]
events = ephemeral + presence events = ephemeral_events + presence_events # List[Callable[[int], Awaitable[None]]]
coro = asyncio.gather(*events, loop=self.loop) coro = asyncio.gather(*events, loop=self.loop)
asyncio.ensure_future(coro, loop=self.loop) asyncio.ensure_future(coro, loop=self.loop)
@@ -220,13 +240,14 @@ class Puppet:
while access_token_at_start == self.access_token: while access_token_at_start == self.access_token:
try: try:
sync_resp = await self.intent.client.sync(filter=filter_id, since=next_batch, sync_resp = await self.intent.client.sync(filter=filter_id, since=next_batch,
set_presence="offline") set_presence="offline") # type: Dict
errors = 0 errors = 0
if next_batch is not None: if next_batch is not None:
presence = sync_resp.get("presence", {}).get("events", []) presence = sync_resp.get("presence", {}).get("events", []) # type: List
ephemeral = {room: data.get("ephemeral", {}).get("events", []) ephemeral = {room: data.get("ephemeral", {}).get("events", [])
for room, data for room, data
in sync_resp.get("rooms", {}).get("join", {}).items()} in sync_resp.get("rooms", {}).get("join", {}).items()
} # type: Dict
self.handle_sync(presence, ephemeral) self.handle_sync(presence, ephemeral)
next_batch = sync_resp.get("next_batch", None) next_batch = sync_resp.get("next_batch", None)
except MatrixRequestError as e: except MatrixRequestError as e:
@@ -241,19 +262,19 @@ class Puppet:
# region DB conversion # region DB conversion
@property @property
def db_instance(self) -> None: def db_instance(self) -> DBPuppet:
if not self._db_instance: if not self._db_instance:
self._db_instance = self.new_db_instance() self._db_instance = self.new_db_instance()
return self._db_instance return self._db_instance
def new_db_instance(self) -> None: def new_db_instance(self) -> DBPuppet:
return DBPuppet(id=self.id, access_token=self.access_token, custom_mxid=self.custom_mxid, return DBPuppet(id=self.id, access_token=self.access_token, custom_mxid=self.custom_mxid,
username=self.username, displayname=self.displayname, username=self.username, displayname=self.displayname,
displayname_source=self.displayname_source, photo_id=self.photo_id, displayname_source=self.displayname_source, photo_id=self.photo_id,
is_bot=self.is_bot, matrix_registered=self.is_registered) is_bot=self.is_bot, matrix_registered=self.is_registered)
@classmethod @classmethod
def from_db(cls, db_puppet) -> None: def from_db(cls, db_puppet: DBPuppet) -> 'Puppet':
return Puppet(db_puppet.id, db_puppet.access_token, db_puppet.custom_mxid, return Puppet(db_puppet.id, db_puppet.access_token, db_puppet.custom_mxid,
db_puppet.username, db_puppet.displayname, db_puppet.displayname_source, db_puppet.username, db_puppet.displayname, db_puppet.displayname_source,
db_puppet.photo_id, db_puppet.is_bot, db_puppet.matrix_registered, db_puppet.photo_id, db_puppet.is_bot, db_puppet.matrix_registered,
@@ -272,16 +293,16 @@ class Puppet:
# endregion # endregion
# region Info updating # region Info updating
def similarity(self, query) -> None: def similarity(self, query: str) -> int:
username_similarity = (SequenceMatcher(None, self.username, query).ratio() username_similarity = (SequenceMatcher(None, self.username, query).ratio()
if self.username else 0) if self.username else 0)
displayname_similarity = (SequenceMatcher(None, self.displayname, query).ratio() displayname_similarity = (SequenceMatcher(None, self.displayname, query).ratio()
if self.displayname else 0) if self.displayname else 0)
similarity = max(username_similarity, displayname_similarity) similarity = max(username_similarity, displayname_similarity)
return round(similarity * 1000) / 10 return int(round(similarity * 1000) / 10)
@staticmethod @staticmethod
def get_displayname(info, enable_format=True) -> None: def get_displayname(info: User, enable_format: bool = True) -> str:
data = { data = {
"phone number": info.phone if hasattr(info, "phone") else None, "phone number": info.phone if hasattr(info, "phone") else None,
"username": info.username, "username": info.username,
@@ -308,7 +329,7 @@ class Puppet:
return config.get("bridge.displayname_template", "{displayname} (Telegram)").format( return config.get("bridge.displayname_template", "{displayname} (Telegram)").format(
displayname=name) displayname=name)
async def update_info(self, source, info) -> None: async def update_info(self, source: 'AbstractUser', info: User) -> None:
changed = False changed = False
if self.username != info.username: if self.username != info.username:
self.username = info.username self.username = info.username
@@ -323,24 +344,26 @@ class Puppet:
if changed: if changed:
self.save() self.save()
async def update_displayname(self, source, info) -> None: async def update_displayname(self, source: 'AbstractUser', info: User) -> bool:
ignore_source = (not source.is_relaybot ignore_source = (not source.is_relaybot
and self.displayname_source is not None and self.displayname_source is not None
and self.displayname_source != source.tgid) and self.displayname_source != source.tgid)
if ignore_source: if ignore_source:
return return False
displayname = self.get_displayname(info) displayname = self.get_displayname(info)
if displayname != self.displayname: if displayname != self.displayname:
await self.default_mxid_intent.set_display_name(displayname) await self.default_mxid_intent.set_display_name(displayname)
self.displayname = displayname self.displayname = displayname
self.displayname_source = source.tgid self.displayname_source = TelegramId(source.tgid)
return True return True
elif source.is_relaybot or self.displayname_source is None: elif source.is_relaybot or self.displayname_source is None:
self.displayname_source = source.tgid self.displayname_source = TelegramId(source.tgid)
return True return True
else:
return False
async def update_avatar(self, source, photo) -> None: async def update_avatar(self, source: 'AbstractUser', photo: FileLocation) -> bool:
photo_id = f"{photo.volume_id}-{photo.local_id}" photo_id = f"{photo.volume_id}-{photo.local_id}"
if self.photo_id != photo_id: if self.photo_id != photo_id:
file = await util.transfer_file_to_matrix(self.db, source.client, file = await util.transfer_file_to_matrix(self.db, source.client,
@@ -355,7 +378,7 @@ class Puppet:
# region Getters # region Getters
@classmethod @classmethod
def get(cls, tgid, create=True) -> "Optional[Puppet]": def get(cls, tgid: TelegramId, create: bool = True) -> Optional['Puppet']:
try: try:
return cls.cache[tgid] return cls.cache[tgid]
except KeyError: except KeyError:
@@ -374,12 +397,15 @@ class Puppet:
return None return None
@classmethod @classmethod
def get_by_mxid(cls, mxid, create=True) -> "Optional[Puppet]": def get_by_mxid(cls, mxid: MatrixUserId, create: bool = True) -> Optional['Puppet']:
tgid = cls.get_id_from_mxid(mxid) tgid = cls.get_id_from_mxid(mxid)
return cls.get(tgid, create) if tgid else None if tgid:
return cls.get(tgid, create)
return None
@classmethod @classmethod
def get_by_custom_mxid(cls, mxid) -> None: def get_by_custom_mxid(cls, mxid: MatrixUserId) -> Optional['Puppet']:
if not mxid: if not mxid:
raise ValueError("Matrix ID can't be empty") raise ValueError("Matrix ID can't be empty")
@@ -396,25 +422,25 @@ class Puppet:
return None return None
@classmethod @classmethod
def get_all_with_custom_mxid(cls) -> None: def get_all_with_custom_mxid(cls) -> List['Puppet']:
return [cls.by_custom_mxid[puppet.mxid] return [cls.by_custom_mxid[puppet.mxid]
if puppet.custom_mxid in cls.by_custom_mxid if puppet.custom_mxid in cls.by_custom_mxid
else cls.from_db(puppet) else cls.from_db(puppet)
for puppet in DBPuppet.query.filter(DBPuppet.custom_mxid is not None).all()] for puppet in DBPuppet.query.filter(DBPuppet.custom_mxid is not None).all()]
@classmethod @classmethod
def get_id_from_mxid(cls, mxid) -> None: def get_id_from_mxid(cls, mxid: MatrixUserId) -> Optional[TelegramId]:
match = cls.mxid_regex.match(mxid) match = cls.mxid_regex.match(mxid)
if match: if match:
return int(match.group(1)) return TelegramId(int(match.group(1)))
return None return None
@classmethod @classmethod
def get_mxid_from_id(cls, tgid) -> None: def get_mxid_from_id(cls, tgid: TelegramId) -> MatrixUserId:
return f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}" return MatrixUserId(f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}")
@classmethod @classmethod
def find_by_username(cls, username) -> "Optional[Puppet]": def find_by_username(cls, username: str) -> Optional['Puppet']:
if not username: if not username:
return None return None
@@ -422,14 +448,14 @@ class Puppet:
if puppet.username and puppet.username.lower() == username.lower(): if puppet.username and puppet.username.lower() == username.lower():
return puppet return puppet
puppet = DBPuppet.query.filter(DBPuppet.username == username).one_or_none() dbpuppet = DBPuppet.query.filter(DBPuppet.username == username).one_or_none()
if puppet: if dbpuppet:
return cls.from_db(puppet) return cls.from_db(dbpuppet)
return None return None
@classmethod @classmethod
def find_by_displayname(cls, displayname) -> "Optional[Puppet]": def find_by_displayname(cls, displayname: str) -> Optional['Puppet']:
if not displayname: if not displayname:
return None return None
@@ -437,17 +463,17 @@ class Puppet:
if puppet.displayname and puppet.displayname == displayname: if puppet.displayname and puppet.displayname == displayname:
return puppet return puppet
puppet = DBPuppet.query.filter(DBPuppet.displayname == displayname).one_or_none() dbpuppet = DBPuppet.query.filter(DBPuppet.displayname == displayname).one_or_none()
if puppet: if dbpuppet:
return cls.from_db(puppet) return cls.from_db(dbpuppet)
return None return None
# endregion # endregion
def init(context: "Context") -> List[Awaitable[int]]: def init(context: 'Context') -> List[Coroutine]: # [None, None, PuppetError]
global config global config
Puppet.az, Puppet.db, config, Puppet.loop, _ = context Puppet.az, Puppet.db, config, Puppet.loop, _ = context.core
Puppet.mx = context.mx Puppet.mx = context.mx
Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}") Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}")
Puppet.hs_domain = config["homeserver"]["domain"] Puppet.hs_domain = config["homeserver"]["domain"]
@@ -40,7 +40,7 @@ telematrix_db_engine.dispose()
portals = {} portals = {}
chats = {} chats = {}
messages = {} messages = {}
puppets = {} puppets = {} # Dict[int, Puppet]
for chat_link in chat_links: for chat_link in chat_links:
if type(chat_link.tg_room) is str: if type(chat_link.tg_room) is str:
+15 -13
View File
@@ -20,37 +20,39 @@ from sqlalchemy import orm
from mautrix_appservice import StateStore from mautrix_appservice import StateStore
from .types import MatrixUserId, MatrixRoomId
from . import puppet as pu from . import puppet as pu
from .db import RoomState, UserProfile from .db import RoomState, UserProfile
class SQLStateStore(StateStore): class SQLStateStore(StateStore):
def __init__(self, db) -> None: def __init__(self, db: orm.Session) -> None:
super().__init__() super().__init__()
self.db = db # type: orm.Session self.db = db # type: orm.Session
self.profile_cache = {} # type: Dict[Tuple[str, str], UserProfile] self.profile_cache = {} # type: Dict[Tuple[str, str], UserProfile]
self.room_state_cache = {} # type: Dict[str, RoomState] self.room_state_cache = {} # type: Dict[str, RoomState]
@staticmethod @staticmethod
def is_registered(user: str) -> bool: def is_registered(user: MatrixUserId) -> bool:
puppet = pu.Puppet.get_by_mxid(user) puppet = pu.Puppet.get_by_mxid(user)
return puppet.is_registered if puppet else False return puppet.is_registered if puppet else False
@staticmethod @staticmethod
def registered(user: str) -> None: def registered(user: MatrixUserId) -> None:
puppet = pu.Puppet.get_by_mxid(user) puppet = pu.Puppet.get_by_mxid(user)
if puppet: if puppet:
puppet.is_registered = True puppet.is_registered = True
puppet.save() puppet.save()
def update_state(self, event: dict) -> None: def update_state(self, event: Dict) -> None:
event_type = event["type"] event_type = event["type"]
if event_type == "m.room.power_levels": if event_type == "m.room.power_levels":
self.set_power_levels(event["room_id"], event["content"]) self.set_power_levels(event["room_id"], event["content"])
elif event_type == "m.room.member": elif event_type == "m.room.member":
self.set_member(event["room_id"], event["state_key"], event["content"]) self.set_member(event["room_id"], event["state_key"], event["content"])
def _get_user_profile(self, room_id: str, user_id: str, create: bool = True) -> UserProfile: def _get_user_profile(self, room_id: MatrixRoomId, user_id: MatrixUserId, create: bool = True
) -> UserProfile:
key = (room_id, user_id) key = (room_id, user_id)
try: try:
return self.profile_cache[key] return self.profile_cache[key]
@@ -67,22 +69,22 @@ class SQLStateStore(StateStore):
self.profile_cache[key] = profile self.profile_cache[key] = profile
return profile return profile
def get_member(self, room: str, user: str) -> dict: def get_member(self, room: MatrixRoomId, user: MatrixUserId) -> Dict:
return self._get_user_profile(room, user).dict() return self._get_user_profile(room, user).dict()
def set_member(self, room: str, user: str, member: dict) -> None: def set_member(self, room: MatrixRoomId, user: MatrixUserId, member: Dict) -> None:
profile = self._get_user_profile(room, user) profile = self._get_user_profile(room, user)
profile.membership = member.get("membership", profile.membership or "leave") profile.membership = member.get("membership", profile.membership or "leave")
profile.displayname = member.get("displayname", profile.displayname) profile.displayname = member.get("displayname", profile.displayname)
profile.avatar_url = member.get("avatar_url", profile.avatar_url) profile.avatar_url = member.get("avatar_url", profile.avatar_url)
self.db.commit() self.db.commit()
def set_membership(self, room: str, user: str, membership: str) -> None: def set_membership(self, room: MatrixRoomId, user: MatrixUserId, membership: str) -> None:
self.set_member(room, user, { self.set_member(room, user, {
"membership": membership, "membership": membership,
}) })
def _get_room_state(self, room_id: str, create: bool = True) -> RoomState: def _get_room_state(self, room_id: MatrixRoomId, create: bool = True) -> RoomState:
try: try:
return self.room_state_cache[room_id] return self.room_state_cache[room_id]
except KeyError: except KeyError:
@@ -96,13 +98,13 @@ class SQLStateStore(StateStore):
self.room_state_cache[room_id] = room self.room_state_cache[room_id] = room
return room return room
def has_power_levels(self, room: str) -> bool: def has_power_levels(self, room: MatrixRoomId) -> bool:
return self._get_room_state(room).has_power_levels return self._get_room_state(room).has_power_levels
def get_power_levels(self, room: str) -> dict: def get_power_levels(self, room: MatrixRoomId) -> Dict:
return self._get_room_state(room).power_levels return self._get_room_state(room).power_levels
def set_power_level(self, room: str, user: str, level: int) -> None: def set_power_level(self, room: MatrixRoomId, user: MatrixUserId, level: int) -> None:
room_state = self._get_room_state(room) room_state = self._get_room_state(room)
power_levels = room_state.power_levels power_levels = room_state.power_levels
if not power_levels: if not power_levels:
@@ -114,7 +116,7 @@ class SQLStateStore(StateStore):
room_state.power_levels = power_levels room_state.power_levels = power_levels
self.db.commit() self.db.commit()
def set_power_levels(self, room: str, content: dict) -> None: def set_power_levels(self, room: MatrixRoomId, content: Dict) -> None:
state = self._get_room_state(room) state = self._get_room_state(room)
state.power_levels = content state.power_levels = content
self.db.commit() self.db.commit()
+10
View File
@@ -0,0 +1,10 @@
from typing import Dict, NewType
# MatrixId = NewType('MatrixId', str)
MatrixUserId = NewType('MatrixUserId', str)
MatrixRoomId = NewType('MatrixRoomId', str)
MatrixEventId = NewType('MatrixEventId', str)
MatrixEvent = NewType('MatrixEvent', Dict)
TelegramId = NewType('TelegramId', int)
+26 -22
View File
@@ -14,7 +14,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Dict, List, Match, Optional, Tuple, TYPE_CHECKING from typing import Coroutine, Dict, List, Match, Optional, Tuple, cast, TYPE_CHECKING
import logging import logging
import asyncio import asyncio
import re import re
@@ -28,6 +28,7 @@ from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest
from telethon.tl.functions.account import UpdateStatusRequest from telethon.tl.functions.account import UpdateStatusRequest
from mautrix_appservice import MatrixRequestError from mautrix_appservice import MatrixRequestError
from .types import MatrixUserId, TelegramId
from .db import User as DBUser, Contact as DBContact, Portal as DBPortal from .db import User as DBUser, Contact as DBContact, Portal as DBPortal
from .abstract_user import AbstractUser from .abstract_user import AbstractUser
from . import portal as po, puppet as pu from . import portal as po, puppet as pu
@@ -46,23 +47,23 @@ class User(AbstractUser):
by_mxid = {} # type: Dict[str, User] by_mxid = {} # type: Dict[str, User]
by_tgid = {} # type: Dict[int, User] by_tgid = {} # type: Dict[int, User]
def __init__(self, mxid: str, tgid: Optional[int] = None, username: Optional[str] = None, def __init__(self, mxid: MatrixUserId, tgid: Optional[TelegramId] = None,
db_contacts: Optional[List[DBContact]] = None, saved_contacts: int = 0, username: Optional[str] = None, db_contacts: Optional[List[DBContact]] = None,
is_bot: bool = False, db_portals: Optional[List[DBPortal]] = None, saved_contacts: int = 0, is_bot: bool = False, db_portals: List[DBPortal] = [],
db_instance: Optional[DBUser] = None) -> None: db_instance: Optional[DBUser] = None) -> None:
super().__init__() super().__init__()
self.mxid = mxid # type: str self.mxid = mxid # type: MatrixUserId
self.tgid = tgid # type: int self.tgid = tgid # type: TelegramId
self.is_bot = is_bot # type: bool self.is_bot = is_bot # type: bool
self.username = username # type: str self.username = username # type: str
self.contacts = [] # type: List[pu.Puppet] self.contacts = [] # type: List[pu.Puppet]
self.saved_contacts = saved_contacts # type: int self.saved_contacts = saved_contacts # type: int
self.db_contacts = db_contacts # type: List[DBContact] self.db_contacts = db_contacts # type: List[DBContact]
self.portals = {} # type: Dict[Tuple[int, int], po.Portal] self.portals = {} # type: Dict[Tuple[int, int], po.Portal]
self.db_portals = db_portals # type: List[DBPortal] self.db_portals = db_portals or [] # type: List[DBPortal]
self._db_instance = db_instance # type: DBUser self._db_instance = db_instance # type: Optional[DBUser]
self.command_status = None # type: dict self.command_status = None # type: Dict
(self.relaybot_whitelisted, (self.relaybot_whitelisted,
self.whitelisted, self.whitelisted,
@@ -169,9 +170,9 @@ class User(AbstractUser):
except Exception: except Exception:
self.log.exception("Failed to run post-login functions for %s", self.mxid) self.log.exception("Failed to run post-login functions for %s", self.mxid)
async def update(self, update: TypeUpdate) -> None: async def update(self, update: TypeUpdate) -> bool:
if not self.is_bot: if not self.is_bot:
return return False
if isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)): if isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)):
message = update.message message = update.message
@@ -185,19 +186,22 @@ class User(AbstractUser):
elif isinstance(update, UpdateShortMessage): elif isinstance(update, UpdateShortMessage):
portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user") portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user")
else: else:
return return False
self.register_portal(portal) if portal:
self.register_portal(portal)
return True
# endregion # endregion
# region Telegram actions that need custom methods # region Telegram actions that need custom methods
def ensure_started(self, even_if_no_session: bool = False) -> "Awaitable[User]": def ensure_started(self, even_if_no_session: bool = False) -> Coroutine[None, None, 'User']:
return super().ensure_started(even_if_no_session) return cast(Coroutine[None, None, 'User'], super().ensure_started(even_if_no_session))
def set_presence(self, online: bool = True) -> None: def set_presence(self, online: bool = True) -> bool:
if self.is_bot: if self.is_bot:
return return False
return self.client(UpdateStatusRequest(offline=not online)) return self.client(UpdateStatusRequest(offline=not online))
async def update_info(self, info: TLUser = None) -> None: async def update_info(self, info: TLUser = None) -> None:
@@ -215,7 +219,7 @@ class User(AbstractUser):
if changed: if changed:
self.save() self.save()
async def log_out(self) -> None: async def log_out(self) -> bool:
puppet = pu.Puppet.get(self.tgid) puppet = pu.Puppet.get(self.tgid)
if puppet.is_real_user: if puppet.is_real_user:
await puppet.switch_mxid(None, None) await puppet.switch_mxid(None, None)
@@ -328,7 +332,7 @@ class User(AbstractUser):
# region Class instance lookup # region Class instance lookup
@classmethod @classmethod
def get_by_mxid(cls, mxid: str, create: bool=True) -> "Optional[User]": def get_by_mxid(cls, mxid: MatrixUserId, create: bool = True) -> Optional['User']:
if not mxid: if not mxid:
raise ValueError("Matrix ID can't be empty") raise ValueError("Matrix ID can't be empty")
@@ -351,7 +355,7 @@ class User(AbstractUser):
return None return None
@classmethod @classmethod
def get_by_tgid(cls, tgid: int) -> "Optional[User]": def get_by_tgid(cls, tgid: int) -> Optional['User']:
try: try:
return cls.by_tgid[tgid] return cls.by_tgid[tgid]
except KeyError: except KeyError:
@@ -365,7 +369,7 @@ class User(AbstractUser):
return None return None
@classmethod @classmethod
def find_by_username(cls, username: str) -> "Optional[User]": def find_by_username(cls, username: str) -> Optional['User']:
if not username: if not username:
return None return None
@@ -381,7 +385,7 @@ class User(AbstractUser):
# endregion # endregion
def init(context: "Context") -> List[Awaitable[User]]: def init(context: 'Context') -> List[Coroutine]: # [None, None, AbstractUser]
global config global config
config = context.config config = context.config
+2 -2
View File
@@ -17,10 +17,10 @@
def format_duration(seconds: int) -> str: def format_duration(seconds: int) -> str:
def pluralize(count, singular) -> None: def pluralize(count: int, singular: str) -> str:
return singular if count == 1 else singular + "s" return singular if count == 1 else singular + "s"
def include(count, word) -> None: def include(count: int, word: str) -> str:
return f"{count} {pluralize(count, word)}" if count > 0 else "" return f"{count} {pluralize(count, word)}" if count > 0 else ""
minutes, seconds = divmod(seconds, 60) minutes, seconds = divmod(seconds, 60)
+6 -6
View File
@@ -14,7 +14,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional from typing import Dict, Optional
import json import json
import base64 import base64
import hashlib import hashlib
@@ -28,13 +28,13 @@ def _get_checksum(key: str, payload: bytes) -> str:
return checksum return checksum
def sign_token(key: str, payload: dict) -> str: def sign_token(key: str, payload: Dict) -> str:
payload = base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8")) payload_b64 = base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8"))
checksum = _get_checksum(key, payload) checksum = _get_checksum(key, payload_b64)
return f"{checksum}:{payload.decode('utf-8')}" return f"{checksum}:{payload_b64.decode('utf-8')}"
def verify_token(key: str, data: str) -> Optional[dict]: def verify_token(key: str, data: str) -> Optional[Dict]:
if not data: if not data:
return None return None
+4 -3
View File
@@ -23,7 +23,7 @@ from telethon.errors import *
from ...commands.auth import enter_password from ...commands.auth import enter_password
from ...util import format_duration from ...util import format_duration
from ...puppet import Puppet from ...puppet import Puppet, PuppetError
from ...user import User from ...user import User
@@ -51,12 +51,13 @@ class AuthAPI(abc.ABC):
"account.", errcode="already-logged-in") "account.", errcode="already-logged-in")
resp = await puppet.switch_mxid(token, user.mxid) resp = await puppet.switch_mxid(token, user.mxid)
if resp == 2: if resp == PuppetError.OnlyLoginSelf:
return self.get_mx_login_response(status=403, errcode="only-login-self", return self.get_mx_login_response(status=403, errcode="only-login-self",
error="You can only log in as your own Matrix user.") error="You can only log in as your own Matrix user.")
elif resp == 1: elif resp == PuppetError.InvalidAccessToken:
return self.get_mx_login_response(status=401, errcode="invalid-access-token", return self.get_mx_login_response(status=401, errcode="invalid-access-token",
error="Failed to verify access token.") error="Failed to verify access token.")
assert resp == PuppetError.Success, "Encountered an unhandled PuppetError."
return self.get_mx_login_response(mxid=user.mxid, status=200, state="logged-in") return self.get_mx_login_response(mxid=user.mxid, status=200, state="logged-in")
@@ -15,7 +15,7 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import web from aiohttp import web
from typing import Tuple, Optional, Callable, Awaitable, TYPE_CHECKING from typing import Awaitable, Callable, Dict, Optional, Tuple, TYPE_CHECKING
import asyncio import asyncio
import logging import logging
import json import json
@@ -24,6 +24,7 @@ from telethon.utils import get_peer_id, resolve_id
from telethon.tl.types import ChatForbidden, ChannelForbidden, TypeChat from telethon.tl.types import ChatForbidden, ChannelForbidden, TypeChat
from mautrix_appservice import AppService, MatrixRequestError, IntentError from mautrix_appservice import AppService, MatrixRequestError, IntentError
from ...types import MatrixUserId
from ...user import User from ...user import User
from ...portal import Portal from ...portal import Portal
from ...commands.portal import user_has_power_level, get_initial_state from ...commands.portal import user_has_power_level, get_initial_state
@@ -36,7 +37,7 @@ if TYPE_CHECKING:
class ProvisioningAPI(AuthAPI): class ProvisioningAPI(AuthAPI):
log = logging.getLogger("mau.web.provisioning") log = logging.getLogger("mau.web.provisioning")
def __init__(self, context: "Context"): def __init__(self, context: "Context") -> None:
super().__init__(context.loop) super().__init__(context.loop)
self.secret = context.config["appservice.provisioning.shared_secret"] self.secret = context.config["appservice.provisioning.shared_secret"]
self.az = context.az # type: AppService self.az = context.az # type: AppService
@@ -411,7 +412,7 @@ class ProvisioningAPI(AuthAPI):
except json.JSONDecodeError: except json.JSONDecodeError:
return None return None
async def get_user(self, mxid: str, expect_logged_in: Optional[bool] = False, async def get_user(self, mxid: MatrixUserId, expect_logged_in: Optional[bool] = False,
require_puppeting: bool = True, require_user: bool = True require_puppeting: bool = True, require_user: bool = True
) -> Tuple[Optional[User], Optional[web.Response]]: ) -> Tuple[Optional[User], Optional[web.Response]]:
if not mxid: if not mxid:
@@ -439,7 +440,7 @@ class ProvisioningAPI(AuthAPI):
expect_logged_in: Optional[bool] = False, expect_logged_in: Optional[bool] = False,
require_puppeting: bool = False, require_puppeting: bool = False,
want_data: bool = True, want_data: bool = True,
) -> (Tuple[Optional[dict], ) -> (Tuple[Optional[Dict],
Optional[User], Optional[User],
Optional[web.Response]]): Optional[web.Response]]):
err = self.check_authorization(request) err = self.check_authorization(request)