summary refs log tree commit diff
path: root/synapse/storage/databases/main/receipts.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/receipts.py')
-rw-r--r--synapse/storage/databases/main/receipts.py89
1 files changed, 47 insertions, 42 deletions
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index f74aa1e3f3..21e954ccc1 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -597,7 +597,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         return super().process_replication_rows(stream_name, instance_name, token, rows)
 
-    def insert_linearized_receipt_txn(
+    def _insert_linearized_receipt_txn(
         self,
         txn: LoggingTransaction,
         room_id: str,
@@ -686,6 +686,44 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         return rx_ts
 
+    def _graph_to_linear(
+        self, txn: LoggingTransaction, room_id: str, event_ids: List[str]
+    ) -> str:
+        """
+        Generate a linearized event from a list of events (i.e. a list of forward
+        extremities in the room).
+
+        This should allow for calculation of the correct read receipt even if
+        servers have different event ordering.
+
+        Args:
+            txn: The transaction
+            room_id: The room ID the events are in.
+            event_ids: The list of event IDs to linearize.
+
+        Returns:
+            The linearized event ID.
+        """
+        # TODO: Make this better.
+        clause, args = make_in_list_sql_clause(
+            self.database_engine, "event_id", event_ids
+        )
+
+        sql = """
+            SELECT event_id WHERE room_id = ? AND stream_ordering IN (
+                SELECT max(stream_ordering) WHERE %s
+            )
+        """ % (
+            clause,
+        )
+
+        txn.execute(sql, [room_id] + list(args))
+        rows = txn.fetchall()
+        if rows:
+            return rows[0][0]
+        else:
+            raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
+
     async def insert_receipt(
         self,
         room_id: str,
@@ -712,35 +750,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
             linearized_event_id = event_ids[0]
         else:
             # we need to points in graph -> linearized form.
-            # TODO: Make this better.
-            def graph_to_linear(txn: LoggingTransaction) -> str:
-                clause, args = make_in_list_sql_clause(
-                    self.database_engine, "event_id", event_ids
-                )
-
-                sql = """
-                    SELECT event_id WHERE room_id = ? AND stream_ordering IN (
-                        SELECT max(stream_ordering) WHERE %s
-                    )
-                """ % (
-                    clause,
-                )
-
-                txn.execute(sql, [room_id] + list(args))
-                rows = txn.fetchall()
-                if rows:
-                    return rows[0][0]
-                else:
-                    raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
-
             linearized_event_id = await self.db_pool.runInteraction(
-                "insert_receipt_conv", graph_to_linear
+                "insert_receipt_conv", self._graph_to_linear, room_id, event_ids
             )
 
         async with self._receipts_id_gen.get_next() as stream_id:  # type: ignore[attr-defined]
             event_ts = await self.db_pool.runInteraction(
                 "insert_linearized_receipt",
-                self.insert_linearized_receipt_txn,
+                self._insert_linearized_receipt_txn,
                 room_id,
                 receipt_type,
                 user_id,
@@ -761,25 +778,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
             now - event_ts,
         )
 
-        await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
-
-        max_persisted_id = self._receipts_id_gen.get_current_token()
-
-        return stream_id, max_persisted_id
-
-    async def insert_graph_receipt(
-        self,
-        room_id: str,
-        receipt_type: str,
-        user_id: str,
-        event_ids: List[str],
-        data: JsonDict,
-    ) -> None:
-        assert self._can_write_to_receipts
-
         await self.db_pool.runInteraction(
             "insert_graph_receipt",
-            self.insert_graph_receipt_txn,
+            self._insert_graph_receipt_txn,
             room_id,
             receipt_type,
             user_id,
@@ -787,7 +788,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
             data,
         )
 
-    def insert_graph_receipt_txn(
+        max_persisted_id = self._receipts_id_gen.get_current_token()
+
+        return stream_id, max_persisted_id
+
+    def _insert_graph_receipt_txn(
         self,
         txn: LoggingTransaction,
         room_id: str,