diff options
Diffstat (limited to 'synapse/storage/push_rule.py')
-rw-r--r-- | synapse/storage/push_rule.py | 113 |
1 files changed, 77 insertions, 36 deletions
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 8758b1c0c7..6a5028961d 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,14 +14,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList -from synapse.push.baserules import list_with_base_rules -from synapse.api.constants import EventTypes +import abc +import logging + +from canonicaljson import json + from twisted.internet import defer -import logging -import simplejson as json +from synapse.push.baserules import list_with_base_rules +from synapse.storage.appservice import ApplicationServiceWorkerStore +from synapse.storage.pusher import PusherWorkerStore +from synapse.storage.receipts import ReceiptsWorkerStore +from synapse.storage.roommember import RoomMemberWorkerStore +from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList +from synapse.util.caches.stream_change_cache import StreamChangeCache + +from ._base import SQLBaseStore logger = logging.getLogger(__name__) @@ -48,7 +57,43 @@ def _load_rules(rawrules, enabled_map): return rules -class PushRuleStore(SQLBaseStore): +class PushRulesWorkerStore(ApplicationServiceWorkerStore, + ReceiptsWorkerStore, + PusherWorkerStore, + RoomMemberWorkerStore, + SQLBaseStore): + """This is an abstract base class where subclasses must implement + `get_max_push_rules_stream_id` which can be called in the initializer. + """ + + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + + def __init__(self, db_conn, hs): + super(PushRulesWorkerStore, self).__init__(db_conn, hs) + + push_rules_prefill, push_rules_id = self._get_cache_dict( + db_conn, "push_rules_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=self.get_max_push_rules_stream_id(), + ) + + self.push_rules_stream_cache = StreamChangeCache( + "PushRulesStreamChangeCache", push_rules_id, + prefilled_cache=push_rules_prefill, + ) + + @abc.abstractmethod + def get_max_push_rules_stream_id(self): + """Get the position of the push rules stream. + + Returns: + int + """ + raise NotImplementedError() + @cachedInlineCallbacks(max_entries=5000) def get_push_rules_for_user(self, user_id): rows = yield self._simple_select_list( @@ -89,6 +134,22 @@ class PushRuleStore(SQLBaseStore): r['rule_id']: False if r['enabled'] == 0 else True for r in results }) + def have_push_rules_changed_for_user(self, user_id, last_id): + if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): + return defer.succeed(False) + else: + def have_push_rules_changed_txn(txn): + sql = ( + "SELECT COUNT(stream_id) FROM push_rules_stream" + " WHERE user_id = ? AND ? < stream_id" + ) + txn.execute(sql, (user_id, last_id)) + count, = txn.fetchone() + return bool(count) + return self.runInteraction( + "have_push_rules_changed", have_push_rules_changed_txn + ) + @cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1, inlineCallbacks=True) def bulk_get_push_rules(self, user_ids): @@ -124,6 +185,7 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(results) + @defer.inlineCallbacks def bulk_get_push_rules_for_room(self, event, context): state_group = context.state_group if not state_group: @@ -133,9 +195,11 @@ class PushRuleStore(SQLBaseStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - return self._bulk_get_push_rules_for_room( - event.room_id, state_group, context.current_state_ids, event=event + current_state_ids = yield context.get_current_state_ids(self) + result = yield self._bulk_get_push_rules_for_room( + event.room_id, state_group, current_state_ids, event=event ) + defer.returnValue(result) @cachedInlineCallbacks(num_args=2, cache_context=True) def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids, @@ -185,18 +249,6 @@ class PushRuleStore(SQLBaseStore): if uid in local_users_in_room: user_ids.add(uid) - forgotten = yield self.who_forgot_in_room( - event.room_id, on_invalidate=cache_context.invalidate, - ) - - for row in forgotten: - user_id = row["user_id"] - event_id = row["event_id"] - - mem_id = current_state_ids.get((EventTypes.Member, user_id), None) - if event_id == mem_id: - user_ids.discard(user_id) - rules_by_user = yield self.bulk_get_push_rules( user_ids, on_invalidate=cache_context.invalidate, ) @@ -228,6 +280,8 @@ class PushRuleStore(SQLBaseStore): results.setdefault(row['user_name'], {})[row['rule_id']] = enabled defer.returnValue(results) + +class PushRuleStore(PushRulesWorkerStore): @defer.inlineCallbacks def add_push_rule( self, user_id, rule_id, priority_class, conditions, actions, @@ -526,21 +580,8 @@ class PushRuleStore(SQLBaseStore): room stream ordering it corresponds to.""" return self._push_rules_stream_id_gen.get_current_token() - def have_push_rules_changed_for_user(self, user_id, last_id): - if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): - return defer.succeed(False) - else: - def have_push_rules_changed_txn(txn): - sql = ( - "SELECT COUNT(stream_id) FROM push_rules_stream" - " WHERE user_id = ? AND ? < stream_id" - ) - txn.execute(sql, (user_id, last_id)) - count, = txn.fetchone() - return bool(count) - return self.runInteraction( - "have_push_rules_changed", have_push_rules_changed_txn - ) + def get_max_push_rules_stream_id(self): + return self.get_push_rules_stream_token()[0] class RuleNotFoundException(Exception): |