diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 6e7fd259d4..533d927701 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -15,17 +15,18 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore
-from synapse.storage.database import Database
+from synapse.replication.tcp.streams import ToDeviceStream
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker(
- db_conn, "device_max_stream_id", "stream_id"
+ db_conn, "device_inbox", "stream_id"
)
self._device_inbox_stream_cache = StreamChangeCache(
"DeviceInboxStreamChangeCache",
@@ -44,8 +45,8 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "to_device":
- self._device_inbox_id_gen.advance(token)
+ if stream_name == ToDeviceStream.NAME:
+ self._device_inbox_id_gen.advance(instance_name, token)
for row in rows:
if row.entity.startswith("@"):
self._device_inbox_stream_cache.entity_has_changed(
|