summary refs log tree commit diff
path: root/synapse/storage/databases/main/receipts.py
diff options
context:
space:
mode:
authorErik Johnston <erikj@matrix.org>2023-10-25 16:16:19 +0100
committerGitHub <noreply@github.com>2023-10-25 16:16:19 +0100
commitba47fea5286e084ec70d568aa62eb4820b857c47 (patch)
tree6e2c608feb1ea0c23b2b9cc40d11211cc3a10aa5 /synapse/storage/databases/main/receipts.py
parentFix tests on Twisted trunk. (#16528) (diff)
downloadsynapse-ba47fea5286e084ec70d568aa62eb4820b857c47.tar.xz
Allow multiple workers to write to receipts stream. (#16432)
Fixes #16417
Diffstat (limited to 'synapse/storage/databases/main/receipts.py')
-rw-r--r--synapse/storage/databases/main/receipts.py148
1 files changed, 109 insertions, 39 deletions
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index b2645ab43c..56e8eb16a8 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -28,6 +28,8 @@ from typing import (
     cast,
 )
 
+from immutabledict import immutabledict
+
 from synapse.api.constants import EduTypes
 from synapse.replication.tcp.streams import ReceiptsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@@ -43,7 +45,12 @@ from synapse.storage.util.id_generators import (
     MultiWriterIdGenerator,
     StreamIdGenerator,
 )
-from synapse.types import JsonDict, JsonMapping
+from synapse.types import (
+    JsonDict,
+    JsonMapping,
+    MultiWriterStreamToken,
+    PersistedPosition,
+)
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -105,7 +112,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             "receipts_linearized",
             entity_column="room_id",
             stream_column="stream_id",
-            max_value=max_receipts_stream_id,
+            max_value=max_receipts_stream_id.stream,
             limit=10000,
         )
         self._receipts_stream_cache = StreamChangeCache(
@@ -114,9 +121,31 @@ class ReceiptsWorkerStore(SQLBaseStore):
             prefilled_cache=receipts_stream_prefill,
         )
 
-    def get_max_receipt_stream_id(self) -> int:
+    def get_max_receipt_stream_id(self) -> MultiWriterStreamToken:
         """Get the current max stream ID for receipts stream"""
-        return self._receipts_id_gen.get_current_token()
+
+        min_pos = self._receipts_id_gen.get_current_token()
+
+        positions = {}
+        if isinstance(self._receipts_id_gen, MultiWriterIdGenerator):
+            # The `min_pos` is the minimum position that we know all instances
+            # have finished persisting to, so we only care about instances whose
+            # positions are ahead of that. (Instance positions can be behind the
+            # min position as there are times we can work out that the minimum
+            # position is ahead of the naive minimum across all current
+            # positions. See MultiWriterIdGenerator for details)
+            positions = {
+                i: p
+                for i, p in self._receipts_id_gen.get_positions().items()
+                if p > min_pos
+            }
+
+        return MultiWriterStreamToken(
+            stream=min_pos, instance_map=immutabledict(positions)
+        )
+
+    def get_receipt_stream_id_for_instance(self, instance_name: str) -> int:
+        return self._receipts_id_gen.get_current_token_for_writer(instance_name)
 
     def get_last_unthreaded_receipt_for_user_txn(
         self,
@@ -257,7 +286,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
         }
 
     async def get_linearized_receipts_for_rooms(
-        self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None
+        self,
+        room_ids: Iterable[str],
+        to_key: MultiWriterStreamToken,
+        from_key: Optional[MultiWriterStreamToken] = None,
     ) -> List[JsonMapping]:
         """Get receipts for multiple rooms for sending to clients.
 
@@ -276,7 +308,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             # Only ask the database about rooms where there have been new
             # receipts added since `from_key`
             room_ids = self._receipts_stream_cache.get_entities_changed(
-                room_ids, from_key
+                room_ids, from_key.stream
             )
 
         results = await self._get_linearized_receipts_for_rooms(
@@ -286,7 +318,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
         return [ev for res in results.values() for ev in res]
 
     async def get_linearized_receipts_for_room(
-        self, room_id: str, to_key: int, from_key: Optional[int] = None
+        self,
+        room_id: str,
+        to_key: MultiWriterStreamToken,
+        from_key: Optional[MultiWriterStreamToken] = None,
     ) -> Sequence[JsonMapping]:
         """Get receipts for a single room for sending to clients.
 
@@ -302,36 +337,49 @@ class ReceiptsWorkerStore(SQLBaseStore):
         if from_key is not None:
             # Check the cache first to see if any new receipts have been added
             # since`from_key`. If not we can no-op.
-            if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
+            if not self._receipts_stream_cache.has_entity_changed(
+                room_id, from_key.stream
+            ):
                 return []
 
         return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
 
     @cached(tree=True)
     async def _get_linearized_receipts_for_room(
-        self, room_id: str, to_key: int, from_key: Optional[int] = None
+        self,
+        room_id: str,
+        to_key: MultiWriterStreamToken,
+        from_key: Optional[MultiWriterStreamToken] = None,
     ) -> Sequence[JsonMapping]:
         """See get_linearized_receipts_for_room"""
 
         def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str]]:
             if from_key:
