diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 48384e238c..1c771e48f7 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -57,10 +57,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.types import Cursor
-from synapse.storage.util.id_generators import (
- AbstractStreamIdGenerator,
- StreamIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import (
JsonDict,
JsonMapping,
@@ -99,19 +96,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
# In the worker store this is an ID tracker which we overwrite in the non-worker
# class below that is used on the main process.
- self._device_list_id_gen = StreamIdGenerator(
- db_conn,
- hs.get_replication_notifier(),
- "device_lists_stream",
- "stream_id",
- extra_tables=[
- ("user_signature_stream", "stream_id"),
- ("device_lists_outbound_pokes", "stream_id"),
- ("device_lists_changes_in_room", "stream_id"),
- ("device_lists_remote_pending", "stream_id"),
- ("device_lists_changes_converted_stream_position", "stream_id"),
+ self._device_list_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ notifier=hs.get_replication_notifier(),
+ stream_name="device_lists_stream",
+ instance_name=self._instance_name,
+ tables=[
+ ("device_lists_stream", "instance_name", "stream_id"),
+ ("user_signature_stream", "instance_name", "stream_id"),
+ ("device_lists_outbound_pokes", "instance_name", "stream_id"),
+ ("device_lists_changes_in_room", "instance_name", "stream_id"),
+ ("device_lists_remote_pending", "instance_name", "stream_id"),
],
- is_writer=hs.config.worker.worker_app is None,
+ sequence_name="device_lists_sequence",
+ writers=["master"],
)
device_list_max = self._device_list_id_gen.get_current_token()
@@ -762,6 +761,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
"stream_id": stream_id,
"from_user_id": from_user_id,
"user_ids": json_encoder.encode(user_ids),
+ "instance_name": self._instance_name,
},
)
@@ -1582,6 +1582,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
):
super().__init__(database, db_conn, hs)
+ self._instance_name = hs.get_instance_name()
+
self.db_pool.updates.register_background_index_update(
"device_lists_stream_idx",
index_name="device_lists_stream_user_id",
@@ -1694,6 +1696,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
"device_lists_outbound_pokes",
{
"stream_id": stream_id,
+ "instance_name": self._instance_name,
"destination": destination,
"user_id": user_id,
"device_id": device_id,
@@ -1730,10 +1733,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
- # Because we have write access, this will be a StreamIdGenerator
- # (see DeviceWorkerStore.__init__)
- _device_list_id_gen: AbstractStreamIdGenerator
-
def __init__(
self,
database: DatabasePool,
@@ -2092,9 +2091,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_stream",
- keys=("stream_id", "user_id", "device_id"),
+ keys=("instance_name", "stream_id", "user_id", "device_id"),
values=[
- (stream_id, user_id, device_id)
+ (self._instance_name, stream_id, user_id, device_id)
for stream_id, device_id in zip(stream_ids, device_ids)
],
)
@@ -2124,6 +2123,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
values = [
(
destination,
+ self._instance_name,
next(stream_id_iterator),
user_id,
device_id,
@@ -2139,6 +2139,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
table="device_lists_outbound_pokes",
keys=(
"destination",
+ "instance_name",
"stream_id",
"user_id",
"device_id",
@@ -2157,7 +2158,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_id,
{
stream_id: destination
- for (destination, stream_id, _, _, _, _, _) in values
+ for (destination, _, stream_id, _, _, _, _, _) in values
},
)
@@ -2210,6 +2211,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"device_id",
"room_id",
"stream_id",
+ "instance_name",
"converted_to_destinations",
"opentracing_context",
),
@@ -2219,6 +2221,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_id,
room_id,
stream_id,
+ self._instance_name,
# We only need to calculate outbound pokes for local users
not self.hs.is_mine_id(user_id),
encoded_context,
@@ -2338,7 +2341,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"user_id": user_id,
"device_id": device_id,
},
- values={"stream_id": stream_id},
+ values={
+ "stream_id": stream_id,
+ "instance_name": self._instance_name,
+ },
desc="add_remote_device_list_to_pending",
)
|