Stop handling events from custom puppets
This commit is contained in:
@@ -110,8 +110,10 @@ with appserv.run(config["appservice.hostname"], config["appservice.port"]) as st
|
|||||||
context.mx = MatrixHandler(context)
|
context.mx = MatrixHandler(context)
|
||||||
init_formatter(context)
|
init_formatter(context)
|
||||||
init_portal(context)
|
init_portal(context)
|
||||||
init_puppet(context)
|
startup_actions = (init_puppet(context) +
|
||||||
startup_actions = init_user(context) + [start, context.mx.init_as_bot()]
|
init_user(context) +
|
||||||
|
[start,
|
||||||
|
context.mx.init_as_bot()])
|
||||||
|
|
||||||
if context.bot:
|
if context.bot:
|
||||||
startup_actions.append(context.bot.start())
|
startup_actions.append(context.bot.start())
|
||||||
|
|||||||
@@ -124,7 +124,7 @@ class AbstractUser:
|
|||||||
self.log.debug("%s connected: %s", self.mxid, self.connected)
|
self.log.debug("%s connected: %s", self.mxid, self.connected)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def ensure_started(self, even_if_no_session=False):
|
async def ensure_started(self, even_if_no_session=False) -> "AbstractUser":
|
||||||
if not self.puppet_whitelisted:
|
if not self.puppet_whitelisted:
|
||||||
return self
|
return self
|
||||||
self.log.debug("ensure_started(%s, connected=%s, even_if_no_session=%s, session_count=%s)",
|
self.log.debug("ensure_started(%s, connected=%s, even_if_no_session=%s, session_count=%s)",
|
||||||
|
|||||||
@@ -56,15 +56,11 @@ async def ping_bot(evt: CommandEvent):
|
|||||||
"account")
|
"account")
|
||||||
async def login_matrix(evt: CommandEvent):
|
async def login_matrix(evt: CommandEvent):
|
||||||
puppet = pu.Puppet.get(evt.sender.tgid)
|
puppet = pu.Puppet.get(evt.sender.tgid)
|
||||||
prev_info = puppet.custom_mxid, puppet.access_token
|
resp = puppet.switch_mxid(" ".join(evt.args), evt.sender.mxid)
|
||||||
puppet.custom_mxid = evt.sender.mxid
|
if resp == 2:
|
||||||
puppet.access_token = " ".join(evt.args)
|
return await evt.reply("You can only log in as your own Matrix user.")
|
||||||
puppet.refresh_intents()
|
elif resp == 1:
|
||||||
if not await puppet.get_profile():
|
|
||||||
puppet.custom_mxid, puppet.access_token = prev_info
|
|
||||||
puppet.refresh_intents()
|
|
||||||
return await evt.reply("Failed to verify access token.")
|
return await evt.reply("Failed to verify access token.")
|
||||||
puppet.save()
|
|
||||||
return await evt.reply(
|
return await evt.reply(
|
||||||
f"Replaced your Telegram account's Matrix puppet with {puppet.custom_mxid}.")
|
f"Replaced your Telegram account's Matrix puppet with {puppet.custom_mxid}.")
|
||||||
|
|
||||||
|
|||||||
+11
-12
@@ -17,14 +17,14 @@
|
|||||||
from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer,
|
from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer,
|
||||||
BigInteger, String, Boolean, Text)
|
BigInteger, String, Boolean, Text)
|
||||||
from sqlalchemy.sql import expression
|
from sqlalchemy.sql import expression
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship, Query
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from .base import Base
|
from .base import Base
|
||||||
|
|
||||||
|
|
||||||
class Portal(Base):
|
class Portal(Base):
|
||||||
query = None
|
query = None # type: Query
|
||||||
__tablename__ = "portal"
|
__tablename__ = "portal"
|
||||||
|
|
||||||
# Telegram chat information
|
# Telegram chat information
|
||||||
@@ -42,9 +42,8 @@ class Portal(Base):
|
|||||||
about = Column(String, nullable=True)
|
about = Column(String, nullable=True)
|
||||||
photo_id = Column(String, nullable=True)
|
photo_id = Column(String, nullable=True)
|
||||||
|
|
||||||
|
|
||||||
class Message(Base):
|
class Message(Base):
|
||||||
query = None
|
query = None # type: Query
|
||||||
__tablename__ = "message"
|
__tablename__ = "message"
|
||||||
|
|
||||||
mxid = Column(String)
|
mxid = Column(String)
|
||||||
@@ -56,7 +55,7 @@ class Message(Base):
|
|||||||
|
|
||||||
|
|
||||||
class UserPortal(Base):
|
class UserPortal(Base):
|
||||||
query = None
|
query = None # type: Query
|
||||||
__tablename__ = "user_portal"
|
__tablename__ = "user_portal"
|
||||||
|
|
||||||
user = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE", ondelete="CASCADE"),
|
user = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE", ondelete="CASCADE"),
|
||||||
@@ -70,7 +69,7 @@ class UserPortal(Base):
|
|||||||
|
|
||||||
|
|
||||||
class User(Base):
|
class User(Base):
|
||||||
query = None
|
query = None # type: Query
|
||||||
__tablename__ = "user"
|
__tablename__ = "user"
|
||||||
|
|
||||||
mxid = Column(String, primary_key=True)
|
mxid = Column(String, primary_key=True)
|
||||||
@@ -83,7 +82,7 @@ class User(Base):
|
|||||||
|
|
||||||
|
|
||||||
class RoomState(Base):
|
class RoomState(Base):
|
||||||
query = None
|
query = None # type: Query
|
||||||
__tablename__ = "mx_room_state"
|
__tablename__ = "mx_room_state"
|
||||||
|
|
||||||
room_id = Column(String, primary_key=True)
|
room_id = Column(String, primary_key=True)
|
||||||
@@ -107,7 +106,7 @@ class RoomState(Base):
|
|||||||
|
|
||||||
|
|
||||||
class UserProfile(Base):
|
class UserProfile(Base):
|
||||||
query = None
|
query = None # type: Query
|
||||||
__tablename__ = "mx_user_profile"
|
__tablename__ = "mx_user_profile"
|
||||||
|
|
||||||
room_id = Column(String, primary_key=True)
|
room_id = Column(String, primary_key=True)
|
||||||
@@ -125,7 +124,7 @@ class UserProfile(Base):
|
|||||||
|
|
||||||
|
|
||||||
class Contact(Base):
|
class Contact(Base):
|
||||||
query = None
|
query = None # type: Query
|
||||||
__tablename__ = "contact"
|
__tablename__ = "contact"
|
||||||
|
|
||||||
user = Column(Integer, ForeignKey("user.tgid"), primary_key=True)
|
user = Column(Integer, ForeignKey("user.tgid"), primary_key=True)
|
||||||
@@ -133,7 +132,7 @@ class Contact(Base):
|
|||||||
|
|
||||||
|
|
||||||
class Puppet(Base):
|
class Puppet(Base):
|
||||||
query = None
|
query = None # type: Query
|
||||||
__tablename__ = "puppet"
|
__tablename__ = "puppet"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
@@ -149,14 +148,14 @@ class Puppet(Base):
|
|||||||
|
|
||||||
# Fucking Telegram not telling bots what chats they are in 3:<
|
# Fucking Telegram not telling bots what chats they are in 3:<
|
||||||
class BotChat(Base):
|
class BotChat(Base):
|
||||||
query = None
|
query = None # type: Query
|
||||||
__tablename__ = "bot_chat"
|
__tablename__ = "bot_chat"
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
type = Column(String, nullable=False)
|
type = Column(String, nullable=False)
|
||||||
|
|
||||||
|
|
||||||
class TelegramFile(Base):
|
class TelegramFile(Base):
|
||||||
query = None
|
query = None # type: Query
|
||||||
__tablename__ = "telegram_file"
|
__tablename__ = "telegram_file"
|
||||||
|
|
||||||
id = Column(String, primary_key=True)
|
id = Column(String, primary_key=True)
|
||||||
|
|||||||
@@ -824,7 +824,12 @@ class Portal:
|
|||||||
mxid=event_id))
|
mxid=event_id))
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
|
||||||
async def handle_matrix_message(self, sender, message, event_id):
|
async def handle_matrix_message(self, sender: u.User, message: dict, event_id: str):
|
||||||
|
puppet = p.Puppet.get_by_custom_mxid(sender.mxid)
|
||||||
|
if puppet and message.get("net.maunium.telegram.puppet", False):
|
||||||
|
self.log.debug("Ignoring puppet-sent message by confirmed puppet user %s", sender.mxid)
|
||||||
|
return
|
||||||
|
|
||||||
logged_in = not await sender.needs_relaybot(self)
|
logged_in = not await sender.needs_relaybot(self)
|
||||||
client = sender.client if logged_in else self.bot.client
|
client = sender.client if logged_in else self.bot.client
|
||||||
sender_id = sender.tgid if logged_in else self.bot.tgid
|
sender_id = sender.tgid if logged_in else self.bot.tgid
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ class Puppet:
|
|||||||
username_template = None
|
username_template = None
|
||||||
hs_domain = None
|
hs_domain = None
|
||||||
cache = {}
|
cache = {}
|
||||||
|
by_custom_mxid = {}
|
||||||
|
|
||||||
def __init__(self, id=None, access_token=None, custom_mxid=None, username=None,
|
def __init__(self, id=None, access_token=None, custom_mxid=None, username=None,
|
||||||
displayname=None, displayname_source=None, photo_id=None, is_bot=None,
|
displayname=None, displayname_source=None, photo_id=None, is_bot=None,
|
||||||
@@ -60,22 +61,51 @@ class Puppet:
|
|||||||
self.refresh_intents()
|
self.refresh_intents()
|
||||||
|
|
||||||
self.cache[id] = self
|
self.cache[id] = self
|
||||||
|
if self.custom_mxid:
|
||||||
|
self.by_custom_mxid[self.custom_mxid] = self
|
||||||
|
|
||||||
def refresh_intents(self):
|
def refresh_intents(self):
|
||||||
self.is_real_user = self.custom_mxid and self.access_token
|
self.is_real_user = self.custom_mxid and self.access_token
|
||||||
self.intent = (self.az.intent.user(self.custom_mxid, self.access_token)
|
self.intent = (self.az.intent.user(self.custom_mxid, self.access_token)
|
||||||
if self.is_real_user else self.default_mxid_intent)
|
if self.is_real_user else self.default_mxid_intent)
|
||||||
|
|
||||||
async def get_profile(self):
|
|
||||||
try:
|
|
||||||
return await self.intent.get_profile(self.custom_mxid)
|
|
||||||
except MatrixError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tgid(self):
|
def tgid(self):
|
||||||
return self.id
|
return self.id
|
||||||
|
|
||||||
|
async def switch_mxid(self, access_token, mxid):
|
||||||
|
prev_mxid = self.custom_mxid
|
||||||
|
self.custom_mxid = mxid
|
||||||
|
self.access_token = access_token
|
||||||
|
self.refresh_intents()
|
||||||
|
|
||||||
|
err = await self.test_custom_mxid()
|
||||||
|
if err != 0:
|
||||||
|
return err
|
||||||
|
|
||||||
|
try:
|
||||||
|
del self.by_custom_mxid[prev_mxid]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
self.mxid = self.custom_mxid or self.default_mxid
|
||||||
|
self.by_custom_mxid[self.mxid] = self
|
||||||
|
self.save()
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def test_custom_mxid(self):
|
||||||
|
if not self.is_real_user:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
mxid = await self.intent.whoami()
|
||||||
|
if not mxid or mxid != self.custom_mxid:
|
||||||
|
self.custom_mxid = None
|
||||||
|
self.access_token = None
|
||||||
|
self.refresh_intents()
|
||||||
|
if mxid != self.custom_mxid:
|
||||||
|
return 2
|
||||||
|
return 1
|
||||||
|
return 0
|
||||||
|
|
||||||
async def is_logged_in(self):
|
async def is_logged_in(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -212,6 +242,30 @@ class Puppet:
|
|||||||
tgid = cls.get_id_from_mxid(mxid)
|
tgid = cls.get_id_from_mxid(mxid)
|
||||||
return cls.get(tgid, create) if tgid else None
|
return cls.get(tgid, create) if tgid else None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_by_custom_mxid(cls, mxid):
|
||||||
|
if not mxid:
|
||||||
|
raise ValueError("Matrix ID can't be empty")
|
||||||
|
|
||||||
|
try:
|
||||||
|
return cls.by_custom_mxid[mxid]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
puppet = DBPuppet.query.filter(DBPuppet.custom_mxid == mxid).one_or_none()
|
||||||
|
if puppet:
|
||||||
|
puppet = cls.from_db(puppet)
|
||||||
|
return puppet
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_all_with_custom_mxid(cls):
|
||||||
|
return [cls.by_custom_mxid[puppet.mxid]
|
||||||
|
if puppet.custom_mxid in cls.by_custom_mxid
|
||||||
|
else cls.from_db(puppet)
|
||||||
|
for puppet in DBPuppet.query.filter(DBPuppet.custom_mxid is not None).all()]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_id_from_mxid(cls, mxid):
|
def get_id_from_mxid(cls, mxid):
|
||||||
match = cls.mxid_regex.match(mxid)
|
match = cls.mxid_regex.match(mxid)
|
||||||
@@ -261,3 +315,4 @@ def init(context):
|
|||||||
Puppet.hs_domain = config["homeserver"]["domain"]
|
Puppet.hs_domain = config["homeserver"]["domain"]
|
||||||
localpart = Puppet.username_template.format(userid="(.+)")
|
localpart = Puppet.username_template.format(userid="(.+)")
|
||||||
Puppet.mxid_regex = re.compile(f"@{localpart}:{Puppet.hs_domain}")
|
Puppet.mxid_regex = re.compile(f"@{localpart}:{Puppet.hs_domain}")
|
||||||
|
return [puppet.test_custom_mxid() for puppet in Puppet.get_all_with_custom_mxid()]
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import Dict
|
from typing import Dict, Awaitable, Optional
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
@@ -185,6 +185,9 @@ class User(AbstractUser):
|
|||||||
# endregion
|
# endregion
|
||||||
# region Telegram actions that need custom methods
|
# region Telegram actions that need custom methods
|
||||||
|
|
||||||
|
def ensure_started(self, even_if_no_session=False) -> "Awaitable[User]":
|
||||||
|
return super().ensure_started(even_if_no_session)
|
||||||
|
|
||||||
async def update_info(self, info: User = None):
|
async def update_info(self, info: User = None):
|
||||||
info = info or await self.client.get_me()
|
info = info or await self.client.get_me()
|
||||||
changed = False
|
changed = False
|
||||||
@@ -309,7 +312,7 @@ class User(AbstractUser):
|
|||||||
# region Class instance lookup
|
# region Class instance lookup
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_by_mxid(cls, mxid, create=True):
|
def get_by_mxid(cls, mxid, create=True) -> "Optional[User]":
|
||||||
if not mxid:
|
if not mxid:
|
||||||
raise ValueError("Matrix ID can't be empty")
|
raise ValueError("Matrix ID can't be empty")
|
||||||
|
|
||||||
@@ -332,7 +335,7 @@ class User(AbstractUser):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_by_tgid(cls, tgid):
|
def get_by_tgid(cls, tgid) -> "Optional[User]":
|
||||||
try:
|
try:
|
||||||
return cls.by_tgid[tgid]
|
return cls.by_tgid[tgid]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@@ -346,7 +349,7 @@ class User(AbstractUser):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def find_by_username(cls, username):
|
def find_by_username(cls, username) -> "Optional[User]":
|
||||||
if not username:
|
if not username:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user