summary refs log tree commit diff
path: root/synapse/storage/databases/main/room.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/room.py')
-rw-r--r--synapse/storage/databases/main/room.py19
1 files changed, 12 insertions, 7 deletions
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 7264a33cd4..6a65b2a89b 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -43,6 +43,7 @@ from synapse.api.errors import StoreError
 from synapse.api.room_versions import RoomVersion, RoomVersions
 from synapse.config.homeserver import HomeServerConfig
 from synapse.events import EventBase
+from synapse.replication.tcp.streams.partial_state import UnPartialStatedRoomStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
@@ -144,6 +145,13 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
                 "stream_id",
             )
 
+    def process_replication_position(
+        self, stream_name: str, instance_name: str, token: int
+    ) -> None:
+        if stream_name == UnPartialStatedRoomStream.NAME:
+            self._un_partial_stated_rooms_stream_id_gen.advance(instance_name, token)
+        return super().process_replication_position(stream_name, instance_name, token)
+
     async def store_room(
         self,
         room_id: str,
@@ -1281,13 +1289,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         )
         return result["join_event_id"], result["device_lists_stream_id"]
 
-    def get_un_partial_stated_rooms_token(self) -> int:
-        # TODO(faster_joins, multiple writers): This is inappropriate if there
-        #     are multiple writers because workers that don't write often will
-        #     hold all readers up.
-        #     (See `MultiWriterIdGenerator.get_persisted_upto_position` for an
-        #      explanation.)
-        return self._un_partial_stated_rooms_stream_id_gen.get_current_token()
+    def get_un_partial_stated_rooms_token(self, instance_name: str) -> int:
+        return self._un_partial_stated_rooms_stream_id_gen.get_current_token_for_writer(
+            instance_name
+        )
 
     async def get_un_partial_stated_rooms_from_stream(
         self, instance_name: str, last_id: int, current_id: int, limit: int