-                sql = (
-                    "SELECT receipt_type, user_id, event_id, data"
-                    " FROM receipts_linearized WHERE"
-                    " room_id = ? AND stream_id > ? AND stream_id <= ?"
-                )
+                sql = """
+                    SELECT stream_id, instance_name, receipt_type, user_id, event_id, data
+                    FROM receipts_linearized
+                    WHERE room_id = ? AND stream_id > ? AND stream_id <= ?
+                """
 
-                txn.execute(sql, (room_id, from_key, to_key))
-            else:
-                sql = (
-                    "SELECT receipt_type, user_id, event_id, data"
-                    " FROM receipts_linearized WHERE"
-                    " room_id = ? AND stream_id <= ?"
+                txn.execute(
+                    sql, (room_id, from_key.stream, to_key.get_max_stream_pos())
                 )
+            else:
+                sql = """
+                    SELECT stream_id, instance_name, receipt_type, user_id, event_id, data
+                    FROM receipts_linearized WHERE
+                    room_id = ? AND stream_id <= ?
+                """
 
-                txn.execute(sql, (room_id, to_key))
+                txn.execute(sql, (room_id, to_key.get_max_stream_pos()))
 
-            return cast(List[Tuple[str, str, str, str]], txn.fetchall())
+            return [
+                (receipt_type, user_id, event_id, data)
+                for stream_id, instance_name, receipt_type, user_id, event_id, data in txn
+                if MultiWriterStreamToken.is_stream_position_in_range(
+                    from_key, to_key, instance_name, stream_id
+                )
+            ]
 
         rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
 
@@ -352,7 +400,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
         num_args=3,
     )
     async def _get_linearized_receipts_for_rooms(
-        self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
+        self,
+        room_ids: Collection[str],
+        to_key: MultiWriterStreamToken,
+        from_key: Optional[MultiWriterStreamToken] = None,
     ) -> Mapping[str, Sequence[JsonMapping]]:
         if not room_ids:
             return {}
@@ -362,7 +413,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
         ) -> List[Tuple[str, str, str, str, Optional[str], str]]:
             if from_key:
                 sql = """
-                    SELECT room_id, receipt_type, user_id, event_id, thread_id, data
+                    SELECT stream_id, instance_name, room_id, receipt_type,
+                        user_id, event_id, thread_id, data
                     FROM receipts_linearized WHERE
                     stream_id > ? AND stream_id <= ? AND
                 """
@@ -370,10 +422,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
                     self.database_engine, "room_id", room_ids
                 )
 
-                txn.execute(sql + clause, [from_key, to_key] + list(args))
+                txn.execute(
+                    sql + clause,
+                    [from_key.stream, to_key.get_max_stream_pos()] + list(args),
+                )
             else:
                 sql = """
-                    SELECT room_id, receipt_type, user_id, event_id, thread_id, data
+                    SELECT stream_id, instance_name, room_id, receipt_type,
+                        user_id, event_id, thread_id, data
                     FROM receipts_linearized WHERE
                     stream_id <= ? AND
                 """
