Minor improvements

This commit is contained in:
Tulir Asokan
2019-08-08 22:21:24 +03:00
parent 1338a43c03
commit ac24bc86a0
2 changed files with 25 additions and 23 deletions
+12 -14
View File
@@ -13,9 +13,8 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Callable, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING from typing import Awaitable, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
import logging import logging
import re
from telethon.tl.patched import Message, MessageService from telethon.tl.patched import Message, MessageService
from telethon.tl.types import ( from telethon.tl.types import (
@@ -44,7 +43,6 @@ ReplyFunc = Callable[[str], Awaitable[Message]]
class Bot(AbstractUser): class Bot(AbstractUser):
log: logging.Logger = logging.getLogger("mau.user.bot") log: logging.Logger = logging.getLogger("mau.user.bot")
mxid_regex: Pattern = re.compile("@.+:.+")
token: str token: str
chats: Dict[int, str] chats: Dict[int, str]
@@ -110,9 +108,9 @@ class Bot(AbstractUser):
if isinstance(chat, ChatForbidden) or chat.left or chat.deactivated: if isinstance(chat, ChatForbidden) or chat.left or chat.deactivated:
self.remove_chat(TelegramID(chat.id)) self.remove_chat(TelegramID(chat.id))
channel_ids = [InputChannel(chat_id, 0) channel_ids = (InputChannel(chat_id, 0)
for chat_id, chat_type in self.chats.items() for chat_id, chat_type in self.chats.items()
if chat_type == "channel"] if chat_type == "channel")
for channel_id in channel_ids: for channel_id in channel_ids:
try: try:
await self.client(GetChannelsRequest([channel_id])) await self.client(GetChannelsRequest([channel_id]))
@@ -165,7 +163,7 @@ class Bot(AbstractUser):
return False return False
async def check_can_use_commands(self, event: Message, reply: ReplyFunc) -> bool: async def check_can_use_commands(self, event: Message, reply: ReplyFunc) -> bool:
if not await self._can_use_commands(event.to_id, event.from_id): if not await self._can_use_commands(event.to_id, TelegramID(event.from_id)):
await reply("You do not have the permission to use that command.") await reply("You do not have the permission to use that command.")
return False return False
return True return True
@@ -193,7 +191,7 @@ class Bot(AbstractUser):
elif not portal.mxid: elif not portal.mxid:
return await reply("Portal does not have Matrix room. " return await reply("Portal does not have Matrix room. "
"Create one with /portal first.") "Create one with /portal first.")
if not self.mxid_regex.match(mxid_input): if mxid_input[0] != '@' or mxid_input.find(':') < 2:
return await reply("That doesn't look like a Matrix ID.") return await reply("That doesn't look like a Matrix ID.")
user = await u.User.get_by_mxid(mxid_input).ensure_started() user = await u.User.get_by_mxid(mxid_input).ensure_started()
if not user.relaybot_whitelisted: if not user.relaybot_whitelisted:
@@ -203,7 +201,7 @@ class Bot(AbstractUser):
return await reply("That user seems to be logged in. " return await reply("That user seems to be logged in. "
f"Just invite [{displayname}](tg://user?id={user.tgid})") f"Just invite [{displayname}](tg://user?id={user.tgid})")
else: else:
await portal.main_intent.invite(portal.mxid, user.mxid) await portal.main_intent.invite_user(portal.mxid, user.mxid)
return await reply(f"Invited `{user.mxid}` to the portal.") return await reply(f"Invited `{user.mxid}` to the portal.")
@staticmethod @staticmethod
@@ -252,15 +250,15 @@ class Bot(AbstractUser):
mxid = text[text.index(" ") + 1:] mxid = text[text.index(" ") + 1:]
except ValueError: except ValueError:
mxid = "" mxid = ""
await self.handle_command_invite(portal, reply, mxid_input=mxid) await self.handle_command_invite(portal, reply, mxid_input=UserID(mxid))
def handle_service_message(self, message: MessageService) -> None: def handle_service_message(self, message: MessageService) -> None:
to_id: TelegramID = message.to_id to_peer = message.to_id
if isinstance(to_id, PeerChannel): if isinstance(to_peer, PeerChannel):
to_id = to_id.channel_id to_id = TelegramID(to_peer.channel_id)
chat_type = "channel" chat_type = "channel"
elif isinstance(to_id, PeerChat): elif isinstance(to_peer, PeerChat):
to_id = to_id.chat_id to_id = TelegramID(to_peer.chat_id)
chat_type = "chat" chat_type = "chat"
else: else:
return return
+13 -9
View File
@@ -13,12 +13,11 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Dict, List, Optional, Pattern, Tuple, Union, Any, TYPE_CHECKING from typing import Awaitable, Dict, List, Optional, Tuple, Union, Any, TYPE_CHECKING
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import asyncio import asyncio
import logging import logging
import json import json
import re
from telethon.tl.functions.messages import ExportChatInviteRequest from telethon.tl.functions.messages import ExportChatInviteRequest
from telethon.tl.types import (Channel, ChannelFull, Chat, ChatFull, ChatInviteEmpty, InputChannel, from telethon.tl.types import (Channel, ChannelFull, Chat, ChatFull, ChatInviteEmpty, InputChannel,
@@ -69,7 +68,8 @@ class BasePortal(ABC):
public_portals: bool = False public_portals: bool = False
alias_template: str = None alias_template: str = None
mx_alias_regex: Pattern = None _mx_alias_prefix: str = None
_mx_alias_suffix: str = None
hs_domain: str = None hs_domain: str = None
# Instance cache # Instance cache
@@ -346,9 +346,10 @@ class BasePortal(ABC):
@classmethod @classmethod
def get_username_from_mx_alias(cls, alias: str) -> Optional[str]: def get_username_from_mx_alias(cls, alias: str) -> Optional[str]:
match = cls.mx_alias_regex.match(alias) prefix = cls._mx_alias_prefix
if match: suffix = cls._mx_alias_suffix
return match.group(1) if alias[:len(prefix)] == prefix and alias[-len(suffix):] == suffix:
return alias[len(prefix):-len(suffix)]
return None return None
@classmethod @classmethod
@@ -473,7 +474,10 @@ def init(context: Context) -> None:
BasePortal.public_portals = config["bridge.public_portals"] BasePortal.public_portals = config["bridge.public_portals"]
BasePortal.filter_mode = config["bridge.filter.mode"] BasePortal.filter_mode = config["bridge.filter.mode"]
BasePortal.filter_list = config["bridge.filter.list"] BasePortal.filter_list = config["bridge.filter.list"]
BasePortal.alias_template = config.get("bridge.alias_template", "telegram_{groupname}")
BasePortal.hs_domain = config["homeserver.domain"] BasePortal.hs_domain = config["homeserver.domain"]
BasePortal.mx_alias_regex = re.compile( BasePortal.alias_template = config["bridge.alias_template"]
f"#{BasePortal.alias_template.format(groupname='(.+)')}:{BasePortal.hs_domain}") index = BasePortal.alias_template.index("{groupname}")
length = len("{groupname}")
BasePortal._mx_alias_prefix = f"#{BasePortal.alias_template[:index]}"
BasePortal._mx_alias_suffix = (f"{BasePortal.alias_template[index + length:]}"
f":{BasePortal.hs_domain}")