diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 8295322b0e..d4c64c46ad 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -12,13 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import abc
import logging
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
+ Iterable,
List,
Mapping,
Optional,
@@ -30,7 +30,7 @@ from typing import (
from synapse.api.errors import StoreError
from synapse.config.homeserver import ExperimentalConfig
-from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@@ -84,14 +84,15 @@ def _load_rules(
push_rules = PushRules(ruleslist)
filtered_rules = FilteredPushRules(
- push_rules, enabled_map, msc3772_enabled=experimental_config.msc3772_enabled
+ push_rules,
+ enabled_map,
+ msc3664_enabled=experimental_config.msc3664_enabled,
+ msc1767_enabled=experimental_config.msc1767_enabled,
)
return filtered_rules
-# The ABCMeta metaclass ensures that it cannot be instantiated without
-# the abstract methods being implemented.
class PushRulesWorkerStore(
ApplicationServiceWorkerStore,
PusherWorkerStore,
@@ -99,7 +100,6 @@ class PushRulesWorkerStore(
ReceiptsWorkerStore,
EventsWorkerStore,
SQLBaseStore,
- metaclass=abc.ABCMeta,
):
"""This is an abstract base class where subclasses must implement
`get_max_push_rules_stream_id` which can be called in the initializer.
@@ -113,14 +113,14 @@ class PushRulesWorkerStore(
):
super().__init__(database, db_conn, hs)
- if hs.config.worker.worker_app is None:
- self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
- db_conn, "push_rules_stream", "stream_id"
- )
- else:
- self._push_rules_stream_id_gen = SlavedIdTracker(
- db_conn, "push_rules_stream", "stream_id"
- )
+ # In the worker store this is an ID tracker which we overwrite in the non-worker
+ # class below that is used on the main process.
+ self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+ db_conn,
+ "push_rules_stream",
+ "stream_id",
+ is_writer=hs.config.worker.worker_app is None,
+ )
push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
db_conn,
@@ -136,14 +136,23 @@ class PushRulesWorkerStore(
prefilled_cache=push_rules_prefill,
)
- @abc.abstractmethod
def get_max_push_rules_stream_id(self) -> int:
"""Get the position of the push rules stream.
Returns:
int
"""
- raise NotImplementedError()
+ return self._push_rules_stream_id_gen.get_current_token()
+
+ def process_replication_rows(
+ self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
+ ) -> None:
+ if stream_name == PushRulesStream.NAME:
+ self._push_rules_stream_id_gen.advance(instance_name, token)
+ for row in rows:
+ self.get_push_rules_for_user.invalidate((row.user_id,))
+ self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
@cached(max_entries=5000)
async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
|