summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/replication/tcp/streams/_base.py2
-rw-r--r--synapse/storage/databases/main/deviceinbox.py31
2 files changed, 20 insertions, 13 deletions
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 1f6402c2da..7514429d99 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -621,7 +621,7 @@ class ToDeviceStream(_StreamFromIdGen):
         super().__init__(
             hs.get_instance_name(),
             store.get_all_new_device_messages,
-            store._device_inbox_id_gen,
+            store._to_device_msg_id_gen,
         )
 
 
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 02dddd1da4..0541f4e93d 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -87,25 +87,32 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 self._instance_name in hs.config.worker.writers.to_device
             )
 
-            self._device_inbox_id_gen: AbstractStreamIdGenerator = (
+            self._to_device_msg_id_gen: AbstractStreamIdGenerator = (
                 MultiWriterIdGenerator(
                     db_conn=db_conn,
                     db=database,
                     notifier=hs.get_replication_notifier(),
                     stream_name="to_device",
                     instance_name=self._instance_name,
-                    tables=[("device_inbox", "instance_name", "stream_id")],
+                    tables=[
+                        ("device_inbox", "instance_name", "stream_id"),
+                        ("device_federation_outbox", "instance_name", "stream_id"),
+                    ],
                     sequence_name="device_inbox_sequence",
                     writers=hs.config.worker.writers.to_device,
                 )
             )
         else:
             self._can_write_to_device = True
-            self._device_inbox_id_gen = StreamIdGenerator(
-                db_conn, hs.get_replication_notifier(), "device_inbox", "stream_id"
+            self._to_device_msg_id_gen = StreamIdGenerator(
+                db_conn,
+                hs.get_replication_notifier(),
+                "device_inbox",
+                "stream_id",
+                extra_tables=[("device_federation_outbox", "stream_id")],
             )
 
-        max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
+        max_device_inbox_id = self._to_device_msg_id_gen.get_current_token()
         device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
             db_conn,
             "device_inbox",
@@ -145,8 +152,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
     ) -> None:
         if stream_name == ToDeviceStream.NAME:
             # If replication is happening than postgres must be being used.
-            assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator)
-            self._device_inbox_id_gen.advance(instance_name, token)
+            assert isinstance(self._to_device_msg_id_gen, MultiWriterIdGenerator)
+            self._to_device_msg_id_gen.advance(instance_name, token)
             for row in rows:
                 if row.entity.startswith("@"):
                     self._device_inbox_stream_cache.entity_has_changed(
@@ -162,11 +169,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         self, stream_name: str, instance_name: str, token: int
     ) -> None:
         if stream_name == ToDeviceStream.NAME:
-            self._device_inbox_id_gen.advance(instance_name, token)
+            self._to_device_msg_id_gen.advance(instance_name, token)
         super().process_replication_position(stream_name, instance_name, token)
 
     def get_to_device_stream_token(self) -> int:
-        return self._device_inbox_id_gen.get_current_token()
+        return self._to_device_msg_id_gen.get_current_token()
 
     async def get_messages_for_user_devices(
         self,
@@ -801,7 +808,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                                 msg.get(EventContentFields.TO_DEVICE_MSGID),
                             )
 
-        async with self._device_inbox_id_gen.get_next() as stream_id:
+        async with self._to_device_msg_id_gen.get_next() as stream_id:
             now_ms = self._clock.time_msec()
             await self.db_pool.runInteraction(
                 "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
@@ -813,7 +820,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                     destination, stream_id
                 )
 
-        return self._device_inbox_id_gen.get_current_token()
+        return self._to_device_msg_id_gen.get_current_token()
 
     async def add_messages_from_remote_to_device_inbox(
         self,
@@ -857,7 +864,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 txn, stream_id, local_messages_by_user_then_device
             )
 
-        async with self._device_inbox_id_gen.get_next() as stream_id:
+        async with self._to_device_msg_id_gen.get_next() as stream_id:
             now_ms = self._clock.time_msec()
             await self.db_pool.runInteraction(
                 "add_messages_from_remote_to_device_inbox",