summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erikj@matrix.org>2023-10-25 16:16:19 +0100
committerGitHub <noreply@github.com>2023-10-25 16:16:19 +0100
commitba47fea5286e084ec70d568aa62eb4820b857c47 (patch)
tree6e2c608feb1ea0c23b2b9cc40d11211cc3a10aa5
parentFix tests on Twisted trunk. (#16528) (diff)
downloadsynapse-ba47fea5286e084ec70d568aa62eb4820b857c47.tar.xz
Allow multiple workers to write to receipts stream. (#16432)
Fixes #16417
Diffstat (limited to '')
-rw-r--r--changelog.d/16432.feature1
-rw-r--r--synapse/config/workers.py4
-rw-r--r--synapse/handlers/appservice.py42
-rw-r--r--synapse/handlers/initial_sync.py2
-rw-r--r--synapse/handlers/receipts.py19
-rw-r--r--synapse/handlers/sync.py7
-rw-r--r--synapse/notifier.py45
-rw-r--r--synapse/replication/tcp/client.py3
-rw-r--r--synapse/storage/databases/main/receipts.py148
-rw-r--r--synapse/storage/databases/main/relations.py4
-rw-r--r--synapse/storage/schema/main/delta/83/03_instance_name_receipts.sql.sqlite17
-rw-r--r--synapse/streams/events.py4
-rw-r--r--synapse/types/__init__.py137
-rw-r--r--tests/handlers/test_appservice.py17
-rw-r--r--tests/replication/test_sharded_receipts.py243
15 files changed, 604 insertions, 89 deletions
diff --git a/changelog.d/16432.feature b/changelog.d/16432.feature
new file mode 100644
index 0000000000..9a76e85592
--- /dev/null
+++ b/changelog.d/16432.feature
@@ -0,0 +1 @@
+Allow multiple workers to write to receipts stream.
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index f1766088fc..6d67a8cd5c 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -358,9 +358,9 @@ class WorkerConfig(Config):
                 "Must only specify one instance to handle `account_data` messages."
             )
 
-        if len(self.writers.receipts) != 1:
+        if len(self.writers.receipts) == 0:
             raise ConfigError(
-                "Must only specify one instance to handle `receipts` messages."
+                "Must specify at least one instance to handle `receipts` messages."
             )
 
         if len(self.writers.events) == 0:
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index c200a45f3a..873dadc3bd 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -47,6 +47,7 @@ from synapse.types import (
     DeviceListUpdates,
     JsonDict,
     JsonMapping,
+    MultiWriterStreamToken,
     RoomAlias,
     RoomStreamToken,
     StreamKeyType,
@@ -217,7 +218,7 @@ class ApplicationServicesHandler:
     def notify_interested_services_ephemeral(
         self,
         stream_key: StreamKeyType,
-        new_token: Union[int, RoomStreamToken],
+        new_token: Union[int, RoomStreamToken, MultiWriterStreamToken],
         users: Collection[Union[str, UserID]],
     ) -> None:
         """
@@ -259,19 +260,6 @@ class ApplicationServicesHandler:
         ):
             return
 
-        # Assert that new_token is an integer (and not a RoomStreamToken).
-        # All of the supported streams that this function handles use an
-        # integer to track progress (rather than a RoomStreamToken - a
-        # vector clock implementation) as they don't support multiple
-        # stream writers.
-        #
-        # As a result, we simply assert that new_token is an integer.
-        # If we do end up needing to pass a RoomStreamToken down here
-        # in the future, using RoomStreamToken.stream (the minimum stream
-        # position) to convert to an ascending integer value should work.
-        # Additional context: https://github.com/matrix-org/synapse/pull/11137
-        assert isinstance(new_token, int)
-
         # Ignore to-device messages if the feature flag is not enabled
         if (
             stream_key == StreamKeyType.TO_DEVICE
@@ -286,6 +274,9 @@ class ApplicationServicesHandler:
         ):
             return
 
+        # We know we're not a `RoomStreamToken` at this point.
+        assert not isinstance(new_token, RoomStreamToken)
+
         # Check whether there are any appservices which have registered to receive
         # ephemeral events.
         #
@@ -327,7 +318,7 @@ class ApplicationServicesHandler:
         self,
         services: List[ApplicationService],
         stream_key: StreamKeyType,
-        new_token: int,
+        new_token: Union[int, MultiWriterStreamToken],
         users: Collection[Union[str, UserID]],
     ) -> None:
         logger.debug("Checking interested services for %s", stream_key)
@@ -340,6 +331,7 @@ class ApplicationServicesHandler:
                     #
                     # Instead we simply grab the latest typing updates in _handle_typing
                     # and, if they apply to this application service, send it off.
+                    assert isinstance(new_token, int)
                     events = await self._handle_typing(service, new_token)
                     if events:
                         self.scheduler.enqueue_for_appservice(service, ephemeral=events)
@@ -350,15 +342,23 @@ class ApplicationServicesHandler:
                     (service.id, stream_key)
                 ):
                     if stream_key == StreamKeyType.RECEIPT:
+                        assert isinstance(new_token, MultiWriterStreamToken)
+
+                        # We store appservice tokens as integers, so we ignore
+                        # the `instance_map` components and instead simply
+                        # follow the base stream position.
+                        new_token = MultiWriterStreamToken(stream=new_token.stream)
+
                         events = await self._handle_receipts(service, new_token)
                         self.scheduler.enqueue_for_appservice(service, ephemeral=events)
 
                         # Persist the latest handled stream token for this appservice
                         await self.store.set_appservice_stream_type_pos(
-                            service, "read_receipt", new_token
+                            service, "read_receipt", new_token.stream
                         )
 
                     elif stream_key == StreamKeyType.PRESENCE:
+                        assert isinstance(new_token, int)
                         events = await self._handle_presence(service, users, new_token)
                         self.scheduler.enqueue_for_appservice(service, ephemeral=events)
 
@@ -368,6 +368,7 @@ class ApplicationServicesHandler:
                         )
 
                     elif stream_key == StreamKeyType.TO_DEVICE:
+                        assert isinstance(new_token, int)
                         # Retrieve a list of to-device message events, as well as the
                         # maximum stream token of the messages we were able to retrieve.
                         to_device_messages = await self._get_to_device_messages(
@@ -383,6 +384,7 @@ class ApplicationServicesHandler:
                         )
 
                     elif stream_key == StreamKeyType.DEVICE_LIST:
+                        assert isinstance(new_token, int)
                         device_list_summary = await self._get_device_list_summary(
                             service, new_token
                         )
@@ -432,7 +434,7 @@ class ApplicationServicesHandler:
         return typing
 
     async def _handle_receipts(
-        self, service: ApplicationService, new_token: int
+        self, service: ApplicationService, new_token: MultiWriterStreamToken
     ) -> List[JsonMapping]:
         """
         Return the latest read receipts that the given application service should receive.
@@ -455,15 +457,17 @@ class ApplicationServicesHandler:
         from_key = await self.store.get_type_stream_id_for_appservice(
             service, "read_receipt"
         )
-        if new_token is not None and new_token <= from_key:
+        if new_token is not None and new_token.stream <= from_key:
             logger.debug(
                 "Rejecting token lower than or equal to stored: %s" % (new_token,)
             )
             return []
 
+        from_token = MultiWriterStreamToken(stream=from_key)
+
         receipts_source = self.event_sources.sources.receipt
         receipts, _ = await receipts_source.get_new_events_as(
-            service=service, from_key=from_key, to_key=new_token
+            service=service, from_key=from_token, to_key=new_token
         )
         return receipts
 
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index c34bd7db95..b1d8be866f 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -145,7 +145,7 @@ class InitialSyncHandler:
         joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN]
         receipt = await self.store.get_linearized_receipts_for_rooms(
             joined_rooms,
-            to_key=int(now_token.receipt_key),
+            to_key=now_token.receipt_key,
         )
 
         receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id)
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 69ac468f75..b5f7a8b47e 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -20,6 +20,7 @@ from synapse.streams import EventSource
 from synapse.types import (
     JsonDict,
     JsonMapping,
+    MultiWriterStreamToken,
     ReadReceipt,
     StreamKeyType,
     UserID,
@@ -200,7 +201,7 @@ class ReceiptsHandler:
             await self.federation_sender.send_read_receipt(receipt)
 
 
-class ReceiptEventSource(EventSource[int, JsonMapping]):
+class ReceiptEventSource(EventSource[MultiWriterStreamToken, JsonMapping]):
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastores().main
         self.config = hs.config
@@ -273,13 +274,12 @@ class ReceiptEventSource(EventSource[int, JsonMapping]):
     async def get_new_events(
         self,
         user: UserID,
-        from_key: int,
+        from_key: MultiWriterStreamToken,
         limit: int,
         room_ids: Iterable[str],
         is_guest: bool,
         explicit_room_id: Optional[str] = None,
-    ) -> Tuple[List[JsonMapping], int]:
-        from_key = int(from_key)
+    ) -> Tuple[List[JsonMapping], MultiWriterStreamToken]:
         to_key = self.get_current_key()
 
         if from_key == to_key:
