Add user+portal-specific lock for sending/receiving messages of authenticated users. Fixes #108
This commit is contained in:
+66
-26
@@ -76,6 +76,8 @@ class Portal:
|
|||||||
self._dedup_mxid = {}
|
self._dedup_mxid = {}
|
||||||
self._dedup_action = deque()
|
self._dedup_action = deque()
|
||||||
|
|
||||||
|
self._send_locks = {}
|
||||||
|
|
||||||
if tgid:
|
if tgid:
|
||||||
self.by_tgid[self.tgid_full] = self
|
self.by_tgid[self.tgid_full] = self
|
||||||
if mxid:
|
if mxid:
|
||||||
@@ -634,11 +636,32 @@ class Portal:
|
|||||||
message, entities = None, None
|
message, entities = None, None
|
||||||
return message, entities
|
return message, entities
|
||||||
|
|
||||||
async def _handle_matrix_text(self, client, message, reply_to):
|
def require_send_lock(self, id):
|
||||||
message, entities = await self._matrix_event_to_entities(client, message)
|
if id is None:
|
||||||
return await client.send_message(self.peer, message, entities=entities, reply_to=reply_to)
|
return None
|
||||||
|
try:
|
||||||
|
return self._send_locks[id]
|
||||||
|
except KeyError:
|
||||||
|
self._send_locks[id] = asyncio.Lock()
|
||||||
|
return self._send_locks[id]
|
||||||
|
|
||||||
async def _handle_matrix_file(self, client, message, reply_to):
|
def optional_send_lock(self, id):
|
||||||
|
if id is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return self._send_locks[id]
|
||||||
|
except KeyError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _handle_matrix_text(self, sender_id, event_id, space, client, message, reply_to):
|
||||||
|
message, entities = await self._matrix_event_to_entities(client, message)
|
||||||
|
|
||||||
|
lock = self.require_send_lock(sender_id)
|
||||||
|
async with lock:
|
||||||
|
response = await client.send_message(self.peer, message, entities=entities, reply_to=reply_to)
|
||||||
|
self._add_telegram_message_to_db(event_id, space, response)
|
||||||
|
|
||||||
|
async def _handle_matrix_file(self, sender_id, event_id, space, client, message, reply_to):
|
||||||
file = await self.main_intent.download_file(message["url"])
|
file = await self.main_intent.download_file(message["url"])
|
||||||
|
|
||||||
info = message["info"]
|
info = message["info"]
|
||||||
@@ -651,11 +674,14 @@ class Portal:
|
|||||||
attributes.append(DocumentAttributeImageSize(w=info["w"], h=info["h"]))
|
attributes.append(DocumentAttributeImageSize(w=info["w"], h=info["h"]))
|
||||||
|
|
||||||
caption = message["body"] if message["body"] != file_name else None
|
caption = message["body"] if message["body"] != file_name else None
|
||||||
return await client.send_file(self.peer, file, mime, caption=caption,
|
|
||||||
attributes=attributes, file_name=file_name,
|
|
||||||
reply_to=reply_to)
|
|
||||||
|
|
||||||
async def _handle_matrix_location(self, client, message, reply_to):
|
media = await client.upload_file(file, mime, attributes, file_name)
|
||||||
|
lock = self.require_send_lock(sender_id)
|
||||||
|
async with lock:
|
||||||
|
response = await client.send_media(self.peer, media, reply_to=reply_to, caption=caption)
|
||||||
|
self._add_telegram_message_to_db(event_id, space, response)
|
||||||
|
|
||||||
|
async def _handle_matrix_location(self, sender_id, event_id, space, client, message, reply_to):
|
||||||
try:
|
try:
|
||||||
lat, long = message["geo_uri"][len("geo:"):].split(",")
|
lat, long = message["geo_uri"][len("geo:"):].split(",")
|
||||||
lat, long = float(lat), float(long)
|
lat, long = float(lat), float(long)
|
||||||
@@ -664,11 +690,26 @@ class Portal:
|
|||||||
return None
|
return None
|
||||||
message, entities = await self._matrix_event_to_entities(client, message)
|
message, entities = await self._matrix_event_to_entities(client, message)
|
||||||
media = MessageMediaGeo(geo=GeoPoint(lat, long))
|
media = MessageMediaGeo(geo=GeoPoint(lat, long))
|
||||||
return await client.send_media(self.peer, media, reply_to=reply_to, caption=message,
|
|
||||||
entities=entities)
|
lock = self.require_send_lock(sender_id)
|
||||||
|
async with lock:
|
||||||
|
response = await client.send_media(self.peer, media, reply_to=reply_to, caption=message,
|
||||||
|
entities=entities)
|
||||||
|
self._add_telegram_message_to_db(event_id, space, response)
|
||||||
|
|
||||||
|
def _add_telegram_message_to_db(self, event_id, space, response):
|
||||||
|
self.log.debug("Handled Matrix message: %s", response)
|
||||||
|
self.is_duplicate(response, (event_id, space))
|
||||||
|
self.db.add(DBMessage(
|
||||||
|
tgid=response.id,
|
||||||
|
tg_space=space,
|
||||||
|
mx_room=self.mxid,
|
||||||
|
mxid=event_id))
|
||||||
|
self.db.commit()
|
||||||
|
|
||||||
async def handle_matrix_message(self, sender, message, event_id):
|
async def handle_matrix_message(self, sender, message, event_id):
|
||||||
client = sender.client if sender.logged_in else self.bot.client
|
client = sender.client if sender.logged_in else self.bot.client
|
||||||
|
sender_id = sender.tgid if sender.logged_in else self.bot.tgid
|
||||||
space = (self.tgid if self.peer_type == "channel" # Channels have their own ID space
|
space = (self.tgid if self.peer_type == "channel" # Channels have their own ID space
|
||||||
else (sender.tgid if sender.logged_in else self.bot.tgid))
|
else (sender.tgid if sender.logged_in else self.bot.tgid))
|
||||||
reply_to = formatter.matrix_reply_to_telegram(message, space, room_id=self.mxid)
|
reply_to = formatter.matrix_reply_to_telegram(message, space, room_id=self.mxid)
|
||||||
@@ -678,26 +719,13 @@ class Portal:
|
|||||||
type = message["msgtype"]
|
type = message["msgtype"]
|
||||||
|
|
||||||
if type == "m.text" or (self.bridge_notices and type == "m.notice"):
|
if type == "m.text" or (self.bridge_notices and type == "m.notice"):
|
||||||
response = await self._handle_matrix_text(client, message, reply_to)
|
await self._handle_matrix_text(sender_id, event_id, space, client, message, reply_to)
|
||||||
elif type == "m.location":
|
elif type == "m.location":
|
||||||
response = await self._handle_matrix_location(client, message, reply_to)
|
await self._handle_matrix_location(sender_id, event_id, space, client, message, reply_to)
|
||||||
elif type in ("m.image", "m.file", "m.audio", "m.video"):
|
elif type in ("m.image", "m.file", "m.audio", "m.video"):
|
||||||
response = await self._handle_matrix_file(client, message, reply_to)
|
await self._handle_matrix_file(sender_id, event_id, space, client, message, reply_to)
|
||||||
else:
|
else:
|
||||||
self.log.debug("Unhandled Matrix event: %s", message)
|
self.log.debug("Unhandled Matrix event: %s", message)
|
||||||
response = None
|
|
||||||
|
|
||||||
if not response:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.log.debug("Handled Matrix message: %s", response)
|
|
||||||
self.is_duplicate(response, (event_id, space))
|
|
||||||
self.db.add(DBMessage(
|
|
||||||
tgid=response.id,
|
|
||||||
tg_space=space,
|
|
||||||
mx_room=self.mxid,
|
|
||||||
mxid=event_id))
|
|
||||||
self.db.commit()
|
|
||||||
|
|
||||||
async def handle_matrix_pin(self, sender, pinned_message):
|
async def handle_matrix_pin(self, sender, pinned_message):
|
||||||
if self.peer_type != "channel":
|
if self.peer_type != "channel":
|
||||||
@@ -1073,7 +1101,13 @@ class Portal:
|
|||||||
self.log.debug("Edits as replies disabled, ignoring edit event...")
|
self.log.debug("Edits as replies disabled, ignoring edit event...")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
lock = self.optional_send_lock(sender.tgid if sender else None)
|
||||||
|
if lock:
|
||||||
|
async with lock:
|
||||||
|
pass
|
||||||
|
|
||||||
tg_space = self.tgid if self.peer_type == "channel" else source.tgid
|
tg_space = self.tgid if self.peer_type == "channel" else source.tgid
|
||||||
|
|
||||||
temporary_identifier = f"${random.randint(1000000000000,9999999999999)}TGBRIDGEDITEMP"
|
temporary_identifier = f"${random.randint(1000000000000,9999999999999)}TGBRIDGEDITEMP"
|
||||||
duplicate_found = self.is_duplicate(evt, (temporary_identifier, tg_space), force_hash=True)
|
duplicate_found = self.is_duplicate(evt, (temporary_identifier, tg_space), force_hash=True)
|
||||||
if duplicate_found:
|
if duplicate_found:
|
||||||
@@ -1111,6 +1145,11 @@ class Portal:
|
|||||||
if not self.mxid:
|
if not self.mxid:
|
||||||
await self.create_matrix_room(source, invites=[source.mxid], update_if_exists=False)
|
await self.create_matrix_room(source, invites=[source.mxid], update_if_exists=False)
|
||||||
|
|
||||||
|
lock = self.optional_send_lock(sender.tgid if sender else None)
|
||||||
|
if lock:
|
||||||
|
async with lock:
|
||||||
|
pass
|
||||||
|
|
||||||
tg_space = self.tgid if self.peer_type == "channel" else source.tgid
|
tg_space = self.tgid if self.peer_type == "channel" else source.tgid
|
||||||
|
|
||||||
temporary_identifier = f"${random.randint(1000000000000,9999999999999)}TGBRIDGETEMP"
|
temporary_identifier = f"${random.randint(1000000000000,9999999999999)}TGBRIDGETEMP"
|
||||||
@@ -1122,6 +1161,7 @@ class Portal:
|
|||||||
DBMessage(tgid=evt.id, mx_room=self.mxid, mxid=mxid, tg_space=tg_space))
|
DBMessage(tgid=evt.id, mx_room=self.mxid, mxid=mxid, tg_space=tg_space))
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
return
|
return
|
||||||
|
|
||||||
allowed_media = (MessageMediaPhoto, MessageMediaDocument, MessageMediaGeo)
|
allowed_media = (MessageMediaPhoto, MessageMediaDocument, MessageMediaGeo)
|
||||||
media = evt.media if hasattr(evt, "media") and isinstance(evt.media,
|
media = evt.media if hasattr(evt, "media") and isinstance(evt.media,
|
||||||
allowed_media) else None
|
allowed_media) else None
|
||||||
|
|||||||
@@ -51,29 +51,23 @@ class MautrixTelegramClient(TelegramClient):
|
|||||||
|
|
||||||
return self._get_response_message(request, result)
|
return self._get_response_message(request, result)
|
||||||
|
|
||||||
async def send_file(self, entity, file, mime_type=None, caption=None, entities=None,
|
async def upload_file(self, file, mime_type=None, attributes=None, file_name=None):
|
||||||
attributes=None, file_name=None, reply_to=None, **kwargs):
|
file_handle = await super().upload_file(file, file_name=file_name, use_cache=False)
|
||||||
entity = await self.get_input_entity(entity)
|
|
||||||
reply_to = self._get_message_id(reply_to)
|
|
||||||
|
|
||||||
file_handle = await self.upload_file(file, file_name=file_name, use_cache=False)
|
|
||||||
|
|
||||||
if mime_type == "image/png" or mime_type == "image/jpeg":
|
if mime_type == "image/png" or mime_type == "image/jpeg":
|
||||||
media = InputMediaUploadedPhoto(file_handle)
|
return InputMediaUploadedPhoto(file_handle)
|
||||||
else:
|
else:
|
||||||
attributes = attributes or []
|
attributes = attributes or []
|
||||||
attr_dict = {type(attr): attr for attr in attributes}
|
attr_dict = {type(attr): attr for attr in attributes}
|
||||||
|
|
||||||
media = InputMediaUploadedDocument(
|
return InputMediaUploadedDocument(
|
||||||
file=file_handle,
|
file=file_handle,
|
||||||
mime_type=mime_type or "application/octet-stream",
|
mime_type=mime_type or "application/octet-stream",
|
||||||
attributes=list(attr_dict.values()))
|
attributes=list(attr_dict.values()))
|
||||||
|
|
||||||
request = SendMediaRequest(entity, media, message=caption or "", entities=entities or [],
|
|
||||||
reply_to_msg_id=reply_to)
|
|
||||||
return self._get_response_message(request, await self(request))
|
|
||||||
|
|
||||||
async def send_media(self, entity, media, caption=None, entities=None, reply_to=None):
|
async def send_media(self, entity, media, caption=None, entities=None, reply_to=None):
|
||||||
|
entity = await self.get_input_entity(entity)
|
||||||
|
reply_to = self._get_message_id(reply_to)
|
||||||
request = SendMediaRequest(entity, media, message=caption or "", entities=entities or [],
|
request = SendMediaRequest(entity, media, message=caption or "", entities=entities or [],
|
||||||
reply_to_msg_id=reply_to)
|
reply_to_msg_id=reply_to)
|
||||||
return self._get_response_message(request, await self(request))
|
return self._get_response_message(request, await self(request))
|
||||||
|
|||||||
Reference in New Issue
Block a user