summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/6280.misc1
-rw-r--r--synapse/federation/send_queue.py4
-rw-r--r--synapse/handlers/read_marker.py13
-rw-r--r--synapse/handlers/receipts.py37
-rw-r--r--synapse/rest/client/v2_alpha/read_marker.py13
-rw-r--r--synapse/rest/client/v2_alpha/receipts.py11
-rw-r--r--synapse/storage/data_stores/main/events.py7
7 files changed, 33 insertions, 53 deletions
diff --git a/changelog.d/6280.misc b/changelog.d/6280.misc
new file mode 100644
index 0000000000..96a0eb21b2
--- /dev/null
+++ b/changelog.d/6280.misc
@@ -0,0 +1 @@
+Port receipt and read markers to async/wait.
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 454456a52d..ced4925a98 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -36,6 +36,8 @@ from six import iteritems
 
 from sortedcontainers import SortedDict
 
+from twisted.internet import defer
+
 from synapse.metrics import LaterGauge
 from synapse.storage.presence import UserPresenceState
 from synapse.util.metrics import Measure
@@ -212,7 +214,7 @@ class FederationRemoteSendQueue(object):
             receipt (synapse.types.ReadReceipt):
         """
         # nothing to do here: the replication listener will handle it.
-        pass
+        return defer.succeed(None)
 
     def send_presence(self, states):
         """As per FederationSender
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index 3e4d8c93a4..e3b528d271 100644
--- a/synapse/handlers/read_marker.py
+++ b/synapse/handlers/read_marker.py
@@ -15,8 +15,6 @@
 
 import logging
 
-from twisted.internet import defer
-
 from synapse.util.async_helpers import Linearizer
 
 from ._base import BaseHandler
@@ -32,8 +30,7 @@ class ReadMarkerHandler(BaseHandler):
         self.read_marker_linearizer = Linearizer(name="read_marker")
         self.notifier = hs.get_notifier()
 
-    @defer.inlineCallbacks
-    def received_client_read_marker(self, room_id, user_id, event_id):
+    async def received_client_read_marker(self, room_id, user_id, event_id):
         """Updates the read marker for a given user in a given room if the event ID given
         is ahead in the stream relative to the current read marker.
 
@@ -41,8 +38,8 @@ class ReadMarkerHandler(BaseHandler):
         the read marker has changed.
         """
 
