Add more type hints
This commit is contained in:
@@ -14,34 +14,33 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
from typing import Optional
|
||||||
import argparse
|
import argparse
|
||||||
import sys
|
|
||||||
import logging
|
|
||||||
import logging.config
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging.config
|
||||||
|
import sys
|
||||||
|
|
||||||
import sqlalchemy as sql
|
|
||||||
from sqlalchemy import orm
|
from sqlalchemy import orm
|
||||||
|
import sqlalchemy as sql
|
||||||
|
|
||||||
from alchemysession import AlchemySessionContainer
|
|
||||||
from mautrix_appservice import AppService
|
from mautrix_appservice import AppService
|
||||||
|
from alchemysession import AlchemySessionContainer
|
||||||
|
|
||||||
from .base import Base
|
from .web.provisioning import ProvisioningAPI
|
||||||
from .config import Config
|
from .web.public import PublicBridgeWebsite
|
||||||
from .matrix import MatrixHandler
|
|
||||||
|
|
||||||
from . import __version__
|
|
||||||
from .db import init as init_db
|
|
||||||
from .abstract_user import init as init_abstract_user
|
from .abstract_user import init as init_abstract_user
|
||||||
from .user import init as init_user, User
|
from .base import Base
|
||||||
from .bot import init as init_bot
|
from .bot import init as init_bot
|
||||||
|
from .config import Config
|
||||||
|
from .context import Context
|
||||||
|
from .db import init as init_db
|
||||||
|
from .formatter import init as init_formatter
|
||||||
|
from .matrix import MatrixHandler
|
||||||
from .portal import init as init_portal
|
from .portal import init as init_portal
|
||||||
from .puppet import init as init_puppet
|
from .puppet import init as init_puppet
|
||||||
from .formatter import init as init_formatter
|
|
||||||
from .web.public import PublicBridgeWebsite
|
|
||||||
from .web.provisioning import ProvisioningAPI
|
|
||||||
from .context import Context
|
|
||||||
from .sqlstatestore import SQLStateStore
|
from .sqlstatestore import SQLStateStore
|
||||||
|
from .user import User, init as init_user
|
||||||
|
from . import __version__
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="A Matrix-Telegram puppeting bridge.",
|
description="A Matrix-Telegram puppeting bridge.",
|
||||||
@@ -68,7 +67,7 @@ if args.generate_registration:
|
|||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
logging.config.dictConfig(config["logging"])
|
logging.config.dictConfig(config["logging"])
|
||||||
log = logging.getLogger("mau.init")
|
log = logging.getLogger("mau.init") # type: logging.Logger
|
||||||
log.debug(f"Initializing mautrix-telegram {__version__}")
|
log.debug(f"Initializing mautrix-telegram {__version__}")
|
||||||
|
|
||||||
db_engine = sql.create_engine(config["appservice.database"] or "sqlite:///mautrix-telegram.db")
|
db_engine = sql.create_engine(config["appservice.database"] or "sqlite:///mautrix-telegram.db")
|
||||||
@@ -80,7 +79,7 @@ session_container = AlchemySessionContainer(engine=db_engine, session=db_session
|
|||||||
table_base=Base, table_prefix="telethon_",
|
table_base=Base, table_prefix="telethon_",
|
||||||
manage_tables=False)
|
manage_tables=False)
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop() # type: asyncio.AbstractEventLoop
|
||||||
|
|
||||||
state_store = SQLStateStore(db_session)
|
state_store = SQLStateStore(db_session)
|
||||||
appserv = AppService(config["homeserver.address"], config["homeserver.domain"],
|
appserv = AppService(config["homeserver.address"], config["homeserver.domain"],
|
||||||
@@ -89,8 +88,8 @@ appserv = AppService(config["homeserver.address"], config["homeserver.domain"],
|
|||||||
verify_ssl=config["homeserver.verify_ssl"], state_store=state_store,
|
verify_ssl=config["homeserver.verify_ssl"], state_store=state_store,
|
||||||
real_user_content_key="net.maunium.telegram.puppet")
|
real_user_content_key="net.maunium.telegram.puppet")
|
||||||
|
|
||||||
public_website = None
|
public_website = None # type: Optional[PublicBridgeWebsite]
|
||||||
provisioning_api = None
|
provisioning_api = None # type: Optional[ProvisioningAPI]
|
||||||
|
|
||||||
if config["appservice.public.enabled"]:
|
if config["appservice.public.enabled"]:
|
||||||
public_website = PublicBridgeWebsite(loop)
|
public_website = PublicBridgeWebsite(loop)
|
||||||
|
|||||||
@@ -14,26 +14,48 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
from typing import Tuple, Optional, List, Union, TYPE_CHECKING
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
import platform
|
import platform
|
||||||
|
|
||||||
from telethon.tl.types import *
|
from sqlalchemy import orm
|
||||||
from mautrix_appservice import MatrixRequestError
|
from telethon.tl.types import Channel, ChannelForbidden, Chat, ChatForbidden, Message, \
|
||||||
|
MessageActionChannelMigrateFrom, MessageService, PeerUser, TypeUpdate, \
|
||||||
|
UpdateChannelPinnedMessage, UpdateChatAdmins, UpdateChatParticipantAdmin, \
|
||||||
|
UpdateChatParticipants, UpdateChatUserTyping, UpdateDeleteChannelMessages, \
|
||||||
|
UpdateDeleteMessages, UpdateEditChannelMessage, UpdateEditMessage, UpdateNewChannelMessage, \
|
||||||
|
UpdateNewMessage, UpdateReadHistoryOutbox, UpdateShortChatMessage, UpdateShortMessage, \
|
||||||
|
UpdateUserName, UpdateUserPhoto, UpdateUserStatus, UpdateUserTyping, User, UserStatusOffline, \
|
||||||
|
UserStatusOnline
|
||||||
|
|
||||||
|
from mautrix_appservice import MatrixRequestError, AppService
|
||||||
|
from alchemysession import AlchemySessionContainer
|
||||||
|
|
||||||
from .tgclient import MautrixTelegramClient
|
|
||||||
from .db import Message as DBMessage
|
|
||||||
from . import portal as po, puppet as pu, __version__
|
from . import portal as po, puppet as pu, __version__
|
||||||
|
from .db import Message as DBMessage
|
||||||
|
from .tgclient import MautrixTelegramClient
|
||||||
|
|
||||||
config = None
|
if TYPE_CHECKING:
|
||||||
|
from .context import Context
|
||||||
|
from .config import Config
|
||||||
|
|
||||||
|
config = None # type: Config
|
||||||
# Value updated from config in init()
|
# Value updated from config in init()
|
||||||
MAX_DELETIONS = 10
|
MAX_DELETIONS = 10 # type: int
|
||||||
|
|
||||||
|
UpdateMessage = Union[UpdateShortChatMessage, UpdateShortMessage, UpdateNewChannelMessage,
|
||||||
|
UpdateNewMessage, UpdateEditMessage, UpdateEditChannelMessage]
|
||||||
|
UpdateMessageContent = Union[UpdateShortMessage, UpdateShortChatMessage, Message, MessageService]
|
||||||
|
|
||||||
|
|
||||||
class AbstractUser:
|
class AbstractUser(ABC):
|
||||||
session_container = None
|
session_container = None # type: AlchemySessionContainer
|
||||||
loop = None
|
loop = None # type: asyncio.AbstractEventLoop
|
||||||
log = None
|
log = None # type: logging.Logger
|
||||||
db = None
|
db = None # type: orm.Session
|
||||||
az = None
|
az = None # type: AppService
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.puppet_whitelisted = False # type: bool
|
self.puppet_whitelisted = False # type: bool
|
||||||
@@ -47,22 +69,22 @@ class AbstractUser:
|
|||||||
self.is_bot = False # type: bool
|
self.is_bot = False # type: bool
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def connected(self):
|
def connected(self) -> bool:
|
||||||
return self.client and self.client.is_connected()
|
return self.client and self.client.is_connected()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _proxy_settings(self):
|
def _proxy_settings(self) -> Optional[Tuple[int, str, str, str, str, str]]:
|
||||||
type = config["telegram.proxy.type"].lower()
|
proxy_type = config["telegram.proxy.type"].lower()
|
||||||
if type == "disabled":
|
if proxy_type == "disabled":
|
||||||
return None
|
return None
|
||||||
elif type == "socks4":
|
elif proxy_type == "socks4":
|
||||||
type = 1
|
proxy_type = 1
|
||||||
elif type == "socks5":
|
elif proxy_type == "socks5":
|
||||||
type = 2
|
proxy_type = 2
|
||||||
elif type == "http":
|
elif proxy_type == "http":
|
||||||
type = 3
|
proxy_type = 3
|
||||||
|
|
||||||
return (type,
|
return (proxy_type,
|
||||||
config["telegram.proxy.address"], config["telegram.proxy.port"],
|
config["telegram.proxy.address"], config["telegram.proxy.port"],
|
||||||
config["telegram.proxy.rdns"],
|
config["telegram.proxy.rdns"],
|
||||||
config["telegram.proxy.username"], config["telegram.proxy.password"])
|
config["telegram.proxy.username"], config["telegram.proxy.password"])
|
||||||
@@ -83,20 +105,30 @@ class AbstractUser:
|
|||||||
proxy=self._proxy_settings)
|
proxy=self._proxy_settings)
|
||||||
self.client.add_event_handler(self._update_catch)
|
self.client.add_event_handler(self._update_catch)
|
||||||
|
|
||||||
async def update(self, update):
|
@abstractmethod
|
||||||
|
async def update(self, update: TypeUpdate) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def post_login(self):
|
async def post_login(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def _update_catch(self, update):
|
@abstractmethod
|
||||||
|
def register_portal(self, portal: po.Portal):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def unregister_portal(self, portal: po.Portal):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def _update_catch(self, update: TypeUpdate):
|
||||||
try:
|
try:
|
||||||
if not await self.update(update):
|
if not await self.update(update):
|
||||||
await self._update(update)
|
await self._update(update)
|
||||||
except Exception:
|
except Exception:
|
||||||
self.log.exception("Failed to handle Telegram update")
|
self.log.exception("Failed to handle Telegram update")
|
||||||
|
|
||||||
async def get_dialogs(self, limit=None) -> List[Union[Chat, Channel]]:
|
async def get_dialogs(self, limit: int = None) -> List[Union[Chat, Channel]]:
|
||||||
if self.is_bot:
|
if self.is_bot:
|
||||||
return []
|
return []
|
||||||
dialogs = await self.client.get_dialogs(limit=limit)
|
dialogs = await self.client.get_dialogs(limit=limit)
|
||||||
@@ -106,18 +138,19 @@ class AbstractUser:
|
|||||||
and (dialog.entity.deactivated or dialog.entity.left)))]
|
and (dialog.entity.deactivated or dialog.entity.left)))]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
@abstractmethod
|
||||||
|
def name(self) -> str:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def is_logged_in(self):
|
async def is_logged_in(self) -> bool:
|
||||||
return self.client and await self.client.is_user_authorized()
|
return self.client and await self.client.is_user_authorized()
|
||||||
|
|
||||||
async def has_full_access(self, allow_bot=False):
|
async def has_full_access(self, allow_bot: bool = False) -> bool:
|
||||||
return (self.puppet_whitelisted
|
return (self.puppet_whitelisted
|
||||||
and (not self.is_bot or allow_bot)
|
and (not self.is_bot or allow_bot)
|
||||||
and await self.is_logged_in())
|
and await self.is_logged_in())
|
||||||
|
|
||||||
async def start(self, delete_unless_authenticated=False):
|
async def start(self, delete_unless_authenticated: bool = False) -> "AbstractUser":
|
||||||
if not self.client:
|
if not self.client:
|
||||||
self._init_client()
|
self._init_client()
|
||||||
await self.client.connect()
|
await self.client.connect()
|
||||||
@@ -144,7 +177,7 @@ class AbstractUser:
|
|||||||
|
|
||||||
# region Telegram update handling
|
# region Telegram update handling
|
||||||
|
|
||||||
async def _update(self, update):
|
async def _update(self, update: TypeUpdate):
|
||||||
if isinstance(update, (UpdateShortChatMessage, UpdateShortMessage, UpdateNewChannelMessage,
|
if isinstance(update, (UpdateShortChatMessage, UpdateShortMessage, UpdateNewChannelMessage,
|
||||||
UpdateNewMessage, UpdateEditMessage, UpdateEditChannelMessage)):
|
UpdateNewMessage, UpdateEditMessage, UpdateEditChannelMessage)):
|
||||||
await self.update_message(update)
|
await self.update_message(update)
|
||||||
@@ -169,17 +202,19 @@ class AbstractUser:
|
|||||||
else:
|
else:
|
||||||
self.log.debug("Unhandled update: %s", update)
|
self.log.debug("Unhandled update: %s", update)
|
||||||
|
|
||||||
async def update_pinned_messages(self, update):
|
@staticmethod
|
||||||
|
async def update_pinned_messages(update: UpdateChannelPinnedMessage):
|
||||||
portal = po.Portal.get_by_tgid(update.channel_id)
|
portal = po.Portal.get_by_tgid(update.channel_id)
|
||||||
if portal and portal.mxid:
|
if portal and portal.mxid:
|
||||||
await portal.receive_telegram_pin_id(update.id)
|
await portal.receive_telegram_pin_id(update.id)
|
||||||
|
|
||||||
async def update_participants(self, update):
|
@staticmethod
|
||||||
|
async def update_participants(update: UpdateChatParticipants):
|
||||||
portal = po.Portal.get_by_tgid(update.participants.chat_id)
|
portal = po.Portal.get_by_tgid(update.participants.chat_id)
|
||||||
if portal and portal.mxid:
|
if portal and portal.mxid:
|
||||||
await portal.update_telegram_participants(update.participants.participants)
|
await portal.update_telegram_participants(update.participants.participants)
|
||||||
|
|
||||||
async def update_read_receipt(self, update):
|
async def update_read_receipt(self, update: UpdateReadHistoryOutbox):
|
||||||
if not isinstance(update.peer, PeerUser):
|
if not isinstance(update.peer, PeerUser):
|
||||||
self.log.debug("Unexpected read receipt peer: %s", update.peer)
|
self.log.debug("Unexpected read receipt peer: %s", update.peer)
|
||||||
return
|
return
|
||||||
@@ -196,7 +231,7 @@ class AbstractUser:
|
|||||||
puppet = pu.Puppet.get(update.peer.user_id)
|
puppet = pu.Puppet.get(update.peer.user_id)
|
||||||
await puppet.intent.mark_read(portal.mxid, message.mxid)
|
await puppet.intent.mark_read(portal.mxid, message.mxid)
|
||||||
|
|
||||||
async def update_admin(self, update):
|
async def update_admin(self, update: Union[UpdateChatAdmins, UpdateChatParticipantAdmin]):
|
||||||
# TODO duplication not checked
|
# TODO duplication not checked
|
||||||
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
|
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
|
||||||
if isinstance(update, UpdateChatAdmins):
|
if isinstance(update, UpdateChatAdmins):
|
||||||
@@ -206,7 +241,7 @@ class AbstractUser:
|
|||||||
else:
|
else:
|
||||||
self.log.warning("Unexpected admin status update: %s", update)
|
self.log.warning("Unexpected admin status update: %s", update)
|
||||||
|
|
||||||
async def update_typing(self, update):
|
async def update_typing(self, update: Union[UpdateUserTyping, UpdateChatUserTyping]):
|
||||||
if isinstance(update, UpdateUserTyping):
|
if isinstance(update, UpdateUserTyping):
|
||||||
portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user")
|
portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user")
|
||||||
else:
|
else:
|
||||||
@@ -214,7 +249,7 @@ class AbstractUser:
|
|||||||
sender = pu.Puppet.get(update.user_id)
|
sender = pu.Puppet.get(update.user_id)
|
||||||
await portal.handle_telegram_typing(sender, update)
|
await portal.handle_telegram_typing(sender, update)
|
||||||
|
|
||||||
async def update_others_info(self, update):
|
async def update_others_info(self, update: Union[UpdateUserName, UpdateUserPhoto]):
|
||||||
# TODO duplication not checked
|
# TODO duplication not checked
|
||||||
puppet = pu.Puppet.get(update.user_id)
|
puppet = pu.Puppet.get(update.user_id)
|
||||||
if isinstance(update, UpdateUserName):
|
if isinstance(update, UpdateUserName):
|
||||||
@@ -226,7 +261,7 @@ class AbstractUser:
|
|||||||
else:
|
else:
|
||||||
self.log.warning("Unexpected other user info update: %s", update)
|
self.log.warning("Unexpected other user info update: %s", update)
|
||||||
|
|
||||||
async def update_status(self, update):
|
async def update_status(self, update: UpdateUserStatus):
|
||||||
puppet = pu.Puppet.get(update.user_id)
|
puppet = pu.Puppet.get(update.user_id)
|
||||||
if isinstance(update.status, UserStatusOnline):
|
if isinstance(update.status, UserStatusOnline):
|
||||||
await puppet.default_mxid_intent.set_presence("online")
|
await puppet.default_mxid_intent.set_presence("online")
|
||||||
@@ -236,7 +271,9 @@ class AbstractUser:
|
|||||||
self.log.warning("Unexpected user status update: %s", update)
|
self.log.warning("Unexpected user status update: %s", update)
|
||||||
return
|
return
|
||||||
|
|
||||||
def get_message_details(self, update):
|
def get_message_details(self, update: UpdateMessage) -> Tuple[UpdateMessageContent,
|
||||||
|
Optional[pu.Puppet],
|
||||||
|
Optional[po.Portal]]:
|
||||||
if isinstance(update, UpdateShortChatMessage):
|
if isinstance(update, UpdateShortChatMessage):
|
||||||
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
|
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
|
||||||
sender = pu.Puppet.get(update.from_id)
|
sender = pu.Puppet.get(update.from_id)
|
||||||
@@ -259,7 +296,7 @@ class AbstractUser:
|
|||||||
return update, sender, portal
|
return update, sender, portal
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _try_redact(portal, message):
|
async def _try_redact(portal: po.Portal, message: DBMessage):
|
||||||
if not portal:
|
if not portal:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
@@ -267,7 +304,7 @@ class AbstractUser:
|
|||||||
except MatrixRequestError:
|
except MatrixRequestError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def delete_message(self, update):
|
async def delete_message(self, update: UpdateDeleteMessages):
|
||||||
if len(update.messages) > MAX_DELETIONS:
|
if len(update.messages) > MAX_DELETIONS:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -283,7 +320,7 @@ class AbstractUser:
|
|||||||
await self._try_redact(portal, message)
|
await self._try_redact(portal, message)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
|
||||||
async def delete_channel_message(self, update):
|
async def delete_channel_message(self, update: UpdateDeleteChannelMessages):
|
||||||
if len(update.messages) > MAX_DELETIONS:
|
if len(update.messages) > MAX_DELETIONS:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -299,7 +336,7 @@ class AbstractUser:
|
|||||||
await self._try_redact(portal, message)
|
await self._try_redact(portal, message)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
|
||||||
async def update_message(self, original_update):
|
async def update_message(self, original_update: UpdateMessage):
|
||||||
update, sender, portal = self.get_message_details(original_update)
|
update, sender, portal = self.get_message_details(original_update)
|
||||||
|
|
||||||
if isinstance(update, MessageService):
|
if isinstance(update, MessageService):
|
||||||
@@ -325,7 +362,7 @@ class AbstractUser:
|
|||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
def init(context):
|
def init(context: "Context"):
|
||||||
global config, MAX_DELETIONS
|
global config, MAX_DELETIONS
|
||||||
AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, _ = context
|
AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, _ = context
|
||||||
AbstractUser.session_container = context.session_container
|
AbstractUser.session_container = context.session_container
|
||||||
|
|||||||
@@ -1,2 +1,2 @@
|
|||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
Base = declarative_base()
|
Base = declarative_base() # type: declarative_base
|
||||||
|
|||||||
+22
-18
@@ -14,7 +14,7 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import Awaitable, Callable
|
from typing import Awaitable, Callable, Pattern, Dict, TYPE_CHECKING
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@@ -27,27 +27,31 @@ from .abstract_user import AbstractUser
|
|||||||
from .db import BotChat
|
from .db import BotChat
|
||||||
from . import puppet as pu, portal as po, user as u
|
from . import puppet as pu, portal as po, user as u
|
||||||
|
|
||||||
config = None
|
if TYPE_CHECKING:
|
||||||
|
from .config import Config
|
||||||
|
|
||||||
|
config = None # type: Config
|
||||||
|
|
||||||
ReplyFunc = Callable[[str], Awaitable[Message]]
|
ReplyFunc = Callable[[str], Awaitable[Message]]
|
||||||
|
|
||||||
|
|
||||||
class Bot(AbstractUser):
|
class Bot(AbstractUser):
|
||||||
log = logging.getLogger("mau.bot")
|
log = logging.getLogger("mau.bot") # type: logging.Logger
|
||||||
mxid_regex = re.compile("@.+:.+")
|
mxid_regex = re.compile("@.+:.+") # type: Pattern
|
||||||
|
|
||||||
def __init__(self, token: str):
|
def __init__(self, token: str):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.token = token
|
self.token = token # type: str
|
||||||
self.puppet_whitelisted = True
|
self.puppet_whitelisted = True # type: bool
|
||||||
self.whitelisted = True
|
self.whitelisted = True # type: bool
|
||||||
self.relaybot_whitelisted = True
|
self.relaybot_whitelisted = True # type: bool
|
||||||
self.username = None
|
self.username = None # type: str
|
||||||
self.is_relaybot = True
|
self.is_relaybot = True # type: bool
|
||||||
self.is_bot = True
|
self.is_bot = True # type: bool
|
||||||
self.chats = {chat.id: chat.type for chat in BotChat.query.all()}
|
self.chats = {chat.id: chat.type for chat in BotChat.query.all()} # type: Dict[int, str]
|
||||||
self.tg_whitelist = []
|
self.tg_whitelist = [] # type: List[int]
|
||||||
self.whitelist_group_admins = config["bridge.relaybot.whitelist_group_admins"] or False
|
self.whitelist_group_admins = (config["bridge.relaybot.whitelist_group_admins"]
|
||||||
|
or False) # type: bool
|
||||||
|
|
||||||
async def init_permissions(self):
|
async def init_permissions(self):
|
||||||
whitelist = config["bridge.relaybot.whitelist"] or []
|
whitelist = config["bridge.relaybot.whitelist"] or []
|
||||||
@@ -61,7 +65,7 @@ class Bot(AbstractUser):
|
|||||||
if isinstance(id, int):
|
if isinstance(id, int):
|
||||||
self.tg_whitelist.append(id)
|
self.tg_whitelist.append(id)
|
||||||
|
|
||||||
async def start(self, delete_unless_authenticated=False):
|
async def start(self, delete_unless_authenticated: bool = False) -> "Bot":
|
||||||
await super().start(delete_unless_authenticated)
|
await super().start(delete_unless_authenticated)
|
||||||
if not await self.is_logged_in():
|
if not await self.is_logged_in():
|
||||||
await self.client.sign_in(bot_token=self.token)
|
await self.client.sign_in(bot_token=self.token)
|
||||||
@@ -118,7 +122,7 @@ class Bot(AbstractUser):
|
|||||||
self.db.delete(existing_chat)
|
self.db.delete(existing_chat)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
|
||||||
async def _can_use_commands(self, chat, tgid):
|
async def _can_use_commands(self, chat: TypePeer, tgid: int) -> bool:
|
||||||
if tgid in self.tg_whitelist:
|
if tgid in self.tg_whitelist:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -138,7 +142,7 @@ class Bot(AbstractUser):
|
|||||||
if p.user_id == tgid:
|
if p.user_id == tgid:
|
||||||
return isinstance(p, (ChatParticipantCreator, ChatParticipantAdmin))
|
return isinstance(p, (ChatParticipantCreator, ChatParticipantAdmin))
|
||||||
|
|
||||||
async def check_can_use_commands(self, event: Message, reply: ReplyFunc):
|
async def check_can_use_commands(self, event: Message, reply: ReplyFunc) -> bool:
|
||||||
if not await self._can_use_commands(event.to_id, event.from_id):
|
if not await self._can_use_commands(event.to_id, event.from_id):
|
||||||
await reply("You do not have the permission to use that command.")
|
await reply("You do not have the permission to use that command.")
|
||||||
return False
|
return False
|
||||||
@@ -262,7 +266,7 @@ class Bot(AbstractUser):
|
|||||||
return "bot"
|
return "bot"
|
||||||
|
|
||||||
|
|
||||||
def init(context):
|
def init(context) -> Optional[Bot]:
|
||||||
global config
|
global config
|
||||||
config = context.config
|
config = context.config
|
||||||
token = config["telegram.bot_token"]
|
token = config["telegram.bot_token"]
|
||||||
|
|||||||
@@ -23,15 +23,14 @@ from .. import puppet as pu, portal as po
|
|||||||
|
|
||||||
ManagementRoomList = List[Tuple[str, str]]
|
ManagementRoomList = List[Tuple[str, str]]
|
||||||
RoomIDList = List[str]
|
RoomIDList = List[str]
|
||||||
PortalList = List[po.Portal]
|
|
||||||
|
|
||||||
|
|
||||||
async def _find_rooms(intent: IntentAPI) -> Tuple[
|
async def _find_rooms(intent: IntentAPI) -> Tuple[ManagementRoomList, RoomIDList,
|
||||||
ManagementRoomList, RoomIDList, PortalList, PortalList]:
|
List["po.Portal"], List["po.Portal"]]:
|
||||||
management_rooms = [] # type: ManagementRoomList
|
management_rooms = [] # type: ManagementRoomList
|
||||||
unidentified_rooms = [] # type: RoomIDList
|
unidentified_rooms = [] # type: RoomIDList
|
||||||
portals = [] # type: PortalList
|
portals = [] # type: List[po.Portal]
|
||||||
empty_portals = [] # type: PortalList
|
empty_portals = [] # type: List[po.Portal]
|
||||||
|
|
||||||
rooms = await intent.get_joined_rooms()
|
rooms = await intent.get_joined_rooms()
|
||||||
for room in rooms:
|
for room in rooms:
|
||||||
@@ -108,8 +107,8 @@ async def clean_rooms(evt: CommandEvent):
|
|||||||
|
|
||||||
|
|
||||||
async def set_rooms_to_clean(evt, management_rooms: ManagementRoomList,
|
async def set_rooms_to_clean(evt, management_rooms: ManagementRoomList,
|
||||||
unidentified_rooms: RoomIDList, portals: PortalList,
|
unidentified_rooms: RoomIDList, portals: List["po.Portal"],
|
||||||
empty_portals: PortalList):
|
empty_portals: List["po.Portal"]):
|
||||||
command = evt.args[0]
|
command = evt.args[0]
|
||||||
rooms_to_clean = []
|
rooms_to_clean = []
|
||||||
if command == "clean-recommended":
|
if command == "clean-recommended":
|
||||||
|
|||||||
@@ -222,7 +222,7 @@ async def bridge(evt: CommandEvent):
|
|||||||
"chat to this room, use `$cmdprefix+sp continue`")
|
"chat to this room, use `$cmdprefix+sp continue`")
|
||||||
|
|
||||||
|
|
||||||
async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: po.Portal):
|
async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Portal"):
|
||||||
if not portal.mxid:
|
if not portal.mxid:
|
||||||
await evt.reply("The portal seems to have lost its Matrix room between you"
|
await evt.reply("The portal seems to have lost its Matrix room between you"
|
||||||
"calling `$cmdprefix+sp bridge` and this command.\n\n"
|
"calling `$cmdprefix+sp bridge` and this command.\n\n"
|
||||||
|
|||||||
+22
-21
@@ -14,6 +14,7 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
from typing import Tuple, Any, Optional
|
||||||
from ruamel.yaml import YAML
|
from ruamel.yaml import YAML
|
||||||
from ruamel.yaml.comments import CommentedMap
|
from ruamel.yaml.comments import CommentedMap
|
||||||
import random
|
import random
|
||||||
@@ -24,28 +25,28 @@ yaml.indent(4)
|
|||||||
|
|
||||||
|
|
||||||
class DictWithRecursion:
|
class DictWithRecursion:
|
||||||
def __init__(self, data=None):
|
def __init__(self, data: CommentedMap = None):
|
||||||
self._data = data or CommentedMap()
|
self._data = data or CommentedMap() # type: CommentedMap
|
||||||
|
|
||||||
def _recursive_get(self, data, key, default_value):
|
def _recursive_get(self, data: CommentedMap, key: str, default_value: Any) -> Any:
|
||||||
if '.' in key:
|
if '.' in key:
|
||||||
key, next_key = key.split('.', 1)
|
key, next_key = key.split('.', 1)
|
||||||
next_data = data.get(key, CommentedMap())
|
next_data = data.get(key, CommentedMap())
|
||||||
return self._recursive_get(next_data, next_key, default_value)
|
return self._recursive_get(next_data, next_key, default_value)
|
||||||
return data.get(key, default_value)
|
return data.get(key, default_value)
|
||||||
|
|
||||||
def get(self, key, default_value, allow_recursion=True):
|
def get(self, key: str, default_value: Any, allow_recursion: bool = True) -> Any:
|
||||||
if allow_recursion and '.' in key:
|
if allow_recursion and '.' in key:
|
||||||
return self._recursive_get(self._data, key, default_value)
|
return self._recursive_get(self._data, key, default_value)
|
||||||
return self._data.get(key, default_value)
|
return self._data.get(key, default_value)
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key: str) -> Any:
|
||||||
return self.get(key, None)
|
return self.get(key, None)
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key: str) -> bool:
|
||||||
return self[key] is not None
|
return self[key] is not None
|
||||||
|
|
||||||
def _recursive_set(self, data, key, value):
|
def _recursive_set(self, data: CommentedMap, key: str, value: Any):
|
||||||
if '.' in key:
|
if '.' in key:
|
||||||
key, next_key = key.split('.', 1)
|
key, next_key = key.split('.', 1)
|
||||||
if key not in data:
|
if key not in data:
|
||||||
@@ -55,16 +56,16 @@ class DictWithRecursion:
|
|||||||
return
|
return
|
||||||
data[key] = value
|
data[key] = value
|
||||||
|
|
||||||
def set(self, key, value, allow_recursion=True):
|
def set(self, key: str, value: Any, allow_recursion: bool = True):
|
||||||
if allow_recursion and '.' in key:
|
if allow_recursion and '.' in key:
|
||||||
self._recursive_set(self._data, key, value)
|
self._recursive_set(self._data, key, value)
|
||||||
return
|
return
|
||||||
self._data[key] = value
|
self._data[key] = value
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key: str, value: Any):
|
||||||
self.set(key, value)
|
self.set(key, value)
|
||||||
|
|
||||||
def _recursive_del(self, data, key):
|
def _recursive_del(self, data: CommentedMap, key: str):
|
||||||
if '.' in key:
|
if '.' in key:
|
||||||
key, next_key = key.split('.', 1)
|
key, next_key = key.split('.', 1)
|
||||||
if key not in data:
|
if key not in data:
|
||||||
@@ -78,7 +79,7 @@ class DictWithRecursion:
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def delete(self, key, allow_recursion=True):
|
def delete(self, key: str, allow_recursion: bool = True):
|
||||||
if allow_recursion and '.' in key:
|
if allow_recursion and '.' in key:
|
||||||
self._recursive_del(self._data, key)
|
self._recursive_del(self._data, key)
|
||||||
return
|
return
|
||||||
@@ -88,23 +89,23 @@ class DictWithRecursion:
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __delitem__(self, key):
|
def __delitem__(self, key: str):
|
||||||
self.delete(key)
|
self.delete(key)
|
||||||
|
|
||||||
|
|
||||||
class Config(DictWithRecursion):
|
class Config(DictWithRecursion):
|
||||||
def __init__(self, path, registration_path, base_path):
|
def __init__(self, path: str, registration_path: str, base_path: str):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.path = path
|
self.path = path # type: str
|
||||||
self.registration_path = registration_path
|
self.registration_path = registration_path # type: str
|
||||||
self.base_path = base_path
|
self.base_path = base_path # type: str
|
||||||
self._registration = None
|
self._registration = None # type: dict
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
with open(self.path, 'r') as stream:
|
with open(self.path, 'r') as stream:
|
||||||
self._data = yaml.load(stream)
|
self._data = yaml.load(stream)
|
||||||
|
|
||||||
def load_base(self):
|
def load_base(self) -> Optional[DictWithRecursion]:
|
||||||
try:
|
try:
|
||||||
with open(self.base_path, 'r') as stream:
|
with open(self.base_path, 'r') as stream:
|
||||||
return DictWithRecursion(yaml.load(stream))
|
return DictWithRecursion(yaml.load(stream))
|
||||||
@@ -120,7 +121,7 @@ class Config(DictWithRecursion):
|
|||||||
yaml.dump(self._registration, stream)
|
yaml.dump(self._registration, stream)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _new_token():
|
def _new_token() -> str:
|
||||||
return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64))
|
return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64))
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
@@ -246,7 +247,7 @@ class Config(DictWithRecursion):
|
|||||||
self._data = base._data
|
self._data = base._data
|
||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
def _get_permissions(self, key):
|
def _get_permissions(self, key: str) -> Tuple[bool, bool, bool, bool, bool]:
|
||||||
level = self["bridge.permissions"].get(key, "")
|
level = self["bridge.permissions"].get(key, "")
|
||||||
admin = level == "admin"
|
admin = level == "admin"
|
||||||
puppeting = level == "full" or admin
|
puppeting = level == "full" or admin
|
||||||
@@ -254,7 +255,7 @@ class Config(DictWithRecursion):
|
|||||||
relaybot = level == "relaybot" or user
|
relaybot = level == "relaybot" or user
|
||||||
return relaybot, user, puppeting, admin, level
|
return relaybot, user, puppeting, admin, level
|
||||||
|
|
||||||
def get_permissions(self, mxid):
|
def get_permissions(self, mxid: str) -> Tuple[bool, bool, bool, bool, bool]:
|
||||||
permissions = self["bridge.permissions"] or {}
|
permissions = self["bridge.permissions"] or {}
|
||||||
if mxid in permissions:
|
if mxid in permissions:
|
||||||
return self._get_permissions(mxid)
|
return self._get_permissions(mxid)
|
||||||
|
|||||||
+18
-12
@@ -14,21 +14,27 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import Tuple
|
from typing import TYPE_CHECKING
|
||||||
import asyncio
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from sqlalchemy.orm import scoped_session
|
||||||
|
|
||||||
|
from alchemysession import AlchemySessionContainer
|
||||||
|
from mautrix_appservice import AppService
|
||||||
|
|
||||||
|
from .web import PublicBridgeWebsite, ProvisioningAPI
|
||||||
|
from .config import Config
|
||||||
|
from .bot import Bot
|
||||||
|
from .matrix import MatrixHandler
|
||||||
|
|
||||||
from sqlalchemy.orm import scoped_session
|
|
||||||
from alchemysession import AlchemySessionContainer
|
|
||||||
from mautrix_appservice import AppService
|
|
||||||
|
|
||||||
class Context:
|
class Context:
|
||||||
def __init__(self, az, db, config, loop, bot, mx, session_container, public_website,
|
def __init__(self, az: "AppService", db: "scoped_session", config: "Config",
|
||||||
provisioning_api):
|
loop: "asyncio.AbstractEventLoop", bot: "Bot", mx: "MatrixHandler",
|
||||||
from .web import PublicBridgeWebsite, ProvisioningAPI
|
session_container: "AlchemySessionContainer",
|
||||||
from .config import Config
|
public_website: "PublicBridgeWebsite", provisioning_api: "ProvisioningAPI"):
|
||||||
from .bot import Bot
|
|
||||||
from .matrix import MatrixHandler
|
|
||||||
|
|
||||||
self.az = az # type: AppService
|
self.az = az # type: AppService
|
||||||
self.db = db # type: scoped_session
|
self.db = db # type: scoped_session
|
||||||
self.config = config # type: Config
|
self.config = config # type: Config
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ class Portal(Base):
|
|||||||
about = Column(String, nullable=True)
|
about = Column(String, nullable=True)
|
||||||
photo_id = Column(String, nullable=True)
|
photo_id = Column(String, nullable=True)
|
||||||
|
|
||||||
|
|
||||||
class Message(Base):
|
class Message(Base):
|
||||||
query = None # type: Query
|
query = None # type: Query
|
||||||
__tablename__ = "message"
|
__tablename__ = "message"
|
||||||
|
|||||||
@@ -14,10 +14,10 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
from typing import (Optional, List, Tuple, Type, Callable, Dict, Any, Pattern, Deque, Match, TYPE_CHECKING)
|
||||||
from html import unescape
|
from html import unescape
|
||||||
from html.parser import HTMLParser
|
from html.parser import HTMLParser
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Optional, List, Tuple, Type, Callable, Dict, Any
|
|
||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
import logging
|
import logging
|
||||||
@@ -27,37 +27,40 @@ from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName, M
|
|||||||
MessageEntityItalic, MessageEntityCode, MessageEntityPre,
|
MessageEntityItalic, MessageEntityCode, MessageEntityPre,
|
||||||
MessageEntityBotCommand, TypeMessageEntity)
|
MessageEntityBotCommand, TypeMessageEntity)
|
||||||
|
|
||||||
from .. import user as u, puppet as pu, portal as po, context as c
|
from .. import user as u, puppet as pu, portal as po
|
||||||
from ..db import Message as DBMessage
|
from ..db import Message as DBMessage
|
||||||
from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html,
|
from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html,
|
||||||
trim_reply_fallback_text, html_to_unicode)
|
trim_reply_fallback_text, html_to_unicode)
|
||||||
|
|
||||||
log = logging.getLogger("mau.fmt.mx")
|
if TYPE_CHECKING:
|
||||||
should_bridge_plaintext_highlights = False
|
from ..context import Context
|
||||||
|
|
||||||
|
log = logging.getLogger("mau.fmt.mx") # type: logging.Logger
|
||||||
|
should_bridge_plaintext_highlights = False # type: bool
|
||||||
|
|
||||||
|
|
||||||
class MatrixParser(HTMLParser):
|
class MatrixParser(HTMLParser):
|
||||||
mention_regex = re.compile("https://matrix.to/#/(@.+:.+)")
|
mention_regex = re.compile("https://matrix.to/#/(@.+:.+)") # type: Pattern
|
||||||
room_regex = re.compile("https://matrix.to/#/(#.+:.+)")
|
room_regex = re.compile("https://matrix.to/#/(#.+:.+)") # type: Pattern
|
||||||
block_tags = ("br", "p", "pre", "blockquote",
|
block_tags = ("br", "p", "pre", "blockquote",
|
||||||
"ol", "ul", "li",
|
"ol", "ul", "li",
|
||||||
"h1", "h2", "h3", "h4", "h5", "h6",
|
"h1", "h2", "h3", "h4", "h5", "h6",
|
||||||
"div", "hr", "table")
|
"div", "hr", "table") # type: Tuple[str, ...]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.text = ""
|
self.text = "" # type: str
|
||||||
self.entities = []
|
self.entities = [] # type: List[TypeMessageEntity]
|
||||||
self._building_entities = {}
|
self._building_entities = {} # type: Dict[str, TypeMessageEntity]
|
||||||
self._list_counter = 0
|
self._list_counter = 0 # type: int
|
||||||
self._open_tags = deque()
|
self._open_tags = deque() # type: Deque[str]
|
||||||
self._open_tags_meta = deque()
|
self._open_tags_meta = deque() # type: Deque[Any]
|
||||||
self._line_is_new = True
|
self._line_is_new = True # type: bool
|
||||||
self._list_entry_is_new = False
|
self._list_entry_is_new = False # type: bool
|
||||||
|
|
||||||
def _parse_url(self, url: str, args: Dict[str, Any]
|
def _parse_url(self, url: str, args: Dict[str, Any]
|
||||||
) -> Tuple[Optional[Type[TypeMessageEntity]], Optional[str]]:
|
) -> Tuple[Optional[Type[TypeMessageEntity]], Optional[str]]:
|
||||||
mention = self.mention_regex.match(url)
|
mention = self.mention_regex.match(url) # type: Match
|
||||||
if mention:
|
if mention:
|
||||||
mxid = mention.group(1)
|
mxid = mention.group(1)
|
||||||
user = (pu.Puppet.get_by_mxid(mxid)
|
user = (pu.Puppet.get_by_mxid(mxid)
|
||||||
@@ -72,7 +75,7 @@ class MatrixParser(HTMLParser):
|
|||||||
else:
|
else:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
room = self.room_regex.match(url)
|
room = self.room_regex.match(url) # type: Match
|
||||||
if room:
|
if room:
|
||||||
username = po.Portal.get_username_from_mx_alias(room.group(1))
|
username = po.Portal.get_username_from_mx_alias(room.group(1))
|
||||||
portal = po.Portal.find_by_username(username)
|
portal = po.Portal.find_by_username(username)
|
||||||
@@ -92,8 +95,8 @@ class MatrixParser(HTMLParser):
|
|||||||
self._open_tags_meta.appendleft(0)
|
self._open_tags_meta.appendleft(0)
|
||||||
|
|
||||||
attrs = dict(attrs)
|
attrs = dict(attrs)
|
||||||
entity_type = None
|
entity_type = None # type: type(TypeMessageEntity)
|
||||||
args = {}
|
args = {} # type: Dict[str, Any]
|
||||||
if tag in ("strong", "b"):
|
if tag in ("strong", "b"):
|
||||||
entity_type = MessageEntityBold
|
entity_type = MessageEntityBold
|
||||||
elif tag in ("em", "i"):
|
elif tag in ("em", "i"):
|
||||||
@@ -243,12 +246,12 @@ class MatrixParser(HTMLParser):
|
|||||||
self._newline(allow_multi=tag == "br")
|
self._newline(allow_multi=tag == "br")
|
||||||
|
|
||||||
|
|
||||||
command_regex = re.compile(r"^!([A-Za-z0-9@]+)")
|
command_regex = re.compile(r"^!([A-Za-z0-9@]+)") # type: Pattern
|
||||||
not_command_regex = re.compile(r"^\\(![A-Za-z0-9@]+)")
|
not_command_regex = re.compile(r"^\\(![A-Za-z0-9@]+)") # type: Pattern
|
||||||
plain_mention_regex = None
|
plain_mention_regex = None # type: Pattern
|
||||||
|
|
||||||
|
|
||||||
def plain_mention_to_html(match):
|
def plain_mention_to_html(match: Match) -> str:
|
||||||
puppet = pu.Puppet.find_by_displayname(match.group(2))
|
puppet = pu.Puppet.find_by_displayname(match.group(2))
|
||||||
if puppet:
|
if puppet:
|
||||||
return (f"{match.group(1)}"
|
return (f"{match.group(1)}"
|
||||||
@@ -351,7 +354,7 @@ def plain_mention_to_text() -> Tuple[List[TypeMessageEntity], Callable[[str], st
|
|||||||
return entities, replacer
|
return entities, replacer
|
||||||
|
|
||||||
|
|
||||||
def init_mx(context: c.Context):
|
def init_mx(context: "Context"):
|
||||||
global plain_mention_regex, should_bridge_plaintext_highlights
|
global plain_mention_regex, should_bridge_plaintext_highlights
|
||||||
config = context.config
|
config = context.config
|
||||||
dn_template = config.get("bridge.displayname_template", "{displayname} (Telegram)")
|
dn_template = config.get("bridge.displayname_template", "{displayname} (Telegram)")
|
||||||
|
|||||||
@@ -14,13 +14,8 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
from typing import Optional, List, Tuple, TYPE_CHECKING
|
||||||
from html import escape
|
from html import escape
|
||||||
from typing import Optional, List, Tuple
|
|
||||||
|
|
||||||
try:
|
|
||||||
from lxml.html.diff import htmldiff
|
|
||||||
except ImportError:
|
|
||||||
htmldiff = None # type: function
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@@ -33,16 +28,26 @@ from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName,
|
|||||||
from mautrix_appservice import MatrixRequestError
|
from mautrix_appservice import MatrixRequestError
|
||||||
from mautrix_appservice.intent_api import IntentAPI
|
from mautrix_appservice.intent_api import IntentAPI
|
||||||
|
|
||||||
from .. import user as u, puppet as pu, portal as po, context as c
|
from .. import user as u, puppet as pu, portal as po
|
||||||
from ..db import Message as DBMessage
|
from ..db import Message as DBMessage
|
||||||
from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html,
|
from .util import (add_surrogates, remove_surrogates, trim_reply_fallback_html,
|
||||||
trim_reply_fallback_text, unicode_to_html)
|
trim_reply_fallback_text, unicode_to_html)
|
||||||
|
|
||||||
log = logging.getLogger("mau.fmt.tg")
|
if TYPE_CHECKING:
|
||||||
should_highlight_edits = False
|
from ..abstract_user import AbstractUser
|
||||||
|
from ..context import Context
|
||||||
|
|
||||||
|
try:
|
||||||
|
from lxml.html.diff import htmldiff
|
||||||
|
except ImportError:
|
||||||
|
htmldiff = None # type: function
|
||||||
|
|
||||||
|
|
||||||
def telegram_reply_to_matrix(evt: Message, source: u.User) -> dict:
|
log = logging.getLogger("mau.fmt.tg") # type: logging.Logger
|
||||||
|
should_highlight_edits = False # type: bool
|
||||||
|
|
||||||
|
|
||||||
|
def telegram_reply_to_matrix(evt: Message, source: "AbstractUser") -> dict:
|
||||||
if evt.reply_to_msg_id:
|
if evt.reply_to_msg_id:
|
||||||
space = (evt.to_id.channel_id
|
space = (evt.to_id.channel_id
|
||||||
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
|
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
|
||||||
@@ -78,7 +83,7 @@ async def _add_forward_header(source, text: str, html: Optional[str],
|
|||||||
if not fwd_from_text:
|
if not fwd_from_text:
|
||||||
user = await source.client.get_entity(PeerUser(fwd_from.from_id))
|
user = await source.client.get_entity(PeerUser(fwd_from.from_id))
|
||||||
if user:
|
if user:
|
||||||
fwd_from_text = pu.Puppet.get_displayname(user, format=False)
|
fwd_from_text = pu.Puppet.get_displayname(user, False)
|
||||||
fwd_from_html = f"<b>{fwd_from_text}</b>"
|
fwd_from_html = f"<b>{fwd_from_text}</b>"
|
||||||
|
|
||||||
if not fwd_from_text:
|
if not fwd_from_text:
|
||||||
@@ -110,8 +115,9 @@ def highlight_edits(new_html: str, old_html: str) -> str:
|
|||||||
return new_html
|
return new_html
|
||||||
|
|
||||||
|
|
||||||
async def _add_reply_header(source: u.User, text: str, html: str, evt: Message, relates_to: dict,
|
async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: Message,
|
||||||
main_intent: IntentAPI, is_edit: bool) -> Tuple[str, str]:
|
relates_to: dict, main_intent: IntentAPI, is_edit: bool
|
||||||
|
) -> Tuple[str, str]:
|
||||||
space = (evt.to_id.channel_id
|
space = (evt.to_id.channel_id
|
||||||
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
|
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
|
||||||
else source.tgid)
|
else source.tgid)
|
||||||
@@ -142,7 +148,7 @@ async def _add_reply_header(source: u.User, text: str, html: str, evt: Message,
|
|||||||
|
|
||||||
if is_edit and should_highlight_edits:
|
if is_edit and should_highlight_edits:
|
||||||
html = highlight_edits(html or escape(text), r_html_body)
|
html = highlight_edits(html or escape(text), r_html_body)
|
||||||
except (ValueError, KeyError, MatrixRequestError) as e:
|
except (ValueError, KeyError, MatrixRequestError):
|
||||||
r_sender_link = "unknown user"
|
r_sender_link = "unknown user"
|
||||||
r_displayname = "unknown user"
|
r_displayname = "unknown user"
|
||||||
r_text_body = "Failed to fetch message"
|
r_text_body = "Failed to fetch message"
|
||||||
@@ -154,8 +160,9 @@ async def _add_reply_header(source: u.User, text: str, html: str, evt: Message,
|
|||||||
|
|
||||||
r_keyword = "In reply to" if not is_edit else "Edit to"
|
r_keyword = "In reply to" if not is_edit else "Edit to"
|
||||||
r_msg_link = f"<a href='https://matrix.to/#/{msg.mx_room}/{msg.mxid}'>{r_keyword}</a>"
|
r_msg_link = f"<a href='https://matrix.to/#/{msg.mx_room}/{msg.mxid}'>{r_keyword}</a>"
|
||||||
html = (f"<mx-reply><blockquote>{r_msg_link} {r_sender_link}\n{r_html_body}</blockquote></mx-reply>"
|
html = (
|
||||||
+ (html or escape(text)))
|
f"<mx-reply><blockquote>{r_msg_link} {r_sender_link}\n{r_html_body}</blockquote></mx-reply>"
|
||||||
|
+ (html or escape(text)))
|
||||||
|
|
||||||
lines = r_text_body.strip().split("\n")
|
lines = r_text_body.strip().split("\n")
|
||||||
text_with_quote = f"> <{r_displayname}> {lines.pop(0)}"
|
text_with_quote = f"> <{r_displayname}> {lines.pop(0)}"
|
||||||
@@ -167,7 +174,8 @@ async def _add_reply_header(source: u.User, text: str, html: str, evt: Message,
|
|||||||
return text_with_quote, html
|
return text_with_quote, html
|
||||||
|
|
||||||
|
|
||||||
async def telegram_to_matrix(evt: Message, source: u.User, main_intent: Optional[IntentAPI] = None,
|
async def telegram_to_matrix(evt: Message, source: "AbstractUser",
|
||||||
|
main_intent: Optional[IntentAPI] = None,
|
||||||
is_edit: bool = False, prefix_text: Optional[str] = None,
|
is_edit: bool = False, prefix_text: Optional[str] = None,
|
||||||
prefix_html: Optional[str] = None) -> Tuple[str, str, dict]:
|
prefix_html: Optional[str] = None) -> Tuple[str, str, dict]:
|
||||||
text = add_surrogates(evt.message)
|
text = add_surrogates(evt.message)
|
||||||
@@ -320,6 +328,6 @@ def _parse_url(html: List[str], entity_text: str, url: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def init_tg(context: c.Context):
|
def init_tg(context: "Context"):
|
||||||
global should_highlight_edits
|
global should_highlight_edits
|
||||||
should_highlight_edits = htmldiff and context.config["bridge.highlight_edits"]
|
should_highlight_edits = htmldiff and context.config["bridge.highlight_edits"]
|
||||||
|
|||||||
@@ -14,8 +14,8 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
from typing import Optional, Pattern
|
||||||
from html import escape
|
from html import escape
|
||||||
from typing import Optional
|
|
||||||
import struct
|
import struct
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ def trim_reply_fallback_text(text: str) -> str:
|
|||||||
|
|
||||||
html_reply_fallback_regex = re.compile("^<mx-reply>"
|
html_reply_fallback_regex = re.compile("^<mx-reply>"
|
||||||
r"[\s\S]+?"
|
r"[\s\S]+?"
|
||||||
"</mx-reply>")
|
"</mx-reply>") # type: Pattern
|
||||||
|
|
||||||
|
|
||||||
def trim_reply_fallback_html(html: str) -> str:
|
def trim_reply_fallback_html(html: str) -> str:
|
||||||
|
|||||||
+114
-108
@@ -14,26 +14,23 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import List, Dict
|
from typing import List, Dict, Tuple, Set, Match
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from mautrix_appservice import MatrixRequestError, IntentError
|
from mautrix_appservice import MatrixRequestError, IntentError
|
||||||
|
|
||||||
from .user import User
|
from . import user as u, portal as po, puppet as pu, commands as com
|
||||||
from .portal import Portal
|
|
||||||
from .puppet import Puppet
|
|
||||||
from .commands import CommandProcessor
|
|
||||||
|
|
||||||
|
|
||||||
class MatrixHandler:
|
class MatrixHandler:
|
||||||
log = logging.getLogger("mau.mx")
|
log = logging.getLogger("mau.mx") # type: logging.Logger
|
||||||
|
|
||||||
def __init__(self, context):
|
def __init__(self, context):
|
||||||
self.az, self.db, self.config, _, self.tgbot = context
|
self.az, self.db, self.config, _, self.tgbot = context
|
||||||
self.commands = CommandProcessor(context)
|
self.commands = com.CommandProcessor(context) # type: com.CommandProcessor
|
||||||
self.previously_typing = []
|
self.previously_typing = [] # type: List[str]
|
||||||
|
|
||||||
self.az.matrix_event_handler(self.handle_event)
|
self.az.matrix_event_handler(self.handle_event)
|
||||||
|
|
||||||
@@ -53,68 +50,68 @@ class MatrixHandler:
|
|||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
self.log.exception("TimeoutError when trying to set avatar")
|
self.log.exception("TimeoutError when trying to set avatar")
|
||||||
|
|
||||||
async def handle_puppet_invite(self, room, puppet, inviter):
|
async def handle_puppet_invite(self, room_id, puppet: pu.Puppet, inviter: u.User):
|
||||||
intent = puppet.default_mxid_intent
|
intent = puppet.default_mxid_intent
|
||||||
self.log.debug(f"{inviter} invited puppet for {puppet.tgid} to {room}")
|
self.log.debug(f"{inviter} invited puppet for {puppet.tgid} to {room_id}")
|
||||||
if not await inviter.is_logged_in():
|
if not await inviter.is_logged_in():
|
||||||
await intent.error_and_leave(
|
await intent.error_and_leave(
|
||||||
room, text="Please log in before inviting Telegram puppets.")
|
room_id, text="Please log in before inviting Telegram puppets.")
|
||||||
return
|
return
|
||||||
portal = Portal.get_by_mxid(room)
|
portal = po.Portal.get_by_mxid(room_id)
|
||||||
if portal:
|
if portal:
|
||||||
if portal.peer_type == "user":
|
if portal.peer_type == "user":
|
||||||
await intent.error_and_leave(
|
await intent.error_and_leave(
|
||||||
room, text="You can not invite additional users to private chats.")
|
room_id, text="You can not invite additional users to private chats.")
|
||||||
return
|
return
|
||||||
await portal.invite_telegram(inviter, puppet)
|
await portal.invite_telegram(inviter, puppet)
|
||||||
await intent.join_room(room)
|
await intent.join_room(room_id)
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
members = await self.az.intent.get_room_members(room)
|
members = await self.az.intent.get_room_members(room_id)
|
||||||
except MatrixRequestError:
|
except MatrixRequestError:
|
||||||
members = []
|
members = []
|
||||||
if self.az.bot_mxid not in members:
|
if self.az.bot_mxid not in members:
|
||||||
if len(members) > 1:
|
if len(members) > 1:
|
||||||
await intent.error_and_leave(room, text=None, html=(
|
await intent.error_and_leave(room_id, text=None, html=(
|
||||||
f"Please invite "
|
f"Please invite "
|
||||||
f"<a href='https://matrix.to/#/{self.az.bot_mxid}'>the bridge bot</a> "
|
f"<a href='https://matrix.to/#/{self.az.bot_mxid}'>the bridge bot</a> "
|
||||||
f"first if you want to create a Telegram chat."))
|
f"first if you want to create a Telegram chat."))
|
||||||
return
|
return
|
||||||
|
|
||||||
await intent.join_room(room)
|
await intent.join_room(room_id)
|
||||||
portal = Portal.get_by_tgid(puppet.tgid, inviter.tgid, "user")
|
portal = po.Portal.get_by_tgid(puppet.tgid, inviter.tgid, "user")
|
||||||
if portal.mxid:
|
if portal.mxid:
|
||||||
try:
|
try:
|
||||||
await intent.invite(portal.mxid, inviter.mxid)
|
await intent.invite(portal.mxid, inviter.mxid)
|
||||||
await intent.send_notice(room, text=None, html=(
|
await intent.send_notice(room_id, text=None, html=(
|
||||||
"You already have a private chat with me: "
|
"You already have a private chat with me: "
|
||||||
f"<a href='https://matrix.to/#/{portal.mxid}'>"
|
f"<a href='https://matrix.to/#/{portal.mxid}'>"
|
||||||
"Link to room"
|
"Link to room"
|
||||||
"</a>"))
|
"</a>"))
|
||||||
await intent.leave_room(room)
|
await intent.leave_room(room_id)
|
||||||
return
|
return
|
||||||
except MatrixRequestError:
|
except MatrixRequestError:
|
||||||
pass
|
pass
|
||||||
portal.mxid = room
|
portal.mxid = room_id
|
||||||
portal.save()
|
portal.save()
|
||||||
inviter.register_portal(portal)
|
inviter.register_portal(portal)
|
||||||
await intent.send_notice(room, "Portal to private chat created.")
|
await intent.send_notice(room_id, "po.Portal to private chat created.")
|
||||||
else:
|
else:
|
||||||
await intent.join_room(room)
|
await intent.join_room(room_id)
|
||||||
await intent.send_notice(room, "This puppet will remain inactive until a "
|
await intent.send_notice(room_id, "This puppet will remain inactive until a "
|
||||||
"Telegram chat is created for this room.")
|
"Telegram chat is created for this room.")
|
||||||
|
|
||||||
async def accept_bot_invite(self, room, inviter):
|
async def accept_bot_invite(self, room_id: str, inviter: u.User):
|
||||||
tries = 0
|
tries = 0
|
||||||
while tries < 5:
|
while tries < 5:
|
||||||
try:
|
try:
|
||||||
await self.az.intent.join_room(room)
|
await self.az.intent.join_room(room_id)
|
||||||
break
|
break
|
||||||
except (IntentError, MatrixRequestError) as e:
|
except (IntentError, MatrixRequestError):
|
||||||
tries += 1
|
tries += 1
|
||||||
wait_for_seconds = (tries + 1) * 10
|
wait_for_seconds = (tries + 1) * 10
|
||||||
if tries < 5:
|
if tries < 5:
|
||||||
self.log.exception(f"Failed to join room {room} with bridge bot, "
|
self.log.exception(f"Failed to join room {room_id} with bridge bot, "
|
||||||
f"retrying in {wait_for_seconds} seconds...")
|
f"retrying in {wait_for_seconds} seconds...")
|
||||||
await asyncio.sleep(wait_for_seconds)
|
await asyncio.sleep(wait_for_seconds)
|
||||||
else:
|
else:
|
||||||
@@ -123,81 +120,81 @@ class MatrixHandler:
|
|||||||
|
|
||||||
if not inviter.whitelisted:
|
if not inviter.whitelisted:
|
||||||
await self.az.intent.send_notice(
|
await self.az.intent.send_notice(
|
||||||
room, text=None,
|
room_id, text=None,
|
||||||
html="You are not whitelisted to use this bridge.<br/><br/>"
|
html="You are not whitelisted to use this bridge.<br/><br/>"
|
||||||
"If you are the owner of this bridge, see the "
|
"If you are the owner of this bridge, see the "
|
||||||
"<code>bridge.permissions</code> section in your config file.")
|
"<code>bridge.permissions</code> section in your config file.")
|
||||||
await self.az.intent.leave_room(room)
|
await self.az.intent.leave_room(room_id)
|
||||||
|
|
||||||
async def handle_invite(self, room, user, inviter):
|
async def handle_invite(self, room_id: str, user_id: str, inviter_mxid: str):
|
||||||
self.log.debug(f"{inviter} invited {user} to {room}")
|
self.log.debug(f"{inviter_mxid} invited {user_id} to {room_id}")
|
||||||
inviter = await User.get_by_mxid(inviter).ensure_started()
|
inviter = await u.User.get_by_mxid(inviter_mxid).ensure_started()
|
||||||
if user == self.az.bot_mxid:
|
if user_id == self.az.bot_mxid:
|
||||||
return await self.accept_bot_invite(room, inviter)
|
return await self.accept_bot_invite(room_id, inviter)
|
||||||
elif not inviter.whitelisted:
|
elif not inviter.whitelisted:
|
||||||
return
|
return
|
||||||
|
|
||||||
puppet = Puppet.get_by_mxid(user)
|
puppet = pu.Puppet.get_by_mxid(user_id)
|
||||||
if puppet:
|
if puppet:
|
||||||
await self.handle_puppet_invite(room, puppet, inviter)
|
await self.handle_puppet_invite(room_id, puppet, inviter)
|
||||||
return
|
return
|
||||||
|
|
||||||
user = User.get_by_mxid(user, create=False)
|
user = u.User.get_by_mxid(user_id, create=False)
|
||||||
if not user:
|
if not user:
|
||||||
return
|
return
|
||||||
await user.ensure_started()
|
await user.ensure_started()
|
||||||
portal = Portal.get_by_mxid(room)
|
portal = po.Portal.get_by_mxid(room_id)
|
||||||
if user and await user.has_full_access(allow_bot=True) and portal:
|
if user and await user.has_full_access(allow_bot=True) and portal:
|
||||||
await portal.invite_telegram(inviter, user)
|
await portal.invite_telegram(inviter, user)
|
||||||
return
|
return
|
||||||
|
|
||||||
# The rest can probably be ignored
|
# The rest can probably be ignored
|
||||||
|
|
||||||
async def handle_join(self, room, user, event_id):
|
async def handle_join(self, room_id: str, user_id: str, event_id: str):
|
||||||
user = await User.get_by_mxid(user).ensure_started()
|
user = await u.User.get_by_mxid(user_id).ensure_started()
|
||||||
|
|
||||||
portal = Portal.get_by_mxid(room)
|
portal = po.Portal.get_by_mxid(room_id)
|
||||||
if not portal:
|
if not portal:
|
||||||
return
|
return
|
||||||
|
|
||||||
if not user.relaybot_whitelisted:
|
if not user.relaybot_whitelisted:
|
||||||
await portal.main_intent.kick(room, user.mxid,
|
await portal.main_intent.kick(room_id, user.mxid,
|
||||||
"You are not whitelisted on this Telegram bridge.")
|
"You are not whitelisted on this Telegram bridge.")
|
||||||
return
|
return
|
||||||
elif not await user.is_logged_in() and not portal.has_bot:
|
elif not await user.is_logged_in() and not portal.has_bot:
|
||||||
await portal.main_intent.kick(room, user.mxid,
|
await portal.main_intent.kick(room_id, user.mxid,
|
||||||
"This chat does not have a bot relaying "
|
"This chat does not have a bot relaying "
|
||||||
"messages for unauthenticated users.")
|
"messages for unauthenticated users.")
|
||||||
return
|
return
|
||||||
|
|
||||||
self.log.debug(f"{user} joined {room}")
|
self.log.debug(f"{user} joined {room_id}")
|
||||||
if await user.is_logged_in() or portal.has_bot:
|
if await user.is_logged_in() or portal.has_bot:
|
||||||
await portal.join_matrix(user, event_id)
|
await portal.join_matrix(user, event_id)
|
||||||
|
|
||||||
async def handle_part(self, room, user, sender, event_id):
|
async def handle_part(self, room_id: str, user_id, sender_mxid: str, event_id: str):
|
||||||
self.log.debug(f"{user} left {room}")
|
self.log.debug(f"{user_id} left {room_id}")
|
||||||
|
|
||||||
sender = User.get_by_mxid(sender, create=False)
|
sender = u.User.get_by_mxid(sender_mxid, create=False)
|
||||||
if not sender:
|
if not sender:
|
||||||
return
|
return
|
||||||
await sender.ensure_started()
|
await sender.ensure_started()
|
||||||
|
|
||||||
portal = Portal.get_by_mxid(room)
|
portal = po.Portal.get_by_mxid(room_id)
|
||||||
if not portal:
|
if not portal:
|
||||||
return
|
return
|
||||||
|
|
||||||
puppet = Puppet.get_by_mxid(user)
|
puppet = pu.Puppet.get_by_mxid(user_id)
|
||||||
if sender and puppet:
|
if sender and puppet:
|
||||||
await portal.leave_matrix(puppet, sender, event_id)
|
await portal.leave_matrix(puppet, sender, event_id)
|
||||||
|
|
||||||
user = User.get_by_mxid(user, create=False)
|
user = u.User.get_by_mxid(user_id, create=False)
|
||||||
if not user:
|
if not user:
|
||||||
return
|
return
|
||||||
await user.ensure_started()
|
await user.ensure_started()
|
||||||
if await user.is_logged_in() or portal.has_bot:
|
if await user.is_logged_in() or portal.has_bot:
|
||||||
await portal.leave_matrix(user, sender, event_id)
|
await portal.leave_matrix(user, sender, event_id)
|
||||||
|
|
||||||
def is_command(self, message):
|
def is_command(self, message: dict) -> Tuple[bool, str]:
|
||||||
text = message.get("body", "")
|
text = message.get("body", "")
|
||||||
prefix = self.config["bridge.command_prefix"]
|
prefix = self.config["bridge.command_prefix"]
|
||||||
is_command = text.startswith(prefix)
|
is_command = text.startswith(prefix)
|
||||||
@@ -207,14 +204,14 @@ class MatrixHandler:
|
|||||||
|
|
||||||
async def handle_message(self, room, sender, message, event_id):
|
async def handle_message(self, room, sender, message, event_id):
|
||||||
is_command, text = self.is_command(message)
|
is_command, text = self.is_command(message)
|
||||||
sender = await User.get_by_mxid(sender).ensure_started()
|
sender = await u.User.get_by_mxid(sender).ensure_started()
|
||||||
if not sender.relaybot_whitelisted:
|
if not sender.relaybot_whitelisted:
|
||||||
self.log.debug(f"Ignoring message \"{message}\" from {sender} to {room}:"
|
self.log.debug(f"Ignoring message \"{message}\" from {sender} to {room}:"
|
||||||
" User is not whitelisted.")
|
" u.User is not whitelisted.")
|
||||||
return
|
return
|
||||||
self.log.debug(f"Received Matrix event \"{message}\" from {sender} in {room}")
|
self.log.debug(f"Received Matrix event \"{message}\" from {sender} in {room}")
|
||||||
|
|
||||||
portal = Portal.get_by_mxid(room)
|
portal = po.Portal.get_by_mxid(room)
|
||||||
if not is_command and portal and (await sender.is_logged_in() or portal.has_bot):
|
if not is_command and portal and (await sender.is_logged_in() or portal.has_bot):
|
||||||
await portal.handle_matrix_message(sender, message, event_id)
|
await portal.handle_matrix_message(sender, message, event_id)
|
||||||
return
|
return
|
||||||
@@ -239,39 +236,44 @@ class MatrixHandler:
|
|||||||
await self.commands.handle(room, sender, command, args, is_management,
|
await self.commands.handle(room, sender, command, args, is_management,
|
||||||
is_portal=portal is not None)
|
is_portal=portal is not None)
|
||||||
|
|
||||||
async def handle_redaction(self, room, sender, event_id):
|
@staticmethod
|
||||||
sender = await User.get_by_mxid(sender).ensure_started()
|
async def handle_redaction(room_id: str, sender_mxid: str, event_id: str):
|
||||||
|
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
|
||||||
if not sender.relaybot_whitelisted:
|
if not sender.relaybot_whitelisted:
|
||||||
return
|
return
|
||||||
|
|
||||||
portal = Portal.get_by_mxid(room)
|
portal = po.Portal.get_by_mxid(room_id)
|
||||||
if not portal:
|
if not portal:
|
||||||
return
|
return
|
||||||
|
|
||||||
await portal.handle_matrix_deletion(sender, event_id)
|
await portal.handle_matrix_deletion(sender, event_id)
|
||||||
|
|
||||||
async def handle_power_levels(self, room, sender, new, old):
|
@staticmethod
|
||||||
portal = Portal.get_by_mxid(room)
|
async def handle_power_levels(room_id: str, sender_mxid: str, new: dict, old: dict):
|
||||||
sender = await User.get_by_mxid(sender).ensure_started()
|
portal = po.Portal.get_by_mxid(room_id)
|
||||||
|
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
|
||||||
if await sender.has_full_access(allow_bot=True) and portal:
|
if await sender.has_full_access(allow_bot=True) and portal:
|
||||||
await portal.handle_matrix_power_levels(sender, new["users"], old["users"])
|
await portal.handle_matrix_power_levels(sender, new["users"], old["users"])
|
||||||
|
|
||||||
async def handle_room_meta(self, type, room, sender, content):
|
@staticmethod
|
||||||
portal = Portal.get_by_mxid(room)
|
async def handle_room_meta(evt_type: str, room_id: str, sender_mxid: str, content: dict):
|
||||||
sender = await User.get_by_mxid(sender).ensure_started()
|
portal = po.Portal.get_by_mxid(room_id)
|
||||||
|
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
|
||||||
if await sender.has_full_access(allow_bot=True) and portal:
|
if await sender.has_full_access(allow_bot=True) and portal:
|
||||||
handler, content_key = {
|
handler, content_key = {
|
||||||
"m.room.name": (portal.handle_matrix_title, "name"),
|
"m.room.name": (portal.handle_matrix_title, "name"),
|
||||||
"m.room.topic": (portal.handle_matrix_about, "topic"),
|
"m.room.topic": (portal.handle_matrix_about, "topic"),
|
||||||
"m.room.avatar": (portal.handle_matrix_avatar, "url"),
|
"m.room.avatar": (portal.handle_matrix_avatar, "url"),
|
||||||
}[type]
|
}[evt_type]
|
||||||
if content_key not in content:
|
if content_key not in content:
|
||||||
return
|
return
|
||||||
await handler(sender, content[content_key])
|
await handler(sender, content[content_key])
|
||||||
|
|
||||||
async def handle_room_pin(self, room, sender, new_events, old_events):
|
@staticmethod
|
||||||
portal = Portal.get_by_mxid(room)
|
async def handle_room_pin(room_id: str, sender_mxid: str, new_events: Set[str],
|
||||||
sender = await User.get_by_mxid(sender).ensure_started()
|
old_events: Set[str]):
|
||||||
|
portal = po.Portal.get_by_mxid(room_id)
|
||||||
|
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
|
||||||
if await sender.has_full_access(allow_bot=True) and portal:
|
if await sender.has_full_access(allow_bot=True) and portal:
|
||||||
events = new_events - old_events
|
events = new_events - old_events
|
||||||
if len(events) > 0:
|
if len(events) > 0:
|
||||||
@@ -281,12 +283,14 @@ class MatrixHandler:
|
|||||||
# All pinned events removed, remove pinned event in Telegram.
|
# All pinned events removed, remove pinned event in Telegram.
|
||||||
await portal.handle_matrix_pin(sender, None)
|
await portal.handle_matrix_pin(sender, None)
|
||||||
|
|
||||||
async def handle_name_change(self, room, user, displayname, prev_displayname, event_id):
|
@staticmethod
|
||||||
portal = Portal.get_by_mxid(room)
|
async def handle_name_change(room_id: str, user_id: str, displayname: str,
|
||||||
|
prev_displayname: str, event_id: str):
|
||||||
|
portal = po.Portal.get_by_mxid(room_id)
|
||||||
if not portal or not portal.has_bot:
|
if not portal or not portal.has_bot:
|
||||||
return
|
return
|
||||||
|
|
||||||
user = await User.get_by_mxid(user).ensure_started()
|
user = await u.User.get_by_mxid(user_id).ensure_started()
|
||||||
if await user.needs_relaybot(portal):
|
if await user.needs_relaybot(portal):
|
||||||
await portal.name_change_matrix(user, displayname, prev_displayname, event_id)
|
await portal.name_change_matrix(user, displayname, prev_displayname, event_id)
|
||||||
|
|
||||||
@@ -296,25 +300,27 @@ class MatrixHandler:
|
|||||||
for event_id, receipts in content.items()
|
for event_id, receipts in content.items()
|
||||||
for user_id in receipts.get("m.read", {})}
|
for user_id in receipts.get("m.read", {})}
|
||||||
|
|
||||||
async def handle_read_receipts(self, room_id: str, receipts: Dict[str, str]):
|
@staticmethod
|
||||||
portal = Portal.get_by_mxid(room_id)
|
async def handle_read_receipts(room_id: str, receipts: Dict[str, str]):
|
||||||
|
portal = po.Portal.get_by_mxid(room_id)
|
||||||
if not portal:
|
if not portal:
|
||||||
return
|
return
|
||||||
|
|
||||||
for user_id, event_id in receipts.items():
|
for user_id, event_id in receipts.items():
|
||||||
user = await User.get_by_mxid(user_id).ensure_started()
|
user = await u.User.get_by_mxid(user_id).ensure_started()
|
||||||
if not await user.is_logged_in():
|
if not await user.is_logged_in():
|
||||||
continue
|
continue
|
||||||
await portal.mark_read(user, event_id)
|
await portal.mark_read(user, event_id)
|
||||||
|
|
||||||
async def handle_presence(self, user: str, presence: str):
|
@staticmethod
|
||||||
user = await User.get_by_mxid(user).ensure_started()
|
async def handle_presence(user_id: str, presence: str):
|
||||||
|
user = await u.User.get_by_mxid(user_id).ensure_started()
|
||||||
if not await user.is_logged_in():
|
if not await user.is_logged_in():
|
||||||
return
|
return
|
||||||
await user.set_presence(presence == "online")
|
await user.set_presence(presence == "online")
|
||||||
|
|
||||||
async def handle_typing(self, room_id: str, now_typing: List[str]):
|
async def handle_typing(self, room_id: str, now_typing: List[str]):
|
||||||
portal = Portal.get_by_mxid(room_id)
|
portal = po.Portal.get_by_mxid(room_id)
|
||||||
if not portal:
|
if not portal:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -324,7 +330,7 @@ class MatrixHandler:
|
|||||||
if is_typing and was_typing:
|
if is_typing and was_typing:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
user = await User.get_by_mxid(user_id).ensure_started()
|
user = await u.User.get_by_mxid(user_id).ensure_started()
|
||||||
if not await user.is_logged_in():
|
if not await user.is_logged_in():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -332,38 +338,38 @@ class MatrixHandler:
|
|||||||
|
|
||||||
self.previously_typing = now_typing
|
self.previously_typing = now_typing
|
||||||
|
|
||||||
def filter_matrix_event(self, event):
|
def filter_matrix_event(self, event: dict):
|
||||||
sender = event.get("sender", None)
|
sender = event.get("sender", None)
|
||||||
if not sender:
|
if not sender:
|
||||||
return False
|
return False
|
||||||
return (sender == self.az.bot_mxid
|
return (sender == self.az.bot_mxid
|
||||||
or Puppet.get_id_from_mxid(sender) is not None)
|
or pu.Puppet.get_id_from_mxid(sender) is not None)
|
||||||
|
|
||||||
async def try_handle_event(self, evt):
|
async def try_handle_event(self, evt: dict):
|
||||||
try:
|
try:
|
||||||
await self.handle_event(evt)
|
await self.handle_event(evt)
|
||||||
except Exception:
|
except Exception:
|
||||||
self.log.exception("Error handling manually received Matrix event")
|
self.log.exception("Error handling manually received Matrix event")
|
||||||
|
|
||||||
async def handle_event(self, evt):
|
async def handle_event(self, evt: dict):
|
||||||
if self.filter_matrix_event(evt):
|
if self.filter_matrix_event(evt):
|
||||||
return
|
return
|
||||||
self.log.debug("Received event: %s", evt)
|
self.log.debug("Received event: %s", evt)
|
||||||
type = evt.get("type", "m.unknown")
|
evt_type = evt.get("type", "m.unknown") # type: str
|
||||||
room_id = evt.get("room_id", None)
|
room_id = evt.get("room_id", None) # type: str
|
||||||
event_id = evt.get("event_id", None)
|
event_id = evt.get("event_id", None) # type: str
|
||||||
sender = evt.get("sender", None)
|
sender = evt.get("sender", None) # type: str
|
||||||
content = evt.get("content", {})
|
content = evt.get("content", {}) # type: dict
|
||||||
if type == "m.room.member":
|
if evt_type == "m.room.member":
|
||||||
state_key = evt["state_key"]
|
state_key = evt["state_key"] # type: str
|
||||||
prev_content = evt.get("unsigned", {}).get("prev_content", {})
|
prev_content = evt.get("unsigned", {}).get("prev_content", {}) # type: dict
|
||||||
membership = content.get("membership", "")
|
membership = content.get("membership", "") # type: str
|
||||||
prev_membership = prev_content.get("membership", "leave")
|
prev_membership = prev_content.get("membership", "leave") # type: str
|
||||||
if membership == prev_membership:
|
if membership == prev_membership:
|
||||||
match = re.compile("@(.+):(.+)").match(state_key)
|
match = re.compile("@(.+):(.+)").match(state_key) # type: Match
|
||||||
localpart = match.group(1)
|
localpart = match.group(1) # type: str
|
||||||
displayname = content.get("displayname", localpart)
|
displayname = content.get("displayname", localpart) # type: str
|
||||||
prev_displayname = prev_content.get("displayname", localpart)
|
prev_displayname = prev_content.get("displayname", localpart) # type: str
|
||||||
if displayname != prev_displayname:
|
if displayname != prev_displayname:
|
||||||
await self.handle_name_change(room_id, state_key, displayname,
|
await self.handle_name_change(room_id, state_key, displayname,
|
||||||
prev_displayname, event_id)
|
prev_displayname, event_id)
|
||||||
@@ -373,26 +379,26 @@ class MatrixHandler:
|
|||||||
await self.handle_part(room_id, state_key, sender, event_id)
|
await self.handle_part(room_id, state_key, sender, event_id)
|
||||||
elif membership == "join":
|
elif membership == "join":
|
||||||
await self.handle_join(room_id, state_key, event_id)
|
await self.handle_join(room_id, state_key, event_id)
|
||||||
elif type in ("m.room.message", "m.sticker"):
|
elif evt_type in ("m.room.message", "m.sticker"):
|
||||||
if type != "m.room.message":
|
if evt_type != "m.room.message":
|
||||||
content["msgtype"] = type
|
content["msgtype"] = evt_type
|
||||||
await self.handle_message(room_id, sender, content, event_id)
|
await self.handle_message(room_id, sender, content, event_id)
|
||||||
elif type == "m.room.redaction":
|
elif evt_type == "m.room.redaction":
|
||||||
await self.handle_redaction(room_id, sender, evt["redacts"])
|
await self.handle_redaction(room_id, sender, evt["redacts"])
|
||||||
elif type == "m.room.power_levels":
|
elif evt_type == "m.room.power_levels":
|
||||||
await self.handle_power_levels(room_id, sender, evt["content"], evt["prev_content"])
|
await self.handle_power_levels(room_id, sender, evt["content"], evt["prev_content"])
|
||||||
elif type in ("m.room.name", "m.room.avatar", "m.room.topic"):
|
elif evt_type in ("m.room.name", "m.room.avatar", "m.room.topic"):
|
||||||
await self.handle_room_meta(type, room_id, sender, evt["content"])
|
await self.handle_room_meta(evt_type, room_id, sender, evt["content"])
|
||||||
elif type == "m.room.pinned_events":
|
elif evt_type == "m.room.pinned_events":
|
||||||
new_events = set(evt["content"]["pinned"])
|
new_events = set(evt["content"]["pinned"])
|
||||||
try:
|
try:
|
||||||
old_events = set(evt["unsigned"]["prev_content"]["pinned"])
|
old_events = set(evt["unsigned"]["prev_content"]["pinned"])
|
||||||
except KeyError:
|
except KeyError:
|
||||||
old_events = set()
|
old_events = set()
|
||||||
await self.handle_room_pin(room_id, sender, new_events, old_events)
|
await self.handle_room_pin(room_id, sender, new_events, old_events)
|
||||||
elif type == "m.receipt":
|
elif evt_type == "m.receipt":
|
||||||
await self.handle_read_receipts(room_id, self.parse_read_receipts(content))
|
await self.handle_read_receipts(room_id, self.parse_read_receipts(content))
|
||||||
elif type == "m.presence":
|
elif evt_type == "m.presence":
|
||||||
await self.handle_presence(sender, content.get("presence", "offline"))
|
await self.handle_presence(sender, content.get("presence", "offline"))
|
||||||
elif type == "m.typing":
|
elif evt_type == "m.typing":
|
||||||
await self.handle_typing(room_id, content.get("user_ids", []))
|
await self.handle_typing(room_id, content.get("user_ids", []))
|
||||||
|
|||||||
+249
-196
File diff suppressed because it is too large
Load Diff
+34
-26
@@ -14,32 +14,39 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
from typing import Optional, Awaitable, Pattern, Dict, List, TYPE_CHECKING
|
||||||
from difflib import SequenceMatcher
|
from difflib import SequenceMatcher
|
||||||
from typing import Optional, Awaitable
|
|
||||||
import re
|
import re
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
from sqlalchemy import orm
|
||||||
|
|
||||||
from telethon.tl.types import UserProfilePhoto
|
from telethon.tl.types import UserProfilePhoto
|
||||||
from mautrix_appservice import AppService, IntentAPI, IntentError, MatrixRequestError
|
from mautrix_appservice import AppService, IntentAPI, IntentError, MatrixRequestError
|
||||||
|
|
||||||
from .db import Puppet as DBPuppet
|
from .db import Puppet as DBPuppet
|
||||||
from . import util, matrix
|
from . import util
|
||||||
|
|
||||||
config = None
|
if TYPE_CHECKING:
|
||||||
|
from .matrix import MatrixHandler
|
||||||
|
from .config import Config
|
||||||
|
from .context import Context
|
||||||
|
|
||||||
|
config = None # type: Config
|
||||||
|
|
||||||
|
|
||||||
class Puppet:
|
class Puppet:
|
||||||
log = logging.getLogger("mau.puppet")
|
log = logging.getLogger("mau.puppet") # type: logging.Logger
|
||||||
db = None
|
db = None # type: orm.Session
|
||||||
az = None # type: AppService
|
az = None # type: AppService
|
||||||
mx = None # type: matrix.MatrixHandler
|
mx = None # type: MatrixHandler
|
||||||
loop = None # type: asyncio.AbstractEventLoop
|
loop = None # type: asyncio.AbstractEventLoop
|
||||||
mxid_regex = None
|
mxid_regex = None # type: Pattern
|
||||||
username_template = None
|
username_template = None # type: str
|
||||||
hs_domain = None
|
hs_domain = None # type: str
|
||||||
cache = {}
|
cache = {} # type: Dict[str, Puppet]
|
||||||
by_custom_mxid = {}
|
by_custom_mxid = {} # type: Dict[str, Puppet]
|
||||||
|
|
||||||
def __init__(self, id=None, access_token=None, custom_mxid=None, username=None,
|
def __init__(self, id=None, access_token=None, custom_mxid=None, username=None,
|
||||||
displayname=None, displayname_source=None, photo_id=None, is_bot=None,
|
displayname=None, displayname_source=None, photo_id=None, is_bot=None,
|
||||||
@@ -71,7 +78,8 @@ class Puppet:
|
|||||||
def tgid(self):
|
def tgid(self):
|
||||||
return self.id
|
return self.id
|
||||||
|
|
||||||
async def is_logged_in(self):
|
@staticmethod
|
||||||
|
async def is_logged_in():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# region Custom puppet management
|
# region Custom puppet management
|
||||||
@@ -154,12 +162,12 @@ class Puppet:
|
|||||||
def filter_events(self, events):
|
def filter_events(self, events):
|
||||||
new_events = []
|
new_events = []
|
||||||
for event in events:
|
for event in events:
|
||||||
type = event.get("type", None)
|
evt_type = event.get("type", None)
|
||||||
event.setdefault("content", {})
|
event.setdefault("content", {})
|
||||||
if type == "m.typing":
|
if evt_type == "m.typing":
|
||||||
is_typing = self.custom_mxid in event["content"].get("user_ids", [])
|
is_typing = self.custom_mxid in event["content"].get("user_ids", [])
|
||||||
event["content"]["user_ids"] = [self.custom_mxid] if is_typing else []
|
event["content"]["user_ids"] = [self.custom_mxid] if is_typing else []
|
||||||
elif type == "m.receipt":
|
elif evt_type == "m.receipt":
|
||||||
val = None
|
val = None
|
||||||
evt = None
|
evt = None
|
||||||
for event_id in event["content"]:
|
for event_id in event["content"]:
|
||||||
@@ -273,7 +281,7 @@ class Puppet:
|
|||||||
return round(similarity * 1000) / 10
|
return round(similarity * 1000) / 10
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_displayname(info, format=True):
|
def get_displayname(info, enable_format=True):
|
||||||
data = {
|
data = {
|
||||||
"phone number": info.phone if hasattr(info, "phone") else None,
|
"phone number": info.phone if hasattr(info, "phone") else None,
|
||||||
"username": info.username,
|
"username": info.username,
|
||||||
@@ -295,7 +303,7 @@ class Puppet:
|
|||||||
elif not name:
|
elif not name:
|
||||||
name = info.id
|
name = info.id
|
||||||
|
|
||||||
if not format:
|
if not enable_format:
|
||||||
return name
|
return name
|
||||||
return config.get("bridge.displayname_template", "{displayname} (Telegram)").format(
|
return config.get("bridge.displayname_template", "{displayname} (Telegram)").format(
|
||||||
displayname=name)
|
displayname=name)
|
||||||
@@ -347,18 +355,18 @@ class Puppet:
|
|||||||
# region Getters
|
# region Getters
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls, id, create=True) -> "Optional[Puppet]":
|
def get(cls, tgid, create=True) -> "Optional[Puppet]":
|
||||||
try:
|
try:
|
||||||
return cls.cache[id]
|
return cls.cache[tgid]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
puppet = DBPuppet.query.get(id)
|
puppet = DBPuppet.query.get(tgid)
|
||||||
if puppet:
|
if puppet:
|
||||||
return cls.from_db(puppet)
|
return cls.from_db(puppet)
|
||||||
|
|
||||||
if create:
|
if create:
|
||||||
puppet = cls(id)
|
puppet = cls(tgid)
|
||||||
cls.db.add(puppet.db_instance)
|
cls.db.add(puppet.db_instance)
|
||||||
cls.db.commit()
|
cls.db.commit()
|
||||||
return puppet
|
return puppet
|
||||||
@@ -402,8 +410,8 @@ class Puppet:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_mxid_from_id(cls, id):
|
def get_mxid_from_id(cls, tgid):
|
||||||
return f"@{cls.username_template.format(userid=id)}:{cls.hs_domain}"
|
return f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def find_by_username(cls, username) -> "Optional[Puppet]":
|
def find_by_username(cls, username) -> "Optional[Puppet]":
|
||||||
@@ -437,12 +445,12 @@ class Puppet:
|
|||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
def init(context):
|
def init(context: "Context") -> List[Awaitable[int]]:
|
||||||
global config
|
global config
|
||||||
Puppet.az, Puppet.db, config, Puppet.loop, _ = context
|
Puppet.az, Puppet.db, config, Puppet.loop, _ = context
|
||||||
Puppet.mx = context.mx
|
Puppet.mx = context.mx
|
||||||
Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}")
|
Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}")
|
||||||
Puppet.hs_domain = config["homeserver"]["domain"]
|
Puppet.hs_domain = config["homeserver"]["domain"]
|
||||||
localpart = Puppet.username_template.format(userid="(.+)")
|
Puppet.mxid_regex = re.compile(
|
||||||
Puppet.mxid_regex = re.compile(f"@{localpart}:{Puppet.hs_domain}")
|
f"@{Puppet.username_template.format(userid='(.+)')}:{Puppet.hs_domain}")
|
||||||
return [puppet.init_custom_mxid() for puppet in Puppet.get_all_with_custom_mxid()]
|
return [puppet.init_custom_mxid() for puppet in Puppet.get_all_with_custom_mxid()]
|
||||||
|
|||||||
@@ -16,6 +16,8 @@
|
|||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
from sqlalchemy import orm
|
||||||
|
|
||||||
from mautrix_appservice import StateStore
|
from mautrix_appservice import StateStore
|
||||||
|
|
||||||
from . import puppet as pu
|
from . import puppet as pu
|
||||||
@@ -25,15 +27,17 @@ from .db import RoomState, UserProfile
|
|||||||
class SQLStateStore(StateStore):
|
class SQLStateStore(StateStore):
|
||||||
def __init__(self, db):
|
def __init__(self, db):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.db = db
|
self.db = db # type: orm.Session
|
||||||
self.profile_cache = {} # type: Dict[Tuple[str, str], UserProfile]
|
self.profile_cache = {} # type: Dict[Tuple[str, str], UserProfile]
|
||||||
self.room_state_cache = {} # type: Dict[str, RoomState]
|
self.room_state_cache = {} # type: Dict[str, RoomState]
|
||||||
|
|
||||||
def is_registered(self, user: str) -> bool:
|
@staticmethod
|
||||||
|
def is_registered(user: str) -> bool:
|
||||||
puppet = pu.Puppet.get_by_mxid(user)
|
puppet = pu.Puppet.get_by_mxid(user)
|
||||||
return puppet.is_registered if puppet else False
|
return puppet.is_registered if puppet else False
|
||||||
|
|
||||||
def registered(self, user: str):
|
@staticmethod
|
||||||
|
def registered(user: str):
|
||||||
puppet = pu.Puppet.get_by_mxid(user)
|
puppet = pu.Puppet.get_by_mxid(user)
|
||||||
if puppet:
|
if puppet:
|
||||||
puppet.is_registered = True
|
puppet.is_registered = True
|
||||||
|
|||||||
@@ -17,10 +17,14 @@
|
|||||||
from telethon import TelegramClient, utils
|
from telethon import TelegramClient, utils
|
||||||
from telethon.tl.functions.messages import SendMediaRequest
|
from telethon.tl.functions.messages import SendMediaRequest
|
||||||
from telethon.tl.types import *
|
from telethon.tl.types import *
|
||||||
|
from telethon.tl import custom
|
||||||
|
|
||||||
|
|
||||||
class MautrixTelegramClient(TelegramClient):
|
class MautrixTelegramClient(TelegramClient):
|
||||||
async def upload_file(self, file, mime_type=None, attributes=None, file_name=None):
|
async def upload_file_direct(self, file: bytes, mime_type: str = None,
|
||||||
|
attributes: List[TypeDocumentAttribute] = None,
|
||||||
|
file_name: str = None
|
||||||
|
) -> Union[InputMediaUploadedDocument, InputMediaUploadedPhoto]:
|
||||||
file_handle = await super().upload_file(file, file_name=file_name, use_cache=False)
|
file_handle = await super().upload_file(file, file_name=file_name, use_cache=False)
|
||||||
|
|
||||||
if mime_type == "image/png" or mime_type == "image/jpeg":
|
if mime_type == "image/png" or mime_type == "image/jpeg":
|
||||||
@@ -34,7 +38,10 @@ class MautrixTelegramClient(TelegramClient):
|
|||||||
mime_type=mime_type or "application/octet-stream",
|
mime_type=mime_type or "application/octet-stream",
|
||||||
attributes=list(attr_dict.values()))
|
attributes=list(attr_dict.values()))
|
||||||
|
|
||||||
async def send_media(self, entity, media, caption=None, entities=None, reply_to=None):
|
async def send_media(self, entity: Union[TypeInputPeer, TypePeer],
|
||||||
|
media: Union[TypeInputMedia, TypeMessageMedia],
|
||||||
|
caption: str = None, entities: List[TypeMessageEntity] = None,
|
||||||
|
reply_to: int = None) -> Optional[custom.Message]:
|
||||||
entity = await self.get_input_entity(entity)
|
entity = await self.get_input_entity(entity)
|
||||||
reply_to = utils.get_message_id(reply_to)
|
reply_to = utils.get_message_id(reply_to)
|
||||||
request = SendMediaRequest(entity, media, message=caption or "", entities=entities or [],
|
request = SendMediaRequest(entity, media, message=caption or "", entities=entities or [],
|
||||||
|
|||||||
+58
-54
@@ -14,42 +14,51 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import Dict, Awaitable, Optional
|
from typing import Dict, Awaitable, Optional, Match, Tuple, TYPE_CHECKING
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from telethon.tl.types import *
|
from telethon.tl.types import *
|
||||||
|
from telethon.tl.types import User as TLUser
|
||||||
from telethon.tl.types.contacts import ContactsNotModified
|
from telethon.tl.types.contacts import ContactsNotModified
|
||||||
from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest
|
from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest
|
||||||
from telethon.tl.functions.account import UpdateStatusRequest
|
from telethon.tl.functions.account import UpdateStatusRequest
|
||||||
from mautrix_appservice import MatrixRequestError
|
from mautrix_appservice import MatrixRequestError
|
||||||
|
|
||||||
from .db import User as DBUser, Contact as DBContact
|
from .db import User as DBUser, Contact as DBContact, Portal as DBPortal
|
||||||
from .abstract_user import AbstractUser
|
from .abstract_user import AbstractUser
|
||||||
from . import portal as po, puppet as pu
|
from . import portal as po, puppet as pu
|
||||||
|
|
||||||
config = None
|
if TYPE_CHECKING:
|
||||||
|
from .config import Config
|
||||||
|
from .context import Context
|
||||||
|
|
||||||
|
config = None # type: Config
|
||||||
|
|
||||||
|
SearchResults = List[Tuple["pu.Puppet", int]]
|
||||||
|
|
||||||
|
|
||||||
class User(AbstractUser):
|
class User(AbstractUser):
|
||||||
log = logging.getLogger("mau.user")
|
log = logging.getLogger("mau.user") # type: logging.Logger
|
||||||
by_mxid = {}
|
by_mxid = {} # type: Dict[str, User]
|
||||||
by_tgid = {}
|
by_tgid = {} # type: Dict[int, User]
|
||||||
|
|
||||||
def __init__(self, mxid, tgid=None, username=None, db_contacts=None, saved_contacts=0,
|
def __init__(self, mxid: str, tgid: Optional[int] = None, username: Optional[str] = None,
|
||||||
is_bot=False, db_portals=None, db_instance=None):
|
db_contacts: Optional[List[DBContact]] = None, saved_contacts: int = 0,
|
||||||
|
is_bot: bool = False, db_portals: Optional[List[DBPortal]] = None,
|
||||||
|
db_instance: Optional[DBUser] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mxid = mxid # type: str
|
self.mxid = mxid # type: str
|
||||||
self.tgid = tgid # type: int
|
self.tgid = tgid # type: int
|
||||||
self.is_bot = is_bot # type: bool
|
self.is_bot = is_bot # type: bool
|
||||||
self.username = username # type: str
|
self.username = username # type: str
|
||||||
self.contacts = []
|
self.contacts = [] # type: List[pu.Puppet]
|
||||||
self.saved_contacts = saved_contacts
|
self.saved_contacts = saved_contacts # type: int
|
||||||
self.db_contacts = db_contacts
|
self.db_contacts = db_contacts # type: List[DBContact]
|
||||||
self.portals = {} # type: Dict[str, po.Portal]
|
self.portals = {} # type: Dict[Tuple[int, int], po.Portal]
|
||||||
self.db_portals = db_portals
|
self.db_portals = db_portals # type: List[DBPortal]
|
||||||
self._db_instance = db_instance
|
self._db_instance = db_instance # type: DBUser
|
||||||
|
|
||||||
self.command_status = None # type: dict
|
self.command_status = None # type: dict
|
||||||
|
|
||||||
@@ -64,53 +73,47 @@ class User(AbstractUser):
|
|||||||
self.by_tgid[tgid] = self
|
self.by_tgid[tgid] = self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return self.mxid
|
return self.mxid
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mxid_localpart(self):
|
def mxid_localpart(self) -> str:
|
||||||
match = re.compile("@(.+):(.+)").match(self.mxid)
|
match = re.compile("@(.+):(.+)").match(self.mxid) # type: Match
|
||||||
return match.group(1)
|
return match.group(1)
|
||||||
|
|
||||||
# TODO replace with proper displayname getting everywhere
|
# TODO replace with proper displayname getting everywhere
|
||||||
@property
|
@property
|
||||||
def displayname(self):
|
def displayname(self) -> str:
|
||||||
return self.mxid_localpart
|
return self.mxid_localpart
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def db_contacts(self):
|
def db_contacts(self) -> List[DBContact]:
|
||||||
return [self.db.merge(DBContact(user=self.tgid, contact=puppet.id))
|
return [self.db.merge(DBContact(user=self.tgid, contact=puppet.id))
|
||||||
for puppet in self.contacts]
|
for puppet in self.contacts]
|
||||||
|
|
||||||
@db_contacts.setter
|
@db_contacts.setter
|
||||||
def db_contacts(self, contacts):
|
def db_contacts(self, contacts: List[DBContact]):
|
||||||
if contacts:
|
self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts] if contacts else []
|
||||||
self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts]
|
|
||||||
else:
|
|
||||||
self.contacts = []
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def db_portals(self):
|
def db_portals(self) -> List[DBPortal]:
|
||||||
return [portal.db_instance for portal in self.portals.values()]
|
return [portal.db_instance for portal in self.portals.values()]
|
||||||
|
|
||||||
@db_portals.setter
|
@db_portals.setter
|
||||||
def db_portals(self, portals):
|
def db_portals(self, portals: List[DBPortal]):
|
||||||
if portals:
|
self.portals = {(portal.tgid, portal.tg_receiver):
|
||||||
self.portals = {(portal.tgid, portal.tg_receiver):
|
po.Portal.get_by_tgid(portal.tgid, portal.tg_receiver)
|
||||||
po.Portal.get_by_tgid(portal.tgid, portal.tg_receiver)
|
for portal in portals} if portals else {}
|
||||||
for portal in portals}
|
|
||||||
else:
|
|
||||||
self.portals = {}
|
|
||||||
|
|
||||||
# region Database conversion
|
# region Database conversion
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def db_instance(self):
|
def db_instance(self) -> DBUser:
|
||||||
if not self._db_instance:
|
if not self._db_instance:
|
||||||
self._db_instance = self.new_db_instance()
|
self._db_instance = self.new_db_instance()
|
||||||
return self._db_instance
|
return self._db_instance
|
||||||
|
|
||||||
def new_db_instance(self):
|
def new_db_instance(self) -> DBUser:
|
||||||
return DBUser(mxid=self.mxid, tgid=self.tgid, tg_username=self.username,
|
return DBUser(mxid=self.mxid, tgid=self.tgid, tg_username=self.username,
|
||||||
contacts=self.db_contacts, saved_contacts=self.saved_contacts or 0,
|
contacts=self.db_contacts, saved_contacts=self.saved_contacts or 0,
|
||||||
portals=self.db_portals)
|
portals=self.db_portals)
|
||||||
@@ -134,14 +137,14 @@ class User(AbstractUser):
|
|||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_db(cls, db_user):
|
def from_db(cls, db_user: DBUser) -> "User":
|
||||||
return User(db_user.mxid, db_user.tgid, db_user.tg_username, db_user.contacts,
|
return User(db_user.mxid, db_user.tgid, db_user.tg_username, db_user.contacts,
|
||||||
False, db_user.saved_contacts, db_user.portals, db_instance=db_user)
|
False, db_user.saved_contacts, db_user.portals, db_instance=db_user)
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
# region Telegram connection management
|
# region Telegram connection management
|
||||||
|
|
||||||
async def start(self, delete_unless_authenticated=False):
|
async def start(self, delete_unless_authenticated: bool = False) -> "User":
|
||||||
await super().start()
|
await super().start()
|
||||||
if await self.is_logged_in():
|
if await self.is_logged_in():
|
||||||
self.log.debug(f"Ensuring post_login() for {self.name}")
|
self.log.debug(f"Ensuring post_login() for {self.name}")
|
||||||
@@ -152,7 +155,7 @@ class User(AbstractUser):
|
|||||||
self.client.session.delete()
|
self.client.session.delete()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def post_login(self, info=None):
|
async def post_login(self, info: TLUser = None):
|
||||||
try:
|
try:
|
||||||
await self.update_info(info)
|
await self.update_info(info)
|
||||||
if not self.is_bot:
|
if not self.is_bot:
|
||||||
@@ -163,7 +166,7 @@ class User(AbstractUser):
|
|||||||
except Exception:
|
except Exception:
|
||||||
self.log.exception("Failed to run post-login functions for %s", self.mxid)
|
self.log.exception("Failed to run post-login functions for %s", self.mxid)
|
||||||
|
|
||||||
async def update(self, update):
|
async def update(self, update: TypeUpdate):
|
||||||
if not self.is_bot:
|
if not self.is_bot:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -186,7 +189,7 @@ class User(AbstractUser):
|
|||||||
# endregion
|
# endregion
|
||||||
# region Telegram actions that need custom methods
|
# region Telegram actions that need custom methods
|
||||||
|
|
||||||
def ensure_started(self, even_if_no_session=False) -> "Awaitable[User]":
|
def ensure_started(self, even_if_no_session: bool = False) -> "Awaitable[User]":
|
||||||
return super().ensure_started(even_if_no_session)
|
return super().ensure_started(even_if_no_session)
|
||||||
|
|
||||||
def set_presence(self, online: bool = True):
|
def set_presence(self, online: bool = True):
|
||||||
@@ -194,7 +197,7 @@ class User(AbstractUser):
|
|||||||
return
|
return
|
||||||
return self.client(UpdateStatusRequest(offline=not online))
|
return self.client(UpdateStatusRequest(offline=not online))
|
||||||
|
|
||||||
async def update_info(self, info: User = None):
|
async def update_info(self, info: TLUser = None):
|
||||||
info = info or await self.client.get_me()
|
info = info or await self.client.get_me()
|
||||||
changed = False
|
changed = False
|
||||||
if self.is_bot != info.bot:
|
if self.is_bot != info.bot:
|
||||||
@@ -233,8 +236,9 @@ class User(AbstractUser):
|
|||||||
self.delete()
|
self.delete()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _search_local(self, query, max_results=5, min_similarity=45):
|
def _search_local(self, query: str, max_results: int = 5, min_similarity: int = 45
|
||||||
results = []
|
) -> SearchResults:
|
||||||
|
results = [] # type: SearchResults
|
||||||
for contact in self.contacts:
|
for contact in self.contacts:
|
||||||
similarity = contact.similarity(query)
|
similarity = contact.similarity(query)
|
||||||
if similarity >= min_similarity:
|
if similarity >= min_similarity:
|
||||||
@@ -242,11 +246,11 @@ class User(AbstractUser):
|
|||||||
results.sort(key=lambda tup: tup[1], reverse=True)
|
results.sort(key=lambda tup: tup[1], reverse=True)
|
||||||
return results[0:max_results]
|
return results[0:max_results]
|
||||||
|
|
||||||
async def _search_remote(self, query, max_results=5):
|
async def _search_remote(self, query: str, max_results: int = 5) -> SearchResults:
|
||||||
if len(query) < 5:
|
if len(query) < 5:
|
||||||
return []
|
return []
|
||||||
server_results = await self.client(SearchRequest(q=query, limit=max_results))
|
server_results = await self.client(SearchRequest(q=query, limit=max_results))
|
||||||
results = []
|
results = [] # type: SearchResults
|
||||||
for user in server_results.users:
|
for user in server_results.users:
|
||||||
puppet = pu.Puppet.get(user.id)
|
puppet = pu.Puppet.get(user.id)
|
||||||
await puppet.update_info(self, user)
|
await puppet.update_info(self, user)
|
||||||
@@ -254,7 +258,7 @@ class User(AbstractUser):
|
|||||||
results.sort(key=lambda tup: tup[1], reverse=True)
|
results.sort(key=lambda tup: tup[1], reverse=True)
|
||||||
return results[0:max_results]
|
return results[0:max_results]
|
||||||
|
|
||||||
async def search(self, query, force_remote=False):
|
async def search(self, query: str, force_remote: bool = False) -> Tuple[SearchResults, bool]:
|
||||||
if force_remote:
|
if force_remote:
|
||||||
return await self._search_remote(query), True
|
return await self._search_remote(query), True
|
||||||
|
|
||||||
@@ -264,7 +268,7 @@ class User(AbstractUser):
|
|||||||
|
|
||||||
return await self._search_remote(query), True
|
return await self._search_remote(query), True
|
||||||
|
|
||||||
async def sync_dialogs(self, synchronous_create=False):
|
async def sync_dialogs(self, synchronous_create: bool = False):
|
||||||
creators = []
|
creators = []
|
||||||
for entity in await self.get_dialogs(limit=30):
|
for entity in await self.get_dialogs(limit=30):
|
||||||
portal = po.Portal.get_by_entity(entity)
|
portal = po.Portal.get_by_entity(entity)
|
||||||
@@ -275,7 +279,7 @@ class User(AbstractUser):
|
|||||||
self.save()
|
self.save()
|
||||||
await asyncio.gather(*creators, loop=self.loop)
|
await asyncio.gather(*creators, loop=self.loop)
|
||||||
|
|
||||||
def register_portal(self, portal):
|
def register_portal(self, portal: po.Portal):
|
||||||
try:
|
try:
|
||||||
if self.portals[portal.tgid_full] == portal:
|
if self.portals[portal.tgid_full] == portal:
|
||||||
return
|
return
|
||||||
@@ -284,18 +288,18 @@ class User(AbstractUser):
|
|||||||
self.portals[portal.tgid_full] = portal
|
self.portals[portal.tgid_full] = portal
|
||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
def unregister_portal(self, portal):
|
def unregister_portal(self, portal: po.Portal):
|
||||||
try:
|
try:
|
||||||
del self.portals[portal.tgid_full]
|
del self.portals[portal.tgid_full]
|
||||||
self.save()
|
self.save()
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def needs_relaybot(self, portal):
|
async def needs_relaybot(self, portal: po.Portal) -> bool:
|
||||||
return not await self.is_logged_in() or (
|
return not await self.is_logged_in() or (
|
||||||
self.is_bot and portal.tgid_full not in self.portals)
|
self.is_bot and portal.tgid_full not in self.portals)
|
||||||
|
|
||||||
def _hash_contacts(self):
|
def _hash_contacts(self) -> int:
|
||||||
acc = 0
|
acc = 0
|
||||||
for id in sorted([self.saved_contacts] + [contact.id for contact in self.contacts]):
|
for id in sorted([self.saved_contacts] + [contact.id for contact in self.contacts]):
|
||||||
acc = (acc * 20261 + id) & 0xffffffff
|
acc = (acc * 20261 + id) & 0xffffffff
|
||||||
@@ -318,7 +322,7 @@ class User(AbstractUser):
|
|||||||
# region Class instance lookup
|
# region Class instance lookup
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_by_mxid(cls, mxid, create=True) -> "Optional[User]":
|
def get_by_mxid(cls, mxid: str, create: bool=True) -> "Optional[User]":
|
||||||
if not mxid:
|
if not mxid:
|
||||||
raise ValueError("Matrix ID can't be empty")
|
raise ValueError("Matrix ID can't be empty")
|
||||||
|
|
||||||
@@ -341,7 +345,7 @@ class User(AbstractUser):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_by_tgid(cls, tgid) -> "Optional[User]":
|
def get_by_tgid(cls, tgid: int) -> "Optional[User]":
|
||||||
try:
|
try:
|
||||||
return cls.by_tgid[tgid]
|
return cls.by_tgid[tgid]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@@ -355,7 +359,7 @@ class User(AbstractUser):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def find_by_username(cls, username) -> "Optional[User]":
|
def find_by_username(cls, username: str) -> "Optional[User]":
|
||||||
if not username:
|
if not username:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -371,7 +375,7 @@ class User(AbstractUser):
|
|||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
def init(context):
|
def init(context: "Context") -> List[Awaitable[User]]:
|
||||||
global config
|
global config
|
||||||
config = context.config
|
config = context.config
|
||||||
|
|
||||||
|
|||||||
@@ -14,15 +14,25 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
from typing import Optional, Tuple, Union, Dict
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
import magic
|
import magic
|
||||||
|
from sqlalchemy import orm
|
||||||
from sqlalchemy.exc import IntegrityError, InvalidRequestError
|
from sqlalchemy.exc import IntegrityError, InvalidRequestError
|
||||||
from sqlalchemy.orm.exc import FlushError
|
from sqlalchemy.orm.exc import FlushError
|
||||||
|
|
||||||
|
from telethon.tl.types import (Document, FileLocation, InputFileLocation,
|
||||||
|
InputDocumentFileLocation, PhotoSize, PhotoCachedSize)
|
||||||
|
from telethon.errors import *
|
||||||
|
from mautrix_appservice import IntentAPI
|
||||||
|
|
||||||
|
from ..tgclient import MautrixTelegramClient
|
||||||
|
from ..db import TelegramFile as DBTelegramFile
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -36,20 +46,18 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
VideoFileClip = random = string = os = mimetypes = None
|
VideoFileClip = random = string = os = mimetypes = None
|
||||||
|
|
||||||
from telethon.tl.types import (Document, FileLocation, InputFileLocation,
|
log = logging.getLogger("mau.util") # type: logging.Logger
|
||||||
InputDocumentFileLocation, PhotoSize, PhotoCachedSize)
|
|
||||||
from telethon.errors import *
|
|
||||||
|
|
||||||
from ..db import TelegramFile as DBTelegramFile
|
TypeLocation = Union[Document, InputDocumentFileLocation, FileLocation, InputFileLocation]
|
||||||
|
|
||||||
log = logging.getLogger("mau.util")
|
|
||||||
|
|
||||||
|
|
||||||
def convert_image(file, source_mime="image/webp", target_type="png", thumbnail_to=None):
|
def convert_image(file: bytes, source_mime: str = "image/webp", target_type: str = "png",
|
||||||
|
thumbnail_to: Optional[Tuple[int, int]] = None
|
||||||
|
) -> Tuple[str, bytes, Optional[int], Optional[int]]:
|
||||||
if not Image:
|
if not Image:
|
||||||
return source_mime, file, None, None
|
return source_mime, file, None, None
|
||||||
try:
|
try:
|
||||||
image = Image.open(BytesIO(file)).convert("RGBA")
|
image = Image.open(BytesIO(file)).convert("RGBA") # type: Image.Image
|
||||||
if thumbnail_to:
|
if thumbnail_to:
|
||||||
image.thumbnail(thumbnail_to, Image.ANTIALIAS)
|
image.thumbnail(thumbnail_to, Image.ANTIALIAS)
|
||||||
new_file = BytesIO()
|
new_file = BytesIO()
|
||||||
@@ -61,13 +69,14 @@ def convert_image(file, source_mime="image/webp", target_type="png", thumbnail_t
|
|||||||
return source_mime, file, None, None
|
return source_mime, file, None, None
|
||||||
|
|
||||||
|
|
||||||
def _temp_file_name(ext):
|
def _temp_file_name(ext: str) -> str:
|
||||||
return ("/tmp/mxtg-video-"
|
return ("/tmp/mxtg-video-"
|
||||||
+ "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10))
|
+ "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10))
|
||||||
+ ext)
|
+ ext)
|
||||||
|
|
||||||
|
|
||||||
def _read_video_thumbnail(data, video_ext="mp4", frame_ext="png", max_size=(1024, 720)):
|
def _read_video_thumbnail(data: bytes, video_ext: str = "mp4", frame_ext: str = "png",
|
||||||
|
max_size: Tuple[int, int] = (1024, 720)) -> Tuple[bytes, int, int]:
|
||||||
# We don't have any way to read the video from memory, so save it to disk.
|
# We don't have any way to read the video from memory, so save it to disk.
|
||||||
temp_file = _temp_file_name(video_ext)
|
temp_file = _temp_file_name(video_ext)
|
||||||
with open(temp_file, "wb") as file:
|
with open(temp_file, "wb") as file:
|
||||||
@@ -90,21 +99,21 @@ def _read_video_thumbnail(data, video_ext="mp4", frame_ext="png", max_size=(1024
|
|||||||
return thumbnail_file.getvalue(), w, h
|
return thumbnail_file.getvalue(), w, h
|
||||||
|
|
||||||
|
|
||||||
def _location_to_id(location):
|
def _location_to_id(location: TypeLocation) -> str:
|
||||||
if isinstance(location, (Document, InputDocumentFileLocation)):
|
if isinstance(location, (Document, InputDocumentFileLocation)):
|
||||||
return f"{location.id}-{location.version}"
|
return f"{location.id}-{location.version}"
|
||||||
elif isinstance(location, (FileLocation, InputFileLocation)):
|
elif isinstance(location, (FileLocation, InputFileLocation)):
|
||||||
return f"{location.volume_id}-{location.local_id}"
|
return f"{location.volume_id}-{location.local_id}"
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def transfer_thumbnail_to_matrix(client, intent, thumbnail_loc, video, mime):
|
async def transfer_thumbnail_to_matrix(client: MautrixTelegramClient, intent: IntentAPI,
|
||||||
|
thumbnail_loc: TypeLocation, video: bytes,
|
||||||
|
mime: str) -> Optional[DBTelegramFile]:
|
||||||
if not Image or not VideoFileClip:
|
if not Image or not VideoFileClip:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
id = _location_to_id(thumbnail_loc)
|
loc_id = _location_to_id(thumbnail_loc)
|
||||||
if not id:
|
if not loc_id:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
video_ext = mimetypes.guess_extension(mime)
|
video_ext = mimetypes.guess_extension(mime)
|
||||||
@@ -121,36 +130,40 @@ async def transfer_thumbnail_to_matrix(client, intent, thumbnail_loc, video, mim
|
|||||||
|
|
||||||
content_uri = await intent.upload_file(file, mime_type)
|
content_uri = await intent.upload_file(file, mime_type)
|
||||||
|
|
||||||
return DBTelegramFile(id=id, mxc=content_uri, mime_type=mime_type,
|
return DBTelegramFile(id=loc_id, mxc=content_uri, mime_type=mime_type,
|
||||||
was_converted=False, timestamp=int(time.time()), size=len(file),
|
was_converted=False, timestamp=int(time.time()), size=len(file),
|
||||||
width=width, height=height)
|
width=width, height=height)
|
||||||
|
|
||||||
|
|
||||||
transfer_locks = {}
|
transfer_locks = {} # type: Dict[str, asyncio.Lock]
|
||||||
transfer_locks_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
|
|
||||||
async def transfer_file_to_matrix(db, client, intent, location, thumbnail=None, is_sticker=False):
|
async def transfer_file_to_matrix(db: orm.Session, client: MautrixTelegramClient, intent: IntentAPI,
|
||||||
id = _location_to_id(location)
|
location: TypeLocation, thumbnail: Optional[TypeLocation] = None,
|
||||||
if not id:
|
is_sticker: bool = False) -> Optional[DBTelegramFile]:
|
||||||
|
location_id = _location_to_id(location)
|
||||||
|
if not location_id:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
db_file = DBTelegramFile.query.get(id)
|
db_file = DBTelegramFile.query.get(location_id)
|
||||||
if db_file:
|
if db_file:
|
||||||
return db_file
|
return db_file
|
||||||
|
|
||||||
async with transfer_locks_lock:
|
try:
|
||||||
try:
|
lock = transfer_locks[location_id]
|
||||||
lock = transfer_locks[id]
|
except KeyError:
|
||||||
except KeyError:
|
lock = asyncio.Lock()
|
||||||
lock = asyncio.Lock()
|
transfer_locks[location_id] = lock
|
||||||
transfer_locks[id] = lock
|
|
||||||
async with lock:
|
async with lock:
|
||||||
return await _unlocked_transfer_file_to_matrix(db, client, intent, id, location, thumbnail, is_sticker)
|
return await _unlocked_transfer_file_to_matrix(db, client, intent, location_id, location,
|
||||||
|
thumbnail, is_sticker)
|
||||||
|
|
||||||
|
|
||||||
async def _unlocked_transfer_file_to_matrix(db, client, intent, id, location, thumbnail, is_sticker):
|
async def _unlocked_transfer_file_to_matrix(db: orm.Session, client: MautrixTelegramClient,
|
||||||
db_file = DBTelegramFile.query.get(id)
|
intent: IntentAPI, loc_id: str, location: TypeLocation,
|
||||||
|
thumbnail: Optional[TypeLocation],
|
||||||
|
is_sticker: bool) -> Optional[DBTelegramFile]:
|
||||||
|
db_file = DBTelegramFile.query.get(loc_id)
|
||||||
if db_file:
|
if db_file:
|
||||||
return db_file
|
return db_file
|
||||||
|
|
||||||
@@ -167,15 +180,16 @@ async def _unlocked_transfer_file_to_matrix(db, client, intent, id, location, th
|
|||||||
|
|
||||||
image_converted = False
|
image_converted = False
|
||||||
if mime_type == "image/webp":
|
if mime_type == "image/webp":
|
||||||
new_mime_type, file, width, height = convert_image(file, source_mime="image/webp", target_type="png", thumbnail_to=(
|
new_mime_type, file, width, height = convert_image(
|
||||||
256, 256) if is_sticker else None)
|
file, source_mime="image/webp", target_type="png",
|
||||||
|
thumbnail_to=(256, 256) if is_sticker else None)
|
||||||
image_converted = new_mime_type != mime_type
|
image_converted = new_mime_type != mime_type
|
||||||
mime_type = new_mime_type
|
mime_type = new_mime_type
|
||||||
thumbnail = None
|
thumbnail = None
|
||||||
|
|
||||||
content_uri = await intent.upload_file(file, mime_type)
|
content_uri = await intent.upload_file(file, mime_type)
|
||||||
|
|
||||||
db_file = DBTelegramFile(id=id, mxc=content_uri,
|
db_file = DBTelegramFile(id=loc_id, mxc=content_uri,
|
||||||
mime_type=mime_type, was_converted=image_converted,
|
mime_type=mime_type, was_converted=image_converted,
|
||||||
timestamp=int(time.time()), size=len(file),
|
timestamp=int(time.time()), size=len(file),
|
||||||
width=width, height=height)
|
width=width, height=height)
|
||||||
|
|||||||
@@ -16,10 +16,12 @@
|
|||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
|
||||||
def format_duration(seconds):
|
def format_duration(seconds: int) -> str:
|
||||||
def pluralize(count, singular): return singular if count == 1 else singular + "s"
|
def pluralize(count, singular):
|
||||||
|
return singular if count == 1 else singular + "s"
|
||||||
|
|
||||||
def include(count, word): return f"{count} {pluralize(count, word)}" if count > 0 else ""
|
def include(count, word):
|
||||||
|
return f"{count} {pluralize(count, word)}" if count > 0 else ""
|
||||||
|
|
||||||
minutes, seconds = divmod(seconds, 60)
|
minutes, seconds = divmod(seconds, 60)
|
||||||
hours, minutes = divmod(minutes, 60)
|
hours, minutes = divmod(minutes, 60)
|
||||||
|
|||||||
Reference in New Issue
Block a user