diff options
-rw-r--r-- | synapse/api/filtering.py | 48 | ||||
-rw-r--r-- | synapse/push/bulk_push_rule_evaluator.py | 23 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/filter.py | 2 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/sync.py | 24 | ||||
-rw-r--r-- | synapse/storage/push_rule.py | 29 | ||||
-rw-r--r-- | tests/api/test_filtering.py | 2 |
6 files changed, 92 insertions, 36 deletions
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index c7f021d1ff..5530b8c48f 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -28,14 +28,14 @@ class Filtering(object): return result def add_user_filter(self, user_localpart, user_filter): - self._check_valid_filter(user_filter) + self.check_valid_filter(user_filter) return self.store.add_user_filter(user_localpart, user_filter) # TODO(paul): surely we should probably add a delete_user_filter or # replace_user_filter at some point? There's no REST API specified for # them however - def _check_valid_filter(self, user_filter_json): + def check_valid_filter(self, user_filter_json): """Check if the provided filter is valid. This inspects all definitions contained within the filter. @@ -129,52 +129,55 @@ class Filtering(object): class FilterCollection(object): def __init__(self, filter_json): - self.filter_json = filter_json + self._filter_json = filter_json - room_filter_json = self.filter_json.get("room", {}) + room_filter_json = self._filter_json.get("room", {}) - self.room_filter = Filter({ + self._room_filter = Filter({ k: v for k, v in room_filter_json.items() if k in ("rooms", "not_rooms") }) - self.room_timeline_filter = Filter(room_filter_json.get("timeline", {})) - self.room_state_filter = Filter(room_filter_json.get("state", {})) - self.room_ephemeral_filter = Filter(room_filter_json.get("ephemeral", {})) - self.room_account_data = Filter(room_filter_json.get("account_data", {})) - self.presence_filter = Filter(self.filter_json.get("presence", {})) - self.account_data = Filter(self.filter_json.get("account_data", {})) + self._room_timeline_filter = Filter(room_filter_json.get("timeline", {})) + self._room_state_filter = Filter(room_filter_json.get("state", {})) + self._room_ephemeral_filter = Filter(room_filter_json.get("ephemeral", {})) + self._room_account_data = Filter(room_filter_json.get("account_data", {})) + self._presence_filter = Filter(filter_json.get("presence", {})) + self._account_data = Filter(filter_json.get("account_data", {})) - self.include_leave = self.filter_json.get("room", {}).get( + self.include_leave = filter_json.get("room", {}).get( "include_leave", False ) + def get_filter_json(self): + return self._filter_json + def timeline_limit(self): - return self.room_timeline_filter.limit() + return self._room_timeline_filter.limit() def presence_limit(self): - return self.presence_filter.limit() + return self._presence_filter.limit() def ephemeral_limit(self): - return self.room_ephemeral_filter.limit() + return self._room_ephemeral_filter.limit() def filter_presence(self, events): - return self.presence_filter.filter(events) + return self._presence_filter.filter(events) def filter_account_data(self, events): - return self.account_data.filter(events) + return self._account_data.filter(events) def filter_room_state(self, events): - return self.room_state_filter.filter(self.room_filter.filter(events)) + return self._room_state_filter.filter(self._room_filter.filter(events)) def filter_room_timeline(self, events): - return self.room_timeline_filter.filter(self.room_filter.filter(events)) + return self._room_timeline_filter.filter(self._room_filter.filter(events)) def filter_room_ephemeral(self, events): - return self.room_ephemeral_filter.filter(self.room_filter.filter(events)) + return self._room_ephemeral_filter.filter(self._room_filter.filter(events)) def filter_room_account_data(self, events): - return self.room_account_data.filter(self.room_filter.filter(events)) + return self._room_account_data.filter(self._room_filter.filter(events)) class Filter(object): @@ -258,3 +261,6 @@ def _matches_wildcard(actual_value, filter_value): return actual_value.startswith(type_prefix) else: return actual_value == filter_value + + +DEFAULT_FILTER_COLLECTION = FilterCollection({}) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index b91c165e2b..20c60422bf 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -36,6 +36,7 @@ def decode_rule_json(rule): @defer.inlineCallbacks def _get_rules(room_id, user_ids, store): rules_by_user = yield store.bulk_get_push_rules(user_ids) + rules_enabled_by_user = yield store.bulk_get_push_rules_enabled(user_ids) rules_by_user = { uid: baserules.list_with_base_rules([ @@ -44,6 +45,26 @@ def _get_rules(room_id, user_ids, store): ]) for uid in user_ids } + + # We apply the rules-enabled map here: bulk_get_push_rules doesn't + # fetch disabled rules, but this won't account for any server default + # rules the user has disabled, so we need to do this too. + for uid in user_ids: + if uid not in rules_enabled_by_user: + continue + + user_enabled_map = rules_enabled_by_user[uid] + + for i, rule in enumerate(rules_by_user[uid]): + rule_id = rule['rule_id'] + + if rule_id in user_enabled_map: + if rule.get('enabled', True) != bool(user_enabled_map[rule_id]): + # Rules are cached across users. + rule = dict(rule) + rule['enabled'] = bool(user_enabled_map[rule_id]) + rules_by_user[uid][i] = rule + defer.returnValue(rules_by_user) @@ -119,7 +140,7 @@ class BulkPushRuleEvaluator: ) if matches: actions = [x for x in rule['actions'] if x != 'dont_notify'] - if actions: + if actions and 'notify' in actions: actions_by_user[uid] = actions break defer.returnValue(actions_by_user) diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index 7695bebc28..7c94f6ec41 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -59,7 +59,7 @@ class GetFilterRestServlet(RestServlet): filter_id=filter_id, ) - defer.returnValue((200, filter.filter_json)) + defer.returnValue((200, filter.get_filter_json())) except KeyError: raise SynapseError(400, "No such filter") diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 4114a7e430..ab924ad9e0 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -24,7 +24,7 @@ from synapse.events import FrozenEvent from synapse.events.utils import ( serialize_event, format_event_for_client_v2_without_room_id, ) -from synapse.api.filtering import FilterCollection +from synapse.api.filtering import FilterCollection, DEFAULT_FILTER_COLLECTION from synapse.api.errors import SynapseError from ._base import client_v2_patterns @@ -113,20 +113,20 @@ class SyncRestServlet(RestServlet): ) ) - if filter_id and filter_id.startswith('{'): - try: - filter_object = json.loads(filter_id) - except: - raise SynapseError(400, "Invalid filter JSON") - self.filtering._check_valid_filter(filter_object) - filter = FilterCollection(filter_object) - else: - try: + if filter_id: + if filter_id.startswith('{'): + try: + filter_object = json.loads(filter_id) + except: + raise SynapseError(400, "Invalid filter JSON") + self.filtering.check_valid_filter(filter_object) + filter = FilterCollection(filter_object) + else: filter = yield self.filtering.get_user_filter( user.localpart, filter_id ) - except: - filter = FilterCollection({}) + else: + filter = DEFAULT_FILTER_COLLECTION sync_config = SyncConfig( user=user, diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 2adfefd994..35ec7e8cef 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -94,6 +94,35 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(results) @defer.inlineCallbacks + def bulk_get_push_rules_enabled(self, user_ids): + if not user_ids: + defer.returnValue({}) + + batch_size = 100 + + def f(txn, user_ids_to_fetch): + sql = ( + "SELECT user_name, rule_id, enabled" + " FROM push_rules_enable" + " WHERE user_name" + " IN (" + ",".join("?" for _ in user_ids_to_fetch) + ")" + ) + txn.execute(sql, user_ids_to_fetch) + return self.cursor_to_dict(txn) + + results = {} + + chunks = [user_ids[i:i+batch_size] for i in xrange(0, len(user_ids), batch_size)] + for batch_user_ids in chunks: + rows = yield self.runInteraction( + "bulk_get_push_rules_enabled", f, batch_user_ids + ) + + for row in rows: + results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled'] + defer.returnValue(results) + + @defer.inlineCallbacks def add_push_rule(self, before, after, **kwargs): vals = kwargs if 'conditions' in vals: diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 14cddee679..16ee6bbe6a 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -504,4 +504,4 @@ class FilteringTestCase(unittest.TestCase): filter_id=filter_id, ) - self.assertEquals(filter.filter_json, user_filter_json) + self.assertEquals(filter.get_filter_json(), user_filter_json) |