diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 1851e7d525..567b4a5cc1 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -15,13 +15,14 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage.data_stores.main.group_server import GroupServerWorkerStore
-from synapse.storage.database import Database
+from synapse.replication.tcp.streams import GroupServerStream
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.group_server import GroupServerWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedGroupServerStore, self).__init__(database, db_conn, hs)
self.hs = hs
@@ -38,8 +39,8 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
return self._group_updates_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "groups":
- self._group_updates_id_gen.advance(token)
+ if stream_name == GroupServerStream.NAME:
+ self._group_updates_id_gen.advance(instance_name, token)
for row in rows:
self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
|