summary refs log tree commit diff
path: root/synapse/storage/push_rule.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/push_rule.py')
-rw-r--r--synapse/storage/push_rule.py113
1 files changed, 77 insertions, 36 deletions
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 8758b1c0c7..6a5028961d 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,14 +14,22 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
-from synapse.push.baserules import list_with_base_rules
-from synapse.api.constants import EventTypes
+import abc
+import logging
+
+from canonicaljson import json
+
 from twisted.internet import defer
 
-import logging
-import simplejson as json
+from synapse.push.baserules import list_with_base_rules
+from synapse.storage.appservice import ApplicationServiceWorkerStore
+from synapse.storage.pusher import PusherWorkerStore
+from synapse.storage.receipts import ReceiptsWorkerStore
+from synapse.storage.roommember import RoomMemberWorkerStore
+from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+from ._base import SQLBaseStore
 
 logger = logging.getLogger(__name__)
 
@@ -48,7 +57,43 @@ def _load_rules(rawrules, enabled_map):
     return rules
 
 
-class PushRuleStore(SQLBaseStore):
+class PushRulesWorkerStore(ApplicationServiceWorkerStore,
+                           ReceiptsWorkerStore,
+                           PusherWorkerStore,
+                           RoomMemberWorkerStore,
+                           SQLBaseStore):
+    """This is an abstract base class where subclasses must implement
+    `get_max_push_rules_stream_id` which can be called in the initializer.
+    """
+
+    # This ABCMeta metaclass ensures that we cannot be instantiated without
+    # the abstract methods being implemented.
+    __metaclass__ = abc.ABCMeta
+
+    def __init__(self, db_conn, hs):
+        super(PushRulesWorkerStore, self).__init__(db_conn, hs)
+
+        push_rules_prefill, push_rules_id = self._get_cache_dict(
+            db_conn, "push_rules_stream",
+            entity_column="user_id",
+            stream_column="stream_id",
+            max_value=self.get_max_push_rules_stream_id(),
+        )
+
+        self.push_rules_stream_cache = StreamChangeCache(
+            "PushRulesStreamChangeCache", push_rules_id,
+            prefilled_cache=push_rules_prefill,
+        )
+
+    @abc.abstractmethod
+    def get_max_push_rules_stream_id(self):
+        """Get the position of the push rules stream.
+
+        Returns:
+            int
+        """
+        raise NotImplementedError()
+
     @cachedInlineCallbacks(max_entries=5000)
     def get_push_rules_for_user(self, user_id):
         rows = yield self._simple_select_list(
@@ -89,6 +134,22 @@ class PushRuleStore(SQLBaseStore):
             r['rule_id']: False if r['enabled'] == 0 else True for r in results
         })
 
+    def have_push_rules_changed_for_user(self, user_id, last_id):
+        if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
+            return defer.succeed(False)
+        else:
+            def have_push_rules_changed_txn(txn):
+                sql = (
+                    "SELECT COUNT(stream_id) FROM push_rules_stream"
+                    " WHERE user_id = ? AND ? < stream_id"
+                )
+                txn.execute(sql, (user_id, last_id))
+                count, = txn.fetchone()
+                return bool(count)
+            return self.runInteraction(
+                "have_push_rules_changed", have_push_rules_changed_txn
+            )
+
     @cachedList(cached_method_name="get_push_rules_for_user",
                 list_name="user_ids", num_args=1, inlineCallbacks=True)
     def bulk_get_push_rules(self, user_ids):
@@ -124,6 +185,7 @@ class PushRuleStore(SQLBaseStore):
 
         defer.returnValue(results)
 
+    @defer.inlineCallbacks
     def bulk_get_push_rules_for_room(self, event, context):
         state_group = context.state_group
         if not state_group:
@@ -133,9 +195,11 @@ class PushRuleStore(SQLBaseStore):
             # To do this we set the state_group to a new object as object() != object()
             state_group = object()
 
-        return self._bulk_get_push_rules_for_room(
-            event.room_id, state_group, context.current_state_ids, event=event
+        current_state_ids = yield context.get_current_state_ids(self)
+        result = yield self._bulk_get_push_rules_for_room(
+            event.room_id, state_group, current_state_ids, event=event
         )
+        defer.returnValue(result)
 
     @cachedInlineCallbacks(num_args=2, cache_context=True)
     def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids,
@@ -185,18 +249,6 @@ class PushRuleStore(SQLBaseStore):
             if uid in local_users_in_room:
                 user_ids.add(uid)
 
-        forgotten = yield self.who_forgot_in_room(
-            event.room_id, on_invalidate=cache_context.invalidate,
-        )
-
-        for row in forgotten:
-            user_id = row["user_id"]
-            event_id = row["event_id"]
-
-            mem_id = current_state_ids.get((EventTypes.Member, user_id), None)
-            if event_id == mem_id:
-                user_ids.discard(user_id)
-
         rules_by_user = yield self.bulk_get_push_rules(
             user_ids, on_invalidate=cache_context.invalidate,
         )
@@ -228,6 +280,8 @@ class PushRuleStore(SQLBaseStore):
             results.setdefault(row['user_name'], {})[row['rule_id']] = enabled
         defer.returnValue(results)
 
+
+class PushRuleStore(PushRulesWorkerStore):
     @defer.inlineCallbacks
     def add_push_rule(
         self, user_id, rule_id, priority_class, conditions, actions,
@@ -526,21 +580,8 @@ class PushRuleStore(SQLBaseStore):
         room stream ordering it corresponds to."""
         return self._push_rules_stream_id_gen.get_current_token()
 
-    def have_push_rules_changed_for_user(self, user_id, last_id):
-        if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
-            return defer.succeed(False)
-        else:
-            def have_push_rules_changed_txn(txn):
-                sql = (
-                    "SELECT COUNT(stream_id) FROM push_rules_stream"
-                    " WHERE user_id = ? AND ? < stream_id"
-                )
-                txn.execute(sql, (user_id, last_id))
-                count, = txn.fetchone()
-                return bool(count)
-            return self.runInteraction(
-                "have_push_rules_changed", have_push_rules_changed_txn
-            )
+    def get_max_push_rules_stream_id(self):
+        return self.get_push_rules_stream_token()[0]
 
 
 class RuleNotFoundException(Exception):