summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <patrickc@matrix.org>2022-05-26 15:55:06 -0400
committerPatrick Cloke <patrickc@matrix.org>2022-06-13 09:57:05 -0400
commit1ec3885aa9b416b4f697745bcaf8f8bda51738d9 (patch)
tree69df0ad55044eb7afdec14804a333444a73fb952
parentCreate a separate RangedReadReceipt class. (diff)
downloadsynapse-1ec3885aa9b416b4f697745bcaf8f8bda51738d9.tar.xz
Accept a start & end event ID when creating a receipt.
-rw-r--r--synapse/handlers/receipts.py45
-rw-r--r--synapse/rest/client/read_marker.py4
-rw-r--r--synapse/rest/client/receipts.py2
-rw-r--r--synapse/storage/databases/main/receipts.py51
4 files changed, 66 insertions, 36 deletions
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 5588545850..989ede27c1 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -19,7 +19,9 @@ from synapse.appservice import ApplicationService
 from synapse.streams import EventSource
 from synapse.types import (
     JsonDict,
+    RangedReadReceipt,
     ReadReceipt,
+    Receipt,
     StreamKeyType,
     UserID,
     get_domain_from_id,
@@ -65,7 +67,7 @@ class ReceiptsHandler:
 
     async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None:
         """Called when we receive an EDU of type m.receipt from a remote HS."""
-        receipts = []
+        receipts: List[Receipt] = []
         for room_id, room_values in content.items():
             # If we're not in the room just ditch the event entirely. This is
             # probably an old server that has come back and thinks we're still in
@@ -103,7 +105,7 @@ class ReceiptsHandler:
 
         await self._handle_new_receipts(receipts)
 
-    async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
+    async def _handle_new_receipts(self, receipts: List[Receipt]) -> bool:
         """Takes a list of receipts, stores them and informs the notifier."""
         min_batch_id: Optional[int] = None
         max_batch_id: Optional[int] = None
@@ -140,24 +142,45 @@ class ReceiptsHandler:
         return True
 
     async def received_client_receipt(
-        self, room_id: str, receipt_type: str, user_id: str, event_id: str
+        self,
+        room_id: str,
+        receipt_type: str,
+        user_id: str,
+        end_event_id: str,
+        start_event_id: Optional[str] = None,
     ) -> None:
         """Called when a client tells us a local user has read up to the given
         event_id in the room.
         """
-        receipt = ReadReceipt(
-            room_id=room_id,
-            receipt_type=receipt_type,
-            user_id=user_id,
-            event_ids=[event_id],
-            data={"ts": int(self.clock.time_msec())},
-        )
+
+        if start_event_id:
+            receipt: Receipt = RangedReadReceipt(
+                room_id=room_id,
+                receipt_type=receipt_type,
+                user_id=user_id,
+                start_event_id=start_event_id,
+                end_event_id=end_event_id,
+                data={"ts": int(self.clock.time_msec())},
+            )
+        else:
+            receipt = ReadReceipt(
+                room_id=room_id,
+                receipt_type=receipt_type,
+                user_id=user_id,
+                event_ids=[end_event_id],
+                data={"ts": int(self.clock.time_msec())},
+            )
 
         is_new = await self._handle_new_receipts([receipt])
         if not is_new:
             return
 
-        if self.federation_sender and receipt_type != ReceiptTypes.READ_PRIVATE:
+        # XXX How to handle this for a ranged read receipt.
+        if (
+            isinstance(receipt, ReadReceipt)
+            and self.federation_sender
+            and receipt_type != ReceiptTypes.READ_PRIVATE
+        ):
             await self.federation_sender.send_read_receipt(receipt)
 
 
diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py
index 3644705e6a..15c031159e 100644
--- a/synapse/rest/client/read_marker.py
+++ b/synapse/rest/client/read_marker.py
@@ -71,7 +71,7 @@ class ReadMarkerRestServlet(RestServlet):
                 room_id,
                 ReceiptTypes.READ,
                 user_id=requester.user.to_string(),
-                event_id=read_event_id,
+                end_event_id=read_event_id,
             )
 
         read_private_event_id = body.get(ReceiptTypes.READ_PRIVATE, None)
@@ -80,7 +80,7 @@ class ReadMarkerRestServlet(RestServlet):
                 room_id,
                 ReceiptTypes.READ_PRIVATE,
                 user_id=requester.user.to_string(),
-                event_id=read_private_event_id,
+                end_event_id=read_private_event_id,
             )
 
         read_marker_event_id = body.get(ReceiptTypes.FULLY_READ, None)
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index 4b03eb876b..f1acd42f5e 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -80,7 +80,7 @@ class ReceiptRestServlet(RestServlet):
                 room_id,
                 receipt_type,
                 user_id=requester.user.to_string(),
-                event_id=event_id,
+                end_event_id=event_id,
             )
 
         return 200, {}
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 2252dd8608..d318753650 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -42,7 +42,7 @@ from synapse.storage.util.id_generators import (
     MultiWriterIdGenerator,
     StreamIdGenerator,
 )
-from synapse.types import JsonDict, ReadReceipt
+from synapse.types import JsonDict, RangedReadReceipt, ReadReceipt, Receipt
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -725,7 +725,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         else:
             raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
 
-    async def insert_receipt(self, receipt: ReadReceipt) -> Optional[Tuple[int, int]]:
+    async def insert_receipt(self, receipt: Receipt) -> Optional[Tuple[int, int]]:
         """Insert a receipt, either from local client or remote server.
 
         Automatically does conversion between linearized and graph
@@ -737,19 +737,25 @@ class ReceiptsWorkerStore(SQLBaseStore):
         """
         assert self._can_write_to_receipts
 
-        if not receipt.event_ids:
-            return None
+        if isinstance(receipt, ReadReceipt):
+            event_ids = receipt.event_ids
+            if not event_ids:
+                return None
 
-        if len(receipt.event_ids) == 1:
-            linearized_event_id = receipt.event_ids[0]
+            if len(event_ids) == 1:
+                linearized_event_id = event_ids[0]
+            else:
+                # we need to points in graph -> linearized form.
+                linearized_event_id = await self.db_pool.runInteraction(
+                    "insert_receipt_conv",
+                    self._graph_to_linear,
+                    receipt.room_id,
+                    event_ids,
+                )
+        elif isinstance(receipt, RangedReadReceipt):
+            linearized_event_id = receipt.end_event_id
         else:
-            # we need to points in graph -> linearized form.
-            linearized_event_id = await self.db_pool.runInteraction(
-                "insert_receipt_conv",
-                self._graph_to_linear,
-                receipt.room_id,
-                receipt.event_ids,
-            )
+            raise ValueError("Unexpected receipt type: %s", type(receipt))
 
         async with self._receipts_id_gen.get_next() as stream_id:  # type: ignore[attr-defined]
             event_ts = await self.db_pool.runInteraction(
@@ -779,15 +785,16 @@ class ReceiptsWorkerStore(SQLBaseStore):
             now - event_ts,
         )
 
-        await self.db_pool.runInteraction(
-            "insert_graph_receipt",
-            self._insert_graph_receipt_txn,
-            receipt.room_id,
-            receipt.receipt_type,
-            receipt.user_id,
-            receipt.event_ids,
-            receipt.data,
-        )
+        # XXX These aren't really used right now, go away.
+        # await self.db_pool.runInteraction(
+        #     "insert_graph_receipt",
+        #     self._insert_graph_receipt_txn,
+        #     room_id,
+        #     receipt_type,
+        #     user_id,
+        #     event_ids,
+        #     data,
+        # )
 
         max_persisted_id = self._receipts_id_gen.get_current_token()