Blacken and isort code

This commit is contained in:
Tulir Asokan
2021-12-21 01:36:24 +02:00
parent f2af17d359
commit 6d25e9687e
55 changed files with 3752 additions and 2018 deletions
+18
View File
@@ -0,0 +1,18 @@
name: Python lint
on: [push, pull_request]
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: "3.10"
- uses: isort/isort-action@master
with:
sortPaths: "./mautrix_telegram"
- uses: psf/black@21.12b0
with:
src: "./mautrix_telegram"
-16
View File
@@ -1,16 +0,0 @@
[settings]
line_length=99
indent=4
multi_line_output=5
sections=FUTURE,STDLIB,THIRDPARTY,TELETHON,MAUTRIX,FIRSTPARTY,LOCALFOLDER
no_lines_before=LOCALFOLDER
default_section=FIRSTPARTY
known_thirdparty=aiohttp,sqlalchemy,alembic,commonmark,ruamel.yaml,PIL,moviepy,prometheus_client,yarl,mako,pkg_resources
known_telethon=telethon,alchemysession,cryptg
known_mautrix=mautrix
balanced_wrapping=True
length_sort=True
+2
View File
@@ -3,6 +3,8 @@
[![License](https://img.shields.io/github/license/mautrix/telegram.svg)](LICENSE) [![License](https://img.shields.io/github/license/mautrix/telegram.svg)](LICENSE)
[![Release](https://img.shields.io/github/release/mautrix/telegram/all.svg)](https://github.com/mautrix/telegram/releases) [![Release](https://img.shields.io/github/release/mautrix/telegram/all.svg)](https://github.com/mautrix/telegram/releases)
[![GitLab CI](https://mau.dev/mautrix/telegram/badges/master/pipeline.svg)](https://mau.dev/mautrix/telegram/container_registry) [![GitLab CI](https://mau.dev/mautrix/telegram/badges/master/pipeline.svg)](https://mau.dev/mautrix/telegram/container_registry)
[![Code style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Imports](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/)
A Matrix-Telegram hybrid puppeting/relaybot bridge. A Matrix-Telegram hybrid puppeting/relaybot bridge.
+6 -5
View File
@@ -19,12 +19,9 @@ from typing import Any
from telethon import __version__ as __telethon_version__ from telethon import __version__ as __telethon_version__
from mautrix.types import UserID, RoomID
from mautrix.bridge import Bridge from mautrix.bridge import Bridge
from mautrix.types import RoomID, UserID
from .web.provisioning import ProvisioningAPI
from .web.public import PublicBridgeWebsite
from .abstract_user import AbstractUser
from .bot import Bot from .bot import Bot
from .config import Config from .config import Config
from .db import init as init_db, upgrade_table from .db import init as init_db, upgrade_table
@@ -32,7 +29,11 @@ from .matrix import MatrixHandler
from .portal import Portal from .portal import Portal
from .puppet import Puppet from .puppet import Puppet
from .user import User from .user import User
from .version import version, linkified_version from .version import linkified_version, version
from .web.provisioning import ProvisioningAPI
from .web.public import PublicBridgeWebsite
from .abstract_user import AbstractUser # isort: skip
class TelegramBridge(Bridge): class TelegramBridge(Bridge):
+150 -73
View File
@@ -15,45 +15,81 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from __future__ import annotations
from typing import Type, Any, Union, TYPE_CHECKING from typing import TYPE_CHECKING, Any, Type, Union
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import platform
import asyncio import asyncio
import logging import logging
import platform
import time import time
from telethon.network import (
Connection,
ConnectionTcpFull,
ConnectionTcpMTProxyRandomizedIntermediate,
)
from telethon.sessions import Session from telethon.sessions import Session
from telethon.network import (ConnectionTcpMTProxyRandomizedIntermediate, ConnectionTcpFull, from telethon.tl.patched import Message, MessageService
Connection)
from telethon.tl.patched import MessageService, Message
from telethon.tl.types import ( from telethon.tl.types import (
Channel, Chat, MessageActionChannelMigrateFrom, PeerUser, TypeUpdate, UpdatePinnedMessages, Channel,
UpdatePinnedChannelMessages, UpdateChatParticipantAdmin, UpdateChatParticipants, PeerChat, Chat,
UpdateChatUserTyping, UpdateDeleteChannelMessages, UpdateNewMessage, UpdateDeleteMessages, MessageActionChannelMigrateFrom,
UpdateEditChannelMessage, UpdateEditMessage, UpdateNewChannelMessage, UpdateReadHistoryOutbox, MessageEmpty,
UpdateShortChatMessage, UpdateShortMessage, UpdateUserName, UpdateUserPhoto, UpdateUserStatus, PeerChat,
UpdateUserTyping, User, UserStatusOffline, UserStatusOnline, UpdateReadHistoryInbox, PeerUser,
UpdateReadChannelInbox, MessageEmpty, UpdateFolderPeers, UpdatePinnedDialogs, TypeUpdate,
UpdateNotifySettings, UpdateChannelUserTyping) UpdateChannelUserTyping,
UpdateChatParticipantAdmin,
UpdateChatParticipants,
UpdateChatUserTyping,
UpdateDeleteChannelMessages,
UpdateDeleteMessages,
UpdateEditChannelMessage,
UpdateEditMessage,
UpdateFolderPeers,
UpdateNewChannelMessage,
UpdateNewMessage,
UpdateNotifySettings,
UpdatePinnedChannelMessages,
UpdatePinnedDialogs,
UpdatePinnedMessages,
UpdateReadChannelInbox,
UpdateReadHistoryInbox,
UpdateReadHistoryOutbox,
UpdateShortChatMessage,
UpdateShortMessage,
UpdateUserName,
UpdateUserPhoto,
UpdateUserStatus,
UpdateUserTyping,
User,
UserStatusOffline,
UserStatusOnline,
)
from mautrix.types import UserID, PresenceState
from mautrix.errors import MatrixError
from mautrix.appservice import AppService from mautrix.appservice import AppService
from mautrix.errors import MatrixError
from mautrix.types import PresenceState, UserID
from mautrix.util.logging import TraceLogger from mautrix.util.logging import TraceLogger
from mautrix.util.opt_prometheus import Histogram, Counter from mautrix.util.opt_prometheus import Counter, Histogram
from . import portal as po, puppet as pu, __version__ from . import __version__, portal as po, puppet as pu
from .db import Message as DBMessage, PgSession
from .types import TelegramID
from .tgclient import MautrixTelegramClient
from .config import Config from .config import Config
from .db import Message as DBMessage, PgSession
from .tgclient import MautrixTelegramClient
from .types import TelegramID
if TYPE_CHECKING: if TYPE_CHECKING:
from .bot import Bot
from .__main__ import TelegramBridge from .__main__ import TelegramBridge
from .bot import Bot
UpdateMessage = Union[UpdateShortChatMessage, UpdateShortMessage, UpdateNewChannelMessage, UpdateMessage = Union[
UpdateNewMessage, UpdateEditMessage, UpdateEditChannelMessage] UpdateShortChatMessage,
UpdateShortMessage,
UpdateNewChannelMessage,
UpdateNewMessage,
UpdateEditMessage,
UpdateEditChannelMessage,
]
UpdateMessageContent = Union[ UpdateMessageContent = Union[
UpdateShortMessage, UpdateShortChatMessage, Message, MessageService, MessageEmpty UpdateShortMessage, UpdateShortChatMessage, Message, MessageService, MessageEmpty
] ]
@@ -74,9 +110,9 @@ class AbstractUser(ABC):
loop: asyncio.AbstractEventLoop = None loop: asyncio.AbstractEventLoop = None
log: TraceLogger log: TraceLogger
az: AppService az: AppService
bridge: 'TelegramBridge' bridge: "TelegramBridge"
config: Config config: Config
relaybot: 'Bot' relaybot: "Bot"
ignore_incoming_bot_events: bool = True ignore_incoming_bot_events: bool = True
max_deletions: int = 10 max_deletions: int = 10
@@ -113,11 +149,13 @@ class AbstractUser(ABC):
def _proxy_settings(self) -> tuple[Type[Connection], tuple[Any, ...] | None]: def _proxy_settings(self) -> tuple[Type[Connection], tuple[Any, ...] | None]:
proxy_type = self.config["telegram.proxy.type"].lower() proxy_type = self.config["telegram.proxy.type"].lower()
connection = ConnectionTcpFull connection = ConnectionTcpFull
connection_data = (self.config["telegram.proxy.address"], connection_data = (
self.config["telegram.proxy.port"], self.config["telegram.proxy.address"],
self.config["telegram.proxy.rdns"], self.config["telegram.proxy.port"],
self.config["telegram.proxy.username"], self.config["telegram.proxy.rdns"],
self.config["telegram.proxy.password"]) self.config["telegram.proxy.username"],
self.config["telegram.proxy.password"],
)
if proxy_type == "disabled": if proxy_type == "disabled":
connection_data = None connection_data = None
elif proxy_type == "socks4": elif proxy_type == "socks4":
@@ -133,7 +171,7 @@ class AbstractUser(ABC):
return connection, connection_data return connection, connection_data
@classmethod @classmethod
def init_cls(cls, bridge: 'TelegramBridge') -> None: def init_cls(cls, bridge: "TelegramBridge") -> None:
cls.bridge = bridge cls.bridge = bridge
cls.config = bridge.config cls.config = bridge.config
cls.loop = bridge.loop cls.loop = bridge.loop
@@ -146,9 +184,11 @@ class AbstractUser(ABC):
session = await PgSession.get(self.name) session = await PgSession.get(self.name)
if self.config["telegram.server.enabled"]: if self.config["telegram.server.enabled"]:
session.set_dc(self.config["telegram.server.dc"], session.set_dc(
self.config["telegram.server.ip"], self.config["telegram.server.dc"],
self.config["telegram.server.port"]) self.config["telegram.server.ip"],
self.config["telegram.server.port"],
)
if self.is_relaybot: if self.is_relaybot:
base_logger = logging.getLogger("telethon.relaybot") base_logger = logging.getLogger("telethon.relaybot")
@@ -164,16 +204,15 @@ class AbstractUser(ABC):
self.client = MautrixTelegramClient( self.client = MautrixTelegramClient(
session=session, session=session,
api_id=self.config["telegram.api_id"], api_id=self.config["telegram.api_id"],
api_hash=self.config["telegram.api_hash"], api_hash=self.config["telegram.api_hash"],
app_version=__version__ if appversion == "auto" else appversion, app_version=__version__ if appversion == "auto" else appversion,
system_version=(MautrixTelegramClient.__version__ system_version=(
if sysversion == "auto" else sysversion), MautrixTelegramClient.__version__ if sysversion == "auto" else sysversion
device_model=(f"{platform.system()} {platform.release()}" ),
if device == "auto" else device), device_model=(
f"{platform.system()} {platform.release()}" if device == "auto" else device
),
timeout=self.config["telegram.connection.timeout"], timeout=self.config["telegram.connection.timeout"],
connection_retries=self.config["telegram.connection.retries"], connection_retries=self.config["telegram.connection.retries"],
retry_delay=self.config["telegram.connection.retry_delay"], retry_delay=self.config["telegram.connection.retry_delay"],
@@ -182,9 +221,8 @@ class AbstractUser(ABC):
connection=connection, connection=connection,
proxy=proxy, proxy=proxy,
raise_last_call_error=True, raise_last_call_error=True,
loop=self.loop, loop=self.loop,
base_logger=base_logger base_logger=base_logger,
) )
self.client.add_event_handler(self._update_catch) self.client.add_event_handler(self._update_catch)
@@ -221,13 +259,16 @@ class AbstractUser(ABC):
raise NotImplementedError() raise NotImplementedError()
async def is_logged_in(self) -> bool: async def is_logged_in(self) -> bool:
return (self.client and self.client.is_connected() return (
and await self.client.is_user_authorized()) self.client and self.client.is_connected() and await self.client.is_user_authorized()
)
async def has_full_access(self, allow_bot: bool = False) -> bool: async def has_full_access(self, allow_bot: bool = False) -> bool:
return (self.puppet_whitelisted return (
and (not self.is_bot or allow_bot) self.puppet_whitelisted
and await self.is_logged_in()) and (not self.is_bot or allow_bot)
and await self.is_logged_in()
)
async def start(self, delete_unless_authenticated: bool = False) -> AbstractUser: async def start(self, delete_unless_authenticated: bool = False) -> AbstractUser:
if not self.client: if not self.client:
@@ -240,8 +281,10 @@ class AbstractUser(ABC):
if self.connected: if self.connected:
return self return self
if even_if_no_session or await PgSession.has(self.mxid): if even_if_no_session or await PgSession.has(self.mxid):
self.log.debug("Starting client due to ensure_started" self.log.debug(
f"(even_if_no_session={even_if_no_session})") "Starting client due to ensure_started"
f"(even_if_no_session={even_if_no_session})"
)
await self.start(delete_unless_authenticated=not even_if_no_session) await self.start(delete_unless_authenticated=not even_if_no_session)
return self return self
@@ -253,8 +296,17 @@ class AbstractUser(ABC):
async def _update(self, update: TypeUpdate) -> None: async def _update(self, update: TypeUpdate) -> None:
asyncio.create_task(self._handle_entity_updates(getattr(update, "_entities", {}))) asyncio.create_task(self._handle_entity_updates(getattr(update, "_entities", {})))
if isinstance(update, (UpdateShortChatMessage, UpdateShortMessage, UpdateNewChannelMessage, if isinstance(
UpdateNewMessage, UpdateEditMessage, UpdateEditChannelMessage)): update,
(
UpdateShortChatMessage,
UpdateShortMessage,
UpdateNewChannelMessage,
UpdateNewMessage,
UpdateEditMessage,
UpdateEditChannelMessage,
),
):
await self.update_message(update) await self.update_message(update)
elif isinstance(update, UpdateDeleteMessages): elif isinstance(update, UpdateDeleteMessages):
await self.delete_message(update) await self.delete_message(update)
@@ -302,8 +354,9 @@ class AbstractUser(ABC):
else: else:
portal = await po.Portal.get_by_tgid(TelegramID(update.channel_id)) portal = await po.Portal.get_by_tgid(TelegramID(update.channel_id))
if portal and portal.mxid: if portal and portal.mxid:
await portal.receive_telegram_pin_ids(update.messages, self.tgid, await portal.receive_telegram_pin_ids(
remove=not update.pinned) update.messages, self.tgid, remove=not update.pinned
)
@staticmethod @staticmethod
async def update_participants(update: UpdateChatParticipants) -> None: async def update_participants(update: UpdateChatParticipants) -> None:
@@ -323,8 +376,9 @@ class AbstractUser(ABC):
return return
# We check that these are user read receipts, so tg_space is always the user ID. # We check that these are user read receipts, so tg_space is always the user ID.
message = await DBMessage.get_one_by_tgid(TelegramID(update.max_id), self.tgid, message = await DBMessage.get_one_by_tgid(
edit_index=-1) TelegramID(update.max_id), self.tgid, edit_index=-1
)
if not message: if not message:
return return
@@ -354,8 +408,9 @@ class AbstractUser(ABC):
return return
tg_space = portal.tgid if portal.peer_type == "channel" else self.tgid tg_space = portal.tgid if portal.peer_type == "channel" else self.tgid
message = await DBMessage.get_one_by_tgid(TelegramID(update.max_id), tg_space, message = await DBMessage.get_one_by_tgid(
edit_index=-1) TelegramID(update.max_id), tg_space, edit_index=-1
)
if not message: if not message:
return return
@@ -400,8 +455,9 @@ class AbstractUser(ABC):
try: try:
users = (entity for entity in entities.values() if isinstance(entity, User)) users = (entity for entity in entities.values() if isinstance(entity, User))
puppets = ((await pu.Puppet.get_by_tgid(TelegramID(user.id)), user) for user in users) puppets = ((await pu.Puppet.get_by_tgid(TelegramID(user.id)), user) for user in users)
await asyncio.gather(*[puppet.try_update_info(self, info) await asyncio.gather(
async for puppet, info in puppets if puppet]) *[puppet.try_update_info(self, info) async for puppet, info in puppets if puppet]
)
except Exception: except Exception:
self.log.exception("Failed to handle entity updates") self.log.exception("Failed to handle entity updates")
@@ -441,8 +497,15 @@ class AbstractUser(ABC):
TelegramID(update.user_id), tg_receiver=self.tgid, peer_type="user" TelegramID(update.user_id), tg_receiver=self.tgid, peer_type="user"
) )
sender = await pu.Puppet.get_by_tgid(self.tgid if update.out else update.user_id) sender = await pu.Puppet.get_by_tgid(self.tgid if update.out else update.user_id)
elif isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage, elif isinstance(
UpdateEditMessage, UpdateEditChannelMessage)): update,
(
UpdateNewMessage,
UpdateNewChannelMessage,
UpdateEditMessage,
UpdateEditChannelMessage,
),
):
update = update.message update = update.message
if isinstance(update, MessageEmpty): if isinstance(update, MessageEmpty):
return update, None, None return update, None, None
@@ -454,8 +517,9 @@ class AbstractUser(ABC):
else: else:
sender = None sender = None
else: else:
self.log.warning("Unexpected message type in User#get_message_details: " self.log.warning(
f"{type(update)}") f"Unexpected message type in User#get_message_details: {type(update)}"
)
return update, None, None return update, None, None
return update, sender, portal return update, sender, portal
@@ -509,26 +573,39 @@ class AbstractUser(ABC):
self.log.debug(f"Ignoring private message to bot from {sender.id}") self.log.debug(f"Ignoring private message to bot from {sender.id}")
return return
elif not portal.mxid and self.config["bridge.relaybot.ignore_unbridged_group_chat"]: elif not portal.mxid and self.config["bridge.relaybot.ignore_unbridged_group_chat"]:
self.log.debug("Ignoring message received by bot" self.log.debug(
f" in unbridged chat {portal.tgid_log}") f"Ignoring message received by bot in unbridged chat {portal.tgid_log}"
)
return return
if ((self.ignore_incoming_bot_events and self.relaybot if (
and sender and sender.id == self.relaybot.tgid)): self.ignore_incoming_bot_events
self.log.debug(f"Ignoring relaybot-sent message %s to %s", update.id, portal.tgid_log) and self.relaybot
and sender
and sender.id == self.relaybot.tgid
):
self.log.debug("Ignoring relaybot-sent message %s to %s", update.id, portal.tgid_log)
return return
await portal.backfill_lock.wait(f"update {update.id}") await portal.backfill_lock.wait(f"update {update.id}")
if isinstance(update, MessageService): if isinstance(update, MessageService):
if isinstance(update.action, MessageActionChannelMigrateFrom): if isinstance(update.action, MessageActionChannelMigrateFrom):
self.log.trace(f"Received %s in %s by %d, unregistering portal...", self.log.trace(
update.action, portal.tgid_log, sender.id) "Received %s in %s by %d, unregistering portal...",
update.action,
portal.tgid_log,
sender.id,
)
await self.unregister_portal(update.action.chat_id, update.action.chat_id) await self.unregister_portal(update.action.chat_id, update.action.chat_id)
await self.register_portal(portal) await self.register_portal(portal)
return return
self.log.trace("Handling action %s to %s by %d", update.action, portal.tgid_log, self.log.trace(
(sender.id if sender else 0)) "Handling action %s to %s by %d",
update.action,
portal.tgid_log,
(sender.id if sender else 0),
)
return await portal.handle_telegram_action(self, sender, update) return await portal.handle_telegram_action(self, sender, update)
if isinstance(original_update, (UpdateEditMessage, UpdateEditChannelMessage)): if isinstance(original_update, (UpdateEditMessage, UpdateEditChannelMessage)):
+56 -32
View File
@@ -13,26 +13,40 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING from typing import Awaitable, Callable, Dict, List, Optional, Tuple
import logging import logging
from telethon.errors import ChannelInvalidError, ChannelPrivateError
from telethon.tl.functions.channels import GetChannelsRequest, GetParticipantRequest
from telethon.tl.functions.messages import GetChatsRequest, GetFullChatRequest
from telethon.tl.patched import Message, MessageService from telethon.tl.patched import Message, MessageService
from telethon.tl.types import ( from telethon.tl.types import (
ChannelParticipantAdmin, ChannelParticipantCreator, ChatForbidden, ChatParticipantAdmin, ChannelParticipantAdmin,
ChatParticipantCreator, InputChannel, InputUser, MessageActionChatAddUser, PeerUser, ChannelParticipantCreator,
MessageActionChatDeleteUser, MessageEntityBotCommand, PeerChannel, PeerChat, TypePeer, ChatForbidden,
UpdateNewChannelMessage, UpdateNewMessage, MessageActionChatMigrateTo, User) ChatParticipantAdmin,
from telethon.tl.functions.messages import GetChatsRequest, GetFullChatRequest ChatParticipantCreator,
from telethon.tl.functions.channels import GetChannelsRequest, GetParticipantRequest InputChannel,
from telethon.errors import ChannelInvalidError, ChannelPrivateError InputUser,
MessageActionChatAddUser,
MessageActionChatDeleteUser,
MessageActionChatMigrateTo,
MessageEntityBotCommand,
PeerChannel,
PeerChat,
PeerUser,
TypePeer,
UpdateNewChannelMessage,
UpdateNewMessage,
User,
)
from mautrix.types import UserID from mautrix.types import UserID
from . import portal as po, puppet as pu, user as u
from .abstract_user import AbstractUser from .abstract_user import AbstractUser
from .db import BotChat from .db import BotChat
from .types import TelegramID from .types import TelegramID
from . import puppet as pu, portal as po, user as u
ReplyFunc = Callable[[str], Awaitable[Message]] ReplyFunc = Callable[[str], Awaitable[Message]]
@@ -60,8 +74,9 @@ class Bot(AbstractUser):
self.is_bot = True self.is_bot = True
self.chats = {} self.chats = {}
self.tg_whitelist = [] self.tg_whitelist = []
self.whitelist_group_admins = (self.config["bridge.relaybot.whitelist_group_admins"] self.whitelist_group_admins = (
or False) self.config["bridge.relaybot.whitelist_group_admins"] or False
)
self._me_info = None self._me_info = None
self._me_mxid = None self._me_mxid = None
@@ -83,7 +98,7 @@ class Bot(AbstractUser):
if isinstance(user_id, int): if isinstance(user_id, int):
self.tg_whitelist.append(user_id) self.tg_whitelist.append(user_id)
async def start(self, delete_unless_authenticated: bool = False) -> 'Bot': async def start(self, delete_unless_authenticated: bool = False) -> "Bot":
self.chats = {chat.id: chat.type for chat in await BotChat.all()} self.chats = {chat.id: chat.type for chat in await BotChat.all()}
await super().start(delete_unless_authenticated) await super().start(delete_unless_authenticated)
if not await self.is_logged_in(): if not await self.is_logged_in():
@@ -104,9 +119,11 @@ class Bot(AbstractUser):
if isinstance(chat, ChatForbidden) or chat.left or chat.deactivated: if isinstance(chat, ChatForbidden) or chat.left or chat.deactivated:
await self.remove_chat(TelegramID(chat.id)) await self.remove_chat(TelegramID(chat.id))
channel_ids = [InputChannel(chat_id, 0) channel_ids = [
for chat_id, chat_type in self.chats.items() InputChannel(chat_id, 0)
if chat_type == "channel"] for chat_id, chat_type in self.chats.items()
if chat_type == "channel"
]
for channel_id in channel_ids: for channel_id in channel_ids:
try: try:
await self.client(GetChannelsRequest([channel_id])) await self.client(GetChannelsRequest([channel_id]))
@@ -143,7 +160,9 @@ class Bot(AbstractUser):
if self.whitelist_group_admins: if self.whitelist_group_admins:
if isinstance(chat, PeerChannel): if isinstance(chat, PeerChannel):
p = await self.client(GetParticipantRequest(chat, tgid)) p = await self.client(GetParticipantRequest(chat, tgid))
return isinstance(p.participant, (ChannelParticipantCreator, ChannelParticipantAdmin)) return isinstance(
p.participant, (ChannelParticipantCreator, ChannelParticipantAdmin)
)
elif isinstance(chat, PeerChat): elif isinstance(chat, PeerChat):
chat = await self.client(GetFullChatRequest(chat.chat_id)) chat = await self.client(GetFullChatRequest(chat.chat_id))
participants = chat.full_chat.participants.participants participants = chat.full_chat.participants.participants
@@ -170,27 +189,29 @@ class Bot(AbstractUser):
if portal.mxid: if portal.mxid:
if portal.username: if portal.username:
return await reply( return await reply(
f"Portal is public: [{portal.alias}](https://matrix.to/#/{portal.alias})") f"Portal is public: [{portal.alias}](https://matrix.to/#/{portal.alias})"
)
else: else:
return await reply( return await reply("Portal is not public. Use `/invite <mxid>` to get an invite.")
"Portal is not public. Use `/invite <mxid>` to get an invite.")
async def handle_command_invite(self, portal: po.Portal, reply: ReplyFunc, async def handle_command_invite(
mxid_input: UserID) -> Message: self, portal: po.Portal, reply: ReplyFunc, mxid_input: UserID
) -> Message:
if len(mxid_input) == 0: if len(mxid_input) == 0:
return await reply("Usage: `/invite <mxid>`") return await reply("Usage: `/invite <mxid>`")
elif not portal.mxid: elif not portal.mxid:
return await reply("Portal does not have Matrix room. " return await reply("Portal does not have Matrix room. Create one with /portal first.")
"Create one with /portal first.") if mxid_input[0] != "@" or mxid_input.find(":") < 2:
if mxid_input[0] != '@' or mxid_input.find(':') < 2:
return await reply("That doesn't look like a Matrix ID.") return await reply("That doesn't look like a Matrix ID.")
user = await u.User.get_and_start_by_mxid(mxid_input) user = await u.User.get_and_start_by_mxid(mxid_input)
if not user.relaybot_whitelisted: if not user.relaybot_whitelisted:
return await reply("That user is not whitelisted to use the bridge.") return await reply("That user is not whitelisted to use the bridge.")
elif await user.is_logged_in(): elif await user.is_logged_in():
displayname = f"@{user.tg_username}" if user.tg_username else user.displayname displayname = f"@{user.tg_username}" if user.tg_username else user.displayname
return await reply("That user seems to be logged in. " return await reply(
f"Just invite [{displayname}](tg://user?id={user.tgid})") "That user seems to be logged in. "
f"Just invite [{displayname}](tg://user?id={user.tgid})"
)
else: else:
await portal.invite_to_matrix(user.mxid) await portal.invite_to_matrix(user.mxid)
return await reply(f"Invited `{user.mxid}` to the portal.") return await reply(f"Invited `{user.mxid}` to the portal.")
@@ -251,7 +272,7 @@ class Bot(AbstractUser):
await self.handle_command_portal(portal, reply) await self.handle_command_portal(portal, reply)
elif is_invite_cmd: elif is_invite_cmd:
try: try:
mxid = text[text.index(" ") + 1:] mxid = text[text.index(" ") + 1 :]
except ValueError: except ValueError:
mxid = "" mxid = ""
await self.handle_command_invite(portal, reply, mxid_input=UserID(mxid)) await self.handle_command_invite(portal, reply, mxid_input=UserID(mxid))
@@ -283,10 +304,13 @@ class Bot(AbstractUser):
await self.handle_service_message(update.message) await self.handle_service_message(update.message)
return False return False
is_command = (isinstance(update.message, Message) is_command = (
and update.message.entities and len(update.message.entities) > 0 isinstance(update.message, Message)
and isinstance(update.message.entities[0], MessageEntityBotCommand) and update.message.entities
and update.message.entities[0].offset == 0) and len(update.message.entities) > 0
and isinstance(update.message.entities[0], MessageEntityBotCommand)
and update.message.entities[0].offset == 0
)
if is_command: if is_command:
await self.handle_command(update.message) await self.handle_command(update.message)
return False return False
+25 -7
View File
@@ -1,8 +1,26 @@
from .handler import (command_handler, CommandHandler, CommandProcessor, CommandEvent, from .handler import (
SECTION_AUTH, SECTION_CREATING_PORTALS, SECTION_PORTAL_MANAGEMENT, SECTION_ADMIN,
SECTION_MISC, SECTION_ADMIN) SECTION_AUTH,
from . import portal, telegram, matrix_auth SECTION_CREATING_PORTALS,
SECTION_MISC,
SECTION_PORTAL_MANAGEMENT,
CommandEvent,
CommandHandler,
CommandProcessor,
command_handler,
)
__all__ = ["command_handler", "CommandHandler", "CommandProcessor", "CommandEvent", # This has to happen after the handler imports
"SECTION_AUTH", "SECTION_MISC", "SECTION_ADMIN", "SECTION_CREATING_PORTALS", from . import matrix_auth, portal, telegram # isort: skip
"SECTION_PORTAL_MANAGEMENT"]
__all__ = [
"command_handler",
"CommandHandler",
"CommandProcessor",
"CommandEvent",
"SECTION_AUTH",
"SECTION_MISC",
"SECTION_ADMIN",
"SECTION_CREATING_PORTALS",
"SECTION_PORTAL_MANAGEMENT",
]
+106 -40
View File
@@ -1,5 +1,5 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge # mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2021 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@@ -15,18 +15,22 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from __future__ import annotations
from typing import Awaitable, Callable, NamedTuple, Any, TYPE_CHECKING from typing import TYPE_CHECKING, Any, Awaitable, Callable, NamedTuple
from telethon.errors import FloodWaitError from telethon.errors import FloodWaitError
from mautrix.types import RoomID, EventID, MessageEventContent from mautrix.bridge.commands import (
from mautrix.bridge.commands import (HelpSection, CommandEvent as BaseCommandEvent, CommandEvent as BaseCommandEvent,
CommandHandler as BaseCommandHandler, CommandHandler as BaseCommandHandler,
CommandProcessor as BaseCommandProcessor, CommandHandlerFunc,
CommandHandlerFunc, command_handler as base_command_handler) CommandProcessor as BaseCommandProcessor,
HelpSection,
command_handler as base_command_handler,
)
from mautrix.types import EventID, MessageEventContent, RoomID
from mautrix.util.format_duration import format_duration from mautrix.util.format_duration import format_duration
from .. import user as u, portal as po from .. import portal as po, user as u
if TYPE_CHECKING: if TYPE_CHECKING:
from ..__main__ import TelegramBridge from ..__main__ import TelegramBridge
@@ -52,11 +56,31 @@ class CommandEvent(BaseCommandEvent):
sender: u.User sender: u.User
portal: po.Portal portal: po.Portal
def __init__(self, processor: CommandProcessor, room_id: RoomID, event_id: EventID, def __init__(
sender: u.User, command: str, args: list[str], content: MessageEventContent, self,
portal: po.Portal | None, is_management: bool, has_bridge_bot: bool) -> None: processor: CommandProcessor,
super().__init__(processor, room_id, event_id, sender, command, args, content, room_id: RoomID,
portal, is_management, has_bridge_bot) event_id: EventID,
sender: u.User,
command: str,
args: list[str],
content: MessageEventContent,
portal: po.Portal | None,
is_management: bool,
has_bridge_bot: bool,
) -> None:
super().__init__(
processor,
room_id,
event_id,
sender,
command,
args,
content,
portal,
is_management,
has_bridge_bot,
)
self.bridge = processor.bridge self.bridge = processor.bridge
self.tgbot = processor.tgbot self.tgbot = processor.tgbot
self.config = processor.config self.config = processor.config
@@ -67,9 +91,14 @@ class CommandEvent(BaseCommandEvent):
return self.sender.is_admin return self.sender.is_admin
async def get_help_key(self) -> HelpCacheKey: async def get_help_key(self) -> HelpCacheKey:
return HelpCacheKey(self.is_management, self.portal is not None, return HelpCacheKey(
self.sender.puppet_whitelisted, self.sender.matrix_puppet_whitelisted, self.is_management,
self.sender.is_admin, await self.sender.is_logged_in()) self.portal is not None,
self.sender.puppet_whitelisted,
self.sender.matrix_puppet_whitelisted,
self.sender.is_admin,
await self.sender.is_logged_in(),
)
class CommandHandler(BaseCommandHandler): class CommandHandler(BaseCommandHandler):
@@ -78,14 +107,33 @@ class CommandHandler(BaseCommandHandler):
needs_puppeting: bool needs_puppeting: bool
needs_matrix_puppeting: bool needs_matrix_puppeting: bool
def __init__(self, handler: Callable[[CommandEvent], Awaitable[EventID]], def __init__(
management_only: bool, name: str, help_text: str, help_args: str, self,
help_section: HelpSection, needs_auth: bool, needs_puppeting: bool, handler: Callable[[CommandEvent], Awaitable[EventID]],
needs_matrix_puppeting: bool, needs_admin: bool, **kwargs) -> None: management_only: bool,
super().__init__(handler, management_only, name, help_text, help_args, help_section, name: str,
needs_auth=needs_auth, needs_puppeting=needs_puppeting, help_text: str,
needs_matrix_puppeting=needs_matrix_puppeting, needs_admin=needs_admin, help_args: str,
**kwargs) help_section: HelpSection,
needs_auth: bool,
needs_puppeting: bool,
needs_matrix_puppeting: bool,
needs_admin: bool,
**kwargs,
) -> None:
super().__init__(
handler,
management_only,
name,
help_text,
help_args,
help_section,
needs_auth=needs_auth,
needs_puppeting=needs_puppeting,
needs_matrix_puppeting=needs_matrix_puppeting,
needs_admin=needs_admin,
**kwargs,
)
async def get_permission_error(self, evt: CommandEvent) -> str | None: async def get_permission_error(self, evt: CommandEvent) -> str | None:
if self.needs_puppeting and not evt.sender.puppet_whitelisted: if self.needs_puppeting and not evt.sender.puppet_whitelisted:
@@ -95,33 +143,51 @@ class CommandHandler(BaseCommandHandler):
return await super().get_permission_error(evt) return await super().get_permission_error(evt)
def has_permission(self, key: HelpCacheKey) -> bool: def has_permission(self, key: HelpCacheKey) -> bool:
return (super().has_permission(key) and return (
(not self.needs_puppeting or key.puppet_whitelisted) and super().has_permission(key)
(not self.needs_matrix_puppeting or key.matrix_puppet_whitelisted)) and (not self.needs_puppeting or key.puppet_whitelisted)
and (not self.needs_matrix_puppeting or key.matrix_puppet_whitelisted)
)
def command_handler(_func: CommandHandlerFunc | None = None, *, needs_auth: bool = True, def command_handler(
needs_puppeting: bool = True, needs_matrix_puppeting: bool = False, _func: CommandHandlerFunc | None = None,
needs_admin: bool = False, management_only: bool = False, *,
name: str | None = None, help_text: str = "", help_args: str = "", needs_auth: bool = True,
help_section: HelpSection = None) -> Callable[[CommandHandlerFunc], needs_puppeting: bool = True,
CommandHandler]: needs_matrix_puppeting: bool = False,
needs_admin: bool = False,
management_only: bool = False,
name: str | None = None,
help_text: str = "",
help_args: str = "",
help_section: HelpSection = None,
) -> Callable[[CommandHandlerFunc], CommandHandler]:
return base_command_handler( return base_command_handler(
_func, _handler_class=CommandHandler, name=name, help_text=help_text, help_args=help_args, _func,
help_section=help_section, management_only=management_only, needs_auth=needs_auth, _handler_class=CommandHandler,
needs_admin=needs_admin, needs_puppeting=needs_puppeting, name=name,
needs_matrix_puppeting=needs_matrix_puppeting) help_text=help_text,
help_args=help_args,
help_section=help_section,
management_only=management_only,
needs_auth=needs_auth,
needs_admin=needs_admin,
needs_puppeting=needs_puppeting,
needs_matrix_puppeting=needs_matrix_puppeting,
)
class CommandProcessor(BaseCommandProcessor): class CommandProcessor(BaseCommandProcessor):
def __init__(self, bridge: 'TelegramBridge') -> None: def __init__(self, bridge: "TelegramBridge") -> None:
super().__init__(event_class=CommandEvent, bridge=bridge) super().__init__(event_class=CommandEvent, bridge=bridge)
self.tgbot = bridge.bot self.tgbot = bridge.bot
self.public_website = bridge.public_website self.public_website = bridge.public_website
@staticmethod @staticmethod
async def _run_handler(handler: Callable[[CommandEvent], Awaitable[Any]], evt: CommandEvent async def _run_handler(
) -> Any: handler: Callable[[CommandEvent], Awaitable[Any]], evt: CommandEvent
) -> Any:
try: try:
return await handler(evt) return await handler(evt)
except FloodWaitError as e: except FloodWaitError as e:
+29 -57
View File
@@ -1,5 +1,5 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge # mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2021 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@@ -13,33 +13,27 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.types import EventID
from mautrix.bridge import InvalidAccessToken, OnlyLoginSelf from mautrix.bridge import InvalidAccessToken, OnlyLoginSelf
from mautrix.types import EventID
from . import command_handler, CommandEvent, SECTION_AUTH
from .. import puppet as pu from .. import puppet as pu
from . import SECTION_AUTH, CommandEvent, command_handler
@command_handler(needs_auth=True, needs_matrix_puppeting=True, @command_handler(
help_section=SECTION_AUTH, help_text="Revert your Telegram account's Matrix " needs_auth=True,
"puppet to use the default Matrix account.") management_only=True,
async def logout_matrix(evt: CommandEvent) -> EventID: needs_matrix_puppeting=True,
puppet = await pu.Puppet.get_by_tgid(evt.sender.tgid) help_section=SECTION_AUTH,
if not puppet.is_real_user: help_text="Replace your Telegram account's Matrix puppet with your own Matrix account.",
return await evt.reply("You are not logged in with your Matrix account.") )
await puppet.switch_mxid(None, None)
return await evt.reply("Reverted your Telegram account's Matrix puppet back to the default.")
@command_handler(needs_auth=True, management_only=True, needs_matrix_puppeting=True,
help_section=SECTION_AUTH,
help_text="Replace your Telegram account's Matrix puppet with your own Matrix "
"account.")
async def login_matrix(evt: CommandEvent) -> EventID: async def login_matrix(evt: CommandEvent) -> EventID:
puppet = await pu.Puppet.get_by_tgid(evt.sender.tgid) puppet = await pu.Puppet.get_by_tgid(evt.sender.tgid)
if puppet.is_real_user: if puppet.is_real_user:
return await evt.reply("You have already logged in with your Matrix account. " return await evt.reply(
"Log out with `$cmdprefix+sp logout-matrix` first.") "You have already logged in with your Matrix account. "
"Log out with `$cmdprefix+sp logout-matrix` first."
)
allow_matrix_login = evt.config.get("bridge.allow_matrix_login", True) allow_matrix_login = evt.config.get("bridge.allow_matrix_login", True)
if allow_matrix_login: if allow_matrix_login:
evt.sender.command_status = { evt.sender.command_status = {
@@ -57,57 +51,35 @@ async def login_matrix(evt: CommandEvent) -> EventID:
"here.\n" "here.\n"
f"If you would like to log in outside of Matrix, [click here]({url}).\n\n" f"If you would like to log in outside of Matrix, [click here]({url}).\n\n"
"Logging in outside of Matrix is recommended, because in-Matrix login would save " "Logging in outside of Matrix is recommended, because in-Matrix login would save "
"your access token in the message history.") "your access token in the message history."
return await evt.reply("This bridge instance does not allow logging in inside Matrix.\n\n" )
f"Please visit [the login page]({url}) to log in.") return await evt.reply(
"This bridge instance does not allow logging in inside Matrix.\n\n"
f"Please visit [the login page]({url}) to log in."
)
elif allow_matrix_login: elif allow_matrix_login:
return await evt.reply( return await evt.reply(
"This bridge instance does not allow you to log in outside of Matrix.\n\n" "This bridge instance does not allow you to log in outside of Matrix.\n\n"
"Please send your Matrix access token here to log in.") "Please send your Matrix access token here to log in."
)
return await evt.reply("This bridge instance has been configured to not allow logging in.") return await evt.reply("This bridge instance has been configured to not allow logging in.")
@command_handler(needs_auth=True, needs_matrix_puppeting=True,
help_section=SECTION_AUTH,
help_text="Pings the server with the stored matrix authentication.")
async def ping_matrix(evt: CommandEvent) -> EventID:
puppet = await pu.Puppet.get_by_tgid(evt.sender.tgid)
if not puppet.is_real_user:
return await evt.reply("You are not logged in with your Matrix account.")
try:
await puppet.start()
except InvalidAccessToken:
return await evt.reply("Your access token is invalid.")
return await evt.reply("Your Matrix login is working.")
@command_handler(needs_auth=True, needs_matrix_puppeting=True, help_section=SECTION_AUTH,
help_text="Clear the Matrix sync token stored for your custom puppet.")
async def clear_cache_matrix(evt: CommandEvent) -> EventID:
puppet = await pu.Puppet.get_by_tgid(evt.sender.tgid)
if not puppet.is_real_user:
return await evt.reply("You are not logged in with your Matrix account.")
try:
puppet.stop()
puppet.next_batch = None
await puppet.start()
except InvalidAccessToken:
return await evt.reply("Your access token is invalid.")
return await evt.reply("Cleared cache successfully.")
async def enter_matrix_token(evt: CommandEvent) -> EventID: async def enter_matrix_token(evt: CommandEvent) -> EventID:
evt.sender.command_status = None evt.sender.command_status = None
puppet = await pu.Puppet.get_by_tgid(evt.sender.tgid) puppet = await pu.Puppet.get_by_tgid(evt.sender.tgid)
if puppet.is_real_user: if puppet.is_real_user:
return await evt.reply("You have already logged in with your Matrix account. " return await evt.reply(
"Log out with `$cmdprefix+sp logout-matrix` first.") "You have already logged in with your Matrix account. "
"Log out with `$cmdprefix+sp logout-matrix` first."
)
try: try:
await puppet.switch_mxid(" ".join(evt.args), evt.sender.mxid) await puppet.switch_mxid(" ".join(evt.args), evt.sender.mxid)
except OnlyLoginSelf: except OnlyLoginSelf:
return await evt.reply("You can only log in as your own Matrix user.") return await evt.reply("You can only log in as your own Matrix user.")
except InvalidAccessToken: except InvalidAccessToken:
return await evt.reply("Failed to verify access token.") return await evt.reply("Failed to verify access token.")
return await evt.reply("Replaced your Telegram account's Matrix puppet " return await evt.reply(
f"with {puppet.custom_mxid}.") "Replaced your Telegram account's Matrix puppet with {puppet.custom_mxid}."
)
+16 -13
View File
@@ -18,13 +18,16 @@ import asyncio
from mautrix.types import EventID from mautrix.types import EventID
from ... import portal as po, puppet as pu, user as u from ... import portal as po, puppet as pu, user as u
from .. import command_handler, CommandEvent, SECTION_ADMIN from .. import SECTION_ADMIN, CommandEvent, command_handler
@command_handler(needs_admin=True, needs_auth=False, @command_handler(
help_section=SECTION_ADMIN, needs_admin=True,
help_args="<`portal`|`puppet`|`user`>", needs_auth=False,
help_text="Clear internal bridge caches") help_section=SECTION_ADMIN,
help_args="<`portal`|`puppet`|`user`>",
help_text="Clear internal bridge caches",
)
async def clear_db_cache(evt: CommandEvent) -> EventID: async def clear_db_cache(evt: CommandEvent) -> EventID:
try: try:
section = evt.args[0].lower() section = evt.args[0].lower()
@@ -44,19 +47,19 @@ async def clear_db_cache(evt: CommandEvent) -> EventID:
) )
await evt.reply("Cleared puppet cache and restarted custom puppet syncers") await evt.reply("Cleared puppet cache and restarted custom puppet syncers")
elif section == "user": elif section == "user":
u.User.by_mxid = { u.User.by_mxid = {user.mxid: user for user in u.User.by_tgid.values()}
user.mxid: user
for user in u.User.by_tgid.values()
}
await evt.reply("Cleared non-logged-in user cache") await evt.reply("Cleared non-logged-in user cache")
else: else:
return await evt.reply("**Usage:** `$cmdprefix+sp clear-db-cache <section>`") return await evt.reply("**Usage:** `$cmdprefix+sp clear-db-cache <section>`")
@command_handler(needs_admin=True, needs_auth=False, @command_handler(
help_section=SECTION_ADMIN, needs_admin=True,
help_args="[_mxid_]", needs_auth=False,
help_text="Reload and reconnect a user") help_section=SECTION_ADMIN,
help_args="[_mxid_]",
help_text="Reload and reconnect a user",
)
async def reload_user(evt: CommandEvent) -> EventID: async def reload_user(evt: CommandEvent) -> EventID:
if len(evt.args) > 0: if len(evt.args) > 0:
mxid = evt.args[0] mxid = evt.args[0]
+90 -60
View File
@@ -1,5 +1,5 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge # mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2021 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@@ -18,26 +18,31 @@ from __future__ import annotations
from typing import Awaitable from typing import Awaitable
import asyncio import asyncio
from telethon.tl.types import ChatForbidden, ChannelForbidden from telethon.tl.types import ChannelForbidden, ChatForbidden
from mautrix.types import EventID, RoomID from mautrix.types import EventID, RoomID
from ...types import TelegramID
from ... import portal as po from ... import portal as po
from .. import command_handler, CommandEvent, SECTION_CREATING_PORTALS from ...types import TelegramID
from .util import user_has_power_level, get_initial_state, warn_missing_power from .. import SECTION_CREATING_PORTALS, CommandEvent, command_handler
from .util import get_initial_state, user_has_power_level, warn_missing_power
@command_handler(needs_auth=False, needs_puppeting=False, @command_handler(
help_section=SECTION_CREATING_PORTALS, needs_auth=False,
help_args="[_id_]", needs_puppeting=False,
help_text="Bridge the current Matrix room to the Telegram chat with the given " help_section=SECTION_CREATING_PORTALS,
"ID. The ID must be the prefixed version that you get with the `/id` " help_args="[_id_]",
"command of the Telegram-side bot.") help_text=(
"Bridge the current Matrix room to the Telegram chat with the given ID. The ID must be "
"the prefixed version that you get with the `/id` command of the Telegram-side bot."
),
)
async def bridge(evt: CommandEvent) -> EventID: async def bridge(evt: CommandEvent) -> EventID:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** " return await evt.reply(
"`$cmdprefix+sp bridge <Telegram chat ID> [Matrix room ID]`") "**Usage:** `$cmdprefix+sp bridge <Telegram chat ID> [Matrix room ID]`"
)
force_use_bot = False force_use_bot = False
if evt.args[0] == "--usebot" and evt.sender.is_admin: if evt.args[0] == "--usebot" and evt.sender.is_admin:
force_use_bot = True force_use_bot = True
@@ -61,24 +66,30 @@ async def bridge(evt: CommandEvent) -> EventID:
tgid = TelegramID(-int(tgid_str)) tgid = TelegramID(-int(tgid_str))
peer_type = "chat" peer_type = "chat"
else: else:
return await evt.reply("That doesn't seem like a prefixed Telegram chat ID.\n\n" return await evt.reply(
"If you did not get the ID using the `/id` bot command, please " "That doesn't seem like a prefixed Telegram chat ID.\n\n"
"prefix channel IDs with `-100` and normal group IDs with `-`.\n\n" "If you did not get the ID using the `/id` bot command, please "
"Bridging private chats to existing rooms is not allowed.") "prefix channel IDs with `-100` and normal group IDs with `-`.\n\n"
"Bridging private chats to existing rooms is not allowed."
)
portal = await po.Portal.get_by_tgid(tgid, peer_type=peer_type) portal = await po.Portal.get_by_tgid(tgid, peer_type=peer_type)
if not portal.allow_bridging: if not portal.allow_bridging:
return await evt.reply("This bridge doesn't allow bridging that Telegram chat.\n" return await evt.reply(
"If you're the bridge admin, try " "This bridge doesn't allow bridging that Telegram chat.\n"
"`$cmdprefix+sp filter whitelist <Telegram chat ID>` first.") "If you're the bridge admin, try "
"`$cmdprefix+sp filter whitelist <Telegram chat ID>` first."
)
if portal.mxid: if portal.mxid:
has_portal_message = ( has_portal_message = (
"That Telegram chat already has a portal at " "That Telegram chat already has a portal at "
f"[{portal.alias or portal.mxid}](https://matrix.to/#/{portal.mxid}). ") f"[{portal.alias or portal.mxid}](https://matrix.to/#/{portal.mxid}). "
)
if not await user_has_power_level(portal.mxid, evt.az.intent, evt.sender, "unbridge"): if not await user_has_power_level(portal.mxid, evt.az.intent, evt.sender, "unbridge"):
return await evt.reply(f"{has_portal_message}" return await evt.reply(
"Additionally, you do not have the permissions to unbridge " f"{has_portal_message}"
"that room.") "Additionally, you do not have the permissions to unbridge that room."
)
evt.sender.command_status = { evt.sender.command_status = {
"next": confirm_bridge, "next": confirm_bridge,
"action": "Room bridging", "action": "Room bridging",
@@ -88,12 +99,14 @@ async def bridge(evt: CommandEvent) -> EventID:
"peer_type": portal.peer_type, "peer_type": portal.peer_type,
"force_use_bot": force_use_bot, "force_use_bot": force_use_bot,
} }
return await evt.reply(f"{has_portal_message}" return await evt.reply(
"However, you have the permissions to unbridge that room.\n\n" f"{has_portal_message}"
"To delete that portal completely and continue bridging, use " "However, you have the permissions to unbridge that room.\n\n"
"`$cmdprefix+sp delete-and-continue`. To unbridge the portal " "To delete that portal completely and continue bridging, use "
"without kicking Matrix users, use `$cmdprefix+sp unbridge-and-" "`$cmdprefix+sp delete-and-continue`. To unbridge the portal "
"continue`. To cancel, use `$cmdprefix+sp cancel`") "without kicking Matrix users, use `$cmdprefix+sp unbridge-and-"
"continue`. To cancel, use `$cmdprefix+sp cancel`"
)
evt.sender.command_status = { evt.sender.command_status = {
"next": confirm_bridge, "next": confirm_bridge,
"action": "Room bridging", "action": "Room bridging",
@@ -102,29 +115,36 @@ async def bridge(evt: CommandEvent) -> EventID:
"peer_type": portal.peer_type, "peer_type": portal.peer_type,
"force_use_bot": force_use_bot, "force_use_bot": force_use_bot,
} }
return await evt.reply("That Telegram chat has no existing portal. To confirm bridging the " return await evt.reply(
"chat to this room, use `$cmdprefix+sp continue`") "That Telegram chat has no existing portal. To confirm bridging the "
"chat to this room, use `$cmdprefix+sp continue`"
)
async def cleanup_old_portal_while_bridging(evt: CommandEvent, portal: "po.Portal" async def cleanup_old_portal_while_bridging(
) -> tuple[bool, Awaitable[None] | None]: evt: CommandEvent, portal: po.Portal
) -> tuple[bool, Awaitable[None] | None]:
if not portal.mxid: if not portal.mxid:
await evt.reply("The portal seems to have lost its Matrix room between you" await evt.reply(
"calling `$cmdprefix+sp bridge` and this command.\n\n" "The portal seems to have lost its Matrix room between you"
"Continuing without touching previous Matrix room...") "calling `$cmdprefix+sp bridge` and this command.\n\n"
"Continuing without touching previous Matrix room..."
)
return True, None return True, None
elif evt.args[0] == "delete-and-continue": elif evt.args[0] == "delete-and-continue":
return True, portal.cleanup_portal("Portal deleted (moving to another room)", delete=False) return True, portal.cleanup_portal("Portal deleted (moving to another room)", delete=False)
elif evt.args[0] == "unbridge-and-continue": elif evt.args[0] == "unbridge-and-continue":
return True, portal.cleanup_portal("Room unbridged (portal moving to another room)", return True, portal.cleanup_portal(
puppets_only=True, delete=False) "Room unbridged (portal moving to another room)", puppets_only=True, delete=False
)
else: else:
await evt.reply( await evt.reply(
"The chat you were trying to bridge already has a Matrix portal room.\n\n" "The chat you were trying to bridge already has a Matrix portal room.\n\n"
"Please use `$cmdprefix+sp delete-and-continue` or `$cmdprefix+sp unbridge-and-" "Please use `$cmdprefix+sp delete-and-continue` or `$cmdprefix+sp unbridge-and-"
"continue` to either delete or unbridge the existing room (respectively) and " "continue` to either delete or unbridge the existing room (respectively) and "
"continue with the bridging.\n\n" "continue with the bridging.\n\n"
"If you changed your mind, use `$cmdprefix+sp cancel` to cancel.") "If you changed your mind, use `$cmdprefix+sp cancel` to cancel."
)
return False, None return False, None
@@ -135,9 +155,10 @@ async def confirm_bridge(evt: CommandEvent) -> EventID | None:
bridge_to_mxid = status["bridge_to_mxid"] bridge_to_mxid = status["bridge_to_mxid"]
except KeyError: except KeyError:
evt.sender.command_status = None evt.sender.command_status = None
return await evt.reply("Fatal error: tgid or peer_type missing from command_status. " return await evt.reply(
"This shouldn't happen unless you're messing with the command " "Fatal error: tgid or peer_type missing from command_status. "
"handler code.") "This shouldn't happen unless you're messing with the command handler code."
)
is_logged_in = await evt.sender.is_logged_in() and not status["force_use_bot"] is_logged_in = await evt.sender.is_logged_in() and not status["force_use_bot"]
@@ -150,32 +171,41 @@ async def confirm_bridge(evt: CommandEvent) -> EventID | None:
await evt.reply("Cleaning up previous portal room...") await evt.reply("Cleaning up previous portal room...")
elif portal.mxid: elif portal.mxid:
evt.sender.command_status = None evt.sender.command_status = None
return await evt.reply("The portal seems to have created a Matrix room between you " return await evt.reply(
"calling `$cmdprefix+sp bridge` and this command.\n\n" "The portal seems to have created a Matrix room between you "
"Please start over by calling the bridge command again.") "calling `$cmdprefix+sp bridge` and this command.\n\n"
"Please start over by calling the bridge command again."
)
elif evt.args[0] != "continue": elif evt.args[0] != "continue":
return await evt.reply("Please use `$cmdprefix+sp continue` to confirm the bridging or " return await evt.reply(
"`$cmdprefix+sp cancel` to cancel.") "Please use `$cmdprefix+sp continue` to confirm the bridging or "
"`$cmdprefix+sp cancel` to cancel."
)
evt.sender.command_status = None evt.sender.command_status = None
async with portal._room_create_lock: async with portal._room_create_lock:
await _locked_confirm_bridge(evt, portal=portal, room_id=bridge_to_mxid, await _locked_confirm_bridge(
is_logged_in=is_logged_in) evt, portal=portal, room_id=bridge_to_mxid, is_logged_in=is_logged_in
)
async def _locked_confirm_bridge(evt: CommandEvent, portal: 'po.Portal', room_id: RoomID, async def _locked_confirm_bridge(
is_logged_in: bool) -> EventID | None: evt: CommandEvent, portal: po.Portal, room_id: RoomID, is_logged_in: bool
) -> EventID | None:
user = evt.sender if is_logged_in else evt.tgbot user = evt.sender if is_logged_in else evt.tgbot
try: try:
entity = await user.client.get_entity(portal.peer) entity = await user.client.get_entity(portal.peer)
except Exception: except Exception:
evt.log.exception("Failed to get_entity(%s) for manual bridging.", portal.peer) evt.log.exception("Failed to get_entity(%s) for manual bridging.", portal.peer)
if is_logged_in: if is_logged_in:
return await evt.reply("Failed to get info of telegram chat. " return await evt.reply(
"You are logged in, are you in that chat?") "Failed to get info of telegram chat. You are logged in, are you in that chat?"
)
else: else:
return await evt.reply("Failed to get info of telegram chat. " return await evt.reply(
"You're not logged in, is the relay bot in the chat?") "Failed to get info of telegram chat. "
"You're not logged in, is the relay bot in the chat?"
)
if isinstance(entity, (ChatForbidden, ChannelForbidden)): if isinstance(entity, (ChatForbidden, ChannelForbidden)):
if is_logged_in: if is_logged_in:
return await evt.reply("You don't seem to be in that chat.") return await evt.reply("You don't seem to be in that chat.")
@@ -184,14 +214,14 @@ async def _locked_confirm_bridge(evt: CommandEvent, portal: 'po.Portal', room_id
portal.mxid = room_id portal.mxid = room_id
portal.by_mxid[portal.mxid] = portal portal.by_mxid[portal.mxid] = portal
(portal.title, portal.about, levels, (portal.title, portal.about, levels, portal.encrypted) = await get_initial_state(
portal.encrypted) = await get_initial_state(evt.az.intent, evt.room_id) evt.az.intent, evt.room_id
)
portal.photo_id = "" portal.photo_id = ""
await portal.save() await portal.save()
await portal.update_bridge_info() await portal.update_bridge_info()
asyncio.ensure_future(portal.update_matrix_room(user, entity, direct=False, levels=levels), asyncio.create_task(portal.update_matrix_room(user, entity, direct=False, levels=levels))
loop=evt.loop)
await warn_missing_power(levels, evt) await warn_missing_power(levels, evt)
+37 -27
View File
@@ -1,5 +1,5 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge # mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2021 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@@ -13,21 +13,26 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Any from __future__ import annotations
from typing import Any, Awaitable
from io import StringIO from io import StringIO
from ruamel.yaml import YAMLError from ruamel.yaml import YAMLError
from mautrix.util.config import yaml
from mautrix.types import EventID from mautrix.types import EventID
from mautrix.util.config import yaml
from ... import portal as po, util from ... import portal as po, util
from .. import command_handler, CommandEvent, SECTION_PORTAL_MANAGEMENT from .. import SECTION_PORTAL_MANAGEMENT, CommandEvent, command_handler
@command_handler(needs_auth=False, help_section=SECTION_PORTAL_MANAGEMENT, @command_handler(
help_text="View or change per-portal settings.", needs_auth=False,
help_args="<`help`|_subcommand_> [...]") help_section=SECTION_PORTAL_MANAGEMENT,
help_text="View or change per-portal settings.",
help_args="<`help`|_subcommand_> [...]",
)
async def config(evt: CommandEvent) -> None: async def config(evt: CommandEvent) -> None:
cmd = evt.args[0].lower() if len(evt.args) > 0 else "help" cmd = evt.args[0].lower() if len(evt.args) > 0 else "help"
if cmd not in ("view", "defaults", "set", "unset", "add", "del"): if cmd not in ("view", "defaults", "set", "unset", "add", "del"):
@@ -67,7 +72,8 @@ async def config(evt: CommandEvent) -> None:
def config_help(evt: CommandEvent) -> Awaitable[EventID]: def config_help(evt: CommandEvent) -> Awaitable[EventID]:
return evt.reply("""**Usage:** `$cmdprefix config <subcommand> [...]`. Subcommands: return evt.reply(
"""**Usage:** `$cmdprefix config <subcommand> [...]`. Subcommands:
* **help** - View this help text. * **help** - View this help text.
* **view** - View the current config data. * **view** - View the current config data.
@@ -76,7 +82,8 @@ def config_help(evt: CommandEvent) -> Awaitable[EventID]:
* **unset** <_key_> - Remove a config value. * **unset** <_key_> - Remove a config value.
* **add** <_key_> <_value_> - Add a value to an array. * **add** <_key_> <_value_> - Add a value to an array.
* **del** <_key_> <_value_> - Remove a value from an array. * **del** <_key_> <_value_> - Remove a value from an array.
""") """
)
def config_view(evt: CommandEvent, portal: po.Portal) -> Awaitable[EventID]: def config_view(evt: CommandEvent, portal: po.Portal) -> Awaitable[EventID]:
@@ -84,18 +91,20 @@ def config_view(evt: CommandEvent, portal: po.Portal) -> Awaitable[EventID]:
def config_defaults(evt: CommandEvent) -> Awaitable[EventID]: def config_defaults(evt: CommandEvent) -> Awaitable[EventID]:
value = _str_value({ value = _str_value(
"bridge_notices": { {
"default": evt.config["bridge.bridge_notices.default"], "bridge_notices": {
"exceptions": evt.config["bridge.bridge_notices.exceptions"], "default": evt.config["bridge.bridge_notices.default"],
}, "exceptions": evt.config["bridge.bridge_notices.exceptions"],
"bot_messages_as_notices": evt.config["bridge.bot_messages_as_notices"], },
"inline_images": evt.config["bridge.inline_images"], "bot_messages_as_notices": evt.config["bridge.bot_messages_as_notices"],
"message_formats": evt.config["bridge.message_formats"], "inline_images": evt.config["bridge.inline_images"],
"emote_format": evt.config["bridge.emote_format"], "message_formats": evt.config["bridge.message_formats"],
"state_event_formats": evt.config["bridge.state_event_formats"], "emote_format": evt.config["bridge.emote_format"],
"telegram_link_preview": evt.config["bridge.telegram_link_preview"], "state_event_formats": evt.config["bridge.state_event_formats"],
}) "telegram_link_preview": evt.config["bridge.telegram_link_preview"],
}
)
return evt.reply(f"Bridge instance wide config:\n{value.rstrip()}") return evt.reply(f"Bridge instance wide config:\n{value.rstrip()}")
@@ -115,8 +124,7 @@ def config_set(evt: CommandEvent, portal: po.Portal, key: str, value: Any) -> Aw
elif util.recursive_set(portal.local_config, key, value): elif util.recursive_set(portal.local_config, key, value):
return evt.reply(f"Successfully set the value of `{key}` to {_str_value(value)}".rstrip()) return evt.reply(f"Successfully set the value of `{key}` to {_str_value(value)}".rstrip())
else: else:
return evt.reply(f"Failed to set value of `{key}`. " return evt.reply(f"Failed to set value of `{key}`. Does the path contain non-map types?")
"Does the path contain non-map types?")
def config_unset(evt: CommandEvent, portal: po.Portal, key: str) -> Awaitable[EventID]: def config_unset(evt: CommandEvent, portal: po.Portal, key: str) -> Awaitable[EventID]:
@@ -128,15 +136,17 @@ def config_unset(evt: CommandEvent, portal: po.Portal, key: str) -> Awaitable[Ev
return evt.reply(f"`{key}` not found in config.") return evt.reply(f"`{key}` not found in config.")
def config_add_del(evt: CommandEvent, portal: po.Portal, key: str, value: str, cmd: str def config_add_del(
) -> Awaitable[EventID]: evt: CommandEvent, portal: po.Portal, key: str, value: str, cmd: str
) -> Awaitable[EventID]:
if not key or value is None: if not key or value is None:
return evt.reply(f"**Usage:** `$cmdprefix+sp config {cmd} <key> <value>`") return evt.reply(f"**Usage:** `$cmdprefix+sp config {cmd} <key> <value>`")
arr = util.recursive_get(portal.local_config, key) arr = util.recursive_get(portal.local_config, key)
if not arr: if not arr:
return evt.reply(f"`{key}` not found in config. " return evt.reply(
f"Maybe do `$cmdprefix+sp config set {key} []` first?") f"`{key}` not found in config. Maybe do `$cmdprefix+sp config set {key} []` first?"
)
elif not isinstance(arr, list): elif not isinstance(arr, list):
return evt.reply("`{key}` does not seem to be an array.") return evt.reply("`{key}` does not seem to be an array.")
elif cmd == "add": elif cmd == "add":
+29 -14
View File
@@ -1,5 +1,5 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge # mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2021 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@@ -13,24 +13,30 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from mautrix.types import EventID from mautrix.types import EventID
from ... import portal as po from ... import portal as po
from ...types import TelegramID from ...types import TelegramID
from .. import command_handler, CommandEvent, SECTION_CREATING_PORTALS from .. import SECTION_CREATING_PORTALS, CommandEvent, command_handler
from .util import user_has_power_level, get_initial_state, warn_missing_power from .util import get_initial_state, user_has_power_level, warn_missing_power
@command_handler(help_section=SECTION_CREATING_PORTALS, @command_handler(
help_args="[_type_]", help_section=SECTION_CREATING_PORTALS,
help_text="Create a Telegram chat of the given type for the current Matrix room. " help_args="[_type_]",
"The type is either `group`, `supergroup` or `channel` (defaults to " help_text=(
"`supergroup`).") "Create a Telegram chat of the given type for the current Matrix room. "
"The type is either `group`, `supergroup` or `channel` (defaults to `supergroup`)."
),
)
async def create(evt: CommandEvent) -> EventID: async def create(evt: CommandEvent) -> EventID:
type = evt.args[0] if len(evt.args) > 0 else "supergroup" type = evt.args[0] if len(evt.args) > 0 else "supergroup"
if type not in ("chat", "group", "supergroup", "channel"): if type not in ("chat", "group", "supergroup", "channel"):
return await evt.reply( return await evt.reply(
"**Usage:** `$cmdprefix+sp create ['group'/'supergroup'/'channel']`") "**Usage:** `$cmdprefix+sp create ['group'/'supergroup'/'channel']`"
)
if await po.Portal.get_by_mxid(evt.room_id): if await po.Portal.get_by_mxid(evt.room_id):
return await evt.reply("This is already a portal room.") return await evt.reply("This is already a portal room.")
@@ -50,14 +56,23 @@ async def create(evt: CommandEvent) -> EventID:
"group": "chat", "group": "chat",
}[type] }[type]
portal = po.Portal(tgid=TelegramID(0), tg_receiver=TelegramID(0), peer_type=type, portal = po.Portal(
mxid=evt.room_id, title=title, about=about, encrypted=encrypted) tgid=TelegramID(0),
tg_receiver=TelegramID(0),
peer_type=type,
mxid=evt.room_id,
title=title,
about=about,
encrypted=encrypted,
)
invites, errors = await portal.get_telegram_users_in_matrix_room(evt.sender) invites, errors = await portal.get_telegram_users_in_matrix_room(evt.sender)
if len(errors) > 0: if len(errors) > 0:
error_list = "\n".join(f"* [{mxid}](https://matrix.to/#/{mxid})" for mxid in errors) error_list = "\n".join(f"* [{mxid}](https://matrix.to/#/{mxid})" for mxid in errors)
await evt.reply(f"Failed to add the following users to the chat:\n\n{error_list}\n\n" await evt.reply(
"You can try `$cmdprefix+sp search -r <username>` to help the bridge find " f"Failed to add the following users to the chat:\n\n{error_list}\n\n"
"those users.") "You can try `$cmdprefix+sp search -r <username>` to help the bridge find "
"those users."
)
await warn_missing_power(levels, evt) await warn_missing_power(levels, evt)
+28 -18
View File
@@ -1,5 +1,5 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge # mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2021 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@@ -13,17 +13,20 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from mautrix.types import EventID from mautrix.types import EventID
from ... import portal as po from ... import portal as po
from .. import command_handler, CommandEvent, SECTION_ADMIN from .. import SECTION_ADMIN, CommandEvent, command_handler
@command_handler(needs_admin=True, @command_handler(
help_section=SECTION_ADMIN, needs_admin=True,
help_args="<`whitelist`|`blacklist`>", help_section=SECTION_ADMIN,
help_text="Change whether the bridge will allow or disallow bridging rooms by " help_args="<`whitelist`|`blacklist`>",
"default.") help_text="Change whether the bridge will allow or disallow bridging rooms by default.",
)
async def filter_mode(evt: CommandEvent) -> EventID: async def filter_mode(evt: CommandEvent) -> EventID:
try: try:
mode = evt.args[0] mode = evt.args[0]
@@ -36,19 +39,26 @@ async def filter_mode(evt: CommandEvent) -> EventID:
evt.config.save() evt.config.save()
po.Portal.filter_mode = mode po.Portal.filter_mode = mode
if mode == "whitelist": if mode == "whitelist":
return await evt.reply("The bridge will now disallow bridging chats by default.\n" return await evt.reply(
"To allow bridging a specific chat, use" "The bridge will now disallow bridging chats by default.\n"
"`!filter whitelist <chat ID>`.") "To allow bridging a specific chat, use"
"`!filter whitelist <chat ID>`."
)
else: else:
return await evt.reply("The bridge will now allow bridging chats by default.\n" return await evt.reply(
"To disallow bridging a specific chat, use" "The bridge will now allow bridging chats by default.\n"
"`!filter blacklist <chat ID>`.") "To disallow bridging a specific chat, use"
"`!filter blacklist <chat ID>`."
)
@command_handler(name="filter", needs_admin=True, @command_handler(
help_section=SECTION_ADMIN, name="filter",
help_args="<`whitelist`|`blacklist`> <_chat ID_>", needs_admin=True,
help_text="Allow or disallow bridging a specific chat.") help_section=SECTION_ADMIN,
help_args="<`whitelist`|`blacklist`> <_chat ID_>",
help_text="Allow or disallow bridging a specific chat.",
)
async def edit_filter(evt: CommandEvent) -> EventID: async def edit_filter(evt: CommandEvent) -> EventID:
try: try:
action = evt.args[0] action = evt.args[0]
@@ -67,7 +77,7 @@ async def edit_filter(evt: CommandEvent) -> EventID:
mode = evt.config["bridge.filter.mode"] mode = evt.config["bridge.filter.mode"]
if mode not in ("blacklist", "whitelist"): if mode not in ("blacklist", "whitelist"):
return await evt.reply(f"Unknown filter mode \"{mode}\". Please fix the bridge config.") return await evt.reply(f'Unknown filter mode "{mode}". Please fix the bridge config.')
filter_id_list = evt.config["bridge.filter.list"] filter_id_list = evt.config["bridge.filter.list"]
+58 -32
View File
@@ -15,24 +15,33 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from __future__ import annotations
from datetime import timedelta, datetime from datetime import datetime, timedelta
import re import re
from telethon.errors import (
ChatAdminRequiredError,
RPCError,
UsernameInvalidError,
UsernameNotModifiedError,
UsernameOccupiedError,
)
from telethon.tl.functions.channels import GetFullChannelRequest from telethon.tl.functions.channels import GetFullChannelRequest
from telethon.tl.functions.messages import GetFullChatRequest from telethon.tl.functions.messages import GetFullChatRequest
from telethon.errors import (ChatAdminRequiredError, UsernameInvalidError,
UsernameNotModifiedError, UsernameOccupiedError, RPCError)
from mautrix.types import EventID from mautrix.types import EventID
from ... import portal as po from ... import portal as po
from .. import command_handler, CommandEvent, SECTION_PORTAL_MANAGEMENT, SECTION_MISC from .. import SECTION_MISC, SECTION_PORTAL_MANAGEMENT, CommandEvent, command_handler
from .util import user_has_power_level from .util import user_has_power_level
@command_handler(needs_admin=False, needs_puppeting=False, needs_auth=False, @command_handler(
help_section=SECTION_MISC, needs_admin=False,
help_text="Fetch Matrix room state to ensure the bridge has up-to-date info.") needs_puppeting=False,
needs_auth=False,
help_section=SECTION_MISC,
help_text="Fetch Matrix room state to ensure the bridge has up-to-date info.",
)
async def sync_state(evt: CommandEvent) -> EventID: async def sync_state(evt: CommandEvent) -> EventID:
portal = await po.Portal.get_by_mxid(evt.room_id) portal = await po.Portal.get_by_mxid(evt.room_id)
if not portal: if not portal:
@@ -44,8 +53,9 @@ async def sync_state(evt: CommandEvent) -> EventID:
await evt.reply("Synchronization complete") await evt.reply("Synchronization complete")
@command_handler(needs_admin=False, needs_puppeting=False, needs_auth=False, @command_handler(
help_section=SECTION_MISC) needs_admin=False, needs_puppeting=False, needs_auth=False, help_section=SECTION_MISC
)
async def sync_full(evt: CommandEvent) -> EventID: async def sync_full(evt: CommandEvent) -> EventID:
portal = await po.Portal.get_by_mxid(evt.room_id) portal = await po.Portal.get_by_mxid(evt.room_id)
if not portal: if not portal:
@@ -70,9 +80,14 @@ async def sync_full(evt: CommandEvent) -> EventID:
return await evt.reply("Portal synced successfully.") return await evt.reply("Portal synced successfully.")
@command_handler(name="id", needs_admin=False, needs_puppeting=False, needs_auth=False, @command_handler(
help_section=SECTION_MISC, name="id",
help_text="Get the ID of the Telegram chat where this room is bridged.") needs_admin=False,
needs_puppeting=False,
needs_auth=False,
help_section=SECTION_MISC,
help_text="Get the ID of the Telegram chat where this room is bridged.",
)
async def get_id(evt: CommandEvent) -> EventID: async def get_id(evt: CommandEvent) -> EventID:
portal = await po.Portal.get_by_mxid(evt.room_id) portal = await po.Portal.get_by_mxid(evt.room_id)
if not portal: if not portal:
@@ -85,12 +100,14 @@ async def get_id(evt: CommandEvent) -> EventID:
await evt.reply(f"This room is bridged to Telegram chat ID `{tgid}`.") await evt.reply(f"This room is bridged to Telegram chat ID `{tgid}`.")
invite_link_usage = ("**Usage:** `$cmdprefix+sp invite-link [--uses=<amount>] [--expire=<delta>]`" invite_link_usage = (
"\n\n" "**Usage:** `$cmdprefix+sp invite-link [--uses=<amount>] [--expire=<delta>]`"
"* `--uses`: the number of times the invite link can be used." "\n\n"
" Defaults to unlimited.\n" "* `--uses`: the number of times the invite link can be used."
"* `--expire`: the duration after which the link will expire." " Defaults to unlimited.\n"
" A number suffixed with d(ay), h(our), m(inute) or s(econd)") "* `--expire`: the duration after which the link will expire."
" A number suffixed with d(ay), h(our), m(inute) or s(econd)"
)
def _parse_flag(args: list[str]) -> tuple[str, str]: def _parse_flag(args: list[str]) -> tuple[str, str]:
@@ -99,7 +116,7 @@ def _parse_flag(args: list[str]) -> tuple[str, str]:
value_start = arg.index("=") value_start = arg.index("=")
if value_start: if value_start:
flag = arg[2:value_start] flag = arg[2:value_start]
value = arg[value_start+1:] value = arg[value_start + 1 :]
else: else:
flag = arg[2:] flag = arg[2:]
value = args.pop(0).lower() value = args.pop(0).lower()
@@ -114,7 +131,9 @@ def _parse_flag(args: list[str]) -> tuple[str, str]:
return flag, value return flag, value
delta_regex = re.compile("([0-9]+)(w(?:eek)?|d(?:ay)?|h(?:our)?|m(?:in(?:ute)?)?|s(?:ec(?:ond)?)?)") delta_regex = re.compile(
"([0-9]+)(w(?:eek)?|d(?:ay)?|h(?:our)?|m(?:in(?:ute)?)?|s(?:ec(?:ond)?)?)"
)
def _parse_delta(value: str) -> timedelta | None: def _parse_delta(value: str) -> timedelta | None:
@@ -137,9 +156,11 @@ def _parse_delta(value: str) -> timedelta | None:
return None return None
@command_handler(help_section=SECTION_PORTAL_MANAGEMENT, @command_handler(
help_text="Get a Telegram invite link to the current chat.", help_section=SECTION_PORTAL_MANAGEMENT,
help_args="[--uses=<amount>] [--expire=<time delta, e.g. 1d>]") help_text="Get a Telegram invite link to the current chat.",
help_args="[--uses=<amount>] [--expire=<time delta, e.g. 1d>]",
)
async def invite_link(evt: CommandEvent) -> EventID: async def invite_link(evt: CommandEvent) -> EventID:
# TODO once we switch to Python 3.9 minimum, use argparse with exit_on_error=False # TODO once we switch to Python 3.9 minimum, use argparse with exit_on_error=False
uses = None uses = None
@@ -176,8 +197,10 @@ async def invite_link(evt: CommandEvent) -> EventID:
return await evt.reply("You don't have the permission to create an invite link.") return await evt.reply("You don't have the permission to create an invite link.")
@command_handler(help_section=SECTION_PORTAL_MANAGEMENT, @command_handler(
help_text="Upgrade a normal Telegram group to a supergroup.") help_section=SECTION_PORTAL_MANAGEMENT,
help_text="Upgrade a normal Telegram group to a supergroup.",
)
async def upgrade(evt: CommandEvent) -> EventID: async def upgrade(evt: CommandEvent) -> EventID:
portal = await po.Portal.get_by_mxid(evt.room_id) portal = await po.Portal.get_by_mxid(evt.room_id)
if not portal: if not portal:
@@ -196,10 +219,13 @@ async def upgrade(evt: CommandEvent) -> EventID:
return await evt.reply(e.args[0]) return await evt.reply(e.args[0])
@command_handler(help_section=SECTION_PORTAL_MANAGEMENT, @command_handler(
help_args="<_name_|`-`>", help_section=SECTION_PORTAL_MANAGEMENT,
help_text="Change the username of a supergroup/channel. " help_args="<_name_|`-`>",
"To disable, use a dash (`-`) as the name.") help_text=(
"Change the username of a supergroup/channel. To disable, use a dash (`-`) as the name."
),
)
async def group_name(evt: CommandEvent) -> EventID: async def group_name(evt: CommandEvent) -> EventID:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp group-name <name/->`") return await evt.reply("**Usage:** `$cmdprefix+sp group-name <name/->`")
@@ -211,15 +237,15 @@ async def group_name(evt: CommandEvent) -> EventID:
return await evt.reply("Only channels and supergroups have usernames.") return await evt.reply("Only channels and supergroups have usernames.")
try: try:
await portal.set_telegram_username(evt.sender, await portal.set_telegram_username(evt.sender, evt.args[0] if evt.args[0] != "-" else "")
evt.args[0] if evt.args[0] != "-" else "")
if portal.username: if portal.username:
return await evt.reply(f"Username of channel changed to {portal.username}.") return await evt.reply(f"Username of channel changed to {portal.username}.")
else: else:
return await evt.reply(f"Channel is now private.") return await evt.reply(f"Channel is now private.")
except ChatAdminRequiredError: except ChatAdminRequiredError:
return await evt.reply( return await evt.reply(
"You don't have the permission to set the username of this channel.") "You don't have the permission to set the username of this channel."
)
except UsernameNotModifiedError: except UsernameNotModifiedError:
if portal.username: if portal.username:
return await evt.reply("That is already the username of this channel.") return await evt.reply("That is already the username of this channel.")
+45 -29
View File
@@ -17,10 +17,10 @@ from __future__ import annotations
from typing import Callable from typing import Callable
from mautrix.types import RoomID, EventID from mautrix.types import EventID, RoomID
from ... import portal as po from ... import portal as po
from .. import command_handler, CommandEvent, SECTION_PORTAL_MANAGEMENT from .. import SECTION_PORTAL_MANAGEMENT, CommandEvent, command_handler
from .util import user_has_power_level from .util import user_has_power_level
@@ -45,8 +45,9 @@ async def _get_portal_and_check_permission(evt: CommandEvent) -> po.Portal | Non
return portal return portal
def _get_portal_murder_function(action: str, room_id: str, function: Callable, command: str, def _get_portal_murder_function(
completed_message: str) -> dict: action: str, room_id: str, function: Callable, command: str, completed_message: str
) -> dict:
async def post_confirm(confirm) -> EventID | None: async def post_confirm(confirm) -> EventID | None:
confirm.sender.command_status = None confirm.sender.command_status = None
if len(confirm.args) > 0 and confirm.args[0] == f"confirm-{command}": if len(confirm.args) > 0 and confirm.args[0] == f"confirm-{command}":
@@ -63,40 +64,55 @@ def _get_portal_murder_function(action: str, room_id: str, function: Callable, c
} }
@command_handler(needs_auth=False, needs_puppeting=False, @command_handler(
help_section=SECTION_PORTAL_MANAGEMENT, needs_auth=False,
help_text="Remove all users from the current portal room and forget the portal. " needs_puppeting=False,
"Only works for group chats; to delete a private chat portal, simply " help_section=SECTION_PORTAL_MANAGEMENT,
"leave the room.") help_text=(
"Remove all users from the current portal room and forget the portal. "
"Only works for group chats; to delete a private chat portal, simply leave the room."
),
)
async def delete_portal(evt: CommandEvent) -> EventID | None: async def delete_portal(evt: CommandEvent) -> EventID | None:
portal = await _get_portal_and_check_permission(evt) portal = await _get_portal_and_check_permission(evt)
if not portal: if not portal:
return None return None
evt.sender.command_status = _get_portal_murder_function("Portal deletion", portal.mxid, evt.sender.command_status = _get_portal_murder_function(
portal.cleanup_and_delete, "delete", "Portal deletion",
"Portal successfully deleted.") portal.mxid,
return await evt.reply("Please confirm deletion of portal " portal.cleanup_and_delete,
f"[{portal.alias or portal.mxid}](https://matrix.to/#/{portal.mxid}) " "delete",
f"to Telegram chat \"{portal.title}\" " "Portal successfully deleted.",
"by typing `$cmdprefix+sp confirm-delete`" )
"\n\n" return await evt.reply(
"**WARNING:** If the bridge bot has the power level to do so, **this " "Please confirm deletion of portal "
"will kick ALL users** in the room. If you just want to remove the " f"[{portal.alias or portal.mxid}](https://matrix.to/#/{portal.mxid}) "
"bridge, use `$cmdprefix+sp unbridge` instead.") f'to Telegram chat "{portal.title}" '
"by typing `$cmdprefix+sp confirm-delete`"
"\n\n"
"**WARNING:** If the bridge bot has the power level to do so, **this "
"will kick ALL users** in the room. If you just want to remove the "
"bridge, use `$cmdprefix+sp unbridge` instead."
)
@command_handler(needs_auth=False, needs_puppeting=False, @command_handler(
help_section=SECTION_PORTAL_MANAGEMENT, needs_auth=False,
help_text="Remove puppets from the current portal room and forget the portal.") needs_puppeting=False,
help_section=SECTION_PORTAL_MANAGEMENT,
help_text="Remove puppets from the current portal room and forget the portal.",
)
async def unbridge(evt: CommandEvent) -> EventID | None: async def unbridge(evt: CommandEvent) -> EventID | None:
portal = await _get_portal_and_check_permission(evt) portal = await _get_portal_and_check_permission(evt)
if not portal: if not portal:
return None return None
evt.sender.command_status = _get_portal_murder_function("Room unbridging", portal.mxid, evt.sender.command_status = _get_portal_murder_function(
portal.unbridge, "unbridge", "Room unbridging", portal.mxid, portal.unbridge, "unbridge", "Room successfully unbridged."
"Room successfully unbridged.") )
return await evt.reply(f"Please confirm unbridging chat \"{portal.title}\" from room " return await evt.reply(
f"[{portal.alias or portal.mxid}](https://matrix.to/#/{portal.mxid}) " f'Please confirm unbridging chat "{portal.title}" from room '
"by typing `$cmdprefix+sp confirm-unbridge`") f"[{portal.alias or portal.mxid}](https://matrix.to/#/{portal.mxid}) "
"by typing `$cmdprefix+sp confirm-unbridge`"
)
+11 -9
View File
@@ -15,12 +15,12 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from __future__ import annotations
from mautrix.errors import MatrixRequestError
from mautrix.appservice import IntentAPI from mautrix.appservice import IntentAPI
from mautrix.types import RoomID, EventType, PowerLevelStateEventContent from mautrix.errors import MatrixRequestError
from .. import CommandEvent from mautrix.types import EventType, PowerLevelStateEventContent, RoomID
from ... import user as u from ... import user as u
from .. import CommandEvent
async def get_initial_state( async def get_initial_state(
@@ -51,14 +51,16 @@ async def get_initial_state(
async def warn_missing_power(levels: PowerLevelStateEventContent, evt: CommandEvent) -> None: async def warn_missing_power(levels: PowerLevelStateEventContent, evt: CommandEvent) -> None:
if levels.get_user_level(evt.az.bot_mxid) < levels.redact: if levels.get_user_level(evt.az.bot_mxid) < levels.redact:
await evt.reply("Warning: The bot does not have privileges to redact messages on Matrix. " await evt.reply(
"Message deletions from Telegram will not be bridged unless you give " "Warning: The bot does not have privileges to redact messages on Matrix. "
"redaction permissions to " "Message deletions from Telegram will not be bridged unless you give "
f"[{evt.az.bot_mxid}](https://matrix.to/#/{evt.az.bot_mxid})") f"redaction permissions to [{evt.az.bot_mxid}](https://matrix.to/#/{evt.az.bot_mxid})"
)
async def user_has_power_level(room_id: RoomID, intent: IntentAPI, sender: u.User, async def user_has_power_level(
event: str) -> bool: room_id: RoomID, intent: IntentAPI, sender: u.User, event: str
) -> bool:
if sender.is_admin: if sender.is_admin:
return True return True
# Make sure the state store contains the power levels. # Make sure the state store contains the power levels.
+70 -40
View File
@@ -1,5 +1,5 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge # mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2021 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@@ -13,22 +13,36 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from telethon.errors import (UsernameInvalidError, UsernameNotModifiedError, UsernameOccupiedError, from __future__ import annotations
HashInvalidError, AuthKeyError, FirstNameInvalidError,
AboutTooLongError) from telethon.errors import (
AboutTooLongError,
AuthKeyError,
FirstNameInvalidError,
HashInvalidError,
UsernameInvalidError,
UsernameNotModifiedError,
UsernameOccupiedError,
)
from telethon.tl.functions.account import (
GetAuthorizationsRequest,
ResetAuthorizationRequest,
UpdateProfileRequest,
UpdateUsernameRequest,
)
from telethon.tl.types import Authorization from telethon.tl.types import Authorization
from telethon.tl.functions.account import (UpdateUsernameRequest, GetAuthorizationsRequest,
ResetAuthorizationRequest, UpdateProfileRequest)
from mautrix.types import EventID from mautrix.types import EventID
from .. import command_handler, CommandEvent, SECTION_AUTH from .. import SECTION_AUTH, CommandEvent, command_handler
@command_handler(needs_auth=True, @command_handler(
help_section=SECTION_AUTH, needs_auth=True,
help_args="<_new username_>", help_section=SECTION_AUTH,
help_text="Change your Telegram username.") help_args="<_new username_>",
help_text="Change your Telegram username.",
)
async def username(evt: CommandEvent) -> EventID: async def username(evt: CommandEvent) -> EventID:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp username <new username>`") return await evt.reply("**Usage:** `$cmdprefix+sp username <new username>`")
@@ -40,8 +54,9 @@ async def username(evt: CommandEvent) -> EventID:
try: try:
await evt.sender.client(UpdateUsernameRequest(username=new_name)) await evt.sender.client(UpdateUsernameRequest(username=new_name))
except UsernameInvalidError: except UsernameInvalidError:
return await evt.reply("Invalid username. Usernames must be between 5 and 30 alphanumeric " return await evt.reply(
"characters.") "Invalid username. Usernames must be between 5 and 30 alphanumeric characters."
)
except UsernameNotModifiedError: except UsernameNotModifiedError:
return await evt.reply("That is your current username.") return await evt.reply("That is your current username.")
except UsernameOccupiedError: except UsernameOccupiedError:
@@ -53,10 +68,12 @@ async def username(evt: CommandEvent) -> EventID:
await evt.reply(f"Username changed to {evt.sender.tg_username}") await evt.reply(f"Username changed to {evt.sender.tg_username}")
@command_handler(needs_auth=True, @command_handler(
help_section=SECTION_AUTH, needs_auth=True,
help_args="<_new about_>", help_section=SECTION_AUTH,
help_text="Change your Telegram about section.") help_args="<_new about_>",
help_text="Change your Telegram about section.",
)
async def about(evt: CommandEvent) -> EventID: async def about(evt: CommandEvent) -> EventID:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp about <new about>`") return await evt.reply("**Usage:** `$cmdprefix+sp about <new about>`")
@@ -72,17 +89,21 @@ async def about(evt: CommandEvent) -> EventID:
return await evt.reply("About section updated") return await evt.reply("About section updated")
@command_handler(needs_auth=True, help_section=SECTION_AUTH, help_args="<_new displayname_>", @command_handler(
help_text="Change your Telegram displayname.") needs_auth=True,
help_section=SECTION_AUTH,
help_args="<_new displayname_>",
help_text="Change your Telegram displayname.",
)
async def displayname(evt: CommandEvent) -> EventID: async def displayname(evt: CommandEvent) -> EventID:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp displayname <new displayname>`") return await evt.reply("**Usage:** `$cmdprefix+sp displayname <new displayname>`")
if evt.sender.is_bot: if evt.sender.is_bot:
return await evt.reply("Bots can't set their own displayname.") return await evt.reply("Bots can't set their own displayname.")
first_name, last_name = ((evt.args[0], "") first_name, last_name = (
if len(evt.args) == 1 (evt.args[0], "") if len(evt.args) == 1 else (" ".join(evt.args[:-1]), evt.args[-1])
else (" ".join(evt.args[:-1]), evt.args[-1])) )
try: try:
await evt.sender.client(UpdateProfileRequest(first_name=first_name, last_name=last_name)) await evt.sender.client(UpdateProfileRequest(first_name=first_name, last_name=last_name))
except FirstNameInvalidError: except FirstNameInvalidError:
@@ -92,16 +113,20 @@ async def displayname(evt: CommandEvent) -> EventID:
def _format_session(sess: Authorization) -> str: def _format_session(sess: Authorization) -> str:
return (f"**{sess.app_name} {sess.app_version}** \n" return (
f" **Platform:** {sess.device_model} {sess.platform} {sess.system_version} \n" f"**{sess.app_name} {sess.app_version}** \n"
f" **Active:** {sess.date_active} (created {sess.date_created}) \n" f" **Platform:** {sess.device_model} {sess.platform} {sess.system_version} \n"
f" **From:** {sess.ip} - {sess.region}, {sess.country}") f" **Active:** {sess.date_active} (created {sess.date_created}) \n"
f" **From:** {sess.ip} - {sess.region}, {sess.country}"
)
@command_handler(needs_auth=True, @command_handler(
help_section=SECTION_AUTH, needs_auth=True,
help_args="<`list`|`terminate`> [_hash_]", help_section=SECTION_AUTH,
help_text="View or delete other Telegram sessions.") help_args="<`list`|`terminate`> [_hash_]",
help_text="View or delete other Telegram sessions.",
)
async def session(evt: CommandEvent) -> EventID: async def session(evt: CommandEvent) -> EventID:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp session <list|terminate> [hash]`") return await evt.reply("**Usage:** `$cmdprefix+sp session <list|terminate> [hash]`")
@@ -113,14 +138,18 @@ async def session(evt: CommandEvent) -> EventID:
session_list = res.authorizations session_list = res.authorizations
current = [s for s in session_list if s.current][0] current = [s for s in session_list if s.current][0]
current_text = _format_session(current) current_text = _format_session(current)
other_text = "\n".join(f"* {_format_session(sess)} \n" other_text = "\n".join(
f" **Hash:** {sess.hash}" f"* {_format_session(sess)} \n **Hash:** {sess.hash}"
for sess in session_list if not sess.current) for sess in session_list
return await evt.reply(f"### Current session\n" if not sess.current
f"{current_text}\n" )
f"\n" return await evt.reply(
f"### Other active sessions\n" f"### Current session\n"
f"{other_text}") f"{current_text}\n"
f"\n"
f"### Other active sessions\n"
f"{other_text}"
)
elif cmd == "terminate" and len(evt.args) > 1: elif cmd == "terminate" and len(evt.args) > 1:
try: try:
session_hash = int(evt.args[1]) session_hash = int(evt.args[1])
@@ -132,8 +161,9 @@ async def session(evt: CommandEvent) -> EventID:
return await evt.reply("Invalid session hash.") return await evt.reply("Invalid session hash.")
except AuthKeyError as e: except AuthKeyError as e:
if e.message == "FRESH_RESET_AUTHORISATION_FORBIDDEN": if e.message == "FRESH_RESET_AUTHORISATION_FORBIDDEN":
return await evt.reply("New sessions can't terminate other sessions. " return await evt.reply(
"Please wait a while.") "New sessions can't terminate other sessions. Please wait a while."
)
raise raise
if ok: if ok:
return await evt.reply("Session terminated successfully.") return await evt.reply("Session terminated successfully.")
+161 -90
View File
@@ -20,33 +20,49 @@ import asyncio
import io import io
from telethon.errors import ( from telethon.errors import (
AccessTokenExpiredError, AccessTokenInvalidError, FirstNameInvalidError, FloodWaitError, AccessTokenExpiredError,
PasswordHashInvalidError, PhoneCodeExpiredError, PhoneCodeInvalidError, AccessTokenInvalidError,
PhoneNumberAppSignupForbiddenError, PhoneNumberBannedError, PhoneNumberFloodError, FirstNameInvalidError,
PhoneNumberOccupiedError, PhoneNumberUnoccupiedError, SessionPasswordNeededError, FloodWaitError,
PhoneNumberInvalidError) PasswordHashInvalidError,
PhoneCodeExpiredError,
PhoneCodeInvalidError,
PhoneNumberAppSignupForbiddenError,
PhoneNumberBannedError,
PhoneNumberFloodError,
PhoneNumberInvalidError,
PhoneNumberOccupiedError,
PhoneNumberUnoccupiedError,
SessionPasswordNeededError,
)
from telethon.tl.types import User from telethon.tl.types import User
from mautrix.types import (EventID, UserID, MediaMessageEventContent, ImageInfo, MessageType, from mautrix.types import (
TextMessageEventContent) EventID,
ImageInfo,
MediaMessageEventContent,
MessageType,
TextMessageEventContent,
UserID,
)
from mautrix.util.format_duration import format_duration as fmt_duration from mautrix.util.format_duration import format_duration as fmt_duration
from ... import user as u from ... import user as u
from ...commands import SECTION_AUTH, CommandEvent, command_handler
from ...types import TelegramID from ...types import TelegramID
from ...commands import command_handler, CommandEvent, SECTION_AUTH
try: try:
import qrcode
import PIL as _
from telethon.tl.custom import QRLogin from telethon.tl.custom import QRLogin
import PIL as _
import qrcode
except ImportError: except ImportError:
qrcode = None qrcode = None
QRLogin = None QRLogin = None
@command_handler(needs_auth=False, @command_handler(
help_section=SECTION_AUTH, needs_auth=False, help_section=SECTION_AUTH, help_text="Check if you're logged into Telegram."
help_text="Check if you're logged into Telegram.") )
async def ping(evt: CommandEvent) -> EventID: async def ping(evt: CommandEvent) -> EventID:
if await evt.sender.is_logged_in(): if await evt.sender.is_logged_in():
me = await evt.sender.get_me() me = await evt.sender.get_me()
@@ -59,22 +75,30 @@ async def ping(evt: CommandEvent) -> EventID:
return await evt.reply("You're not logged in.") return await evt.reply("You're not logged in.")
@command_handler(needs_auth=False, needs_puppeting=False, @command_handler(
help_section=SECTION_AUTH, needs_auth=False,
help_text="Get the info of the message relay Telegram bot.") needs_puppeting=False,
help_section=SECTION_AUTH,
help_text="Get the info of the message relay Telegram bot.",
)
async def ping_bot(evt: CommandEvent) -> EventID: async def ping_bot(evt: CommandEvent) -> EventID:
if not evt.tgbot: if not evt.tgbot:
return await evt.reply("Telegram message relay bot not configured.") return await evt.reply("Telegram message relay bot not configured.")
info, mxid = await evt.tgbot.get_me(use_cache=False) info, mxid = await evt.tgbot.get_me(use_cache=False)
return await evt.reply("Telegram message relay bot is active: " return await evt.reply(
f"[{info.first_name}](https://matrix.to/#/{mxid}) (ID {info.id})\n\n" "Telegram message relay bot is active: "
"To use the bot, simply invite it to a portal room.") f"[{info.first_name}](https://matrix.to/#/{mxid}) (ID {info.id})\n\n"
"To use the bot, simply invite it to a portal room."
)
@command_handler(needs_auth=False, management_only=True, @command_handler(
help_section=SECTION_AUTH, needs_auth=False,
help_args="<_phone_> <_full name_>", management_only=True,
help_text="Register to Telegram") help_section=SECTION_AUTH,
help_args="<_phone_> <_full name_>",
help_text="Register to Telegram",
)
async def register(evt: CommandEvent) -> EventID: async def register(evt: CommandEvent) -> EventID:
if await evt.sender.is_logged_in(): if await evt.sender.is_logged_in():
return await evt.reply("You are already logged in.") return await evt.reply("You are already logged in.")
@@ -87,13 +111,19 @@ async def register(evt: CommandEvent) -> EventID:
else: else:
full_name = " ".join(evt.args[1:-1]), evt.args[-1] full_name = " ".join(evt.args[1:-1]), evt.args[-1]
await _request_code(evt, phone_number, { await _request_code(
"next": enter_code_register, evt,
"action": "Register", phone_number,
"full_name": full_name, {
}) "next": enter_code_register,
return await evt.reply("By signing up for Telegram, you agree to " "action": "Register",
"the terms of service: https://telegram.org/tos") "full_name": full_name,
},
)
return await evt.reply(
"By signing up for Telegram, you agree to "
"the terms of service: https://telegram.org/tos"
)
async def enter_code_register(evt: CommandEvent) -> EventID: async def enter_code_register(evt: CommandEvent) -> EventID:
@@ -107,23 +137,31 @@ async def enter_code_register(evt: CommandEvent) -> EventID:
evt.sender.command_status = None evt.sender.command_status = None
return await evt.reply(f"Successfully registered to Telegram.") return await evt.reply(f"Successfully registered to Telegram.")
except PhoneNumberOccupiedError: except PhoneNumberOccupiedError:
return await evt.reply("That phone number has already been registered. " return await evt.reply(
"You can log in with `$cmdprefix+sp login`.") "That phone number has already been registered. "
"You can log in with `$cmdprefix+sp login`."
)
except FirstNameInvalidError: except FirstNameInvalidError:
return await evt.reply("Invalid name. Please set a Matrix displayname before registering.") return await evt.reply("Invalid name. Please set a Matrix displayname before registering.")
except PhoneCodeExpiredError: except PhoneCodeExpiredError:
return await evt.reply( return await evt.reply(
"Phone code expired. Try again with `$cmdprefix+sp register <phone>`.") "Phone code expired. Try again with `$cmdprefix+sp register <phone>`."
)
except PhoneCodeInvalidError: except PhoneCodeInvalidError:
return await evt.reply("Invalid phone code.") return await evt.reply("Invalid phone code.")
except Exception: except Exception:
evt.log.exception("Error sending phone code") evt.log.exception("Error sending phone code")
return await evt.reply("Unhandled exception while sending code. " return await evt.reply(
"Check console for more details.") "Unhandled exception while sending code. Check console for more details."
)
@command_handler(needs_auth=False, management_only=True, help_section=SECTION_AUTH, @command_handler(
help_text="Log in by scanning a QR code.") needs_auth=False,
management_only=True,
help_section=SECTION_AUTH,
help_text="Log in by scanning a QR code.",
)
async def login_qr(evt: CommandEvent) -> EventID: async def login_qr(evt: CommandEvent) -> EventID:
login_as = evt.sender login_as = evt.sender
if len(evt.args) > 0 and evt.sender.is_admin: if len(evt.args) > 0 and evt.sender.is_admin:
@@ -145,9 +183,12 @@ async def login_qr(evt: CommandEvent) -> EventID:
image.save(buffer, "PNG") image.save(buffer, "PNG")
qr = buffer.getvalue() qr = buffer.getvalue()
mxc = await evt.az.intent.upload_media(qr, "image/png", "login-qr.png", len(qr)) mxc = await evt.az.intent.upload_media(qr, "image/png", "login-qr.png", len(qr))
content = MediaMessageEventContent(body=qr_login.url, url=mxc, msgtype=MessageType.IMAGE, content = MediaMessageEventContent(
info=ImageInfo(mimetype="image/png", size=len(qr), body=qr_login.url,
width=size, height=size)) url=mxc,
msgtype=MessageType.IMAGE,
info=ImageInfo(mimetype="image/png", size=len(qr), width=size, height=size),
)
if qr_event_id: if qr_event_id:
content.set_edit(qr_event_id) content.set_edit(qr_event_id)
await evt.az.intent.send_message(evt.room_id, content) await evt.az.intent.send_message(evt.room_id, content)
@@ -170,8 +211,9 @@ async def login_qr(evt: CommandEvent) -> EventID:
"login_as": login_as if login_as != evt.sender else None, "login_as": login_as if login_as != evt.sender else None,
"action": "Login (password entry)", "action": "Login (password entry)",
} }
return await evt.reply("Your account has two-factor authentication. " return await evt.reply(
"Please send your password here.") "Your account has two-factor authentication. Please send your password here."
)
else: else:
timeout = TextMessageEventContent(body="Login timed out", msgtype=MessageType.TEXT) timeout = TextMessageEventContent(body="Login timed out", msgtype=MessageType.TEXT)
timeout.set_edit(qr_event_id) timeout.set_edit(qr_event_id)
@@ -180,9 +222,12 @@ async def login_qr(evt: CommandEvent) -> EventID:
return await _finish_sign_in(evt, user, login_as=login_as) return await _finish_sign_in(evt, user, login_as=login_as)
@command_handler(needs_auth=False, management_only=True, @command_handler(
help_section=SECTION_AUTH, needs_auth=False,
help_text="Get instructions on how to log in.") management_only=True,
help_section=SECTION_AUTH,
help_text="Get instructions on how to log in.",
)
async def login(evt: CommandEvent) -> EventID: async def login(evt: CommandEvent) -> EventID:
override_sender = False override_sender = False
if len(evt.args) > 0 and evt.sender.is_admin: if len(evt.args) > 0 and evt.sender.is_admin:
@@ -203,24 +248,32 @@ async def login(evt: CommandEvent) -> EventID:
prefix = evt.config["appservice.public.external"] prefix = evt.config["appservice.public.external"]
url = f"{prefix}/login?token={evt.public_website.make_token(evt.sender.mxid, '/login')}" url = f"{prefix}/login?token={evt.public_website.make_token(evt.sender.mxid, '/login')}"
if override_sender: if override_sender:
return await evt.reply(f"[Click here to log in]({url}) as " return await evt.reply(
f"[{evt.sender.mxid}](https://matrix.to/#/{evt.sender.mxid}).") f"[Click here to log in]({url}) as "
f"[{evt.sender.mxid}](https://matrix.to/#/{evt.sender.mxid})."
)
elif allow_matrix_login: elif allow_matrix_login:
return await evt.reply(f"[Click here to log in]({url}). Alternatively, send your phone" return await evt.reply(
f" number (or bot auth token) here to log in.\n\n{nb}") f"[Click here to log in]({url}). Alternatively, send your phone"
f" number (or bot auth token) here to log in.\n\n{nb}"
)
return await evt.reply(f"[Click here to log in]({url}).\n\n{nb}") return await evt.reply(f"[Click here to log in]({url}).\n\n{nb}")
elif allow_matrix_login: elif allow_matrix_login:
if override_sender: if override_sender:
return await evt.reply( return await evt.reply(
"This bridge instance does not allow you to log in outside of Matrix. " "This bridge instance does not allow you to log in outside of Matrix. "
"Logging in as another user inside Matrix is not currently possible.") "Logging in as another user inside Matrix is not currently possible."
return await evt.reply("Please send your phone number (or bot auth token) here to start " )
f"the login process.\n\n{nb}") return await evt.reply(
"Please send your phone number (or bot auth token) here to start "
f"the login process.\n\n{nb}"
)
return await evt.reply("This bridge instance has been configured to not allow logging in.") return await evt.reply("This bridge instance has been configured to not allow logging in.")
async def _request_code(evt: CommandEvent, phone_number: str, next_status: dict[str, Any] async def _request_code(
) -> EventID: evt: CommandEvent, phone_number: str, next_status: dict[str, Any]
) -> EventID:
ok = False ok = False
try: try:
await evt.sender.ensure_started(even_if_no_session=True) await evt.sender.ensure_started(even_if_no_session=True)
@@ -230,22 +283,29 @@ async def _request_code(evt: CommandEvent, phone_number: str, next_status: dict[
except PhoneNumberAppSignupForbiddenError: except PhoneNumberAppSignupForbiddenError:
return await evt.reply("Your phone number does not allow 3rd party apps to sign in.") return await evt.reply("Your phone number does not allow 3rd party apps to sign in.")
except PhoneNumberFloodError: except PhoneNumberFloodError:
return await evt.reply("Your phone number has been temporarily blocked for flooding. " return await evt.reply(
"The ban is usually applied for around a day.") "Your phone number has been temporarily blocked for flooding. "
"The ban is usually applied for around a day."
)
except FloodWaitError as e: except FloodWaitError as e:
return await evt.reply("Your phone number has been temporarily blocked for flooding. " return await evt.reply(
f"Please wait for {fmt_duration(e.seconds)} before trying again.") "Your phone number has been temporarily blocked for flooding. "
f"Please wait for {fmt_duration(e.seconds)} before trying again."
)
except PhoneNumberBannedError: except PhoneNumberBannedError:
return await evt.reply("Your phone number has been banned from Telegram.") return await evt.reply("Your phone number has been banned from Telegram.")
except PhoneNumberUnoccupiedError: except PhoneNumberUnoccupiedError:
return await evt.reply("That phone number has not been registered. " return await evt.reply(
"Please register with `$cmdprefix+sp register <phone>`.") "That phone number has not been registered. "
"Please register with `$cmdprefix+sp register <phone>`."
)
except PhoneNumberInvalidError: except PhoneNumberInvalidError:
return await evt.reply("That phone number is not valid.") return await evt.reply("That phone number is not valid.")
except Exception: except Exception:
evt.log.exception("Error requesting phone code") evt.log.exception("Error requesting phone code")
return await evt.reply("Unhandled exception while requesting code. " return await evt.reply(
"Check console for more details.") "Unhandled exception while requesting code. Check console for more details."
)
finally: finally:
evt.sender.command_status = next_status if ok else None evt.sender.command_status = next_status if ok else None
@@ -255,8 +315,10 @@ async def enter_phone_or_token(evt: CommandEvent) -> EventID | None:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp enter-phone-or-token <phone-or-token>`") return await evt.reply("**Usage:** `$cmdprefix+sp enter-phone-or-token <phone-or-token>`")
elif not evt.config.get("bridge.allow_matrix_login", True): elif not evt.config.get("bridge.allow_matrix_login", True):
return await evt.reply("This bridge instance does not allow in-Matrix login. " return await evt.reply(
"Please use `$cmdprefix+sp login` to get login instructions") "This bridge instance does not allow in-Matrix login. "
"Please use `$cmdprefix+sp login` to get login instructions"
)
# phone numbers don't contain colons but telegram bot auth tokens do # phone numbers don't contain colons but telegram bot auth tokens do
if evt.args[0].find(":") > 0: if evt.args[0].find(":") > 0:
@@ -264,13 +326,11 @@ async def enter_phone_or_token(evt: CommandEvent) -> EventID | None:
await _sign_in(evt, bot_token=evt.args[0]) await _sign_in(evt, bot_token=evt.args[0])
except Exception: except Exception:
evt.log.exception("Error sending auth token") evt.log.exception("Error sending auth token")
return await evt.reply("Unhandled exception while sending auth token. " return await evt.reply(
"Check console for more details.") "Unhandled exception while sending auth token. Check console for more details."
)
else: else:
await _request_code(evt, evt.args[0], { await _request_code(evt, evt.args[0], {"next": enter_code, "action": "Login"})
"next": enter_code,
"action": "Login",
})
return None return None
@@ -279,14 +339,17 @@ async def enter_code(evt: CommandEvent) -> EventID | None:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp enter-code <code>`") return await evt.reply("**Usage:** `$cmdprefix+sp enter-code <code>`")
elif not evt.config.get("bridge.allow_matrix_login", True): elif not evt.config.get("bridge.allow_matrix_login", True):
return await evt.reply("This bridge instance does not allow in-Matrix login. " return await evt.reply(
"Please use `$cmdprefix+sp login` to get login instructions") "This bridge instance does not allow in-Matrix login. "
"Please use `$cmdprefix+sp login` to get login instructions"
)
try: try:
await _sign_in(evt, code=evt.args[0]) await _sign_in(evt, code=evt.args[0])
except Exception: except Exception:
evt.log.exception("Error sending phone code") evt.log.exception("Error sending phone code")
return await evt.reply("Unhandled exception while sending code. " return await evt.reply(
"Check console for more details.") "Unhandled exception while sending code. Check console for more details."
)
return None return None
@@ -295,19 +358,25 @@ async def enter_password(evt: CommandEvent) -> EventID | None:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp enter-password <password>`") return await evt.reply("**Usage:** `$cmdprefix+sp enter-password <password>`")
elif not evt.config.get("bridge.allow_matrix_login", True): elif not evt.config.get("bridge.allow_matrix_login", True):
return await evt.reply("This bridge instance does not allow in-Matrix login. " return await evt.reply(
"Please use `$cmdprefix+sp login` to get login instructions") "This bridge instance does not allow in-Matrix login. "
"Please use `$cmdprefix+sp login` to get login instructions"
)
try: try:
await _sign_in(evt, login_as=evt.sender.command_status.get("login_as", None), await _sign_in(
password=" ".join(evt.args)) evt,
login_as=evt.sender.command_status.get("login_as", None),
password=" ".join(evt.args),
)
except AccessTokenInvalidError: except AccessTokenInvalidError:
return await evt.reply("That bot token is not valid.") return await evt.reply("That bot token is not valid.")
except AccessTokenExpiredError: except AccessTokenExpiredError:
return await evt.reply("That bot token has expired.") return await evt.reply("That bot token has expired.")
except Exception: except Exception:
evt.log.exception("Error sending password") evt.log.exception("Error sending password")
return await evt.reply("Unhandled exception while sending password. " return await evt.reply(
"Check console for more details.") "Unhandled exception while sending password. Check console for more details."
)
return None return None
@@ -328,8 +397,9 @@ async def _sign_in(evt: CommandEvent, login_as: u.User = None, **sign_in_info) -
"next": enter_password, "next": enter_password,
"action": "Login (password entry)", "action": "Login (password entry)",
} }
return await evt.reply("Your account has two-factor authentication. " return await evt.reply(
"Please send your password here.") "Your account has two-factor authentication. Please send your password here."
)
async def _finish_sign_in(evt: CommandEvent, user: User, login_as: u.User = None) -> EventID: async def _finish_sign_in(evt: CommandEvent, user: User, login_as: u.User = None) -> EventID:
@@ -337,23 +407,24 @@ async def _finish_sign_in(evt: CommandEvent, user: User, login_as: u.User = None
existing_user = await u.User.get_by_tgid(TelegramID(user.id)) existing_user = await u.User.get_by_tgid(TelegramID(user.id))
if existing_user and existing_user != login_as: if existing_user and existing_user != login_as:
await existing_user.log_out() await existing_user.log_out()
await evt.reply(f"[{existing_user.displayname}]" await evt.reply(
f"(https://matrix.to/#/{existing_user.mxid})" f"[{existing_user.displayname}] (https://matrix.to/#/{existing_user.mxid})"
" was logged out from the account.") " was logged out from the account."
)
asyncio.ensure_future(login_as.post_login(user, first_login=True), loop=evt.loop) asyncio.ensure_future(login_as.post_login(user, first_login=True), loop=evt.loop)
evt.sender.command_status = None evt.sender.command_status = None
name = f"@{user.username}" if user.username else f"+{user.phone}" name = f"@{user.username}" if user.username else f"+{user.phone}"
if login_as != evt.sender: if login_as != evt.sender:
msg = (f"Successfully logged in [{login_as.mxid}](https://matrix.to/#/{login_as.mxid})" msg = (
f" as {name}") f"Successfully logged in [{login_as.mxid}](https://matrix.to/#/{login_as.mxid})"
f" as {name}"
)
else: else:
msg = f"Successfully logged in as {name}" msg = f"Successfully logged in as {name}"
return await evt.reply(msg) return await evt.reply(msg)
@command_handler(needs_auth=False, @command_handler(needs_auth=False, help_section=SECTION_AUTH, help_text="Log out from Telegram.")
help_section=SECTION_AUTH,
help_text="Log out from Telegram.")
async def logout(evt: CommandEvent) -> EventID: async def logout(evt: CommandEvent) -> EventID:
if not evt.sender.tgid: if not evt.sender.tgid:
return await evt.reply("You're not logged in") return await evt.reply("You're not logged in")
+135 -72
View File
@@ -16,36 +16,59 @@
from __future__ import annotations from __future__ import annotations
from typing import cast from typing import cast
import codecs
import base64 import base64
import codecs
import re import re
from aiohttp import ClientSession, InvalidURL from aiohttp import ClientSession, InvalidURL
from telethon.errors import (
from telethon.errors import (InviteHashInvalidError, InviteHashExpiredError, OptionsTooMuchError, ChatIdInvalidError,
UserAlreadyParticipantError, ChatIdInvalidError, EmoticonInvalidError,
TakeoutInitDelayError, EmoticonInvalidError) InviteHashExpiredError,
from telethon.tl.patched import Message InviteHashInvalidError,
from telethon.tl.types import (User as TLUser, TypeUpdates, MessageMediaGame, MessageMediaPoll, OptionsTooMuchError,
TypeInputPeer, InputMediaDice) TakeoutInitDelayError,
from telethon.tl.types.messages import BotCallbackAnswer UserAlreadyParticipantError,
from telethon.tl.functions.messages import (ImportChatInviteRequest, CheckChatInviteRequest, )
GetBotCallbackAnswerRequest, SendVoteRequest)
from telethon.tl.functions.channels import JoinChannelRequest from telethon.tl.functions.channels import JoinChannelRequest
from telethon.tl.functions.messages import (
CheckChatInviteRequest,
GetBotCallbackAnswerRequest,
ImportChatInviteRequest,
SendVoteRequest,
)
from telethon.tl.patched import Message
from telethon.tl.types import (
InputMediaDice,
MessageMediaGame,
MessageMediaPoll,
TypeInputPeer,
TypeUpdates,
User as TLUser,
)
from telethon.tl.types.messages import BotCallbackAnswer
from mautrix.types import EventID, Format from mautrix.types import EventID, Format
from ... import puppet as pu, portal as po from ... import portal as po, puppet as pu
from ...abstract_user import AbstractUser from ...abstract_user import AbstractUser
from ...commands import (
SECTION_CREATING_PORTALS,
SECTION_MISC,
SECTION_PORTAL_MANAGEMENT,
CommandEvent,
command_handler,
)
from ...db import Message as DBMessage from ...db import Message as DBMessage
from ...types import TelegramID from ...types import TelegramID
from ...commands import (command_handler, CommandEvent, SECTION_MISC, SECTION_CREATING_PORTALS,
SECTION_PORTAL_MANAGEMENT)
@command_handler(needs_auth=False, @command_handler(
help_section=SECTION_MISC, help_args="<_caption_>", needs_auth=False,
help_text="Set a caption for the next image you send") help_section=SECTION_MISC,
help_args="<_caption_>",
help_text="Set a caption for the next image you send",
)
async def caption(evt: CommandEvent) -> EventID: async def caption(evt: CommandEvent) -> EventID:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp caption <caption>`") return await evt.reply("**Usage:** `$cmdprefix+sp caption <caption>`")
@@ -55,13 +78,17 @@ async def caption(evt: CommandEvent) -> EventID:
evt.content.formatted_body = evt.content.formatted_body.replace(prefix, "", 1) evt.content.formatted_body = evt.content.formatted_body.replace(prefix, "", 1)
evt.content.body = evt.content.body.replace(prefix, "", 1) evt.content.body = evt.content.body.replace(prefix, "", 1)
evt.sender.command_status = {"caption": evt.content, "action": "Caption"} evt.sender.command_status = {"caption": evt.content, "action": "Caption"}
return await evt.reply("Your next image or file will be sent with that caption. " return await evt.reply(
"Use `$cmdprefix+sp cancel` to cancel the caption.") "Your next image or file will be sent with that caption. "
"Use `$cmdprefix+sp cancel` to cancel the caption."
)
@command_handler(help_section=SECTION_MISC, @command_handler(
help_args="[_-r|--remote_] <_query_>", help_section=SECTION_MISC,
help_text="Search your contacts or the Telegram servers for users.") help_args="[_-r|--remote_] <_query_>",
help_text="Search your contacts or the Telegram servers for users.",
)
async def search(evt: CommandEvent) -> EventID: async def search(evt: CommandEvent) -> EventID:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp search [-r|--remote] <query>`") return await evt.reply("**Usage:** `$cmdprefix+sp search [-r|--remote] <query>`")
@@ -79,8 +106,9 @@ async def search(evt: CommandEvent) -> EventID:
if not results: if not results:
if len(query) < 5 and remote: if len(query) < 5 and remote:
return await evt.reply("No local results. " return await evt.reply(
"Minimum length of remote query is 5 characters.") "No local results. Minimum length of remote query is 5 characters."
)
return await evt.reply("No results 3:") return await evt.reply("No results 3:")
reply: list[str] = [] reply: list[str] = []
@@ -88,20 +116,27 @@ async def search(evt: CommandEvent) -> EventID:
reply += ["**Results from Telegram server:**", ""] reply += ["**Results from Telegram server:**", ""]
else: else:
reply += ["**Results in contacts:**", ""] reply += ["**Results in contacts:**", ""]
reply += [(f"* [{puppet.displayname}](https://matrix.to/#/{puppet.mxid}): " reply += [
f"{puppet.id} ({similarity}% match)") (
for puppet, similarity in results] f"* [{puppet.displayname}](https://matrix.to/#/{puppet.mxid}): "
f"{puppet.id} ({similarity}% match)"
)
for puppet, similarity in results
]
# TODO somehow show remote channel results when joining by alias is possible? # TODO somehow show remote channel results when joining by alias is possible?
return await evt.reply("\n".join(reply)) return await evt.reply("\n".join(reply))
@command_handler(help_section=SECTION_CREATING_PORTALS, help_args="<_identifier_>", @command_handler(
help_text="Open a private chat with the given Telegram user. The identifier is " help_section=SECTION_CREATING_PORTALS,
"either the internal user ID, the username or the phone number. " help_args="<_identifier_>",
"**N.B.** The phone numbers you start chats with must already be in " help_text="Open a private chat with the given Telegram user. The identifier is "
"your contacts.") "either the internal user ID, the username or the phone number. "
"**N.B.** The phone numbers you start chats with must already be in "
"your contacts.",
)
async def pm(evt: CommandEvent) -> EventID: async def pm(evt: CommandEvent) -> EventID:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp pm <user identifier>`") return await evt.reply("**Usage:** `$cmdprefix+sp pm <user identifier>`")
@@ -122,8 +157,9 @@ async def pm(evt: CommandEvent) -> EventID:
return await evt.reply(f"Created private chat room with {displayname}") return await evt.reply(f"Created private chat room with {displayname}")
async def _join(evt: CommandEvent, identifier: str, link_type: str async def _join(
) -> tuple[TypeUpdates | None, EventID | None]: evt: CommandEvent, identifier: str, link_type: str
) -> tuple[TypeUpdates | None, EventID | None]:
if link_type == "joinchat": if link_type == "joinchat":
try: try:
await evt.sender.client(CheckChatInviteRequest(identifier)) await evt.sender.client(CheckChatInviteRequest(identifier))
@@ -142,9 +178,11 @@ async def _join(evt: CommandEvent, identifier: str, link_type: str
return await evt.sender.client(JoinChannelRequest(channel)), None return await evt.sender.client(JoinChannelRequest(channel)), None
@command_handler(help_section=SECTION_CREATING_PORTALS, @command_handler(
help_args="<_link_>", help_section=SECTION_CREATING_PORTALS,
help_text="Join a chat with an invite link.") help_args="<_link_>",
help_text="Join a chat with an invite link.",
)
async def join(evt: CommandEvent) -> EventID | None: async def join(evt: CommandEvent) -> EventID | None:
if len(evt.args) == 0: if len(evt.args) == 0:
return await evt.reply("**Usage:** `$cmdprefix+sp join <invite link>`") return await evt.reply("**Usage:** `$cmdprefix+sp join <invite link>`")
@@ -157,8 +195,10 @@ async def join(evt: CommandEvent) -> EventID | None:
except InvalidURL: except InvalidURL:
return await evt.reply("That doesn't look like a Telegram invite link.") return await evt.reply("That doesn't look like a Telegram invite link.")
regex = re.compile(r"(?:https?://)?t(?:elegram)?\.(?:dog|me)" regex = re.compile(
r"(?:/(?P<type>joinchat|s))?/(?P<id>[^/]+)/?", flags=re.IGNORECASE) r"(?:https?://)?t(?:elegram)?\.(?:dog|me)(?:/(?P<type>joinchat|s))?/(?P<id>[^/]+)/?",
flags=re.IGNORECASE,
)
arg = regex.match(url) arg = regex.match(url)
if not arg: if not arg:
return await evt.reply("That doesn't look like a Telegram invite link.") return await evt.reply("That doesn't look like a Telegram invite link.")
@@ -182,16 +222,20 @@ async def join(evt: CommandEvent) -> EventID | None:
try: try:
await portal.create_matrix_room(evt.sender, chat, [evt.sender.mxid]) await portal.create_matrix_room(evt.sender, chat, [evt.sender.mxid])
except ChatIdInvalidError as e: except ChatIdInvalidError as e:
evt.log.trace("ChatIdInvalidError while creating portal from !tg join command: %s", evt.log.trace(
updates.stringify()) "ChatIdInvalidError while creating portal from !tg join command: %s",
updates.stringify(),
)
raise e raise e
return await evt.reply(f"Created room for {portal.title}") return await evt.reply(f"Created room for {portal.title}")
return None return None
@command_handler(help_section=SECTION_MISC, @command_handler(
help_args="[`chats`|`contacts`|`me`]", help_section=SECTION_MISC,
help_text="Synchronize your chat portals, contacts and/or own info.") help_args="[`chats`|`contacts`|`me`]",
help_text="Synchronize your chat portals, contacts and/or own info.",
)
async def sync(evt: CommandEvent) -> EventID: async def sync(evt: CommandEvent) -> EventID:
if len(evt.args) > 0: if len(evt.args) > 0:
sync_only = evt.args[0] sync_only = evt.args[0]
@@ -220,8 +264,9 @@ class MessageIDError(ValueError):
self.message = message self.message = message
async def _parse_encoded_msgid(user: AbstractUser, enc_id: str, type_name: str async def _parse_encoded_msgid(
) -> tuple[TypeInputPeer, Message]: user: AbstractUser, enc_id: str, type_name: str
) -> tuple[TypeInputPeer, Message]:
try: try:
enc_id += (4 - len(enc_id) % 4) * "=" enc_id += (4 - len(enc_id) % 4) * "="
enc_id = base64.b64decode(enc_id) enc_id = base64.b64decode(enc_id)
@@ -253,9 +298,9 @@ async def _parse_encoded_msgid(user: AbstractUser, enc_id: str, type_name: str
return peer, cast(Message, msg) return peer, cast(Message, msg)
@command_handler(help_section=SECTION_MISC, @command_handler(
help_args="<_play ID_>", help_section=SECTION_MISC, help_args="<_play ID_>", help_text="Play a Telegram game."
help_text="Play a Telegram game.") )
async def play(evt: CommandEvent) -> EventID: async def play(evt: CommandEvent) -> EventID:
if len(evt.args) < 1: if len(evt.args) < 1:
return await evt.reply("**Usage:** `$cmdprefix+sp play <play ID>`") return await evt.reply("**Usage:** `$cmdprefix+sp play <play ID>`")
@@ -273,17 +318,22 @@ async def play(evt: CommandEvent) -> EventID:
return await evt.reply("Invalid play ID (message doesn't look like a game)") return await evt.reply("Invalid play ID (message doesn't look like a game)")
game = await evt.sender.client( game = await evt.sender.client(
GetBotCallbackAnswerRequest(peer=peer, msg_id=msg.id, game=True)) GetBotCallbackAnswerRequest(peer=peer, msg_id=msg.id, game=True)
)
if not isinstance(game, BotCallbackAnswer): if not isinstance(game, BotCallbackAnswer):
return await evt.reply("Game request response invalid") return await evt.reply("Game request response invalid")
return await evt.reply(f"Click [here]({game.url}) to play {msg.media.game.title}:\n\n" return await evt.reply(
f"{msg.media.game.description}") f"Click [here]({game.url}) to play {msg.media.game.title}:\n\n"
f"{msg.media.game.description}"
)
@command_handler(help_section=SECTION_MISC, @command_handler(
help_args="<_poll ID_> <_choice number_>", help_section=SECTION_MISC,
help_text="Vote in a Telegram poll.") help_args="<_poll ID_> <_choice number_>",
help_text="Vote in a Telegram poll.",
)
async def vote(evt: CommandEvent) -> EventID | None: async def vote(evt: CommandEvent) -> EventID | None:
if len(evt.args) < 1: if len(evt.args) < 1:
return await evt.reply("**Usage:** `$cmdprefix+sp vote <poll ID> <choice number>`") return await evt.reply("**Usage:** `$cmdprefix+sp vote <poll ID> <choice number>`")
@@ -309,17 +359,20 @@ async def vote(evt: CommandEvent) -> EventID | None:
except ValueError: except ValueError:
option_index = None option_index = None
if option_index is None: if option_index is None:
return await evt.reply(f"Invalid option number \"{option}\"", return await evt.reply(
render_markdown=False, allow_html=False) f'Invalid option number "{option}"', render_markdown=False, allow_html=False
)
elif option_index < 0: elif option_index < 0:
return await evt.reply(f"Invalid option number {option}. " return await evt.reply(
f"Option numbers must be positive.") f"Invalid option number {option}. Option numbers must be positive."
)
elif option_index >= len(msg.media.poll.answers): elif option_index >= len(msg.media.poll.answers):
return await evt.reply(f"Invalid option number {option}. " return await evt.reply(
f"The poll only has {len(msg.media.poll.answers)} options.") f"Invalid option number {option}. "
f"The poll only has {len(msg.media.poll.answers)} options."
)
options.append(msg.media.poll.answers[option_index].option) options.append(msg.media.poll.answers[option_index].option)
options = [msg.media.poll.answers[int(option) - 1].option options = [msg.media.poll.answers[int(option) - 1].option for option in evt.args[1:]]
for option in evt.args[1:]]
try: try:
await evt.sender.client(SendVoteRequest(peer=peer, msg_id=msg.id, options=options)) await evt.sender.client(SendVoteRequest(peer=peer, msg_id=msg.id, options=options))
except OptionsTooMuchError: except OptionsTooMuchError:
@@ -328,9 +381,12 @@ async def vote(evt: CommandEvent) -> EventID | None:
return await evt.mark_read() return await evt.mark_read()
@command_handler(help_section=SECTION_MISC, help_args="<_emoji_>", @command_handler(
help_text="Roll a dice (\U0001F3B2), kick a football (\u26BD\uFE0F) or throw a " help_section=SECTION_MISC,
"dart (\U0001F3AF) or basketball (\U0001F3C0) on the Telegram servers.") help_args="<_emoji_>",
help_text="Roll a dice (\U0001F3B2), kick a football (\u26BD\uFE0F) or throw a "
"dart (\U0001F3AF) or basketball (\U0001F3C0) on the Telegram servers.",
)
async def random(evt: CommandEvent) -> EventID: async def random(evt: CommandEvent) -> EventID:
if not evt.is_portal: if not evt.is_portal:
return await evt.reply("You can only randomize values in portal rooms") return await evt.reply("You can only randomize values in portal rooms")
@@ -345,14 +401,18 @@ async def random(evt: CommandEvent) -> EventID:
"soccer": "\u26BD", "soccer": "\u26BD",
}.get(arg, arg) }.get(arg, arg)
try: try:
await evt.sender.client.send_media(await portal.get_input_entity(evt.sender), await evt.sender.client.send_media(
InputMediaDice(emoticon)) await portal.get_input_entity(evt.sender), InputMediaDice(emoticon)
)
except EmoticonInvalidError: except EmoticonInvalidError:
return await evt.reply("Invalid emoji for randomization") return await evt.reply("Invalid emoji for randomization")
@command_handler(help_section=SECTION_PORTAL_MANAGEMENT, help_args="[_limit_]", @command_handler(
help_text="Backfill messages from Telegram history.") help_section=SECTION_PORTAL_MANAGEMENT,
help_args="[_limit_]",
help_text="Backfill messages from Telegram history.",
)
async def backfill(evt: CommandEvent) -> None: async def backfill(evt: CommandEvent) -> None:
if not evt.is_portal: if not evt.is_portal:
await evt.reply("You can only use backfill in portal rooms") await evt.reply("You can only use backfill in portal rooms")
@@ -368,10 +428,13 @@ async def backfill(evt: CommandEvent) -> None:
try: try:
await portal.backfill(evt.sender, limit=limit) await portal.backfill(evt.sender, limit=limit)
except TakeoutInitDelayError: except TakeoutInitDelayError:
msg = ("Please accept the data export request from a mobile device, " msg = (
"then re-run the backfill command.") "Please accept the data export request from a mobile device, "
"then re-run the backfill command."
)
if portal.peer_type == "user": if portal.peer_type == "user":
from mautrix.appservice import IntentAPI from mautrix.appservice import IntentAPI
await portal.main_intent.send_notice(evt.room_id, msg) await portal.main_intent.send_notice(evt.room_id, msg)
else: else:
await evt.reply(msg) await evt.reply(msg)
+30 -14
View File
@@ -14,16 +14,24 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Any, List, NamedTuple from typing import Any, List, NamedTuple
from ruamel.yaml.comments import CommentedMap
import os import os
from mautrix.types import UserID from ruamel.yaml.comments import CommentedMap
from mautrix.client import Client
from mautrix.bridge.config import BaseBridgeConfig
from mautrix.util.config import ForbiddenKey, ForbiddenDefault, ConfigUpdateHelper
Permissions = NamedTuple("Permissions", relaybot=bool, user=bool, puppeting=bool, from mautrix.bridge.config import BaseBridgeConfig
matrix_puppeting=bool, admin=bool, level=str) from mautrix.client import Client
from mautrix.types import UserID
from mautrix.util.config import ConfigUpdateHelper, ForbiddenDefault, ForbiddenKey
Permissions = NamedTuple(
"Permissions",
relaybot=bool,
user=bool,
puppeting=bool,
matrix_puppeting=bool,
admin=bool,
level=str,
)
class Config(BaseBridgeConfig): class Config(BaseBridgeConfig):
@@ -37,8 +45,11 @@ class Config(BaseBridgeConfig):
def forbidden_defaults(self) -> List[ForbiddenDefault]: def forbidden_defaults(self) -> List[ForbiddenDefault]:
return [ return [
*super().forbidden_defaults, *super().forbidden_defaults,
ForbiddenDefault("appservice.public.external", "https://example.com/public", ForbiddenDefault(
condition="appservice.public.enabled"), "appservice.public.external",
"https://example.com/public",
condition="appservice.public.enabled",
),
ForbiddenDefault("bridge.permissions", ForbiddenKey("example.com")), ForbiddenDefault("bridge.permissions", ForbiddenKey("example.com")),
ForbiddenDefault("telegram.api_id", 12345), ForbiddenDefault("telegram.api_id", 12345),
ForbiddenDefault("telegram.api_hash", "tjyd5yge35lbodk1xwzw2jstp90k55qz"), ForbiddenDefault("telegram.api_hash", "tjyd5yge35lbodk1xwzw2jstp90k55qz"),
@@ -51,8 +62,11 @@ class Config(BaseBridgeConfig):
copy("homeserver.asmux") copy("homeserver.asmux")
if "appservice.protocol" in self and "appservice.address" not in self: if "appservice.protocol" in self and "appservice.address" not in self:
protocol, hostname, port = (self["appservice.protocol"], self["appservice.hostname"], protocol, hostname, port = (
self["appservice.port"]) self["appservice.protocol"],
self["appservice.hostname"],
self["appservice.port"],
)
base["appservice.address"] = f"{protocol}://{hostname}:{port}" base["appservice.address"] = f"{protocol}://{hostname}:{port}"
if "appservice.debug" in self and "logging" not in self: if "appservice.debug" in self and "logging" not in self:
level = "DEBUG" if self["appservice.debug"] else "INFO" level = "DEBUG" if self["appservice.debug"] else "INFO"
@@ -170,9 +184,11 @@ class Config(BaseBridgeConfig):
copy("bridge.command_prefix") copy("bridge.command_prefix")
migrate_permissions = ("bridge.permissions" not in self migrate_permissions = (
or "bridge.whitelist" in self "bridge.permissions" not in self
or "bridge.admins" in self) or "bridge.whitelist" in self
or "bridge.admins" in self
)
if migrate_permissions: if migrate_permissions:
permissions = self["bridge.permissions"] or CommentedMap() permissions = self["bridge.permissions"] or CommentedMap()
for entry in self["bridge.whitelist"] or []: for entry in self["bridge.whitelist"] or []:
+13 -5
View File
@@ -15,15 +15,14 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.util.async_db import Database from mautrix.util.async_db import Database
from .upgrade import upgrade_table
from .bot_chat import BotChat from .bot_chat import BotChat
from .message import Message from .message import Message
from .portal import Portal from .portal import Portal
from .puppet import Puppet from .puppet import Puppet
from .telegram_file import TelegramFile from .telegram_file import TelegramFile
from .user import User
from .telethon_session import PgSession from .telethon_session import PgSession
from .upgrade import upgrade_table
from .user import User
def init(db: Database) -> None: def init(db: Database) -> None:
@@ -31,5 +30,14 @@ def init(db: Database) -> None:
table.db = db table.db = db
__all__ = ["upgrade_table", "init", "Portal", "Message", "User", "Puppet", "TelegramFile", __all__ = [
"BotChat", "PgSession"] "upgrade_table",
"init",
"Portal",
"Message",
"User",
"Puppet",
"TelegramFile",
"BotChat",
"PgSession",
]
+1 -1
View File
@@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from __future__ import annotations
from typing import ClassVar, TYPE_CHECKING from typing import TYPE_CHECKING, ClassVar
from asyncpg import Record from asyncpg import Record
from attr import dataclass from attr import dataclass
+8 -5
View File
@@ -15,12 +15,12 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from __future__ import annotations
from typing import ClassVar, TYPE_CHECKING from typing import TYPE_CHECKING, ClassVar
from asyncpg import Record from asyncpg import Record
from attr import dataclass from attr import dataclass
from mautrix.types import RoomID, EventID from mautrix.types import EventID, RoomID
from mautrix.util.async_db import Database from mautrix.util.async_db import Database
from ..types import TelegramID from ..types import TelegramID
@@ -92,9 +92,12 @@ class Message:
@classmethod @classmethod
async def count_spaces_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> int: async def count_spaces_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> int:
return await cls.db.fetchval( return (
"SELECT COUNT(tg_space) FROM message WHERE mxid=$1 AND mx_room=$2", mxid, mx_room await cls.db.fetchval(
) or 0 "SELECT COUNT(tg_space) FROM message WHERE mxid=$1 AND mx_room=$2", mxid, mx_room
)
or 0
)
@classmethod @classmethod
async def find_last(cls, mx_room: RoomID, tg_space: TelegramID) -> Message | None: async def find_last(cls, mx_room: RoomID, tg_space: TelegramID) -> Message | None:
+16 -5
View File
@@ -15,14 +15,14 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from __future__ import annotations
from typing import ClassVar, Any, TYPE_CHECKING from typing import TYPE_CHECKING, Any, ClassVar
import json import json
from asyncpg import Record from asyncpg import Record
from attr import dataclass from attr import dataclass
import attr import attr
from mautrix.types import RoomID, ContentURI from mautrix.types import ContentURI, RoomID
from mautrix.util.async_db import Database from mautrix.util.async_db import Database
from ..types import TelegramID from ..types import TelegramID
@@ -93,9 +93,20 @@ class Portal:
@property @property
def _values(self): def _values(self):
return (self.tgid, self.tg_receiver, self.peer_type, self.mxid, self.avatar_url, return (
self.encrypted, self.username, self.title, self.about, self.photo_id, self.tgid,
self.megagroup, json.dumps(self.local_config) if self.local_config else None) self.tg_receiver,
self.peer_type,
self.mxid,
self.avatar_url,
self.encrypted,
self.username,
self.title,
self.about,
self.photo_id,
self.megagroup,
json.dumps(self.local_config) if self.local_config else None,
)
async def save(self) -> None: async def save(self) -> None:
q = ( q = (
+18 -6
View File
@@ -15,13 +15,13 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from __future__ import annotations
from typing import ClassVar, TYPE_CHECKING from typing import TYPE_CHECKING, ClassVar
from asyncpg import Record from asyncpg import Record
from attr import dataclass from attr import dataclass
from yarl import URL from yarl import URL
from mautrix.types import UserID, SyncToken from mautrix.types import SyncToken, UserID
from mautrix.util.async_db import Database from mautrix.util.async_db import Database
from ..types import TelegramID from ..types import TelegramID
@@ -92,10 +92,22 @@ class Puppet:
@property @property
def _values(self): def _values(self):
return (self.id, self.is_registered, self.displayname, self.displayname_source, return (
self.displayname_contact, self.displayname_quality, self.disable_updates, self.id,
self.username, self.photo_id, self.is_bot, self.custom_mxid, self.access_token, self.is_registered,
self.next_batch, str(self.base_url) if self.base_url else None) self.displayname,
self.displayname_source,
self.displayname_contact,
self.displayname_quality,
self.disable_updates,
self.username,
self.photo_id,
self.is_bot,
self.custom_mxid,
self.access_token,
self.next_batch,
str(self.base_url) if self.base_url else None,
)
async def save(self) -> None: async def save(self) -> None:
q = ( q = (
+13 -5
View File
@@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from __future__ import annotations
from typing import ClassVar, TYPE_CHECKING from typing import TYPE_CHECKING, ClassVar
from attr import dataclass from attr import dataclass
@@ -68,7 +68,15 @@ class TelegramFile:
" thumbnail, decryption_info) " " thumbnail, decryption_info) "
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
) )
await self.db.execute(q, self.id, self.mxc, self.mime_type, self.was_converted, self.size, await self.db.execute(
self.width, self.height, q,
self.thumbnail.id if self.thumbnail else None, self.id,
self.decryption_info.json() if self.decryption_info else None) self.mxc,
self.mime_type,
self.was_converted,
self.size,
self.width,
self.height,
self.thumbnail.id if self.thumbnail else None,
self.decryption_info.json() if self.decryption_info else None,
)
+10 -7
View File
@@ -15,14 +15,14 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from __future__ import annotations
from typing import ClassVar, TYPE_CHECKING from typing import TYPE_CHECKING, ClassVar
import datetime
import asyncio import asyncio
import datetime
from telethon.sessions import MemorySession
from telethon.tl.types import updates, PeerUser, PeerChat, PeerChannel
from telethon.crypto import AuthKey
from telethon import utils from telethon import utils
from telethon.crypto import AuthKey
from telethon.sessions import MemorySession
from telethon.tl.types import PeerChannel, PeerChat, PeerUser, updates
from mautrix.util.async_db import Database from mautrix.util.async_db import Database
@@ -97,7 +97,10 @@ class PgSession(MemorySession):
) )
_tables: ClassVar[tuple[str, ...]] = ( _tables: ClassVar[tuple[str, ...]] = (
"telethon_sessions", "telethon_entities", "telethon_sent_files", "telethon_update_state" "telethon_sessions",
"telethon_entities",
"telethon_sent_files",
"telethon_update_state",
) )
async def delete(self) -> None: async def delete(self) -> None:
@@ -196,7 +199,7 @@ class PgSession(MemorySession):
ids = ( ids = (
utils.get_peer_id(PeerUser(key)), utils.get_peer_id(PeerUser(key)),
utils.get_peer_id(PeerChat(key)), utils.get_peer_id(PeerChat(key)),
utils.get_peer_id(PeerChannel(key)) utils.get_peer_id(PeerChannel(key)),
) )
if self.db.scheme == "postgres": if self.db.scheme == "postgres":
return await self._select_entity("id=ANY($1)", ids) return await self._select_entity("id=ANY($1)", ids)
@@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from asyncpg import Connection from asyncpg import Connection
from . import upgrade_table from . import upgrade_table
legacy_version_query = "SELECT version_num FROM alembic_version" legacy_version_query = "SELECT version_num FROM alembic_version"
@@ -40,8 +41,10 @@ async def upgrade_v1(conn: Connection, scheme: str) -> None:
async def migrate_legacy_to_v1(conn: Connection, scheme: str) -> None: async def migrate_legacy_to_v1(conn: Connection, scheme: str) -> None:
legacy_version = await conn.fetchval(legacy_version_query) legacy_version = await conn.fetchval(legacy_version_query)
if legacy_version != last_legacy_version: if legacy_version != last_legacy_version:
raise RuntimeError("Legacy database is not on last version. Please upgrade the old " raise RuntimeError(
"database with alembic or drop it completely first.") "Legacy database is not on last version. "
"Please upgrade the old database with alembic or drop it completely first."
)
if scheme != "sqlite": if scheme != "sqlite":
await conn.execute( await conn.execute(
""" """
@@ -128,13 +131,24 @@ async def varchar_to_text(conn: Connection) -> None:
columns_to_adjust = { columns_to_adjust = {
"user": ("mxid", "tg_username", "tg_phone"), "user": ("mxid", "tg_username", "tg_phone"),
"portal": ( "portal": (
"peer_type", "mxid", "username", "title", "about", "photo_id", "avatar_url", "config" "peer_type",
"mxid",
"username",
"title",
"about",
"photo_id",
"avatar_url",
"config",
), ),
"message": ("mxid", "mx_room"), "message": ("mxid", "mx_room"),
"puppet": ( "puppet": (
"displayname", "username", "photo_id", "displayname",
) + ( "username",
"access_token", "custom_mxid", "next_batch", "base_url" "photo_id",
"access_token",
"custom_mxid",
"next_batch",
"base_url",
), ),
"bot_chat": ("type",), "bot_chat": ("type",),
"telegram_file": ("id", "mxc", "mime_type", "thumbnail"), "telegram_file": ("id", "mxc", "mime_type", "thumbnail"),
+13 -6
View File
@@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from __future__ import annotations
from typing import Iterable, ClassVar, TYPE_CHECKING from typing import TYPE_CHECKING, ClassVar, Iterable
from asyncpg import Record from asyncpg import Record
from attr import dataclass from attr import dataclass
@@ -73,20 +73,25 @@ class User:
@property @property
def _values(self): def _values(self):
return ( return (
self.mxid, self.tgid, self.tg_username, self.tg_phone, self.is_bot, self.saved_contacts self.mxid,
self.tgid,
self.tg_username,
self.tg_phone,
self.is_bot,
self.saved_contacts,
) )
async def save(self) -> None: async def save(self) -> None:
q = ( q = (
'UPDATE "user" SET tgid=$2, tg_username=$3, tg_phone=$4, is_bot=$5, saved_contacts=$6 ' 'UPDATE "user" SET tgid=$2, tg_username=$3, tg_phone=$4, is_bot=$5, saved_contacts=$6 '
'WHERE mxid=$1' "WHERE mxid=$1"
) )
await self.db.execute(q, *self._values) await self.db.execute(q, *self._values)
async def insert(self) -> None: async def insert(self) -> None:
q = ( q = (
'INSERT INTO "user" (mxid, tgid, tg_username, tg_phone, is_bot, saved_contacts) ' 'INSERT INTO "user" (mxid, tgid, tg_username, tg_phone, is_bot, saved_contacts) '
'VALUES ($1, $2, $3, $4, $5, $6)' "VALUES ($1, $2, $3, $4, $5, $6)"
) )
await self.db.execute(q, *self._values) await self.db.execute(q, *self._values)
@@ -122,8 +127,10 @@ class User:
await conn.executemany(q, records) await conn.executemany(q, records)
async def register_portal(self, tgid: TelegramID, tg_receiver: TelegramID) -> None: async def register_portal(self, tgid: TelegramID, tg_receiver: TelegramID) -> None:
q = ('INSERT INTO user_portal ("user", portal, portal_receiver) VALUES ($1, $2, $3) ' q = (
'ON CONFLICT ("user", portal, portal_receiver) DO NOTHING') 'INSERT INTO user_portal ("user", portal, portal_receiver) VALUES ($1, $2, $3) '
'ON CONFLICT ("user", portal, portal_receiver) DO NOTHING'
)
await self.db.execute(q, self.tgid, tgid, tg_receiver) await self.db.execute(q, self.tgid, tgid, tg_receiver)
async def unregister_portal(self, tgid: TelegramID, tg_receiver: TelegramID) -> None: async def unregister_portal(self, tgid: TelegramID, tg_receiver: TelegramID) -> None:
@@ -17,14 +17,14 @@ from __future__ import annotations
import re import re
from telethon.tl.types import MessageEntityItalic, TypeMessageEntity
from telethon.helpers import add_surrogate, del_surrogate
from telethon import TelegramClient from telethon import TelegramClient
from telethon.helpers import add_surrogate, del_surrogate
from telethon.tl.types import MessageEntityItalic, TypeMessageEntity
from mautrix.types import RoomID, MessageEventContent from mautrix.types import MessageEventContent, RoomID
from ...types import TelegramID
from ...db import Message as DBMessage from ...db import Message as DBMessage
from ...types import TelegramID
from .parser import MatrixParser from .parser import MatrixParser
command_regex = re.compile(r"^!([A-Za-z0-9@]+)") command_regex = re.compile(r"^!([A-Za-z0-9@]+)")
@@ -19,13 +19,13 @@ import logging
from telethon import TelegramClient from telethon import TelegramClient
from mautrix.types import UserID, RoomID from mautrix.types import RoomID, UserID
from mautrix.util.formatter import MatrixParser as BaseMatrixParser, RecursionContext from mautrix.util.formatter import MatrixParser as BaseMatrixParser, RecursionContext
from mautrix.util.formatter.html_reader_htmlparser import read_html, HTMLNode from mautrix.util.formatter.html_reader_htmlparser import HTMLNode, read_html
from mautrix.util.logging import TraceLogger from mautrix.util.logging import TraceLogger
from ... import user as u, puppet as pu, portal as po from ... import portal as po, puppet as pu, user as u
from .telegram_message import TelegramMessage, TelegramEntityType from .telegram_message import TelegramEntityType, TelegramMessage
log: TraceLogger = logging.getLogger("mau.fmt.mx") log: TraceLogger = logging.getLogger("mau.fmt.mx")
@@ -48,8 +48,9 @@ class MatrixParser(BaseMatrixParser[TelegramMessage]):
return None return None
async def user_pill_to_fstring(self, msg: TelegramMessage, user_id: UserID) -> TelegramMessage: async def user_pill_to_fstring(self, msg: TelegramMessage, user_id: UserID) -> TelegramMessage:
user = (await pu.Puppet.get_by_mxid(user_id) user = await pu.Puppet.get_by_mxid(user_id) or await u.User.get_by_mxid(
or await u.User.get_by_mxid(user_id, create=False)) user_id, create=False
)
if not user: if not user:
return msg return msg
if user.tg_username: if user.tg_username:
@@ -18,20 +18,30 @@ from __future__ import annotations
from typing import Any, Type from typing import Any, Type
from enum import Enum from enum import Enum
from telethon.tl.types import (MessageEntityMention as Mention, MessageEntityBotCommand as Command, from telethon.tl.types import (
MessageEntityMentionName as MentionName, MessageEntityUrl as URL, InputMessageEntityMentionName as InputMentionName,
MessageEntityEmail as Email, MessageEntityTextUrl as TextURL, MessageEntityBlockquote as Blockquote,
MessageEntityBold as Bold, MessageEntityItalic as Italic, MessageEntityBold as Bold,
MessageEntityCode as Code, MessageEntityPre as Pre, MessageEntityBotCommand as Command,
MessageEntityStrike as Strike, MessageEntityUnderline as Underline, MessageEntityCode as Code,
MessageEntityBlockquote as Blockquote, TypeMessageEntity, MessageEntityEmail as Email,
InputMessageEntityMentionName as InputMentionName) MessageEntityItalic as Italic,
MessageEntityMention as Mention,
MessageEntityMentionName as MentionName,
MessageEntityPre as Pre,
MessageEntityStrike as Strike,
MessageEntityTextUrl as TextURL,
MessageEntityUnderline as Underline,
MessageEntityUrl as URL,
TypeMessageEntity,
)
from mautrix.util.formatter import EntityString, SemiAbstractEntity from mautrix.util.formatter import EntityString, SemiAbstractEntity
class TelegramEntityType(Enum): class TelegramEntityType(Enum):
"""EntityType is a Matrix formatting entity type.""" """EntityType is a Matrix formatting entity type."""
BOLD = Bold BOLD = Bold
ITALIC = Italic ITALIC = Italic
STRIKETHROUGH = Strike STRIKETHROUGH = Strike
@@ -54,8 +64,13 @@ class TelegramEntityType(Enum):
class TelegramEntity(SemiAbstractEntity): class TelegramEntity(SemiAbstractEntity):
internal: TypeMessageEntity internal: TypeMessageEntity
def __init__(self, type: TelegramEntityType | Type[TypeMessageEntity], def __init__(
offset: int, length: int, extra_info: dict[str, Any]) -> None: self,
type: TelegramEntityType | Type[TypeMessageEntity],
offset: int,
length: int,
extra_info: dict[str, Any],
) -> None:
if isinstance(type, TelegramEntityType): if isinstance(type, TelegramEntityType):
if isinstance(type.value, int): if isinstance(type.value, int):
raise ValueError(f"Can't create Entity with non-Telegram EntityType {type}") raise ValueError(f"Can't create Entity with non-Telegram EntityType {type}")
@@ -70,8 +85,12 @@ class TelegramEntity(SemiAbstractEntity):
extra_info["url"] = self.internal.url extra_info["url"] = self.internal.url
elif isinstance(self.internal, (MentionName, InputMentionName)): elif isinstance(self.internal, (MentionName, InputMentionName)):
extra_info["user_id"] = self.internal.user_id extra_info["user_id"] = self.internal.user_id
return TelegramEntity(type(self.internal), offset=self.internal.offset, return TelegramEntity(
length=self.internal.length, extra_info=extra_info) type(self.internal),
offset=self.internal.offset,
length=self.internal.length,
extra_info=extra_info,
)
def __repr__(self) -> str: def __repr__(self) -> str:
return str(self.internal) return str(self.internal)
+96 -51
View File
@@ -1,5 +1,5 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge # mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2021 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@@ -19,41 +19,66 @@ from html import escape
import logging import logging
import re import re
from telethon.tl.types import (MessageEntityMention, MessageEntityMentionName, MessageEntityUrl,
MessageEntityEmail, MessageEntityTextUrl, MessageEntityBold,
MessageEntityItalic, MessageEntityCode, MessageEntityPre,
MessageEntityBotCommand, MessageEntityHashtag, MessageEntityCashtag,
MessageEntityPhone, TypeMessageEntity, PeerChannel, PeerChat,
MessageEntityBlockquote, MessageEntityStrike, MessageFwdHeader,
MessageEntityUnderline, PeerUser)
from telethon.tl.custom import Message
from telethon.errors import RPCError from telethon.errors import RPCError
from telethon.helpers import add_surrogate, del_surrogate from telethon.helpers import add_surrogate, del_surrogate
from telethon.tl.custom import Message
from telethon.tl.types import (
MessageEntityBlockquote,
MessageEntityBold,
MessageEntityBotCommand,
MessageEntityCashtag,
MessageEntityCode,
MessageEntityEmail,
MessageEntityHashtag,
MessageEntityItalic,
MessageEntityMention,
MessageEntityMentionName,
MessageEntityPhone,
MessageEntityPre,
MessageEntityStrike,
MessageEntityTextUrl,
MessageEntityUnderline,
MessageEntityUrl,
MessageFwdHeader,
PeerChannel,
PeerChat,
PeerUser,
TypeMessageEntity,
)
from mautrix.appservice import IntentAPI from mautrix.appservice import IntentAPI
from mautrix.types import (TextMessageEventContent, RelatesTo, RelationType, Format, MessageType, from mautrix.types import (
EventType) EventType,
Format,
MessageType,
RelatesTo,
RelationType,
TextMessageEventContent,
)
from .. import user as u, puppet as pu, portal as po, abstract_user as au from .. import abstract_user as au, portal as po, puppet as pu, user as u
from ..types import TelegramID
from ..db import Message as DBMessage from ..db import Message as DBMessage
from ..types import TelegramID
log: logging.Logger = logging.getLogger("mau.fmt.tg") log: logging.Logger = logging.getLogger("mau.fmt.tg")
async def telegram_reply_to_matrix(evt: Message, source: au.AbstractUser) -> RelatesTo | None: async def telegram_reply_to_matrix(evt: Message, source: au.AbstractUser) -> RelatesTo | None:
if evt.reply_to: if evt.reply_to:
space = (evt.peer_id.channel_id space = (
if isinstance(evt, Message) and isinstance(evt.peer_id, PeerChannel) evt.peer_id.channel_id
else source.tgid) if isinstance(evt, Message) and isinstance(evt.peer_id, PeerChannel)
else source.tgid
)
msg = await DBMessage.get_one_by_tgid(TelegramID(evt.reply_to.reply_to_msg_id), space) msg = await DBMessage.get_one_by_tgid(TelegramID(evt.reply_to.reply_to_msg_id), space)
if msg: if msg:
return RelatesTo(rel_type=RelationType.REPLY, event_id=msg.mxid) return RelatesTo(rel_type=RelationType.REPLY, event_id=msg.mxid)
return None return None
async def _add_forward_header(source: au.AbstractUser, content: TextMessageEventContent, async def _add_forward_header(
fwd_from: MessageFwdHeader) -> None: source: au.AbstractUser, content: TextMessageEventContent, fwd_from: MessageFwdHeader
) -> None:
if not content.formatted_body or content.format != Format.HTML: if not content.formatted_body or content.format != Format.HTML:
content.format = Format.HTML content.format = Format.HTML
content.formatted_body = escape(content.body) content.formatted_body = escape(content.body)
@@ -62,8 +87,9 @@ async def _add_forward_header(source: au.AbstractUser, content: TextMessageEvent
user = await u.User.get_by_tgid(TelegramID(fwd_from.from_id.user_id)) user = await u.User.get_by_tgid(TelegramID(fwd_from.from_id.user_id))
if user: if user:
fwd_from_text = user.displayname or user.mxid fwd_from_text = user.displayname or user.mxid
fwd_from_html = (f"<a href='https://matrix.to/#/{user.mxid}'>" fwd_from_html = (
f"{escape(fwd_from_text)}</a>") f"<a href='https://matrix.to/#/{user.mxid}'>{escape(fwd_from_text)}</a>"
)
if not fwd_from_text: if not fwd_from_text:
puppet = await pu.Puppet.get_by_tgid( puppet = await pu.Puppet.get_by_tgid(
@@ -71,8 +97,9 @@ async def _add_forward_header(source: au.AbstractUser, content: TextMessageEvent
) )
if puppet and puppet.displayname: if puppet and puppet.displayname:
fwd_from_text = puppet.displayname or puppet.mxid fwd_from_text = puppet.displayname or puppet.mxid
fwd_from_html = (f"<a href='https://matrix.to/#/{puppet.mxid}'>" fwd_from_html = (
f"{escape(fwd_from_text)}</a>") f"<a href='https://matrix.to/#/{puppet.mxid}'>{escape(fwd_from_text)}</a>"
)
if not fwd_from_text: if not fwd_from_text:
try: try:
@@ -83,14 +110,18 @@ async def _add_forward_header(source: au.AbstractUser, content: TextMessageEvent
except (ValueError, RPCError): except (ValueError, RPCError):
fwd_from_text = fwd_from_html = "unknown user" fwd_from_text = fwd_from_html = "unknown user"
elif isinstance(fwd_from.from_id, (PeerChannel, PeerChat)): elif isinstance(fwd_from.from_id, (PeerChannel, PeerChat)):
from_id = (fwd_from.from_id.chat_id if isinstance(fwd_from.from_id, PeerChat) from_id = (
else fwd_from.from_id.channel_id) fwd_from.from_id.chat_id
if isinstance(fwd_from.from_id, PeerChat)
else fwd_from.from_id.channel_id
)
portal = await po.Portal.get_by_tgid(TelegramID(from_id)) portal = await po.Portal.get_by_tgid(TelegramID(from_id))
if portal and portal.title: if portal and portal.title:
fwd_from_text = portal.title fwd_from_text = portal.title
if portal.alias: if portal.alias:
fwd_from_html = (f"<a href='https://matrix.to/#/{portal.alias}'>" fwd_from_html = (
f"{escape(fwd_from_text)}</a>") f"<a href='https://matrix.to/#/{portal.alias}'>{escape(fwd_from_text)}</a>"
)
else: else:
fwd_from_html = f"channel <b>{escape(fwd_from_text)}</b>" fwd_from_html = f"channel <b>{escape(fwd_from_text)}</b>"
else: else:
@@ -112,14 +143,18 @@ async def _add_forward_header(source: au.AbstractUser, content: TextMessageEvent
content.body = f"Forwarded from {fwd_from_text}:\n{content.body}" content.body = f"Forwarded from {fwd_from_text}:\n{content.body}"
content.formatted_body = ( content.formatted_body = (
f"Forwarded message from {fwd_from_html}<br/>" f"Forwarded message from {fwd_from_html}<br/>"
f"<tg-forward><blockquote>{content.formatted_body}</blockquote></tg-forward>") f"<tg-forward><blockquote>{content.formatted_body}</blockquote></tg-forward>"
)
async def _add_reply_header(source: au.AbstractUser, content: TextMessageEventContent, async def _add_reply_header(
evt: Message, main_intent: IntentAPI) -> None: source: au.AbstractUser, content: TextMessageEventContent, evt: Message, main_intent: IntentAPI
space = (evt.peer_id.channel_id ) -> None:
if isinstance(evt, Message) and isinstance(evt.peer_id, PeerChannel) space = (
else source.tgid) evt.peer_id.channel_id
if isinstance(evt, Message) and isinstance(evt.peer_id, PeerChannel)
else source.tgid
)
msg = await DBMessage.get_one_by_tgid(TelegramID(evt.reply_to.reply_to_msg_id), space) msg = await DBMessage.get_one_by_tgid(TelegramID(evt.reply_to.reply_to_msg_id), space)
if not msg: if not msg:
@@ -139,12 +174,16 @@ async def _add_reply_header(source: au.AbstractUser, content: TextMessageEventCo
log.exception("Failed to get event to add reply fallback") log.exception("Failed to get event to add reply fallback")
async def telegram_to_matrix(evt: Message, source: au.AbstractUser, async def telegram_to_matrix(
main_intent: IntentAPI | None = None, evt: Message,
prefix_text: str | None = None, prefix_html: str | None = None, source: au.AbstractUser,
override_text: str = None, main_intent: IntentAPI | None = None,
override_entities: list[TypeMessageEntity] = None, prefix_text: str | None = None,
no_reply_fallback: bool = False) -> TextMessageEventContent: prefix_html: str | None = None,
override_text: str = None,
override_entities: list[TypeMessageEntity] = None,
no_reply_fallback: bool = False,
) -> TextMessageEventContent:
content = TextMessageEventContent( content = TextMessageEventContent(
msgtype=MessageType.TEXT, msgtype=MessageType.TEXT,
body=add_surrogate(override_text or evt.message), body=add_surrogate(override_text or evt.message),
@@ -186,15 +225,15 @@ async def _telegram_entities_to_matrix_catch(text: str, entities: list[TypeMessa
try: try:
return await _telegram_entities_to_matrix(text, entities) return await _telegram_entities_to_matrix(text, entities)
except Exception: except Exception:
log.exception("Failed to convert Telegram format:\n" log.exception(
"message=%s\n" "Failed to convert Telegram format:\nmessage=%s\nentities=%s", text, entities
"entities=%s", )
text, entities)
return "[failed conversion in _telegram_entities_to_matrix]" return "[failed conversion in _telegram_entities_to_matrix]"
async def _telegram_entities_to_matrix(text: str, entities: list[TypeMessageEntity], async def _telegram_entities_to_matrix(
offset: int = 0, length: int = None) -> str: text: str, entities: list[TypeMessageEntity], offset: int = 0, length: int = None
) -> str:
if not entities: if not entities:
return escape(text) return escape(text)
if length is None: if length is None:
@@ -212,8 +251,11 @@ async def _telegram_entities_to_matrix(text: str, entities: list[TypeMessageEnti
skip_entity = False skip_entity = False
entity_text = await _telegram_entities_to_matrix( entity_text = await _telegram_entities_to_matrix(
text=text[relative_offset:relative_offset + entity.length], text=text[relative_offset : relative_offset + entity.length],
entities=entities[i + 1:], offset=entity.offset, length=entity.length) entities=entities[i + 1 :],
offset=entity.offset,
length=entity.length,
)
entity_type = type(entity) entity_type = type(entity)
if entity_type == MessageEntityBold: if entity_type == MessageEntityBold:
@@ -227,9 +269,11 @@ async def _telegram_entities_to_matrix(text: str, entities: list[TypeMessageEnti
elif entity_type == MessageEntityBlockquote: elif entity_type == MessageEntityBlockquote:
html.append(f"<blockquote>{entity_text}</blockquote>") html.append(f"<blockquote>{entity_text}</blockquote>")
elif entity_type == MessageEntityCode: elif entity_type == MessageEntityCode:
html.append(f"<pre><code>{entity_text}</code></pre>" html.append(
if "\n" in entity_text f"<pre><code>{entity_text}</code></pre>"
else f"<code>{entity_text}</code>") if "\n" in entity_text
else f"<code>{entity_text}</code>"
)
elif entity_type == MessageEntityPre: elif entity_type == MessageEntityPre:
skip_entity = _parse_pre(html, entity_text, entity.language) skip_entity = _parse_pre(html, entity_text, entity.language)
elif entity_type == MessageEntityMention: elif entity_type == MessageEntityMention:
@@ -293,8 +337,9 @@ async def _parse_name_mention(html: list[str], entity_text: str, user_id: Telegr
return False return False
message_link_regex = re.compile(r"https?://t(?:elegram)?\.(?:me|dog)/" message_link_regex = re.compile(
r"([A-Za-z][A-Za-z0-9_]{3,}[A-Za-z0-9])/([0-9]{1,50})") r"https?://t(?:elegram)?\.(?:me|dog)/([A-Za-z][A-Za-z0-9_]{3,}[A-Za-z0-9])/([0-9]{1,50})"
)
async def _parse_url(html: list[str], entity_text: str, url: str) -> bool: async def _parse_url(html: list[str], entity_text: str, url: str) -> bool:
+4 -4
View File
@@ -1,6 +1,6 @@
import subprocess
import shutil
import os import os
import shutil
import subprocess
from . import __version__ from . import __version__
@@ -15,6 +15,7 @@ cmd_env = {
def run(cmd): def run(cmd):
return subprocess.check_output(cmd, stderr=subprocess.DEVNULL, env=cmd_env) return subprocess.check_output(cmd, stderr=subprocess.DEVNULL, env=cmd_env)
if os.path.exists(".git") and shutil.which("git"): if os.path.exists(".git") and shutil.which("git"):
try: try:
git_revision = run(["git", "rev-parse", "HEAD"]).strip().decode("ascii") git_revision = run(["git", "rev-parse", "HEAD"]).strip().decode("ascii")
@@ -33,8 +34,7 @@ else:
git_revision_url = None git_revision_url = None
git_tag = None git_tag = None
git_tag_url = (f"https://github.com/mautrix/telegram/releases/tag/{git_tag}" git_tag_url = f"https://github.com/mautrix/telegram/releases/tag/{git_tag}" if git_tag else None
if git_tag else None)
if git_tag and __version__ == git_tag[1:].replace("-", ""): if git_tag and __version__ == git_tag[1:].replace("-", ""):
version = __version__ version = __version__
+133 -71
View File
@@ -15,20 +15,33 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from __future__ import annotations
from typing import Iterable, TYPE_CHECKING from typing import TYPE_CHECKING, Iterable
from mautrix.bridge import BaseMatrixHandler from mautrix.bridge import BaseMatrixHandler
from mautrix.types import (Event, EventType, RoomID, UserID, EventID, ReceiptEvent, ReceiptType,
ReceiptEventContent, PresenceEvent, PresenceState, TypingEvent,
StateEvent, RedactionEvent,
RoomNameStateEventContent as NameContent,
RoomAvatarStateEventContent as AvatarContent,
RoomTopicStateEventContent as TopicContent,
MemberStateEventContent, TextMessageEventContent,
MessageType)
from mautrix.errors import MatrixError from mautrix.errors import MatrixError
from mautrix.types import (
Event,
EventID,
EventType,
MemberStateEventContent,
MessageType,
PresenceEvent,
PresenceState,
ReceiptEvent,
ReceiptEventContent,
ReceiptType,
RedactionEvent,
RoomAvatarStateEventContent as AvatarContent,
RoomID,
RoomNameStateEventContent as NameContent,
RoomTopicStateEventContent as TopicContent,
StateEvent,
TextMessageEventContent,
TypingEvent,
UserID,
)
from . import user as u, portal as po, puppet as pu, commands as com from . import commands as com, portal as po, puppet as pu, user as u
if TYPE_CHECKING: if TYPE_CHECKING:
from .__main__ import TelegramBridge from .__main__ import TelegramBridge
@@ -38,7 +51,7 @@ class MatrixHandler(BaseMatrixHandler):
commands: com.CommandProcessor commands: com.CommandProcessor
_previously_typing: dict[RoomID, set[UserID]] _previously_typing: dict[RoomID, set[UserID]]
def __init__(self, bridge: 'TelegramBridge') -> None: def __init__(self, bridge: "TelegramBridge") -> None:
prefix, suffix = bridge.config["bridge.username_template"].format(userid=":").split(":") prefix, suffix = bridge.config["bridge.username_template"].format(userid=":").split(":")
homeserver = bridge.config["homeserver.domain"] homeserver = bridge.config["homeserver.domain"]
self.user_id_prefix = f"@{prefix}" self.user_id_prefix = f"@{prefix}"
@@ -48,19 +61,22 @@ class MatrixHandler(BaseMatrixHandler):
self._previously_typing = {} self._previously_typing = {}
async def handle_puppet_invite(self, room_id: RoomID, puppet: pu.Puppet, inviter: u.User, async def handle_puppet_invite(
event_id: EventID) -> None: self, room_id: RoomID, puppet: pu.Puppet, inviter: u.User, event_id: EventID
) -> None:
intent = puppet.default_mxid_intent intent = puppet.default_mxid_intent
self.log.debug(f"{inviter.mxid} invited puppet for {puppet.tgid} to {room_id}") self.log.debug(f"{inviter.mxid} invited puppet for {puppet.tgid} to {room_id}")
if not await inviter.is_logged_in(): if not await inviter.is_logged_in():
await intent.error_and_leave( await intent.error_and_leave(
room_id, text="Please log in before inviting Telegram puppets.") room_id, text="Please log in before inviting Telegram puppets."
)
return return
portal = await po.Portal.get_by_mxid(room_id) portal = await po.Portal.get_by_mxid(room_id)
if portal: if portal:
if portal.peer_type == "user": if portal.peer_type == "user":
await intent.error_and_leave( await intent.error_and_leave(
room_id, text="You can not invite additional users to private chats.") room_id, text="You can not invite additional users to private chats."
)
return return
await portal.invite_telegram(inviter, puppet) await portal.invite_telegram(inviter, puppet)
await intent.join_room(room_id) await intent.join_room(room_id)
@@ -72,10 +88,15 @@ class MatrixHandler(BaseMatrixHandler):
return return
if self.az.bot_mxid not in members: if self.az.bot_mxid not in members:
if len(members) > 2: if len(members) > 2:
await intent.error_and_leave(room_id, text=None, html=( await intent.error_and_leave(
f"Please invite " room_id,
f"<a href='https://matrix.to/#/{self.az.bot_mxid}'>the bridge bot</a> " text=None,
f"first if you want to create a Telegram chat.")) html=(
f"Please invite "
f"<a href='https://matrix.to/#/{self.az.bot_mxid}'>the bridge bot</a> "
f"first if you want to create a Telegram chat."
),
)
return return
await intent.join_room(room_id) await intent.join_room(room_id)
@@ -86,9 +107,13 @@ class MatrixHandler(BaseMatrixHandler):
try: try:
await portal.invite_to_matrix(inviter.mxid) await portal.invite_to_matrix(inviter.mxid)
await intent.send_notice( await intent.send_notice(
room_id, text=f"You already have a private chat with me: {portal.mxid}", room_id,
html=("You already have a private chat with me: " text=f"You already have a private chat with me: {portal.mxid}",
f"<a href='https://matrix.to/#/{portal.mxid}'>Link to room</a>")) html=(
"You already have a private chat with me: "
f"<a href='https://matrix.to/#/{portal.mxid}'>Link to room</a>"
),
)
await intent.leave_room(room_id) await intent.leave_room(room_id)
return return
except MatrixError: except MatrixError:
@@ -99,10 +124,14 @@ class MatrixHandler(BaseMatrixHandler):
await inviter.register_portal(portal) await inviter.register_portal(portal)
if e2be_ok is True: if e2be_ok is True:
evt_type, content = await self.e2ee.encrypt( evt_type, content = await self.e2ee.encrypt(
room_id, EventType.ROOM_MESSAGE, room_id,
TextMessageEventContent(msgtype=MessageType.NOTICE, EventType.ROOM_MESSAGE,
body="Portal to private chat created and end-to-bridge" TextMessageEventContent(
" encryption enabled.")) msgtype=MessageType.NOTICE,
body="Portal to private chat created and end-to-bridge"
" encryption enabled.",
),
)
await intent.send_message_event(room_id, evt_type, content) await intent.send_message_event(room_id, evt_type, content)
else: else:
message = "Portal to private chat created." message = "Portal to private chat created."
@@ -112,11 +141,14 @@ class MatrixHandler(BaseMatrixHandler):
await portal.update_bridge_info() await portal.update_bridge_info()
else: else:
await intent.join_room(room_id) await intent.join_room(room_id)
await intent.send_notice(room_id, "This puppet will remain inactive until a " await intent.send_notice(
"Telegram chat is created for this room.") room_id,
"This puppet will remain inactive until a Telegram chat is created for this room.",
)
async def handle_invite(self, room_id: RoomID, user_id: UserID, inviter: u.User, async def handle_invite(
event_id: EventID) -> None: self, room_id: RoomID, user_id: UserID, inviter: u.User, event_id: EventID
) -> None:
user = await u.User.get_by_mxid(user_id, create=False) user = await u.User.get_by_mxid(user_id, create=False)
if not user: if not user:
return return
@@ -134,13 +166,16 @@ class MatrixHandler(BaseMatrixHandler):
return return
if not user.relaybot_whitelisted: if not user.relaybot_whitelisted:
await portal.main_intent.kick_user(room_id, user.mxid, await portal.main_intent.kick_user(
"You are not whitelisted on this Telegram bridge.") room_id, user.mxid, "You are not whitelisted on this Telegram bridge."
)
return return
elif not await user.is_logged_in() and not portal.has_bot: elif not await user.is_logged_in() and not portal.has_bot:
await portal.main_intent.kick_user(room_id, user.mxid, await portal.main_intent.kick_user(
"This chat does not have a bot relaying " room_id,
"messages for unauthenticated users.") user.mxid,
"This chat does not have a bot relaying messages for unauthenticated users.",
)
return return
self.log.debug(f"{user.mxid} joined {room_id}") self.log.debug(f"{user.mxid} joined {room_id}")
@@ -159,8 +194,15 @@ class MatrixHandler(BaseMatrixHandler):
await user.ensure_started() await user.ensure_started()
await portal.leave_matrix(user, event_id) await portal.leave_matrix(user, event_id)
async def handle_kick_ban(self, ban: bool, room_id: RoomID, user_id: UserID, sender: UserID, async def handle_kick_ban(
reason: str, event_id: EventID) -> None: self,
ban: bool,
room_id: RoomID,
user_id: UserID,
sender: UserID,
reason: str,
event_id: EventID,
) -> None:
action = "banned" if ban else "kicked" action = "banned" if ban else "kicked"
self.log.debug(f"{user_id} was {action} from {room_id} by {sender} for {reason}") self.log.debug(f"{user_id} was {action} from {room_id} by {sender} for {reason}")
portal = await po.Portal.get_by_mxid(room_id) portal = await po.Portal.get_by_mxid(room_id)
@@ -195,17 +237,20 @@ class MatrixHandler(BaseMatrixHandler):
else: else:
await portal.kick_matrix(user, sender) await portal.kick_matrix(user, sender)
async def handle_kick(self, room_id: RoomID, user_id: UserID, kicked_by: UserID, reason: str, async def handle_kick(
event_id: EventID) -> None: self, room_id: RoomID, user_id: UserID, kicked_by: UserID, reason: str, event_id: EventID
) -> None:
await self.handle_kick_ban(False, room_id, user_id, kicked_by, reason, event_id) await self.handle_kick_ban(False, room_id, user_id, kicked_by, reason, event_id)
async def handle_unban(self, room_id: RoomID, user_id: UserID, unbanned_by: UserID, async def handle_unban(
reason: str, event_id: EventID) -> None: self, room_id: RoomID, user_id: UserID, unbanned_by: UserID, reason: str, event_id: EventID
) -> None:
# TODO handle unbans properly instead of handling it as a kick # TODO handle unbans properly instead of handling it as a kick
await self.handle_kick_ban(False, room_id, user_id, unbanned_by, reason, event_id) await self.handle_kick_ban(False, room_id, user_id, unbanned_by, reason, event_id)
async def handle_ban(self, room_id: RoomID, user_id: UserID, banned_by: UserID, reason: str, async def handle_ban(
event_id: EventID) -> None: self, room_id: RoomID, user_id: UserID, banned_by: UserID, reason: str, event_id: EventID
) -> None:
await self.handle_kick_ban(True, room_id, user_id, banned_by, reason, event_id) await self.handle_kick_ban(True, room_id, user_id, banned_by, reason, event_id)
async def allow_message(self, user: u.User) -> bool: async def allow_message(self, user: u.User) -> bool:
@@ -235,9 +280,9 @@ class MatrixHandler(BaseMatrixHandler):
portal = await po.Portal.get_by_mxid(evt.room_id) portal = await po.Portal.get_by_mxid(evt.room_id)
sender = await u.User.get_and_start_by_mxid(evt.sender) sender = await u.User.get_and_start_by_mxid(evt.sender)
if await sender.has_full_access(allow_bot=True) and portal and portal.allow_bridging: if await sender.has_full_access(allow_bot=True) and portal and portal.allow_bridging:
await portal.handle_matrix_power_levels(sender, evt.content.users, await portal.handle_matrix_power_levels(
evt.unsigned.prev_content.users, sender, evt.content.users, evt.unsigned.prev_content.users, evt.event_id
evt.event_id) )
@staticmethod @staticmethod
async def handle_room_meta( async def handle_room_meta(
@@ -245,7 +290,7 @@ class MatrixHandler(BaseMatrixHandler):
room_id: RoomID, room_id: RoomID,
sender_mxid: UserID, sender_mxid: UserID,
content: NameContent | AvatarContent | TopicContent, content: NameContent | AvatarContent | TopicContent,
event_id: EventID event_id: EventID,
) -> None: ) -> None:
portal = await po.Portal.get_by_mxid(room_id) portal = await po.Portal.get_by_mxid(room_id)
sender = await u.User.get_and_start_by_mxid(sender_mxid) sender = await u.User.get_and_start_by_mxid(sender_mxid)
@@ -260,30 +305,40 @@ class MatrixHandler(BaseMatrixHandler):
await handler(sender, content[content_key], event_id) await handler(sender, content[content_key], event_id)
@staticmethod @staticmethod
async def handle_room_pin(room_id: RoomID, sender_mxid: UserID, async def handle_room_pin(
new_events: set[str], old_events: set[str], room_id: RoomID,
event_id: EventID) -> None: sender_mxid: UserID,
new_events: set[str],
old_events: set[str],
event_id: EventID,
) -> None:
portal = await po.Portal.get_by_mxid(room_id) portal = await po.Portal.get_by_mxid(room_id)
sender = await u.User.get_and_start_by_mxid(sender_mxid) sender = await u.User.get_and_start_by_mxid(sender_mxid)
if await sender.has_full_access(allow_bot=True) and portal and portal.allow_bridging: if await sender.has_full_access(allow_bot=True) and portal and portal.allow_bridging:
if not new_events: if not new_events:
await portal.handle_matrix_unpin_all(sender, event_id) await portal.handle_matrix_unpin_all(sender, event_id)
else: else:
changes = {event_id: event_id in new_events changes = {
for event_id in new_events ^ old_events} event_id: event_id in new_events for event_id in new_events ^ old_events
}
await portal.handle_matrix_pin(sender, changes, event_id) await portal.handle_matrix_pin(sender, changes, event_id)
@staticmethod @staticmethod
async def handle_room_upgrade(room_id: RoomID, sender: UserID, new_room_id: RoomID, async def handle_room_upgrade(
event_id: EventID) -> None: room_id: RoomID, sender: UserID, new_room_id: RoomID, event_id: EventID
) -> None:
portal = await po.Portal.get_by_mxid(room_id) portal = await po.Portal.get_by_mxid(room_id)
if portal and portal.allow_bridging: if portal and portal.allow_bridging:
await portal.handle_matrix_upgrade(sender, new_room_id, event_id) await portal.handle_matrix_upgrade(sender, new_room_id, event_id)
async def handle_member_info_change(self, room_id: RoomID, user_id: UserID, async def handle_member_info_change(
profile: MemberStateEventContent, self,
prev_profile: MemberStateEventContent, room_id: RoomID,
event_id: EventID) -> None: user_id: UserID,
profile: MemberStateEventContent,
prev_profile: MemberStateEventContent,
event_id: EventID,
) -> None:
if profile.displayname == prev_profile.displayname: if profile.displayname == prev_profile.displayname:
return return
@@ -293,18 +348,22 @@ class MatrixHandler(BaseMatrixHandler):
user = await u.User.get_and_start_by_mxid(user_id) user = await u.User.get_and_start_by_mxid(user_id)
if await user.needs_relaybot(portal): if await user.needs_relaybot(portal):
await portal.name_change_matrix(user, profile.displayname, prev_profile.displayname, await portal.name_change_matrix(
event_id) user, profile.displayname, prev_profile.displayname, event_id
)
@staticmethod @staticmethod
def parse_read_receipts(content: ReceiptEventContent) -> Iterable[tuple[UserID, EventID]]: def parse_read_receipts(content: ReceiptEventContent) -> Iterable[tuple[UserID, EventID]]:
return ((user_id, event_id) return (
for event_id, receipts in content.items() (user_id, event_id)
for user_id in receipts.get(ReceiptType.READ, {})) for event_id, receipts in content.items()
for user_id in receipts.get(ReceiptType.READ, {})
)
@staticmethod @staticmethod
async def handle_read_receipts(room_id: RoomID, receipts: Iterable[tuple[UserID, EventID]] async def handle_read_receipts(
) -> None: room_id: RoomID, receipts: Iterable[tuple[UserID, EventID]]
) -> None:
portal = await po.Portal.get_by_mxid(room_id) portal = await po.Portal.get_by_mxid(room_id)
if not portal or not portal.allow_bridging: if not portal or not portal.allow_bridging:
return return
@@ -357,16 +416,19 @@ class MatrixHandler(BaseMatrixHandler):
if evt.type == EventType.ROOM_POWER_LEVELS: if evt.type == EventType.ROOM_POWER_LEVELS:
await self.handle_power_levels(evt) await self.handle_power_levels(evt)
elif evt.type in (EventType.ROOM_NAME, EventType.ROOM_AVATAR, EventType.ROOM_TOPIC): elif evt.type in (EventType.ROOM_NAME, EventType.ROOM_AVATAR, EventType.ROOM_TOPIC):
await self.handle_room_meta(evt.type, evt.room_id, evt.sender, evt.content, await self.handle_room_meta(
evt.event_id) evt.type, evt.room_id, evt.sender, evt.content, evt.event_id
)
elif evt.type == EventType.ROOM_PINNED_EVENTS: elif evt.type == EventType.ROOM_PINNED_EVENTS:
new_events = set(evt.content.pinned) new_events = set(evt.content.pinned)
try: try:
old_events = set(evt.unsigned.prev_content.pinned) old_events = set(evt.unsigned.prev_content.pinned)
except (KeyError, ValueError, TypeError, AttributeError): except (KeyError, ValueError, TypeError, AttributeError):
old_events = set() old_events = set()
await self.handle_room_pin(evt.room_id, evt.sender, new_events, old_events, await self.handle_room_pin(
evt.event_id) evt.room_id, evt.sender, new_events, old_events, evt.event_id
)
elif evt.type == EventType.ROOM_TOMBSTONE: elif evt.type == EventType.ROOM_TOMBSTONE:
await self.handle_room_upgrade(evt.room_id, evt.sender, evt.content.replacement_room, await self.handle_room_upgrade(
evt.event_id) evt.room_id, evt.sender, evt.content.replacement_room, evt.event_id
)
+985 -518
View File
File diff suppressed because it is too large Load Diff
+54 -32
View File
@@ -15,24 +15,32 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from __future__ import annotations
from typing import Awaitable, AsyncGenerator, AsyncIterable, TYPE_CHECKING, cast from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, cast
from difflib import SequenceMatcher from difflib import SequenceMatcher
import unicodedata import unicodedata
from telethon.tl.types import (UserProfilePhoto, User, UpdateUserName, PeerUser, TypeInputPeer, from telethon.tl.types import (
InputPeerPhotoFileLocation, UserProfilePhotoEmpty, TypeInputUser) InputPeerPhotoFileLocation,
PeerUser,
TypeInputPeer,
TypeInputUser,
UpdateUserName,
User,
UserProfilePhoto,
UserProfilePhotoEmpty,
)
from yarl import URL from yarl import URL
from mautrix.appservice import IntentAPI from mautrix.appservice import IntentAPI
from mautrix.errors import MatrixError
from mautrix.bridge import BasePuppet, async_getter_lock from mautrix.bridge import BasePuppet, async_getter_lock
from mautrix.types import UserID, SyncToken, RoomID, ContentURI from mautrix.errors import MatrixError
from mautrix.types import ContentURI, RoomID, SyncToken, UserID
from mautrix.util.simple_template import SimpleTemplate from mautrix.util.simple_template import SimpleTemplate
from . import abstract_user as au, portal as p, util
from .config import Config from .config import Config
from .types import TelegramID
from .db import Puppet as DBPuppet from .db import Puppet as DBPuppet
from . import util, portal as p, abstract_user as au from .types import TelegramID
if TYPE_CHECKING: if TYPE_CHECKING:
from .__main__ import TelegramBridge from .__main__ import TelegramBridge
@@ -62,7 +70,7 @@ class Puppet(DBPuppet, BasePuppet):
custom_mxid: UserID | None = None, custom_mxid: UserID | None = None,
access_token: str | None = None, access_token: str | None = None,
next_batch: SyncToken | None = None, next_batch: SyncToken | None = None,
base_url: str | None = None base_url: str | None = None,
) -> None: ) -> None:
super().__init__( super().__init__(
id=id, id=id,
@@ -116,7 +124,7 @@ class Puppet(DBPuppet, BasePuppet):
return self.intent return self.intent
@classmethod @classmethod
def init_cls(cls, bridge: 'TelegramBridge') -> AsyncIterable[Awaitable[None]]: def init_cls(cls, bridge: "TelegramBridge") -> AsyncIterable[Awaitable[None]]:
cls.config = bridge.config cls.config = bridge.config
cls.loop = bridge.loop cls.loop = bridge.loop
cls.mx = bridge.matrix cls.mx = bridge.matrix
@@ -134,11 +142,15 @@ class Puppet(DBPuppet, BasePuppet):
cls.config["bridge.displayname_template"], "displayname" cls.config["bridge.displayname_template"], "displayname"
) )
cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"] cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
cls.homeserver_url_map = {server: URL(url) for server, url cls.homeserver_url_map = {
in cls.config["bridge.double_puppet_server_map"].items()} server: URL(url)
for server, url in cls.config["bridge.double_puppet_server_map"].items()
}
cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"] cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"]
cls.login_shared_secret_map = {server: secret.encode("utf-8") for server, secret cls.login_shared_secret_map = {
in cls.config["bridge.login_shared_secret_map"].items()} server: secret.encode("utf-8")
for server, secret in cls.config["bridge.login_shared_secret_map"].items()
}
cls.login_device_name = "Telegram Bridge" cls.login_device_name = "Telegram Bridge"
return (puppet.try_start() async for puppet in cls.all_with_custom_mxid()) return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
@@ -146,10 +158,12 @@ class Puppet(DBPuppet, BasePuppet):
# region Info updating # region Info updating
def similarity(self, query: str) -> int: def similarity(self, query: str) -> int:
username_similarity = (SequenceMatcher(None, self.username, query).ratio() username_similarity = (
if self.username else 0) SequenceMatcher(None, self.username, query).ratio() if self.username else 0
displayname_similarity = (SequenceMatcher(None, self.plain_displayname, query).ratio() )
if self.displayname else 0) displayname_similarity = (
SequenceMatcher(None, self.plain_displayname, query).ratio() if self.displayname else 0
)
similarity = max(username_similarity, displayname_similarity) similarity = max(username_similarity, displayname_similarity)
return int(round(similarity * 100)) return int(round(similarity * 100))
@@ -157,12 +171,17 @@ class Puppet(DBPuppet, BasePuppet):
def _filter_name(name: str) -> str: def _filter_name(name: str) -> str:
if not name: if not name:
return "" return ""
whitespace = ("\t\n\r\v\f \u00a0\u034f\u180e\u2063\u202f\u205f\u2800\u3000\u3164\ufeff" whitespace = (
"\u2000\u2001\u2002\u2003\u2004\u2005\u2006\u2007\u2008\u2009\u200a\u200b" "\t\n\r\v\f \u00a0\u034f\u180e\u2063\u202f\u205f\u2800\u3000\u3164\ufeff\u2000\u2001"
"\u200c\u200d\u200e\u200f\ufe0f") "\u2002\u2003\u2004\u2005\u2006\u2007\u2008\u2009\u200a\u200b\u200c\u200d\u200e\u200f"
"\ufe0f"
)
allowed_other_format = ("\u200d", "\u200c") allowed_other_format = ("\u200d", "\u200c")
name = "".join(c for c in name.strip(whitespace) if unicodedata.category(c) != 'Cf' name = "".join(
or c in allowed_other_format) c
for c in name.strip(whitespace)
if unicodedata.category(c) != "Cf" or c in allowed_other_format
)
return name return name
@classmethod @classmethod
@@ -219,8 +238,9 @@ class Puppet(DBPuppet, BasePuppet):
if changed: if changed:
await self.save() await self.save()
async def update_displayname(self, source: au.AbstractUser, info: User | UpdateUserName async def update_displayname(
) -> bool: self, source: au.AbstractUser, info: User | UpdateUserName
) -> bool:
if self.disable_updates: if self.disable_updates:
return False return False
if source.is_relaybot or source.is_bot: if source.is_relaybot or source.is_bot:
@@ -249,15 +269,18 @@ class Puppet(DBPuppet, BasePuppet):
displayname, quality = self.get_displayname(info) displayname, quality = self.get_displayname(info)
if displayname != self.displayname and quality >= self.displayname_quality: if displayname != self.displayname and quality >= self.displayname_quality:
allow_because = f"{allow_because} and quality {quality} >= {self.displayname_quality}" allow_because = f"{allow_because} and quality {quality} >= {self.displayname_quality}"
self.log.debug(f"Updating displayname of {self.id} (src: {source.tgid}, allowed " self.log.debug(
f"because {allow_because}) from {self.displayname} to {displayname}") f"Updating displayname of {self.id} (src: {source.tgid}, allowed "
f"because {allow_because}) from {self.displayname} to {displayname}"
)
self.log.trace("Displayname source data: %s", info) self.log.trace("Displayname source data: %s", info)
self.displayname = displayname self.displayname = displayname
self.displayname_source = source.tgid self.displayname_source = source.tgid
self.displayname_quality = quality self.displayname_quality = quality
try: try:
await self.default_mxid_intent.set_displayname( await self.default_mxid_intent.set_displayname(
displayname[:self.config["bridge.displayname_max_length"]]) displayname[: self.config["bridge.displayname_max_length"]]
)
except MatrixError: except MatrixError:
self.log.exception("Failed to set displayname") self.log.exception("Failed to set displayname")
self.displayname = "" self.displayname = ""
@@ -269,8 +292,9 @@ class Puppet(DBPuppet, BasePuppet):
return True return True
return False return False
async def update_avatar(self, source: au.AbstractUser, async def update_avatar(
photo: UserProfilePhoto | UserProfilePhotoEmpty) -> bool: self, source: au.AbstractUser, photo: UserProfilePhoto | UserProfilePhotoEmpty
) -> bool:
if self.disable_updates: if self.disable_updates:
return False return False
@@ -294,9 +318,7 @@ class Puppet(DBPuppet, BasePuppet):
return True return True
loc = InputPeerPhotoFileLocation( loc = InputPeerPhotoFileLocation(
peer=await self.get_input_entity(source), peer=await self.get_input_entity(source), photo_id=photo.photo_id, big=True
photo_id=photo.photo_id,
big=True
) )
file = await util.transfer_file_to_matrix(source.client, self.default_mxid_intent, loc) file = await util.transfer_file_to_matrix(source.client, self.default_mxid_intent, loc)
if file: if file:
+35 -18
View File
@@ -1,5 +1,5 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge # mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2021 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@@ -13,24 +13,35 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Union, Optional from typing import List, Optional, Union
from telethon import TelegramClient, utils from telethon import TelegramClient, utils
from telethon.tl.functions.messages import SendMediaRequest
from telethon.tl.types import (InputMediaUploadedDocument, InputMediaUploadedPhoto,
TypeDocumentAttribute, TypeInputMedia, TypeInputPeer,
TypeMessageEntity, TypeMessageMedia, TypePeer)
from telethon.tl.patched import Message
from telethon.sessions.abstract import Session from telethon.sessions.abstract import Session
from telethon.tl.functions.messages import SendMediaRequest
from telethon.tl.patched import Message
from telethon.tl.types import (
InputMediaUploadedDocument,
InputMediaUploadedPhoto,
TypeDocumentAttribute,
TypeInputMedia,
TypeInputPeer,
TypeMessageEntity,
TypeMessageMedia,
TypePeer,
)
class MautrixTelegramClient(TelegramClient): class MautrixTelegramClient(TelegramClient):
session: Session session: Session
async def upload_file_direct(self, file: bytes, mime_type: str = None, async def upload_file_direct(
attributes: List[TypeDocumentAttribute] = None, self,
file_name: str = None, max_image_size: float = 10 * 1000 ** 2, file: bytes,
) -> Union[InputMediaUploadedDocument, InputMediaUploadedPhoto]: mime_type: str = None,
attributes: List[TypeDocumentAttribute] = None,
file_name: str = None,
max_image_size: float = 10 * 1000 ** 2,
) -> Union[InputMediaUploadedDocument, InputMediaUploadedPhoto]:
file_handle = await super().upload_file(file, file_name=file_name) file_handle = await super().upload_file(file, file_name=file_name)
if (mime_type == "image/png" or mime_type == "image/jpeg") and len(file) < max_image_size: if (mime_type == "image/png" or mime_type == "image/jpeg") and len(file) < max_image_size:
@@ -42,14 +53,20 @@ class MautrixTelegramClient(TelegramClient):
return InputMediaUploadedDocument( return InputMediaUploadedDocument(
file=file_handle, file=file_handle,
mime_type=mime_type or "application/octet-stream", mime_type=mime_type or "application/octet-stream",
attributes=list(attr_dict.values())) attributes=list(attr_dict.values()),
)
async def send_media(self, entity: Union[TypeInputPeer, TypePeer], async def send_media(
media: Union[TypeInputMedia, TypeMessageMedia], self,
caption: str = None, entities: List[TypeMessageEntity] = None, entity: Union[TypeInputPeer, TypePeer],
reply_to: int = None) -> Optional[Message]: media: Union[TypeInputMedia, TypeMessageMedia],
caption: str = None,
entities: List[TypeMessageEntity] = None,
reply_to: int = None,
) -> Optional[Message]:
entity = await self.get_input_entity(entity) entity = await self.get_input_entity(entity)
reply_to = utils.get_message_id(reply_to) reply_to = utils.get_message_id(reply_to)
request = SendMediaRequest(entity, media, message=caption or "", entities=entities or [], request = SendMediaRequest(
reply_to_msg_id=reply_to) entity, media, message=caption or "", entities=entities or [], reply_to_msg_id=reply_to
)
return self._get_response_message(request, await self(request), entity) return self._get_response_message(request, await self(request), entity)
+1 -1
View File
@@ -1,3 +1,3 @@
from typing import NewType from typing import NewType
TelegramID = NewType('TelegramID', int) TelegramID = NewType("TelegramID", int)
+131 -81
View File
@@ -15,49 +15,62 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from __future__ import annotations
from typing import Awaitable, AsyncIterable, NamedTuple, AsyncGenerator, TYPE_CHECKING, cast from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, NamedTuple, cast
from datetime import datetime, timezone from datetime import datetime, timezone
import asyncio import asyncio
from telethon.tl.types import (TypeUpdate, UpdateNewMessage, UpdateNewChannelMessage,
UpdateShortChatMessage, UpdateShortMessage, User as TLUser, Chat,
ChatForbidden, UpdateFolderPeers, UpdatePinnedDialogs,
UpdateNotifySettings, NotifyPeer, InputUserSelf)
from telethon.tl.custom import Dialog
from telethon.tl.types.contacts import ContactsNotModified
from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest
from telethon.tl.functions.account import UpdateStatusRequest
from telethon.tl.functions.users import GetUsersRequest
from telethon.tl.functions.updates import GetStateRequest
from telethon.errors import AuthKeyDuplicatedError, RPCError, UnauthorizedError from telethon.errors import AuthKeyDuplicatedError, RPCError, UnauthorizedError
from telethon.tl.custom import Dialog
from telethon.tl.functions.account import UpdateStatusRequest
from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest
from telethon.tl.functions.updates import GetStateRequest
from telethon.tl.functions.users import GetUsersRequest
from telethon.tl.types import (
Chat,
ChatForbidden,
InputUserSelf,
NotifyPeer,
TypeUpdate,
UpdateFolderPeers,
UpdateNewChannelMessage,
UpdateNewMessage,
UpdateNotifySettings,
UpdatePinnedDialogs,
UpdateShortChatMessage,
UpdateShortMessage,
User as TLUser,
)
from telethon.tl.types.contacts import ContactsNotModified
from mautrix.client import Client
from mautrix.errors import MatrixRequestError, MNotFound
from mautrix.types import UserID, RoomID, PushRuleScope, PushRuleKind, PushActionType, RoomTagInfo
from mautrix.appservice import DOUBLE_PUPPET_SOURCE_KEY from mautrix.appservice import DOUBLE_PUPPET_SOURCE_KEY
from mautrix.bridge import BaseUser, async_getter_lock from mautrix.bridge import BaseUser, async_getter_lock
from mautrix.client import Client
from mautrix.errors import MatrixRequestError, MNotFound
from mautrix.types import PushActionType, PushRuleKind, PushRuleScope, RoomID, RoomTagInfo, UserID
from mautrix.util.bridge_state import BridgeState, BridgeStateEvent from mautrix.util.bridge_state import BridgeState, BridgeStateEvent
from mautrix.util.opt_prometheus import Gauge from mautrix.util.opt_prometheus import Gauge
from .types import TelegramID
from .db import User as DBUser, Message as DBMessage, PgSession
from .abstract_user import AbstractUser
from . import portal as po, puppet as pu from . import portal as po, puppet as pu
from .abstract_user import AbstractUser
from .db import Message as DBMessage, PgSession, User as DBUser
from .types import TelegramID
if TYPE_CHECKING: if TYPE_CHECKING:
from .__main__ import TelegramBridge from .__main__ import TelegramBridge
SearchResult = NamedTuple('SearchResult', puppet='pu.Puppet', similarity=int) SearchResult = NamedTuple("SearchResult", puppet="pu.Puppet", similarity=int)
METRIC_LOGGED_IN = Gauge('bridge_logged_in', 'Users logged into bridge') METRIC_LOGGED_IN = Gauge("bridge_logged_in", "Users logged into bridge")
METRIC_CONNECTED = Gauge('bridge_connected', 'Users connected to Telegram') METRIC_CONNECTED = Gauge("bridge_connected", "Users connected to Telegram")
BridgeState.human_readable_errors.update({ BridgeState.human_readable_errors.update(
"tg-not-connected": "Your Telegram connection failed", {
"tg-auth-key-duplicated": "The bridge accidentally logged you out", "tg-not-connected": "Your Telegram connection failed",
"tg-not-authenticated": "The stored auth token did not work", "tg-auth-key-duplicated": "The bridge accidentally logged you out",
"tg-no-auth": "You're not logged in", "tg-not-authenticated": "The stored auth token did not work",
}) "tg-no-auth": "You're not logged in",
}
)
class User(DBUser, AbstractUser, BaseUser): class User(DBUser, AbstractUser, BaseUser):
@@ -94,12 +107,14 @@ class User(DBUser, AbstractUser, BaseUser):
self._is_backfilling = False self._is_backfilling = False
self._portals_cache = None self._portals_cache = None
(self.relaybot_whitelisted, (
self.whitelisted, self.relaybot_whitelisted,
self.puppet_whitelisted, self.whitelisted,
self.matrix_puppet_whitelisted, self.puppet_whitelisted,
self.is_admin, self.matrix_puppet_whitelisted,
self.permissions) = self.config.get_permissions(self.mxid) self.is_admin,
self.permissions,
) = self.config.get_permissions(self.mxid)
@property @property
def name(self) -> str: def name(self) -> str:
@@ -124,7 +139,7 @@ class User(DBUser, AbstractUser, BaseUser):
return self.displayname return self.displayname
@classmethod @classmethod
def init_cls(cls, bridge: 'TelegramBridge') -> AsyncIterable[Awaitable[User]]: def init_cls(cls, bridge: "TelegramBridge") -> AsyncIterable[Awaitable[User]]:
cls.config = bridge.config cls.config = bridge.config
cls.bridge = bridge cls.bridge = bridge
cls.az = bridge.az cls.az = bridge.az
@@ -143,8 +158,9 @@ class User(DBUser, AbstractUser, BaseUser):
if not self.client and not await PgSession.has(self.mxid): if not self.client and not await PgSession.has(self.mxid):
self.log.warning("Didn't start user: no session stored") self.log.warning("Didn't start user: no session stored")
if self.tgid: if self.tgid:
await self.push_bridge_state(BridgeStateEvent.BAD_CREDENTIALS, await self.push_bridge_state(
error="tg-no-auth") BridgeStateEvent.BAD_CREDENTIALS, error="tg-no-auth"
)
async def ensure_started(self, even_if_no_session=False) -> User: async def ensure_started(self, even_if_no_session=False) -> User:
if not self.puppet_whitelisted or self.connected: if not self.puppet_whitelisted or self.connected:
@@ -157,8 +173,9 @@ class User(DBUser, AbstractUser, BaseUser):
await super().start() await super().start()
except AuthKeyDuplicatedError: except AuthKeyDuplicatedError:
self.log.warning("Got AuthKeyDuplicatedError in start()") self.log.warning("Got AuthKeyDuplicatedError in start()")
await self.push_bridge_state(BridgeStateEvent.BAD_CREDENTIALS, await self.push_bridge_state(
error="tg-auth-key-duplicated") BridgeStateEvent.BAD_CREDENTIALS, error="tg-auth-key-duplicated"
)
await self.client.disconnect() await self.client.disconnect()
await self.client.session.delete() await self.client.session.delete()
self.client = None self.client = None
@@ -180,8 +197,12 @@ class User(DBUser, AbstractUser, BaseUser):
except UnauthorizedError as e: except UnauthorizedError as e:
self.log.error(f"Authorization error in start(): {type(e)}: {e}") self.log.error(f"Authorization error in start(): {type(e)}: {e}")
if self.tgid: if self.tgid:
await self.push_bridge_state(BridgeStateEvent.BAD_CREDENTIALS, await self.push_bridge_state(
error="tg-auth-error", message=str(e), ttl=3600) BridgeStateEvent.BAD_CREDENTIALS,
error="tg-auth-error",
message=str(e),
ttl=3600,
)
except RPCError as e: except RPCError as e:
self.log.error(f"Unknown RPC error in start(): {type(e)}: {e}") self.log.error(f"Unknown RPC error in start(): {type(e)}: {e}")
if self.tgid: if self.tgid:
@@ -200,8 +221,9 @@ class User(DBUser, AbstractUser, BaseUser):
@property @property
def _is_connected(self) -> bool: def _is_connected(self) -> bool:
return bool(self.client and self.client._sender return bool(
and self.client._sender._transport_connected()) self.client and self.client._sender and self.client._sender._transport_connected()
)
async def _track_connection(self) -> None: async def _track_connection(self) -> None:
self.log.debug("Starting loop to track connection state") self.log.debug("Starting loop to track connection state")
@@ -210,11 +232,16 @@ class User(DBUser, AbstractUser, BaseUser):
connected = self._is_connected connected = self._is_connected
self._track_metric(METRIC_CONNECTED, connected) self._track_metric(METRIC_CONNECTED, connected)
if connected: if connected:
await self.push_bridge_state(BridgeStateEvent.BACKFILLING if self._is_backfilling await self.push_bridge_state(
else BridgeStateEvent.CONNECTED, ttl=3600) BridgeStateEvent.BACKFILLING
if self._is_backfilling
else BridgeStateEvent.CONNECTED,
ttl=3600,
)
else: else:
await self.push_bridge_state(BridgeStateEvent.UNKNOWN_ERROR, ttl=240, await self.push_bridge_state(
error="tg-not-connected") BridgeStateEvent.UNKNOWN_ERROR, ttl=240, error="tg-not-connected"
)
async def fill_bridge_state(self, state: BridgeState) -> None: async def fill_bridge_state(self, state: BridgeState) -> None:
await super().fill_bridge_state(state) await super().fill_bridge_state(state)
@@ -225,8 +252,11 @@ class User(DBUser, AbstractUser, BaseUser):
if not self.tgid: if not self.tgid:
return [] return []
if self._is_connected and await self.is_logged_in(): if self._is_connected and await self.is_logged_in():
state_event = (BridgeStateEvent.BACKFILLING if self._is_backfilling state_event = (
else BridgeStateEvent.CONNECTED) BridgeStateEvent.BACKFILLING
if self._is_backfilling
else BridgeStateEvent.CONNECTED
)
ttl = 3600 ttl = 3600
else: else:
state_event = BridgeStateEvent.UNKNOWN_ERROR state_event = BridgeStateEvent.UNKNOWN_ERROR
@@ -309,8 +339,9 @@ class User(DBUser, AbstractUser, BaseUser):
return (await self.client(GetUsersRequest([InputUserSelf()])))[0] return (await self.client(GetUsersRequest([InputUserSelf()])))[0]
except UnauthorizedError as e: except UnauthorizedError as e:
self.log.error(f"Authorization error in get_me(): {type(e)}: {e}") self.log.error(f"Authorization error in get_me(): {type(e)}: {e}")
await self.push_bridge_state(BridgeStateEvent.BAD_CREDENTIALS, error="tg-auth-error", await self.push_bridge_state(
message=str(e), ttl=3600) BridgeStateEvent.BAD_CREDENTIALS, error="tg-auth-error", message=str(e), ttl=3600
)
await self.stop() await self.stop()
return None return None
@@ -347,8 +378,9 @@ class User(DBUser, AbstractUser, BaseUser):
await portal.cleanup_portal("Logged out of Telegram") await portal.cleanup_portal("Logged out of Telegram")
else: else:
try: try:
await portal.main_intent.kick_user(portal.mxid, self.mxid, await portal.main_intent.kick_user(
"Logged out of Telegram.") portal.mxid, self.mxid, "Logged out of Telegram."
)
except MatrixRequestError: except MatrixRequestError:
pass pass
@@ -375,8 +407,9 @@ class User(DBUser, AbstractUser, BaseUser):
self._track_metric(METRIC_LOGGED_IN, False) self._track_metric(METRIC_LOGGED_IN, False)
return ok return ok
async def _search_local(self, query: str, max_results: int = 5, min_similarity: int = 45 async def _search_local(
) -> list[SearchResult]: self, query: str, max_results: int = 5, min_similarity: int = 45
) -> list[SearchResult]:
results: list[SearchResult] = [] results: list[SearchResult] = []
for contact_id in await self.get_contacts(): for contact_id in await self.get_contacts():
contact = await pu.Puppet.get_by_tgid(contact_id, create=False) contact = await pu.Puppet.get_by_tgid(contact_id, create=False)
@@ -400,8 +433,9 @@ class User(DBUser, AbstractUser, BaseUser):
results.sort(key=lambda tup: tup[1], reverse=True) results.sort(key=lambda tup: tup[1], reverse=True)
return results[0:max_results] return results[0:max_results]
async def search(self, query: str, force_remote: bool = False async def search(
) -> tuple[list[SearchResult], bool]: self, query: str, force_remote: bool = False
) -> tuple[list[SearchResult], bool]:
if force_remote: if force_remote:
return await self._search_remote(query), True return await self._search_remote(query), True
@@ -418,8 +452,9 @@ class User(DBUser, AbstractUser, BaseUser):
if portal.mxid if portal.mxid
} }
async def _tag_room(self, puppet: pu.Puppet, portal: po.Portal, tag: str, active: bool async def _tag_room(
) -> None: self, puppet: pu.Puppet, portal: po.Portal, tag: str, active: bool
) -> None:
if not tag or not portal or not portal.mxid: if not tag or not portal or not portal.mxid:
return return
tag_info = await puppet.intent.get_room_tag(portal.mxid, tag) tag_info = await puppet.intent.get_room_tag(portal.mxid, tag)
@@ -428,8 +463,7 @@ class User(DBUser, AbstractUser, BaseUser):
tag_info[DOUBLE_PUPPET_SOURCE_KEY] = self.bridge.name tag_info[DOUBLE_PUPPET_SOURCE_KEY] = self.bridge.name
await puppet.intent.set_room_tag(portal.mxid, tag, tag_info) await puppet.intent.set_room_tag(portal.mxid, tag, tag_info)
elif ( elif (
not active and tag_info not active and tag_info and tag_info.get(DOUBLE_PUPPET_SOURCE_KEY) == self.bridge.name
and tag_info.get(DOUBLE_PUPPET_SOURCE_KEY) == self.bridge.name
): ):
await puppet.intent.remove_room_tag(portal.mxid, tag) await puppet.intent.remove_room_tag(portal.mxid, tag)
@@ -438,12 +472,17 @@ class User(DBUser, AbstractUser, BaseUser):
return return
now = datetime.utcnow().replace(tzinfo=timezone.utc) now = datetime.utcnow().replace(tzinfo=timezone.utc)
if mute_until is not None and mute_until > now: if mute_until is not None and mute_until > now:
await puppet.intent.set_push_rule(PushRuleScope.GLOBAL, PushRuleKind.ROOM, portal.mxid, await puppet.intent.set_push_rule(
actions=[PushActionType.DONT_NOTIFY]) PushRuleScope.GLOBAL,
PushRuleKind.ROOM,
portal.mxid,
actions=[PushActionType.DONT_NOTIFY],
)
else: else:
try: try:
await puppet.intent.remove_push_rule(PushRuleScope.GLOBAL, PushRuleKind.ROOM, await puppet.intent.remove_push_rule(
portal.mxid) PushRuleScope.GLOBAL, PushRuleKind.ROOM, portal.mxid
)
except MNotFound: except MNotFound:
pass pass
@@ -455,8 +494,9 @@ class User(DBUser, AbstractUser, BaseUser):
return return
for peer in update.folder_peers: for peer in update.folder_peers:
portal = await po.Portal.get_by_entity(peer.peer, tg_receiver=self.tgid, create=False) portal = await po.Portal.get_by_entity(peer.peer, tg_receiver=self.tgid, create=False)
await self._tag_room(puppet, portal, self.config["bridge.archive_tag"], await self._tag_room(
peer.folder_id == 1) puppet, portal, self.config["bridge.archive_tag"], peer.folder_id == 1
)
async def update_pinned_dialogs(self, update: UpdatePinnedDialogs) -> None: async def update_pinned_dialogs(self, update: UpdatePinnedDialogs) -> None:
if self.config["bridge.tag_only_on_create"]: if self.config["bridge.tag_only_on_create"]:
@@ -485,8 +525,9 @@ class User(DBUser, AbstractUser, BaseUser):
) )
await self._mute_room(puppet, portal, update.notify_settings.mute_until) await self._mute_room(puppet, portal, update.notify_settings.mute_until)
async def _sync_dialog(self, portal: po.Portal, dialog: Dialog, should_create: bool, async def _sync_dialog(
puppet: pu.Puppet | None) -> None: self, portal: po.Portal, dialog: Dialog, should_create: bool, puppet: pu.Puppet | None
) -> None:
was_created = False was_created = False
if portal.mxid: if portal.mxid:
try: try:
@@ -510,16 +551,19 @@ class User(DBUser, AbstractUser, BaseUser):
# e.g. if the last read message is a service message that isn't in the message db # e.g. if the last read message is a service message that isn't in the message db
last_read = await DBMessage.find_last(portal.mxid, tg_space) last_read = await DBMessage.find_last(portal.mxid, tg_space)
else: else:
last_read = await DBMessage.get_one_by_tgid(portal.tgid, tg_space, last_read = await DBMessage.get_one_by_tgid(
dialog.dialog.read_inbox_max_id) portal.tgid, tg_space, dialog.dialog.read_inbox_max_id
)
if last_read: if last_read:
await puppet.intent.mark_read(last_read.mx_room, last_read.mxid) await puppet.intent.mark_read(last_read.mx_room, last_read.mxid)
if was_created or not self.config["bridge.tag_only_on_create"]: if was_created or not self.config["bridge.tag_only_on_create"]:
await self._mute_room(puppet, portal, dialog.dialog.notify_settings.mute_until) await self._mute_room(puppet, portal, dialog.dialog.notify_settings.mute_until)
await self._tag_room(puppet, portal, self.config["bridge.pinned_tag"], await self._tag_room(
dialog.pinned) puppet, portal, self.config["bridge.pinned_tag"], dialog.pinned
await self._tag_room(puppet, portal, self.config["bridge.archive_tag"], )
dialog.archived) await self._tag_room(
puppet, portal, self.config["bridge.archive_tag"], dialog.archived
)
async def get_cached_portals(self) -> dict[tuple[TelegramID, TelegramID], po.Portal]: async def get_cached_portals(self) -> dict[tuple[TelegramID, TelegramID], po.Portal]:
if self._portals_cache is None: if self._portals_cache is None:
@@ -536,15 +580,17 @@ class User(DBUser, AbstractUser, BaseUser):
update_limit = self.config["bridge.sync_update_limit"] or None update_limit = self.config["bridge.sync_update_limit"] or None
create_limit = self.config["bridge.sync_create_limit"] create_limit = self.config["bridge.sync_create_limit"]
index = 0 index = 0
self.log.debug(f"Syncing dialogs (update_limit={update_limit}, " self.log.debug(
f"create_limit={create_limit})") f"Syncing dialogs (update_limit={update_limit}, create_limit={create_limit})"
)
await self.push_bridge_state(BridgeStateEvent.BACKFILLING) await self.push_bridge_state(BridgeStateEvent.BACKFILLING)
puppet = await pu.Puppet.get_by_custom_mxid(self.mxid) puppet = await pu.Puppet.get_by_custom_mxid(self.mxid)
dialog: Dialog dialog: Dialog
old_portal_cache = await self.get_cached_portals() old_portal_cache = await self.get_cached_portals()
new_portal_cache = old_portal_cache.copy() new_portal_cache = old_portal_cache.copy()
async for dialog in self.client.iter_dialogs(limit=update_limit, ignore_migrated=True, async for dialog in self.client.iter_dialogs(
archived=False): limit=update_limit, ignore_migrated=True, archived=False
):
entity = dialog.entity entity = dialog.entity
if isinstance(entity, ChatForbidden): if isinstance(entity, ChatForbidden):
self.log.warning(f"Ignoring forbidden chat {entity} while syncing") self.log.warning(f"Ignoring forbidden chat {entity} while syncing")
@@ -557,8 +603,12 @@ class User(DBUser, AbstractUser, BaseUser):
continue continue
portal = await po.Portal.get_by_entity(entity, tg_receiver=self.tgid) portal = await po.Portal.get_by_entity(entity, tg_receiver=self.tgid)
new_portal_cache[portal.tgid_full] = portal new_portal_cache[portal.tgid_full] = portal
coro = self._sync_dialog(portal=portal, dialog=dialog, puppet=puppet, coro = self._sync_dialog(
should_create=not create_limit or index < create_limit) portal=portal,
dialog=dialog,
puppet=puppet,
should_create=not create_limit or index < create_limit,
)
creators.append(self.loop.create_task(coro)) creators.append(self.loop.create_task(coro))
index += 1 index += 1
if new_portal_cache.keys() != old_portal_cache.keys(): if new_portal_cache.keys() != old_portal_cache.keys():
@@ -592,8 +642,8 @@ class User(DBUser, AbstractUser, BaseUser):
def _hash_contacts(count: int, ids: list[TelegramID]) -> int: def _hash_contacts(count: int, ids: list[TelegramID]) -> int:
acc = 0 acc = 0
for contact in sorted([count] + ids): for contact in sorted([count] + ids):
acc = (acc * 20261 + contact) & 0xffffffff acc = (acc * 20261 + contact) & 0xFFFFFFFF
return acc & 0x7fffffff return acc & 0x7FFFFFFF
async def sync_contacts(self) -> None: async def sync_contacts(self) -> None:
existing_contacts = await self.get_contacts() existing_contacts = await self.get_contacts()
+5 -5
View File
@@ -1,7 +1,7 @@
from .file_transfer import transfer_file_to_matrix, convert_image
from .parallel_file_transfer import parallel_transfer_to_telegram
from .recursive_dict import recursive_del, recursive_set, recursive_get
from .color_log import ColorFormatter from .color_log import ColorFormatter
from .send_lock import PortalSendLock
from .deduplication import PortalDedup from .deduplication import PortalDedup
from .media_fallback import make_dice_event_content, make_contact_event_content from .file_transfer import convert_image, transfer_file_to_matrix
from .media_fallback import make_contact_event_content, make_dice_event_content
from .parallel_file_transfer import parallel_transfer_to_telegram
from .recursive_dict import recursive_del, recursive_get, recursive_set
from .send_lock import PortalSendLock
+12 -6
View File
@@ -1,5 +1,5 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge # mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2021 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@@ -13,8 +13,12 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.util.logging.color import (ColorFormatter as BaseColorFormatter, from mautrix.util.logging.color import (
PREFIX, MXID_COLOR, RESET) MXID_COLOR,
PREFIX,
RESET,
ColorFormatter as BaseColorFormatter,
)
TELETHON_COLOR = PREFIX + "35;1m" # magenta TELETHON_COLOR = PREFIX + "35;1m" # magenta
TELETHON_MODULE_COLOR = PREFIX + "35m" TELETHON_MODULE_COLOR = PREFIX + "35m"
@@ -24,7 +28,9 @@ class ColorFormatter(BaseColorFormatter):
def _color_name(self, module: str) -> str: def _color_name(self, module: str) -> str:
if module.startswith("telethon"): if module.startswith("telethon"):
prefix, user_id, module = module.split(".", 2) prefix, user_id, module = module.split(".", 2)
return (f"{TELETHON_COLOR}{prefix}{RESET}." return (
f"{MXID_COLOR}{user_id}{RESET}." f"{TELETHON_COLOR}{prefix}{RESET}."
f"{TELETHON_MODULE_COLOR}{module}{RESET}") f"{MXID_COLOR}{user_id}{RESET}."
f"{TELETHON_MODULE_COLOR}{module}{RESET}"
)
return super()._color_name(module) return super()._color_name(module)
+35 -27
View File
@@ -13,22 +13,29 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Deque, Dict, Tuple, TYPE_CHECKING from __future__ import annotations
from typing import Tuple
from collections import deque from collections import deque
import hashlib import hashlib
from telethon.tl.patched import Message, MessageService from telethon.tl.patched import Message, MessageService
from telethon.tl.types import (MessageMediaContact, MessageMediaDocument, MessageMediaGeo, from telethon.tl.types import (
MessageMediaPhoto, TypeMessage, TypeUpdates, UpdateNewMessage, MessageMediaContact,
UpdateNewChannelMessage) MessageMediaDocument,
MessageMediaGeo,
MessageMediaPhoto,
TypeMessage,
TypeUpdates,
UpdateNewChannelMessage,
UpdateNewMessage,
)
from mautrix.types import EventID from mautrix.types import EventID
from .. import portal as po
from ..types import TelegramID from ..types import TelegramID
if TYPE_CHECKING:
from ..portal import Portal
DedupMXID = Tuple[EventID, TelegramID] DedupMXID = Tuple[EventID, TelegramID]
@@ -36,12 +43,12 @@ class PortalDedup:
pre_db_check: bool = False pre_db_check: bool = False
cache_queue_length: int = 20 cache_queue_length: int = 20
_dedup: Deque[str] _dedup: deque[str]
_dedup_mxid: Dict[str, DedupMXID] _dedup_mxid: dict[str, DedupMXID]
_dedup_action: Deque[str] _dedup_action: deque[str]
_portal: 'Portal' _portal: po.Portal
def __init__(self, portal: 'Portal') -> None: def __init__(self, portal: po.Portal) -> None:
self._dedup = deque() self._dedup = deque()
self._dedup_mxid = {} self._dedup_mxid = {}
self._dedup_action = deque() self._dedup_action = deque()
@@ -49,7 +56,7 @@ class PortalDedup:
@property @property
def _always_force_hash(self) -> bool: def _always_force_hash(self) -> bool:
return self._portal.peer_type == 'chat' return self._portal.peer_type == "chat"
@staticmethod @staticmethod
def _hash_event(event: TypeMessage) -> str: def _hash_event(event: TypeMessage) -> str:
@@ -73,10 +80,7 @@ class PortalDedup:
}[type(event.media)](event.media) }[type(event.media)](event.media)
except KeyError: except KeyError:
pass pass
return hashlib.md5("-" return hashlib.md5("-".join(str(a) for a in hash_content).encode("utf-8")).hexdigest()
.join(str(a) for a in hash_content)
.encode("utf-8")
).hexdigest()
def check_action(self, event: TypeMessage) -> bool: def check_action(self, event: TypeMessage) -> bool:
evt_hash = self._hash_event(event) if self._always_force_hash else event.id evt_hash = self._hash_event(event) if self._always_force_hash else event.id
@@ -89,9 +93,13 @@ class PortalDedup:
self._dedup_action.popleft() self._dedup_action.popleft()
return False return False
def update(self, event: TypeMessage, mxid: DedupMXID = None, def update(
expected_mxid: Optional[DedupMXID] = None, force_hash: bool = False self,
) -> Optional[DedupMXID]: event: TypeMessage,
mxid: DedupMXID = None,
expected_mxid: DedupMXID | None = None,
force_hash: bool = False,
) -> DedupMXID | None:
evt_hash = self._hash_event(event) if self._always_force_hash or force_hash else event.id evt_hash = self._hash_event(event) if self._always_force_hash or force_hash else event.id
try: try:
found_mxid = self._dedup_mxid[evt_hash] found_mxid = self._dedup_mxid[evt_hash]
@@ -103,11 +111,10 @@ class PortalDedup:
self._dedup_mxid[evt_hash] = mxid self._dedup_mxid[evt_hash] = mxid
return None return None
def check(self, event: TypeMessage, mxid: DedupMXID = None, force_hash: bool = False def check(
) -> Optional[DedupMXID]: self, event: TypeMessage, mxid: DedupMXID = None, force_hash: bool = False
evt_hash = (self._hash_event(event) ) -> DedupMXID | None:
if self._always_force_hash or force_hash evt_hash = self._hash_event(event) if self._always_force_hash or force_hash else event.id
else event.id)
if evt_hash in self._dedup: if evt_hash in self._dedup:
return self._dedup_mxid[evt_hash] return self._dedup_mxid[evt_hash]
@@ -120,7 +127,8 @@ class PortalDedup:
def register_outgoing_actions(self, response: TypeUpdates) -> None: def register_outgoing_actions(self, response: TypeUpdates) -> None:
for update in response.updates: for update in response.updates:
check_dedup = (isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)) check_dedup = isinstance(
and isinstance(update.message, MessageService)) update, (UpdateNewMessage, UpdateNewChannelMessage)
) and isinstance(update.message, MessageService)
if check_dedup: if check_dedup:
self.check(update.message) self.check(update.message)
+149 -68
View File
@@ -1,5 +1,5 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge # mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2021 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@@ -13,27 +13,40 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Tuple, Union, Dict from __future__ import annotations
from typing import Optional, Union
from io import BytesIO from io import BytesIO
import time
import logging
import asyncio
import tempfile
import magic
from asyncpg import UniqueViolationError
from sqlite3 import IntegrityError from sqlite3 import IntegrityError
import asyncio
import logging
import tempfile
import time
from telethon.tl.types import (Document, InputFileLocation, InputDocumentFileLocation, from asyncpg import UniqueViolationError
TypePhotoSize, PhotoSize, PhotoCachedSize, InputPhotoFileLocation, from telethon.errors import (
InputPeerPhotoFileLocation) AuthBytesInvalidError,
from telethon.errors import (AuthBytesInvalidError, AuthKeyInvalidError, LocationInvalidError, AuthKeyInvalidError,
SecurityError, FileIdInvalidError) FileIdInvalidError,
LocationInvalidError,
SecurityError,
)
from telethon.tl.types import (
Document,
InputDocumentFileLocation,
InputFileLocation,
InputPeerPhotoFileLocation,
InputPhotoFileLocation,
PhotoCachedSize,
PhotoSize,
TypePhotoSize,
)
import magic
from mautrix.appservice import IntentAPI from mautrix.appservice import IntentAPI
from ..tgclient import MautrixTelegramClient
from ..db import TelegramFile as DBTelegramFile from ..db import TelegramFile as DBTelegramFile
from ..tgclient import MautrixTelegramClient
from ..util import sane_mimetypes from ..util import sane_mimetypes
from .parallel_file_transfer import parallel_transfer_to_matrix from .parallel_file_transfer import parallel_transfer_to_matrix
from .tgs_converter import convert_tgs_to from .tgs_converter import convert_tgs_to
@@ -55,13 +68,21 @@ except ImportError:
log: logging.Logger = logging.getLogger("mau.util") log: logging.Logger = logging.getLogger("mau.util")
TypeLocation = Union[Document, InputDocumentFileLocation, InputPeerPhotoFileLocation, TypeLocation = Union[
InputFileLocation, InputPhotoFileLocation] Document,
InputDocumentFileLocation,
InputPeerPhotoFileLocation,
InputFileLocation,
InputPhotoFileLocation,
]
def convert_image(file: bytes, source_mime: str = "image/webp", target_type: str = "png", def convert_image(
thumbnail_to: Optional[Tuple[int, int]] = None file: bytes,
) -> Tuple[str, bytes, Optional[int], Optional[int]]: source_mime: str = "image/webp",
target_type: str = "png",
thumbnail_to: tuple[int, int] | None = None,
) -> tuple[str, bytes, int | None, int | None]:
if not Image: if not Image:
return source_mime, file, None, None return source_mime, file, None, None
try: try:
@@ -77,8 +98,12 @@ def convert_image(file: bytes, source_mime: str = "image/webp", target_type: str
return source_mime, file, None, None return source_mime, file, None, None
def _read_video_thumbnail(data: bytes, video_ext: str = "mp4", frame_ext: str = "png", def _read_video_thumbnail(
max_size: Tuple[int, int] = (1024, 720)) -> Tuple[bytes, int, int]: data: bytes,
video_ext: str = "mp4",
frame_ext: str = "png",
max_size: tuple[int, int] = (1024, 720),
) -> tuple[bytes, int, int]:
with tempfile.NamedTemporaryFile(prefix="mxtg_video_", suffix=f".{video_ext}") as file: with tempfile.NamedTemporaryFile(prefix="mxtg_video_", suffix=f".{video_ext}") as file:
# We don't have any way to read the video from memory, so save it to disk. # We don't have any way to read the video from memory, so save it to disk.
file.write(data) file.write(data)
@@ -109,11 +134,17 @@ def _location_to_id(location: TypeLocation) -> str:
return str(location.photo_id) return str(location.photo_id)
async def transfer_thumbnail_to_matrix(client: MautrixTelegramClient, intent: IntentAPI, async def transfer_thumbnail_to_matrix(
thumbnail_loc: TypeLocation, mime_type: str, encrypt: bool, client: MautrixTelegramClient,
video: Optional[bytes], custom_data: Optional[bytes] = None, intent: IntentAPI,
width: Optional[int] = None, height: [int] = None thumbnail_loc: TypeLocation,
) -> Optional[DBTelegramFile]: mime_type: str,
encrypt: bool,
video: bytes | None,
custom_data: bytes | None = None,
width: int | None = None,
height: int | None = None,
) -> DBTelegramFile | None:
if not Image or not VideoFileClip: if not Image or not VideoFileClip:
return None return None
@@ -151,28 +182,45 @@ async def transfer_thumbnail_to_matrix(client: MautrixTelegramClient, intent: In
if decryption_info: if decryption_info:
decryption_info.url = content_uri decryption_info.url = content_uri
db_file = DBTelegramFile(id=loc_id, mxc=content_uri, mime_type=mime_type, db_file = DBTelegramFile(
was_converted=False, timestamp=int(time.time()), size=len(file), id=loc_id,
width=width, height=height, decryption_info=decryption_info) mxc=content_uri,
mime_type=mime_type,
was_converted=False,
timestamp=int(time.time()),
size=len(file),
width=width,
height=height,
decryption_info=decryption_info,
)
try: try:
await db_file.insert() await db_file.insert()
except (UniqueViolationError, IntegrityError) as e: except (UniqueViolationError, IntegrityError) as e:
log.exception(f"{e.__class__.__name__} while saving transferred file thumbnail data. " log.exception(
"This was probably caused by two simultaneous transfers of the same file, " f"{e.__class__.__name__} while saving transferred file thumbnail data. "
"and might (but probably won't) cause problems with thumbnails or something.") "This was probably caused by two simultaneous transfers of the same file, "
"and might (but probably won't) cause problems with thumbnails or something."
)
return db_file return db_file
transfer_locks: Dict[str, asyncio.Lock] = {} transfer_locks: dict[str, asyncio.Lock] = {}
TypeThumbnail = Optional[Union[TypeLocation, TypePhotoSize]] TypeThumbnail = Optional[Union[TypeLocation, TypePhotoSize]]
async def transfer_file_to_matrix(client: MautrixTelegramClient, intent: IntentAPI, async def transfer_file_to_matrix(
location: TypeLocation, thumbnail: TypeThumbnail = None, *, client: MautrixTelegramClient,
is_sticker: bool = False, tgs_convert: Optional[dict] = None, intent: IntentAPI,
filename: Optional[str] = None, encrypt: bool = False, location: TypeLocation,
parallel_id: Optional[int] = None) -> Optional[DBTelegramFile]: thumbnail: TypeThumbnail = None,
*,
is_sticker: bool = False,
tgs_convert: dict | None = None,
filename: str | None = None,
encrypt: bool = False,
parallel_id: int | None = None,
) -> DBTelegramFile | None:
location_id = _location_to_id(location) location_id = _location_to_id(location)
if not location_id: if not location_id:
return None return None
@@ -187,17 +235,32 @@ async def transfer_file_to_matrix(client: MautrixTelegramClient, intent: IntentA
lock = asyncio.Lock() lock = asyncio.Lock()
transfer_locks[location_id] = lock transfer_locks[location_id] = lock
async with lock: async with lock:
return await _unlocked_transfer_file_to_matrix(client, intent, location_id, location, return await _unlocked_transfer_file_to_matrix(
thumbnail, is_sticker, tgs_convert, client,
filename, encrypt, parallel_id) intent,
location_id,
location,
thumbnail,
is_sticker,
tgs_convert,
filename,
encrypt,
parallel_id,
)
async def _unlocked_transfer_file_to_matrix(client: MautrixTelegramClient, intent: IntentAPI, async def _unlocked_transfer_file_to_matrix(
loc_id: str, location: TypeLocation, client: MautrixTelegramClient,
thumbnail: TypeThumbnail, is_sticker: bool, intent: IntentAPI,
tgs_convert: Optional[dict], filename: Optional[str], loc_id: str,
encrypt: bool, parallel_id: Optional[int] location: TypeLocation,
) -> Optional[DBTelegramFile]: thumbnail: TypeThumbnail,
is_sticker: bool,
tgs_convert: dict | None,
filename: str | None,
encrypt: bool,
parallel_id: int | None,
) -> DBTelegramFile | None:
db_file = await DBTelegramFile.get(loc_id) db_file = await DBTelegramFile.get(loc_id)
if db_file: if db_file:
return db_file return db_file
@@ -205,8 +268,9 @@ async def _unlocked_transfer_file_to_matrix(client: MautrixTelegramClient, inten
converted_anim = None converted_anim = None
if parallel_id and isinstance(location, Document) and (not is_sticker or not tgs_convert): if parallel_id and isinstance(location, Document) and (not is_sticker or not tgs_convert):
db_file = await parallel_transfer_to_matrix(client, intent, loc_id, location, filename, db_file = await parallel_transfer_to_matrix(
encrypt, parallel_id) client, intent, loc_id, location, filename, encrypt, parallel_id
)
mime_type = location.mime_type mime_type = location.mime_type
file = None file = None
else: else:
@@ -223,12 +287,13 @@ async def _unlocked_transfer_file_to_matrix(client: MautrixTelegramClient, inten
image_converted = False image_converted = False
# A weird bug in alpine/magic makes it return application/octet-stream for gzips... # A weird bug in alpine/magic makes it return application/octet-stream for gzips...
is_tgs = (mime_type == "application/gzip" is_tgs = mime_type == "application/gzip" or (
or (mime_type == "application/octet-stream" mime_type == "application/octet-stream" and magic.from_buffer(file).startswith("gzip")
and magic.from_buffer(file).startswith("gzip"))) )
if is_sticker and tgs_convert and is_tgs: if is_sticker and tgs_convert and is_tgs:
converted_anim = await convert_tgs_to(file, tgs_convert["target"], converted_anim = await convert_tgs_to(
**tgs_convert["args"]) file, tgs_convert["target"], **tgs_convert["args"]
)
mime_type = converted_anim.mime mime_type = converted_anim.mime
file = converted_anim.data file = converted_anim.data
width, height = converted_anim.width, converted_anim.height width, height = converted_anim.width, converted_anim.height
@@ -244,29 +309,45 @@ async def _unlocked_transfer_file_to_matrix(client: MautrixTelegramClient, inten
if decryption_info: if decryption_info:
decryption_info.url = content_uri decryption_info.url = content_uri
db_file = DBTelegramFile(id=loc_id, mxc=content_uri, decryption_info=decryption_info, db_file = DBTelegramFile(
mime_type=mime_type, was_converted=image_converted, id=loc_id,
timestamp=int(time.time()), size=len(file), mxc=content_uri,
width=width, height=height) decryption_info=decryption_info,
mime_type=mime_type,
was_converted=image_converted,
timestamp=int(time.time()),
size=len(file),
width=width,
height=height,
)
if thumbnail and (mime_type.startswith("video/") or mime_type == "image/gif"): if thumbnail and (mime_type.startswith("video/") or mime_type == "image/gif"):
if isinstance(thumbnail, (PhotoSize, PhotoCachedSize)): if isinstance(thumbnail, (PhotoSize, PhotoCachedSize)):
thumbnail = thumbnail.location thumbnail = thumbnail.location
try: try:
db_file.thumbnail = await transfer_thumbnail_to_matrix(client, intent, thumbnail, db_file.thumbnail = await transfer_thumbnail_to_matrix(
video=file, mime_type=mime_type, client, intent, thumbnail, video=file, mime_type=mime_type, encrypt=encrypt
encrypt=encrypt) )
except FileIdInvalidError: except FileIdInvalidError:
log.warning(f"Failed to transfer thumbnail for {thumbnail!s}", exc_info=True) log.warning(f"Failed to transfer thumbnail for {thumbnail!s}", exc_info=True)
elif converted_anim and converted_anim.thumbnail_data: elif converted_anim and converted_anim.thumbnail_data:
db_file.thumbnail = await transfer_thumbnail_to_matrix( db_file.thumbnail = await transfer_thumbnail_to_matrix(
client, intent, location, video=None, encrypt=encrypt, client,
custom_data=converted_anim.thumbnail_data, mime_type=converted_anim.thumbnail_mime, intent,
width=converted_anim.width, height=converted_anim.height) location,
video=None,
encrypt=encrypt,
custom_data=converted_anim.thumbnail_data,
mime_type=converted_anim.thumbnail_mime,
width=converted_anim.width,
height=converted_anim.height,
)
try: try:
await db_file.insert() await db_file.insert()
except (UniqueViolationError, IntegrityError) as e: except (UniqueViolationError, IntegrityError) as e:
log.exception(f"{e.__class__.__name__} while saving transferred file data. " log.exception(
"This was probably caused by two simultaneous transfers of the same file, " f"{e.__class__.__name__} while saving transferred file data. "
"and should not cause any problems.") "This was probably caused by two simultaneous transfers of the same file, "
"and should not cause any problems."
)
return db_file return db_file
+8 -7
View File
@@ -17,11 +17,11 @@ from __future__ import annotations
import html import html
from telethon.tl.types import MessageMediaDice, MessageMediaContact, PeerUser from telethon.tl.types import MessageMediaContact, MessageMediaDice, PeerUser
from mautrix.types import TextMessageEventContent, MessageType, Format from mautrix.types import Format, MessageType, TextMessageEventContent
from .. import puppet as pu, abstract_user as au from .. import abstract_user as au, puppet as pu
from ..types import TelegramID from ..types import TelegramID
try: try:
@@ -36,7 +36,7 @@ def _format_dice(roll: MessageMediaDice) -> str:
0: "\U0001F36B", # "🍫", 0: "\U0001F36B", # "🍫",
1: "\U0001F352", # "🍒", 1: "\U0001F352", # "🍒",
2: "\U0001F34B", # "🍋", 2: "\U0001F34B", # "🍋",
3: "7\ufe0f\u20e3" # "7️⃣", 3: "7\ufe0f\u20e3", # "7️⃣",
} }
res = roll.value - 1 res = roll.value - 1
slot1, slot2, slot3 = emojis[res % 4], emojis[res // 4 % 4], emojis[res // 16] slot1, slot2, slot3 = emojis[res % 4], emojis[res // 4 % 4], emojis[res // 16]
@@ -82,11 +82,12 @@ def make_dice_event_content(roll: MessageMediaDice) -> TextMessageEventContent:
"\U0001F3C0": " Basketball throw", "\U0001F3C0": " Basketball throw",
"\U0001F3B0": " Slot machine", "\U0001F3B0": " Slot machine",
"\U0001F3B3": " Bowling", "\U0001F3B3": " Bowling",
"\u26BD": " Football kick" "\u26BD": " Football kick",
} }
text = f"{roll.emoticon}{emoji_text.get(roll.emoticon, '')} result: {_format_dice(roll)}" text = f"{roll.emoticon}{emoji_text.get(roll.emoticon, '')} result: {_format_dice(roll)}"
content = TextMessageEventContent(msgtype=MessageType.TEXT, format=Format.HTML, body=text, content = TextMessageEventContent(
formatted_body=f"<h4>{text}</h4>") msgtype=MessageType.TEXT, format=Format.HTML, body=text, formatted_body=f"<h4>{text}</h4>"
)
content["net.maunium.telegram.dice"] = {"emoticon": roll.emoticon, "value": roll.value} content["net.maunium.telegram.dice"] = {"emoticon": roll.emoticon, "value": roll.value}
return content return content
+156 -78
View File
@@ -13,34 +13,45 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, List, AsyncGenerator, Union, Awaitable, DefaultDict, Tuple, cast from __future__ import annotations
from typing import AsyncGenerator, Awaitable, Union, cast
from collections import defaultdict from collections import defaultdict
import hashlib
import asyncio import asyncio
import hashlib
import logging import logging
import time
import math import math
import time
from aiohttp import ClientResponse from aiohttp import ClientResponse
from telethon import helpers, utils
from telethon.tl.types import (Document, InputFileLocation, InputDocumentFileLocation,
InputPhotoFileLocation, InputPeerPhotoFileLocation, TypeInputFile,
InputFileBig, InputFile)
from telethon.tl.functions.auth import ExportAuthorizationRequest, ImportAuthorizationRequest
from telethon.tl.functions import InvokeWithLayerRequest
from telethon.tl.functions.upload import (GetFileRequest, SaveFilePartRequest,
SaveBigFilePartRequest)
from telethon.tl.alltlobjects import LAYER
from telethon.network import MTProtoSender
from telethon.crypto import AuthKey from telethon.crypto import AuthKey
from telethon import utils, helpers from telethon.network import MTProtoSender
from telethon.tl.alltlobjects import LAYER
from telethon.tl.functions import InvokeWithLayerRequest
from telethon.tl.functions.auth import ExportAuthorizationRequest, ImportAuthorizationRequest
from telethon.tl.functions.upload import (
GetFileRequest,
SaveBigFilePartRequest,
SaveFilePartRequest,
)
from telethon.tl.types import (
Document,
InputDocumentFileLocation,
InputFile,
InputFileBig,
InputFileLocation,
InputPeerPhotoFileLocation,
InputPhotoFileLocation,
TypeInputFile,
)
from mautrix.appservice import IntentAPI from mautrix.appservice import IntentAPI
from mautrix.types import ContentURI, EncryptedFile from mautrix.types import ContentURI, EncryptedFile
from mautrix.util.logging import TraceLogger from mautrix.util.logging import TraceLogger
from ..tgclient import MautrixTelegramClient
from ..db import TelegramFile as DBTelegramFile from ..db import TelegramFile as DBTelegramFile
from ..tgclient import MautrixTelegramClient
try: try:
from mautrix.crypto.attachments import async_encrypt_attachment from mautrix.crypto.attachments import async_encrypt_attachment
@@ -49,8 +60,13 @@ except ImportError:
log: TraceLogger = cast(TraceLogger, logging.getLogger("mau.util")) log: TraceLogger = cast(TraceLogger, logging.getLogger("mau.util"))
TypeLocation = Union[Document, InputDocumentFileLocation, InputPeerPhotoFileLocation, TypeLocation = Union[
InputFileLocation, InputPhotoFileLocation] Document,
InputDocumentFileLocation,
InputPeerPhotoFileLocation,
InputFileLocation,
InputPhotoFileLocation,
]
class DownloadSender: class DownloadSender:
@@ -59,14 +75,21 @@ class DownloadSender:
remaining: int remaining: int
stride: int stride: int
def __init__(self, sender: MTProtoSender, file: TypeLocation, offset: int, limit: int, def __init__(
stride: int, count: int) -> None: self,
sender: MTProtoSender,
file: TypeLocation,
offset: int,
limit: int,
stride: int,
count: int,
) -> None:
self.sender = sender self.sender = sender
self.request = GetFileRequest(file, offset=offset, limit=limit) self.request = GetFileRequest(file, offset=offset, limit=limit)
self.stride = stride self.stride = stride
self.remaining = count self.remaining = count
async def next(self) -> Optional[bytes]: async def next(self) -> bytes | None:
if not self.remaining: if not self.remaining:
return None return None
result = await self.sender.send(self.request) result = await self.sender.send(self.request)
@@ -80,14 +103,22 @@ class DownloadSender:
class UploadSender: class UploadSender:
sender: MTProtoSender sender: MTProtoSender
request: Union[SaveFilePartRequest, SaveBigFilePartRequest] request: SaveFilePartRequest < SaveBigFilePartRequest
part_count: int part_count: int
stride: int stride: int
previous: Optional[asyncio.Task] previous: asyncio.Task | None
loop: asyncio.AbstractEventLoop loop: asyncio.AbstractEventLoop
def __init__(self, sender: MTProtoSender, file_id: int, part_count: int, big: bool, index: int, def __init__(
stride: int, loop: asyncio.AbstractEventLoop) -> None: self,
sender: MTProtoSender,
file_id: int,
part_count: int,
big: bool,
index: int,
stride: int,
loop: asyncio.AbstractEventLoop,
) -> None:
self.sender = sender self.sender = sender
self.part_count = part_count self.part_count = part_count
if big: if big:
@@ -105,8 +136,10 @@ class UploadSender:
async def _next(self, data: bytes) -> None: async def _next(self, data: bytes) -> None:
self.request.bytes = data self.request.bytes = data
log.trace(f"Sending file part {self.request.file_part}/{self.part_count}" log.trace(
f" with {len(data)} bytes") f"Sending file part {self.request.file_part}/{self.part_count}"
f" with {len(data)} bytes"
)
await self.sender.send(self.request) await self.sender.send(self.request)
self.request.file_part += self.stride self.request.file_part += self.stride
@@ -120,16 +153,17 @@ class ParallelTransferrer:
client: MautrixTelegramClient client: MautrixTelegramClient
loop: asyncio.AbstractEventLoop loop: asyncio.AbstractEventLoop
dc_id: int dc_id: int
senders: Optional[List[Union[DownloadSender, UploadSender]]] senders: list[DownloadSender | UploadSender] | None
auth_key: AuthKey auth_key: AuthKey
upload_ticker: int upload_ticker: int
def __init__(self, client: MautrixTelegramClient, dc_id: Optional[int] = None) -> None: def __init__(self, client: MautrixTelegramClient, dc_id: int | None = None) -> None:
self.client = client self.client = client
self.loop = self.client.loop self.loop = self.client.loop
self.dc_id = dc_id or self.client.session.dc_id self.dc_id = dc_id or self.client.session.dc_id
self.auth_key = (None if dc_id and self.client.session.dc_id != dc_id self.auth_key = (
else self.client.session.auth_key) None if dc_id and self.client.session.dc_id != dc_id else self.client.session.auth_key
)
self.senders = None self.senders = None
self.upload_ticker = 0 self.upload_ticker = 0
@@ -138,14 +172,16 @@ class ParallelTransferrer:
self.senders = None self.senders = None
@staticmethod @staticmethod
def _get_connection_count(file_size: int, max_count: int = 20, def _get_connection_count(
full_size: int = 100 * 1024 * 1024) -> int: file_size: int, max_count: int = 20, full_size: int = 100 * 1024 * 1024
) -> int:
if file_size > full_size: if file_size > full_size:
return max_count return max_count
return math.ceil((file_size / full_size) * max_count) return math.ceil((file_size / full_size) * max_count)
async def _init_download(self, connections: int, file: TypeLocation, part_count: int, async def _init_download(
part_size: int) -> None: self, connections: int, file: TypeLocation, part_count: int, part_size: int
) -> None:
minimum, remainder = divmod(part_count, connections) minimum, remainder = divmod(part_count, connections)
def get_part_count() -> int: def get_part_count() -> int:
@@ -158,52 +194,72 @@ class ParallelTransferrer:
# The first cross-DC sender will export+import the authorization, so we always create it # The first cross-DC sender will export+import the authorization, so we always create it
# before creating any other senders. # before creating any other senders.
self.senders = [ self.senders = [
await self._create_download_sender(file, 0, part_size, connections * part_size, await self._create_download_sender(
get_part_count()), file, 0, part_size, connections * part_size, get_part_count()
),
*await asyncio.gather( *await asyncio.gather(
*(self._create_download_sender(file, i, part_size, connections * part_size, *(
get_part_count()) self._create_download_sender(
for i in range(1, connections))) file, i, part_size, connections * part_size, get_part_count()
)
for i in range(1, connections)
)
),
] ]
async def _create_download_sender(self, file: TypeLocation, index: int, part_size: int, async def _create_download_sender(
stride: int, self, file: TypeLocation, index: int, part_size: int, stride: int, part_count: int
part_count: int) -> DownloadSender: ) -> DownloadSender:
return DownloadSender(await self._create_sender(), file, index * part_size, part_size, return DownloadSender(
stride, part_count) await self._create_sender(), file, index * part_size, part_size, stride, part_count
)
async def _init_upload(self, connections: int, file_id: int, part_count: int, big: bool async def _init_upload(
) -> None: self, connections: int, file_id: int, part_count: int, big: bool
) -> None:
self.senders = [ self.senders = [
await self._create_upload_sender(file_id, part_count, big, 0, connections), await self._create_upload_sender(file_id, part_count, big, 0, connections),
*await asyncio.gather( *await asyncio.gather(
*(self._create_upload_sender(file_id, part_count, big, i, connections) *(
for i in range(1, connections))) self._create_upload_sender(file_id, part_count, big, i, connections)
for i in range(1, connections)
)
),
] ]
async def _create_upload_sender(self, file_id: int, part_count: int, big: bool, index: int, async def _create_upload_sender(
stride: int) -> UploadSender: self, file_id: int, part_count: int, big: bool, index: int, stride: int
return UploadSender(await self._create_sender(), file_id, part_count, big, index, stride, ) -> UploadSender:
loop=self.loop) return UploadSender(
await self._create_sender(), file_id, part_count, big, index, stride, loop=self.loop
)
async def _create_sender(self) -> MTProtoSender: async def _create_sender(self) -> MTProtoSender:
dc = await self.client._get_dc(self.dc_id) dc = await self.client._get_dc(self.dc_id)
sender = MTProtoSender(self.auth_key, loggers=self.client._log) sender = MTProtoSender(self.auth_key, loggers=self.client._log)
await sender.connect(self.client._connection(dc.ip_address, dc.port, dc.id, await sender.connect(
loggers=self.client._log, self.client._connection(
proxy=self.client._proxy)) dc.ip_address, dc.port, dc.id, loggers=self.client._log, proxy=self.client._proxy
)
)
if not self.auth_key: if not self.auth_key:
log.debug(f"Exporting auth to DC {self.dc_id}") log.debug(f"Exporting auth to DC {self.dc_id}")
auth = await self.client(ExportAuthorizationRequest(self.dc_id)) auth = await self.client(ExportAuthorizationRequest(self.dc_id))
self.client._init_request.query = ImportAuthorizationRequest(id=auth.id, self.client._init_request.query = ImportAuthorizationRequest(
bytes=auth.bytes) id=auth.id, bytes=auth.bytes
)
req = InvokeWithLayerRequest(LAYER, self.client._init_request) req = InvokeWithLayerRequest(LAYER, self.client._init_request)
await sender.send(req) await sender.send(req)
self.auth_key = sender.auth_key self.auth_key = sender.auth_key
return sender return sender
async def init_upload(self, file_id: int, file_size: int, part_size_kb: Optional[float] = None, async def init_upload(
connection_count: Optional[int] = None) -> Tuple[int, int, bool]: self,
file_id: int,
file_size: int,
part_size_kb: float | None = None,
connection_count: int | None = None,
) -> tuple[int, int, bool]:
connection_count = connection_count or self._get_connection_count(file_size) connection_count = connection_count or self._get_connection_count(file_size)
part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024 part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
part_count = (file_size + part_size - 1) // part_size part_count = (file_size + part_size - 1) // part_size
@@ -218,14 +274,19 @@ class ParallelTransferrer:
async def finish_upload(self) -> None: async def finish_upload(self) -> None:
await self._cleanup() await self._cleanup()
async def download(self, file: TypeLocation, file_size: int, async def download(
part_size_kb: Optional[float] = None, self,
connection_count: Optional[int] = None) -> AsyncGenerator[bytes, None]: file: TypeLocation,
file_size: int,
part_size_kb: float | None = None,
connection_count: int | None = None,
) -> AsyncGenerator[bytes, None]:
connection_count = connection_count or self._get_connection_count(file_size) connection_count = connection_count or self._get_connection_count(file_size)
part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024 part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
part_count = math.ceil(file_size / part_size) part_count = math.ceil(file_size / part_size)
log.debug("Starting parallel download: " log.debug(
f"{connection_count} {part_size} {part_count} {file!s}") f"Starting parallel download: {connection_count} {part_size} {part_count} {file!s}"
)
await self._init_download(connection_count, file, part_count, part_size) await self._init_download(connection_count, file, part_count, part_size)
part = 0 part = 0
@@ -245,12 +306,18 @@ class ParallelTransferrer:
await self._cleanup() await self._cleanup()
parallel_transfer_locks: DefaultDict[int, asyncio.Lock] = defaultdict(lambda: asyncio.Lock()) parallel_transfer_locks: defaultdict[int, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
async def parallel_transfer_to_matrix(client: MautrixTelegramClient, intent: IntentAPI, async def parallel_transfer_to_matrix(
loc_id: str, location: TypeLocation, filename: str, client: MautrixTelegramClient,
encrypt: bool, parallel_id: int) -> DBTelegramFile: intent: IntentAPI,
loc_id: str,
location: TypeLocation,
filename: str,
encrypt: bool,
parallel_id: int,
) -> DBTelegramFile:
size = location.size size = location.size
mime_type = location.mime_type mime_type = location.mime_type
dc_id, location = utils.get_input_location(location) dc_id, location = utils.get_input_location(location)
@@ -261,6 +328,7 @@ async def parallel_transfer_to_matrix(client: MautrixTelegramClient, intent: Int
decryption_info = None decryption_info = None
up_mime_type = mime_type up_mime_type = mime_type
if encrypt and async_encrypt_attachment: if encrypt and async_encrypt_attachment:
async def encrypted(stream): async def encrypted(stream):
nonlocal decryption_info nonlocal decryption_info
async for chunk in async_encrypt_attachment(stream): async for chunk in async_encrypt_attachment(stream):
@@ -271,17 +339,27 @@ async def parallel_transfer_to_matrix(client: MautrixTelegramClient, intent: Int
data = encrypted(data) data = encrypted(data)
up_mime_type = "application/octet-stream" up_mime_type = "application/octet-stream"
content_uri = await intent.upload_media(data, mime_type=up_mime_type, filename=filename, content_uri = await intent.upload_media(
size=size if not encrypt else None) data, mime_type=up_mime_type, filename=filename, size=size if not encrypt else None
)
if decryption_info: if decryption_info:
decryption_info.url = content_uri decryption_info.url = content_uri
return DBTelegramFile(id=loc_id, mxc=content_uri, mime_type=mime_type, return DBTelegramFile(
was_converted=False, timestamp=int(time.time()), size=size, id=loc_id,
width=None, height=None, decryption_info=decryption_info) mxc=content_uri,
mime_type=mime_type,
was_converted=False,
timestamp=int(time.time()),
size=size,
width=None,
height=None,
decryption_info=decryption_info,
)
async def _internal_transfer_to_telegram(client: MautrixTelegramClient, response: ClientResponse async def _internal_transfer_to_telegram(
) -> Tuple[TypeInputFile, int]: client: MautrixTelegramClient, response: ClientResponse
) -> tuple[TypeInputFile, int]:
file_id = helpers.generate_random_long() file_id = helpers.generate_random_long()
file_size = response.content_length file_size = response.content_length
@@ -313,9 +391,9 @@ async def _internal_transfer_to_telegram(client: MautrixTelegramClient, response
return InputFile(file_id, part_count, "upload", hash_md5.hexdigest()), file_size return InputFile(file_id, part_count, "upload", hash_md5.hexdigest()), file_size
async def parallel_transfer_to_telegram(client: MautrixTelegramClient, intent: IntentAPI, async def parallel_transfer_to_telegram(
uri: ContentURI, parallel_id: int client: MautrixTelegramClient, intent: IntentAPI, uri: ContentURI, parallel_id: int
) -> Tuple[TypeInputFile, int]: ) -> tuple[TypeInputFile, int]:
url = intent.api.get_download_url(uri) url = intent.api.get_download_url(uri)
async with parallel_transfer_locks[parallel_id]: async with parallel_transfer_locks[parallel_id]:
async with intent.api.session.get(url) as response: async with intent.api.session.get(url) as response:
+7 -5
View File
@@ -1,5 +1,5 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge # mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2021 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@@ -13,12 +13,14 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Any from __future__ import annotations
from typing import Any
from mautrix.util.config import RecursiveDict from mautrix.util.config import RecursiveDict
def recursive_set(data: Dict[str, Any], key: str, value: Any) -> bool: def recursive_set(data: dict[str, Any], key: str, value: Any) -> bool:
key, next_key = RecursiveDict.parse_key(key) key, next_key = RecursiveDict.parse_key(key)
if next_key is not None: if next_key is not None:
if key not in data: if key not in data:
@@ -31,7 +33,7 @@ def recursive_set(data: Dict[str, Any], key: str, value: Any) -> bool:
return True return True
def recursive_get(data: Dict[str, Any], key: str) -> Any: def recursive_get(data: dict[str, Any], key: str) -> Any:
key, next_key = RecursiveDict.parse_key(key) key, next_key = RecursiveDict.parse_key(key)
if next_key is not None: if next_key is not None:
next_data = data.get(key, None) next_data = data.get(key, None)
@@ -41,7 +43,7 @@ def recursive_get(data: Dict[str, Any], key: str) -> Any:
return data.get(key, None) return data.get(key, None)
def recursive_del(data: Dict[str, any], key: str) -> bool: def recursive_del(data: dict[str, any], key: str) -> bool:
key, next_key = RecursiveDict.parse_key(key) key, next_key = RecursiveDict.parse_key(key)
if next_key is not None: if next_key is not None:
if key not in data: if key not in data:
+4 -4
View File
@@ -13,7 +13,8 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict from __future__ import annotations
from asyncio import Lock from asyncio import Lock
from ..types import TelegramID from ..types import TelegramID
@@ -28,7 +29,7 @@ class FakeLock:
class PortalSendLock: class PortalSendLock:
_send_locks: Dict[int, Lock] _send_locks: dict[int, Lock]
_noop_lock: Lock = FakeLock() _noop_lock: Lock = FakeLock()
def __init__(self) -> None: def __init__(self) -> None:
@@ -40,5 +41,4 @@ class PortalSendLock:
try: try:
return self._send_locks[user_id] return self._send_locks[user_id]
except KeyError: except KeyError:
return (self._send_locks.setdefault(user_id, Lock()) return self._send_locks.setdefault(user_id, Lock()) if required else self._noop_lock
if required else self._noop_lock)
+102 -43
View File
@@ -14,11 +14,13 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Callable, Awaitable, Optional, Tuple, Any from __future__ import annotations
from typing import Any, Awaitable, Callable
import asyncio.subprocess import asyncio.subprocess
import logging import logging
import shutil
import os.path import os.path
import shutil
import tempfile import tempfile
from attr import dataclass from attr import dataclass
@@ -30,17 +32,17 @@ log: logging.Logger = logging.getLogger("mau.util.tgs")
class ConvertedSticker: class ConvertedSticker:
mime: str mime: str
data: bytes data: bytes
thumbnail_mime: Optional[str] = None thumbnail_mime: str | None = None
thumbnail_data: Optional[bytes] = None thumbnail_data: bytes | None = None
width: int = 0 width: int = 0
height: int = 0 height: int = 0
Converter = Callable[[bytes, int, int, Any], Awaitable[ConvertedSticker]] Converter = Callable[[bytes, int, int, Any], Awaitable[ConvertedSticker]]
converters: Dict[str, Converter] = {} converters: dict[str, Converter] = {}
def abswhich(program: Optional[str]) -> Optional[str]: def abswhich(program: str | None) -> str | None:
path = shutil.which(program) path = shutil.which(program)
return os.path.abspath(path) if path else None return os.path.abspath(path) if path else None
@@ -49,77 +51,134 @@ lottieconverter = abswhich("lottieconverter")
ffmpeg = abswhich("ffmpeg") ffmpeg = abswhich("ffmpeg")
if lottieconverter: if lottieconverter:
async def tgs_to_png(file: bytes, width: int, height: int, **_: Any) -> ConvertedSticker: async def tgs_to_png(file: bytes, width: int, height: int, **_: Any) -> ConvertedSticker:
frame = 1 frame = 1
proc = await asyncio.create_subprocess_exec(lottieconverter, "-", "-", "png", proc = await asyncio.create_subprocess_exec(
f"{width}x{height}", str(frame), lottieconverter,
stdout=asyncio.subprocess.PIPE, "-",
stdin=asyncio.subprocess.PIPE) "-",
"png",
f"{width}x{height}",
str(frame),
stdout=asyncio.subprocess.PIPE,
stdin=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate(file) stdout, stderr = await proc.communicate(file)
if proc.returncode == 0: if proc.returncode == 0:
return ConvertedSticker("image/png", stdout) return ConvertedSticker("image/png", stdout)
else: else:
log.error("lottieconverter error: " + (stderr.decode("utf-8") if stderr is not None log.error(
else f"unknown ({proc.returncode})")) "lottieconverter error: "
+ (
stderr.decode("utf-8")
if stderr is not None
else f"unknown ({proc.returncode})"
)
)
return ConvertedSticker("application/gzip", file) return ConvertedSticker("application/gzip", file)
async def tgs_to_gif(
async def tgs_to_gif(file: bytes, width: int, height: int, fps: int = 25, file: bytes, width: int, height: int, fps: int = 25, **_: Any
**_: Any) -> ConvertedSticker: ) -> ConvertedSticker:
proc = await asyncio.create_subprocess_exec(lottieconverter, "-", "-", "gif", proc = await asyncio.create_subprocess_exec(
f"{width}x{height}", str(fps), lottieconverter,
stdout=asyncio.subprocess.PIPE, "-",
stdin=asyncio.subprocess.PIPE) "-",
"gif",
f"{width}x{height}",
str(fps),
stdout=asyncio.subprocess.PIPE,
stdin=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate(file) stdout, stderr = await proc.communicate(file)
if proc.returncode == 0: if proc.returncode == 0:
return ConvertedSticker("image/gif", stdout) return ConvertedSticker("image/gif", stdout)
else: else:
log.error("lottieconverter error: " + (stderr.decode("utf-8") if stderr is not None log.error(
else f"unknown ({proc.returncode})")) "lottieconverter error: "
+ (
stderr.decode("utf-8")
if stderr is not None
else f"unknown ({proc.returncode})"
)
)
return ConvertedSticker("application/gzip", file) return ConvertedSticker("application/gzip", file)
converters["png"] = tgs_to_png converters["png"] = tgs_to_png
converters["gif"] = tgs_to_gif converters["gif"] = tgs_to_gif
if lottieconverter and ffmpeg: if lottieconverter and ffmpeg:
async def tgs_to_webm(file: bytes, width: int, height: int, fps: int = 30,
**_: Any) -> ConvertedSticker: async def tgs_to_webm(
file: bytes, width: int, height: int, fps: int = 30, **_: Any
) -> ConvertedSticker:
with tempfile.TemporaryDirectory(prefix="tgs_") as tmpdir: with tempfile.TemporaryDirectory(prefix="tgs_") as tmpdir:
file_template = tmpdir + "/out_" file_template = tmpdir + "/out_"
proc = await asyncio.create_subprocess_exec(lottieconverter, "-", file_template, proc = await asyncio.create_subprocess_exec(
"pngs", f"{width}x{height}", str(fps), lottieconverter,
stdout=asyncio.subprocess.PIPE, "-",
stdin=asyncio.subprocess.PIPE) file_template,
"pngs",
f"{width}x{height}",
str(fps),
stdout=asyncio.subprocess.PIPE,
stdin=asyncio.subprocess.PIPE,
)
_, stderr = await proc.communicate(file) _, stderr = await proc.communicate(file)
if proc.returncode == 0: if proc.returncode == 0:
with open(f"{file_template}00.png", "rb") as first_frame_file: with open(f"{file_template}00.png", "rb") as first_frame_file:
first_frame_data = first_frame_file.read() first_frame_data = first_frame_file.read()
proc = await asyncio.create_subprocess_exec(ffmpeg, "-hide_banner", "-loglevel", proc = await asyncio.create_subprocess_exec(
"error", "-framerate", str(fps), ffmpeg,
"-pattern_type", "glob", "-i", "-hide_banner",
file_template + "*.png", "-loglevel",
"-c:v", "libvpx-vp9", "-pix_fmt", "error",
"yuva420p", "-f", "webm", "-", "-framerate",
stdout=asyncio.subprocess.PIPE, str(fps),
stdin=asyncio.subprocess.PIPE) "-pattern_type",
"glob",
"-i",
file_template + "*.png",
"-c:v",
"libvpx-vp9",
"-pix_fmt",
"yuva420p",
"-f",
"webm",
"-",
stdout=asyncio.subprocess.PIPE,
stdin=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate() stdout, stderr = await proc.communicate()
if proc.returncode == 0: if proc.returncode == 0:
return ConvertedSticker("video/webm", stdout, "image/png", first_frame_data) return ConvertedSticker("video/webm", stdout, "image/png", first_frame_data)
else: else:
log.error("ffmpeg error: " + (stderr.decode("utf-8") if stderr is not None log.error(
else f"unknown ({proc.returncode})")) "ffmpeg error: "
+ (
stderr.decode("utf-8")
if stderr is not None
else f"unknown ({proc.returncode})"
)
)
else: else:
log.error("lottieconverter error: " + (stderr.decode("utf-8") if stderr is not None log.error(
else f"unknown ({proc.returncode})")) "lottieconverter error: "
+ (
stderr.decode("utf-8")
if stderr is not None
else f"unknown ({proc.returncode})"
)
)
return ConvertedSticker("application/gzip", file) return ConvertedSticker("application/gzip", file)
converters["webm"] = tgs_to_webm converters["webm"] = tgs_to_webm
async def convert_tgs_to(file: bytes, convert_to: str, width: int, height: int, **kwargs: Any async def convert_tgs_to(
) -> ConvertedSticker: file: bytes, convert_to: str, width: int, height: int, **kwargs: Any
) -> ConvertedSticker:
if convert_to in converters: if convert_to in converters:
converter = converters[convert_to] converter = converters[convert_to]
converted = await converter(file, width, height, **kwargs) converted = await converter(file, width, height, **kwargs)
+1 -1
View File
@@ -1 +1 @@
from .get_version import git_tag, git_revision, version, linkified_version from .get_version import git_revision, git_tag, linkified_version, version
+210 -84
View File
@@ -1,5 +1,5 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge # mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2021 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@@ -13,17 +13,31 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod
import abc import abc
import asyncio import asyncio
import logging import logging
from aiohttp import web from aiohttp import web
from telethon.errors import (
AccessTokenExpiredError,
AccessTokenInvalidError,
FloodWaitError,
PasswordEmptyError,
PasswordHashInvalidError,
PhoneCodeExpiredError,
PhoneCodeInvalidError,
PhoneNumberAppSignupForbiddenError,
PhoneNumberBannedError,
PhoneNumberFloodError,
PhoneNumberInvalidError,
PhoneNumberUnoccupiedError,
SessionPasswordNeededError,
)
from telethon.errors import * from mautrix.bridge import InvalidAccessToken, OnlyLoginSelf
from mautrix.bridge import OnlyLoginSelf, InvalidAccessToken
from mautrix.util.format_duration import format_duration from mautrix.util.format_duration import format_duration
from ...commands.telegram.auth import enter_password from ...commands.telegram.auth import enter_password
@@ -39,81 +53,141 @@ class AuthAPI(abc.ABC):
self.loop = loop self.loop = loop
@abstractmethod @abstractmethod
def get_login_response(self, status: int = 200, state: str = "", username: str = "", def get_login_response(
phone: str = "", human_tg_id: str = "", mxid: str = "", self,
message: str = "", error: str = "", errcode: str = "") -> web.Response: status: int = 200,
state: str = "",
username: str = "",
phone: str = "",
human_tg_id: str = "",
mxid: str = "",
message: str = "",
error: str = "",
errcode: str = "",
) -> web.Response:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def get_mx_login_response(self, status: int = 200, state: str = "", username: str = "", def get_mx_login_response(
phone: str = "", human_tg_id: str = "", mxid: str = "", self,
message: str = "", error: str = "", errcode: str = "" status: int = 200,
) -> web.Response: state: str = "",
username: str = "",
phone: str = "",
human_tg_id: str = "",
mxid: str = "",
message: str = "",
error: str = "",
errcode: str = "",
) -> web.Response:
raise NotImplementedError() raise NotImplementedError()
async def post_matrix_token(self, user: User, token: str) -> web.Response: async def post_matrix_token(self, user: User, token: str) -> web.Response:
puppet = await Puppet.get_by_tgid(user.tgid) puppet = await Puppet.get_by_tgid(user.tgid)
if puppet.is_real_user: if puppet.is_real_user:
return self.get_mx_login_response(state="already-logged-in", status=409, return self.get_mx_login_response(
error="You have already logged in with your Matrix " state="already-logged-in",
"account.", errcode="already-logged-in") status=409,
error="You have already logged in with your Matrix account.",
errcode="already-logged-in",
)
try: try:
await puppet.switch_mxid(token.strip(), user.mxid) await puppet.switch_mxid(token.strip(), user.mxid)
except OnlyLoginSelf: except OnlyLoginSelf:
return self.get_mx_login_response(status=403, errcode="only-login-self", return self.get_mx_login_response(
error="You can only log in as your own Matrix user.") status=403,
errcode="only-login-self",
error="You can only log in as your own Matrix user.",
)
except InvalidAccessToken: except InvalidAccessToken:
return self.get_mx_login_response(status=401, errcode="invalid-access-token", return self.get_mx_login_response(
error="Failed to verify access token.") status=401, errcode="invalid-access-token", error="Failed to verify access token."
)
return self.get_mx_login_response(mxid=user.mxid, status=200, state="logged-in") return self.get_mx_login_response(mxid=user.mxid, status=200, state="logged-in")
async def post_matrix_password(self, user: User, password: str) -> web.Response: async def post_matrix_password(self, user: User, password: str) -> web.Response:
return self.get_mx_login_response(mxid=user.mxid, status=501, error="Not yet implemented", return self.get_mx_login_response(
errcode="not-yet-implemented") mxid=user.mxid, status=501, error="Not yet implemented", errcode="not-yet-implemented"
)
async def post_login_phone(self, user: User, phone: str) -> web.Response: async def post_login_phone(self, user: User, phone: str) -> web.Response:
if not phone or not phone.strip(): if not phone or not phone.strip():
return self.get_login_response(mxid=user.mxid, state="request", status=400, return self.get_login_response(
errcode="phone_number_invalid", mxid=user.mxid,
error="Phone number not given.") state="request",
status=400,
errcode="phone_number_invalid",
error="Phone number not given.",
)
try: try:
await user.client.sign_in(phone.strip()) await user.client.sign_in(phone.strip())
return self.get_login_response(mxid=user.mxid, state="code", status=200, return self.get_login_response(
message="Code requested successfully. Check your SMS " mxid=user.mxid,
"or Telegram client and enter the code below.") state="code",
status=200,
message="Code requested successfully. Check your SMS "
"or Telegram client and enter the code below.",
)
except PhoneNumberInvalidError: except PhoneNumberInvalidError:
return self.get_login_response(mxid=user.mxid, state="request", status=400, return self.get_login_response(
errcode="phone_number_invalid", mxid=user.mxid,
error="Invalid phone number.") state="request",
status=400,
errcode="phone_number_invalid",
error="Invalid phone number.",
)
except PhoneNumberBannedError: except PhoneNumberBannedError:
return self.get_login_response(mxid=user.mxid, state="request", status=403, return self.get_login_response(
errcode="phone_number_banned", mxid=user.mxid,
error="Your phone number is banned from Telegram.") state="request",
status=403,
errcode="phone_number_banned",
error="Your phone number is banned from Telegram.",
)
except PhoneNumberAppSignupForbiddenError: except PhoneNumberAppSignupForbiddenError:
return self.get_login_response(mxid=user.mxid, state="request", status=403, return self.get_login_response(
errcode="phone_number_app_signup_forbidden", mxid=user.mxid,
error="You have disabled 3rd party apps on your " state="request",
"account.") status=403,
errcode="phone_number_app_signup_forbidden",
error="You have disabled 3rd party apps on your account.",
)
except PhoneNumberUnoccupiedError: except PhoneNumberUnoccupiedError:
return self.get_login_response(mxid=user.mxid, state="request", status=404, return self.get_login_response(
errcode="phone_number_unoccupied", mxid=user.mxid,
error="That phone number has not been registered.") state="request",
status=404,
errcode="phone_number_unoccupied",
error="That phone number has not been registered.",
)
except PhoneNumberFloodError: except PhoneNumberFloodError:
return self.get_login_response( return self.get_login_response(
mxid=user.mxid, state="request", status=429, errcode="phone_number_flood", mxid=user.mxid,
state="request",
status=429,
errcode="phone_number_flood",
error="Your phone number has been temporarily blocked for flooding. " error="Your phone number has been temporarily blocked for flooding. "
"The ban is usually applied for around a day.") "The ban is usually applied for around a day.",
)
except FloodWaitError as e: except FloodWaitError as e:
return self.get_login_response( return self.get_login_response(
mxid=user.mxid, state="request", status=429, errcode="flood_wait", mxid=user.mxid,
state="request",
status=429,
errcode="flood_wait",
error="Your phone number has been temporarily blocked for flooding. " error="Your phone number has been temporarily blocked for flooding. "
f"Please wait for {format_duration(e.seconds)} before trying again.") f"Please wait for {format_duration(e.seconds)} before trying again.",
)
except Exception: except Exception:
self.log.exception("Error requesting phone code") self.log.exception("Error requesting phone code")
return self.get_login_response(mxid=user.mxid, state="request", status=500, return self.get_login_response(
errcode="unknown_error", mxid=user.mxid,
error="Internal server error while requesting code.") state="request",
status=500,
errcode="unknown_error",
error="Internal server error while requesting code.",
)
async def postprocess_login(self, user: User, user_info) -> None: async def postprocess_login(self, user: User, user_info) -> None:
existing_user = await User.get_by_tgid(user_info.id) existing_user = await User.get_by_tgid(user_info.id)
@@ -127,39 +201,70 @@ class AuthAPI(abc.ABC):
try: try:
user_info = await user.client.sign_in(bot_token=token.strip()) user_info = await user.client.sign_in(bot_token=token.strip())
await self.postprocess_login(user, user_info) await self.postprocess_login(user, user_info)
return self.get_login_response(mxid=user.mxid, state="logged-in", status=200, return self.get_login_response(
username=user_info.username, phone=None, mxid=user.mxid,
human_tg_id=f"@{user_info.username}") state="logged-in",
status=200,
username=user_info.username,
phone=None,
human_tg_id=f"@{user_info.username}",
)
except AccessTokenInvalidError: except AccessTokenInvalidError:
return self.get_login_response(mxid=user.mxid, state="token", status=401, return self.get_login_response(
errcode="bot_token_invalid", mxid=user.mxid,
error="Bot token invalid.") state="token",
status=401,
errcode="bot_token_invalid",
error="Bot token invalid.",
)
except AccessTokenExpiredError: except AccessTokenExpiredError:
return self.get_login_response(mxid=user.mxid, state="token", status=403, return self.get_login_response(
errcode="bot_token_expired", mxid=user.mxid,
error="Bot token expired.") state="token",
status=403,
errcode="bot_token_expired",
error="Bot token expired.",
)
except Exception: except Exception:
self.log.exception("Error sending bot token") self.log.exception("Error sending bot token")
return self.get_login_response(mxid=user.mxid, state="token", status=500, return self.get_login_response(
error="Internal server error while sending token.") mxid=user.mxid,
state="token",
status=500,
error="Internal server error while sending token.",
)
async def post_login_code(self, user: User, code: int, password_in_data: bool async def post_login_code(
) -> Optional[web.Response]: self, user: User, code: int, password_in_data: bool
) -> web.Response | None:
try: try:
user_info = await user.client.sign_in(code=code) user_info = await user.client.sign_in(code=code)
await self.postprocess_login(user, user_info) await self.postprocess_login(user, user_info)
human_tg_id = f"@{user_info.username}" if user_info.username else f"+{user_info.phone}" human_tg_id = f"@{user_info.username}" if user_info.username else f"+{user_info.phone}"
return self.get_login_response(mxid=user.mxid, state="logged-in", status=200, return self.get_login_response(
username=user_info.username, phone=user_info.phone, mxid=user.mxid,
human_tg_id=human_tg_id) state="logged-in",
status=200,
username=user_info.username,
phone=user_info.phone,
human_tg_id=human_tg_id,
)
except PhoneCodeInvalidError: except PhoneCodeInvalidError:
return self.get_login_response(mxid=user.mxid, state="code", status=401, return self.get_login_response(
errcode="phone_code_invalid", mxid=user.mxid,
error="Incorrect phone code.") state="code",
status=401,
errcode="phone_code_invalid",
error="Incorrect phone code.",
)
except PhoneCodeExpiredError: except PhoneCodeExpiredError:
return self.get_login_response(mxid=user.mxid, state="code", status=403, return self.get_login_response(
errcode="phone_code_expired", mxid=user.mxid,
error="Phone code expired.") state="code",
status=403,
errcode="phone_code_expired",
error="Phone code expired.",
)
except SessionPasswordNeededError: except SessionPasswordNeededError:
if not password_in_data: if not password_in_data:
if user.command_status and user.command_status["action"] == "Login": if user.command_status and user.command_status["action"] == "Login":
@@ -177,28 +282,49 @@ class AuthAPI(abc.ABC):
return None return None
except Exception: except Exception:
self.log.exception("Error sending phone code") self.log.exception("Error sending phone code")
return self.get_login_response(mxid=user.mxid, state="code", status=500, return self.get_login_response(
errcode="unknown_error", mxid=user.mxid,
error="Internal server error while sending code.") state="code",
status=500,
errcode="unknown_error",
error="Internal server error while sending code.",
)
async def post_login_password(self, user: User, password: str) -> web.Response: async def post_login_password(self, user: User, password: str) -> web.Response:
try: try:
user_info = await user.client.sign_in(password=password.strip()) user_info = await user.client.sign_in(password=password.strip())
await self.postprocess_login(user, user_info) await self.postprocess_login(user, user_info)
human_tg_id = f"@{user_info.username}" if user_info.username else f"+{user_info.phone}" human_tg_id = f"@{user_info.username}" if user_info.username else f"+{user_info.phone}"
return self.get_login_response(mxid=user.mxid, state="logged-in", status=200, return self.get_login_response(
username=user_info.username, phone=user_info.phone, mxid=user.mxid,
human_tg_id=human_tg_id) state="logged-in",
status=200,
username=user_info.username,
phone=user_info.phone,
human_tg_id=human_tg_id,
)
except PasswordEmptyError: except PasswordEmptyError:
return self.get_login_response(mxid=user.mxid, state="password", status=400, return self.get_login_response(
errcode="password_empty", mxid=user.mxid,
error="Empty password.") state="password",
status=400,
errcode="password_empty",
error="Empty password.",
)
except PasswordHashInvalidError: except PasswordHashInvalidError:
return self.get_login_response(mxid=user.mxid, state="password", status=401, return self.get_login_response(
errcode="password_invalid", mxid=user.mxid,
error="Incorrect password.") state="password",
status=401,
errcode="password_invalid",
error="Incorrect password.",
)
except Exception: except Exception:
self.log.exception("Error sending password") self.log.exception("Error sending password")
return self.get_login_response(mxid=user.mxid, state="password", status=500, return self.get_login_response(
errcode="unknown_error", mxid=user.mxid,
error="Internal server error while sending password.") state="password",
status=500,
errcode="unknown_error",
error="Internal server error while sending password.",
)
+256 -142
View File
@@ -1,5 +1,5 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge # mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2021 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@@ -13,24 +13,25 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Callable, Dict, Optional, Tuple, TYPE_CHECKING from __future__ import annotations
from typing import TYPE_CHECKING, Awaitable, Callable
import asyncio import asyncio
import logging
import json import json
import logging
from aiohttp import web from aiohttp import web
from telethon.tl.types import ChannelForbidden, ChatForbidden, TypeChat
from telethon.utils import get_peer_id, resolve_id from telethon.utils import get_peer_id, resolve_id
from telethon.tl.types import ChatForbidden, ChannelForbidden, TypeChat
from mautrix.appservice import AppService from mautrix.appservice import AppService
from mautrix.errors import MatrixRequestError, IntentError from mautrix.errors import IntentError, MatrixRequestError
from mautrix.types import UserID from mautrix.types import UserID
from ...commands.portal.util import get_initial_state, user_has_power_level
from ...portal import Portal
from ...types import TelegramID from ...types import TelegramID
from ...user import User from ...user import User
from ...portal import Portal
from ...commands.portal.util import user_has_power_level, get_initial_state
from ..common import AuthAPI from ..common import AuthAPI
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -41,7 +42,7 @@ class ProvisioningAPI(AuthAPI):
log: logging.Logger = logging.getLogger("mau.web.provisioning") log: logging.Logger = logging.getLogger("mau.web.provisioning")
secret: str secret: str
az: AppService az: AppService
bridge: 'TelegramBridge' bridge: "TelegramBridge"
app: web.Application app: web.Application
def __init__(self, bridge: "TelegramBridge") -> None: def __init__(self, bridge: "TelegramBridge") -> None:
@@ -55,8 +56,9 @@ class ProvisioningAPI(AuthAPI):
portal_prefix = "/portal/{mxid:![^/]+}" portal_prefix = "/portal/{mxid:![^/]+}"
self.app.router.add_route("GET", f"{portal_prefix}", self.get_portal_by_mxid) self.app.router.add_route("GET", f"{portal_prefix}", self.get_portal_by_mxid)
self.app.router.add_route("GET", "/portal/{tgid:-[0-9]+}", self.get_portal_by_tgid) self.app.router.add_route("GET", "/portal/{tgid:-[0-9]+}", self.get_portal_by_tgid)
self.app.router.add_route("POST", portal_prefix + "/connect/{chat_id:-[0-9]+}", self.app.router.add_route(
self.connect_chat) "POST", portal_prefix + "/connect/{chat_id:-[0-9]+}", self.connect_chat
)
self.app.router.add_route("POST", f"{portal_prefix}/create", self.create_chat) self.app.router.add_route("POST", f"{portal_prefix}/create", self.create_chat)
self.app.router.add_route("POST", f"{portal_prefix}/disconnect", self.disconnect_chat) self.app.router.add_route("POST", f"{portal_prefix}/disconnect", self.disconnect_chat)
@@ -80,8 +82,9 @@ class ProvisioningAPI(AuthAPI):
mxid = request.match_info["mxid"] mxid = request.match_info["mxid"]
portal = await Portal.get_by_mxid(mxid) portal = await Portal.get_by_mxid(mxid)
if not portal: if not portal:
return self.get_error_response(404, "portal_not_found", return self.get_error_response(
"Portal with given Matrix ID not found.") 404, "portal_not_found", "Portal with given Matrix ID not found."
)
return await self._get_portal_response(UserID(request.query.get("user_id", "")), portal) return await self._get_portal_response(UserID(request.query.get("user_id", "")), portal)
async def get_portal_by_tgid(self, request: web.Request) -> web.Response: async def get_portal_by_tgid(self, request: web.Request) -> web.Response:
@@ -92,26 +95,30 @@ class ProvisioningAPI(AuthAPI):
try: try:
tgid, _ = resolve_id(int(request.match_info["tgid"])) tgid, _ = resolve_id(int(request.match_info["tgid"]))
except ValueError: except ValueError:
return self.get_error_response(400, "tgid_invalid", return self.get_error_response(400, "tgid_invalid", "Given chat ID is not valid.")
"Given chat ID is not valid.")
portal = await Portal.get_by_tgid(tgid) portal = await Portal.get_by_tgid(tgid)
if not portal: if not portal:
return self.get_error_response(404, "portal_not_found", return self.get_error_response(
"Portal to given Telegram chat not found.") 404, "portal_not_found", "Portal to given Telegram chat not found."
)
return await self._get_portal_response(UserID(request.query.get("user_id", "")), portal) return await self._get_portal_response(UserID(request.query.get("user_id", "")), portal)
async def _get_portal_response(self, user_id: UserID, portal: Portal) -> web.Response: async def _get_portal_response(self, user_id: UserID, portal: Portal) -> web.Response:
user, _ = await self.get_user(user_id, expect_logged_in=None, require_puppeting=False) user, _ = await self.get_user(user_id, expect_logged_in=None, require_puppeting=False)
return web.json_response({ return web.json_response(
"mxid": portal.mxid, {
"chat_id": get_peer_id(portal.peer), "mxid": portal.mxid,
"peer_type": portal.peer_type, "chat_id": get_peer_id(portal.peer),
"title": portal.title, "peer_type": portal.peer_type,
"about": portal.about, "title": portal.title,
"username": portal.username, "about": portal.about,
"megagroup": portal.megagroup, "username": portal.username,
"can_unbridge": (await portal.can_user_perform(user, "unbridge")) if user else False, "megagroup": portal.megagroup,
}) "can_unbridge": (await portal.can_user_perform(user, "unbridge"))
if user
else False,
}
)
async def connect_chat(self, request: web.Request) -> web.Response: async def connect_chat(self, request: web.Request) -> web.Response:
err = self.check_authorization(request) err = self.check_authorization(request)
@@ -120,8 +127,9 @@ class ProvisioningAPI(AuthAPI):
room_id = request.match_info["mxid"] room_id = request.match_info["mxid"]
if await Portal.get_by_mxid(room_id): if await Portal.get_by_mxid(room_id):
return self.get_error_response(409, "room_already_bridged", return self.get_error_response(
"Room is already bridged to another Telegram chat.") 409, "room_already_bridged", "Room is already bridged to another Telegram chat."
)
chat_id = request.match_info["chat_id"] chat_id = request.match_info["chat_id"]
if chat_id.startswith("-100"): if chat_id.startswith("-100"):
@@ -133,38 +141,51 @@ class ProvisioningAPI(AuthAPI):
else: else:
return self.get_error_response(400, "tgid_invalid", "Invalid Telegram chat ID.") return self.get_error_response(400, "tgid_invalid", "Invalid Telegram chat ID.")
user, err = await self.get_user(request.query.get("user_id", None), expect_logged_in=None, user, err = await self.get_user(
require_puppeting=False) request.query.get("user_id", None), expect_logged_in=None, require_puppeting=False
)
if err is not None: if err is not None:
return err return err
elif user and not await user_has_power_level(room_id, self.az.intent, user, "bridge"): elif user and not await user_has_power_level(room_id, self.az.intent, user, "bridge"):
return self.get_error_response(403, "not_enough_permissions", return self.get_error_response(
"You do not have the permissions to bridge that room.") 403,
"not_enough_permissions",
"You do not have the permissions to bridge that room.",
)
is_logged_in = user is not None and await user.is_logged_in() is_logged_in = user is not None and await user.is_logged_in()
acting_user = user if is_logged_in else self.bridge.bot acting_user = user if is_logged_in else self.bridge.bot
if not acting_user: if not acting_user:
return self.get_login_response(status=403, errcode="not_logged_in", return self.get_login_response(
error="You are not logged in and there is no relay bot.") status=403,
errcode="not_logged_in",
error="You are not logged in and there is no relay bot.",
)
portal = await Portal.get_by_tgid(tgid, peer_type=peer_type) portal = await Portal.get_by_tgid(tgid, peer_type=peer_type)
if portal.mxid == room_id: if portal.mxid == room_id:
return self.get_error_response(200, "bridge_exists", return self.get_error_response(
"Telegram chat is already bridged to that Matrix room.") 200, "bridge_exists", "Telegram chat is already bridged to that Matrix room."
)
elif portal.mxid: elif portal.mxid:
force = request.query.get("force", None) force = request.query.get("force", None)
if force in ("delete", "unbridge"): if force in ("delete", "unbridge"):
delete = force == "delete" delete = force == "delete"
await portal.cleanup_portal("Portal deleted (moving to another room)" if delete await portal.cleanup_portal(
else "Room unbridged (portal moving to another room)", "Portal deleted (moving to another room)"
puppets_only=not delete) if delete
else "Room unbridged (portal moving to another room)",
puppets_only=not delete,
)
else: else:
return self.get_error_response(409, "chat_already_bridged", return self.get_error_response(
"Telegram chat is already bridged to another " 409,
"Matrix room.") "chat_already_bridged",
"Telegram chat is already bridged to another Matrix room.",
)
async with portal._room_create_lock: async with portal._room_create_lock:
entity: Optional[TypeChat] = None entity: TypeChat | None = None
try: try:
entity = await acting_user.client.get_entity(portal.peer) entity = await acting_user.client.get_entity(portal.peer)
except Exception: except Exception:
@@ -172,22 +193,28 @@ class ProvisioningAPI(AuthAPI):
if not entity or isinstance(entity, (ChatForbidden, ChannelForbidden)): if not entity or isinstance(entity, (ChatForbidden, ChannelForbidden)):
if is_logged_in: if is_logged_in:
return self.get_error_response(403, "user_not_in_chat", return self.get_error_response(
"Failed to get info of Telegram chat. " 403,
"Are you in the chat?") "user_not_in_chat",
return self.get_error_response(403, "bot_not_in_chat", "Failed to get info of Telegram chat. Are you in the chat?",
"Failed to get info of Telegram chat. " )
"Is the relay bot in the chat?") return self.get_error_response(
403,
"bot_not_in_chat",
"Failed to get info of Telegram chat. Is the relay bot in the chat?",
)
portal.mxid = room_id portal.mxid = room_id
portal.by_mxid[portal.mxid] = portal portal.by_mxid[portal.mxid] = portal
(portal.title, portal.about, levels, (portal.title, portal.about, levels, portal.encrypted) = await get_initial_state(
portal.encrypted) = await get_initial_state(self.az.intent, room_id) self.az.intent, room_id
)
portal.photo_id = "" portal.photo_id = ""
await portal.save() await portal.save()
asyncio.ensure_future(portal.update_matrix_room(user, entity, direct=False, levels=levels), asyncio.ensure_future(
loop=self.loop) portal.update_matrix_room(user, entity, direct=False, levels=levels), loop=self.loop
)
return web.Response(status=202, body="{}") return web.Response(status=202, body="{}")
@@ -202,25 +229,32 @@ class ProvisioningAPI(AuthAPI):
room_id = request.match_info["mxid"] room_id = request.match_info["mxid"]
if await Portal.get_by_mxid(room_id): if await Portal.get_by_mxid(room_id):
return self.get_error_response(409, "room_already_bridged", return self.get_error_response(
"Room is already bridged to another Telegram chat.") 409, "room_already_bridged", "Room is already bridged to another Telegram chat."
)
user, err = await self.get_user(request.query.get("user_id", None), expect_logged_in=None, user, err = await self.get_user(
require_puppeting=False) request.query.get("user_id", None), expect_logged_in=None, require_puppeting=False
)
if err is not None: if err is not None:
return err return err
elif not await user.is_logged_in() or user.is_bot: elif not await user.is_logged_in() or user.is_bot:
return self.get_error_response(403, "not_logged_in_real_account", return self.get_error_response(
"You are not logged in with a real account.") 403, "not_logged_in_real_account", "You are not logged in with a real account."
)
elif not await user_has_power_level(room_id, self.az.intent, user, "bridge"): elif not await user_has_power_level(room_id, self.az.intent, user, "bridge"):
return self.get_error_response(403, "not_enough_permissions", return self.get_error_response(
"You do not have the permissions to bridge that room.") 403,
"not_enough_permissions",
"You do not have the permissions to bridge that room.",
)
try: try:
title, about, _, encrypted = await get_initial_state(self.az.intent, room_id) title, about, _, encrypted = await get_initial_state(self.az.intent, room_id)
except (MatrixRequestError, IntentError): except (MatrixRequestError, IntentError):
return self.get_error_response(403, "bot_not_in_room", return self.get_error_response(
"The bridge bot is not in the given room.") 403, "bot_not_in_room", "The bridge bot is not in the given room."
)
about = data.get("about", about) about = data.get("about", about)
@@ -230,8 +264,9 @@ class ProvisioningAPI(AuthAPI):
type = data.get("type", "") type = data.get("type", "")
if type not in ("group", "chat", "supergroup", "channel"): if type not in ("group", "chat", "supergroup", "channel"):
return self.get_error_response(400, "body_value_invalid", return self.get_error_response(
"Given chat type is not valid.") 400, "body_value_invalid", "Given chat type is not valid."
)
supergroup = type == "supergroup" supergroup = type == "supergroup"
type = { type = {
@@ -241,17 +276,27 @@ class ProvisioningAPI(AuthAPI):
"group": "chat", "group": "chat",
}[type] }[type]
portal = Portal(tgid=TelegramID(0), mxid=room_id, title=title, about=about, peer_type=type, portal = Portal(
encrypted=encrypted, tg_receiver=TelegramID(0)) tgid=TelegramID(0),
mxid=room_id,
title=title,
about=about,
peer_type=type,
encrypted=encrypted,
tg_receiver=TelegramID(0),
)
try: try:
await portal.create_telegram_chat(user, supergroup=supergroup) await portal.create_telegram_chat(user, supergroup=supergroup)
except ValueError as e: except ValueError as e:
await portal.delete() await portal.delete()
return self.get_error_response(500, "unknown_error", e.args[0]) return self.get_error_response(500, "unknown_error", e.args[0])
return web.json_response({ return web.json_response(
"chat_id": portal.tgid, {
}, status=201) "chat_id": portal.tgid,
},
status=201,
)
async def disconnect_chat(self, request: web.Request) -> web.Response: async def disconnect_chat(self, request: web.Request) -> web.Response:
err = self.check_authorization(request) err = self.check_authorization(request)
@@ -260,17 +305,24 @@ class ProvisioningAPI(AuthAPI):
portal = await Portal.get_by_mxid(request.match_info["mxid"]) portal = await Portal.get_by_mxid(request.match_info["mxid"])
if not portal or not portal.tgid: if not portal or not portal.tgid:
return self.get_error_response(404, "portal_not_found", return self.get_error_response(404, "portal_not_found", "Room is not a portal.")
"Room is not a portal.")
user, err = await self.get_user(request.query.get("user_id", None), expect_logged_in=None, user, err = await self.get_user(
require_puppeting=False, require_user=False) request.query.get("user_id", None),
expect_logged_in=None,
require_puppeting=False,
require_user=False,
)
if err is not None: if err is not None:
return err return err
elif user and not await user_has_power_level(portal.mxid, self.az.intent, user, elif user and not await user_has_power_level(
"unbridge"): portal.mxid, self.az.intent, user, "unbridge"
return self.get_error_response(403, "not_enough_permissions", ):
"You do not have the permissions to unbridge that room.") return self.get_error_response(
403,
"not_enough_permissions",
"You do not have the permissions to unbridge that room.",
)
delete = request.query.get("delete", "").lower() in ("true", "t", "1", "yes", "y") delete = request.query.get("delete", "").lower() in ("true", "t", "1", "yes", "y")
sync = request.query.get("delete", "").lower() in ("true", "t", "1", "yes", "y") sync = request.query.get("delete", "").lower() in ("true", "t", "1", "yes", "y")
@@ -287,8 +339,9 @@ class ProvisioningAPI(AuthAPI):
return web.json_response({}, status=200 if sync else 202) return web.json_response({}, status=200 if sync else 202)
async def get_user_info(self, request: web.Request) -> web.Response: async def get_user_info(self, request: web.Request) -> web.Response:
data, user, err = await self.get_user_request_info(request, expect_logged_in=None, data, user, err = await self.get_user_request_info(
require_puppeting=False) request, expect_logged_in=None, require_puppeting=False
)
if err is not None: if err is not None:
return err return err
@@ -305,11 +358,13 @@ class ProvisioningAPI(AuthAPI):
"phone": user.tg_phone, "phone": user.tg_phone,
"is_bot": user.is_bot, "is_bot": user.is_bot,
} }
return web.json_response({ return web.json_response(
"telegram": user_data, {
"mxid": user.mxid, "telegram": user_data,
"permissions": user.permissions, "mxid": user.mxid,
}) "permissions": user.permissions,
}
)
async def get_chats(self, request: web.Request) -> web.Response: async def get_chats(self, request: web.Request) -> web.Response:
data, user, err = await self.get_user_request_info(request, expect_logged_in=True) data, user, err = await self.get_user_request_info(request, expect_logged_in=True)
@@ -317,15 +372,28 @@ class ProvisioningAPI(AuthAPI):
return err return err
if not user.is_bot: if not user.is_bot:
return web.json_response([{ return web.json_response(
"id": chat.id, [
"title": chat.title, {
} async for chat in user.client.iter_dialogs(ignore_migrated=True, archived=False)]) "id": chat.id,
"title": chat.title,
}
async for chat in user.client.iter_dialogs(
ignore_migrated=True, archived=False
)
]
)
else: else:
return web.json_response([{ return web.json_response(
"id": get_peer_id(chat.peer), [
"title": chat.title, {
} for chat in (await user.get_cached_portals()).values() if chat.tgid]) "id": get_peer_id(chat.peer),
"title": chat.title,
}
for chat in (await user.get_cached_portals()).values()
if chat.tgid
]
)
async def send_bot_token(self, request: web.Request) -> web.Response: async def send_bot_token(self, request: web.Request) -> web.Response:
data, user, err = await self.get_user_request_info(request) data, user, err = await self.get_user_request_info(request)
@@ -352,48 +420,78 @@ class ProvisioningAPI(AuthAPI):
return await self.post_login_password(user, data.get("password", "")) return await self.post_login_password(user, data.get("password", ""))
async def logout(self, request: web.Request) -> web.Response: async def logout(self, request: web.Request) -> web.Response:
_, user, err = await self.get_user_request_info(request, expect_logged_in=None, _, user, err = await self.get_user_request_info(
require_puppeting=False, request, expect_logged_in=None, require_puppeting=False, want_data=False
want_data=False) )
if err is not None: if err is not None:
return err return err
await user.log_out() await user.log_out()
return web.json_response({}, status=200) return web.json_response({}, status=200)
async def bridge_info(self, request: web.Request) -> web.Response: async def bridge_info(self, request: web.Request) -> web.Response:
return web.json_response({ return web.json_response(
"relaybot_username": (self.bridge.bot.tg_username {
if self.bridge.bot is not None else None), "relaybot_username": (
}, status=200) self.bridge.bot.tg_username if self.bridge.bot is not None else None
),
},
status=200,
)
@staticmethod @staticmethod
async def error_middleware(_, handler: Callable[[web.Request], Awaitable[web.Response]] async def error_middleware(
) -> Callable[[web.Request], Awaitable[web.Response]]: _, handler: Callable[[web.Request], Awaitable[web.Response]]
) -> Callable[[web.Request], Awaitable[web.Response]]:
async def middleware_handler(request: web.Request) -> web.Response: async def middleware_handler(request: web.Request) -> web.Response:
try: try:
return await handler(request) return await handler(request)
except web.HTTPException as ex: except web.HTTPException as ex:
return web.json_response({ return web.json_response(
"error": f"Unhandled HTTP {ex.status}", {
"errcode": f"unhandled_http_{ex.status}", "error": f"Unhandled HTTP {ex.status}",
}, status=ex.status) "errcode": f"unhandled_http_{ex.status}",
},
status=ex.status,
)
return middleware_handler return middleware_handler
@staticmethod @staticmethod
def get_error_response(status=200, errcode="", error="") -> web.Response: def get_error_response(status=200, errcode="", error="") -> web.Response:
return web.json_response({ return web.json_response(
"error": error, {
"errcode": errcode, "error": error,
}, status=status) "errcode": errcode,
},
status=status,
)
def get_mx_login_response(self, status=200, state="", username="", phone="", human_tg_id="", def get_mx_login_response(
mxid="", message="", error="", errcode=""): self,
status=200,
state="",
username="",
phone="",
human_tg_id="",
mxid="",
message="",
error="",
errcode="",
):
raise NotImplementedError() raise NotImplementedError()
def get_login_response(self, status=200, state="", username="", phone: str = "", def get_login_response(
human_tg_id: str = "", mxid="", message="", error="", errcode="" self,
) -> web.Response: status=200,
state="",
username="",
phone: str = "",
human_tg_id: str = "",
mxid="",
message="",
error="",
errcode="",
) -> web.Response:
if username or phone: if username or phone:
resp = { resp = {
"state": "logged-in", "state": "logged-in",
@@ -414,52 +512,63 @@ class ProvisioningAPI(AuthAPI):
resp["state"] = state resp["state"] = state
return web.json_response(resp, status=status) return web.json_response(resp, status=status)
def check_authorization(self, request: web.Request) -> Optional[web.Response]: def check_authorization(self, request: web.Request) -> web.Response | None:
auth = request.headers.get("Authorization", "") auth = request.headers.get("Authorization", "")
if auth != f"Bearer {self.secret}": if auth != f"Bearer {self.secret}":
return self.get_error_response(error="Shared secret is not valid.", return self.get_error_response(
errcode="shared_secret_invalid", error="Shared secret is not valid.", errcode="shared_secret_invalid", status=401
status=401) )
return None return None
@staticmethod @staticmethod
async def get_data(request: web.Request) -> Optional[dict]: async def get_data(request: web.Request) -> dict | None:
try: try:
return await request.json() return await request.json()
except json.JSONDecodeError: except json.JSONDecodeError:
return None return None
async def get_user(self, mxid: Optional[UserID], expect_logged_in: Optional[bool] = False, async def get_user(
require_puppeting: bool = True, require_user: bool = True self,
) -> Tuple[Optional[User], Optional[web.Response]]: mxid: UserID | None,
expect_logged_in: bool | None = False,
require_puppeting: bool = True,
require_user: bool = True,
) -> tuple[User | None, web.Response | None]:
if not mxid: if not mxid:
if not require_user: if not require_user:
return None, None return None, None
return None, self.get_login_response(error="User ID not given.", return None, self.get_login_response(
errcode="mxid_empty", status=400) error="User ID not given.", errcode="mxid_empty", status=400
)
user = await User.get_and_start_by_mxid(mxid, even_if_no_session=True) user = await User.get_and_start_by_mxid(mxid, even_if_no_session=True)
if require_puppeting and not user.puppet_whitelisted: if require_puppeting and not user.puppet_whitelisted:
return user, self.get_login_response(error="You are not whitelisted.", return user, self.get_login_response(
errcode="mxid_not_whitelisted", status=403) error="You are not whitelisted.", errcode="mxid_not_whitelisted", status=403
)
if expect_logged_in is not None: if expect_logged_in is not None:
logged_in = await user.is_logged_in() logged_in = await user.is_logged_in()
if not expect_logged_in and logged_in: if not expect_logged_in and logged_in:
return user, self.get_login_response(username=user.tg_username, phone=user.tg_phone, return user, self.get_login_response(
status=409, username=user.tg_username,
error="You are already logged in.", phone=user.tg_phone,
errcode="already_logged_in") status=409,
error="You are already logged in.",
errcode="already_logged_in",
)
elif expect_logged_in and not logged_in: elif expect_logged_in and not logged_in:
return user, self.get_login_response(status=403, error="You are not logged in.", return user, self.get_login_response(
errcode="not_logged_in") status=403, error="You are not logged in.", errcode="not_logged_in"
)
return user, None return user, None
async def get_user_request_info(self, request: web.Request, async def get_user_request_info(
expect_logged_in: Optional[bool] = False, self,
require_puppeting: bool = False, request: web.Request,
want_data: bool = True, expect_logged_in: bool | None = False,
) -> (Tuple[Optional[Dict], Optional[User], require_puppeting: bool = False,
Optional[web.Response]]): want_data: bool = True,
) -> tuple[dict | None, User | None, web.Response | None]:
err = self.check_authorization(request) err = self.check_authorization(request)
if err is not None: if err is not None:
return None, None, err return None, None, err
@@ -468,8 +577,13 @@ class ProvisioningAPI(AuthAPI):
if want_data and (request.method == "POST" or request.method == "PUT"): if want_data and (request.method == "POST" or request.method == "PUT"):
data = await self.get_data(request) data = await self.get_data(request)
if not data: if not data:
return None, None, self.get_login_response(error="Invalid JSON.", return (
errcode="json_invalid", status=400) None,
None,
self.get_login_response(
error="Invalid JSON.", errcode="json_invalid", status=400
),
)
mxid = request.match_info["mxid"] mxid = request.match_info["mxid"]
user, err = await self.get_user(mxid, expect_logged_in, require_puppeting) user, err = await self.get_user(mxid, expect_logged_in, require_puppeting)
+100 -56
View File
@@ -1,5 +1,5 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge # mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2021 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@@ -13,22 +13,23 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional from __future__ import annotations
import asyncio import asyncio
import logging import logging
import random import random
import string import string
import time import time
from mako.template import Template
from aiohttp import web from aiohttp import web
from mako.template import Template
import pkg_resources import pkg_resources
from mautrix.types import UserID from mautrix.types import UserID
from mautrix.util.signed_token import sign_token, verify_token from mautrix.util.signed_token import sign_token, verify_token
from ...user import User
from ...puppet import Puppet from ...puppet import Puppet
from ...user import User
from ..common import AuthAPI from ..common import AuthAPI
@@ -43,31 +44,38 @@ class PublicBridgeWebsite(AuthAPI):
super().__init__(loop) super().__init__(loop)
self.secret_key = "".join(random.choices(string.ascii_lowercase + string.digits, k=64)) self.secret_key = "".join(random.choices(string.ascii_lowercase + string.digits, k=64))
self.login = Template(pkg_resources.resource_string( self.login = Template(
"mautrix_telegram", "web/public/login.html.mako")) pkg_resources.resource_string("mautrix_telegram", "web/public/login.html.mako")
)
self.mx_login = Template(pkg_resources.resource_string( self.mx_login = Template(
"mautrix_telegram", "web/public/matrix-login.html.mako")) pkg_resources.resource_string("mautrix_telegram", "web/public/matrix-login.html.mako")
)
self.app = web.Application(loop=loop) self.app = web.Application(loop=loop)
self.app.router.add_route("GET", "/login", self.get_login) self.app.router.add_route("GET", "/login", self.get_login)
self.app.router.add_route("POST", "/login", self.post_login) self.app.router.add_route("POST", "/login", self.post_login)
self.app.router.add_route("GET", "/matrix-login", self.get_matrix_login) self.app.router.add_route("GET", "/matrix-login", self.get_matrix_login)
self.app.router.add_route("POST", "/matrix-login", self.post_matrix_login) self.app.router.add_route("POST", "/matrix-login", self.post_matrix_login)
self.app.router.add_static("/", pkg_resources.resource_filename("mautrix_telegram", self.app.router.add_static(
"web/public/")) "/", pkg_resources.resource_filename("mautrix_telegram", "web/public/")
)
def make_token(self, mxid: str, endpoint: str = "/login", expires_in: int = 900) -> str: def make_token(self, mxid: str, endpoint: str = "/login", expires_in: int = 900) -> str:
return sign_token(self.secret_key, { return sign_token(
"mxid": mxid, self.secret_key,
"endpoint": endpoint, {
"expiry": int(time.time()) + expires_in, "mxid": mxid,
}) "endpoint": endpoint,
"expiry": int(time.time()) + expires_in,
},
)
def verify_token(self, token: str, endpoint: str = "/login") -> Optional[UserID]: def verify_token(self, token: str, endpoint: str = "/login") -> UserID | None:
token = verify_token(self.secret_key, token) token = verify_token(self.secret_key, token)
if token and (token.get("expiry", 0) > int(time.time()) and if token and (
token.get("endpoint", None) == endpoint): token.get("expiry", 0) > int(time.time()) and token.get("endpoint", None) == endpoint
):
return UserID(token.get("mxid", None)) return UserID(token.get("mxid", None))
return None return None
@@ -82,8 +90,9 @@ class PublicBridgeWebsite(AuthAPI):
if not user: if not user:
return self.get_login_response(mxid=mxid, state=state) return self.get_login_response(mxid=mxid, state=state)
elif not user.puppet_whitelisted: elif not user.puppet_whitelisted:
return self.get_login_response(mxid=user.mxid, error="You are not whitelisted.", return self.get_login_response(
status=403) mxid=user.mxid, error="You are not whitelisted.", status=403
)
await user.ensure_started() await user.ensure_started()
if not await user.is_logged_in(): if not await user.is_logged_in():
return self.get_login_response(mxid=user.mxid, state=state) return self.get_login_response(mxid=user.mxid, state=state)
@@ -91,8 +100,9 @@ class PublicBridgeWebsite(AuthAPI):
return self.get_login_response(mxid=user.mxid, human_tg_id=user.human_tg_id) return self.get_login_response(mxid=user.mxid, human_tg_id=user.human_tg_id)
async def get_matrix_login(self, request: web.Request) -> web.Response: async def get_matrix_login(self, request: web.Request) -> web.Response:
mxid = self.verify_token(request.rel_url.query.get("token", None), mxid = self.verify_token(
endpoint="/matrix-login") request.rel_url.query.get("token", None), endpoint="/matrix-login"
)
if not mxid: if not mxid:
return self.get_mx_login_response(status=401, state="invalid-token") return self.get_mx_login_response(status=401, state="invalid-token")
user = await User.get_by_mxid(mxid, create=False) if mxid else None user = await User.get_by_mxid(mxid, create=False) if mxid else None
@@ -100,12 +110,14 @@ class PublicBridgeWebsite(AuthAPI):
if not user: if not user:
return self.get_mx_login_response(mxid=mxid) return self.get_mx_login_response(mxid=mxid)
elif not user.puppet_whitelisted: elif not user.puppet_whitelisted:
return self.get_mx_login_response(mxid=user.mxid, error="You are not whitelisted.", return self.get_mx_login_response(
status=403) mxid=user.mxid, error="You are not whitelisted.", status=403
)
await user.ensure_started() await user.ensure_started()
if not await user.is_logged_in(): if not await user.is_logged_in():
return self.get_mx_login_response(mxid=user.mxid, status=403, return self.get_mx_login_response(
error="You are not logged in to Telegram.") mxid=user.mxid, status=403, error="You are not logged in to Telegram."
)
puppet = await Puppet.get_by_tgid(user.tgid) puppet = await Puppet.get_by_tgid(user.tgid)
if puppet.is_real_user: if puppet.is_real_user:
@@ -113,24 +125,50 @@ class PublicBridgeWebsite(AuthAPI):
return self.get_mx_login_response(mxid=user.mxid) return self.get_mx_login_response(mxid=user.mxid)
def get_login_response(self, status: int = 200, state: str = "", username: str = "", def get_login_response(
phone: str = "", human_tg_id: str = "", mxid: str = "", self,
message: str = "", error: str = "", errcode: str = "") -> web.Response: status: int = 200,
return web.Response(status=status, content_type="text/html", state: str = "",
text=self.login.render(human_tg_id=human_tg_id, state=state, username: str = "",
error=error, message=message, mxid=mxid)) phone: str = "",
human_tg_id: str = "",
mxid: str = "",
message: str = "",
error: str = "",
errcode: str = "",
) -> web.Response:
return web.Response(
status=status,
content_type="text/html",
text=self.login.render(
human_tg_id=human_tg_id, state=state, error=error, message=message, mxid=mxid
),
)
def get_mx_login_response(self, status: int = 200, state: str = "", username: str = "", def get_mx_login_response(
phone: str = "", human_tg_id: str = "", mxid: str = "", self,
message: str = "", error: str = "", errcode: str = "" status: int = 200,
) -> web.Response: state: str = "",
return web.Response(status=status, content_type="text/html", username: str = "",
text=self.mx_login.render(human_tg_id=human_tg_id, state=state, phone: str = "",
error=error, message=message, mxid=mxid)) human_tg_id: str = "",
mxid: str = "",
message: str = "",
error: str = "",
errcode: str = "",
) -> web.Response:
return web.Response(
status=status,
content_type="text/html",
text=self.mx_login.render(
human_tg_id=human_tg_id, state=state, error=error, message=message, mxid=mxid
),
)
async def post_matrix_login(self, request: web.Request) -> web.Response: async def post_matrix_login(self, request: web.Request) -> web.Response:
mxid = self.verify_token(request.rel_url.query.get("token", None), mxid = self.verify_token(
endpoint="/matrix-login") request.rel_url.query.get("token", None), endpoint="/matrix-login"
)
if not mxid: if not mxid:
return self.get_mx_login_response(status=401, state="invalid-token") return self.get_mx_login_response(status=401, state="invalid-token")
@@ -138,19 +176,21 @@ class PublicBridgeWebsite(AuthAPI):
user = await User.get_and_start_by_mxid(mxid) user = await User.get_and_start_by_mxid(mxid)
if not user.puppet_whitelisted: if not user.puppet_whitelisted:
return self.get_mx_login_response(mxid=user.mxid, error="You are not whitelisted.", return self.get_mx_login_response(
status=403) mxid=user.mxid, error="You are not whitelisted.", status=403
)
elif not await user.is_logged_in(): elif not await user.is_logged_in():
return self.get_mx_login_response(mxid=user.mxid, status=403, return self.get_mx_login_response(
error="You are not logged in to Telegram.") mxid=user.mxid, status=403, error="You are not logged in to Telegram."
)
mode = data.get("mode", "access_token") mode = data.get("mode", "access_token")
if mode == "password": if mode == "password":
return await self.post_matrix_password(user, data["value"]) return await self.post_matrix_password(user, data["value"])
elif mode == "access_token": elif mode == "access_token":
return await self.post_matrix_token(user, data["value"]) return await self.post_matrix_token(user, data["value"])
return self.get_mx_login_response(mxid=user.mxid, status=400, return self.get_mx_login_response(
error="You must provide an access token or " mxid=user.mxid, status=400, error="You must provide an access token or password."
"password.") )
async def post_login(self, request: web.Request) -> web.Response: async def post_login(self, request: web.Request) -> web.Response:
mxid = self.verify_token(request.rel_url.query.get("token", None), endpoint="/login") mxid = self.verify_token(request.rel_url.query.get("token", None), endpoint="/login")
@@ -159,10 +199,11 @@ class PublicBridgeWebsite(AuthAPI):
data = await request.post() data = await request.post()
user = await User.get_by_mxid(mxid).ensure_started(even_if_no_session=True) user = await User.get_and_start_by_mxid(mxid, even_if_no_session=True)
if not user.puppet_whitelisted: if not user.puppet_whitelisted:
return self.get_login_response(mxid=user.mxid, error="You are not whitelisted.", return self.get_login_response(
status=403) mxid=user.mxid, error="You are not whitelisted.", status=403
)
elif await user.is_logged_in(): elif await user.is_logged_in():
return self.get_login_response(mxid=user.mxid, human_tg_id=user.human_tg_id) return self.get_login_response(mxid=user.mxid, human_tg_id=user.human_tg_id)
@@ -176,11 +217,14 @@ class PublicBridgeWebsite(AuthAPI):
try: try:
code = int(data["code"].strip()) code = int(data["code"].strip())
except ValueError: except ValueError:
return self.get_login_response(mxid=user.mxid, state="code", status=400, return self.get_login_response(
errcode="phone_code_invalid", mxid=user.mxid,
error="Phone code must be a number.") state="code",
resp = await self.post_login_code(user, code, status=400,
password_in_data="password" in data) errcode="phone_code_invalid",
error="Phone code must be a number.",
)
resp = await self.post_login_code(user, code, password_in_data="password" in data)
if resp or "password" not in data: if resp or "password" not in data:
return resp return resp
elif "password" not in data: elif "password" not in data:
+12
View File
@@ -0,0 +1,12 @@
[tool.isort]
profile = "black"
force_to_top = "typing"
from_first = true
combine_as_imports = true
known_first_party = "mautrix"
line_length = 99
[tool.black]
line-length = 99
target-version = ["py38"]
required-version = "21.12b0"