summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/api/filtering.py48
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py23
-rw-r--r--synapse/rest/client/v2_alpha/filter.py2
-rw-r--r--synapse/rest/client/v2_alpha/sync.py24
-rw-r--r--synapse/storage/push_rule.py29
-rw-r--r--tests/api/test_filtering.py2
6 files changed, 92 insertions, 36 deletions
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index c7f021d1ff..5530b8c48f 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -28,14 +28,14 @@ class Filtering(object):
         return result
 
     def add_user_filter(self, user_localpart, user_filter):
-        self._check_valid_filter(user_filter)
+        self.check_valid_filter(user_filter)
         return self.store.add_user_filter(user_localpart, user_filter)
 
     # TODO(paul): surely we should probably add a delete_user_filter or
     #   replace_user_filter at some point? There's no REST API specified for
     #   them however
 
-    def _check_valid_filter(self, user_filter_json):
+    def check_valid_filter(self, user_filter_json):
         """Check if the provided filter is valid.
 
         This inspects all definitions contained within the filter.
@@ -129,52 +129,55 @@ class Filtering(object):
 
 class FilterCollection(object):
     def __init__(self, filter_json):
-        self.filter_json = filter_json
+        self._filter_json = filter_json
 
-        room_filter_json = self.filter_json.get("room", {})
+        room_filter_json = self._filter_json.get("room", {})
 
-        self.room_filter = Filter({
+        self._room_filter = Filter({
             k: v for k, v in room_filter_json.items()
             if k in ("rooms", "not_rooms")
         })
 
-        self.room_timeline_filter = Filter(room_filter_json.get("timeline", {}))
-        self.room_state_filter = Filter(room_filter_json.get("state", {}))
-        self.room_ephemeral_filter = Filter(room_filter_json.get("ephemeral", {}))
-        self.room_account_data = Filter(room_filter_json.get("account_data", {}))
-        self.presence_filter = Filter(self.filter_json.get("presence", {}))
-        self.account_data = Filter(self.filter_json.get("account_data", {}))
+        self._room_timeline_filter = Filter(room_filter_json.get("timeline", {}))
+        self._room_state_filter = Filter(room_filter_json.get("state", {}))
+        self._room_ephemeral_filter = Filter(room_filter_json.get("ephemeral", {}))
+        self._room_account_data = Filter(room_filter_json.get("account_data", {}))
+        self._presence_filter = Filter(filter_json.get("presence", {}))
+        self._account_data = Filter(filter_json.get("account_data", {}))
 
-        self.include_leave = self.filter_json.get("room", {}).get(
+        self.include_leave = filter_json.get("room", {}).get(
             "include_leave", False
         )
 
+    def get_filter_json(self):
+        return self._filter_json
+
     def timeline_limit(self):
-        return self.room_timeline_filter.limit()
+        return self._room_timeline_filter.limit()
 
     def presence_limit(self):
-        return self.presence_filter.limit()
+        return self._presence_filter.limit()
 
     def ephemeral_limit(self):
-        return self.room_ephemeral_filter.limit()
+        return self._room_ephemeral_filter.limit()
 
     def filter_presence(self, events):
-        return self.presence_filter.filter(events)
+        return self._presence_filter.filter(events)
 
     def filter_account_data(self, events):
-        return self.account_data.filter(events)
+        return self._account_data.filter(events)
 
     def filter_room_state(self, events):
-        return self.room_state_filter.filter(self.room_filter.filter(events))
+        return self._room_state_filter.filter(self._room_filter.filter(events))
 
     def filter_room_timeline(self, events):
-        return self.room_timeline_filter.filter(self.room_filter.filter(events))
+        return self._room_timeline_filter.filter(self._room_filter.filter(events))
 
     def filter_room_ephemeral(self, events):
-        return self.room_ephemeral_filter.filter(self.room_filter.filter(events))
+        return self._room_ephemeral_filter.filter(self._room_filter.filter(events))
 
     def filter_room_account_data(self, events):
-        return self.room_account_data.filter(self.room_filter.filter(events))
+        return self._room_account_data.filter(self._room_filter.filter(events))
 
 
 class Filter(object):
@@ -258,3 +261,6 @@ def _matches_wildcard(actual_value, filter_value):
         return actual_value.startswith(type_prefix)
     else:
         return actual_value == filter_value
+
+
+DEFAULT_FILTER_COLLECTION = FilterCollection({})
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index b91c165e2b..20c60422bf 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -36,6 +36,7 @@ def decode_rule_json(rule):
 @defer.inlineCallbacks
 def _get_rules(room_id, user_ids, store):
     rules_by_user = yield store.bulk_get_push_rules(user_ids)
+    rules_enabled_by_user = yield store.bulk_get_push_rules_enabled(user_ids)
 
     rules_by_user = {
         uid: baserules.list_with_base_rules([
@@ -44,6 +45,26 @@ def _get_rules(room_id, user_ids, store):
         ])
         for uid in user_ids
     }
+
+    # We apply the rules-enabled map here: bulk_get_push_rules doesn't
+    # fetch disabled rules, but this won't account for any server default
+    # rules the user has disabled, so we need to do this too.
+    for uid in user_ids:
+        if uid not in rules_enabled_by_user:
+            continue
+
+        user_enabled_map = rules_enabled_by_user[uid]
+
+        for i, rule in enumerate(rules_by_user[uid]):
+            rule_id = rule['rule_id']
+
+            if rule_id in user_enabled_map:
+                if rule.get('enabled', True) != bool(user_enabled_map[rule_id]):
+                    # Rules are cached across users.
+                    rule = dict(rule)
+                    rule['enabled'] = bool(user_enabled_map[rule_id])
+                    rules_by_user[uid][i] = rule
+
     defer.returnValue(rules_by_user)
 
 
@@ -119,7 +140,7 @@ class BulkPushRuleEvaluator:
                 )
                 if matches:
                     actions = [x for x in rule['actions'] if x != 'dont_notify']
-                    if actions:
+                    if actions and 'notify' in actions:
                         actions_by_user[uid] = actions
                     break
         defer.returnValue(actions_by_user)
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
index 7695bebc28..7c94f6ec41 100644
--- a/synapse/rest/client/v2_alpha/filter.py
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -59,7 +59,7 @@ class GetFilterRestServlet(RestServlet):
                 filter_id=filter_id,
             )
 
-            defer.returnValue((200, filter.filter_json))
+            defer.returnValue((200, filter.get_filter_json()))
         except KeyError:
             raise SynapseError(400, "No such filter")
 
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 4114a7e430..ab924ad9e0 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -24,7 +24,7 @@ from synapse.events import FrozenEvent
 from synapse.events.utils import (
     serialize_event, format_event_for_client_v2_without_room_id,
 )
-from synapse.api.filtering import FilterCollection
+from synapse.api.filtering import FilterCollection, DEFAULT_FILTER_COLLECTION
 from synapse.api.errors import SynapseError
 from ._base import client_v2_patterns
 
@@ -113,20 +113,20 @@ class SyncRestServlet(RestServlet):
             )
         )
 
-        if filter_id and filter_id.startswith('{'):
-            try:
-                filter_object = json.loads(filter_id)
-            except:
-                raise SynapseError(400, "Invalid filter JSON")
-            self.filtering._check_valid_filter(filter_object)
-            filter = FilterCollection(filter_object)
-        else:
-            try:
+        if filter_id:
+            if filter_id.startswith('{'):
+                try:
+                    filter_object = json.loads(filter_id)
+                except:
+                    raise SynapseError(400, "Invalid filter JSON")
+                self.filtering.check_valid_filter(filter_object)
+                filter = FilterCollection(filter_object)
+            else:
                 filter = yield self.filtering.get_user_filter(
                     user.localpart, filter_id
                 )
-            except:
-                filter = FilterCollection({})
+        else:
+            filter = DEFAULT_FILTER_COLLECTION
 
         sync_config = SyncConfig(
             user=user,
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 2adfefd994..35ec7e8cef 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -94,6 +94,35 @@ class PushRuleStore(SQLBaseStore):
         defer.returnValue(results)
 
     @defer.inlineCallbacks
+    def bulk_get_push_rules_enabled(self, user_ids):
+        if not user_ids:
+            defer.returnValue({})
+
+        batch_size = 100
+
+        def f(txn, user_ids_to_fetch):
+            sql = (
+                "SELECT user_name, rule_id, enabled"
+                " FROM push_rules_enable"
+                " WHERE user_name"
+                " IN (" + ",".join("?" for _ in user_ids_to_fetch) + ")"
+            )
+            txn.execute(sql, user_ids_to_fetch)
+            return self.cursor_to_dict(txn)
+
+        results = {}
+
+        chunks = [user_ids[i:i+batch_size] for i in xrange(0, len(user_ids), batch_size)]
+        for batch_user_ids in chunks:
+            rows = yield self.runInteraction(
+                "bulk_get_push_rules_enabled", f, batch_user_ids
+            )
+
+            for row in rows:
+                results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled']
+        defer.returnValue(results)
+
+    @defer.inlineCallbacks
     def add_push_rule(self, before, after, **kwargs):
         vals = kwargs
         if 'conditions' in vals:
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 14cddee679..16ee6bbe6a 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -504,4 +504,4 @@ class FilteringTestCase(unittest.TestCase):
             filter_id=filter_id,
         )
 
-        self.assertEquals(filter.filter_json, user_filter_json)
+        self.assertEquals(filter.get_filter_json(), user_filter_json)