diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 6854c751a6..9283c039e3 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.handlers._base import BaseHandler
from synapse.types import ReadReceipt, get_domain_from_id
+from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__)
@@ -36,8 +37,7 @@ class ReceiptsHandler(BaseHandler):
self.clock = self.hs.get_clock()
self.state = hs.get_state_handler()
- @defer.inlineCallbacks
- def _received_remote_receipt(self, origin, content):
+ async def _received_remote_receipt(self, origin, content):
"""Called when we receive an EDU of type m.receipt from a remote HS.
"""
receipts = []
@@ -62,17 +62,16 @@ class ReceiptsHandler(BaseHandler):
)
)
- yield self._handle_new_receipts(receipts)
+ await self._handle_new_receipts(receipts)
- @defer.inlineCallbacks
- def _handle_new_receipts(self, receipts):
+ async def _handle_new_receipts(self, receipts):
"""Takes a list of receipts, stores them and informs the notifier.
"""
min_batch_id = None
max_batch_id = None
for receipt in receipts:
- res = yield self.store.insert_receipt(
+ res = await self.store.insert_receipt(
receipt.room_id,
receipt.receipt_type,
receipt.user_id,
@@ -99,14 +98,15 @@ class ReceiptsHandler(BaseHandler):
self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
# Note that the min here shouldn't be relied upon to be accurate.
- yield self.hs.get_pusherpool().on_new_receipts(
- min_batch_id, max_batch_id, affected_room_ids
+ await maybe_awaitable(
+ self.hs.get_pusherpool().on_new_receipts(
+ min_batch_id, max_batch_id, affected_room_ids
+ )
)
return True
- @defer.inlineCallbacks
- def received_client_receipt(self, room_id, receipt_type, user_id, event_id):
+ async def received_client_receipt(self, room_id, receipt_type, user_id, event_id):
"""Called when a client tells us a local user has read up to the given
event_id in the room.
"""
@@ -118,24 +118,11 @@ class ReceiptsHandler(BaseHandler):
data={"ts": int(self.clock.time_msec())},
)
- is_new = yield self._handle_new_receipts([receipt])
+ is_new = await self._handle_new_receipts([receipt])
if not is_new:
return
- yield self.federation.send_read_receipt(receipt)
-
- @defer.inlineCallbacks
- def get_receipts_for_room(self, room_id, to_key):
- """Gets all receipts for a room, upto the given key.
- """
- result = yield self.store.get_linearized_receipts_for_room(
- room_id, to_key=to_key
- )
-
- if not result:
- return []
-
- return result
+ await self.federation.send_read_receipt(receipt)
class ReceiptEventSource(object):
|