summary refs log tree commit diff
path: root/synapse/storage/databases/main/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/devices.py')
-rw-r--r--synapse/storage/databases/main/devices.py36
1 files changed, 23 insertions, 13 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 3e5c16b15b..aa58c2adc3 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -38,6 +38,7 @@ from synapse.logging.opentracing import (
     whitelisted_homeserver,
 )
 from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
@@ -85,19 +86,28 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     ):
         super().__init__(database, db_conn, hs)
 
-        # In the worker store this is an ID tracker which we overwrite in the non-worker
-        # class below that is used on the main process.
-        self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
-            db_conn,
-            "device_lists_stream",
-            "stream_id",
-            extra_tables=[
-                ("user_signature_stream", "stream_id"),
-                ("device_lists_outbound_pokes", "stream_id"),
-                ("device_lists_changes_in_room", "stream_id"),
-            ],
-            is_writer=hs.config.worker.worker_app is None,
-        )
+        if hs.config.worker.worker_app is None:
+            self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+                db_conn,
+                "device_lists_stream",
+                "stream_id",
+                extra_tables=[
+                    ("user_signature_stream", "stream_id"),
+                    ("device_lists_outbound_pokes", "stream_id"),
+                    ("device_lists_changes_in_room", "stream_id"),
+                ],
+            )
+        else:
+            self._device_list_id_gen = SlavedIdTracker(
+                db_conn,
+                "device_lists_stream",
+                "stream_id",
+                extra_tables=[
+                    ("user_signature_stream", "stream_id"),
+                    ("device_lists_outbound_pokes", "stream_id"),
+                    ("device_lists_changes_in_room", "stream_id"),
+                ],
+            )
 
         # Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a
         # StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker).