summary refs log tree commit diff
path: root/synapse/replication/slave/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/slave/storage')
-rw-r--r--synapse/replication/slave/storage/push_rule.py24
1 files changed, 7 insertions, 17 deletions
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 83e880fdd2..bb2c40b6e3 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,29 +16,15 @@
 
 from .events import SlavedEventStore
 from ._slaved_id_tracker import SlavedIdTracker
-from synapse.storage import DataStore
-from synapse.storage.push_rule import PushRuleStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.storage.push_rule import PushRulesWorkerStore
 
 
-class SlavedPushRuleStore(SlavedEventStore):
+class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore):
     def __init__(self, db_conn, hs):
-        super(SlavedPushRuleStore, self).__init__(db_conn, hs)
         self._push_rules_stream_id_gen = SlavedIdTracker(
             db_conn, "push_rules_stream", "stream_id",
         )
-        self.push_rules_stream_cache = StreamChangeCache(
-            "PushRulesStreamChangeCache",
-            self._push_rules_stream_id_gen.get_current_token(),
-        )
-
-    get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"]
-    get_push_rules_enabled_for_user = (
-        PushRuleStore.__dict__["get_push_rules_enabled_for_user"]
-    )
-    have_push_rules_changed_for_user = (
-        DataStore.have_push_rules_changed_for_user.__func__
-    )
+        super(SlavedPushRuleStore, self).__init__(db_conn, hs)
 
     def get_push_rules_stream_token(self):
         return (
@@ -45,6 +32,9 @@ class SlavedPushRuleStore(SlavedEventStore):
             self._stream_id_gen.get_current_token(),
         )
 
+    def get_max_push_rules_stream_id(self):
+        return self._push_rules_stream_id_gen.get_current_token()
+
     def stream_positions(self):
         result = super(SlavedPushRuleStore, self).stream_positions()
         result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()