Start migrating to mautrix-python

This commit is contained in:
Tulir Asokan
2019-07-13 00:23:46 +03:00
parent e0d3c940f8
commit 8d4a9dc231
14 changed files with 263 additions and 797 deletions
+50 -115
View File
@@ -13,25 +13,15 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, List, Any from itertools import chain
from time import time
import argparse
import asyncio
import logging.config
import sys
import copy
import signal
import os
import sqlalchemy as sql from mautrix.bridge import Bridge
from mautrix_appservice import AppService
from alchemysession import AlchemySessionContainer from alchemysession import AlchemySessionContainer
from .web.provisioning import ProvisioningAPI from .web.provisioning import ProvisioningAPI
from .web.public import PublicBridgeWebsite from .web.public import PublicBridgeWebsite
from .abstract_user import init as init_abstract_user from .abstract_user import init as init_abstract_user
from .bot import init as init_bot from .bot import Bot, init as init_bot
from .config import Config from .config import Config
from .context import Context from .context import Context
from .db import Base, init as init_db from .db import Base, init as init_db
@@ -48,115 +38,60 @@ try:
except ImportError: except ImportError:
prometheus = None prometheus = None
parser = argparse.ArgumentParser(
description="A Matrix-Telegram puppeting bridge.",
prog="python -m mautrix-telegram")
parser.add_argument("-c", "--config", type=str, default="config.yaml",
metavar="<path>", help="the path to your config file")
parser.add_argument("-b", "--base-config", type=str, default="example-config.yaml",
metavar="<path>", help="the path to the example config "
"(for automatic config updates)")
parser.add_argument("-g", "--generate-registration", action="store_true",
help="generate registration and quit")
parser.add_argument("-r", "--registration", type=str, default="registration.yaml",
metavar="<path>", help="the path to save the generated registration to")
args = parser.parse_args()
config = Config(args.config, args.registration, args.base_config, os.environ) class TelegramBridge(Bridge):
config.load() name = "mautrix-telegram"
config.update() command = "python -m mautrix-telegram"
description = "A Matrix-Telegram puppeting bridge."
real_user_content_key = "net.maunium.telegram.puppet"
version = __version__
config_class = Config
matrix_class = MatrixHandler
state_store_class = SQLStateStore
if args.generate_registration: config: Config
config.generate_registration() session_container: AlchemySessionContainer
config.save() bot: Bot
print(f"Registration generated and saved to {config.registration_path}")
sys.exit(0)
logging.config.dictConfig(copy.deepcopy(config["logging"])) def prepare_db(self) -> None:
log: logging.Logger = logging.getLogger("mau.init") super().prepare_db()
log.debug(f"Initializing mautrix-telegram {__version__}") init_db(self.db)
self.session_container = AlchemySessionContainer(
engine=self.db, table_base=Base, session=False,
table_prefix="telethon_", manage_tables=False)
db_engine = sql.create_engine(config["appservice.database"] or "sqlite:///mautrix-telegram.db") def prepare_bridge(self) -> None:
Base.metadata.bind = db_engine self.bot = init_bot(self.config)
context = Context(self.az, self.config, self.loop, self.session_container, self.bot)
session_container = AlchemySessionContainer(engine=db_engine, table_base=Base, session=False, if self.config["appservice.public.enabled"]:
table_prefix="telethon_", manage_tables=False) public_website = PublicBridgeWebsite(self.loop)
session_container.core_mode = True self.az.app.add_subapp(self.config["appservice.public.prefix"], public_website.app)
context.public_website = public_website
try: if self.config["appservice.provisioning.enabled"]:
import uvloop provisioning_api = ProvisioningAPI(context)
self.az.app.add_subapp(self.config["appservice.provisioning.prefix"],
provisioning_api.app)
context.provisioning_api = provisioning_api
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) self.matrix = context.mx = MatrixHandler(context)
log.debug("Using uvloop for asyncio")
except ImportError:
pass
loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if self.config["metrics.enabled"]:
if prometheus:
prometheus.start_http_server(self.config["metrics.listen_port"])
else:
self.log.warn("Metrics are enabled in the config, "
"but prometheus_client is not installed.")
state_store = SQLStateStore() init_abstract_user(context)
mebibyte = 1024 ** 2 init_formatter(context)
appserv = AppService(config["homeserver.address"], config["homeserver.domain"], init_portal(context)
config["appservice.as_token"], config["appservice.hs_token"], puppet_startup = init_puppet(context)
config["appservice.bot_username"], log="mau.as", loop=loop, user_startup = init_user(context)
verify_ssl=config["homeserver.verify_ssl"], state_store=state_store, self.startup_actions = chain(puppet_startup, user_startup,
real_user_content_key="net.maunium.telegram.puppet", [self.bot.start] if self.bot else [])
aiohttp_params={
"client_max_size": config["appservice.max_body_size"] * mebibyte
})
bot = init_bot(config)
context = Context(appserv, config, loop, session_container, bot)
if config["appservice.public.enabled"]: async def stop(self) -> None:
public_website = PublicBridgeWebsite(loop) self.shutdown_actions = [user.stop() for user in User.by_tgid.values()]
appserv.app.add_subapp(config["appservice.public.prefix"] or "/public", public_website.app) await super().stop()
context.public_website = public_website
if config["appservice.provisioning.enabled"]:
provisioning_api = ProvisioningAPI(context)
appserv.app.add_subapp(config["appservice.provisioning.prefix"] or "/_matrix/provisioning",
provisioning_api.app)
context.provisioning_api = provisioning_api
context.mx = MatrixHandler(context)
if config["metrics.enabled"]:
if prometheus:
prometheus.start_http_server(config["metrics.listen_port"])
else:
log.warn("Metrics are enabled in the config, but prometheus_client is not installed.")
with appserv.run(config["appservice.hostname"], config["appservice.port"]) as start:
start_ts = time()
init_db(db_engine)
init_abstract_user(context)
init_formatter(context)
init_portal(context)
startup_actions: List[Awaitable[Any]] = (init_puppet(context) +
init_user(context) +
[start, context.mx.init_as_bot()])
if context.bot:
startup_actions.append(context.bot.start())
signal.signal(signal.SIGINT, signal.default_int_handler)
signal.signal(signal.SIGTERM, signal.default_int_handler)
end_ts = time()
try:
log.debug(f"Initialization complete in {round(end_ts - start_ts, 2)} seconds,"
" running startup actions")
start_ts = time()
loop.run_until_complete(asyncio.gather(*startup_actions, loop=loop))
end_ts = time()
log.debug(f"Startup actions complete in {round(end_ts - start_ts, 2)} seconds,"
" now running forever")
loop.run_forever()
except KeyboardInterrupt:
log.debug("Interrupt received, stopping clients")
loop.run_until_complete(
asyncio.gather(*[user.stop() for user in User.by_tgid.values()], loop=loop))
log.debug("Clients stopped, shutting down")
sys.exit(0)
except Exception as e:
log.exception("Unexpected error")
sys.exit(1)
+2 -3
View File
@@ -65,7 +65,7 @@ class AbstractUser(ABC):
loop: asyncio.AbstractEventLoop = None loop: asyncio.AbstractEventLoop = None
log: logging.Logger log: logging.Logger
az: AppService az: AppService
bot: 'Bot' relaybot: Optional['Bot']
ignore_incoming_bot_events: bool = True ignore_incoming_bot_events: bool = True
client: Optional[MautrixTelegramClient] client: Optional[MautrixTelegramClient]
@@ -76,7 +76,6 @@ class AbstractUser(ABC):
is_bot: bool is_bot: bool
is_relaybot: bool is_relaybot: bool
relaybot: Optional['Bot']
puppet_whitelisted: bool puppet_whitelisted: bool
whitelisted: bool whitelisted: bool
@@ -404,7 +403,7 @@ class AbstractUser(ABC):
portal.tgid_log) portal.tgid_log)
return return
if self.ignore_incoming_bot_events and self.bot and sender.id == self.bot.tgid: if self.ignore_incoming_bot_events and self.relaybot and sender.id == self.relaybot.tgid:
self.log.debug(f"Ignoring relaybot-sent message %s to %s", update, portal.tgid_log) self.log.debug(f"Ignoring relaybot-sent message %s to %s", update, portal.tgid_log)
return return
+33 -172
View File
@@ -13,157 +13,33 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, List, NamedTuple
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap from ruamel.yaml.comments import CommentedMap
import random import random
import string import string
import os
yaml: YAML = YAML() from mautrix.types import UserID
yaml.indent(4) from mautrix.client import Client
from mautrix.bridge.config import BaseBridgeConfig, ConfigUpdateHelper
Permissions = NamedTuple("Permissions", relaybot=bool, user=bool, puppeting=bool,
matrix_puppeting=bool, admin=bool, level=str)
class DictWithRecursion: class Config(BaseBridgeConfig):
_data: CommentedMap
def __init__(self, data: Optional[CommentedMap] = None) -> None:
self._data = data or CommentedMap()
@staticmethod
def _parse_key(key: str) -> Tuple[str, Optional[str]]:
if '.' not in key:
return key, None
key, next_key = key.split('.', 1)
if len(key) > 0 and key[0] == "[":
end_index = next_key.index("]")
key = key[1:] + "." + next_key[:end_index]
next_key = next_key[end_index + 2:] if len(next_key) > end_index + 1 else None
return key, next_key
def _recursive_get(self, data: CommentedMap, key: str, default_value: Any) -> Any:
key, next_key = self._parse_key(key)
if next_key is not None:
next_data = data.get(key, CommentedMap())
return self._recursive_get(next_data, next_key, default_value)
return data.get(key, default_value)
def get(self, key: str, default_value: Any, allow_recursion: bool = True) -> Any:
if allow_recursion and '.' in key:
return self._recursive_get(self._data, key, default_value)
return self._data.get(key, default_value)
def __getitem__(self, key: str) -> Any:
return self.get(key, None)
def __contains__(self, key: str) -> bool:
return self[key] is not None
def _recursive_set(self, data: CommentedMap, key: str, value: Any) -> None:
key, next_key = self._parse_key(key)
if next_key is not None:
if key not in data:
data[key] = CommentedMap()
next_data = data.get(key, CommentedMap())
return self._recursive_set(next_data, next_key, value)
data[key] = value
def set(self, key: str, value: Any, allow_recursion: bool = True) -> None:
if allow_recursion and '.' in key:
self._recursive_set(self._data, key, value)
return
self._data[key] = value
def __setitem__(self, key: str, value: Any) -> None:
self.set(key, value)
def _recursive_del(self, data: CommentedMap, key: str) -> None:
key, next_key = self._parse_key(key)
if next_key is not None:
if key not in data:
return
next_data = data[key]
return self._recursive_del(next_data, next_key)
try:
del data[key]
del data.ca.items[key]
except KeyError:
pass
def delete(self, key: str, allow_recursion: bool = True) -> None:
if allow_recursion and '.' in key:
self._recursive_del(self._data, key)
return
try:
del self._data[key]
del self._data.ca.items[key]
except KeyError:
pass
def __delitem__(self, key: str) -> None:
self.delete(key)
class Config(DictWithRecursion):
path: str
registration_path: str
base_path: str
_registration: Optional[Dict[str, Any]]
_overrides: Dict[str, Any]
def __init__(self, path: str, registration_path: str, base_path: str,
overrides: Dict[str, Any] = None) -> None:
super().__init__()
self.path = path
self.registration_path = registration_path
self.base_path = base_path
self._registration = None
self._overrides = overrides or {}
def __getitem__(self, key: str) -> Any: def __getitem__(self, key: str) -> Any:
try: try:
return self._overrides[f"MAUTRIX_TELEGRAM_{key.replace('.', '_').upper()}"] return os.environ[f"MAUTRIX_TELEGRAM_{key.replace('.', '_').upper()}"]
except KeyError: except KeyError:
return super().__getitem__(key) return super().__getitem__(key)
def load(self) -> None:
with open(self.path, 'r') as stream:
self._data = yaml.load(stream)
def load_base(self) -> Optional[DictWithRecursion]:
try:
with open(self.base_path, 'r') as stream:
return DictWithRecursion(yaml.load(stream))
except OSError:
pass
return None
def save(self) -> None:
with open(self.path, 'w') as stream:
yaml.dump(self._data, stream)
if self._registration and self.registration_path:
with open(self.registration_path, 'w') as stream:
yaml.dump(self._registration, stream)
@staticmethod @staticmethod
def _new_token() -> str: def _new_token() -> str:
return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64)) return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64))
def update(self) -> None: def do_update(self, helper: ConfigUpdateHelper) -> None:
base = self.load_base() copy, copy_dict, base = helper
if not base:
return
def copy(from_path, to_path=None) -> None:
if from_path in self:
base[to_path or from_path] = self[from_path]
def copy_dict(from_path, to_path=None, override_existing_map=True) -> None:
if from_path in self:
to_path = to_path or from_path
if override_existing_map or to_path not in base:
base[to_path] = CommentedMap()
for key, value in self[from_path].items():
base[to_path][key] = value
copy("homeserver.address") copy("homeserver.address")
copy("homeserver.domain") copy("homeserver.domain")
@@ -309,58 +185,43 @@ class Config(DictWithRecursion):
else: else:
copy("logging") copy("logging")
self._data = base._data def _get_permissions(self, key: str) -> Permissions:
self.save()
def _get_permissions(self, key: str) -> Tuple[bool, bool, bool, bool, bool, bool]:
level = self["bridge.permissions"].get(key, "") level = self["bridge.permissions"].get(key, "")
admin = level == "admin" admin = level == "admin"
matrix_puppeting = level == "full" or admin matrix_puppeting = level == "full" or admin
puppeting = level == "puppeting" or matrix_puppeting puppeting = level == "puppeting" or matrix_puppeting
user = level == "user" or puppeting user = level == "user" or puppeting
relaybot = level == "relaybot" or user relaybot = level == "relaybot" or user
return relaybot, user, puppeting, matrix_puppeting, admin, level return Permissions(relaybot, user, puppeting, matrix_puppeting, admin, level)
def get_permissions(self, mxid: str) -> Tuple[bool, bool, bool, bool, bool, bool]: def get_permissions(self, mxid: UserID) -> Permissions:
permissions = self["bridge.permissions"] or {} permissions = self["bridge.permissions"]
if mxid in permissions: if mxid in permissions:
return self._get_permissions(mxid) return self._get_permissions(mxid)
homeserver = mxid[mxid.index(":") + 1:] _, homeserver = Client.parse_user_id(mxid)
if homeserver in permissions: if homeserver in permissions:
return self._get_permissions(homeserver) return self._get_permissions(homeserver)
return self._get_permissions("*") return self._get_permissions("*")
def generate_registration(self) -> None: @property
def namespaces(self) -> Dict[str, List[Dict[str, Any]]]:
homeserver = self["homeserver.domain"] homeserver = self["homeserver.domain"]
username_format = self.get("bridge.username_template", username_format = self["bridge.username_template"].format(userid=".+")
"telegram_{userid}").format(userid=".+") alias_format = self["bridge.alias_template"].format(groupname=".+")
alias_format = self.get("bridge.alias_template", group_id = ({"group_id": self["appservice.community_id"]}
"telegram_{groupname}").format(groupname=".+") if self["appservice.community_id"] else {})
self.set("appservice.as_token", self._new_token()) return {
self.set("appservice.hs_token", self._new_token()) "users": [{
"exclusive": True,
self._registration = { "regex": f"@{username_format}:{homeserver}",
"id": self["appservice.id"] or "telegram", **group_id,
"as_token": self["appservice.as_token"], }],
"hs_token": self["appservice.hs_token"], "aliases": [{
"namespaces": { "exclusive": True,
"users": [{ "regex": f"#{alias_format}:{homeserver}",
"exclusive": True, }]
"regex": f"@{username_format}:{homeserver}"
}],
"aliases": [{
"exclusive": True,
"regex": f"#{alias_format}:{homeserver}"
}]
},
"url": self["appservice.address"],
"sender_localpart": self["appservice.bot_username"],
"rate_limited": False
} }
if self["appservice.community_id"]:
self._registration["namespaces"]["users"][0]["group_id"] = self[
"appservice.community_id"]
+11 -10
View File
@@ -15,11 +15,12 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Tuple, TYPE_CHECKING from typing import Optional, Tuple, TYPE_CHECKING
if TYPE_CHECKING: import asyncio
import asyncio
from alchemysession import AlchemySessionContainer from alchemysession import AlchemySessionContainer
from mautrix_appservice import AppService from mautrix_appservice import AppService
if TYPE_CHECKING:
from .web import PublicBridgeWebsite, ProvisioningAPI from .web import PublicBridgeWebsite, ProvisioningAPI
from .config import Config from .config import Config
@@ -28,17 +29,17 @@ if TYPE_CHECKING:
class Context: class Context:
az: 'AppService' az: AppService
config: 'Config' config: 'Config'
loop: 'asyncio.AbstractEventLoop' loop: asyncio.AbstractEventLoop
bot: Optional['Bot'] bot: Optional['Bot']
mx: Optional['MatrixHandler'] mx: Optional['MatrixHandler']
session_container: 'AlchemySessionContainer' session_container: AlchemySessionContainer
public_website: Optional['PublicBridgeWebsite'] public_website: Optional['PublicBridgeWebsite']
provisioning_api: Optional['ProvisioningAPI'] provisioning_api: Optional['ProvisioningAPI']
def __init__(self, az: 'AppService', config: 'Config', loop: 'asyncio.AbstractEventLoop', def __init__(self, az: AppService, config: 'Config', loop: asyncio.AbstractEventLoop,
session_container: 'AlchemySessionContainer', bot: Optional['Bot']) -> None: session_container: AlchemySessionContainer, bot: Optional['Bot']) -> None:
self.az = az self.az = az
self.config = config self.config = config
self.loop = loop self.loop = loop
@@ -49,5 +50,5 @@ class Context:
self.provisioning_api = None self.provisioning_api = None
@property @property
def core(self) -> Tuple['AppService', 'Config', 'asyncio.AbstractEventLoop', Optional['Bot']]: def core(self) -> Tuple[AppService, 'Config', asyncio.AbstractEventLoop, Optional['Bot']]:
return self.az, self.config, self.loop, self.bot return self.az, self.config, self.loop, self.bot
-61
View File
@@ -1,61 +0,0 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2019 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from sqlalchemy import Column, String, Text
from typing import Dict, Optional
import json
from ..types import MatrixRoomID
from .base import Base
class RoomState(Base):
__tablename__ = "mx_room_state"
room_id = Column(String, primary_key=True) # type: MatrixRoomID
power_levels = Column("power_levels", Text, nullable=True) # type: Optional[Dict]
@property
def _power_levels_text(self) -> Optional[str]:
return json.dumps(self.power_levels) if self.power_levels else None
@property
def has_power_levels(self) -> bool:
return bool(self.power_levels)
@classmethod
def get(cls, room_id: MatrixRoomID) -> Optional['RoomState']:
rows = cls.db.execute(cls.t.select().where(cls.c.room_id == room_id))
try:
room_id, power_levels_text = next(rows)
return cls(room_id=room_id, power_levels=(json.loads(power_levels_text)
if power_levels_text else None))
except StopIteration:
return None
def update(self) -> None:
with self.db.begin() as conn:
conn.execute(self.t.update()
.where(self.c.room_id == self.room_id)
.values(power_levels=self._power_levels_text))
@property
def _edit_identity(self):
return self.c.room_id == self.room_id
def insert(self) -> None:
with self.db.begin() as conn:
conn.execute(self.t.insert().values(room_id=self.room_id,
power_levels=self._power_levels_text))
-68
View File
@@ -1,68 +0,0 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2019 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from sqlalchemy import Column, String, and_
from typing import Dict, Optional
from ..types import MatrixUserID, MatrixRoomID
from .base import Base
class UserProfile(Base):
__tablename__ = "mx_user_profile"
room_id = Column(String, primary_key=True) # type: MatrixRoomID
user_id = Column(String, primary_key=True) # type: MatrixUserID
membership = Column(String, nullable=False, default="leave")
displayname = Column(String, nullable=True)
avatar_url = Column(String, nullable=True)
def dict(self) -> Dict[str, str]:
return {
"membership": self.membership,
"displayname": self.displayname,
"avatar_url": self.avatar_url,
}
@classmethod
def get(cls, room_id: MatrixRoomID, user_id: MatrixUserID) -> Optional['UserProfile']:
rows = cls.db.execute(
cls.t.select().where(and_(cls.c.room_id == room_id, cls.c.user_id == user_id)))
try:
room_id, user_id, membership, displayname, avatar_url = next(rows)
return cls(room_id=room_id, user_id=user_id, membership=membership,
displayname=displayname, avatar_url=avatar_url)
except StopIteration:
return None
@classmethod
def delete_all(cls, room_id: MatrixRoomID) -> None:
with cls.db.begin() as conn:
conn.execute(cls.t.delete().where(cls.c.room_id == room_id))
def update(self) -> None:
super().update(membership=self.membership, displayname=self.displayname,
avatar_url=self.avatar_url)
@property
def _edit_identity(self):
return and_(self.c.room_id == self.room_id, self.c.user_id == self.user_id)
def insert(self) -> None:
with self.db.begin() as conn:
conn.execute(self.t.insert().values(room_id=self.room_id, user_id=self.user_id,
membership=self.membership,
displayname=self.displayname,
avatar_url=self.avatar_url))
+128 -246
View File
@@ -13,66 +13,59 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Match, Optional, Set, Tuple, TYPE_CHECKING from typing import Dict, Match, Optional, Set, Tuple, Union, Iterable, TYPE_CHECKING
import logging
import asyncio
import time import time
import re import re
from mautrix_appservice import MatrixRequestError, IntentError from mautrix.bridge import BaseMatrixHandler
from mautrix.types import (Event, EventType, RoomID, UserID, EventID, ReceiptEvent, ReceiptType,
ReceiptEventContent, PresenceEvent, PresenceState, TypingEvent,
MessageEvent, StateEvent, RedactionEvent, RoomNameStateEventContent,
RoomAvatarStateEventContent, RoomTopicStateEventContent,
MemberStateEventContent)
from mautrix.errors import MatrixError
from .types import MatrixEvent, MatrixEventID, MatrixRoomID, MatrixUserID
from . import user as u, portal as po, puppet as pu, commands as com from . import user as u, portal as po, puppet as pu, commands as com
if TYPE_CHECKING: if TYPE_CHECKING:
from .context import Context from .context import Context
from .config import Config
from .bot import Bot from .bot import Bot
from mautrix_appservice import AppService
try: try:
from prometheus_client import Histogram from prometheus_client import Histogram
EVENT_TIME = Histogram("matrix_event", "Time spent processing Matrix events", EVENT_TIME = Histogram("matrix_event", "Time spent processing Matrix events", ["event_type"])
["event_type"])
except ImportError: except ImportError:
Histogram = None Histogram = None
EVENT_TIME = None EVENT_TIME = None
class MatrixHandler: RoomMetaStateEventContent = Union[RoomNameStateEventContent, RoomAvatarStateEventContent,
log: logging.Logger = logging.getLogger("mau.mx") RoomTopicStateEventContent]
az: 'AppService'
config: 'Config'
class MatrixHandler(BaseMatrixHandler):
bot: 'Bot' bot: 'Bot'
commands: 'com.CommandProcessor' commands: 'com.CommandProcessor'
previously_typing: Dict[MatrixRoomID, Set[MatrixUserID]] previously_typing: Dict[RoomID, Set[UserID]]
def __init__(self, context: 'Context') -> None: def __init__(self, context: 'Context') -> None:
self.az, self.config, _, self.tgbot = context.core super(MatrixHandler, self).__init__(context.az, context.config, loop=context.loop,
self.commands = com.CommandProcessor(context) command_processor=com.CommandProcessor(context))
self.bot = context.bot
self.previously_typing = {} self.previously_typing = {}
self.az.matrix_event_handler(self.handle_event) async def get_user(self, user_id: UserID) -> 'u.User':
return await u.User.get_by_mxid(user_id).ensure_started()
async def init_as_bot(self) -> None: async def get_portal(self, room_id: RoomID) -> 'po.Portal':
displayname = self.config["appservice.bot_displayname"] return po.Portal.get_by_mxid(room_id)
if displayname:
try:
await self.az.intent.set_display_name(
displayname if displayname != "remove" else "")
except asyncio.TimeoutError:
self.log.exception("TimeoutError when trying to set displayname")
avatar = self.config["appservice.bot_avatar"] async def get_puppet(self, user_id: UserID) -> 'pu.Puppet':
if avatar: return pu.Puppet.get_by_mxid(user_id)
try:
await self.az.intent.set_avatar(avatar if avatar != "remove" else "")
except asyncio.TimeoutError:
self.log.exception("TimeoutError when trying to set avatar")
async def handle_puppet_invite(self, room_id: MatrixRoomID, puppet: pu.Puppet, inviter: u.User async def handle_puppet_invite(self, room_id: RoomID, puppet: pu.Puppet, inviter: u.User,
) -> None: event_id: EventID) -> None:
intent = puppet.default_mxid_intent intent = puppet.default_mxid_intent
self.log.debug(f"{inviter} invited puppet for {puppet.tgid} to {room_id}") self.log.debug(f"{inviter} invited puppet for {puppet.tgid} to {room_id}")
if not await inviter.is_logged_in(): if not await inviter.is_logged_in():
@@ -90,7 +83,7 @@ class MatrixHandler:
return return
try: try:
members = await self.az.intent.get_room_members(room_id) members = await self.az.intent.get_room_members(room_id)
except MatrixRequestError: except MatrixError:
members = [] members = []
if self.az.bot_mxid not in members: if self.az.bot_mxid not in members:
if len(members) > 1: if len(members) > 1:
@@ -113,7 +106,7 @@ class MatrixHandler:
"</a>")) "</a>"))
await intent.leave_room(room_id) await intent.leave_room(room_id)
return return
except MatrixRequestError: except MatrixError:
pass pass
portal.mxid = room_id portal.mxid = room_id
portal.save() portal.save()
@@ -124,67 +117,25 @@ class MatrixHandler:
await intent.send_notice(room_id, "This puppet will remain inactive until a " await intent.send_notice(room_id, "This puppet will remain inactive until a "
"Telegram chat is created for this room.") "Telegram chat is created for this room.")
async def accept_bot_invite(self, room_id: MatrixRoomID, inviter: u.User) -> None: async def send_welcome_message(self, room_id: RoomID, inviter: 'u.User', event_id: EventID
tries = 0 ) -> None:
while tries < 5:
try:
await self.az.intent.join_room(room_id)
break
except (IntentError, MatrixRequestError):
tries += 1
wait_for_seconds = (tries + 1) * 10
if tries < 5:
self.log.exception(f"Failed to join room {room_id} with bridge bot, "
f"retrying in {wait_for_seconds} seconds...")
await asyncio.sleep(wait_for_seconds)
else:
self.log.exception("Failed to join room {room}, giving up.")
return
if not inviter.whitelisted:
await self.az.intent.send_notice(
room_id,
text="You are not whitelisted to use this bridge.\n\n"
"If you are the owner of this bridge, see the "
"`bridge.permissions` section in your config file.",
html="<p>You are not whitelisted to use this bridge.</p>"
"<p>If you are the owner of this bridge, see the "
"<code>bridge.permissions</code> section in your config file.</p>")
await self.az.intent.leave_room(room_id)
try: try:
is_management = len(await self.az.intent.get_room_members(room_id)) == 2 is_management = len(await self.az.intent.get_room_members(room_id)) == 2
except MatrixRequestError: except MatrixError:
is_management = False # The AS bot is not in the room.
return
cmd_prefix = self.commands.command_prefix cmd_prefix = self.commands.command_prefix
text = html = "Hello, I'm a Telegram bridge bot. " text = html = "Hello, I'm a Telegram bridge bot. "
if is_management and inviter.puppet_whitelisted and not await inviter.is_logged_in(): if is_management and inviter.puppet_whitelisted and not await inviter.is_logged_in():
text += f"Use `{cmd_prefix} help` for help or `{cmd_prefix} login` to log in." text += f"Use `{cmd_prefix} help` for help or `{cmd_prefix} login` to log in."
html += (f"Use <code>{cmd_prefix} help</code> for help" html += (f"Use <code>{cmd_prefix} help</code> for help"
f" or <code>{cmd_prefix} login</code> to log in.") f" or <code>{cmd_prefix} login</code> to log in.")
pass
else: else:
text += f"Use `{cmd_prefix} help` for help." text += f"Use `{cmd_prefix} help` for help."
html += f"Use <code>{cmd_prefix} help</code> for help." html += f"Use <code>{cmd_prefix} help</code> for help."
await self.az.intent.send_notice(room_id, text=text, html=html) await self.az.intent.send_notice(room_id, text=text, html=html)
async def handle_invite(self, room_id: MatrixRoomID, user_id: MatrixUserID, async def handle_invite(self, room_id: RoomID, user_id: UserID, inviter: 'u.User') -> None:
inviter_mxid: MatrixUserID) -> None:
self.log.debug(f"{inviter_mxid} invited {user_id} to {room_id}")
inviter = u.User.get_by_mxid(inviter_mxid)
if inviter is None:
self.log.exception("Failed to find user with Matrix ID {inviter_mxid}")
await inviter.ensure_started()
if user_id == self.az.bot_mxid:
return await self.accept_bot_invite(room_id, inviter)
elif not inviter.whitelisted:
return
puppet = pu.Puppet.get_by_mxid(user_id)
if puppet:
await self.handle_puppet_invite(room_id, puppet, inviter)
return
user = u.User.get_by_mxid(user_id, create=False) user = u.User.get_by_mxid(user_id, create=False)
if not user: if not user:
return return
@@ -194,10 +145,8 @@ class MatrixHandler:
await portal.invite_telegram(inviter, user) await portal.invite_telegram(inviter, user)
return return
# The rest can probably be ignored async def handle_join(self, room_id: RoomID, user_id: UserID,
event_id: EventID) -> None:
async def handle_join(self, room_id: MatrixRoomID, user_id: MatrixUserID,
event_id: MatrixEventID) -> None:
user = await u.User.get_by_mxid(user_id).ensure_started() user = await u.User.get_by_mxid(user_id).ensure_started()
portal = po.Portal.get_by_mxid(room_id) portal = po.Portal.get_by_mxid(room_id)
@@ -218,11 +167,11 @@ class MatrixHandler:
if await user.is_logged_in() or portal.has_bot: if await user.is_logged_in() or portal.has_bot:
await portal.join_matrix(user, event_id) await portal.join_matrix(user, event_id)
async def handle_part(self, room_id: MatrixRoomID, user_id: MatrixUserID, async def handle_raw_leave(self, room_id: RoomID, user_id: UserID, sender_id: UserID,
sender_mxid: MatrixUserID, event_id: MatrixEventID) -> None: reason: str, event_id: EventID) -> None:
self.log.debug(f"{user_id} left {room_id}") self.log.debug(f"{user_id} left {room_id}")
sender = u.User.get_by_mxid(sender_mxid, create=False) sender = u.User.get_by_mxid(sender_id, create=False)
if not sender: if not sender:
return return
await sender.ensure_started() await sender.ensure_started()
@@ -233,98 +182,67 @@ class MatrixHandler:
puppet = pu.Puppet.get_by_mxid(user_id) puppet = pu.Puppet.get_by_mxid(user_id)
if puppet: if puppet:
if sender: await portal.kick_matrix(puppet, sender)
await portal.kick_matrix(puppet, sender)
return return
user = u.User.get_by_mxid(user_id, create=False) user = u.User.get_by_mxid(user_id, create=False)
if not user: if not user:
return return
await user.ensure_started() await user.ensure_started()
if await user.is_logged_in() or portal.has_bot: if sender_id != user_id:
await portal.leave_matrix(user, sender, event_id) await portal.kick_matrix(user, sender)
else:
def is_command(self, message: Dict) -> Tuple[bool, str]: await portal.leave_matrix(user, event_id)
text = message.get("body", "")
prefix = self.config["bridge.command_prefix"]
is_command = text.startswith(prefix)
if is_command:
text = text[len(prefix) + 1:].lstrip()
return is_command, text
async def handle_message(self, room: MatrixRoomID, sender_id: MatrixUserID, message: Dict,
event_id: MatrixEventID) -> None:
is_command, text = self.is_command(message)
sender = await u.User.get_by_mxid(sender_id).ensure_started()
if not sender.relaybot_whitelisted:
self.log.debug(f"Ignoring message \"{message}\" from {sender} to {room}:"
" User is not whitelisted.")
return
self.log.debug(f"Received Matrix event \"{message}\" from {sender} in {room}")
portal = po.Portal.get_by_mxid(room)
if not is_command and portal and (await sender.is_logged_in() or portal.has_bot):
await portal.handle_matrix_message(sender, message, event_id)
return
if not sender.whitelisted or message.get("msgtype", "m.unknown") != "m.text":
return
try:
is_management = len(await self.az.intent.get_room_members(room)) == 2
except MatrixRequestError:
# The AS bot is not in the room.
return
if is_command or is_management:
try:
command, arguments = text.split(" ", 1)
args = arguments.split(" ")
except ValueError:
# Not enough values to unpack, i.e. no arguments
command = text
args = []
await self.commands.handle(room, event_id, sender, command, args, is_management,
is_portal=portal is not None)
@staticmethod @staticmethod
async def handle_redaction(room_id: MatrixRoomID, sender_mxid: MatrixUserID, async def allow_message(user: 'u.User') -> bool:
event_id: MatrixEventID) -> None: return user.relaybot_whitelisted
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
@staticmethod
async def allow_command(user: 'u.User') -> bool:
return user.whitelisted
@staticmethod
async def allow_bridging_message(user: 'u.User', portal: 'po.Portal') -> bool:
return await user.is_logged_in() or portal.has_bot
@staticmethod
async def handle_redaction(evt: RedactionEvent) -> None:
sender = await u.User.get_by_mxid(evt.sender).ensure_started()
if not sender.relaybot_whitelisted: if not sender.relaybot_whitelisted:
return return
portal = po.Portal.get_by_mxid(room_id) portal = po.Portal.get_by_mxid(evt.room_id)
if not portal: if not portal:
return return
await portal.handle_matrix_deletion(sender, event_id) await portal.handle_matrix_deletion(sender, evt.redacts)
@staticmethod @staticmethod
async def handle_power_levels(room_id: MatrixRoomID, sender_mxid: MatrixUserID, async def handle_power_levels(evt: StateEvent) -> None:
new: Dict, old: Dict) -> None: portal = po.Portal.get_by_mxid(evt.event_id)
portal = po.Portal.get_by_mxid(room_id) sender = await u.User.get_by_mxid(evt.sender).ensure_started()
sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if await sender.has_full_access(allow_bot=True) and portal: if await sender.has_full_access(allow_bot=True) and portal:
await portal.handle_matrix_power_levels(sender, new["users"], old["users"]) await portal.handle_matrix_power_levels(sender, evt.content.users,
evt.unsigned.prev_content.users)
@staticmethod @staticmethod
async def handle_room_meta(evt_type: str, room_id: MatrixRoomID, sender_mxid: MatrixUserID, async def handle_room_meta(evt_type: EventType, room_id: RoomID, sender_mxid: UserID,
content: dict) -> None: content: RoomMetaStateEventContent) -> None:
portal = po.Portal.get_by_mxid(room_id) portal = po.Portal.get_by_mxid(room_id)
sender = await u.User.get_by_mxid(sender_mxid).ensure_started() sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
if await sender.has_full_access(allow_bot=True) and portal: if await sender.has_full_access(allow_bot=True) and portal:
handler, content_key = { handler, content_key = {
"m.room.name": (portal.handle_matrix_title, "name"), EventType.ROOM_NAME: (portal.handle_matrix_title, "name"),
"m.room.topic": (portal.handle_matrix_about, "topic"), EventType.ROOM_TOPIC: (portal.handle_matrix_about, "topic"),
"m.room.avatar": (portal.handle_matrix_avatar, "url"), EventType.ROOM_AVATAR: (portal.handle_matrix_avatar, "url"),
}[evt_type] }[evt_type]
if content_key not in content: if content_key not in content:
return return
await handler(sender, content[content_key]) await handler(sender, content[content_key])
@staticmethod @staticmethod
async def handle_room_pin(room_id: MatrixRoomID, sender_mxid: MatrixUserID, async def handle_room_pin(room_id: RoomID, sender_mxid: UserID,
new_events: Set[str], old_events: Set[str]) -> None: new_events: Set[str], old_events: Set[str]) -> None:
portal = po.Portal.get_by_mxid(room_id) portal = po.Portal.get_by_mxid(room_id)
sender = await u.User.get_by_mxid(sender_mxid).ensure_started() sender = await u.User.get_by_mxid(sender_mxid).ensure_started()
@@ -332,55 +250,61 @@ class MatrixHandler:
events = new_events - old_events events = new_events - old_events
if len(events) > 0: if len(events) > 0:
# New event pinned, set that as pinned in Telegram. # New event pinned, set that as pinned in Telegram.
await portal.handle_matrix_pin(sender, MatrixEventID(events.pop())) await portal.handle_matrix_pin(sender, EventID(events.pop()))
elif len(new_events) == 0: elif len(new_events) == 0:
# All pinned events removed, remove pinned event in Telegram. # All pinned events removed, remove pinned event in Telegram.
await portal.handle_matrix_pin(sender, None) await portal.handle_matrix_pin(sender, None)
@staticmethod @staticmethod
async def handle_room_upgrade(room_id: MatrixRoomID, new_room_id: MatrixRoomID) -> None: async def handle_room_upgrade(room_id: RoomID, new_room_id: RoomID) -> None:
portal = po.Portal.get_by_mxid(room_id) portal = po.Portal.get_by_mxid(room_id)
if portal: if portal:
await portal.handle_matrix_upgrade(new_room_id) await portal.handle_matrix_upgrade(new_room_id)
@staticmethod @staticmethod
async def handle_name_change(room_id: MatrixRoomID, user_id: MatrixUserID, displayname: str, async def handle_member_info_change(room_id: RoomID, user_id: UserID,
prev_displayname: str, event_id: MatrixEventID) -> None: profile: MemberStateEventContent,
prev_profile: MemberStateEventContent,
event_id: EventID) -> None:
if profile.displayname == prev_profile.displayname:
return
portal = po.Portal.get_by_mxid(room_id) portal = po.Portal.get_by_mxid(room_id)
if not portal or not portal.has_bot: if not portal or not portal.has_bot:
return return
user = await u.User.get_by_mxid(user_id).ensure_started() user = await u.User.get_by_mxid(user_id).ensure_started()
if await user.needs_relaybot(portal): if await user.needs_relaybot(portal):
await portal.name_change_matrix(user, displayname, prev_displayname, event_id) await portal.name_change_matrix(user, profile.displayname, prev_profile.displayname,
event_id)
@staticmethod @staticmethod
def parse_read_receipts(content: Dict) -> Dict[MatrixUserID, MatrixEventID]: def parse_read_receipts(content: ReceiptEventContent) -> Iterable[Tuple[UserID, EventID]]:
return {user_id: event_id return ((user_id, event_id)
for event_id, receipts in content.items() for event_id, receipts in content.items()
for user_id in receipts.get("m.read", {})} for user_id in receipts.get(ReceiptType.READ, {}))
@staticmethod @staticmethod
async def handle_read_receipts(room_id: MatrixRoomID, async def handle_read_receipts(room_id: RoomID, receipts: Iterable[Tuple[UserID, EventID]]
receipts: Dict[MatrixUserID, MatrixEventID]) -> None: ) -> None:
portal = po.Portal.get_by_mxid(room_id) portal = po.Portal.get_by_mxid(room_id)
if not portal: if not portal:
return return
for user_id, event_id in receipts.items(): for user_id, event_id in receipts:
user = await u.User.get_by_mxid(user_id).ensure_started() user = await u.User.get_by_mxid(user_id).ensure_started()
if not await user.is_logged_in(): if not await user.is_logged_in():
continue continue
await portal.mark_read(user, event_id) await portal.mark_read(user, event_id)
@staticmethod @staticmethod
async def handle_presence(user_id: MatrixUserID, presence: str) -> None: async def handle_presence(user_id: UserID, presence: PresenceState) -> None:
user = await u.User.get_by_mxid(user_id).ensure_started() user = await u.User.get_by_mxid(user_id).ensure_started()
if not await user.is_logged_in(): if not await user.is_logged_in():
return return
await user.set_presence(presence == "online") await user.set_presence(presence == PresenceState.ONLINE)
async def handle_typing(self, room_id: MatrixRoomID, now_typing: Set[MatrixUserID]) -> None: async def handle_typing(self, room_id: RoomID, now_typing: Set[UserID]) -> None:
portal = po.Portal.get_by_mxid(room_id) portal = po.Portal.get_by_mxid(room_id)
if not portal: if not portal:
return return
@@ -401,86 +325,44 @@ class MatrixHandler:
self.previously_typing[room_id] = now_typing self.previously_typing[room_id] = now_typing
def filter_matrix_event(self, event: MatrixEvent) -> bool: def filter_matrix_event(self, evt: Event) -> bool:
sender = event.get("sender", None) if not isinstance(evt, (MessageEvent, StateEvent)):
if not sender: return True
return False return evt.sender and (evt.sender == self.az.bot_mxid
return (sender == self.az.bot_mxid or pu.Puppet.get_id_from_mxid(evt.sender) is not None)
or pu.Puppet.get_id_from_mxid(sender) is not None)
async def try_handle_ephemeral_event(self, evt: MatrixEvent) -> None: async def handle_ephemeral_event(self, evt: Union[ReceiptEvent, PresenceEvent, TypingEvent]
try: ) -> None:
await self.handle_ephemeral_event(evt) if evt.type == EventType.RECEIPT:
except Exception: await self.handle_read_receipts(evt.room_id, self.parse_read_receipts(evt.content))
self.log.exception("Error handling manually received Matrix event") elif evt.type == EventType.PRESENCE:
await self.handle_presence(evt.sender, evt.content.presence)
elif evt.type == EventType.TYPING:
await self.handle_typing(evt.room_id, set(evt.content.user_ids))
async def handle_ephemeral_event(self, evt: MatrixEvent) -> None: async def handle_event(self, evt: Event) -> None:
evt_type: str = evt.get("type", "m.unknown") if evt.type == EventType.ROOM_REDACTION:
room_id: Optional[MatrixRoomID] = evt.get("room_id", None) await self.handle_redaction(evt)
sender: Optional[MatrixUserID] = evt.get("sender", None)
content: Dict = evt.get("content", {})
if evt_type == "m.receipt":
await self.handle_read_receipts(room_id, self.parse_read_receipts(content))
elif evt_type == "m.presence":
await self.handle_presence(sender, content.get("presence", "offline"))
elif evt_type == "m.typing":
await self.handle_typing(room_id, set(content.get("user_ids", [])))
async def handle_event(self, evt: MatrixEvent) -> None: async def handle_state_event(self, evt: StateEvent) -> None:
if self.filter_matrix_event(evt): if evt.type == EventType.ROOM_POWER_LEVELS:
return await self.handle_power_levels(evt)
start_time = time.time() elif evt.type in (EventType.ROOM_NAME, EventType.ROOM_AVATAR, EventType.ROOM_TOPIC):
self.log.debug("Received event: %s", evt) await self.handle_room_meta(evt.type, evt.room_id, evt.sender, evt.content)
evt_type: str = evt.get("type", "m.unknown") elif evt.type == EventType.ROOM_PINNED_EVENTS:
room_id: Optional[MatrixRoomID] = evt.get("room_id", None) new_events = set(evt.content.pinned)
event_id: Optional[MatrixEventID] = evt.get("event_id", None) try:
sender: Optional[MatrixUserID] = evt.get("sender", None) old_events = set(evt.unsigned.prev_content.pinned)
state_key = evt.get("state_key", None) except (KeyError, ValueError, TypeError, AttributeError):
content: Dict = evt.get("content", {}) old_events = set()
if state_key is not None: await self.handle_room_pin(evt.room_id, evt.sender, new_events, old_events)
if evt_type == "m.room.member": elif evt.type == EventType.ROOM_TOMBSTONE:
prev_content: Dict = evt.get("unsigned", {}).get("prev_content", {}) await self.handle_room_upgrade(evt.room_id, evt.content.replacement_room)
membership: str = content.get("membership", "")
prev_membership: str = prev_content.get("membership", "leave")
if membership == prev_membership:
match: Match = re.compile("@(.+):(.+)").match(state_key)
mxid: str = match.group(0)
displayname: str = content.get("displayname", None) or mxid
prev_displayname: str = prev_content.get("displayname", None) or mxid
if displayname != prev_displayname:
await self.handle_name_change(room_id, state_key, displayname,
prev_displayname, event_id)
elif membership == "invite":
await self.handle_invite(room_id, state_key, sender)
elif prev_membership == "join" and membership == "leave":
await self.handle_part(room_id, state_key, sender, event_id)
elif membership == "join":
await self.handle_join(room_id, state_key, event_id)
elif evt_type == "m.room.power_levels":
prev_content = evt.get("unsigned", {}).get("prev_content", {})
await self.handle_power_levels(room_id, sender, evt["content"], prev_content)
elif evt_type in ("m.room.name", "m.room.avatar", "m.room.topic"):
await self.handle_room_meta(evt_type, room_id, sender, evt["content"])
elif evt_type == "m.room.pinned_events":
new_events = set(evt["content"]["pinned"])
try:
old_events = set(evt["unsigned"]["prev_content"]["pinned"])
except KeyError:
old_events = set()
await self.handle_room_pin(room_id, sender, new_events, old_events)
elif evt_type == "m.room.tombstone":
await self.handle_room_upgrade(room_id, evt["content"]["replacement_room"])
else:
return
else:
if evt_type in ("m.room.message", "m.sticker"):
if evt_type != "m.room.message":
content["msgtype"] = evt_type
await self.handle_message(room_id, sender, content, event_id)
elif evt_type == "m.room.redaction":
await self.handle_redaction(room_id, sender, evt["redacts"])
else:
return
if EVENT_TIME: # async def handle_event(self, evt: MatrixEvent) -> None:
EVENT_TIME.labels(event_type=evt_type).observe(time.time() - start_time) # if self.filter_matrix_event(evt):
# return
# start_time = time.time()
#
# if EVENT_TIME:
# EVENT_TIME.labels(event_type=evt_type).observe(time.time() - start_time)
+17 -10
View File
@@ -864,23 +864,32 @@ class Portal:
else: else:
await user.client(ReadMessageHistoryRequest(peer=self.peer, max_id=message.tgid)) await user.client(ReadMessageHistoryRequest(peer=self.peer, max_id=message.tgid))
async def kick_matrix(self, user: Union['u.User', 'p.Puppet'], source: 'u.User') -> None: async def kick_matrix(self, user: Union['u.User', 'p.Puppet'], source: 'u.User',
ban: bool = False) -> None:
if user.tgid == source.tgid: if user.tgid == source.tgid:
return return
if isinstance(user, u.User) and await user.needs_relaybot(self):
if not self.bot:
return
# TODO kick and ban message
return
if await source.needs_relaybot(self): if await source.needs_relaybot(self):
if not self.has_bot:
return
source = self.bot source = self.bot
target = await user.get_input_entity(source)
if self.peer_type == "chat": if self.peer_type == "chat":
await source.client(DeleteChatUserRequest(chat_id=self.tgid, user_id=user.tgid)) await source.client(DeleteChatUserRequest(chat_id=self.tgid, user_id=target))
elif self.peer_type == "channel": elif self.peer_type == "channel":
channel = await self.get_input_entity(source) channel = await self.get_input_entity(source)
rights = ChatBannedRights(datetime.fromtimestamp(0), True) await source.client.edit_permissions(channel, target, view_messages=False)
await source.client(EditBannedRequest(channel=channel, if not ban:
user_id=user.tgid, await source.client.edit_permissions(channel, target, view_messages=True)
banned_rights=rights))
async def leave_matrix(self, user: 'u.User', source: 'u.User', async def leave_matrix(self, user: 'u.User', event_id: MatrixEventID) -> None:
event_id: MatrixEventID) -> None:
if await user.needs_relaybot(self): if await user.needs_relaybot(self):
if not self.has_bot:
return
async with self.require_send_lock(self.bot.tgid): async with self.require_send_lock(self.bot.tgid):
message = await self._get_state_change_message("leave", user) message = await self._get_state_change_message("leave", user)
if not message: if not message:
@@ -900,8 +909,6 @@ class Portal:
del self.by_mxid[self.mxid] del self.by_mxid[self.mxid]
except KeyError: except KeyError:
pass pass
elif source and source.tgid != user.tgid:
await self.kick_matrix(user, source)
elif self.peer_type == "chat": elif self.peer_type == "chat":
await user.client(DeleteChatUserRequest(chat_id=self.tgid, user_id=InputUserSelf())) await user.client(DeleteChatUserRequest(chat_id=self.tgid, user_id=InputUserSelf()))
elif self.peer_type == "channel": elif self.peer_type == "channel":
+2 -2
View File
@@ -521,7 +521,7 @@ class Puppet:
# endregion # endregion
def init(context: 'Context') -> List[Awaitable[Any]]: # [None, None, PuppetError] def init(context: 'Context') -> Iterable[Awaitable[Any]]:
global config global config
Puppet.az, config, Puppet.loop, _ = context.core Puppet.az, config, Puppet.loop, _ = context.core
Puppet.mx = context.mx Puppet.mx = context.mx
@@ -529,4 +529,4 @@ def init(context: 'Context') -> List[Awaitable[Any]]: # [None, None, PuppetErro
Puppet.hs_domain = config["homeserver"]["domain"] Puppet.hs_domain = config["homeserver"]["domain"]
Puppet.mxid_regex = re.compile( Puppet.mxid_regex = re.compile(
f"@{Puppet.username_template.format(userid='([0-9]+)')}:{Puppet.hs_domain}") f"@{Puppet.username_template.format(userid='([0-9]+)')}:{Puppet.hs_domain}")
return [puppet.init_custom_mxid() for puppet in Puppet.all_with_custom_mxid()] return (puppet.init_custom_mxid() for puppet in Puppet.all_with_custom_mxid())
+15 -98
View File
@@ -13,109 +13,26 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Tuple from mautrix.types import UserID
from mautrix.bridge.db import SQLStateStore as BaseSQLStateStore
from mautrix_appservice import StateStore
from .types import MatrixUserID, MatrixRoomID
from . import puppet as pu from . import puppet as pu
from .db import RoomState, UserProfile
class SQLStateStore(StateStore): class SQLStateStore(BaseSQLStateStore):
profile_cache: Dict[Tuple[str, str], UserProfile] def is_registered(self, user_id: UserID) -> bool:
room_state_cache: Dict[str, RoomState] puppet = pu.Puppet.get_by_mxid(user_id, create=False)
if puppet:
return puppet.is_registered
custom_puppet = pu.Puppet.get_by_custom_mxid(user_id)
if custom_puppet:
return True
return super().is_registered(user_id)
def __init__(self) -> None: def registered(self, user_id: UserID) -> None:
super().__init__() puppet = pu.Puppet.get_by_mxid(user_id, create=True)
self.profile_cache = {}
self.room_state_cache = {}
@staticmethod
def is_registered(user: MatrixUserID) -> bool:
puppet = pu.Puppet.get_by_mxid(user)
return puppet.is_registered if puppet else False
@staticmethod
def registered(user: MatrixUserID) -> None:
puppet = pu.Puppet.get_by_mxid(user)
if puppet: if puppet:
puppet.is_registered = True puppet.is_registered = True
puppet.save() puppet.save()
else:
def update_state(self, event: Dict) -> None: super().registered(user_id)
event_type = event["type"]
if event_type == "m.room.power_levels":
self.set_power_levels(event["room_id"], event["content"])
elif event_type == "m.room.member":
self.set_member(event["room_id"], event["state_key"], event["content"])
def _get_user_profile(self, room_id: MatrixRoomID, user_id: MatrixUserID, create: bool = True
) -> UserProfile:
key = (room_id, user_id)
try:
return self.profile_cache[key]
except KeyError:
pass
profile = UserProfile.get(*key)
if profile:
self.profile_cache[key] = profile
elif create:
profile = UserProfile(room_id=room_id, user_id=user_id, membership="leave")
profile.insert()
self.profile_cache[key] = profile
return profile
def get_member(self, room: MatrixRoomID, user: MatrixUserID) -> Dict:
return self._get_user_profile(room, user).dict()
def set_member(self, room: MatrixRoomID, user: MatrixUserID, member: Dict) -> None:
profile = self._get_user_profile(room, user)
profile.membership = member.get("membership", profile.membership or "leave")
profile.displayname = member.get("displayname", profile.displayname)
profile.avatar_url = member.get("avatar_url", profile.avatar_url)
profile.update()
def set_membership(self, room: MatrixRoomID, user: MatrixUserID, membership: str) -> None:
self.set_member(room, user, {
"membership": membership,
})
def _get_room_state(self, room_id: MatrixRoomID, create: bool = True) -> RoomState:
try:
return self.room_state_cache[room_id]
except KeyError:
pass
room = RoomState.get(room_id)
if room:
self.room_state_cache[room_id] = room
elif create:
room = RoomState(room_id=room_id)
room.insert()
self.room_state_cache[room_id] = room
return room
def has_power_levels(self, room: MatrixRoomID) -> bool:
return bool(self._get_room_state(room).power_levels)
def get_power_levels(self, room: MatrixRoomID) -> Dict:
return self._get_room_state(room).power_levels
def set_power_level(self, room: MatrixRoomID, user: MatrixUserID, level: int) -> None:
room_state = self._get_room_state(room)
power_levels = room_state.power_levels
if not power_levels:
power_levels = {
"users": {},
"events": {},
}
power_levels[room]["users"][user] = level
room_state.power_levels = power_levels
room_state.update()
def set_power_levels(self, room: MatrixRoomID, content: Dict) -> None:
state = self._get_room_state(room)
state.power_levels = content
state.update()
-6
View File
@@ -1,9 +1,3 @@
from typing import Dict, NewType from typing import Dict, NewType
MatrixUserID = NewType('MatrixUserID', str)
MatrixRoomID = NewType('MatrixRoomID', str)
MatrixEventID = NewType('MatrixEventID', str)
MatrixEvent = NewType('MatrixEvent', Dict)
TelegramID = NewType('TelegramID', int) TelegramID = NewType('TelegramID', int)
+3 -4
View File
@@ -331,7 +331,7 @@ class User(AbstractUser):
async def needs_relaybot(self, portal: po.Portal) -> bool: async def needs_relaybot(self, portal: po.Portal) -> bool:
return not await self.is_logged_in() or ( return not await self.is_logged_in() or (
(portal.has_bot or self.bot) and portal.tgid_full not in self.portals) (portal.has_bot or self.is_bot) and portal.tgid_full not in self.portals)
def _hash_contacts(self) -> int: def _hash_contacts(self) -> int:
acc = 0 acc = 0
@@ -408,9 +408,8 @@ class User(AbstractUser):
# endregion # endregion
def init(context: 'Context') -> List[Awaitable['User']]: def init(context: 'Context') -> Iterable[Awaitable['User']]:
global config global config
config = context.config config = context.config
users = [User.from_db(user) for user in DBUser.all()] return (User.from_db(db_user).ensure_started() for db_user in DBUser.all() if db_user.tgid)
return [user.ensure_started() for user in users if user.tgid]
+1 -1
View File
@@ -1,5 +1,5 @@
aiohttp aiohttp
mautrix-appservice mautrix
ruamel.yaml ruamel.yaml
python-magic python-magic
SQLAlchemy SQLAlchemy
+1 -1
View File
@@ -31,7 +31,7 @@ setuptools.setup(
install_requires=[ install_requires=[
"aiohttp>=3.0.1,<4", "aiohttp>=3.0.1,<4",
"mautrix-appservice>=0.3.11,<0.4.0", "mautrix>=0.4.0.dev46,<0.5",
"SQLAlchemy>=1.2.3,<2", "SQLAlchemy>=1.2.3,<2",
"alembic>=1.0.0,<2", "alembic>=1.0.0,<2",
"commonmark>=0.8.1,<1", "commonmark>=0.8.1,<1",