diff options
Diffstat (limited to 'synapse/handlers/receipts.py')
-rw-r--r-- | synapse/handlers/receipts.py | 49 |
1 files changed, 29 insertions, 20 deletions
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index a85dd8cdee..73973502a4 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -17,7 +17,7 @@ import logging from twisted.internet import defer from synapse.handlers._base import BaseHandler -from synapse.types import ReadReceipt +from synapse.types import ReadReceipt, get_domain_from_id logger = logging.getLogger(__name__) @@ -40,18 +40,27 @@ class ReceiptsHandler(BaseHandler): def _received_remote_receipt(self, origin, content): """Called when we receive an EDU of type m.receipt from a remote HS. """ - receipts = [ - ReadReceipt( - room_id=room_id, - receipt_type=receipt_type, - user_id=user_id, - event_ids=user_values["event_ids"], - data=user_values.get("data", {}), - ) - for room_id, room_values in content.items() - for receipt_type, users in room_values.items() - for user_id, user_values in users.items() - ] + receipts = [] + for room_id, room_values in content.items(): + for receipt_type, users in room_values.items(): + for user_id, user_values in users.items(): + if get_domain_from_id(user_id) != origin: + logger.info( + "Received receipt for user %r from server %s, ignoring", + user_id, + origin, + ) + continue + + receipts.append( + ReadReceipt( + room_id=room_id, + receipt_type=receipt_type, + user_id=user_id, + event_ids=user_values["event_ids"], + data=user_values.get("data", {}), + ) + ) yield self._handle_new_receipts(receipts) @@ -84,7 +93,7 @@ class ReceiptsHandler(BaseHandler): if min_batch_id is None: # no new receipts - defer.returnValue(False) + return False affected_room_ids = list(set([r.room_id for r in receipts])) @@ -94,7 +103,7 @@ class ReceiptsHandler(BaseHandler): min_batch_id, max_batch_id, affected_room_ids ) - defer.returnValue(True) + return True @defer.inlineCallbacks def received_client_receipt(self, room_id, receipt_type, user_id, event_id): @@ -124,9 +133,9 @@ class ReceiptsHandler(BaseHandler): ) if not result: - defer.returnValue([]) + return [] - defer.returnValue(result) + return result class ReceiptEventSource(object): @@ -139,13 +148,13 @@ class ReceiptEventSource(object): to_key = yield self.get_current_key() if from_key == to_key: - defer.returnValue(([], to_key)) + return ([], to_key) events = yield self.store.get_linearized_receipts_for_rooms( room_ids, from_key=from_key, to_key=to_key ) - defer.returnValue((events, to_key)) + return (events, to_key) def get_current_key(self, direction="f"): return self.store.get_max_receipt_stream_id() @@ -164,4 +173,4 @@ class ReceiptEventSource(object): room_ids, from_key=from_key, to_key=to_key ) - defer.returnValue((events, to_key)) + return (events, to_key) |