From ddf9e7b3027eee61086ebfb447c5fa33e9b640fe Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 3 Mar 2016 14:57:45 +0000 Subject: Hook up the push rules to the notifier --- synapse/rest/client/v1/push_rule.py | 44 ++++++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 15 deletions(-) (limited to 'synapse/rest') diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 970a019223..cf68725ca1 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -36,6 +36,11 @@ class PushRuleRestServlet(ClientV1RestServlet): SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( "Unrecognised request: You probably wanted a trailing slash") + def __init__(self, hs): + super(PushRuleRestServlet, self).__init__(hs) + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() + @defer.inlineCallbacks def on_PUT(self, request): spec = _rule_spec_from_path(request.postpath) @@ -51,8 +56,11 @@ class PushRuleRestServlet(ClientV1RestServlet): content = _parse_json(request) + user_id = requester.user.to_string() + if 'attr' in spec: - yield self.set_rule_attr(requester.user.to_string(), spec, content) + yield self.set_rule_attr(user_id, spec, content) + self.notify_user(user_id) defer.returnValue((200, {})) if spec['rule_id'].startswith('.'): @@ -77,8 +85,8 @@ class PushRuleRestServlet(ClientV1RestServlet): after = _namespaced_rule_id(spec, after[0]) try: - yield self.hs.get_datastore().add_push_rule( - user_id=requester.user.to_string(), + yield self.store.add_push_rule( + user_id=user_id, rule_id=_namespaced_rule_id_from_spec(spec), priority_class=priority_class, conditions=conditions, @@ -86,6 +94,7 @@ class PushRuleRestServlet(ClientV1RestServlet): before=before, after=after ) + self.notify_user(user_id) except InconsistentRuleException as e: raise SynapseError(400, e.message) except RuleNotFoundException as e: @@ -98,13 +107,15 @@ class PushRuleRestServlet(ClientV1RestServlet): spec = _rule_spec_from_path(request.postpath) requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() namespaced_rule_id = _namespaced_rule_id_from_spec(spec) try: - yield self.hs.get_datastore().delete_push_rule( - requester.user.to_string(), namespaced_rule_id + yield self.store.delete_push_rule( + user_id, namespaced_rule_id ) + self.notify_user(user_id) defer.returnValue((200, {})) except StoreError as e: if e.code == 404: @@ -115,14 +126,12 @@ class PushRuleRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): requester = yield self.auth.get_user_by_req(request) - user = requester.user + user_id = requester.user.to_string() # we build up the full structure and then decide which bits of it # to send which means doing unnecessary work sometimes but is # is probably not going to make a whole lot of difference - rawrules = yield self.hs.get_datastore().get_push_rules_for_user( - user.to_string() - ) + rawrules = yield self.store.get_push_rules_for_user(user_id) ruleslist = [] for rawrule in rawrules: @@ -138,8 +147,7 @@ class PushRuleRestServlet(ClientV1RestServlet): rules['global'] = _add_empty_priority_class_arrays(rules['global']) - enabled_map = yield self.hs.get_datastore().\ - get_push_rules_enabled_for_user(user.to_string()) + enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) for r in ruleslist: rulearray = None @@ -152,9 +160,9 @@ class PushRuleRestServlet(ClientV1RestServlet): pattern_type = c.pop("pattern_type", None) if pattern_type == "user_id": - c["pattern"] = user.to_string() + c["pattern"] = user_id elif pattern_type == "user_localpart": - c["pattern"] = user.localpart + c["pattern"] = requester.user.localpart rulearray = rules['global'][template_name] @@ -188,6 +196,12 @@ class PushRuleRestServlet(ClientV1RestServlet): def on_OPTIONS(self, _): return 200, {} + def notify_user(self, user_id): + stream_id = self.store.get_push_rules_stream_token() + self.notifier.on_new_event( + "push_rules_key", stream_id, users=[user_id] + ) + def set_rule_attr(self, user_id, spec, val): if spec['attr'] == 'enabled': if isinstance(val, dict) and "enabled" in val: @@ -198,7 +212,7 @@ class PushRuleRestServlet(ClientV1RestServlet): # bools directly, so let's not break them. raise SynapseError(400, "Value for 'enabled' must be boolean") namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - return self.hs.get_datastore().set_push_rule_enabled( + return self.store.set_push_rule_enabled( user_id, namespaced_rule_id, val ) elif spec['attr'] == 'actions': @@ -210,7 +224,7 @@ class PushRuleRestServlet(ClientV1RestServlet): if is_default_rule: if namespaced_rule_id not in BASE_RULE_IDS: raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) - return self.hs.get_datastore().set_push_rule_actions( + return self.store.set_push_rule_actions( user_id, namespaced_rule_id, actions, is_default_rule ) else: -- cgit 1.5.1 From 3406eba4ef40de888ebb5b22c0ea4925b2dddeb1 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 3 Mar 2016 16:11:59 +0000 Subject: Move the code for formatting push rules into a separate function --- synapse/push/clientformat.py | 112 ++++++++++++++++++++++++++++++++++++ synapse/rest/client/v1/push_rule.py | 90 ++--------------------------- 2 files changed, 116 insertions(+), 86 deletions(-) create mode 100644 synapse/push/clientformat.py (limited to 'synapse/rest') diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py new file mode 100644 index 0000000000..ae9db9ec2f --- /dev/null +++ b/synapse/push/clientformat.py @@ -0,0 +1,112 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.push.baserules import list_with_base_rules + +from synapse.push.rulekinds import ( + PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP +) + +import copy +import simplejson as json + + +def format_push_rules_for_user(user, rawrules, enabled_map): + """Converts a list of rawrules and a enabled map into nested dictionaries + to match the Matrix client-server format for push rules""" + + ruleslist = [] + for rawrule in rawrules: + rule = dict(rawrule) + rule["conditions"] = json.loads(rawrule["conditions"]) + rule["actions"] = json.loads(rawrule["actions"]) + ruleslist.append(rule) + + # We're going to be mutating this a lot, so do a deep copy + ruleslist = copy.deepcopy(list_with_base_rules(ruleslist)) + + rules = {'global': {}, 'device': {}} + + rules['global'] = _add_empty_priority_class_arrays(rules['global']) + + for r in ruleslist: + rulearray = None + + template_name = _priority_class_to_template_name(r['priority_class']) + + # Remove internal stuff. + for c in r["conditions"]: + c.pop("_id", None) + + pattern_type = c.pop("pattern_type", None) + if pattern_type == "user_id": + c["pattern"] = user.to_string() + elif pattern_type == "user_localpart": + c["pattern"] = user.localpart + + rulearray = rules['global'][template_name] + + template_rule = _rule_to_template(r) + if template_rule: + if r['rule_id'] in enabled_map: + template_rule['enabled'] = enabled_map[r['rule_id']] + elif 'enabled' in r: + template_rule['enabled'] = r['enabled'] + else: + template_rule['enabled'] = True + rulearray.append(template_rule) + + return rules + + +def _add_empty_priority_class_arrays(d): + for pc in PRIORITY_CLASS_MAP.keys(): + d[pc] = [] + return d + + +def _rule_to_template(rule): + unscoped_rule_id = None + if 'rule_id' in rule: + unscoped_rule_id = _rule_id_from_namespaced(rule['rule_id']) + + template_name = _priority_class_to_template_name(rule['priority_class']) + if template_name in ['override', 'underride']: + templaterule = {k: rule[k] for k in ["conditions", "actions"]} + elif template_name in ["sender", "room"]: + templaterule = {'actions': rule['actions']} + unscoped_rule_id = rule['conditions'][0]['pattern'] + elif template_name == 'content': + if len(rule["conditions"]) != 1: + return None + thecond = rule["conditions"][0] + if "pattern" not in thecond: + return None + templaterule = {'actions': rule['actions']} + templaterule["pattern"] = thecond["pattern"] + + if unscoped_rule_id: + templaterule['rule_id'] = unscoped_rule_id + if 'default' in rule: + templaterule['default'] = rule['default'] + return templaterule + + +def _rule_id_from_namespaced(in_rule_id): + return in_rule_id.split('/')[-1] + + +def _priority_class_to_template_name(pc): + return PRIORITY_CLASS_INVERSE_MAP[pc] diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index cf68725ca1..edfe28c79b 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -22,12 +22,10 @@ from .base import ClientV1RestServlet, client_path_patterns from synapse.storage.push_rule import ( InconsistentRuleException, RuleNotFoundException ) -from synapse.push.baserules import list_with_base_rules, BASE_RULE_IDS -from synapse.push.rulekinds import ( - PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP -) +from synapse.push.clientformat import format_push_rules_for_user +from synapse.push.baserules import BASE_RULE_IDS +from synapse.push.rulekinds import PRIORITY_CLASS_MAP -import copy import simplejson as json @@ -133,48 +131,9 @@ class PushRuleRestServlet(ClientV1RestServlet): # is probably not going to make a whole lot of difference rawrules = yield self.store.get_push_rules_for_user(user_id) - ruleslist = [] - for rawrule in rawrules: - rule = dict(rawrule) - rule["conditions"] = json.loads(rawrule["conditions"]) - rule["actions"] = json.loads(rawrule["actions"]) - ruleslist.append(rule) - - # We're going to be mutating this a lot, so do a deep copy - ruleslist = copy.deepcopy(list_with_base_rules(ruleslist)) - - rules = {'global': {}, 'device': {}} - - rules['global'] = _add_empty_priority_class_arrays(rules['global']) - enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) - for r in ruleslist: - rulearray = None - - template_name = _priority_class_to_template_name(r['priority_class']) - - # Remove internal stuff. - for c in r["conditions"]: - c.pop("_id", None) - - pattern_type = c.pop("pattern_type", None) - if pattern_type == "user_id": - c["pattern"] = user_id - elif pattern_type == "user_localpart": - c["pattern"] = requester.user.localpart - - rulearray = rules['global'][template_name] - - template_rule = _rule_to_template(r) - if template_rule: - if r['rule_id'] in enabled_map: - template_rule['enabled'] = enabled_map[r['rule_id']] - elif 'enabled' in r: - template_rule['enabled'] = r['enabled'] - else: - template_rule['enabled'] = True - rulearray.append(template_rule) + rules = format_push_rules_for_user(requester.user, rawrules, enabled_map) path = request.postpath[1:] @@ -322,12 +281,6 @@ def _check_actions(actions): raise InvalidRuleException("Unrecognised action") -def _add_empty_priority_class_arrays(d): - for pc in PRIORITY_CLASS_MAP.keys(): - d[pc] = [] - return d - - def _filter_ruleset_with_path(ruleset, path): if path == []: raise UnrecognizedRequestError( @@ -376,37 +329,6 @@ def _priority_class_from_spec(spec): return pc -def _priority_class_to_template_name(pc): - return PRIORITY_CLASS_INVERSE_MAP[pc] - - -def _rule_to_template(rule): - unscoped_rule_id = None - if 'rule_id' in rule: - unscoped_rule_id = _rule_id_from_namespaced(rule['rule_id']) - - template_name = _priority_class_to_template_name(rule['priority_class']) - if template_name in ['override', 'underride']: - templaterule = {k: rule[k] for k in ["conditions", "actions"]} - elif template_name in ["sender", "room"]: - templaterule = {'actions': rule['actions']} - unscoped_rule_id = rule['conditions'][0]['pattern'] - elif template_name == 'content': - if len(rule["conditions"]) != 1: - return None - thecond = rule["conditions"][0] - if "pattern" not in thecond: - return None - templaterule = {'actions': rule['actions']} - templaterule["pattern"] = thecond["pattern"] - - if unscoped_rule_id: - templaterule['rule_id'] = unscoped_rule_id - if 'default' in rule: - templaterule['default'] = rule['default'] - return templaterule - - def _namespaced_rule_id_from_spec(spec): return _namespaced_rule_id(spec, spec['rule_id']) @@ -415,10 +337,6 @@ def _namespaced_rule_id(spec, rule_id): return "global/%s/%s" % (spec['template'], rule_id) -def _rule_id_from_namespaced(in_rule_id): - return in_rule_id.split('/')[-1] - - class InvalidRuleException(Exception): pass -- cgit 1.5.1 From 1b4f4a936fb416d81203fcd66be690f9a04b2b62 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 4 Mar 2016 14:44:01 +0000 Subject: Hook up the push rules stream to account_data in /sync --- synapse/handlers/sync.py | 22 +++++++ synapse/rest/client/v1/push_rule.py | 2 +- synapse/storage/__init__.py | 5 ++ synapse/storage/push_rule.py | 125 ++++++++++++++++-------------------- 4 files changed, 85 insertions(+), 69 deletions(-) (limited to 'synapse/rest') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index fded6e4009..92eab20c7c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -20,6 +20,7 @@ from synapse.api.constants import Membership, EventTypes from synapse.util import unwrapFirstError from synapse.util.logcontext import LoggingContext, preserve_fn from synapse.util.metrics import Measure +from synapse.push.clientformat import format_push_rules_for_user from twisted.internet import defer @@ -224,6 +225,10 @@ class SyncHandler(BaseHandler): ) ) + account_data['m.push_rules'] = yield self.push_rules_for_user( + sync_config.user + ) + tags_by_room = yield self.store.get_tags_for_user( sync_config.user.to_string() ) @@ -322,6 +327,14 @@ class SyncHandler(BaseHandler): defer.returnValue(room_sync) + @defer.inlineCallbacks + def push_rules_for_user(self, user): + user_id = user.to_string() + rawrules = yield self.store.get_push_rules_for_user(user_id) + enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) + rules = format_push_rules_for_user(user, rawrules, enabled_map) + defer.returnValue(rules) + def account_data_for_user(self, account_data): account_data_events = [] @@ -481,6 +494,15 @@ class SyncHandler(BaseHandler): ) ) + push_rules_changed = yield self.store.have_push_rules_changed_for_user( + user_id, int(since_token.push_rules_key) + ) + + if push_rules_changed: + account_data["m.push_rules"] = yield self.push_rules_for_user( + sync_config.user + ) + # Get a list of membership change events that have happened. rooms_changed = yield self.store.get_membership_changes_for_user( user_id, since_token.room_key, now_token.room_key diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index edfe28c79b..981d7708db 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -156,7 +156,7 @@ class PushRuleRestServlet(ClientV1RestServlet): return 200, {} def notify_user(self, user_id): - stream_id = self.store.get_push_rules_stream_token() + stream_id, _ = self.store.get_push_rules_stream_token() self.notifier.on_new_event( "push_rules_key", stream_id, users=[user_id] ) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index e2d7b52569..7b7b03d052 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -160,6 +160,11 @@ class DataStore(RoomMemberStore, RoomStore, prefilled_cache=presence_cache_prefill ) + self.push_rules_stream_cache = StreamChangeCache( + "PushRulesStreamChangeCache", + self._push_rules_stream_id_gen.get_max_token()[0], + ) + super(DataStore, self).__init__(hs) def take_presence_startup_info(self): diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index e034024108..792fcbdf5b 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -244,15 +244,10 @@ class PushRuleStore(SQLBaseStore): ) if update_stream: - self._simple_insert_txn( - txn, - table="push_rules_stream", - values={ - "stream_id": stream_id, - "stream_ordering": stream_ordering, - "user_id": user_id, - "rule_id": rule_id, - "op": "ADD", + self._insert_push_rules_update_txn( + txn, stream_id, stream_ordering, user_id, rule_id, + op="ADD", + data={ "priority_class": priority_class, "priority": priority, "conditions": conditions_json, @@ -260,13 +255,6 @@ class PushRuleStore(SQLBaseStore): } ) - txn.call_after( - self.get_push_rules_for_user.invalidate, (user_id,) - ) - txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, (user_id,) - ) - @defer.inlineCallbacks def delete_push_rule(self, user_id, rule_id): """ @@ -284,22 +272,10 @@ class PushRuleStore(SQLBaseStore): "push_rules", {'user_name': user_id, 'rule_id': rule_id}, ) - self._simple_insert_txn( - txn, - table="push_rules_stream", - values={ - "stream_id": stream_id, - "stream_ordering": stream_ordering, - "user_id": user_id, - "rule_id": rule_id, - "op": "DELETE", - } - ) - txn.call_after( - self.get_push_rules_for_user.invalidate, (user_id,) - ) - txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, (user_id,) + + self._insert_push_rules_update_txn( + txn, stream_id, stream_ordering, user_id, rule_id, + op="DELETE" ) with self._push_rules_stream_id_gen.get_next() as (stream_id, stream_ordering): @@ -328,23 +304,9 @@ class PushRuleStore(SQLBaseStore): {'id': new_id}, ) - self._simple_insert_txn( - txn, - "push_rules_stream", - values={ - "stream_id": stream_id, - "stream_ordering": stream_ordering, - "user_id": user_id, - "rule_id": rule_id, - "op": "ENABLE" if enabled else "DISABLE", - } - ) - - txn.call_after( - self.get_push_rules_for_user.invalidate, (user_id,) - ) - txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, (user_id,) + self._insert_push_rules_update_txn( + txn, stream_id, stream_ordering, user_id, rule_id, + op="ENABLE" if enabled else "DISABLE" ) @defer.inlineCallbacks @@ -370,24 +332,9 @@ class PushRuleStore(SQLBaseStore): {'actions': actions_json}, ) - self._simple_insert_txn( - txn, - "push_rules_stream", - values={ - "stream_id": stream_id, - "stream_ordering": stream_ordering, - "user_id": user_id, - "rule_id": rule_id, - "op": "ACTIONS", - "actions": actions_json, - } - ) - - txn.call_after( - self.get_push_rules_for_user.invalidate, (user_id,) - ) - txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, (user_id,) + self._insert_push_rules_update_txn( + txn, stream_id, stream_ordering, user_id, rule_id, + op="ACTIONS", data={"actions": actions_json} ) with self._push_rules_stream_id_gen.get_next() as (stream_id, stream_ordering): @@ -396,6 +343,31 @@ class PushRuleStore(SQLBaseStore): stream_id, stream_ordering ) + def _insert_push_rules_update_txn( + self, txn, stream_id, stream_ordering, user_id, rule_id, op, data=None + ): + values = { + "stream_id": stream_id, + "stream_ordering": stream_ordering, + "user_id": user_id, + "rule_id": rule_id, + "op": op, + } + if data is not None: + values.update(data) + + self._simple_insert_txn(txn, "push_rules_stream", values=values) + + txn.call_after( + self.get_push_rules_for_user.invalidate, (user_id,) + ) + txn.call_after( + self.get_push_rules_enabled_for_user.invalidate, (user_id,) + ) + txn.call_after( + self.push_rules_stream_cache.entity_has_changed, user_id, stream_id + ) + def get_all_push_rule_updates(self, last_id, current_id, limit): """Get all the push rules changes that have happend on the server""" def get_all_push_rule_updates_txn(txn): @@ -403,7 +375,7 @@ class PushRuleStore(SQLBaseStore): "SELECT stream_id, stream_ordering, user_id, rule_id," " op, priority_class, priority, conditions, actions" " FROM push_rules_stream" - " WHERE ? < stream_id and stream_id <= ?" + " WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) txn.execute(sql, (last_id, current_id, limit)) @@ -418,6 +390,23 @@ class PushRuleStore(SQLBaseStore): room stream ordering it corresponds to.""" return self._push_rules_stream_id_gen.get_max_token() + def have_push_rules_changed_for_user(self, user_id, last_id): + if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): + logger.error("FNARG") + return defer.succeed(False) + else: + def have_push_rules_changed_txn(txn): + sql = ( + "SELECT COUNT(stream_id) FROM push_rules_stream" + " WHERE user_id = ? AND ? < stream_id" + ) + txn.execute(sql, (user_id, last_id)) + count, = txn.fetchone() + return bool(count) + return self.runInteraction( + "have_push_rules_changed", have_push_rules_changed_txn + ) + class RuleNotFoundException(Exception): pass -- cgit 1.5.1