Add more type hints

This commit is contained in:
Tulir Asokan
2018-07-25 10:40:31 -04:00
parent ae334b9a04
commit dbfb980bde
20 changed files with 751 additions and 595 deletions
+19 -20
View File
@@ -14,34 +14,33 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional
import argparse import argparse
import sys
import logging
import logging.config
import asyncio import asyncio
import logging.config
import sys
import sqlalchemy as sql
from sqlalchemy import orm from sqlalchemy import orm
import sqlalchemy as sql
from alchemysession import AlchemySessionContainer
from mautrix_appservice import AppService from mautrix_appservice import AppService
from alchemysession import AlchemySessionContainer
from .base import Base from .web.provisioning import ProvisioningAPI
from .config import Config from .web.public import PublicBridgeWebsite
from .matrix import MatrixHandler
from . import __version__
from .db import init as init_db
from .abstract_user import init as init_abstract_user from .abstract_user import init as init_abstract_user
from .user import init as init_user, User from .base import Base
from .bot import init as init_bot from .bot import init as init_bot
from .config import Config
from .context import Context
from .db import init as init_db
from .formatter import init as init_formatter
from .matrix import MatrixHandler
from .portal import init as init_portal from .portal import init as init_portal
from .puppet import init as init_puppet from .puppet import init as init_puppet
from .formatter import init as init_formatter
from .web.public import PublicBridgeWebsite
from .web.provisioning import ProvisioningAPI
from .context import Context
from .sqlstatestore import SQLStateStore from .sqlstatestore import SQLStateStore
from .user import User, init as init_user
from . import __version__
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="A Matrix-Telegram puppeting bridge.", description="A Matrix-Telegram puppeting bridge.",
@@ -68,7 +67,7 @@ if args.generate_registration:
sys.exit(0) sys.exit(0)
logging.config.dictConfig(config["logging"]) logging.config.dictConfig(config["logging"])
log = logging.getLogger("mau.init") log = logging.getLogger("mau.init") # type: logging.Logger
log.debug(f"Initializing mautrix-telegram {__version__}") log.debug(f"Initializing mautrix-telegram {__version__}")
db_engine = sql.create_engine(config["appservice.database"] or "sqlite:///mautrix-telegram.db") db_engine = sql.create_engine(config["appservice.database"] or "sqlite:///mautrix-telegram.db")
@@ -80,7 +79,7 @@ session_container = AlchemySessionContainer(engine=db_engine, session=db_session
table_base=Base, table_prefix="telethon_", table_base=Base, table_prefix="telethon_",
manage_tables=False) manage_tables=False)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop() # type: asyncio.AbstractEventLoop
state_store = SQLStateStore(db_session) state_store = SQLStateStore(db_session)
appserv = AppService(config["homeserver.address"], config["homeserver.domain"], appserv = AppService(config["homeserver.address"], config["homeserver.domain"],
@@ -89,8 +88,8 @@ appserv = AppService(config["homeserver.address"], config["homeserver.domain"],
verify_ssl=config["homeserver.verify_ssl"], state_store=state_store, verify_ssl=config["homeserver.verify_ssl"], state_store=state_store,
real_user_content_key="net.maunium.telegram.puppet") real_user_content_key="net.maunium.telegram.puppet")
public_website = None public_website = None # type: Optional[PublicBridgeWebsite]
provisioning_api = None provisioning_api = None # type: Optional[ProvisioningAPI]
if config["appservice.public.enabled"]: if config["appservice.public.enabled"]:
public_website = PublicBridgeWebsite(loop) public_website = PublicBridgeWebsite(loop)
+81 -44
View File
@@ -14,26 +14,48 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Tuple, Optional, List, Union, TYPE_CHECKING
from abc import ABC, abstractmethod
import asyncio
import logging
import platform import platform
from telethon.tl.types import * from sqlalchemy import orm
from mautrix_appservice import MatrixRequestError from telethon.tl.types import Channel, ChannelForbidden, Chat, ChatForbidden, Message, \
MessageActionChannelMigrateFrom, MessageService, PeerUser, TypeUpdate, \
UpdateChannelPinnedMessage, UpdateChatAdmins, UpdateChatParticipantAdmin, \
UpdateChatParticipants, UpdateChatUserTyping, UpdateDeleteChannelMessages, \
UpdateDeleteMessages, UpdateEditChannelMessage, UpdateEditMessage, UpdateNewChannelMessage, \
UpdateNewMessage, UpdateReadHistoryOutbox, UpdateShortChatMessage, UpdateShortMessage, \
UpdateUserName, UpdateUserPhoto, UpdateUserStatus, UpdateUserTyping, User, UserStatusOffline, \
UserStatusOnline
from mautrix_appservice import MatrixRequestError, AppService
from alchemysession import AlchemySessionContainer
from .tgclient import MautrixTelegramClient
from .db import Message as DBMessage
from . import portal as po, puppet as pu, __version__ from . import portal as po, puppet as pu, __version__
from .db import Message as DBMessage
from .tgclient import MautrixTelegramClient
config = None if TYPE_CHECKING:
from .context import Context
from .config import Config
config = None # type: Config
# Value updated from config in init() # Value updated from config in init()
MAX_DELETIONS = 10 MAX_DELETIONS = 10 # type: int
UpdateMessage = Union[UpdateShortChatMessage, UpdateShortMessage, UpdateNewChannelMessage,
UpdateNewMessage, UpdateEditMessage, UpdateEditChannelMessage]
UpdateMessageContent = Union[UpdateShortMessage, UpdateShortChatMessage, Message, MessageService]
class AbstractUser: class AbstractUser(ABC):
session_container = None session_container = None # type: AlchemySessionContainer
loop = None loop = None # type: asyncio.AbstractEventLoop
log = None log = None # type: logging.Logger
db = None db = None # type: orm.Session
az = None az = None # type: AppService
def __init__(self): def __init__(self):
self.puppet_whitelisted = False # type: bool self.puppet_whitelisted = False # type: bool
@@ -47,22 +69,22 @@ class AbstractUser:
self.is_bot = False # type: bool self.is_bot = False # type: bool
@property @property
def connected(self): def connected(self) -> bool:
return self.client and self.client.is_connected() return self.client and self.client.is_connected()
@property @property
def _proxy_settings(self): def _proxy_settings(self) -> Optional[Tuple[int, str, str, str, str, str]]:
type = config["telegram.proxy.type"].lower() proxy_type = config["telegram.proxy.type"].lower()
if type == "disabled": if proxy_type == "disabled":
return None return None
elif type == "socks4": elif proxy_type == "socks4":
type = 1 proxy_type = 1
elif type == "socks5": elif proxy_type == "socks5":
type = 2 proxy_type = 2
elif type == "http": elif proxy_type == "http":
type = 3 proxy_type = 3
return (type, return (proxy_type,
config["telegram.proxy.address"], config["telegram.proxy.port"], config["telegram.proxy.address"], config["telegram.proxy.port"],
config["telegram.proxy.rdns"], config["telegram.proxy.rdns"],
config["telegram.proxy.username"], config["telegram.proxy.password"]) config["telegram.proxy.username"], config["telegram.proxy.password"])
@@ -83,20 +105,30 @@ class AbstractUser:
proxy=self._proxy_settings) proxy=self._proxy_settings)
self.client.add_event_handler(self._update_catch) self.client.add_event_handler(self._update_catch)
async def update(self, update): @abstractmethod
async def update(self, update: TypeUpdate) -> bool:
return False return False
@abstractmethod
async def post_login(self): async def post_login(self):
raise NotImplementedError() raise NotImplementedError()
async def _update_catch(self, update): @abstractmethod
def register_portal(self, portal: po.Portal):
raise NotImplementedError()
@abstractmethod
def unregister_portal(self, portal: po.Portal):
raise NotImplementedError()
async def _update_catch(self, update: TypeUpdate):
try: try:
if not await self.update(update): if not await self.update(update):
await self._update(update) await self._update(update)
except Exception: except Exception:
self.log.exception("Failed to handle Telegram update") self.log.exception("Failed to handle Telegram update")
async def get_dialogs(self, limit=None) -> List[Union[Chat, Channel]]: async def get_dialogs(self, limit: int = None) -> List[Union[Chat, Channel]]:
if self.is_bot: if self.is_bot:
return [] return []
dialogs = await self.client.get_dialogs(limit=limit) dialogs = await self.client.get_dialogs(limit=limit)
@@ -106,18 +138,19 @@ class AbstractUser:
and (dialog.entity.deactivated or dialog.entity.left)))] and (dialog.entity.deactivated or dialog.entity.left)))]
@property @property
def name(self): @abstractmethod
def name(self) -> str:
raise NotImplementedError() raise NotImplementedError()
async def is_logged_in(self): async def is_logged_in(self) -> bool:
return self.client and await self.client.is_user_authorized() return self.client and await self.client.is_user_authorized()
async def has_full_access(self, allow_bot=False): async def has_full_access(self, allow_bot: bool = False) -> bool:
return (self.puppet_whitelisted return (self.puppet_whitelisted
and (not self.is_bot or allow_bot) and (not self.is_bot or allow_bot)
and await self.is_logged_in()) and await self.is_logged_in())
async def start(self, delete_unless_authenticated=False): async def start(self, delete_unless_authenticated: bool = False) -> "AbstractUser":
if not self.client: if not self.client:
self._init_client() self._init_client()
await self.client.connect() await self.client.connect()
@@ -144,7 +177,7 @@ class AbstractUser:
# region Telegram update handling # region Telegram update handling
async def _update(self, update): async def _update(self, update: TypeUpdate):
if isinstance(update, (UpdateShortChatMessage, UpdateShortMessage, UpdateNewChannelMessage, if isinstance(update, (UpdateShortChatMessage, UpdateShortMessage, UpdateNewChannelMessage,
UpdateNewMessage, UpdateEditMessage, UpdateEditChannelMessage)): UpdateNewMessage, UpdateEditMessage, UpdateEditChannelMessage)):
await self.update_message(update) await self.update_message(update)
@@ -169,17 +202,19 @@ class AbstractUser:
else: else:
self.log.debug("Unhandled update: %s", update) self.log.debug("Unhandled update: %s", update)
async def update_pinned_messages(self, update): @staticmethod
async def update_pinned_messages(update: UpdateChannelPinnedMessage):
portal = po.Portal.get_by_tgid(update.channel_id) portal = po.Portal.get_by_tgid(update.channel_id)
if portal and portal.mxid: if portal and portal.mxid:
await portal.receive_telegram_pin_id(update.id) await portal.receive_telegram_pin_id(update.id)
async def update_participants(self, update): @staticmethod
async def update_participants(update: UpdateChatParticipants):
portal = po.Portal.get_by_tgid(update.participants.chat_id) portal = po.Portal.get_by_tgid(update.participants.chat_id)
if portal and portal.mxid: if portal and portal.mxid:
await portal.update_telegram_participants(update.participants.participants) await portal.update_telegram_participants(update.participants.participants)
async def update_read_receipt(self, update): async def update_read_receipt(self, update: UpdateReadHistoryOutbox):
if not isinstance(update.peer, PeerUser): if not isinstance(update.peer, PeerUser):
self.log.debug("Unexpected read receipt peer: %s", update.peer) self.log.debug("Unexpected read receipt peer: %s", update.peer)
return return
@@ -196,7 +231,7 @@ class AbstractUser:
puppet = pu.Puppet.get(update.peer.user_id) puppet = pu.Puppet.get(update.peer.user_id)
await puppet.intent.mark_read(portal.mxid, message.mxid) await puppet.intent.mark_read(portal.mxid, message.mxid)
async def update_admin(self, update): async def update_admin(self, update: Union[UpdateChatAdmins, UpdateChatParticipantAdmin]):
# TODO duplication not checked # TODO duplication not checked
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat") portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
if isinstance(update, UpdateChatAdmins): if isinstance(update, UpdateChatAdmins):
@@ -206,7 +241,7 @@ class AbstractUser:
else: else:
self.log.warning("Unexpected admin status update: %s", update) self.log.warning("Unexpected admin status update: %s", update)
async def update_typing(self, update): async def update_typing(self, update: Union[UpdateUserTyping, UpdateChatUserTyping]):
if isinstance(update, UpdateUserTyping): if isinstance(update, UpdateUserTyping):
portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user") portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user")
else: else:
@@ -214,7 +249,7 @@ class AbstractUser:
sender = pu.Puppet.get(update.user_id) sender = pu.Puppet.get(update.user_id)
await portal.handle_telegram_typing(sender, update) await portal.handle_telegram_typing(sender, update)
async def update_others_info(self, update): async def update_others_info(self, update: Union[UpdateUserName, UpdateUserPhoto]):
# TODO duplication not checked # TODO duplication not checked
puppet = pu.Puppet.get(update.user_id) puppet = pu.Puppet.get(update.user_id)
if isinstance(update, UpdateUserName): if isinstance(update, UpdateUserName):
@@ -226,7 +261,7 @@ class AbstractUser:
else: else:
self.log.warning("Unexpected other user info update: %s", update) self.log.warning("Unexpected other user info update: %s", update)
async def update_status(self, update): async def update_status(self, update: UpdateUserStatus):
puppet = pu.Puppet.get(update.user_id) puppet = pu.Puppet.get(update.user_id)
if isinstance(update.status, UserStatusOnline): if isinstance(update.status, UserStatusOnline):
await puppet.default_mxid_intent.set_presence("online") await puppet.default_mxid_intent.set_presence("online")
@@ -236,7 +271,9 @@ class AbstractUser:
self.log.warning("Unexpected user status update: %s", update) self.log.warning("Unexpected user status update: %s", update)
return return
def get_message_details(self, update): def get_message_details(self, update: UpdateMessage) -> Tuple[UpdateMessageContent,
Optional[pu.Puppet],
Optional[po.Portal]]:
if isinstance(update, UpdateShortChatMessage): if isinstance(update, UpdateShortChatMessage):
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat") portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
sender = pu.Puppet.get(update.from_id) sender = pu.Puppet.get(update.from_id)
@@ -259,7 +296,7 @@ class AbstractUser:
return update, sender, portal return update, sender, portal
@staticmethod @staticmethod
async def _try_redact(portal, message): async def _try_redact(portal: po.Portal, message: DBMessage):
if not portal: if not portal:
return return
try: try:
@@ -267,7 +304,7 @@ class AbstractUser:
except MatrixRequestError: except MatrixRequestError:
pass pass
async def delete_message(self, update): async def delete_message(self, update: UpdateDeleteMessages):
if len(update.messages) > MAX_DELETIONS: if len(update.messages) > MAX_DELETIONS:
return return
@@ -283,7 +320,7 @@ class AbstractUser:
await self._try_redact(portal, message) await self._try_redact(portal, message)
self.db.commit() self.db.commit()
async def delete_channel_message(self, update): async def delete_channel_message(self, update: UpdateDeleteChannelMessages):
if len(update.messages) > MAX_DELETIONS: if len(update.messages) > MAX_DELETIONS:
return return
@@ -299,7 +336,7 @@ class AbstractUser:
await self._try_redact(portal, message) await self._try_redact(portal, message)
self.db.commit() self.db.commit()
async def update_message(self, original_update): async def update_message(self, original_update: UpdateMessage):
update, sender, portal = self.get_message_details(original_update) update, sender, portal = self.get_message_details(original_update)
if isinstance(update, MessageService): if isinstance(update, MessageService):
@@ -325,7 +362,7 @@ class AbstractUser:
# endregion # endregion
def init(context): def init(context: "Context"):
global config, MAX_DELETIONS global config, MAX_DELETIONS
AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, _ = context AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, _ = context
AbstractUser.session_container = context.session_container AbstractUser.session_container = context.session_container
+1 -1
View File
@@ -1,2 +1,2 @@
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base() Base = declarative_base() # type: declarative_base
+22 -18
View File
@@ -14,7 +14,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Callable from typing import Awaitable, Callable, Pattern, Dict, TYPE_CHECKING
import logging import logging
import re import re
@@ -27,27 +27,31 @@ from .abstract_user import AbstractUser
from .db import BotChat from .db import BotChat
from . import puppet as pu, portal as po, user as u from . import puppet as pu, portal as po, user as u
config = None if TYPE_CHECKING:
from .config import Config
config = None # type: Config
ReplyFunc = Callable[[str], Awaitable[Message]] ReplyFunc = Callable[[str], Awaitable[Message]]
class Bot(AbstractUser): class Bot(AbstractUser):
log = logging.getLogger("mau.bot") log = logging.getLogger("mau.bot") # type: logging.Logger
mxid_regex = re.compile("@.+:.+") mxid_regex = re.compile("@.+:.+") # type: Pattern
def __init__(self, token: str): def __init__(self, token: str):
super().__init__() super().__init__()
self.token = token self.token = token # type: str
self.puppet_whitelisted = True self.puppet_whitelisted = True # type: bool
self.whitelisted = True self.whitelisted = True # type: bool
self.relaybot_whitelisted = True self.relaybot_whitelisted = True # type: bool
self.username = None self.username = None # type: str
self.is_relaybot = True self.is_relaybot = True # type: bool
self.is_bot = True self.is_bot = True # type: bool
self.chats = {chat.id: chat.type for chat in BotChat.query.all()} self.chats = {chat.id: chat.type for chat in BotChat.query.all()} # type: Dict[int, str]
self.tg_whitelist = [] self.tg_whitelist = [] # type: List[int]
self.whitelist_group_admins = config["bridge.relaybot.whitelist_group_admins"] or False self.whitelist_group_admins = (config["bridge.relaybot.whitelist_group_admins"]
or False) # type: bool
async def init_permissions(self): async def init_permissions(self):
whitelist = config["bridge.relaybot.whitelist"] or [] whitelist = config["bridge.relaybot.whitelist"] or []
@@ -61,7 +65,7 @@ class Bot(AbstractUser):
if isinstance(id, int): if isinstance(id, int):
self.tg_whitelist.append(id) self.tg_whitelist.append(id)
async def start(self, delete_unless_authenticated=False): async def start(self, delete_unless_authenticated: bool = False) -> "Bot":
await super().start(delete_unless_authenticated) await super().start(delete_unless_authenticated)
if not await self.is_logged_in(): if not await self.is_logged_in():
await self.client.sign_in(bot_token=self.token) await self.client.sign_in(bot_token=self.token)
@@ -118,7 +122,7 @@ class Bot(AbstractUser):
self.db.delete(existing_chat) self.db.delete(existing_chat)
self.db.commit() self.db.commit()
async def _can_use_commands(self, chat, tgid): async def _can_use_commands(self, chat: TypePeer, tgid: int) -> bool:
if tgid in self.tg_whitelist: if tgid in self.tg_whitelist:
return True return True
@@ -138,7 +142,7 @@ class Bot(AbstractUser):
if p.user_id == tgid: if p.user_id == tgid:
return isinstance(p, (ChatParticipantCreator, ChatParticipantAdmin)) return isinstance(p, (ChatParticipantCreator, ChatParticipantAdmin))
async def check_can_use_commands(self, event: Message, reply: ReplyFunc): async def check_can_use_commands(self, event: Message, reply: ReplyFunc) -> bool:
if not await self._can_use_commands(event.to_id, event.from_id): if not await self._can_use_commands(event.to_id, event.from_id):
await reply("You do not have the permission to use that command.") await reply("You do not have the permission to use that command.")
return False return False
@@ -262,7 +266,7 @@ class Bot(AbstractUser):
return "bot" return "bot"
def init(context): def init(context) -> Optional[Bot]:
global config global config
config = context.config config = context.config
token = config["telegram.bot_token"] token = config["telegram.bot_token"]
+6 -7
View File
@@ -23,15 +23,14 @@ from .. import puppet as pu, portal as po
ManagementRoomList = List[Tuple[str, str]] ManagementRoomList = List[Tuple[str, str]]
RoomIDList = List[str] RoomIDList = List[str]
PortalList = List[po.Portal]
async def _find_rooms(intent: IntentAPI) -> Tuple[ async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList,
ManagementRoomList, RoomIDList, PortalList, PortalList]: List["po.Portal"], List["po.Portal"]]:
management_rooms = [] # type: ManagementRoomList management_rooms = [] # type: ManagementRoomList
unidentified_rooms = [] # type: RoomIDList unidentified_rooms = [] # type: RoomIDList
portals = [] # type: PortalList portals = [] # type: List[po.Portal]
empty_portals = [] # type: PortalList empty_portals = [] # type: List[po.Portal]
rooms = await intent.get_joined_rooms() rooms = await intent.get_joined_rooms()
for room in rooms: for room in rooms:
@@ -108,8 +107,8 @@ async def clean_rooms(evt: CommandEvent):
async def set_rooms_to_clean(evt, management_rooms: ManagementRoomList, async def set_rooms_to_clean(evt, management_rooms: ManagementRoomList,
unidentified_rooms: RoomIDList, portals: PortalList, unidentified_rooms: RoomIDList, portals: List["po.Portal"],
empty_portals: PortalList): empty_portals: List["po.Portal"]):
command = evt.args[0] command = evt.args[0]
rooms_to_clean = [] rooms_to_clean = []
if command == "clean-recommended": if command == "clean-recommended":
+1 -1
View File
@@ -222,7 +222,7 @@ async def bridge(evt: CommandEvent):
"chat to this room, use `$cmdprefix+sp continue`") "chat to this room, use `$cmdprefix+sp continue`")
async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: po.Portal): async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Portal"):
if not portal.mxid: if not portal.mxid:
await evt.reply("The portal seems to have lost its Matrix room between you" await evt.reply("The portal seems to have lost its Matrix room between you"
"calling `$cmdprefix+sp bridge` and this command.\n\n" "calling `$cmdprefix+sp bridge` and this command.\n\n"
+22 -21
View File
@@ -14,6 +14,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Tuple, Any, Optional
from ruamel.yaml import YAML from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap from ruamel.yaml.comments import CommentedMap
import random import random
@@ -24,28 +25,28 @@ yaml.indent(4)
class DictWithRecursion: class DictWithRecursion:
def __init__(self, data=None): def __init__(self, data: CommentedMap = None):
self._data = data or CommentedMap() self._data = data or CommentedMap() # type: CommentedMap
def _recursive_get(self, data, key, default_value): def _recursive_get(self, data: CommentedMap, key: str, default_value: Any) -> Any:
if '.' in key: if '.' in key:
key, next_key = key.split('.', 1) key, next_key = key.split('.', 1)
next_data = data.get(key, CommentedMap()) next_data = data.get(key, CommentedMap())
return self._recursive_get(next_data, next_key, default_value) return self._recursive_get(next_data, next_key, default_value)
return data.get(key, default_value) return data.get(key, default_value)
def get(self, key, default_value, allow_recursion=True): def get(self, key: str, default_value: Any, allow_recursion: bool = True) -> Any:
if allow_recursion and '.' in key: if allow_recursion and '.' in key:
return self._recursive_get(self._data, key, default_value) return self._recursive_get(self._data, key, default_value)
return self._data.get(key, default_value) return self._data.get(key, default_value)
def __getitem__(self, key): def __getitem__(self, key: str) -> Any:
return self.get(key, None) return self.get(key, None)
def __contains__(self, key): def __contains__(self, key: str) -> bool:
return self[key] is not None return self[key] is not None
def _recursive_set(self, data, key, value): def _recursive_set(self, data: CommentedMap, key: str, value: Any):
if '.' in key: if '.' in key:
key, next_key = key.split('.', 1) key, next_key = key.split('.', 1)
if key not in data: if key not in data:
@@ -55,16 +56,16 @@ class DictWithRecursion:
return return
data[key] = value data[key] = value
def set(self, key, value, allow_recursion=True): def set(self, key: str, value: Any, allow_recursion: bool = True):
if allow_recursion and '.' in key: if allow_recursion and '.' in key:
self._recursive_set(self._data, key, value) self._recursive_set(self._data, key, value)
return return
self._data[key] = value self._data[key] = value
def __setitem__(self, key, value): def __setitem__(self, key: str, value: Any):
self.set(key, value) self.set(key, value)
def _recursive_del(self, data, key): def _recursive_del(self, data: CommentedMap, key: str):
if '.' in key: if '.' in key:
key, next_key = key.split('.', 1) key, next_key = key.split('.', 1)
if key not in data: if key not in data:
@@ -78,7 +79,7 @@ class DictWithRecursion:
except KeyError: except KeyError:
pass pass
def delete(self, key, allow_recursion=True): def delete(self, key: str, allow_recursion: bool = True):
if allow_recursion and '.' in key: if allow_recursion and '.' in key:
self._recursive_del(self._data, key) self._recursive_del(self._data, key)
return return
@@ -88,23 +89,23 @@ class DictWithRecursion:
except KeyError: except KeyError:
pass pass
def __delitem__(self, key): def __delitem__(self, key: str):
self.delete(key) self.delete(key)
class Config(DictWithRecursion): class Config(DictWithRecursion):
def __init__(self, path, registration_path, base_path): def __init__(self, path: str, registration_path: str, base_path: str):
super().__init__() super().__init__()
self.path = path self.path = path # type: str
self.registration_path = registration_path self.registration_path = registration_path # type: str
self.base_path = base_path self.base_path = base_path # type: str
self._registration = None self._registration = None # type: dict
def load(self): def load(self):
with open(self.path, 'r') as stream: with open(self.path, 'r') as stream:
self._data = yaml.load(stream) self._data = yaml.load(stream)
def load_base(self): def load_base(self) -> Optional[DictWithRecursion]:
try: try:
with open(self.base_path, 'r') as stream: with open(self.base_path, 'r') as stream:
return DictWithRecursion(yaml.load(stream)) return DictWithRecursion(yaml.load(stream))
@@ -120,7 +121,7 @@ class Config(DictWithRecursion):
yaml.dump(self._registration, stream) yaml.dump(self._registration, stream)
@staticmethod @staticmethod
def _new_token(): def _new_token() -> str:
return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64)) return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64))
def update(self): def update(self):
@@ -246,7 +247,7 @@ class Config(DictWithRecursion):
self._data = base._data self._data = base._data
self.save() self.save()
def _get_permissions(self, key): def _get_permissions(self, key: str) -> Tuple[bool, bool, bool, bool, bool]:
level = self["bridge.permissions"].get(key, "") level = self["bridge.permissions"].get(key, "")
admin = level == "admin" admin = level == "admin"
puppeting = level == "full" or admin puppeting = level == "full" or admin
@@ -254,7 +255,7 @@ class Config(DictWithRecursion):
relaybot = level == "relaybot" or user relaybot = level == "relaybot" or user
return relaybot, user, puppeting, admin, level return relaybot, user, puppeting, admin, level
def get_permissions(self, mxid): def get_permissions(self, mxid: str) -> Tuple[bool, bool, bool, bool, bool]:
permissions = self["bridge.permissions"] or {} permissions = self["bridge.permissions"] or {}
if mxid in permissions: if mxid in permissions:
return self._get_permissions(mxid) return self._get_permissions(mxid)
+18 -12
View File
@@ -14,21 +14,27 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Tuple from typing import TYPE_CHECKING
import asyncio
if TYPE_CHECKING:
import asyncio
from sqlalchemy.orm import scoped_session
from alchemysession import AlchemySessionContainer
from mautrix_appservice import AppService
from .web import PublicBridgeWebsite, ProvisioningAPI
from .config import Config
from .bot import Bot
from .matrix import MatrixHandler
from sqlalchemy.orm import scoped_session
from alchemysession import AlchemySessionContainer
from mautrix_appservice import AppService
class Context: class Context:
def __init__(self, az, db, config, loop, bot, mx, session_container, public_website, def __init__(self, az: "AppService", db: "scoped_session", config: "Config",
provisioning_api): loop: "asyncio.AbstractEventLoop", bot: "Bot", mx: "MatrixHandler",
from .web import PublicBridgeWebsite, ProvisioningAPI session_container: "AlchemySessionContainer",
from .config import Config public_website: "PublicBridgeWebsite", provisioning_api: "ProvisioningAPI"):
from .bot import Bot
from .matrix import MatrixHandler
self.az = az # type: AppService self.az = az # type: AppService
self.db = db # type: scoped_session self.db = db # type: scoped_session
self.config = config # type: Config self.config = config # type: Config
+1
View File
@@ -42,6 +42,7 @@ class Portal(Base):
about = Column(String, nullable=True) about = Column(String, nullable=True)
photo_id = Column(String, nullable=True) photo_id = Column(String, nullable=True)
class Message(Base): class Message(Base):
query = None # type: Query query = None # type: Query
__tablename__ = "message" __tablename__ = "message"
+27 -24
View File
@@ -14,10 +14,10 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import (Optional, List, Tuple, Type, Callable, Dict, Any, Pattern, Deque, Match, TYPE_CHECKING)
from html import unescape from html import unescape
from html.parser import HTMLParser from html.parser import HTMLParser
from collections import deque from collections import deque
from typing import Optional, List, Tuple, Type, Callable, Dict, Any
import math import math
import re import re
import logging import logging
@@ -27,37 +27,40 @@ from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName, M
MessageEntityItalic, MessageEntityCode, MessageEntityPre, MessageEntityItalic, MessageEntityCode, MessageEntityPre,
MessageEntityBotCommand, TypeMessageEntity) MessageEntityBotCommand, TypeMessageEntity)
from .. import user as u, puppet as pu, portal as po, context as c from .. import user as u, puppet as pu, portal as po
from ..db import Message as DBMessage from ..db import Message as DBMessage
from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html, from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html,
trim_reply_fallback_text, html_to_unicode) trim_reply_fallback_text, html_to_unicode)
log = logging.getLogger("mau.fmt.mx") if TYPE_CHECKING:
should_bridge_plaintext_highlights = False from ..context import Context
log = logging.getLogger("mau.fmt.mx") # type: logging.Logger
should_bridge_plaintext_highlights = False # type: bool
class MatrixParser(HTMLParser): class MatrixParser(HTMLParser):
mention_regex = re.compile("https://matrix.to/#/(@.+:.+)") mention_regex = re.compile("https://matrix.to/#/(@.+:.+)") # type: Pattern
room_regex = re.compile("https://matrix.to/#/(#.+:.+)") room_regex = re.compile("https://matrix.to/#/(#.+:.+)") # type: Pattern
block_tags = ("br", "p", "pre", "blockquote", block_tags = ("br", "p", "pre", "blockquote",
"ol", "ul", "li", "ol", "ul", "li",
"h1", "h2", "h3", "h4", "h5", "h6", "h1", "h2", "h3", "h4", "h5", "h6",
"div", "hr", "table") "div", "hr", "table") # type: Tuple[str, ...]
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.text = "" self.text = "" # type: str
self.entities = [] self.entities = [] # type: List[TypeMessageEntity]
self._building_entities = {} self._building_entities = {} # type: Dict[str, TypeMessageEntity]
self._list_counter = 0 self._list_counter = 0 # type: int
self._open_tags = deque() self._open_tags = deque() # type: Deque[str]
self._open_tags_meta = deque() self._open_tags_meta = deque() # type: Deque[Any]
self._line_is_new = True self._line_is_new = True # type: bool
self._list_entry_is_new = False self._list_entry_is_new = False # type: bool
def _parse_url(self, url: str, args: Dict[str, Any] def _parse_url(self, url: str, args: Dict[str, Any]
) -> Tuple[Optional[Type[TypeMessageEntity]], Optional[str]]: ) -> Tuple[Optional[Type[TypeMessageEntity]], Optional[str]]:
mention = self.mention_regex.match(url) mention = self.mention_regex.match(url) # type: Match
if mention: if mention:
mxid = mention.group(1) mxid = mention.group(1)
user = (pu.Puppet.get_by_mxid(mxid) user = (pu.Puppet.get_by_mxid(mxid)
@@ -72,7 +75,7 @@ class MatrixParser(HTMLParser):
else: else:
return None, None return None, None
room = self.room_regex.match(url) room = self.room_regex.match(url) # type: Match
if room: if room:
username = po.Portal.get_username_from_mx_alias(room.group(1)) username = po.Portal.get_username_from_mx_alias(room.group(1))
portal = po.Portal.find_by_username(username) portal = po.Portal.find_by_username(username)
@@ -92,8 +95,8 @@ class MatrixParser(HTMLParser):
self._open_tags_meta.appendleft(0) self._open_tags_meta.appendleft(0)
attrs = dict(attrs) attrs = dict(attrs)
entity_type = None entity_type = None # type: type(TypeMessageEntity)
args = {} args = {} # type: Dict[str, Any]
if tag in ("strong", "b"): if tag in ("strong", "b"):
entity_type = MessageEntityBold entity_type = MessageEntityBold
elif tag in ("em", "i"): elif tag in ("em", "i"):
@@ -243,12 +246,12 @@ class MatrixParser(HTMLParser):
self._newline(allow_multi=tag == "br") self._newline(allow_multi=tag == "br")
command_regex = re.compile(r"^!([A-Za-z0-9@]+)") command_regex = re.compile(r"^!([A-Za-z0-9@]+)") # type: Pattern
not_command_regex = re.compile(r"^\\(![A-Za-z0-9@]+)") not_command_regex = re.compile(r"^\\(![A-Za-z0-9@]+)") # type: Pattern
plain_mention_regex = None plain_mention_regex = None # type: Pattern
def plain_mention_to_html(match): def plain_mention_to_html(match: Match) -> str:
puppet = pu.Puppet.find_by_displayname(match.group(2)) puppet = pu.Puppet.find_by_displayname(match.group(2))
if puppet: if puppet:
return (f"{match.group(1)}" return (f"{match.group(1)}"
@@ -351,7 +354,7 @@ def plain_mention_to_text() -> Tuple[List[TypeMessageEntity], Callable[[str], st
return entities, replacer return entities, replacer
def init_mx(context: c.Context): def init_mx(context: "Context"):
global plain_mention_regex, should_bridge_plaintext_highlights global plain_mention_regex, should_bridge_plaintext_highlights
config = context.config config = context.config
dn_template = config.get("bridge.displayname_template", "{displayname} (Telegram)") dn_template = config.get("bridge.displayname_template", "{displayname} (Telegram)")
+26 -18
View File
@@ -14,13 +14,8 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, List, Tuple, TYPE_CHECKING
from html import escape from html import escape
from typing import Optional, List, Tuple
try:
from lxml.html.diff import htmldiff
except ImportError:
htmldiff = None # type: function
import logging import logging
import re import re
@@ -33,16 +28,26 @@ from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName,
from mautrix_appservice import MatrixRequestError from mautrix_appservice import MatrixRequestError
from mautrix_appservice.intent_api import IntentAPI from mautrix_appservice.intent_api import IntentAPI
from .. import user as u, puppet as pu, portal as po, context as c from .. import user as u, puppet as pu, portal as po
from ..db import Message as DBMessage from ..db import Message as DBMessage
from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html, from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html,
trim_reply_fallback_text, unicode_to_html) trim_reply_fallback_text, unicode_to_html)
log = logging.getLogger("mau.fmt.tg") if TYPE_CHECKING:
should_highlight_edits = False from ..abstract_user import AbstractUser
from ..context import Context
try:
from lxml.html.diff import htmldiff
except ImportError:
htmldiff = None # type: function
def telegram_reply_to_matrix(evt: Message, source: u.User) -> dict: log = logging.getLogger("mau.fmt.tg") # type: logging.Logger
should_highlight_edits = False # type: bool
def telegram_reply_to_matrix(evt: Message, source: "AbstractUser") -> dict:
if evt.reply_to_msg_id: if evt.reply_to_msg_id:
space = (evt.to_id.channel_id space = (evt.to_id.channel_id
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel) if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
@@ -78,7 +83,7 @@ async def _add_forward_header(source, text: str, html: Optional[str],
if not fwd_from_text: if not fwd_from_text:
user = await source.client.get_entity(PeerUser(fwd_from.from_id)) user = await source.client.get_entity(PeerUser(fwd_from.from_id))
if user: if user:
fwd_from_text = pu.Puppet.get_displayname(user, format=False) fwd_from_text = pu.Puppet.get_displayname(user, False)
fwd_from_html = f"<b>{fwd_from_text}</b>" fwd_from_html = f"<b>{fwd_from_text}</b>"
if not fwd_from_text: if not fwd_from_text:
@@ -110,8 +115,9 @@ def highlight_edits(new_html: str, old_html: str) -> str:
return new_html return new_html
async def _add_reply_header(source: u.User, text: str, html: str, evt: Message, relates_to: dict, async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: Message,
main_intent: IntentAPI, is_edit: bool) -> Tuple[str, str]: relates_to: dict, main_intent: IntentAPI, is_edit: bool
) -> Tuple[str, str]:
space = (evt.to_id.channel_id space = (evt.to_id.channel_id
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel) if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
else source.tgid) else source.tgid)
@@ -142,7 +148,7 @@ async def _add_reply_header(source: u.User, text: str, html: str, evt: Message,
if is_edit and should_highlight_edits: if is_edit and should_highlight_edits:
html = highlight_edits(html or escape(text), r_html_body) html = highlight_edits(html or escape(text), r_html_body)
except (ValueError, KeyError, MatrixRequestError) as e: except (ValueError, KeyError, MatrixRequestError):
r_sender_link = "unknown user" r_sender_link = "unknown user"
r_displayname = "unknown user" r_displayname = "unknown user"
r_text_body = "Failed to fetch message" r_text_body = "Failed to fetch message"
@@ -154,8 +160,9 @@ async def _add_reply_header(source: u.User, text: str, html: str, evt: Message,
r_keyword = "In reply to" if not is_edit else "Edit to" r_keyword = "In reply to" if not is_edit else "Edit to"
r_msg_link = f"<a href='https://matrix.to/#/{msg.mx_room}/{msg.mxid}'>{r_keyword}</a>" r_msg_link = f"<a href='https://matrix.to/#/{msg.mx_room}/{msg.mxid}'>{r_keyword}</a>"
html = (f"<mx-reply><blockquote>{r_msg_link} {r_sender_link}\n{r_html_body}</blockquote></mx-reply>" html = (
+ (html or escape(text))) f"<mx-reply><blockquote>{r_msg_link} {r_sender_link}\n{r_html_body}</blockquote></mx-reply>"
+ (html or escape(text)))
lines = r_text_body.strip().split("\n") lines = r_text_body.strip().split("\n")
text_with_quote = f"> <{r_displayname}> {lines.pop(0)}" text_with_quote = f"> <{r_displayname}> {lines.pop(0)}"
@@ -167,7 +174,8 @@ async def _add_reply_header(source: u.User, text: str, html: str, evt: Message,
return text_with_quote, html return text_with_quote, html
async def telegram_to_matrix(evt: Message, source: u.User, main_intent: Optional[IntentAPI] = None, async def telegram_to_matrix(evt: Message, source: "AbstractUser",
main_intent: Optional[IntentAPI] = None,
is_edit: bool = False, prefix_text: Optional[str] = None, is_edit: bool = False, prefix_text: Optional[str] = None,
prefix_html: Optional[str] = None) -> Tuple[str, str, dict]: prefix_html: Optional[str] = None) -> Tuple[str, str, dict]:
text = add_surrogates(evt.message) text = add_surrogates(evt.message)
@@ -320,6 +328,6 @@ def _parse_url(html: List[str], entity_text: str, url: str) -> bool:
return False return False
def init_tg(context: c.Context): def init_tg(context: "Context"):
global should_highlight_edits global should_highlight_edits
should_highlight_edits = htmldiff and context.config["bridge.highlight_edits"] should_highlight_edits = htmldiff and context.config["bridge.highlight_edits"]
+2 -2
View File
@@ -14,8 +14,8 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Pattern
from html import escape from html import escape
from typing import Optional
import struct import struct
import re import re
@@ -47,7 +47,7 @@ def trim_reply_fallback_text(text: str) -> str:
html_reply_fallback_regex = re.compile("^<mx-reply>" html_reply_fallback_regex = re.compile("^<mx-reply>"
r"[\s\S]+?" r"[\s\S]+?"
"</mx-reply>") "</mx-reply>") # type: Pattern
def trim_reply_fallback_html(html: str) -> str: def trim_reply_fallback_html(html: str) -> str:
+114 -108
View File
@@ -14,26 +14,23 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Dict from typing import List, Dict, Tuple, Set, Match
import logging import logging
import asyncio import asyncio
import re import re
from mautrix_appservice import MatrixRequestError, IntentError from mautrix_appservice import MatrixRequestError, IntentError
from .user import User from . import user as u, portal as po, puppet as pu, commands as com
from .portal import Portal
from .puppet import Puppet
from .commands import CommandProcessor
class MatrixHandler: class MatrixHandler:
log = logging.getLogger("mau.mx") log = logging.getLogger("mau.mx") # type: logging.Logger
def __init__(self, context): def __init__(self, context):
self.az, self.db, self.config, _, self.tgbot = context self.az, self.db, self.config, _, self.tgbot = context
self.commands = CommandProcessor(context) self.commands = com.CommandProcessor(context) # type: com.CommandProcessor
self.previously_typing = [] self.previously_typing = [] # type: List[str]
self.az.matrix_event_handler(self.handle_event) self.az.matrix_event_handler(self.handle_event)
@@ -53,68 +50,68 @@ class MatrixHandler:
except asyncio.TimeoutError: except asyncio.TimeoutError:
self.log.exception("TimeoutError when trying to set avatar") self.log.exception("TimeoutError when trying to set avatar")
async def handle_puppet_invite(self, room, puppet, inviter): async def handle_puppet_invite(self, room_id, puppet: pu.Puppet, inviter: u.User):
intent = puppet.default_mxid_intent intent = puppet.default_mxid_intent
self.log.debug(f"{inviter} invited puppet for {puppet.tgid} to {room}") self.log.debug(f"{inviter} invited puppet for {puppet.tgid} to {room_id}")
if not await inviter.is_logged_in(): if not await inviter.is_logged_in():
await intent.error_and_leave( await intent.error_and_leave(
room, text="Please log in before inviting Telegram puppets.") room_id, text="Please log in before inviting Telegram puppets.")
return return
portal = Portal.get_by_mxid(room) portal = po.Portal.get_by_mxid(room_id)
if portal: if portal:
if portal.peer_type == "user": if portal.peer_type == "user":
await intent.error_and_leave( await intent.error_and_leave(
room, text="You can not invite additional users to private chats.") room_id, text="You can not invite additional users to private chats.")
return return
await portal.invite_telegram(inviter, puppet) await portal.invite_telegram(inviter, puppet)
await intent.join_room(room) await intent.join_room(room_id)
return return
try: try:
members = await self.az.intent.get_room_members(room) members = await self.az.intent.get_room_members(room_id)
except MatrixRequestError: except MatrixRequestError:
members = [] members = []
if self.az.bot_mxid not in members: if self.az.bot_mxid not in members:
if len(members) > 1: if len(members) > 1:
await intent.error_and_leave(room, text=None, html=( await intent.error_and_leave(room_id, text=None, html=(
f"Please invite " f"Please invite "
f"<a href='https://matrix.to/#/{self.az.bot_mxid}'>the bridge bot</a> " f"<a href='https://matrix.to/#/{self.az.bot_mxid}'>the bridge bot</a> "
f"first if you want to create a Telegram chat.")) f"first if you want to create a Telegram chat."))
return return
await intent.join_room(room) await intent.join_room(room_id)
portal = Portal.get_by_tgid(puppet.tgid, inviter.tgid, "user") portal = po.Portal.get_by_tgid(puppet.tgid, inviter.tgid, "user")
if portal.mxid: if portal.mxid:
try: try:
await intent.invite(portal.mxid, inviter.mxid) await intent.invite(portal.mxid, inviter.mxid)
await intent.send_notice(room, text=None, html=( await intent.send_notice(room_id, text=None, html=(
"You already have a private chat with me: " "You already have a private chat with me: "
f"<a href='https://matrix.to/#/{portal.mxid}'>" f"<a href='https://matrix.to/#/{portal.mxid}'>"
"Link to room" "Link to room"
"</a>")) "</a>"))
await intent.leave_room(room) await intent.leave_room(room_id)
return return
except MatrixRequestError: except MatrixRequestError:
pass pass
portal.mxid = room portal.mxid = room_id
portal.save() portal.save()
inviter.register_portal(portal) inviter.register_portal(portal)
await intent.send_notice(room, "Portal to private chat created.") await intent.send_notice(room_id, "po.Portal to private chat created.")
else: else:
await intent.join_room(room) await intent.join_room(room_id)
await intent.send_notice(room, "This puppet will remain inactive until a " await intent.send_notice(room_id, "This puppet will remain inactive until a "
"Telegram chat is created for this room.") "Telegram chat is created for this room.")
async def accept_bot_invite(self, room, inviter): async def accept_bot_invite(self, room_id: str, inviter: u.User):
tries = 0 tries = 0
while tries < 5: while tries < 5:
try: try:
await self.az.intent.join_room(room) await self.az.intent.join_room(room_id)
break break
except (IntentError, MatrixRequestError) as e: except (IntentError, MatrixRequestError):
tries += 1 tries += 1
wait_for_seconds = (tries + 1) * 10 wait_for_seconds = (tries + 1) * 10
if tries < 5: if tries < 5:
self.log.exception(f"Failed to join room {room} with bridge bot, " self.log.exception(f"Failed to join room {room_id} with bridge bot, "
f"retrying in {wait_for_seconds} seconds...") f"retrying in {wait_for_seconds} seconds...")
await asyncio.sleep(wait_for_seconds) await asyncio.sleep(wait_for_seconds)
else: else:
@@ -123,81 +120,81 @@ class MatrixHandler:
if not inviter.whitelisted: if not inviter.whitelisted:
await self.az.intent.send_notice( await self.az.intent.send_notice(
room, text=None, room_id, text=None,
html="You are not whitelisted to use this bridge.<br/><br/>" html="You are not whitelisted to use this bridge.<br/><br/>"
"If you are the owner of this bridge, see the " "If you are the owner of this bridge, see the "
"<code>bridge.permissions</code> section in your config file.") "<code>bridge.permissions</code> section in your config file.")
await self.az.intent.leave_room(room) await self.az.intent.leave_room(room_id)
async def handle_invite(self, room, user, inviter): async def handle_invite(self, room_id: str, user_id: str, inviter_mxid: str):
self.log.debug(f"{inviter} invited {user} to {room}") self.log.debug(f"{inviter_mxid} invited {user_id} to {room_id}")
inviter = await User.get_by_mxid(inviter).ensure_started() inviter = await u.User.get_by_mxid(inviter_mxid).ensure_started()
if user == self.az.bot_mxid: if user_id == self.az.bot_mxid:
return await self.accept_bot_invite(room, inviter) return await self.accept_bot_invite(room_id, inviter)
elif not inviter.whitelisted: elif not inviter.whitelisted:
return return
puppet = Puppet.get_by_mxid(user) puppet = pu.Puppet.get_by_mxid(user_id)
if puppet: if puppet:
await self.handle_puppet_invite(room, puppet, inviter) await self.handle_puppet_invite(room_id, puppet, inviter)
return return
user = User.get_by_mxid(user, create=False) user = u.User.get_by_mxid(user_id, create=False)
if not user: if not user:
return return
await user.ensure_started() await user.ensure_started()
portal = Portal.get_by_mxid(room) portal = po.Portal.get_by_mxid(room_id)
if user and await user.has_full_access(allow_bot=True) and portal: if user and await user.has_full_access(allow_bot=True) and portal:
await portal.invite_telegram(inviter, user) await portal.invite_telegram(inviter, user)
return return
# The rest can probably be ignored # The rest can probably be ignored
async def handle_join(self, room, user, event_id): async def handle_join(self, room_id: str, user_id: str, event_id: str):
user = await User.get_by_mxid(user).ensure_started() user = await u.User.get_by_mxid(user_id).ensure_started()
portal = Portal.get_by_mxid(room) portal = po.Portal.get_by_mxid(room_id)
if not portal: if not portal:
return return
if not user.relaybot_whitelisted: if not user.relaybot_whitelisted:
await portal.main_intent.kick(room, user.mxid, await portal.main_intent.kick(room_id, user.mxid,
"You are not whitelisted on this Telegram bridge.") "You are not whitelisted on this Telegram bridge.")
return return
elif not await user.is_logged_in() and not portal.has_bot: elif not await user.is_logged_in() and not portal.has_bot:
await portal.main_intent.kick(room, user.mxid, await portal.main_intent.kick(room_id, user.mxid,
"This chat does not have a bot relaying " "This chat does not have a bot relaying "
"messages for unauthenticated users.") "messages for unauthenticated users.")
return return
self.log.debug(f"{user} joined {room}") self.log.debug(f"{user} joined {room_id}")
if await user.is_logged_in() or portal.has_bot: if await user.is_logged_in() or portal.has_bot:
await portal.join_matrix(user, event_id) await portal.join_matrix(user, event_id)
async def handle_part(self, room, user, sender, event_id): async def handle_part(self, room_id: str, user_id, sender_mxid: str, event_id: str):
self.log.debug(f"{user} left {room}") self.log.debug(f"{user_id} left {room_id}")
sender = User.get_by_mxid(sender, create=False) sender = u.User.get_by_mxid(sender_mxid, create=False)
if not sender: if not sender:
return return
await sender.ensure_started() await sender.ensure_started()
portal = Portal.get_by_mxid(room) portal = po.Portal.get_by_mxid(room_id)
if not portal: if not portal:
return return
puppet = Puppet.get_by_mxid(user) puppet = pu.Puppet.get_by_mxid(user_id)
if sender and puppet: if sender and puppet:
await portal.leave_matrix(puppet, sender, event_id) await portal.leave_matrix(puppet, sender, event_id)
user = User.get_by_mxid(user, create=False) user = u.User.get_by_mxid(user_id, create=False)
if not user: if not user:
return return
await user.ensure_started() await user.ensure_started()
if await user.is_logged_in() or portal.has_bot: if await user.is_logged_in() or portal.has_bot:
await portal.leave_matrix(user, sender, event_id) await portal.leave_matrix(user, sender, event_id)
def is_command(self, message): def is_command(self, message: dict) -> Tuple[bool, str]:
text = message.get("body", "") text = message.get("body", "")
prefix = self.config["bridge.command_prefix"] prefix = self.config["bridge.command_prefix"]
is_command = text.startswith(prefix) is_command = text.startswith(prefix)
@@ -207,14 +204,14 @@ class MatrixHandler:
async def handle_message(self, room, sender, message, event_id): async def handle_message(self, room, sender, message, event_id):
is_command, text = self.is_command(message) is_command, text = self.is_command(message)
sender = await User.get_by_mxid(sender).ensure_started() sender = await u.User.get_by_mxid(sender).ensure_started()
if not sender.relaybot_whitelisted: if not sender.relaybot_whitelisted:
self.log.debug(f"Ignoring message \"{message}\" from {sender} to {room}:" self.log.debug(f"Ignoring message \"{message}\" from {sender} to {room}:"
" User is not whitelisted.") " u.User is not whitelisted.")
return return
self.log.debug(f"Received Matrix event \"{message}\" from {sender} in {room}") self.log.debug(f"Received Matrix event \"{message}\" from {sender} in {room}")
portal = Portal.get_by_mxid(room) portal = po.Portal.get_by_mxid(room)
if not is_command and portal and (await sender.is_logged_in() or portal.has_bot): if not is_command and portal and (await sender.is_logged_in() or portal.has_bot):
await portal.handle_matrix_message(sender, message, event_id) await portal.handle_matrix_message(sender, message, event_id)
return return
@@ -239,39 +236,44 @@ class MatrixHandler:
await self.commands.handle(room, sender, command, args, is_management, await self.commands.handle(room, sender, command, args, is_management,
is_portal=portal is not None) is_portal=portal is not None)
async def handle_redaction(self, room, sender, event_id): @staticmethod
sender = await User.get_by_mxid(sender).ensure_started() async def handle_redaction(room_id: str, sender_mxid: str, event_id: str):
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if not sender.relaybot_whitelisted: if not sender.relaybot_whitelisted:
return return
portal = Portal.get_by_mxid(room) portal = po.Portal.get_by_mxid(room_id)
if not portal: if not portal:
return return
await portal.handle_matrix_deletion(sender, event_id) await portal.handle_matrix_deletion(sender, event_id)
async def handle_power_levels(self, room, sender, new, old): @staticmethod
portal = Portal.get_by_mxid(room) async def handle_power_levels(room_id: str, sender_mxid: str, new: dict, old: dict):
sender = await User.get_by_mxid(sender).ensure_started() portal = po.Portal.get_by_mxid(room_id)
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if await sender.has_full_access(allow_bot=True) and portal: if await sender.has_full_access(allow_bot=True) and portal:
await portal.handle_matrix_power_levels(sender, new["users"], old["users"]) await portal.handle_matrix_power_levels(sender, new["users"], old["users"])
async def handle_room_meta(self, type, room, sender, content): @staticmethod
portal = Portal.get_by_mxid(room) async def handle_room_meta(evt_type: str, room_id: str, sender_mxid: str, content: dict):
sender = await User.get_by_mxid(sender).ensure_started() portal = po.Portal.get_by_mxid(room_id)
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if await sender.has_full_access(allow_bot=True) and portal: if await sender.has_full_access(allow_bot=True) and portal:
handler, content_key = { handler, content_key = {
"m.room.name": (portal.handle_matrix_title, "name"), "m.room.name": (portal.handle_matrix_title, "name"),
"m.room.topic": (portal.handle_matrix_about, "topic"), "m.room.topic": (portal.handle_matrix_about, "topic"),
"m.room.avatar": (portal.handle_matrix_avatar, "url"), "m.room.avatar": (portal.handle_matrix_avatar, "url"),
}[type] }[evt_type]
if content_key not in content: if content_key not in content:
return return
await handler(sender, content[content_key]) await handler(sender, content[content_key])
async def handle_room_pin(self, room, sender, new_events, old_events): @staticmethod
portal = Portal.get_by_mxid(room) async def handle_room_pin(room_id: str, sender_mxid: str, new_events: Set[str],
sender = await User.get_by_mxid(sender).ensure_started() old_events: Set[str]):
portal = po.Portal.get_by_mxid(room_id)
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if await sender.has_full_access(allow_bot=True) and portal: if await sender.has_full_access(allow_bot=True) and portal:
events = new_events - old_events events = new_events - old_events
if len(events) > 0: if len(events) > 0:
@@ -281,12 +283,14 @@ class MatrixHandler:
# All pinned events removed, remove pinned event in Telegram. # All pinned events removed, remove pinned event in Telegram.
await portal.handle_matrix_pin(sender, None) await portal.handle_matrix_pin(sender, None)
async def handle_name_change(self, room, user, displayname, prev_displayname, event_id): @staticmethod
portal = Portal.get_by_mxid(room) async def handle_name_change(room_id: str, user_id: str, displayname: str,
prev_displayname: str, event_id: str):
portal = po.Portal.get_by_mxid(room_id)
if not portal or not portal.has_bot: if not portal or not portal.has_bot:
return return
user = await User.get_by_mxid(user).ensure_started() user = await u.User.get_by_mxid(user_id).ensure_started()
if await user.needs_relaybot(portal): if await user.needs_relaybot(portal):
await portal.name_change_matrix(user, displayname, prev_displayname, event_id) await portal.name_change_matrix(user, displayname, prev_displayname, event_id)
@@ -296,25 +300,27 @@ class MatrixHandler:
for event_id, receipts in content.items() for event_id, receipts in content.items()
for user_id in receipts.get("m.read", {})} for user_id in receipts.get("m.read", {})}
async def handle_read_receipts(self, room_id: str, receipts: Dict[str, str]): @staticmethod
portal = Portal.get_by_mxid(room_id) async def handle_read_receipts(room_id: str, receipts: Dict[str, str]):
portal = po.Portal.get_by_mxid(room_id)
if not portal: if not portal:
return return
for user_id, event_id in receipts.items(): for user_id, event_id in receipts.items():
user = await User.get_by_mxid(user_id).ensure_started() user = await u.User.get_by_mxid(user_id).ensure_started()
if not await user.is_logged_in(): if not await user.is_logged_in():
continue continue
await portal.mark_read(user, event_id) await portal.mark_read(user, event_id)
async def handle_presence(self, user: str, presence: str): @staticmethod
user = await User.get_by_mxid(user).ensure_started() async def handle_presence(user_id: str, presence: str):
user = await u.User.get_by_mxid(user_id).ensure_started()
if not await user.is_logged_in(): if not await user.is_logged_in():
return return
await user.set_presence(presence == "online") await user.set_presence(presence == "online")
async def handle_typing(self, room_id: str, now_typing: List[str]): async def handle_typing(self, room_id: str, now_typing: List[str]):
portal = Portal.get_by_mxid(room_id) portal = po.Portal.get_by_mxid(room_id)
if not portal: if not portal:
return return
@@ -324,7 +330,7 @@ class MatrixHandler:
if is_typing and was_typing: if is_typing and was_typing:
continue continue
user = await User.get_by_mxid(user_id).ensure_started() user = await u.User.get_by_mxid(user_id).ensure_started()
if not await user.is_logged_in(): if not await user.is_logged_in():
continue continue
@@ -332,38 +338,38 @@ class MatrixHandler:
self.previously_typing = now_typing self.previously_typing = now_typing
def filter_matrix_event(self, event): def filter_matrix_event(self, event: dict):
sender = event.get("sender", None) sender = event.get("sender", None)
if not sender: if not sender:
return False return False
return (sender == self.az.bot_mxid return (sender == self.az.bot_mxid
or Puppet.get_id_from_mxid(sender) is not None) or pu.Puppet.get_id_from_mxid(sender) is not None)
async def try_handle_event(self, evt): async def try_handle_event(self, evt: dict):
try: try:
await self.handle_event(evt) await self.handle_event(evt)
except Exception: except Exception:
self.log.exception("Error handling manually received Matrix event") self.log.exception("Error handling manually received Matrix event")
async def handle_event(self, evt): async def handle_event(self, evt: dict):
if self.filter_matrix_event(evt): if self.filter_matrix_event(evt):
return return
self.log.debug("Received event: %s", evt) self.log.debug("Received event: %s", evt)
type = evt.get("type", "m.unknown") evt_type = evt.get("type", "m.unknown") # type: str
room_id = evt.get("room_id", None) room_id = evt.get("room_id", None) # type: str
event_id = evt.get("event_id", None) event_id = evt.get("event_id", None) # type: str
sender = evt.get("sender", None) sender = evt.get("sender", None) # type: str
content = evt.get("content", {}) content = evt.get("content", {}) # type: dict
if type == "m.room.member": if evt_type == "m.room.member":
state_key = evt["state_key"] state_key = evt["state_key"] # type: str
prev_content = evt.get("unsigned", {}).get("prev_content", {}) prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: dict
membership = content.get("membership", "") membership = content.get("membership", "") # type: str
prev_membership = prev_content.get("membership", "leave") prev_membership = prev_content.get("membership", "leave") # type: str
if membership == prev_membership: if membership == prev_membership:
match = re.compile("@(.+):(.+)").match(state_key) match = re.compile("@(.+):(.+)").match(state_key) # type: Match
localpart = match.group(1) localpart = match.group(1) # type: str
displayname = content.get("displayname", localpart) displayname = content.get("displayname", localpart) # type: str
prev_displayname = prev_content.get("displayname", localpart) prev_displayname = prev_content.get("displayname", localpart) # type: str
if displayname != prev_displayname: if displayname != prev_displayname:
await self.handle_name_change(room_id, state_key, displayname, await self.handle_name_change(room_id, state_key, displayname,
prev_displayname, event_id) prev_displayname, event_id)
@@ -373,26 +379,26 @@ class MatrixHandler:
await self.handle_part(room_id, state_key, sender, event_id) await self.handle_part(room_id, state_key, sender, event_id)
elif membership == "join": elif membership == "join":
await self.handle_join(room_id, state_key, event_id) await self.handle_join(room_id, state_key, event_id)
elif type in ("m.room.message", "m.sticker"): elif evt_type in ("m.room.message", "m.sticker"):
if type != "m.room.message": if evt_type != "m.room.message":
content["msgtype"] = type content["msgtype"] = evt_type
await self.handle_message(room_id, sender, content, event_id) await self.handle_message(room_id, sender, content, event_id)
elif type == "m.room.redaction": elif evt_type == "m.room.redaction":
await self.handle_redaction(room_id, sender, evt["redacts"]) await self.handle_redaction(room_id, sender, evt["redacts"])
elif type == "m.room.power_levels": elif evt_type == "m.room.power_levels":
await self.handle_power_levels(room_id, sender, evt["content"], evt["prev_content"]) await self.handle_power_levels(room_id, sender, evt["content"], evt["prev_content"])
elif type in ("m.room.name", "m.room.avatar", "m.room.topic"): elif evt_type in ("m.room.name", "m.room.avatar", "m.room.topic"):
await self.handle_room_meta(type, room_id, sender, evt["content"]) await self.handle_room_meta(evt_type, room_id, sender, evt["content"])
elif type == "m.room.pinned_events": elif evt_type == "m.room.pinned_events":
new_events = set(evt["content"]["pinned"]) new_events = set(evt["content"]["pinned"])
try: try:
old_events = set(evt["unsigned"]["prev_content"]["pinned"]) old_events = set(evt["unsigned"]["prev_content"]["pinned"])
except KeyError: except KeyError:
old_events = set() old_events = set()
await self.handle_room_pin(room_id, sender, new_events, old_events) await self.handle_room_pin(room_id, sender, new_events, old_events)
elif type == "m.receipt": elif evt_type == "m.receipt":
await self.handle_read_receipts(room_id, self.parse_read_receipts(content)) await self.handle_read_receipts(room_id, self.parse_read_receipts(content))
elif type == "m.presence": elif evt_type == "m.presence":
await self.handle_presence(sender, content.get("presence", "offline")) await self.handle_presence(sender, content.get("presence", "offline"))
elif type == "m.typing": elif evt_type == "m.typing":
await self.handle_typing(room_id, content.get("user_ids", [])) await self.handle_typing(room_id, content.get("user_ids", []))
+249 -196
View File
File diff suppressed because it is too large Load Diff
+34 -26
View File
@@ -14,32 +14,39 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Awaitable, Pattern, Dict, List, TYPE_CHECKING
from difflib import SequenceMatcher from difflib import SequenceMatcher
from typing import Optional, Awaitable
import re import re
import logging import logging
import asyncio import asyncio
from sqlalchemy import orm
from telethon.tl.types import UserProfilePhoto from telethon.tl.types import UserProfilePhoto
from mautrix_appservice import AppService, IntentAPI, IntentError, MatrixRequestError from mautrix_appservice import AppService, IntentAPI, IntentError, MatrixRequestError
from .db import Puppet as DBPuppet from .db import Puppet as DBPuppet
from . import util, matrix from . import util
config = None if TYPE_CHECKING:
from .matrix import MatrixHandler
from .config import Config
from .context import Context
config = None # type: Config
class Puppet: class Puppet:
log = logging.getLogger("mau.puppet") log = logging.getLogger("mau.puppet") # type: logging.Logger
db = None db = None # type: orm.Session
az = None # type: AppService az = None # type: AppService
mx = None # type: matrix.MatrixHandler mx = None # type: MatrixHandler
loop = None # type: asyncio.AbstractEventLoop loop = None # type: asyncio.AbstractEventLoop
mxid_regex = None mxid_regex = None # type: Pattern
username_template = None username_template = None # type: str
hs_domain = None hs_domain = None # type: str
cache = {} cache = {} # type: Dict[str, Puppet]
by_custom_mxid = {} by_custom_mxid = {} # type: Dict[str, Puppet]
def __init__(self, id=None, access_token=None, custom_mxid=None, username=None, def __init__(self, id=None, access_token=None, custom_mxid=None, username=None,
displayname=None, displayname_source=None, photo_id=None, is_bot=None, displayname=None, displayname_source=None, photo_id=None, is_bot=None,
@@ -71,7 +78,8 @@ class Puppet:
def tgid(self): def tgid(self):
return self.id return self.id
async def is_logged_in(self): @staticmethod
async def is_logged_in():
return True return True
# region Custom puppet management # region Custom puppet management
@@ -154,12 +162,12 @@ class Puppet:
def filter_events(self, events): def filter_events(self, events):
new_events = [] new_events = []
for event in events: for event in events:
type = event.get("type", None) evt_type = event.get("type", None)
event.setdefault("content", {}) event.setdefault("content", {})
if type == "m.typing": if evt_type == "m.typing":
is_typing = self.custom_mxid in event["content"].get("user_ids", []) is_typing = self.custom_mxid in event["content"].get("user_ids", [])
event["content"]["user_ids"] = [self.custom_mxid] if is_typing else [] event["content"]["user_ids"] = [self.custom_mxid] if is_typing else []
elif type == "m.receipt": elif evt_type == "m.receipt":
val = None val = None
evt = None evt = None
for event_id in event["content"]: for event_id in event["content"]:
@@ -273,7 +281,7 @@ class Puppet:
return round(similarity * 1000) / 10 return round(similarity * 1000) / 10
@staticmethod @staticmethod
def get_displayname(info, format=True): def get_displayname(info, enable_format=True):
data = { data = {
"phone number": info.phone if hasattr(info, "phone") else None, "phone number": info.phone if hasattr(info, "phone") else None,
"username": info.username, "username": info.username,
@@ -295,7 +303,7 @@ class Puppet:
elif not name: elif not name:
name = info.id name = info.id
if not format: if not enable_format:
return name return name
return config.get("bridge.displayname_template", "{displayname} (Telegram)").format( return config.get("bridge.displayname_template", "{displayname} (Telegram)").format(
displayname=name) displayname=name)
@@ -347,18 +355,18 @@ class Puppet:
# region Getters # region Getters
@classmethod @classmethod
def get(cls, id, create=True) -> "Optional[Puppet]": def get(cls, tgid, create=True) -> "Optional[Puppet]":
try: try:
return cls.cache[id] return cls.cache[tgid]
except KeyError: except KeyError:
pass pass
puppet = DBPuppet.query.get(id) puppet = DBPuppet.query.get(tgid)
if puppet: if puppet:
return cls.from_db(puppet) return cls.from_db(puppet)
if create: if create:
puppet = cls(id) puppet = cls(tgid)
cls.db.add(puppet.db_instance) cls.db.add(puppet.db_instance)
cls.db.commit() cls.db.commit()
return puppet return puppet
@@ -402,8 +410,8 @@ class Puppet:
return None return None
@classmethod @classmethod
def get_mxid_from_id(cls, id): def get_mxid_from_id(cls, tgid):
return f"@{cls.username_template.format(userid=id)}:{cls.hs_domain}" return f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}"
@classmethod @classmethod
def find_by_username(cls, username) -> "Optional[Puppet]": def find_by_username(cls, username) -> "Optional[Puppet]":
@@ -437,12 +445,12 @@ class Puppet:
# endregion # endregion
def init(context): def init(context: "Context") -> List[Awaitable[int]]:
global config global config
Puppet.az, Puppet.db, config, Puppet.loop, _ = context Puppet.az, Puppet.db, config, Puppet.loop, _ = context
Puppet.mx = context.mx Puppet.mx = context.mx
Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}") Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}")
Puppet.hs_domain = config["homeserver"]["domain"] Puppet.hs_domain = config["homeserver"]["domain"]
localpart = Puppet.username_template.format(userid="(.+)") Puppet.mxid_regex = re.compile(
Puppet.mxid_regex = re.compile(f"@{localpart}:{Puppet.hs_domain}") f"@{Puppet.username_template.format(userid='(.+)')}:{Puppet.hs_domain}")
return [puppet.init_custom_mxid() for puppet in Puppet.get_all_with_custom_mxid()] return [puppet.init_custom_mxid() for puppet in Puppet.get_all_with_custom_mxid()]
+7 -3
View File
@@ -16,6 +16,8 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Tuple from typing import Dict, Tuple
from sqlalchemy import orm
from mautrix_appservice import StateStore from mautrix_appservice import StateStore
from . import puppet as pu from . import puppet as pu
@@ -25,15 +27,17 @@ from .db import RoomState, UserProfile
class SQLStateStore(StateStore): class SQLStateStore(StateStore):
def __init__(self, db): def __init__(self, db):
super().__init__() super().__init__()
self.db = db self.db = db # type: orm.Session
self.profile_cache = {} # type: Dict[Tuple[str, str], UserProfile] self.profile_cache = {} # type: Dict[Tuple[str, str], UserProfile]
self.room_state_cache = {} # type: Dict[str, RoomState] self.room_state_cache = {} # type: Dict[str, RoomState]
def is_registered(self, user: str) -> bool: @staticmethod
def is_registered(user: str) -> bool:
puppet = pu.Puppet.get_by_mxid(user) puppet = pu.Puppet.get_by_mxid(user)
return puppet.is_registered if puppet else False return puppet.is_registered if puppet else False
def registered(self, user: str): @staticmethod
def registered(user: str):
puppet = pu.Puppet.get_by_mxid(user) puppet = pu.Puppet.get_by_mxid(user)
if puppet: if puppet:
puppet.is_registered = True puppet.is_registered = True
+9 -2
View File
@@ -17,10 +17,14 @@
from telethon import TelegramClient, utils from telethon import TelegramClient, utils
from telethon.tl.functions.messages import SendMediaRequest from telethon.tl.functions.messages import SendMediaRequest
from telethon.tl.types import * from telethon.tl.types import *
from telethon.tl import custom
class MautrixTelegramClient(TelegramClient): class MautrixTelegramClient(TelegramClient):
async def upload_file(self, file, mime_type=None, attributes=None, file_name=None): async def upload_file_direct(self, file: bytes, mime_type: str = None,
attributes: List[TypeDocumentAttribute] = None,
file_name: str = None
) -> Union[InputMediaUploadedDocument, InputMediaUploadedPhoto]:
file_handle = await super().upload_file(file, file_name=file_name, use_cache=False) file_handle = await super().upload_file(file, file_name=file_name, use_cache=False)
if mime_type == "image/png" or mime_type == "image/jpeg": if mime_type == "image/png" or mime_type == "image/jpeg":
@@ -34,7 +38,10 @@ class MautrixTelegramClient(TelegramClient):
mime_type=mime_type or "application/octet-stream", mime_type=mime_type or "application/octet-stream",
attributes=list(attr_dict.values())) attributes=list(attr_dict.values()))
async def send_media(self, entity, media, caption=None, entities=None, reply_to=None): async def send_media(self, entity: Union[TypeInputPeer, TypePeer],
media: Union[TypeInputMedia, TypeMessageMedia],
caption: str = None, entities: List[TypeMessageEntity] = None,
reply_to: int = None) -> Optional[custom.Message]:
entity = await self.get_input_entity(entity) entity = await self.get_input_entity(entity)
reply_to = utils.get_message_id(reply_to) reply_to = utils.get_message_id(reply_to)
request = SendMediaRequest(entity, media, message=caption or "", entities=entities or [], request = SendMediaRequest(entity, media, message=caption or "", entities=entities or [],
+58 -54
View File
@@ -14,42 +14,51 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Awaitable, Optional from typing import Dict, Awaitable, Optional, Match, Tuple, TYPE_CHECKING
import logging import logging
import asyncio import asyncio
import re import re
from telethon.tl.types import * from telethon.tl.types import *
from telethon.tl.types import User as TLUser
from telethon.tl.types.contacts import ContactsNotModified from telethon.tl.types.contacts import ContactsNotModified
from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest
from telethon.tl.functions.account import UpdateStatusRequest from telethon.tl.functions.account import UpdateStatusRequest
from mautrix_appservice import MatrixRequestError from mautrix_appservice import MatrixRequestError
from .db import User as DBUser, Contact as DBContact from .db import User as DBUser, Contact as DBContact, Portal as DBPortal
from .abstract_user import AbstractUser from .abstract_user import AbstractUser
from . import portal as po, puppet as pu from . import portal as po, puppet as pu
config = None if TYPE_CHECKING:
from .config import Config
from .context import Context
config = None # type: Config
SearchResults = List[Tuple["pu.Puppet", int]]
class User(AbstractUser): class User(AbstractUser):
log = logging.getLogger("mau.user") log = logging.getLogger("mau.user") # type: logging.Logger
by_mxid = {} by_mxid = {} # type: Dict[str, User]
by_tgid = {} by_tgid = {} # type: Dict[int, User]
def __init__(self, mxid, tgid=None, username=None, db_contacts=None, saved_contacts=0, def __init__(self, mxid: str, tgid: Optional[int] = None, username: Optional[str] = None,
is_bot=False, db_portals=None, db_instance=None): db_contacts: Optional[List[DBContact]] = None, saved_contacts: int = 0,
is_bot: bool = False, db_portals: Optional[List[DBPortal]] = None,
db_instance: Optional[DBUser] = None):
super().__init__() super().__init__()
self.mxid = mxid # type: str self.mxid = mxid # type: str
self.tgid = tgid # type: int self.tgid = tgid # type: int
self.is_bot = is_bot # type: bool self.is_bot = is_bot # type: bool
self.username = username # type: str self.username = username # type: str
self.contacts = [] self.contacts = [] # type: List[pu.Puppet]
self.saved_contacts = saved_contacts self.saved_contacts = saved_contacts # type: int
self.db_contacts = db_contacts self.db_contacts = db_contacts # type: List[DBContact]
self.portals = {} # type: Dict[str, po.Portal] self.portals = {} # type: Dict[Tuple[int, int], po.Portal]
self.db_portals = db_portals self.db_portals = db_portals # type: List[DBPortal]
self._db_instance = db_instance self._db_instance = db_instance # type: DBUser
self.command_status = None # type: dict self.command_status = None # type: dict
@@ -64,53 +73,47 @@ class User(AbstractUser):
self.by_tgid[tgid] = self self.by_tgid[tgid] = self
@property @property
def name(self): def name(self) -> str:
return self.mxid return self.mxid
@property @property
def mxid_localpart(self): def mxid_localpart(self) -> str:
match = re.compile("@(.+):(.+)").match(self.mxid) match = re.compile("@(.+):(.+)").match(self.mxid) # type: Match
return match.group(1) return match.group(1)
# TODO replace with proper displayname getting everywhere # TODO replace with proper displayname getting everywhere
@property @property
def displayname(self): def displayname(self) -> str:
return self.mxid_localpart return self.mxid_localpart
@property @property
def db_contacts(self): def db_contacts(self) -> List[DBContact]:
return [self.db.merge(DBContact(user=self.tgid, contact=puppet.id)) return [self.db.merge(DBContact(user=self.tgid, contact=puppet.id))
for puppet in self.contacts] for puppet in self.contacts]
@db_contacts.setter @db_contacts.setter
def db_contacts(self, contacts): def db_contacts(self, contacts: List[DBContact]):
if contacts: self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts] if contacts else []
self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts]
else:
self.contacts = []
@property @property
def db_portals(self): def db_portals(self) -> List[DBPortal]:
return [portal.db_instance for portal in self.portals.values()] return [portal.db_instance for portal in self.portals.values()]
@db_portals.setter @db_portals.setter
def db_portals(self, portals): def db_portals(self, portals: List[DBPortal]):
if portals: self.portals = {(portal.tgid, portal.tg_receiver):
self.portals = {(portal.tgid, portal.tg_receiver): po.Portal.get_by_tgid(portal.tgid, portal.tg_receiver)
po.Portal.get_by_tgid(portal.tgid, portal.tg_receiver) for portal in portals} if portals else {}
for portal in portals}
else:
self.portals = {}
# region Database conversion # region Database conversion
@property @property
def db_instance(self): def db_instance(self) -> DBUser:
if not self._db_instance: if not self._db_instance:
self._db_instance = self.new_db_instance() self._db_instance = self.new_db_instance()
return self._db_instance return self._db_instance
def new_db_instance(self): def new_db_instance(self) -> DBUser:
return DBUser(mxid=self.mxid, tgid=self.tgid, tg_username=self.username, return DBUser(mxid=self.mxid, tgid=self.tgid, tg_username=self.username,
contacts=self.db_contacts, saved_contacts=self.saved_contacts or 0, contacts=self.db_contacts, saved_contacts=self.saved_contacts or 0,
portals=self.db_portals) portals=self.db_portals)
@@ -134,14 +137,14 @@ class User(AbstractUser):
self.db.commit() self.db.commit()
@classmethod @classmethod
def from_db(cls, db_user): def from_db(cls, db_user: DBUser) -> "User":
return User(db_user.mxid, db_user.tgid, db_user.tg_username, db_user.contacts, return User(db_user.mxid, db_user.tgid, db_user.tg_username, db_user.contacts,
False, db_user.saved_contacts, db_user.portals, db_instance=db_user) False, db_user.saved_contacts, db_user.portals, db_instance=db_user)
# endregion # endregion
# region Telegram connection management # region Telegram connection management
async def start(self, delete_unless_authenticated=False): async def start(self, delete_unless_authenticated: bool = False) -> "User":
await super().start() await super().start()
if await self.is_logged_in(): if await self.is_logged_in():
self.log.debug(f"Ensuring post_login() for {self.name}") self.log.debug(f"Ensuring post_login() for {self.name}")
@@ -152,7 +155,7 @@ class User(AbstractUser):
self.client.session.delete() self.client.session.delete()
return self return self
async def post_login(self, info=None): async def post_login(self, info: TLUser = None):
try: try:
await self.update_info(info) await self.update_info(info)
if not self.is_bot: if not self.is_bot:
@@ -163,7 +166,7 @@ class User(AbstractUser):
except Exception: except Exception:
self.log.exception("Failed to run post-login functions for %s", self.mxid) self.log.exception("Failed to run post-login functions for %s", self.mxid)
async def update(self, update): async def update(self, update: TypeUpdate):
if not self.is_bot: if not self.is_bot:
return return
@@ -186,7 +189,7 @@ class User(AbstractUser):
# endregion # endregion
# region Telegram actions that need custom methods # region Telegram actions that need custom methods
def ensure_started(self, even_if_no_session=False) -> "Awaitable[User]": def ensure_started(self, even_if_no_session: bool = False) -> "Awaitable[User]":
return super().ensure_started(even_if_no_session) return super().ensure_started(even_if_no_session)
def set_presence(self, online: bool = True): def set_presence(self, online: bool = True):
@@ -194,7 +197,7 @@ class User(AbstractUser):
return return
return self.client(UpdateStatusRequest(offline=not online)) return self.client(UpdateStatusRequest(offline=not online))
async def update_info(self, info: User = None): async def update_info(self, info: TLUser = None):
info = info or await self.client.get_me() info = info or await self.client.get_me()
changed = False changed = False
if self.is_bot != info.bot: if self.is_bot != info.bot:
@@ -233,8 +236,9 @@ class User(AbstractUser):
self.delete() self.delete()
return True return True
def _search_local(self, query, max_results=5, min_similarity=45): def _search_local(self, query: str, max_results: int = 5, min_similarity: int = 45
results = [] ) -> SearchResults:
results = [] # type: SearchResults
for contact in self.contacts: for contact in self.contacts:
similarity = contact.similarity(query) similarity = contact.similarity(query)
if similarity >= min_similarity: if similarity >= min_similarity:
@@ -242,11 +246,11 @@ class User(AbstractUser):
results.sort(key=lambda tup: tup[1], reverse=True) results.sort(key=lambda tup: tup[1], reverse=True)
return results[0:max_results] return results[0:max_results]
async def _search_remote(self, query, max_results=5): async def _search_remote(self, query: str, max_results: int = 5) -> SearchResults:
if len(query) < 5: if len(query) < 5:
return [] return []
server_results = await self.client(SearchRequest(q=query, limit=max_results)) server_results = await self.client(SearchRequest(q=query, limit=max_results))
results = [] results = [] # type: SearchResults
for user in server_results.users: for user in server_results.users:
puppet = pu.Puppet.get(user.id) puppet = pu.Puppet.get(user.id)
await puppet.update_info(self, user) await puppet.update_info(self, user)
@@ -254,7 +258,7 @@ class User(AbstractUser):
results.sort(key=lambda tup: tup[1], reverse=True) results.sort(key=lambda tup: tup[1], reverse=True)
return results[0:max_results] return results[0:max_results]
async def search(self, query, force_remote=False): async def search(self, query: str, force_remote: bool = False) -> Tuple[SearchResults, bool]:
if force_remote: if force_remote:
return await self._search_remote(query), True return await self._search_remote(query), True
@@ -264,7 +268,7 @@ class User(AbstractUser):
return await self._search_remote(query), True return await self._search_remote(query), True
async def sync_dialogs(self, synchronous_create=False): async def sync_dialogs(self, synchronous_create: bool = False):
creators = [] creators = []
for entity in await self.get_dialogs(limit=30): for entity in await self.get_dialogs(limit=30):
portal = po.Portal.get_by_entity(entity) portal = po.Portal.get_by_entity(entity)
@@ -275,7 +279,7 @@ class User(AbstractUser):
self.save() self.save()
await asyncio.gather(*creators, loop=self.loop) await asyncio.gather(*creators, loop=self.loop)
def register_portal(self, portal): def register_portal(self, portal: po.Portal):
try: try:
if self.portals[portal.tgid_full] == portal: if self.portals[portal.tgid_full] == portal:
return return
@@ -284,18 +288,18 @@ class User(AbstractUser):
self.portals[portal.tgid_full] = portal self.portals[portal.tgid_full] = portal
self.save() self.save()
def unregister_portal(self, portal): def unregister_portal(self, portal: po.Portal):
try: try:
del self.portals[portal.tgid_full] del self.portals[portal.tgid_full]
self.save() self.save()
except KeyError: except KeyError:
pass pass
async def needs_relaybot(self, portal): async def needs_relaybot(self, portal: po.Portal) -> bool:
return not await self.is_logged_in() or ( return not await self.is_logged_in() or (
self.is_bot and portal.tgid_full not in self.portals) self.is_bot and portal.tgid_full not in self.portals)
def _hash_contacts(self): def _hash_contacts(self) -> int:
acc = 0 acc = 0
for id in sorted([self.saved_contacts] + [contact.id for contact in self.contacts]): for id in sorted([self.saved_contacts] + [contact.id for contact in self.contacts]):
acc = (acc * 20261 + id) & 0xffffffff acc = (acc * 20261 + id) & 0xffffffff
@@ -318,7 +322,7 @@ class User(AbstractUser):
# region Class instance lookup # region Class instance lookup
@classmethod @classmethod
def get_by_mxid(cls, mxid, create=True) -> "Optional[User]": def get_by_mxid(cls, mxid: str, create: bool=True) -> "Optional[User]":
if not mxid: if not mxid:
raise ValueError("Matrix ID can't be empty") raise ValueError("Matrix ID can't be empty")
@@ -341,7 +345,7 @@ class User(AbstractUser):
return None return None
@classmethod @classmethod
def get_by_tgid(cls, tgid) -> "Optional[User]": def get_by_tgid(cls, tgid: int) -> "Optional[User]":
try: try:
return cls.by_tgid[tgid] return cls.by_tgid[tgid]
except KeyError: except KeyError:
@@ -355,7 +359,7 @@ class User(AbstractUser):
return None return None
@classmethod @classmethod
def find_by_username(cls, username) -> "Optional[User]": def find_by_username(cls, username: str) -> "Optional[User]":
if not username: if not username:
return None return None
@@ -371,7 +375,7 @@ class User(AbstractUser):
# endregion # endregion
def init(context): def init(context: "Context") -> List[Awaitable[User]]:
global config global config
config = context.config config = context.config
+49 -35
View File
@@ -14,15 +14,25 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Tuple, Union, Dict
from io import BytesIO from io import BytesIO
import time import time
import logging import logging
import asyncio import asyncio
import magic import magic
from sqlalchemy import orm
from sqlalchemy.exc import IntegrityError, InvalidRequestError from sqlalchemy.exc import IntegrityError, InvalidRequestError
from sqlalchemy.orm.exc import FlushError from sqlalchemy.orm.exc import FlushError
from telethon.tl.types import (Document, FileLocation, InputFileLocation,
InputDocumentFileLocation, PhotoSize, PhotoCachedSize)
from telethon.errors import *
from mautrix_appservice import IntentAPI
from ..tgclient import MautrixTelegramClient
from ..db import TelegramFile as DBTelegramFile
try: try:
from PIL import Image from PIL import Image
except ImportError: except ImportError:
@@ -36,20 +46,18 @@ try:
except ImportError: except ImportError:
VideoFileClip = random = string = os = mimetypes = None VideoFileClip = random = string = os = mimetypes = None
from telethon.tl.types import (Document, FileLocation, InputFileLocation, log = logging.getLogger("mau.util") # type: logging.Logger
InputDocumentFileLocation, PhotoSize, PhotoCachedSize)
from telethon.errors import *
from ..db import TelegramFile as DBTelegramFile TypeLocation = Union[Document, InputDocumentFileLocation, FileLocation, InputFileLocation]
log = logging.getLogger("mau.util")
def convert_image(file, source_mime="image/webp", target_type="png", thumbnail_to=None): def convert_image(file: bytes, source_mime: str = "image/webp", target_type: str = "png",
thumbnail_to: Optional[Tuple[int, int]] = None
) -> Tuple[str, bytes, Optional[int], Optional[int]]:
if not Image: if not Image:
return source_mime, file, None, None return source_mime, file, None, None
try: try:
image = Image.open(BytesIO(file)).convert("RGBA") image = Image.open(BytesIO(file)).convert("RGBA") # type: Image.Image
if thumbnail_to: if thumbnail_to:
image.thumbnail(thumbnail_to, Image.ANTIALIAS) image.thumbnail(thumbnail_to, Image.ANTIALIAS)
new_file = BytesIO() new_file = BytesIO()
@@ -61,13 +69,14 @@ def convert_image(file, source_mime="image/webp", target_type="png", thumbnail_t
return source_mime, file, None, None return source_mime, file, None, None
def _temp_file_name(ext): def _temp_file_name(ext: str) -> str:
return ("/tmp/mxtg-video-" return ("/tmp/mxtg-video-"
+ "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10)) + "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10))
+ ext) + ext)
def _read_video_thumbnail(data, video_ext="mp4", frame_ext="png", max_size=(1024, 720)): def _read_video_thumbnail(data: bytes, video_ext: str = "mp4", frame_ext: str = "png",
max_size: Tuple[int, int] = (1024, 720)) -> Tuple[bytes, int, int]:
# We don't have any way to read the video from memory, so save it to disk. # We don't have any way to read the video from memory, so save it to disk.
temp_file = _temp_file_name(video_ext) temp_file = _temp_file_name(video_ext)
with open(temp_file, "wb") as file: with open(temp_file, "wb") as file:
@@ -90,21 +99,21 @@ def _read_video_thumbnail(data, video_ext="mp4", frame_ext="png", max_size=(1024
return thumbnail_file.getvalue(), w, h return thumbnail_file.getvalue(), w, h
def _location_to_id(location): def _location_to_id(location: TypeLocation) -> str:
if isinstance(location, (Document, InputDocumentFileLocation)): if isinstance(location, (Document, InputDocumentFileLocation)):
return f"{location.id}-{location.version}" return f"{location.id}-{location.version}"
elif isinstance(location, (FileLocation, InputFileLocation)): elif isinstance(location, (FileLocation, InputFileLocation)):
return f"{location.volume_id}-{location.local_id}" return f"{location.volume_id}-{location.local_id}"
else:
return None
async def transfer_thumbnail_to_matrix(client, intent, thumbnail_loc, video, mime): async def transfer_thumbnail_to_matrix(client: MautrixTelegramClient, intent: IntentAPI,
thumbnail_loc: TypeLocation, video: bytes,
mime: str) -> Optional[DBTelegramFile]:
if not Image or not VideoFileClip: if not Image or not VideoFileClip:
return None return None
id = _location_to_id(thumbnail_loc) loc_id = _location_to_id(thumbnail_loc)
if not id: if not loc_id:
return None return None
video_ext = mimetypes.guess_extension(mime) video_ext = mimetypes.guess_extension(mime)
@@ -121,36 +130,40 @@ async def transfer_thumbnail_to_matrix(client, intent, thumbnail_loc, video, mim
content_uri = await intent.upload_file(file, mime_type) content_uri = await intent.upload_file(file, mime_type)
return DBTelegramFile(id=id, mxc=content_uri, mime_type=mime_type, return DBTelegramFile(id=loc_id, mxc=content_uri, mime_type=mime_type,
was_converted=False, timestamp=int(time.time()), size=len(file), was_converted=False, timestamp=int(time.time()), size=len(file),
width=width, height=height) width=width, height=height)
transfer_locks = {} transfer_locks = {} # type: Dict[str, asyncio.Lock]
transfer_locks_lock = asyncio.Lock()
async def transfer_file_to_matrix(db, client, intent, location, thumbnail=None, is_sticker=False): async def transfer_file_to_matrix(db: orm.Session, client: MautrixTelegramClient, intent: IntentAPI,
id = _location_to_id(location) location: TypeLocation, thumbnail: Optional[TypeLocation] = None,
if not id: is_sticker: bool = False) -> Optional[DBTelegramFile]:
location_id = _location_to_id(location)
if not location_id:
return None return None
db_file = DBTelegramFile.query.get(id) db_file = DBTelegramFile.query.get(location_id)
if db_file: if db_file:
return db_file return db_file
async with transfer_locks_lock: try:
try: lock = transfer_locks[location_id]
lock = transfer_locks[id] except KeyError:
except KeyError: lock = asyncio.Lock()
lock = asyncio.Lock() transfer_locks[location_id] = lock
transfer_locks[id] = lock
async with lock: async with lock:
return await _unlocked_transfer_file_to_matrix(db, client, intent, id, location, thumbnail, is_sticker) return await _unlocked_transfer_file_to_matrix(db, client, intent, location_id, location,
thumbnail, is_sticker)
async def _unlocked_transfer_file_to_matrix(db, client, intent, id, location, thumbnail, is_sticker): async def _unlocked_transfer_file_to_matrix(db: orm.Session, client: MautrixTelegramClient,
db_file = DBTelegramFile.query.get(id) intent: IntentAPI, loc_id: str, location: TypeLocation,
thumbnail: Optional[TypeLocation],
is_sticker: bool) -> Optional[DBTelegramFile]:
db_file = DBTelegramFile.query.get(loc_id)
if db_file: if db_file:
return db_file return db_file
@@ -167,15 +180,16 @@ async def _unlocked_transfer_file_to_matrix(db, client, intent, id, location, th
image_converted = False image_converted = False
if mime_type == "image/webp": if mime_type == "image/webp":
new_mime_type, file, width, height = convert_image(file, source_mime="image/webp", target_type="png", thumbnail_to=( new_mime_type, file, width, height = convert_image(
256, 256) if is_sticker else None) file, source_mime="image/webp", target_type="png",
thumbnail_to=(256, 256) if is_sticker else None)
image_converted = new_mime_type != mime_type image_converted = new_mime_type != mime_type
mime_type = new_mime_type mime_type = new_mime_type
thumbnail = None thumbnail = None
content_uri = await intent.upload_file(file, mime_type) content_uri = await intent.upload_file(file, mime_type)
db_file = DBTelegramFile(id=id, mxc=content_uri, db_file = DBTelegramFile(id=loc_id, mxc=content_uri,
mime_type=mime_type, was_converted=image_converted, mime_type=mime_type, was_converted=image_converted,
timestamp=int(time.time()), size=len(file), timestamp=int(time.time()), size=len(file),
width=width, height=height) width=width, height=height)
+5 -3
View File
@@ -16,10 +16,12 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
def format_duration(seconds): def format_duration(seconds: int) -> str:
def pluralize(count, singular): return singular if count == 1 else singular + "s" def pluralize(count, singular):
return singular if count == 1 else singular + "s"
def include(count, word): return f"{count} {pluralize(count, word)}" if count > 0 else "" def include(count, word):
return f"{count} {pluralize(count, word)}" if count > 0 else ""
minutes, seconds = divmod(seconds, 60) minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(minutes, 60) hours, minutes = divmod(minutes, 60)