Stop handling events from custom puppets

This commit is contained in:
Tulir Asokan
2018-07-20 14:13:13 -04:00
parent 2b92483c50
commit ecdca21e32
7 changed files with 94 additions and 34 deletions
+4 -2
View File
@@ -110,8 +110,10 @@ with appserv.run(config["appservice.hostname"], config["appservice.port"]) as st
context.mx = MatrixHandler(context) context.mx = MatrixHandler(context)
init_formatter(context) init_formatter(context)
init_portal(context) init_portal(context)
init_puppet(context) startup_actions = (init_puppet(context) +
startup_actions = init_user(context) + [start, context.mx.init_as_bot()] init_user(context) +
[start,
context.mx.init_as_bot()])
if context.bot: if context.bot:
startup_actions.append(context.bot.start()) startup_actions.append(context.bot.start())
+1 -1
View File
@@ -124,7 +124,7 @@ class AbstractUser:
self.log.debug("%s connected: %s", self.mxid, self.connected) self.log.debug("%s connected: %s", self.mxid, self.connected)
return self return self
async def ensure_started(self, even_if_no_session=False): async def ensure_started(self, even_if_no_session=False) -> "AbstractUser":
if not self.puppet_whitelisted: if not self.puppet_whitelisted:
return self return self
self.log.debug("ensure_started(%s, connected=%s, even_if_no_session=%s, session_count=%s)", self.log.debug("ensure_started(%s, connected=%s, even_if_no_session=%s, session_count=%s)",
+4 -8
View File
@@ -56,15 +56,11 @@ async def ping_bot(evt: CommandEvent):
"account") "account")
async def login_matrix(evt: CommandEvent): async def login_matrix(evt: CommandEvent):
puppet = pu.Puppet.get(evt.sender.tgid) puppet = pu.Puppet.get(evt.sender.tgid)
prev_info = puppet.custom_mxid, puppet.access_token resp = puppet.switch_mxid(" ".join(evt.args), evt.sender.mxid)
puppet.custom_mxid = evt.sender.mxid if resp == 2:
puppet.access_token = " ".join(evt.args) return await evt.reply("You can only log in as your own Matrix user.")
puppet.refresh_intents() elif resp == 1:
if not await puppet.get_profile():
puppet.custom_mxid, puppet.access_token = prev_info
puppet.refresh_intents()
return await evt.reply("Failed to verify access token.") return await evt.reply("Failed to verify access token.")
puppet.save()
return await evt.reply( return await evt.reply(
f"Replaced your Telegram account's Matrix puppet with {puppet.custom_mxid}.") f"Replaced your Telegram account's Matrix puppet with {puppet.custom_mxid}.")
+11 -12
View File
@@ -17,14 +17,14 @@
from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer, from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer,
BigInteger, String, Boolean, Text) BigInteger, String, Boolean, Text)
from sqlalchemy.sql import expression from sqlalchemy.sql import expression
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship, Query
import json import json
from .base import Base from .base import Base
class Portal(Base): class Portal(Base):
query = None query = None # type: Query
__tablename__ = "portal" __tablename__ = "portal"
# Telegram chat information # Telegram chat information
@@ -42,9 +42,8 @@ class Portal(Base):
about = Column(String, nullable=True) about = Column(String, nullable=True)
photo_id = Column(String, nullable=True) photo_id = Column(String, nullable=True)
class Message(Base): class Message(Base):
query = None query = None # type: Query
__tablename__ = "message" __tablename__ = "message"
mxid = Column(String) mxid = Column(String)
@@ -56,7 +55,7 @@ class Message(Base):
class UserPortal(Base): class UserPortal(Base):
query = None query = None # type: Query
__tablename__ = "user_portal" __tablename__ = "user_portal"
user = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE", ondelete="CASCADE"), user = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE", ondelete="CASCADE"),
@@ -70,7 +69,7 @@ class UserPortal(Base):
class User(Base): class User(Base):
query = None query = None # type: Query
__tablename__ = "user" __tablename__ = "user"
mxid = Column(String, primary_key=True) mxid = Column(String, primary_key=True)
@@ -83,7 +82,7 @@ class User(Base):
class RoomState(Base): class RoomState(Base):
query = None query = None # type: Query
__tablename__ = "mx_room_state" __tablename__ = "mx_room_state"
room_id = Column(String, primary_key=True) room_id = Column(String, primary_key=True)
@@ -107,7 +106,7 @@ class RoomState(Base):
class UserProfile(Base): class UserProfile(Base):
query = None query = None # type: Query
__tablename__ = "mx_user_profile" __tablename__ = "mx_user_profile"
room_id = Column(String, primary_key=True) room_id = Column(String, primary_key=True)
@@ -125,7 +124,7 @@ class UserProfile(Base):
class Contact(Base): class Contact(Base):
query = None query = None # type: Query
__tablename__ = "contact" __tablename__ = "contact"
user = Column(Integer, ForeignKey("user.tgid"), primary_key=True) user = Column(Integer, ForeignKey("user.tgid"), primary_key=True)
@@ -133,7 +132,7 @@ class Contact(Base):
class Puppet(Base): class Puppet(Base):
query = None query = None # type: Query
__tablename__ = "puppet" __tablename__ = "puppet"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
@@ -149,14 +148,14 @@ class Puppet(Base):
# Fucking Telegram not telling bots what chats they are in 3:< # Fucking Telegram not telling bots what chats they are in 3:<
class BotChat(Base): class BotChat(Base):
query = None query = None # type: Query
__tablename__ = "bot_chat" __tablename__ = "bot_chat"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
type = Column(String, nullable=False) type = Column(String, nullable=False)
class TelegramFile(Base): class TelegramFile(Base):
query = None query = None # type: Query
__tablename__ = "telegram_file" __tablename__ = "telegram_file"
id = Column(String, primary_key=True) id = Column(String, primary_key=True)
+6 -1
View File
@@ -824,7 +824,12 @@ class Portal:
mxid=event_id)) mxid=event_id))
self.db.commit() self.db.commit()
async def handle_matrix_message(self, sender, message, event_id): async def handle_matrix_message(self, sender: u.User, message: dict, event_id: str):
puppet = p.Puppet.get_by_custom_mxid(sender.mxid)
if puppet and message.get("net.maunium.telegram.puppet", False):
self.log.debug("Ignoring puppet-sent message by confirmed puppet user %s", sender.mxid)
return
logged_in = not await sender.needs_relaybot(self) logged_in = not await sender.needs_relaybot(self)
client = sender.client if logged_in else self.bot.client client = sender.client if logged_in else self.bot.client
sender_id = sender.tgid if logged_in else self.bot.tgid sender_id = sender.tgid if logged_in else self.bot.tgid
+61 -6
View File
@@ -36,6 +36,7 @@ class Puppet:
username_template = None username_template = None
hs_domain = None hs_domain = None
cache = {} cache = {}
by_custom_mxid = {}
def __init__(self, id=None, access_token=None, custom_mxid=None, username=None, def __init__(self, id=None, access_token=None, custom_mxid=None, username=None,
displayname=None, displayname_source=None, photo_id=None, is_bot=None, displayname=None, displayname_source=None, photo_id=None, is_bot=None,
@@ -60,22 +61,51 @@ class Puppet:
self.refresh_intents() self.refresh_intents()
self.cache[id] = self self.cache[id] = self
if self.custom_mxid:
self.by_custom_mxid[self.custom_mxid] = self
def refresh_intents(self): def refresh_intents(self):
self.is_real_user = self.custom_mxid and self.access_token self.is_real_user = self.custom_mxid and self.access_token
self.intent = (self.az.intent.user(self.custom_mxid, self.access_token) self.intent = (self.az.intent.user(self.custom_mxid, self.access_token)
if self.is_real_user else self.default_mxid_intent) if self.is_real_user else self.default_mxid_intent)
async def get_profile(self):
try:
return await self.intent.get_profile(self.custom_mxid)
except MatrixError:
return None
@property @property
def tgid(self): def tgid(self):
return self.id return self.id
async def switch_mxid(self, access_token, mxid):
prev_mxid = self.custom_mxid
self.custom_mxid = mxid
self.access_token = access_token
self.refresh_intents()
err = await self.test_custom_mxid()
if err != 0:
return err
try:
del self.by_custom_mxid[prev_mxid]
except KeyError:
pass
self.mxid = self.custom_mxid or self.default_mxid
self.by_custom_mxid[self.mxid] = self
self.save()
return 0
async def test_custom_mxid(self):
if not self.is_real_user:
return 0
mxid = await self.intent.whoami()
if not mxid or mxid != self.custom_mxid:
self.custom_mxid = None
self.access_token = None
self.refresh_intents()
if mxid != self.custom_mxid:
return 2
return 1
return 0
async def is_logged_in(self): async def is_logged_in(self):
return True return True
@@ -212,6 +242,30 @@ class Puppet:
tgid = cls.get_id_from_mxid(mxid) tgid = cls.get_id_from_mxid(mxid)
return cls.get(tgid, create) if tgid else None return cls.get(tgid, create) if tgid else None
@classmethod
def get_by_custom_mxid(cls, mxid):
if not mxid:
raise ValueError("Matrix ID can't be empty")
try:
return cls.by_custom_mxid[mxid]
except KeyError:
pass
puppet = DBPuppet.query.filter(DBPuppet.custom_mxid == mxid).one_or_none()
if puppet:
puppet = cls.from_db(puppet)
return puppet
return None
@classmethod
def get_all_with_custom_mxid(cls):
return [cls.by_custom_mxid[puppet.mxid]
if puppet.custom_mxid in cls.by_custom_mxid
else cls.from_db(puppet)
for puppet in DBPuppet.query.filter(DBPuppet.custom_mxid is not None).all()]
@classmethod @classmethod
def get_id_from_mxid(cls, mxid): def get_id_from_mxid(cls, mxid):
match = cls.mxid_regex.match(mxid) match = cls.mxid_regex.match(mxid)
@@ -261,3 +315,4 @@ def init(context):
Puppet.hs_domain = config["homeserver"]["domain"] Puppet.hs_domain = config["homeserver"]["domain"]
localpart = Puppet.username_template.format(userid="(.+)") localpart = Puppet.username_template.format(userid="(.+)")
Puppet.mxid_regex = re.compile(f"@{localpart}:{Puppet.hs_domain}") Puppet.mxid_regex = re.compile(f"@{localpart}:{Puppet.hs_domain}")
return [puppet.test_custom_mxid() for puppet in Puppet.get_all_with_custom_mxid()]
+7 -4
View File
@@ -14,7 +14,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict from typing import Dict, Awaitable, Optional
import logging import logging
import asyncio import asyncio
import re import re
@@ -185,6 +185,9 @@ class User(AbstractUser):
# endregion # endregion
# region Telegram actions that need custom methods # region Telegram actions that need custom methods
def ensure_started(self, even_if_no_session=False) -> "Awaitable[User]":
return super().ensure_started(even_if_no_session)
async def update_info(self, info: User = None): async def update_info(self, info: User = None):
info = info or await self.client.get_me() info = info or await self.client.get_me()
changed = False changed = False
@@ -309,7 +312,7 @@ class User(AbstractUser):
# region Class instance lookup # region Class instance lookup
@classmethod @classmethod
def get_by_mxid(cls, mxid, create=True): def get_by_mxid(cls, mxid, create=True) -> "Optional[User]":
if not mxid: if not mxid:
raise ValueError("Matrix ID can't be empty") raise ValueError("Matrix ID can't be empty")
@@ -332,7 +335,7 @@ class User(AbstractUser):
return None return None
@classmethod @classmethod
def get_by_tgid(cls, tgid): def get_by_tgid(cls, tgid) -> "Optional[User]":
try: try:
return cls.by_tgid[tgid] return cls.by_tgid[tgid]
except KeyError: except KeyError:
@@ -346,7 +349,7 @@ class User(AbstractUser):
return None return None
@classmethod @classmethod
def find_by_username(cls, username): def find_by_username(cls, username) -> "Optional[User]":
if not username: if not username:
return None return None