Merge pull request #239 from tulir/sqlalchemy-core
Port Message table to SQLAlchemy Core
This commit is contained in:
@@ -113,7 +113,7 @@ if config["appservice.provisioning.enabled"]:
|
|||||||
context.provisioning_api = provisioning_api
|
context.provisioning_api = provisioning_api
|
||||||
|
|
||||||
with appserv.run(config["appservice.hostname"], config["appservice.port"]) as start:
|
with appserv.run(config["appservice.hostname"], config["appservice.port"]) as start:
|
||||||
init_db(db_session)
|
init_db(db_session, db_engine)
|
||||||
init_abstract_user(context)
|
init_abstract_user(context)
|
||||||
context.bot = init_bot(context)
|
context.bot = init_bot(context)
|
||||||
context.mx = MatrixHandler(context)
|
context.mx = MatrixHandler(context)
|
||||||
|
|||||||
@@ -230,7 +230,7 @@ class AbstractUser(ABC):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# We check that these are user read receipts, so tg_space is always the user ID.
|
# We check that these are user read receipts, so tg_space is always the user ID.
|
||||||
message = DBMessage.query.get((update.max_id, self.tgid))
|
message = DBMessage.get_by_tgid(update.max_id, self.tgid)
|
||||||
if not message:
|
if not message:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -323,12 +323,11 @@ class AbstractUser(ABC):
|
|||||||
return
|
return
|
||||||
|
|
||||||
for message in update.messages:
|
for message in update.messages:
|
||||||
message = DBMessage.query.get((message, self.tgid))
|
message = DBMessage.get_by_tgid(TelegramID(message), self.tgid)
|
||||||
if not message:
|
if not message:
|
||||||
continue
|
continue
|
||||||
self.db.delete(message)
|
message.delete()
|
||||||
number_left = DBMessage.query.filter(DBMessage.mxid == message.mxid,
|
number_left = DBMessage.count_spaces_by_mxid(message.mxid, message.mx_room)
|
||||||
DBMessage.mx_room == message.mx_room).count()
|
|
||||||
if number_left == 0:
|
if number_left == 0:
|
||||||
portal = po.Portal.get_by_mxid(message.mx_room)
|
portal = po.Portal.get_by_mxid(message.mx_room)
|
||||||
await self._try_redact(portal, message)
|
await self._try_redact(portal, message)
|
||||||
@@ -343,10 +342,10 @@ class AbstractUser(ABC):
|
|||||||
return
|
return
|
||||||
|
|
||||||
for message in update.messages:
|
for message in update.messages:
|
||||||
message = DBMessage.query.get((message, portal.tgid))
|
message = DBMessage.get_by_tgid(TelegramID(message), portal.tgid)
|
||||||
if not message:
|
if not message:
|
||||||
continue
|
continue
|
||||||
self.db.delete(message)
|
message.delete()
|
||||||
await self._try_redact(portal, message)
|
await self._try_redact(portal, message)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
|
||||||
|
|||||||
+72
-4
@@ -15,9 +15,12 @@
|
|||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer,
|
from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer,
|
||||||
BigInteger, String, Boolean, Text)
|
BigInteger, String, Boolean, Text, Table,
|
||||||
|
and_, func, select)
|
||||||
|
from sqlalchemy.engine import Engine, RowProxy
|
||||||
from sqlalchemy.sql import expression
|
from sqlalchemy.sql import expression
|
||||||
from sqlalchemy.orm import relationship, Query
|
from sqlalchemy.orm import relationship, Query
|
||||||
|
from sqlalchemy.sql.base import ImmutableColumnCollection
|
||||||
from typing import Dict, Optional, List
|
from typing import Dict, Optional, List
|
||||||
import json
|
import json
|
||||||
|
|
||||||
@@ -49,7 +52,9 @@ class Portal(Base):
|
|||||||
|
|
||||||
|
|
||||||
class Message(Base):
|
class Message(Base):
|
||||||
query = None # type: Query
|
db = None # type: Engine
|
||||||
|
t = None # type: Table
|
||||||
|
c = None # type: ImmutableColumnCollection
|
||||||
__tablename__ = "message"
|
__tablename__ = "message"
|
||||||
|
|
||||||
mxid = Column(String) # type: MatrixEventID
|
mxid = Column(String) # type: MatrixEventID
|
||||||
@@ -59,6 +64,67 @@ class Message(Base):
|
|||||||
|
|
||||||
__table_args__ = (UniqueConstraint("mxid", "mx_room", "tg_space", name="_mx_id_room"),)
|
__table_args__ = (UniqueConstraint("mxid", "mx_room", "tg_space", name="_mx_id_room"),)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _one_or_none(rows: RowProxy) -> Optional['Message']:
|
||||||
|
try:
|
||||||
|
mxid, mx_room, tgid, tg_space = next(rows)
|
||||||
|
return Message(mxid=mxid, mx_room=mx_room, tgid=tgid, tg_space=tg_space)
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _all(rows: RowProxy) -> List['Message']:
|
||||||
|
return [Message(mxid=row[0], mx_room=row[1], tgid=row[2], tg_space=row[3])
|
||||||
|
for row in rows]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID) -> Optional['Message']:
|
||||||
|
rows = cls.db.execute(cls.t.select()
|
||||||
|
.where(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space)))
|
||||||
|
return cls._one_or_none(rows)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def count_spaces_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID) -> int:
|
||||||
|
rows = cls.db.execute(select([func.count(cls.c.tg_space)])
|
||||||
|
.where(and_(cls.c.mxid == mxid, cls.c.mx_room == mx_room)))
|
||||||
|
try:
|
||||||
|
count, = next(rows)
|
||||||
|
return count
|
||||||
|
except StopIteration:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID, tg_space: TelegramID
|
||||||
|
) -> Optional['Message']:
|
||||||
|
rows = cls.db.execute(cls.t.select().where(
|
||||||
|
and_(cls.c.mxid == mxid, cls.c.mx_room == mx_room, cls.c.tg_space == tg_space)))
|
||||||
|
return cls._one_or_none(rows)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def update_by_tgid(cls, s_tgid: TelegramID, s_tg_space: TelegramID, **values) -> None:
|
||||||
|
cls.db.execute(cls.t.update()
|
||||||
|
.where(and_(cls.c.tgid == s_tgid, cls.c.tg_space == s_tg_space))
|
||||||
|
.values(**values))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def update_by_mxid(cls, s_mxid: MatrixEventID, s_mx_room: MatrixRoomID, **values) -> None:
|
||||||
|
cls.db.execute(cls.t.update()
|
||||||
|
.where(and_(cls.c.mxid == s_mxid, cls.c.mx_room == s_mx_room))
|
||||||
|
.values(**values))
|
||||||
|
|
||||||
|
def update(self, **values) -> None:
|
||||||
|
for key, value in values.items():
|
||||||
|
setattr(self, key, value)
|
||||||
|
self.update_by_tgid(self.tgid, self.tg_space, **values)
|
||||||
|
|
||||||
|
def delete(self) -> None:
|
||||||
|
self.db.execute(self.t.delete().where(
|
||||||
|
and_(self.c.tgid == self.tgid, self.c.tg_space == self.tg_space)))
|
||||||
|
|
||||||
|
def insert(self) -> None:
|
||||||
|
self.db.execute(self.t.insert().values(mxid=self.mxid, mx_room=self.mx_room, tgid=self.tgid,
|
||||||
|
tg_space=self.tg_space))
|
||||||
|
|
||||||
|
|
||||||
class UserPortal(Base):
|
class UserPortal(Base):
|
||||||
query = None # type: Query
|
query = None # type: Query
|
||||||
@@ -178,9 +244,11 @@ class TelegramFile(Base):
|
|||||||
thumbnail = relationship("TelegramFile", uselist=False)
|
thumbnail = relationship("TelegramFile", uselist=False)
|
||||||
|
|
||||||
|
|
||||||
def init(db_session) -> None:
|
def init(db_session, db_engine) -> None:
|
||||||
Portal.query = db_session.query_property()
|
Portal.query = db_session.query_property()
|
||||||
Message.query = db_session.query_property()
|
Message.db = db_engine
|
||||||
|
Message.t = Message.__table__
|
||||||
|
Message.c = Message.t.c
|
||||||
UserPortal.query = db_session.query_property()
|
UserPortal.query = db_session.query_property()
|
||||||
User.query = db_session.query_property()
|
User.query = db_session.query_property()
|
||||||
Puppet.query = db_session.query_property()
|
Puppet.query = db_session.query_property()
|
||||||
|
|||||||
@@ -105,9 +105,7 @@ def matrix_reply_to_telegram(content: Dict[str, Any], tg_space: TelegramID,
|
|||||||
pass
|
pass
|
||||||
content["body"] = trim_reply_fallback_text(content["body"])
|
content["body"] = trim_reply_fallback_text(content["body"])
|
||||||
|
|
||||||
message = DBMessage.query.filter(DBMessage.mxid == event_id,
|
message = DBMessage.get_by_mxid(event_id, room_id, tg_space)
|
||||||
DBMessage.tg_space == tg_space,
|
|
||||||
DBMessage.mx_room == room_id).one_or_none()
|
|
||||||
if message:
|
if message:
|
||||||
return message.tgid
|
return message.tgid
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ def telegram_reply_to_matrix(evt: Message, source: 'AbstractUser') -> Dict:
|
|||||||
space = (evt.to_id.channel_id
|
space = (evt.to_id.channel_id
|
||||||
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
|
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
|
||||||
else source.tgid)
|
else source.tgid)
|
||||||
msg = DBMessage.query.get((evt.reply_to_msg_id, space))
|
msg = DBMessage.get_by_tgid(evt.reply_to_msg_id, space)
|
||||||
if msg:
|
if msg:
|
||||||
return {
|
return {
|
||||||
"m.in_reply_to": {
|
"m.in_reply_to": {
|
||||||
@@ -124,7 +124,7 @@ async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: M
|
|||||||
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
|
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
|
||||||
else source.tgid)
|
else source.tgid)
|
||||||
|
|
||||||
msg = DBMessage.query.get((evt.reply_to_msg_id, space))
|
msg = DBMessage.get_by_tgid(evt.reply_to_msg_id, space)
|
||||||
if not msg:
|
if not msg:
|
||||||
return text, html
|
return text, html
|
||||||
|
|
||||||
@@ -325,7 +325,7 @@ def _parse_url(html: List[str], entity_text: str, url: str) -> bool:
|
|||||||
|
|
||||||
portal = po.Portal.find_by_username(group)
|
portal = po.Portal.find_by_username(group)
|
||||||
if portal:
|
if portal:
|
||||||
message = DBMessage.query.get((msgid, portal.tgid))
|
message = DBMessage.get_by_tgid(TelegramID(msgid), portal.tgid)
|
||||||
if message:
|
if message:
|
||||||
url = f"https://matrix.to/#/{portal.mxid}/{message.mxid}"
|
url = f"https://matrix.to/#/{portal.mxid}/{message.mxid}"
|
||||||
|
|
||||||
|
|||||||
+19
-35
@@ -772,9 +772,7 @@ class Portal:
|
|||||||
if user.is_bot:
|
if user.is_bot:
|
||||||
return
|
return
|
||||||
space = self.tgid if self.peer_type == "channel" else user.tgid
|
space = self.tgid if self.peer_type == "channel" else user.tgid
|
||||||
message = DBMessage.query.filter(DBMessage.mxid == event_id,
|
message = DBMessage.get_by_mxid(event_id, self.mxid, space)
|
||||||
DBMessage.mx_room == self.mxid,
|
|
||||||
DBMessage.tg_space == space).one_or_none()
|
|
||||||
if not message:
|
if not message:
|
||||||
return
|
return
|
||||||
if self.peer_type == "channel":
|
if self.peer_type == "channel":
|
||||||
@@ -959,12 +957,11 @@ class Portal:
|
|||||||
response: TypeMessage) -> None:
|
response: TypeMessage) -> None:
|
||||||
self.log.debug("Handled Matrix message: %s", response)
|
self.log.debug("Handled Matrix message: %s", response)
|
||||||
self.is_duplicate(response, (event_id, space))
|
self.is_duplicate(response, (event_id, space))
|
||||||
self.db.add(DBMessage(
|
DBMessage(
|
||||||
tgid=response.id,
|
tgid=response.id,
|
||||||
tg_space=space,
|
tg_space=space,
|
||||||
mx_room=self.mxid,
|
mx_room=self.mxid,
|
||||||
mxid=event_id))
|
mxid=event_id).insert()
|
||||||
self.db.commit()
|
|
||||||
|
|
||||||
async def handle_matrix_message(self, sender: 'u.User', message: Dict[str, Any],
|
async def handle_matrix_message(self, sender: 'u.User', message: Dict[str, Any],
|
||||||
event_id: MatrixEventID) -> None:
|
event_id: MatrixEventID) -> None:
|
||||||
@@ -1009,9 +1006,10 @@ class Portal:
|
|||||||
if not pinned_message:
|
if not pinned_message:
|
||||||
await sender.client(UpdatePinnedMessageRequest(channel=self.peer, id=0))
|
await sender.client(UpdatePinnedMessageRequest(channel=self.peer, id=0))
|
||||||
else:
|
else:
|
||||||
message = DBMessage.query.filter(DBMessage.mxid == pinned_message,
|
message = DBMessage.get_by_mxid(pinned_message, self.mxid, self.tgid)
|
||||||
DBMessage.tg_space == self.tgid,
|
if message is None:
|
||||||
DBMessage.mx_room == self.mxid).one_or_none()
|
self.log.warning(f"Could not find pinned {pinned_message} in {self.mxid}")
|
||||||
|
return
|
||||||
await sender.client(UpdatePinnedMessageRequest(channel=self.peer, id=message.tgid))
|
await sender.client(UpdatePinnedMessageRequest(channel=self.peer, id=message.tgid))
|
||||||
except ChatNotModifiedError:
|
except ChatNotModifiedError:
|
||||||
pass
|
pass
|
||||||
@@ -1019,9 +1017,7 @@ class Portal:
|
|||||||
async def handle_matrix_deletion(self, deleter: 'u.User', event_id: MatrixEventID) -> None:
|
async def handle_matrix_deletion(self, deleter: 'u.User', event_id: MatrixEventID) -> None:
|
||||||
real_deleter = deleter if not await deleter.needs_relaybot(self) else self.bot
|
real_deleter = deleter if not await deleter.needs_relaybot(self) else self.bot
|
||||||
space = self.tgid if self.peer_type == "channel" else real_deleter.tgid
|
space = self.tgid if self.peer_type == "channel" else real_deleter.tgid
|
||||||
message = DBMessage.query.filter(DBMessage.mxid == event_id,
|
message = DBMessage.get_by_mxid(event_id, self.mxid, space)
|
||||||
DBMessage.tg_space == space,
|
|
||||||
DBMessage.mx_room == self.mxid).one_or_none()
|
|
||||||
if not message:
|
if not message:
|
||||||
return
|
return
|
||||||
await real_deleter.client.delete_messages(self.peer, [message.tgid])
|
await real_deleter.client.delete_messages(self.peer, [message.tgid])
|
||||||
@@ -1413,10 +1409,9 @@ class Portal:
|
|||||||
if duplicate_found:
|
if duplicate_found:
|
||||||
mxid, other_tg_space = duplicate_found
|
mxid, other_tg_space = duplicate_found
|
||||||
if tg_space != other_tg_space:
|
if tg_space != other_tg_space:
|
||||||
msg = DBMessage.query.get((evt.id, tg_space))
|
DBMessage.update_by_tgid(evt.id, tg_space,
|
||||||
msg.mxid = mxid
|
mxid=mxid,
|
||||||
msg.mx_room = self.mxid
|
mx_room=self.mxid)
|
||||||
self.db.commit()
|
|
||||||
return
|
return
|
||||||
|
|
||||||
evt.reply_to_msg_id = evt.id
|
evt.reply_to_msg_id = evt.id
|
||||||
@@ -1429,19 +1424,14 @@ class Portal:
|
|||||||
|
|
||||||
mxid = response["event_id"]
|
mxid = response["event_id"]
|
||||||
|
|
||||||
msg = DBMessage.query.get((evt.id, tg_space))
|
msg = DBMessage.get_by_tgid(evt.id, tg_space)
|
||||||
if not msg:
|
if not msg:
|
||||||
self.log.info(f"Didn't find edited message {evt.id}@{tg_space} (src {source.tgid}) "
|
self.log.info(f"Didn't find edited message {evt.id}@{tg_space} (src {source.tgid}) "
|
||||||
"in database.")
|
"in database.")
|
||||||
# Oh crap
|
# Oh crap
|
||||||
return
|
return
|
||||||
msg.mxid = mxid
|
msg.update(mxid=mxid, mx_room=self.mxid)
|
||||||
msg.mx_room = self.mxid
|
DBMessage.update_by_mxid(temporary_identifier, self.mxid, mxid=mxid)
|
||||||
DBMessage.query \
|
|
||||||
.filter(DBMessage.mx_room == self.mxid,
|
|
||||||
DBMessage.mxid == temporary_identifier) \
|
|
||||||
.update({"mxid": mxid})
|
|
||||||
self.db.commit()
|
|
||||||
|
|
||||||
async def handle_telegram_message(self, source: "AbstractUser", sender: p.Puppet,
|
async def handle_telegram_message(self, source: "AbstractUser", sender: p.Puppet,
|
||||||
evt: Message) -> None:
|
evt: Message) -> None:
|
||||||
@@ -1463,13 +1453,11 @@ class Portal:
|
|||||||
self.log.debug(f"Ignoring message {evt.id}@{tg_space} (src {source.tgid}) "
|
self.log.debug(f"Ignoring message {evt.id}@{tg_space} (src {source.tgid}) "
|
||||||
f"as it was already handled (in space {other_tg_space})")
|
f"as it was already handled (in space {other_tg_space})")
|
||||||
if tg_space != other_tg_space:
|
if tg_space != other_tg_space:
|
||||||
self.db.add(
|
DBMessage(tgid=evt.id, mx_room=self.mxid, mxid=mxid, tg_space=tg_space).insert()
|
||||||
DBMessage(tgid=evt.id, mx_room=self.mxid, mxid=mxid, tg_space=tg_space))
|
|
||||||
self.db.commit()
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.dedup_pre_db_check and self.peer_type == "channel":
|
if self.dedup_pre_db_check and self.peer_type == "channel":
|
||||||
msg = DBMessage.query.get((evt.id, tg_space))
|
msg = DBMessage.get_by_tgid(evt.id, tg_space)
|
||||||
if msg:
|
if msg:
|
||||||
self.log.debug(f"Ignoring message {evt.id} (src {source.tgid}) as it was already"
|
self.log.debug(f"Ignoring message {evt.id} (src {source.tgid}) as it was already"
|
||||||
f"handled into {msg.mxid}. This duplicate was catched in the db "
|
f"handled into {msg.mxid}. This duplicate was catched in the db "
|
||||||
@@ -1523,12 +1511,8 @@ class Portal:
|
|||||||
|
|
||||||
self.log.debug("Handled Telegram message: %s", evt)
|
self.log.debug("Handled Telegram message: %s", evt)
|
||||||
try:
|
try:
|
||||||
self.db.add(DBMessage(tgid=evt.id, mx_room=self.mxid, mxid=mxid, tg_space=tg_space))
|
DBMessage(tgid=evt.id, mx_room=self.mxid, mxid=mxid, tg_space=tg_space).insert()
|
||||||
self.db.commit()
|
DBMessage.update_by_mxid(temporary_identifier, self.mxid, mxid=mxid)
|
||||||
DBMessage.query \
|
|
||||||
.filter(DBMessage.mx_room == self.mxid,
|
|
||||||
DBMessage.mxid == temporary_identifier) \
|
|
||||||
.update({"mxid": mxid})
|
|
||||||
except FlushError as e:
|
except FlushError as e:
|
||||||
self.log.exception(f"{e.__class__.__name__} while saving message mapping. "
|
self.log.exception(f"{e.__class__.__name__} while saving message mapping. "
|
||||||
"This might mean that an update was handled after it left the "
|
"This might mean that an update was handled after it left the "
|
||||||
@@ -1610,7 +1594,7 @@ class Portal:
|
|||||||
self._temp_pinned_message_id = None
|
self._temp_pinned_message_id = None
|
||||||
self._temp_pinned_message_sender = None
|
self._temp_pinned_message_sender = None
|
||||||
|
|
||||||
message = DBMessage.query.get((msg_id, self.tgid))
|
message = DBMessage.get_by_tgid(msg_id, self.tgid)
|
||||||
if message:
|
if message:
|
||||||
await intent.set_pinned_messages(self.mxid, [message.mxid])
|
await intent.set_pinned_messages(self.mxid, [message.mxid])
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user