Minor improvements
This commit is contained in:
+12
-14
@@ -13,9 +13,8 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import Awaitable, Callable, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING
|
from typing import Awaitable, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
|
|
||||||
from telethon.tl.patched import Message, MessageService
|
from telethon.tl.patched import Message, MessageService
|
||||||
from telethon.tl.types import (
|
from telethon.tl.types import (
|
||||||
@@ -44,7 +43,6 @@ ReplyFunc = Callable[[str], Awaitable[Message]]
|
|||||||
|
|
||||||
class Bot(AbstractUser):
|
class Bot(AbstractUser):
|
||||||
log: logging.Logger = logging.getLogger("mau.user.bot")
|
log: logging.Logger = logging.getLogger("mau.user.bot")
|
||||||
mxid_regex: Pattern = re.compile("@.+:.+")
|
|
||||||
|
|
||||||
token: str
|
token: str
|
||||||
chats: Dict[int, str]
|
chats: Dict[int, str]
|
||||||
@@ -110,9 +108,9 @@ class Bot(AbstractUser):
|
|||||||
if isinstance(chat, ChatForbidden) or chat.left or chat.deactivated:
|
if isinstance(chat, ChatForbidden) or chat.left or chat.deactivated:
|
||||||
self.remove_chat(TelegramID(chat.id))
|
self.remove_chat(TelegramID(chat.id))
|
||||||
|
|
||||||
channel_ids = [InputChannel(chat_id, 0)
|
channel_ids = (InputChannel(chat_id, 0)
|
||||||
for chat_id, chat_type in self.chats.items()
|
for chat_id, chat_type in self.chats.items()
|
||||||
if chat_type == "channel"]
|
if chat_type == "channel")
|
||||||
for channel_id in channel_ids:
|
for channel_id in channel_ids:
|
||||||
try:
|
try:
|
||||||
await self.client(GetChannelsRequest([channel_id]))
|
await self.client(GetChannelsRequest([channel_id]))
|
||||||
@@ -165,7 +163,7 @@ class Bot(AbstractUser):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
async def check_can_use_commands(self, event: Message, reply: ReplyFunc) -> bool:
|
async def check_can_use_commands(self, event: Message, reply: ReplyFunc) -> bool:
|
||||||
if not await self._can_use_commands(event.to_id, event.from_id):
|
if not await self._can_use_commands(event.to_id, TelegramID(event.from_id)):
|
||||||
await reply("You do not have the permission to use that command.")
|
await reply("You do not have the permission to use that command.")
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
@@ -193,7 +191,7 @@ class Bot(AbstractUser):
|
|||||||
elif not portal.mxid:
|
elif not portal.mxid:
|
||||||
return await reply("Portal does not have Matrix room. "
|
return await reply("Portal does not have Matrix room. "
|
||||||
"Create one with /portal first.")
|
"Create one with /portal first.")
|
||||||
if not self.mxid_regex.match(mxid_input):
|
if mxid_input[0] != '@' or mxid_input.find(':') < 2:
|
||||||
return await reply("That doesn't look like a Matrix ID.")
|
return await reply("That doesn't look like a Matrix ID.")
|
||||||
user = await u.User.get_by_mxid(mxid_input).ensure_started()
|
user = await u.User.get_by_mxid(mxid_input).ensure_started()
|
||||||
if not user.relaybot_whitelisted:
|
if not user.relaybot_whitelisted:
|
||||||
@@ -203,7 +201,7 @@ class Bot(AbstractUser):
|
|||||||
return await reply("That user seems to be logged in. "
|
return await reply("That user seems to be logged in. "
|
||||||
f"Just invite [{displayname}](tg://user?id={user.tgid})")
|
f"Just invite [{displayname}](tg://user?id={user.tgid})")
|
||||||
else:
|
else:
|
||||||
await portal.main_intent.invite(portal.mxid, user.mxid)
|
await portal.main_intent.invite_user(portal.mxid, user.mxid)
|
||||||
return await reply(f"Invited `{user.mxid}` to the portal.")
|
return await reply(f"Invited `{user.mxid}` to the portal.")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -252,15 +250,15 @@ class Bot(AbstractUser):
|
|||||||
mxid = text[text.index(" ") + 1:]
|
mxid = text[text.index(" ") + 1:]
|
||||||
except ValueError:
|
except ValueError:
|
||||||
mxid = ""
|
mxid = ""
|
||||||
await self.handle_command_invite(portal, reply, mxid_input=mxid)
|
await self.handle_command_invite(portal, reply, mxid_input=UserID(mxid))
|
||||||
|
|
||||||
def handle_service_message(self, message: MessageService) -> None:
|
def handle_service_message(self, message: MessageService) -> None:
|
||||||
to_id: TelegramID = message.to_id
|
to_peer = message.to_id
|
||||||
if isinstance(to_id, PeerChannel):
|
if isinstance(to_peer, PeerChannel):
|
||||||
to_id = to_id.channel_id
|
to_id = TelegramID(to_peer.channel_id)
|
||||||
chat_type = "channel"
|
chat_type = "channel"
|
||||||
elif isinstance(to_id, PeerChat):
|
elif isinstance(to_peer, PeerChat):
|
||||||
to_id = to_id.chat_id
|
to_id = TelegramID(to_peer.chat_id)
|
||||||
chat_type = "chat"
|
chat_type = "chat"
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -13,12 +13,11 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import Awaitable, Dict, List, Optional, Pattern, Tuple, Union, Any, TYPE_CHECKING
|
from typing import Awaitable, Dict, List, Optional, Tuple, Union, Any, TYPE_CHECKING
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
import re
|
|
||||||
|
|
||||||
from telethon.tl.functions.messages import ExportChatInviteRequest
|
from telethon.tl.functions.messages import ExportChatInviteRequest
|
||||||
from telethon.tl.types import (Channel, ChannelFull, Chat, ChatFull, ChatInviteEmpty, InputChannel,
|
from telethon.tl.types import (Channel, ChannelFull, Chat, ChatFull, ChatInviteEmpty, InputChannel,
|
||||||
@@ -69,7 +68,8 @@ class BasePortal(ABC):
|
|||||||
public_portals: bool = False
|
public_portals: bool = False
|
||||||
|
|
||||||
alias_template: str = None
|
alias_template: str = None
|
||||||
mx_alias_regex: Pattern = None
|
_mx_alias_prefix: str = None
|
||||||
|
_mx_alias_suffix: str = None
|
||||||
hs_domain: str = None
|
hs_domain: str = None
|
||||||
|
|
||||||
# Instance cache
|
# Instance cache
|
||||||
@@ -346,9 +346,10 @@ class BasePortal(ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_username_from_mx_alias(cls, alias: str) -> Optional[str]:
|
def get_username_from_mx_alias(cls, alias: str) -> Optional[str]:
|
||||||
match = cls.mx_alias_regex.match(alias)
|
prefix = cls._mx_alias_prefix
|
||||||
if match:
|
suffix = cls._mx_alias_suffix
|
||||||
return match.group(1)
|
if alias[:len(prefix)] == prefix and alias[-len(suffix):] == suffix:
|
||||||
|
return alias[len(prefix):-len(suffix)]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -473,7 +474,10 @@ def init(context: Context) -> None:
|
|||||||
BasePortal.public_portals = config["bridge.public_portals"]
|
BasePortal.public_portals = config["bridge.public_portals"]
|
||||||
BasePortal.filter_mode = config["bridge.filter.mode"]
|
BasePortal.filter_mode = config["bridge.filter.mode"]
|
||||||
BasePortal.filter_list = config["bridge.filter.list"]
|
BasePortal.filter_list = config["bridge.filter.list"]
|
||||||
BasePortal.alias_template = config.get("bridge.alias_template", "telegram_{groupname}")
|
|
||||||
BasePortal.hs_domain = config["homeserver.domain"]
|
BasePortal.hs_domain = config["homeserver.domain"]
|
||||||
BasePortal.mx_alias_regex = re.compile(
|
BasePortal.alias_template = config["bridge.alias_template"]
|
||||||
f"#{BasePortal.alias_template.format(groupname='(.+)')}:{BasePortal.hs_domain}")
|
index = BasePortal.alias_template.index("{groupname}")
|
||||||
|
length = len("{groupname}")
|
||||||
|
BasePortal._mx_alias_prefix = f"#{BasePortal.alias_template[:index]}"
|
||||||
|
BasePortal._mx_alias_suffix = (f"{BasePortal.alias_template[index + length:]}"
|
||||||
|
f":{BasePortal.hs_domain}")
|
||||||
|
|||||||
Reference in New Issue
Block a user