Store user portals and kick when logging out. Fixes #53

This commit is contained in:
Tulir Asokan
2018-02-13 00:58:03 +02:00
parent e42fcd2fb3
commit 2064f2b2d1
4 changed files with 90 additions and 20 deletions
+18 -2
View File
@@ -13,7 +13,7 @@
# #
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from sqlalchemy import Column, UniqueConstraint, ForeignKey, Integer, String from sqlalchemy import Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer, String
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from .base import Base from .base import Base
@@ -50,6 +50,18 @@ class Message(Base):
__table_args__ = (UniqueConstraint('mxid', 'mx_room', 'tg_space', name='_mx_id_room'),) __table_args__ = (UniqueConstraint('mxid', 'mx_room', 'tg_space', name='_mx_id_room'),)
class UserPortal(Base):
query = None
__tablename__ = "user_portal"
user = Column(Integer, ForeignKey("user.tgid"), primary_key=True)
portal = Column(Integer, primary_key=True)
portal_receiver = Column(Integer, primary_key=True)
__table_args__ = (ForeignKeyConstraint(("portal", "portal_receiver"),
("portal.tgid", "portal.tg_receiver")),)
class User(Base): class User(Base):
query = None query = None
__tablename__ = "user" __tablename__ = "user"
@@ -58,7 +70,10 @@ class User(Base):
tgid = Column(Integer, nullable=True) tgid = Column(Integer, nullable=True)
tg_username = Column(String, nullable=True) tg_username = Column(String, nullable=True)
saved_contacts = Column(Integer, default=0) saved_contacts = Column(Integer, default=0)
contacts = relationship("Contact", uselist=True) contacts = relationship("Contact", uselist=True,
cascade="save-update, merge, delete, delete-orphan")
portals = relationship("Portal", secondary="user_portal", single_parent=True,
cascade="save-update, merge, delete, delete-orphan")
class Contact(Base): class Contact(Base):
@@ -82,5 +97,6 @@ class Puppet(Base):
def init(db_session): def init(db_session):
Portal.query = db_session.query_property() Portal.query = db_session.query_property()
Message.query = db_session.query_property() Message.query = db_session.query_property()
UserPortal.query = db_session.query_property()
User.query = db_session.query_property() User.query = db_session.query_property()
Puppet.query = db_session.query_property() Puppet.query = db_session.query_property()
+1
View File
@@ -79,6 +79,7 @@ class MatrixHandler:
pass pass
portal.mxid = room portal.mxid = room
portal.save() portal.save()
inviter.register_portal(portal)
await puppet.intent.send_notice(room, "Portal to private chat created.") await puppet.intent.send_notice(room, "Portal to private chat created.")
else: else:
await puppet.intent.join_room(room) await puppet.intent.join_room(room)
+5 -1
View File
@@ -204,7 +204,7 @@ class Portal:
if alias: if alias:
# TODO properly handle existing room aliases # TODO properly handle existing room aliases
intent.remove_room_alias(alias) await intent.remove_room_alias(alias)
room = await intent.create_room(alias=alias, is_public=public, invitees=invites or [], room = await intent.create_room(alias=alias, is_public=public, invitees=invites or [],
name=self.title, is_direct=direct) name=self.title, is_direct=direct)
if not room: if not room:
@@ -213,6 +213,7 @@ class Portal:
self.mxid = room["room_id"] self.mxid = room["room_id"]
self.by_mxid[self.mxid] = self self.by_mxid[self.mxid] = self
self.save() self.save()
user.register_portal(self)
power_level_requirement = 0 if self.peer_type == "chat" and entity.admins_enabled else 50 power_level_requirement = 0 if self.peer_type == "chat" and entity.admins_enabled else 50
levels = await self.main_intent.get_power_levels(self.mxid) levels = await self.main_intent.get_power_levels(self.mxid)
@@ -245,6 +246,7 @@ class Portal:
user = u.User.get_by_tgid(user_id) user = u.User.get_by_tgid(user_id)
if user: if user:
user.register_portal(self)
await self.main_intent.invite(self.mxid, user.mxid) await self.main_intent.invite(self.mxid, user.mxid)
async def delete_telegram_user(self, user_id, kick_message=None): async def delete_telegram_user(self, user_id, kick_message=None):
@@ -255,6 +257,7 @@ class Portal:
else: else:
await puppet.intent.leave_room(self.mxid) await puppet.intent.leave_room(self.mxid)
if user: if user:
user.unregister_portal(self)
await self.main_intent.kick(self.mxid, user.mxid, kick_message or "Left Telegram chat") await self.main_intent.kick(self.mxid, user.mxid, kick_message or "Left Telegram chat")
async def update_info(self, user, entity=None): async def update_info(self, user, entity=None):
@@ -840,6 +843,7 @@ class Portal:
user_levels = levels["users"] user_levels = levels["users"]
if user: if user:
user.register_portal(self)
user_level_defined = user.mxid in user_levels user_level_defined = user.mxid in user_levels
user_has_right_level = (user_levels[user.mxid] == new_level user_has_right_level = (user_levels[user.mxid] == new_level
if user_level_defined else new_level == 0) if user_level_defined else new_level == 0)
+66 -17
View File
@@ -21,12 +21,12 @@ from telethon.tl.types import *
from telethon.tl.types.contacts import ContactsNotModified from telethon.tl.types.contacts import ContactsNotModified
from telethon.tl.types import User as TLUser from telethon.tl.types import User as TLUser
from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest
from mautrix_appservice import MatrixRequestError
from .db import User as DBUser, Message as DBMessage, Contact as DBContact from .db import User as DBUser, Message as DBMessage, Contact as DBContact
from .tgclient import MautrixTelegramClient from .tgclient import MautrixTelegramClient
from . import portal as po, puppet as pu, __version__ from . import portal as po, puppet as pu, __version__
config = None config = None
@@ -38,26 +38,20 @@ class User:
by_mxid = {} by_mxid = {}
by_tgid = {} by_tgid = {}
def __init__(self, mxid, tgid=None, username=None, db_contacts=None, saved_contacts=0): def __init__(self, mxid, tgid=None, username=None, db_contacts=None, saved_contacts=0,
db_portals=None):
self.mxid = mxid self.mxid = mxid
self.tgid = tgid self.tgid = tgid
self.username = username self.username = username
self.contacts = [] self.contacts = []
self.saved_contacts = saved_contacts self.saved_contacts = saved_contacts
self.db_contacts = db_contacts self.db_contacts = db_contacts
self.portals = {}
self.db_portals = db_portals
self.command_status = None self.command_status = None
self.connected = False self.connected = False
device = f"{platform.system()} {platform.release()}" self._init_client()
sysversion = MautrixTelegramClient.__version__
self.client = MautrixTelegramClient(self.mxid,
config["telegram.api_id"],
config["telegram.api_hash"],
loop=self.loop,
app_version=__version__,
system_version=sysversion,
device_model=device)
self.client.add_update_handler(self.update_catch)
self.is_admin = self.mxid in config.get("bridge.admins", []) self.is_admin = self.mxid in config.get("bridge.admins", [])
@@ -91,6 +85,19 @@ class User:
else: else:
self.contacts = [] self.contacts = []
@property
def db_portals(self):
return [portal.to_db(merge=False) for _, portal in self.portals.items()]
@db_portals.setter
def db_portals(self, portals):
if portals:
self.portals = {(portal.tgid, portal.tg_receiver):
po.Portal.get_by_tgid(portal.tgid, portal.tg_receiver)
for portal in portals}
else:
self.portals = {}
def get_input_entity(self, user): def get_input_entity(self, user):
return user.client.get_input_entity(InputUser(user_id=self.tgid, access_hash=0)) return user.client.get_input_entity(InputUser(user_id=self.tgid, access_hash=0))
@@ -99,7 +106,8 @@ class User:
def to_db(self): def to_db(self):
return self.db.merge( return self.db.merge(
DBUser(mxid=self.mxid, tgid=self.tgid, tg_username=self.username, DBUser(mxid=self.mxid, tgid=self.tgid, tg_username=self.username,
contacts=self.db_contacts, saved_contacts=self.saved_contacts)) contacts=self.db_contacts, saved_contacts=self.saved_contacts,
portals=self.db_portals))
def save(self): def save(self):
self.to_db() self.to_db()
@@ -108,11 +116,23 @@ class User:
@classmethod @classmethod
def from_db(cls, db_user): def from_db(cls, db_user):
return User(db_user.mxid, db_user.tgid, db_user.tg_username, db_user.contacts, return User(db_user.mxid, db_user.tgid, db_user.tg_username, db_user.contacts,
db_user.saved_contacts) db_user.saved_contacts, db_user.portals)
# endregion # endregion
# region Telegram connection management # region Telegram connection management
def _init_client(self):
device = f"{platform.system()} {platform.release()}"
sysversion = MautrixTelegramClient.__version__
self.client = MautrixTelegramClient(self.mxid,
config["telegram.api_id"],
config["telegram.api_hash"],
loop=self.loop,
app_version=__version__,
system_version=sysversion,
device_model=device)
self.client.add_update_handler(self.update_catch)
async def start(self): async def start(self):
self.connected = await self.client.connect() self.connected = await self.client.connect()
if self.logged_in: if self.logged_in:
@@ -148,7 +168,14 @@ class User:
self.save() self.save()
async def log_out(self): async def log_out(self):
self.connected = False for _, portal in self.portals.items():
try:
await portal.main_intent.kick(portal.mxid, self.mxid, "Logged out of Telegram.")
except MatrixRequestError:
pass
self.portals = {}
self.contacts = []
self.save()
if self.tgid: if self.tgid:
try: try:
del self.by_tgid[self.tgid] del self.by_tgid[self.tgid]
@@ -156,8 +183,12 @@ class User:
pass pass
self.tgid = None self.tgid = None
self.save() self.save()
await self.client.log_out() ok = await self.client.log_out()
# TODO kick user from portals if not ok:
return False
self._init_client()
await self.start()
return True
def _search_local(self, query, max_results=5, min_similarity=45): def _search_local(self, query, max_results=5, min_similarity=45):
results = [] results = []
@@ -200,9 +231,27 @@ class User:
if invalid: if invalid:
continue continue
portal = po.Portal.get_by_entity(entity) portal = po.Portal.get_by_entity(entity)
self.portals[portal.tgid_full] = portal
creators.append(portal.create_matrix_room(self, entity, invites=[self.mxid])) creators.append(portal.create_matrix_room(self, entity, invites=[self.mxid]))
self.save()
await asyncio.gather(*creators, loop=self.loop) await asyncio.gather(*creators, loop=self.loop)
def register_portal(self, portal):
try:
if self.portals[portal.tgid_full] == portal:
return
except KeyError:
pass
self.portals[portal.tgid_full] = portal
self.save()
def unregister_portal(self, portal):
try:
del self.portals[portal.tgid_full]
self.save()
except KeyError:
pass
def _hash_contacts(self): def _hash_contacts(self):
acc = 0 acc = 0
for id in sorted([self.saved_contacts] + [contact.id for contact in self.contacts]): for id in sorted([self.saved_contacts] + [contact.id for contact in self.contacts]):