-        with (yield self.read_marker_linearizer.queue((room_id, user_id))):
-            existing_read_marker = yield self.store.get_account_data_for_room_and_type(
+        with await self.read_marker_linearizer.queue((room_id, user_id)):
+            existing_read_marker = await self.store.get_account_data_for_room_and_type(
                 user_id, room_id, "m.fully_read"
             )
 
@@ -50,13 +47,13 @@ class ReadMarkerHandler(BaseHandler):
 
             if existing_read_marker:
                 # Only update if the new marker is ahead in the stream
-                should_update = yield self.store.is_event_after(
+                should_update = await self.store.is_event_after(
                     event_id, existing_read_marker["event_id"]
                 )
 
             if should_update:
                 content = {"event_id": event_id}
-                max_id = yield self.store.add_account_data_to_room(
+                max_id = await self.store.add_account_data_to_room(
                     user_id, room_id, "m.fully_read", content
                 )
                 self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 6854c751a6..9283c039e3 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
 
 from synapse.handlers._base import BaseHandler
 from synapse.types import ReadReceipt, get_domain_from_id
+from synapse.util.async_helpers import maybe_awaitable
 
 logger = logging.getLogger(__name__)
 
@@ -36,8 +37,7 @@ class ReceiptsHandler(BaseHandler):
         self.clock = self.hs.get_clock()
         self.state = hs.get_state_handler()
 
-    @defer.inlineCallbacks
-    def _received_remote_receipt(self, origin, content):
+    async def _received_remote_receipt(self, origin, content):
         """Called when we receive an EDU of type m.receipt from a remote HS.
         """
         receipts = []
@@ -62,17 +62,16 @@ class ReceiptsHandler(BaseHandler):
                         )
                     )
 
-        yield self._handle_new_receipts(receipts)
+        await self._handle_new_receipts(receipts)
 
-    @defer.inlineCallbacks
-    def _handle_new_receipts(self, receipts):
+    async def _handle_new_receipts(self, receipts):
         """Takes a list of receipts, stores them and informs the notifier.
         """
         min_batch_id = None
         max_batch_id = None
 
         for receipt in receipts:
-            res = yield self.store.insert_receipt(
+            res = await self.store.insert_receipt(
                 receipt.room_id,
                 receipt.receipt_type,
                 receipt.user_id,
@@ -99,14 +98,15 @@ class ReceiptsHandler(BaseHandler):
 
         self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
         # Note that the min here shouldn't be relied upon to be accurate.
-        yield self.hs.get_pusherpool().on_new_receipts(
-            min_batch_id, max_batch_id, affected_room_ids
+        await maybe_awaitable(
+            self.hs.get_pusherpool().on_new_receipts(
+                min_batch_id, max_batch_id, affected_room_ids
+            )
         )
 
         return True
 
-    @defer.inlineCallbacks
-    def received_client_receipt(self, room_id, receipt_type, user_id, event_id):
+    async def received_client_receipt(self, room_id, receipt_type, user_id, event_id):
         """Called when a client tells us a local user has read up to the given
         event_id in the room.
         """
@@ -118,24 +118,11 @@ class ReceiptsHandler(BaseHandler):
             data={"ts": int(self.clock.time_msec())},
         )
 
-        is_new = yield self._handle_new_receipts([receipt])
+        is_new = await self._handle_new_receipts([receipt])
         if not is_new:
             return
 
-        yield self.federation.send_read_receipt(receipt)
-
-    @defer.inlineCallbacks
-    def get_receipts_for_room(self, room_id, to_key):
-        """Gets all receipts for a room, upto the given key.
-        """
-        result = yield self.store.get_linearized_receipts_for_room(
-            room_id, to_key=to_key
-        )
-
-        if not result:
-            return []
-
-        return result
+        await self.federation.send_read_receipt(receipt)
 
 
 class ReceiptEventSource(object):
diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py
index b3bf8567e1..67cbc37312 100644
--- a/synapse/rest/client/v2_alpha/read_marker.py
+++ b/synapse/rest/client/v2_alpha/read_marker.py
@@ -15,8 +15,6 @@
 
 import logging
 
-from twisted.internet import defer
-
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
 
 from ._base import client_patterns
@@ -34,17 +32,16 @@ class ReadMarkerRestServlet(RestServlet):
         self.read_marker_handler = hs.get_read_marker_handler()
         self.presence_handler = hs.get_presence_handler()
 
-    @defer.inlineCallbacks
-    def on_POST(self, request, room_id):
-        requester = yield self.auth.get_user_by_req(request)
+    async def on_POST(self, request, room_id):
+        requester = await self.auth.get_user_by_req(request)
 
-        yield self.presence_handler.bump_presence_active_time(requester.user)
+        await self.presence_handler.bump_presence_active_time(requester.user)
 
         body = parse_json_object_from_request(request)
 
         read_event_id = body.get("m.read", None)
         if read_event_id:
-            yield self.receipts_handler.received_client_receipt(
+            await self.receipts_handler.received_client_receipt(
                 room_id,
                 "m.read",
                 user_id=requester.user.to_string(),
@@ -53,7 +50,7 @@ class ReadMarkerRestServlet(RestServlet):
 
         read_marker_event_id = body.get("m.fully_read", None)
         if read_marker_event_id:
-            yield self.read_marker_handler.received_client_read_marker(
+            await self.read_marker_handler.received_client_read_marker(
                 room_id,
                 user_id=requester.user.to_string(),
                 event_id=read_marker_event_id,
diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py
index 0dab03d227..92555bd4a9 100644
--- a/synapse/rest/client/v2_alpha/receipts.py
+++ b/synapse/rest/client/v2_alpha/receipts.py
@@ -15,8 +15,6 @@
 
 import logging
 
-from twisted.internet import defer
-
 from synapse.api.errors import SynapseError
 from synapse.http.servlet import RestServlet
 
@@ -39,16 +37,15 @@ class ReceiptRestServlet(RestServlet):
         self.receipts_handler = hs.get_receipts_handler()
         self.presence_handler = hs.get_presence_handler()
 
-    @defer.inlineCallbacks
-    def on_POST(self, request, room_id, receipt_type, event_id):
-        requester = yield self.auth.get_user_by_req(request)
+    async def on_POST(self, request, room_id, receipt_type, event_id):
+        requester = await self.auth.get_user_by_req(request)
 
         if receipt_type != "m.read":
             raise SynapseError(400, "Receipt type must be 'm.read'")
 
-        yield self.presence_handler.bump_presence_active_time(requester.user)
+        await self.presence_handler.bump_presence_active_time(requester.user)
 
-        yield self.receipts_handler.received_client_receipt(
+        await self.receipts_handler.received_client_receipt(
             room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id
         )
 
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 03b5111c5d..067e77ae00 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -2439,12 +2439,11 @@ class EventsStore(
 
         logger.info("[purge] done")
 
-    @defer.inlineCallbacks
-    def is_event_after(self, event_id1, event_id2):
+    async def is_event_after(self, event_id1, event_id2):
         """Returns True if event_id1 is after event_id2 in the stream
         """
-        to_1, so_1 = yield self._get_event_ordering(event_id1)
-        to_2, so_2 = yield self._get_event_ordering(event_id2)
+        to_1, so_1 = await self._get_event_ordering(event_id1)
+        to_2, so_2 = await self._get_event_ordering(event_id2)
         return (to_1, so_1) > (to_2, so_2)
 
     @cachedInlineCallbacks(max_entries=5000)