diff options
Diffstat (limited to 'synapse/replication/slave/storage/groups.py')
-rw-r--r-- | synapse/replication/slave/storage/groups.py | 30 |
1 files changed, 11 insertions, 19 deletions
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 28a46edd28..1851e7d525 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -13,16 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage import DataStore +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.storage.data_stores.main.group_server import GroupServerWorkerStore +from synapse.storage.database import Database from synapse.util.caches.stream_change_cache import StreamChangeCache -from ._base import BaseSlavedStore, __func__ -from ._slaved_id_tracker import SlavedIdTracker - -class SlavedGroupServerStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedGroupServerStore, self).__init__(db_conn, hs) +class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): + def __init__(self, database: Database, db_conn, hs): + super(SlavedGroupServerStore, self).__init__(database, db_conn, hs) self.hs = hs @@ -34,21 +34,13 @@ class SlavedGroupServerStore(BaseSlavedStore): self._group_updates_id_gen.get_current_token(), ) - get_groups_changes_for_user = __func__(DataStore.get_groups_changes_for_user) - get_group_stream_token = __func__(DataStore.get_group_stream_token) - get_all_groups_for_user = __func__(DataStore.get_all_groups_for_user) - - def stream_positions(self): - result = super(SlavedGroupServerStore, self).stream_positions() - result["groups"] = self._group_updates_id_gen.get_current_token() - return result + def get_group_stream_token(self): + return self._group_updates_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "groups": self._group_updates_id_gen.advance(token) for row in rows: self._group_updates_stream_cache.entity_has_changed(row.user_id, token) - return super(SlavedGroupServerStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) |