summary refs log tree commit diff
path: root/synapse/storage/databases/main/receipts.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/receipts.py')
-rw-r--r--synapse/storage/databases/main/receipts.py87
1 files changed, 64 insertions, 23 deletions
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index ddb8e80b69..52fe0db924 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -540,7 +540,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
     async def get_all_updated_receipts(
         self, instance_name: str, last_id: int, current_id: int, limit: int
-    ) -> Tuple[List[Tuple[int, list]], int, bool]:
+    ) -> Tuple[
+        List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], int, bool
+    ]:
         """Get updates for receipts replication stream.
 
         Args:
@@ -567,9 +569,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         def get_all_updated_receipts_txn(
             txn: LoggingTransaction,
-        ) -> Tuple[List[Tuple[int, list]], int, bool]:
+        ) -> Tuple[
+            List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]],
+            int,
+            bool,
+        ]:
             sql = """
-                SELECT stream_id, room_id, receipt_type, user_id, event_id, data
+                SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data
                 FROM receipts_linearized
                 WHERE ? < stream_id AND stream_id <= ?
                 ORDER BY stream_id ASC
@@ -578,8 +584,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
             txn.execute(sql, (last_id, current_id, limit))
 
             updates = cast(
-                List[Tuple[int, list]],
-                [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn],
+                List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]],
+                [(r[0], r[1:6] + (db_to_json(r[6]),)) for r in txn],
             )
 
             limited = False
@@ -631,6 +637,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         receipt_type: str,
         user_id: str,
         event_id: str,
+        thread_id: Optional[str],
         data: JsonDict,
         stream_id: int,
     ) -> Optional[int]:
@@ -657,12 +664,27 @@ class ReceiptsWorkerStore(SQLBaseStore):
         # We don't want to clobber receipts for more recent events, so we
         # have to compare orderings of existing receipts
         if stream_ordering is not None:
-            sql = (
-                "SELECT stream_ordering, event_id FROM events"
-                " INNER JOIN receipts_linearized AS r USING (event_id, room_id)"
-                " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
+            if thread_id is None:
+                thread_clause = "r.thread_id IS NULL"
+                thread_args: Tuple[str, ...] = ()
+            else:
+                thread_clause = "r.thread_id = ?"
+                thread_args = (thread_id,)
+
+            sql = f"""
+            SELECT stream_ordering, event_id FROM events
+            INNER JOIN receipts_linearized AS r USING (event_id, room_id)
+            WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ? AND {thread_clause}
+            """
+            txn.execute(
+                sql,
+                (
+                    room_id,
+                    receipt_type,
+                    user_id,
+                )
+                + thread_args,
             )
-            txn.execute(sql, (room_id, receipt_type, user_id))
 
             for so, eid in txn:
                 if int(so) >= stream_ordering:
@@ -682,21 +704,28 @@ class ReceiptsWorkerStore(SQLBaseStore):
             self._receipts_stream_cache.entity_has_changed, room_id, stream_id
         )
 
+        keyvalues = {
+            "room_id": room_id,
+            "receipt_type": receipt_type,
+            "user_id": user_id,
+        }
+        where_clause = ""
+        if thread_id is None:
+            where_clause = "thread_id IS NULL"
+        else:
+            keyvalues["thread_id"] = thread_id
+
         self.db_pool.simple_upsert_txn(
             txn,
             table="receipts_linearized",
-            keyvalues={
-                "room_id": room_id,
-                "receipt_type": receipt_type,
-                "user_id": user_id,
-            },
+            keyvalues=keyvalues,
             values={
                 "stream_id": stream_id,
                 "event_id": event_id,
                 "event_stream_ordering": stream_ordering,
                 "data": json_encoder.encode(data),
-                "thread_id": None,
             },
+            where_clause=where_clause,
             # receipts_linearized has a unique constraint on
             # (user_id, room_id, receipt_type), so no need to lock
             lock=False,
@@ -748,6 +777,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         receipt_type: str,
         user_id: str,
         event_ids: List[str],
+        thread_id: Optional[str],
         data: dict,
     ) -> Optional[Tuple[int, int]]:
         """Insert a receipt, either from local client or remote server.
@@ -780,6 +810,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 receipt_type,
                 user_id,
                 linearized_event_id,
+                thread_id,
                 data,
                 stream_id=stream_id,
                 # Read committed is actually beneficial here because we check for a receipt with
@@ -794,7 +825,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         now = self._clock.time_msec()
         logger.debug(
-            "RR for event %s in %s (%i ms old)",
+            "Receipt %s for event %s in %s (%i ms old)",
+            receipt_type,
             linearized_event_id,
             room_id,
             now - event_ts,
@@ -807,6 +839,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             receipt_type,
             user_id,
             event_ids,
+            thread_id,
             data,
         )
 
@@ -821,6 +854,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         receipt_type: str,
         user_id: str,
         event_ids: List[str],
+        thread_id: Optional[str],
         data: JsonDict,
     ) -> None:
         assert self._can_write_to_receipts
@@ -832,19 +866,26 @@ class ReceiptsWorkerStore(SQLBaseStore):
         # FIXME: This shouldn't invalidate the whole cache
         txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,))
 
+        keyvalues = {
+            "room_id": room_id,
+            "receipt_type": receipt_type,
+            "user_id": user_id,
+        }
+        where_clause = ""
+        if thread_id is None:
+            where_clause = "thread_id IS NULL"
+        else:
+            keyvalues["thread_id"] = thread_id
+
         self.db_pool.simple_upsert_txn(
             txn,
             table="receipts_graph",
-            keyvalues={
-                "room_id": room_id,
-                "receipt_type": receipt_type,
-                "user_id": user_id,
-            },
+            keyvalues=keyvalues,
             values={
                 "event_ids": json_encoder.encode(event_ids),
                 "data": json_encoder.encode(data),
-                "thread_id": None,
             },
+            where_clause=where_clause,
             # receipts_graph has a unique constraint on
             # (user_id, room_id, receipt_type), so no need to lock
             lock=False,