@@ -296,8 +296,11 @@ class ReceiptEventSource(EventSource[int, JsonMapping]):
         return events, to_key
 
     async def get_new_events_as(
-        self, from_key: int, to_key: int, service: ApplicationService
-    ) -> Tuple[List[JsonMapping], int]:
+        self,
+        from_key: MultiWriterStreamToken,
+        to_key: MultiWriterStreamToken,
+        service: ApplicationService,
+    ) -> Tuple[List[JsonMapping], MultiWriterStreamToken]:
         """Returns a set of new read receipt events that an appservice
         may be interested in.
 
@@ -312,8 +315,6 @@ class ReceiptEventSource(EventSource[int, JsonMapping]):
                   appservice may be interested in.
                 * The current read receipt stream token.
         """
-        from_key = int(from_key)
-
         if from_key == to_key:
             return [], to_key
 
@@ -333,5 +334,5 @@ class ReceiptEventSource(EventSource[int, JsonMapping]):
 
         return events, to_key
 
-    def get_current_key(self) -> int:
+    def get_current_key(self) -> MultiWriterStreamToken:
         return self.store.get_max_receipt_stream_id()
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index f131c0e8e0..f75c1548ca 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -57,6 +57,7 @@ from synapse.types import (
     DeviceListUpdates,
     JsonDict,
     JsonMapping,
+    MultiWriterStreamToken,
     MutableStateMap,
     Requester,
     RoomStreamToken,
@@ -477,7 +478,11 @@ class SyncHandler:
                 event_copy = {k: v for (k, v) in event.items() if k != "room_id"}
                 ephemeral_by_room.setdefault(room_id, []).append(event_copy)
 
-            receipt_key = since_token.receipt_key if since_token else 0
+            receipt_key = (
+                since_token.receipt_key
+                if since_token
+                else MultiWriterStreamToken(stream=0)
+            )
 
             receipt_source = self.event_sources.sources.receipt
             receipts, receipt_key = await receipt_source.get_new_events(
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 99e7715896..ee0bd84f1e 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -21,11 +21,13 @@ from typing import (
     Dict,
     Iterable,
     List,
+    Literal,
     Optional,
     Set,
     Tuple,
     TypeVar,
     Union,
+    overload,
 )
 
 import attr
@@ -44,6 +46,7 @@ from synapse.metrics import LaterGauge
 from synapse.streams.config import PaginationConfig
 from synapse.types import (
     JsonDict,
+    MultiWriterStreamToken,
     PersistedEventPosition,
     RoomStreamToken,
     StrCollection,
@@ -127,7 +130,7 @@ class _NotifierUserStream:
     def notify(
         self,
         stream_key: StreamKeyType,
-        stream_id: Union[int, RoomStreamToken],
+        stream_id: Union[int, RoomStreamToken, MultiWriterStreamToken],
         time_now_ms: int,
     ) -> None:
         """Notify any listeners for this user of a new event from an
@@ -452,10 +455,48 @@ class Notifier:
         except Exception:
             logger.exception("Error pusher pool of event")
 
+    @overload
+    def on_new_event(
+        self,
+        stream_key: Literal[StreamKeyType.ROOM],
+        new_token: RoomStreamToken,
+        users: Optional[Collection[Union[str, UserID]]] = None,
+        rooms: Optional[StrCollection] = None,
+    ) -> None:
+        ...
+
+    @overload
+    def on_new_event(
+        self,
+        stream_key: Literal[StreamKeyType.RECEIPT],
+        new_token: MultiWriterStreamToken,
+        users: Optional[Collection[Union[str, UserID]]] = None,
+        rooms: Optional[StrCollection] = None,
+    ) -> None:
+        ...
+
+    @overload
+    def on_new_event(
+        self,
+        stream_key: Literal[
+            StreamKeyType.ACCOUNT_DATA,
+            StreamKeyType.DEVICE_LIST,
+            StreamKeyType.PRESENCE,
+            StreamKeyType.PUSH_RULES,
+            StreamKeyType.TO_DEVICE,
+            StreamKeyType.TYPING,
+            StreamKeyType.UN_PARTIAL_STATED_ROOMS,
+        ],
+        new_token: int,
+        users: Optional[Collection[Union[str, UserID]]] = None,
+        rooms: Optional[StrCollection] = None,
+    ) -> None:
+        ...
+
     def on_new_event(
         self,
         stream_key: StreamKeyType,
-        new_token: Union[int, RoomStreamToken],
+        new_token: Union[int, RoomStreamToken, MultiWriterStreamToken],
         users: Optional[Collection[Union[str, UserID]]] = None,
         rooms: Optional[StrCollection] = None,
     ) -> None:
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 384355698d..1312b6f21e 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -126,8 +126,9 @@ class ReplicationDataHandler:
                 StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows]
             )
         elif stream_name == ReceiptsStream.NAME:
+            new_token = self.store.get_max_receipt_stream_id()
             self.notifier.on_new_event(
-                StreamKeyType.RECEIPT, token, rooms=[row.room_id for row in rows]
+                StreamKeyType.RECEIPT, new_token, rooms=[row.room_id for row in rows]
             )
             await self._pusher_pool.on_new_receipts({row.user_id for row in rows})
         elif stream_name == ToDeviceStream.NAME:
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index b2645ab43c..56e8eb16a8 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -28,6 +28,8 @@ from typing import (
     cast,
 )
 
+from immutabledict import immutabledict
+
 from synapse.api.constants import EduTypes
 from synapse.replication.tcp.streams import ReceiptsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@@ -43,7 +45,12 @@ from synapse.storage.util.id_generators import (
     MultiWriterIdGenerator,
     StreamIdGenerator,
 )
-from synapse.types import JsonDict, JsonMapping
+from synapse.types import (
+    JsonDict,
+    JsonMapping,
+    MultiWriterStreamToken,
+    PersistedPosition,
+)
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -105,7 +112,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             "receipts_linearized",
             entity_column="room_id",
             stream_column="stream_id",
-            max_value=max_receipts_stream_id,
+            max_value=max_receipts_stream_id.stream,
             limit=10000,
         )
         self._receipts_stream_cache = StreamChangeCache(
@@ -114,9 +121,31 @@ class ReceiptsWorkerStore(SQLBaseStore):
             prefilled_cache=receipts_stream_prefill,
         )
 
-    def get_max_receipt_stream_id(self) -> int:
+    def get_max_receipt_stream_id(self) -> MultiWriterStreamToken:
         """Get the current max stream ID for receipts stream"""
-        return self._receipts_id_gen.get_current_token()
+
+        min_pos = self._receipts_id_gen.get_current_token()
+
+        positions = {}
+        if isinstance(self._receipts_id_gen, MultiWriterIdGenerator):
+            # The `min_pos` is the minimum position that we know all instances
+            # have finished persisting to, so we only care about instances whose
+            # positions are ahead of that. (Instance positions can be behind the
+            # min position as there are times we can work out that the minimum
+            # position is ahead of the naive minimum across all current
+            # positions. See MultiWriterIdGenerator for details)
+            positions = {
+                i: p
+                for i, p in self._receipts_id_gen.get_positions().items()
+                if p > min_pos
+            }
+
+        return MultiWriterStreamToken(
+            stream=min_pos, instance_map=immutabledict(positions)
+        )
+
+    def get_receipt_stream_id_for_instance(self, instance_name: str) -> int:
+        return self._receipts_id_gen.get_current_token_for_writer(instance_name)
 
     def get_last_unthreaded_receipt_for_user_txn(
         self,
@@ -257,7 +286,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
         }
 
     async def get_linearized_receipts_for_rooms(
-        self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None
+        self,
+        room_ids: Iterable[str],
+        to_key: MultiWriterStreamToken,
+        from_key: Optional[MultiWriterStreamToken] = None,
     ) -> List[JsonMapping]:
         """Get receipts for multiple rooms for sending to clients.
 
@@ -276,7 +308,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             # Only ask the database about rooms where there have been new
             # receipts added since `from_key`
             room_ids = self._receipts_stream_cache.get_entities_changed(
-                room_ids, from_key
+                room_ids, from_key.stream
             )
 
         results = await self._get_linearized_receipts_for_rooms(
@@ -286,7 +318,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
         return [ev for res in results.values() for ev in res]
 
     async def get_linearized_receipts_for_room(
-        self, room_id: str, to_key: int, from_key: Optional[int] = None
+        self,
+        room_id: str,
+        to_key: MultiWriterStreamToken,
+        from_key: Optional[MultiWriterStreamToken] = None,
     ) -> Sequence[JsonMapping]:
         """Get receipts for a single room for sending to clients.
 
@@ -302,36 +337,49 @@ class ReceiptsWorkerStore(SQLBaseStore):
         if from_key is not None:
             # Check the cache first to see if any new receipts have been added
             # since`from_key`. If not we can no-op.
-            if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
+            if not self._receipts_stream_cache.has_entity_changed(
+                room_id, from_key.stream
+            ):
                 return []
 
         return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
 
     @cached(tree=True)
     async def _get_linearized_receipts_for_room(
-        self, room_id: str, to_key: int, from_key: Optional[int] = None
+        self,
+        room_id: str,
+        to_key: MultiWriterStreamToken,
+        from_key: Optional[MultiWriterStreamToken] = None,
     ) -> Sequence[JsonMapping]:
         """See get_linearized_receipts_for_room"""
 
         def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str]]:
             if from_key:
-                sql = (
-                    "SELECT receipt_type, user_id, event_id, data"
-                    " FROM receipts_linearized WHERE"
-                    " room_id = ? AND stream_id > ? AND stream_id <= ?"
-                )
+                sql = """
+                    SELECT stream_id, instance_name, receipt_type, user_id, event_id, data
+                    FROM receipts_linearized
+                    WHERE room_id = ? AND stream_id > ? AND stream_id <= ?
+                """
 
-                txn.execute(sql, (room_id, from_key, to_key))
-            else:
-                sql = (
-                    "SELECT receipt_type, user_id, event_id, data"
-                    " FROM receipts_linearized WHERE"
-                    " room_id = ? AND stream_id <= ?"
+                txn.execute(
+                    sql, (room_id, from_key.stream, to_key.get_max_stream_pos())
                 )
+            else:
+                sql = """
+                    SELECT stream_id, instance_name, receipt_type, user_id, event_id, data
+                    FROM receipts_linearized WHERE
+                    room_id = ? AND stream_id <= ?
+                """
 
-                txn.execute(sql, (room_id, to_key))
+                txn.execute(sql, (room_id, to_key.get_max_stream_pos()))
 
-            return cast(List[Tuple[str, str, str, str]], txn.fetchall())
+            return [
+                (receipt_type, user_id, event_id, data)
+                for stream_id, instance_name, receipt_type, user_id, event_id, data in txn
+                if MultiWriterStreamToken.is_stream_position_in_range(
+                    from_key, to_key, instance_name, stream_id
+                )
+            ]
 
         rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
 
@@ -352,7 +400,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
         num_args=3,
     )
     async def _get_linearized_receipts_for_rooms(
-        self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
+        self,
+        room_ids: Collection[str],
+        to_key: MultiWriterStreamToken,
+        from_key: Optional[MultiWriterStreamToken] = None,
     ) -> Mapping[str, Sequence[JsonMapping]]:
         if not room_ids:
             return {}
@@ -362,7 +413,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
         ) -> List[Tuple[str, str, str, str, Optional[str], str]]:
             if from_key:
                 sql = """
-                    SELECT room_id, receipt_type, user_id, event_id, thread_id, data
+                    SELECT stream_id, instance_name, room_id, receipt_type,
+                        user_id, event_id, thread_id, data
                     FROM receipts_linearized WHERE
                     stream_id > ? AND stream_id <= ? AND
                 """
@@ -370,10 +422,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
                     self.database_engine, "room_id", room_ids
                 )
 
-                txn.execute(sql + clause, [from_key, to_key] + list(args))
+                txn.execute(
+                    sql + clause,
+                    [from_key.stream, to_key.get_max_stream_pos()] + list(args),
+                )
             else:
                 sql = """
-                    SELECT room_id, receipt_type, user_id, event_id, thread_id, data
+                    SELECT stream_id, instance_name, room_id, receipt_type,
+                        user_id, event_id, thread_id, data
                     FROM receipts_linearized WHERE
                     stream_id <= ? AND
                 """
@@ -382,11 +438,15 @@ class ReceiptsWorkerStore(SQLBaseStore):
                     self.database_engine, "room_id", room_ids
                 )
 
-                txn.execute(sql + clause, [to_key] + list(args))
+                txn.execute(sql + clause, [to_key.get_max_stream_pos()] + list(args))
 
-            return cast(
-                List[Tuple[str, str, str, str, Optional[str], str]], txn.fetchall()
-            )
+            return [
+                (room_id, receipt_type, user_id, event_id, thread_id, data)
+                for stream_id, instance_name, room_id, receipt_type, user_id, event_id, thread_id, data in txn
+                if MultiWriterStreamToken.is_stream_position_in_range(
+                    from_key, to_key, instance_name, stream_id
+                )
+            ]
 
         txn_results = await self.db_pool.runInteraction(
             "_get_linearized_receipts_for_rooms", f
@@ -420,7 +480,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
         num_args=2,
     )
     async def get_linearized_receipts_for_all_rooms(
-        self, to_key: int, from_key: Optional[int] = None
+        self,
+        to_key: MultiWriterStreamToken,
+        from_key: Optional[MultiWriterStreamToken] = None,
     ) -> Mapping[str, JsonMapping]:
         """Get receipts for all rooms between two stream_ids, up
         to a limit of the latest 100 read receipts.
@@ -437,25 +499,31 @@ class ReceiptsWorkerStore(SQLBaseStore):
         def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]:
             if from_key:
                 sql = """
-                    SELECT room_id, receipt_type, user_id, event_id, data
+                    SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data
                     FROM receipts_linearized WHERE
                     stream_id > ? AND stream_id <= ?
                     ORDER BY stream_id DESC
                     LIMIT 100
                 """
-                txn.execute(sql, [from_key, to_key])
+                txn.execute(sql, [from_key.stream, to_key.get_max_stream_pos()])
             else:
                 sql = """
-                    SELECT room_id, receipt_type, user_id, event_id, data
+                    SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data
                     FROM receipts_linearized WHERE
                     stream_id <= ?
                     ORDER BY stream_id DESC
                     LIMIT 100
                 """
 
-                txn.execute(sql, [to_key])
+                txn.execute(sql, [to_key.get_max_stream_pos()])
 
-            return cast(List[Tuple[str, str, str, str, str]], txn.fetchall())
+            return [
+                (room_id, receipt_type, user_id, event_id, data)
+                for stream_id, instance_name, room_id, receipt_type, user_id, event_id, data in txn
+                if MultiWriterStreamToken.is_stream_position_in_range(
+                    from_key, to_key, instance_name, stream_id
+                )
+            ]
 
         txn_results = await self.db_pool.runInteraction(
             "get_linearized_receipts_for_all_rooms", f
@@ -545,10 +613,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data
                 FROM receipts_linearized
                 WHERE ? < stream_id AND stream_id <= ?
+                AND instance_name = ?
                 ORDER BY stream_id ASC
                 LIMIT ?
             """
-            txn.execute(sql, (last_id, current_id, limit))
+            txn.execute(sql, (last_id, current_id, instance_name, limit))
 
             updates = cast(
                 List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]],
@@ -695,6 +764,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             keyvalues=keyvalues,
             values={
                 "stream_id": stream_id,
+                "instance_name": self._instance_name,
                 "event_id": event_id,
                 "event_stream_ordering": stream_ordering,
                 "data": json_encoder.encode(data),
@@ -750,7 +820,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         event_ids: List[str],
         thread_id: Optional[str],
         data: dict,
-    ) -> Optional[int]:
+    ) -> Optional[PersistedPosition]:
         """Insert a receipt, either from local client or remote server.
 
         Automatically does conversion between linearized and graph
@@ -812,7 +882,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             data,
         )
 
-        return stream_id
+        return PersistedPosition(self._instance_name, stream_id)
 
     async def _insert_graph_receipt(
         self,
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 7f40e2c446..ce7bfd5146 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -47,7 +47,7 @@ from synapse.storage.databases.main.stream import (
     generate_pagination_where_clause,
 )
 from synapse.storage.engines import PostgresEngine
-from synapse.types import JsonDict, StreamKeyType, StreamToken
+from synapse.types import JsonDict, MultiWriterStreamToken, StreamKeyType, StreamToken
 from synapse.util.caches.descriptors import cached, cachedList
 
 if TYPE_CHECKING:
@@ -314,7 +314,7 @@ class RelationsWorkerStore(SQLBaseStore):
                         room_key=next_key,
                         presence_key=0,
                         typing_key=0,
-                        receipt_key=0,
+                        receipt_key=MultiWriterStreamToken(stream=0),
                         account_data_key=0,
                         push_rules_key=0,
                         to_device_key=0,
diff --git a/synapse/storage/schema/main/delta/83/03_instance_name_receipts.sql.sqlite b/synapse/storage/schema/main/delta/83/03_instance_name_receipts.sql.sqlite
new file mode 100644
index 0000000000..6c7ad0fd37
--- /dev/null
+++ b/synapse/storage/schema/main/delta/83/03_instance_name_receipts.sql.sqlite
@@ -0,0 +1,17 @@
+/* Copyright 2023 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This already exists on Postgres.
+ALTER TABLE receipts_linearized ADD COLUMN instance_name TEXT;
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 609a0978a9..d0bb83b184 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -23,7 +23,7 @@ from synapse.handlers.room import RoomEventSource
 from synapse.handlers.typing import TypingNotificationEventSource
 from synapse.logging.opentracing import trace
 from synapse.streams import EventSource
-from synapse.types import StreamKeyType, StreamToken
+from synapse.types import MultiWriterStreamToken, StreamKeyType, StreamToken
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -111,7 +111,7 @@ class EventSources:
             room_key=await self.sources.room.get_current_key_for_room(room_id),
             presence_key=0,
             typing_key=0,
-            receipt_key=0,
+            receipt_key=MultiWriterStreamToken(stream=0),
             account_data_key=0,
             push_rules_key=0,
             to_device_key=0,
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 09a88c86a7..4c5b26ad93 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -695,6 +695,90 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
             return "s%d" % (self.stream,)
 
 
+@attr.s(frozen=True, slots=True, order=False)
+class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
+    """A basic stream token class for streams that supports multiple writers."""
+
+    @classmethod
+    async def parse(cls, store: "DataStore", string: str) -> "MultiWriterStreamToken":
+        try:
+            if string[0].isdigit():
+                return cls(stream=int(string))
+            if string[0] == "m":
+                parts = string[1:].split("~")
+                stream = int(parts[0])
+
+                instance_map = {}
+                for part in parts[1:]:
+                    key, value = part.split(".")
+                    instance_id = int(key)
+                    pos = int(value)
+
+                    instance_name = await store.get_name_from_instance_id(instance_id)
+                    instance_map[instance_name] = pos
+
+                return cls(
+                    stream=stream,
+                    instance_map=immutabledict(instance_map),
+                )
+        except CancelledError:
+            raise
+        except Exception:
+            pass
+        raise SynapseError(400, "Invalid stream token %r" % (string,))
+
+    async def to_string(self, store: "DataStore") -> str:
+        if self.instance_map:
+            entries = []
+            for name, pos in self.instance_map.items():
+                if pos <= self.stream:
+                    # Ignore instances who are below the minimum stream position
+                    # (we might know they've advanced without seeing a recent
+                    # write from them).
+                    continue
+
+                instance_id = await store.get_id_for_instance(name)
+                entries.append(f"{instance_id}.{pos}")
+
+            encoded_map = "~".join(entries)
+            return f"m{self.stream}~{encoded_map}"
+        else:
+            return str(self.stream)
+
+    @staticmethod
+    def is_stream_position_in_range(
+        low: Optional["AbstractMultiWriterStreamToken"],
+        high: Optional["AbstractMultiWriterStreamToken"],
+        instance_name: Optional[str],
+        pos: int,
+    ) -> bool:
+        """Checks if a given persisted position is between the two given tokens.
+
+        If `instance_name` is None then the row was persisted before multi
+        writer support.
+        """
+
+        if low:
+            if instance_name:
+                low_stream = low.instance_map.get(instance_name, low.stream)
+            else:
+                low_stream = low.stream
+
+            if pos <= low_stream:
+                return False
+
+        if high:
+            if instance_name:
+                high_stream = high.instance_map.get(instance_name, high.stream)
+            else:
+                high_stream = high.stream
+
+            if high_stream < pos:
+                return False
+
+        return True
+
+
 class StreamKeyType(Enum):
     """Known stream types.
 
@@ -776,7 +860,9 @@ class StreamToken:
     )
     presence_key: int
     typing_key: int
-    receipt_key: int
+    receipt_key: MultiWriterStreamToken = attr.ib(
+        validator=attr.validators.instance_of(MultiWriterStreamToken)
+    )
     account_data_key: int
     push_rules_key: int
     to_device_key: int
@@ -799,8 +885,31 @@ class StreamToken:
             while len(keys) < len(attr.fields(cls)):
                 # i.e. old token from before receipt_key
                 keys.append("0")
+
+            (
+                room_key,
+                presence_key,
+                typing_key,
+                receipt_key,
+                account_data_key,
+                push_rules_key,
+                to_device_key,
+                device_list_key,
+                groups_key,
+                un_partial_stated_rooms_key,
+            ) = keys
+
             return cls(
-                await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:])
+                room_key=await RoomStreamToken.parse(store, room_key),
+                presence_key=int(presence_key),
+                typing_key=int(typing_key),
+                receipt_key=await MultiWriterStreamToken.parse(store, receipt_key),
+                account_data_key=int(account_data_key),
+                push_rules_key=int(push_rules_key),
+                to_device_key=int(to_device_key),
+                device_list_key=int(device_list_key),
+                groups_key=int(groups_key),
+                un_partial_stated_rooms_key=int(un_partial_stated_rooms_key),
             )
         except CancelledError:
             raise
@@ -813,7 +922,7 @@ class StreamToken:
                 await self.room_key.to_string(store),
                 str(self.presence_key),
                 str(self.typing_key),
-                str(self.receipt_key),
+                await self.receipt_key.to_string(store),
                 str(self.account_data_key),
                 str(self.push_rules_key),
                 str(self.to_device_key),
@@ -841,6 +950,11 @@ class StreamToken:
                 StreamKeyType.ROOM, self.room_key.copy_and_advance(new_value)
             )
             return new_token
+        elif key == StreamKeyType.RECEIPT:
+            new_token = self.copy_and_replace(
+                StreamKeyType.RECEIPT, self.receipt_key.copy_and_advance(new_value)
+            )
+            return new_token
 
         new_token = self.copy_and_replace(key, new_value)
         new_id = new_token.get_field(key)
@@ -859,6 +973,10 @@ class StreamToken:
         ...
 
     @overload
+    def get_field(self, key: Literal[StreamKeyType.RECEIPT]) -> MultiWriterStreamToken:
+        ...
+
+    @overload
     def get_field(
         self,
         key: Literal[
@@ -866,7 +984,6 @@ class StreamToken:
             StreamKeyType.DEVICE_LIST,
             StreamKeyType.PRESENCE,
             StreamKeyType.PUSH_RULES,
-            StreamKeyType.RECEIPT,
             StreamKeyType.TO_DEVICE,
             StreamKeyType.TYPING,
             StreamKeyType.UN_PARTIAL_STATED_ROOMS,
@@ -875,15 +992,21 @@ class StreamToken:
         ...
 
     @overload
-    def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]:
+    def get_field(
+        self, key: StreamKeyType
+    ) -> Union[int, RoomStreamToken, MultiWriterStreamToken]:
         ...
 
-    def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]:
+    def get_field(
+        self, key: StreamKeyType
+    ) -> Union[int, RoomStreamToken, MultiWriterStreamToken]:
         """Returns the stream ID for the given key."""
         return getattr(self, key.value)
 
 
-StreamToken.START = StreamToken(RoomStreamToken(stream=0), 0, 0, 0, 0, 0, 0, 0, 0, 0)
+StreamToken.START = StreamToken(
+    RoomStreamToken(stream=0), 0, 0, MultiWriterStreamToken(stream=0), 0, 0, 0, 0, 0, 0
+)
 
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index c888d1ff01..78646cb5dc 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -31,7 +31,12 @@ from synapse.appservice import (
 from synapse.handlers.appservice import ApplicationServicesHandler
 from synapse.rest.client import login, receipts, register, room, sendtodevice
 from synapse.server import HomeServer
-from synapse.types import JsonDict, RoomStreamToken, StreamKeyType
+from synapse.types import (
+    JsonDict,
+    MultiWriterStreamToken,
+    RoomStreamToken,
+    StreamKeyType,
+)
 from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
@@ -305,7 +310,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         )
 
         self.handler.notify_interested_services_ephemeral(
-            StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"]
+            StreamKeyType.RECEIPT,
+            MultiWriterStreamToken(stream=580),
+            ["@fakerecipient:example.com"],
         )
         self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
             interested_service, ephemeral=[event]
@@ -333,7 +340,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         )
 
         self.handler.notify_interested_services_ephemeral(
-            StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"]
+            StreamKeyType.RECEIPT,
+            MultiWriterStreamToken(stream=580),
+            ["@fakerecipient:example.com"],
         )
         # This method will be called, but with an empty list of events
         self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
@@ -636,7 +645,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
                 self.hs.get_application_service_handler()._notify_interested_services_ephemeral(
                     services=[interested_appservice],
                     stream_key=StreamKeyType.RECEIPT,
-                    new_token=stream_token,
+                    new_token=MultiWriterStreamToken(stream=stream_token),
                     users=[self.exclusive_as_user],
                 )
             )
diff --git a/tests/replication/test_sharded_receipts.py b/tests/replication/test_sharded_receipts.py
new file mode 100644
index 0000000000..41876b36de
--- /dev/null
+++ b/tests/replication/test_sharded_receipts.py
@@ -0,0 +1,243 @@
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.constants import ReceiptTypes
+from synapse.rest import admin
+from synapse.rest.client import login, receipts, room, sync
+from synapse.server import HomeServer
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.types import StreamToken
+from synapse.util import Clock
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.server import make_request
+
+logger = logging.getLogger(__name__)
+
+
+class ReceiptsShardTestCase(BaseMultiWorkerStreamTestCase):
+    """Checks receipts sharding works"""
+
+    servlets = [
+        admin.register_servlets_for_client_rest_resource,
+        room.register_servlets,
+        login.register_servlets,
+        sync.register_servlets,
+        receipts.register_servlets,
+    ]
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        # Register a user who sends a message that we'll get notified about
+        self.other_user_id = self.register_user("otheruser", "pass")
+        self.other_access_token = self.login("otheruser", "pass")
+
+        self.room_creator = self.hs.get_room_creation_handler()
+        self.store = hs.get_datastores().main
+
+    def default_config(self) -> dict:
+        conf = super().default_config()
+        conf["stream_writers"] = {"receipts": ["worker1", "worker2"]}
+        conf["instance_map"] = {
+            "main": {"host": "testserv", "port": 8765},
+            "worker1": {"host": "testserv", "port": 1001},
+            "worker2": {"host": "testserv", "port": 1002},
+        }
+        return conf
+
+    def test_basic(self) -> None:
+        """Simple test to ensure that receipts can be sent on multiple
+        workers.
+        """
+
+        worker1 = self.make_worker_hs(
+            "synapse.app.generic_worker",
+            {"worker_name": "worker1"},
+        )
+        worker1_site = self._hs_to_site[worker1]
+
+        worker2 = self.make_worker_hs(
+            "synapse.app.generic_worker",
+            {"worker_name": "worker2"},
+        )
+        worker2_site = self._hs_to_site[worker2]
+
+        user_id = self.register_user("user", "pass")
+        access_token = self.login("user", "pass")
+
+        # Create a room
+        room_id = self.helper.create_room_as(user_id, tok=access_token)
+
+        # The other user joins
+        self.helper.join(
+            room=room_id, user=self.other_user_id, tok=self.other_access_token
+        )
+
+        # First user sends a message, the other users sends a receipt.
+        response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token)
+        event_id = response["event_id"]
+
+        channel = make_request(
+            reactor=self.reactor,
+            site=worker1_site,
+            method="POST",
+            path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{event_id}",
+            access_token=access_token,
+            content={},
+        )
+        self.assertEqual(200, channel.code)
+
+        # Now we do it again using the second worker
+        response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token)
+        event_id = response["event_id"]
+
+        channel = make_request(
+            reactor=self.reactor,
+            site=worker2_site,
+            method="POST",
+            path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{event_id}",
+            access_token=access_token,
+            content={},
+        )
+        self.assertEqual(200, channel.code)
+
+    def test_vector_clock_token(self) -> None:
+        """Tests that using a stream token with a vector clock component works
+        correctly with basic /sync usage.
+        """
+
+        worker_hs1 = self.make_worker_hs(
+            "synapse.app.generic_worker",
+            {"worker_name": "worker1"},
+        )
+        worker1_site = self._hs_to_site[worker_hs1]
+
+        worker_hs2 = self.make_worker_hs(
+            "synapse.app.generic_worker",
+            {"worker_name": "worker2"},
+        )
+        worker2_site = self._hs_to_site[worker_hs2]
+
+        sync_hs = self.make_worker_hs(
+            "synapse.app.generic_worker",
+            {"worker_name": "sync"},
+        )
+        sync_hs_site = self._hs_to_site[sync_hs]
+
+        user_id = self.register_user("user", "pass")
+        access_token = self.login("user", "pass")
+
+        store = self.hs.get_datastores().main
+
+        room_id = self.helper.create_room_as(user_id, tok=access_token)
+
+        # The other user joins
+        self.helper.join(
+            room=room_id, user=self.other_user_id, tok=self.other_access_token
+        )
+
+        response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token)
+        first_event = response["event_id"]
+
+        # Do an initial sync so that we're up to date.
+        channel = make_request(
+            self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token
+        )
+        next_batch = channel.json_body["next_batch"]
+
+        # We now gut wrench into the events stream MultiWriterIdGenerator on
+        # worker2 to mimic it getting stuck persisting a receipt. This ensures
+        # that when we send an event on worker1 we end up in a state where
+        # worker2 events stream position lags that on worker1, resulting in a
+        # receipts token with a non-empty instance map component.
+        #
+        # Worker2's receipts stream position will not advance until we call
+        # __aexit__ again.
+        worker_store2 = worker_hs2.get_datastores().main
+        assert isinstance(worker_store2._receipts_id_gen, MultiWriterIdGenerator)
+
+        actx = worker_store2._receipts_id_gen.get_next()
+        self.get_success(actx.__aenter__())
+
+        channel = make_request(
+            reactor=self.reactor,
+            site=worker1_site,
+            method="POST",
+            path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{first_event}",
+            access_token=access_token,
+            content={},
+        )
+        self.assertEqual(200, channel.code)
+
+        # Assert that the current stream token has an instance map component, as
+        # we are trying to test vector clock tokens.
+        receipts_token = store.get_max_receipt_stream_id()
+        self.assertGreater(len(receipts_token.instance_map), 0)
+
+        # Check that syncing still gets the new receipt, despite the gap in the
+        # stream IDs.
+        channel = make_request(
+            self.reactor,
+            sync_hs_site,
+            "GET",
+            f"/sync?since={next_batch}",
+            access_token=access_token,
+        )
+
+        # We should only see the new event and nothing else
+        self.assertIn(room_id, channel.json_body["rooms"]["join"])
+
+        events = channel.json_body["rooms"]["join"][room_id]["ephemeral"]["events"]
+        self.assertEqual(len(events), 1)
+        self.assertIn(first_event, events[0]["content"])
+
+        # Get the next batch and makes sure its a vector clock style token.
+        vector_clock_token = channel.json_body["next_batch"]
+        parsed_token = self.get_success(
+            StreamToken.from_string(store, vector_clock_token)
+        )
+        self.assertGreaterEqual(len(parsed_token.receipt_key.instance_map), 1)
+
+        # Now that we've got a vector clock token we finish the fake persisting
+        # a receipt we started above.
+        self.get_success(actx.__aexit__(None, None, None))
+
+        # Now try and send another receipts to the other worker.
+        response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token)
+        second_event = response["event_id"]
+
+        channel = make_request(
+            reactor=self.reactor,
+            site=worker2_site,
+            method="POST",
+            path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{second_event}",
+            access_token=access_token,
+            content={},
+        )
+
+        channel = make_request(
+            self.reactor,
+            sync_hs_site,
+            "GET",
+            f"/sync?since={vector_clock_token}",
+            access_token=access_token,
+        )
+
+        self.assertIn(room_id, channel.json_body["rooms"]["join"])
+
+        events = channel.json_body["rooms"]["join"][room_id]["ephemeral"]["events"]
+        self.assertEqual(len(events), 1)
+        self.assertIn(second_event, events[0]["content"])