@@ -382,11 +438,15 @@ class ReceiptsWorkerStore(SQLBaseStore):
                     self.database_engine, "room_id", room_ids
                 )
 
-                txn.execute(sql + clause, [to_key] + list(args))
+                txn.execute(sql + clause, [to_key.get_max_stream_pos()] + list(args))
 
-            return cast(
-                List[Tuple[str, str, str, str, Optional[str], str]], txn.fetchall()
-            )
+            return [
+                (room_id, receipt_type, user_id, event_id, thread_id, data)
+                for stream_id, instance_name, room_id, receipt_type, user_id, event_id, thread_id, data in txn
+                if MultiWriterStreamToken.is_stream_position_in_range(
+                    from_key, to_key, instance_name, stream_id
+                )
+            ]
 
         txn_results = await self.db_pool.runInteraction(
             "_get_linearized_receipts_for_rooms", f
@@ -420,7 +480,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
         num_args=2,
     )
     async def get_linearized_receipts_for_all_rooms(
-        self, to_key: int, from_key: Optional[int] = None
+        self,
+        to_key: MultiWriterStreamToken,
+        from_key: Optional[MultiWriterStreamToken] = None,
     ) -> Mapping[str, JsonMapping]:
         """Get receipts for all rooms between two stream_ids, up
         to a limit of the latest 100 read receipts.
@@ -437,25 +499,31 @@ class ReceiptsWorkerStore(SQLBaseStore):
         def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]:
             if from_key:
                 sql = """
-                    SELECT room_id, receipt_type, user_id, event_id, data
+                    SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data
                     FROM receipts_linearized WHERE
                     stream_id > ? AND stream_id <= ?
                     ORDER BY stream_id DESC
                     LIMIT 100
                 """
-                txn.execute(sql, [from_key, to_key])
+                txn.execute(sql, [from_key.stream, to_key.get_max_stream_pos()])
             else:
                 sql = """
-                    SELECT room_id, receipt_type, user_id, event_id, data
+                    SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data
                     FROM receipts_linearized WHERE
                     stream_id <= ?
                     ORDER BY stream_id DESC
                     LIMIT 100
                 """
 
-                txn.execute(sql, [to_key])
+                txn.execute(sql, [to_key.get_max_stream_pos()])
 
-            return cast(List[Tuple[str, str, str, str, str]], txn.fetchall())
+            return [
+                (room_id, receipt_type, user_id, event_id, data)
+                for stream_id, instance_name, room_id, receipt_type, user_id, event_id, data in txn
+                if MultiWriterStreamToken.is_stream_position_in_range(
+                    from_key, to_key, instance_name, stream_id
+                )
+            ]
 
         txn_results = await self.db_pool.runInteraction(
             "get_linearized_receipts_for_all_rooms", f
@@ -545,10 +613,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data
                 FROM receipts_linearized
                 WHERE ? < stream_id AND stream_id <= ?
+                AND instance_name = ?
                 ORDER BY stream_id ASC
                 LIMIT ?
             """
-            txn.execute(sql, (last_id, current_id, limit))
+            txn.execute(sql, (last_id, current_id, instance_name, limit))
 
             updates = cast(
                 List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]],
@@ -695,6 +764,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             keyvalues=keyvalues,
             values={
                 "stream_id": stream_id,
+                "instance_name": self._instance_name,
                 "event_id": event_id,
                 "event_stream_ordering": stream_ordering,
                 "data": json_encoder.encode(data),
@@ -750,7 +820,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         event_ids: List[str],
         thread_id: Optional[str],
         data: dict,
-    ) -> Optional[int]:
+    ) -> Optional[PersistedPosition]:
         """Insert a receipt, either from local client or remote server.
 
         Automatically does conversion between linearized and graph
@@ -812,7 +882,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             data,
         )
 
-        return stream_id
+        return PersistedPosition(self._instance_name, stream_id)
 
     async def _insert_graph_receipt(
         self,