From e18257f0e52cf587ab4d7bb8937dbcaefeb26e54 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Mon, 25 Jan 2016 16:51:56 +0000 Subject: Add missing yield in push_rules set enabled --- synapse/rest/client/v1/push_rule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'synapse/rest/client/v1/push_rule.py') diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 2272d66dc7..cb3ec23872 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -52,7 +52,7 @@ class PushRuleRestServlet(ClientV1RestServlet): content = _parse_json(request) if 'attr' in spec: - self.set_rule_attr(requester.user.to_string(), spec, content) + yield self.set_rule_attr(requester.user.to_string(), spec, content) defer.returnValue((200, {})) try: @@ -228,7 +228,7 @@ class PushRuleRestServlet(ClientV1RestServlet): # bools directly, so let's not break them. raise SynapseError(400, "Value for 'enabled' must be boolean") namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - self.hs.get_datastore().set_push_rule_enabled( + return self.hs.get_datastore().set_push_rule_enabled( user_id, namespaced_rule_id, val ) else: -- cgit 1.5.1 From 8c94833b72f27d18a57e424a0f4f0c823c3a0aa1 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 27 Jan 2016 10:24:20 +0000 Subject: Fix adding push rules relative to other rules --- synapse/rest/client/v1/push_rule.py | 15 ++++++++++----- synapse/storage/push_rule.py | 3 ++- 2 files changed, 12 insertions(+), 6 deletions(-) (limited to 'synapse/rest/client/v1/push_rule.py') diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index cb3ec23872..96633a176c 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -66,11 +66,12 @@ class PushRuleRestServlet(ClientV1RestServlet): raise SynapseError(400, e.message) before = request.args.get("before", None) - if before and len(before): - before = before[0] + if before: + before = _namespaced_rule_id(spec, before[0]) + after = request.args.get("after", None) - if after and len(after): - after = after[0] + if after: + after = _namespaced_rule_id(spec, after[0]) try: yield self.hs.get_datastore().add_push_rule( @@ -452,11 +453,15 @@ def _strip_device_condition(rule): def _namespaced_rule_id_from_spec(spec): + return _namespaced_rule_id(spec, spec['rule_id']) + + +def _namespaced_rule_id(spec, rule_id): if spec['scope'] == 'global': scope = 'global' else: scope = 'device/%s' % (spec['profile_tag']) - return "%s/%s/%s" % (scope, spec['template'], spec['rule_id']) + return "%s/%s/%s" % (scope, spec['template'], rule_id) def _rule_id_from_namespaced(in_rule_id): diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 1f51c90ee5..f9a48171ba 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -130,7 +130,8 @@ class PushRuleStore(SQLBaseStore): def _add_push_rule_relative_txn(self, txn, user_id, **kwargs): after = kwargs.pop("after", None) - relative_to_rule = kwargs.pop("before", after) + before = kwargs.pop("before", None) + relative_to_rule = before or after res = self._simple_select_one_txn( txn, -- cgit 1.5.1 From 458782bf67ef7c188af752b0f455d4a0f9f4cdd5 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 16 Feb 2016 18:00:30 +0000 Subject: Fix typo in request validation for adding push rules. --- synapse/rest/client/v1/push_rule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/rest/client/v1/push_rule.py') diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 96633a176c..7766b8be1d 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -400,7 +400,7 @@ def _filter_ruleset_with_path(ruleset, path): def _priority_class_from_spec(spec): if spec['template'] not in PRIORITY_CLASS_MAP.keys(): - raise InvalidRuleException("Unknown template: %s" % (spec['kind'])) + raise InvalidRuleException("Unknown template: %s" % (spec['template'])) pc = PRIORITY_CLASS_MAP[spec['template']] if spec['scope'] == 'device': -- cgit 1.5.1 From b9977ea667889f6cf89464c92fc57cbcae7cca28 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 18 Feb 2016 16:05:13 +0000 Subject: Remove dead code for setting device specific rules. It wasn't possible to hit the code from the API because of a typo in parsing the request path. Since no-one was using the feature we might as well remove the dead code. --- synapse/push/__init__.py | 7 ++- synapse/push/action_generator.py | 2 +- synapse/push/bulk_push_rule_evaluator.py | 2 +- synapse/push/httppusher.py | 3 +- synapse/push/push_rule_evaluator.py | 15 ++---- synapse/push/pusherpool.py | 48 +++++++---------- synapse/rest/client/v1/push_rule.py | 90 ++------------------------------ synapse/rest/client/v1/pusher.py | 6 +-- synapse/storage/event_push_actions.py | 7 ++- synapse/storage/pusher.py | 6 +-- 10 files changed, 45 insertions(+), 141 deletions(-) (limited to 'synapse/rest/client/v1/push_rule.py') diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 8da2d8716c..4c6c3b83a2 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -47,14 +47,13 @@ class Pusher(object): MAX_BACKOFF = 60 * 60 * 1000 GIVE_UP_AFTER = 24 * 60 * 60 * 1000 - def __init__(self, _hs, profile_tag, user_id, app_id, + def __init__(self, _hs, user_id, app_id, app_display_name, device_display_name, pushkey, pushkey_ts, data, last_token, last_success, failing_since): self.hs = _hs self.evStreamHandler = self.hs.get_handlers().event_stream_handler self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() - self.profile_tag = profile_tag self.user_id = user_id self.app_id = app_id self.app_display_name = app_display_name @@ -186,8 +185,8 @@ class Pusher(object): processed = False rule_evaluator = yield \ - push_rule_evaluator.evaluator_for_user_id_and_profile_tag( - self.user_id, self.profile_tag, single_event['room_id'], self.store + push_rule_evaluator.evaluator_for_user_id( + self.user_id, single_event['room_id'], self.store ) actions = yield rule_evaluator.actions_for_event(single_event) diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index e0da0868ec..c6c1dc769e 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -44,5 +44,5 @@ class ActionGenerator: ) context.push_actions = [ - (uid, None, actions) for uid, actions in actions_by_user.items() + (uid, actions) for uid, actions in actions_by_user.items() ] diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 8ac5ceb9ef..0a23b3f102 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -152,7 +152,7 @@ def _condition_checker(evaluator, conditions, uid, display_name, cache): elif res is True: continue - res = evaluator.matches(cond, uid, display_name, None) + res = evaluator.matches(cond, uid, display_name) if _id: cache[_id] = bool(res) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index cdc4494928..9be4869360 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -23,12 +23,11 @@ logger = logging.getLogger(__name__) class HttpPusher(Pusher): - def __init__(self, _hs, profile_tag, user_id, app_id, + def __init__(self, _hs, user_id, app_id, app_display_name, device_display_name, pushkey, pushkey_ts, data, last_token, last_success, failing_since): super(HttpPusher, self).__init__( _hs, - profile_tag, user_id, app_id, app_display_name, diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 2a2b4437dc..98e2a2015e 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -33,7 +33,7 @@ INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") @defer.inlineCallbacks -def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store): +def evaluator_for_user_id(user_id, room_id, store): rawrules = yield store.get_push_rules_for_user(user_id) enabled_map = yield store.get_push_rules_enabled_for_user(user_id) our_member_event = yield store.get_current_state( @@ -43,7 +43,7 @@ def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store): ) defer.returnValue(PushRuleEvaluator( - user_id, profile_tag, rawrules, enabled_map, + user_id, rawrules, enabled_map, room_id, our_member_event, store )) @@ -77,10 +77,9 @@ def _room_member_count(ev, condition, room_member_count): class PushRuleEvaluator: DEFAULT_ACTIONS = [] - def __init__(self, user_id, profile_tag, raw_rules, enabled_map, room_id, + def __init__(self, user_id, raw_rules, enabled_map, room_id, our_member_event, store): self.user_id = user_id - self.profile_tag = profile_tag self.room_id = room_id self.our_member_event = our_member_event self.store = store @@ -152,7 +151,7 @@ class PushRuleEvaluator: matches = True for c in conditions: matches = evaluator.matches( - c, self.user_id, my_display_name, self.profile_tag + c, self.user_id, my_display_name ) if not matches: break @@ -189,13 +188,9 @@ class PushRuleEvaluatorForEvent(object): # Maps strings of e.g. 'content.body' -> event["content"]["body"] self._value_cache = _flatten_dict(event) - def matches(self, condition, user_id, display_name, profile_tag): + def matches(self, condition, user_id, display_name): if condition['kind'] == 'event_match': return self._event_match(condition, user_id) - elif condition['kind'] == 'device': - if 'profile_tag' not in condition: - return True - return condition['profile_tag'] == profile_tag elif condition['kind'] == 'contains_display_name': return self._contains_display_name(display_name) elif condition['kind'] == 'room_member_count': diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index d7dcb2de4b..a05aa5f661 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -29,6 +29,7 @@ class PusherPool: def __init__(self, _hs): self.hs = _hs self.store = self.hs.get_datastore() + self.clock = self.hs.get_clock() self.pushers = {} self.last_pusher_started = -1 @@ -38,8 +39,11 @@ class PusherPool: self._start_pushers(pushers) @defer.inlineCallbacks - def add_pusher(self, user_id, access_token, profile_tag, kind, app_id, - app_display_name, device_display_name, pushkey, lang, data): + def add_pusher(self, user_id, access_token, kind, app_id, + app_display_name, device_display_name, pushkey, lang, data, + profile_tag=""): + time_now_msec = self.clock.time_msec() + # we try to create the pusher just to validate the config: it # will then get pulled out of the database, # recreated, added and started: this means we have only one @@ -47,23 +51,31 @@ class PusherPool: self._create_pusher({ "user_name": user_id, "kind": kind, - "profile_tag": profile_tag, "app_id": app_id, "app_display_name": app_display_name, "device_display_name": device_display_name, "pushkey": pushkey, - "ts": self.hs.get_clock().time_msec(), + "ts": time_now_msec, "lang": lang, "data": data, "last_token": None, "last_success": None, "failing_since": None }) - yield self._add_pusher_to_store( - user_id, access_token, profile_tag, kind, app_id, - app_display_name, device_display_name, - pushkey, lang, data + yield self.store.add_pusher( + user_id=user_id, + access_token=access_token, + kind=kind, + app_id=app_id, + app_display_name=app_display_name, + device_display_name=device_display_name, + pushkey=pushkey, + pushkey_ts=time_now_msec, + lang=lang, + data=data, + profile_tag=profile_tag, ) + yield self._refresh_pusher(app_id, pushkey, user_id) @defer.inlineCallbacks def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey, @@ -94,30 +106,10 @@ class PusherPool: ) yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) - @defer.inlineCallbacks - def _add_pusher_to_store(self, user_id, access_token, profile_tag, kind, - app_id, app_display_name, device_display_name, - pushkey, lang, data): - yield self.store.add_pusher( - user_id=user_id, - access_token=access_token, - profile_tag=profile_tag, - kind=kind, - app_id=app_id, - app_display_name=app_display_name, - device_display_name=device_display_name, - pushkey=pushkey, - pushkey_ts=self.hs.get_clock().time_msec(), - lang=lang, - data=data, - ) - yield self._refresh_pusher(app_id, pushkey, user_id) - def _create_pusher(self, pusherdict): if pusherdict['kind'] == 'http': return HttpPusher( self.hs, - profile_tag=pusherdict['profile_tag'], user_id=pusherdict['user_name'], app_id=pusherdict['app_id'], app_display_name=pusherdict['app_display_name'], diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 7766b8be1d..5db2805d68 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -60,7 +60,6 @@ class PushRuleRestServlet(ClientV1RestServlet): spec['template'], spec['rule_id'], content, - device=spec['device'] if 'device' in spec else None ) except InvalidRuleException as e: raise SynapseError(400, e.message) @@ -153,23 +152,7 @@ class PushRuleRestServlet(ClientV1RestServlet): elif pattern_type == "user_localpart": c["pattern"] = user.localpart - if r['priority_class'] > PRIORITY_CLASS_MAP['override']: - # per-device rule - profile_tag = _profile_tag_from_conditions(r["conditions"]) - r = _strip_device_condition(r) - if not profile_tag: - continue - if profile_tag not in rules['device']: - rules['device'][profile_tag] = {} - rules['device'][profile_tag] = ( - _add_empty_priority_class_arrays( - rules['device'][profile_tag] - ) - ) - - rulearray = rules['device'][profile_tag][template_name] - else: - rulearray = rules['global'][template_name] + rulearray = rules['global'][template_name] template_rule = _rule_to_template(r) if template_rule: @@ -195,24 +178,6 @@ class PushRuleRestServlet(ClientV1RestServlet): path = path[1:] result = _filter_ruleset_with_path(rules['global'], path) defer.returnValue((200, result)) - elif path[0] == 'device': - path = path[1:] - if path == []: - raise UnrecognizedRequestError( - PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR - ) - if path[0] == '': - defer.returnValue((200, rules['device'])) - - profile_tag = path[0] - path = path[1:] - if profile_tag not in rules['device']: - ret = {} - ret = _add_empty_priority_class_arrays(ret) - defer.returnValue((200, ret)) - ruleset = rules['device'][profile_tag] - result = _filter_ruleset_with_path(ruleset, path) - defer.returnValue((200, result)) else: raise UnrecognizedRequestError() @@ -252,16 +217,9 @@ def _rule_spec_from_path(path): scope = path[1] path = path[2:] - if scope not in ['global', 'device']: + if scope != 'global': raise UnrecognizedRequestError() - device = None - if scope == 'device': - if len(path) == 0: - raise UnrecognizedRequestError() - device = path[0] - path = path[1:] - if len(path) == 0: raise UnrecognizedRequestError() @@ -278,8 +236,6 @@ def _rule_spec_from_path(path): 'template': template, 'rule_id': rule_id } - if device: - spec['profile_tag'] = device path = path[1:] @@ -289,7 +245,7 @@ def _rule_spec_from_path(path): return spec -def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None): +def _rule_tuple_from_request_object(rule_template, rule_id, req_obj): if rule_template in ['override', 'underride']: if 'conditions' not in req_obj: raise InvalidRuleException("Missing 'conditions'") @@ -322,12 +278,6 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None else: raise InvalidRuleException("Unknown rule template: %s" % (rule_template,)) - if device: - conditions.append({ - 'kind': 'device', - 'profile_tag': device - }) - if 'actions' not in req_obj: raise InvalidRuleException("No actions found") actions = req_obj['actions'] @@ -349,17 +299,6 @@ def _add_empty_priority_class_arrays(d): return d -def _profile_tag_from_conditions(conditions): - """ - Given a list of conditions, return the profile tag of the - device rule if there is one - """ - for c in conditions: - if c['kind'] == 'device': - return c['profile_tag'] - return None - - def _filter_ruleset_with_path(ruleset, path): if path == []: raise UnrecognizedRequestError( @@ -403,19 +342,11 @@ def _priority_class_from_spec(spec): raise InvalidRuleException("Unknown template: %s" % (spec['template'])) pc = PRIORITY_CLASS_MAP[spec['template']] - if spec['scope'] == 'device': - pc += len(PRIORITY_CLASS_MAP) - return pc def _priority_class_to_template_name(pc): - if pc > PRIORITY_CLASS_MAP['override']: - # per-device - prio_class_index = pc - len(PRIORITY_CLASS_MAP) - return PRIORITY_CLASS_INVERSE_MAP[prio_class_index] - else: - return PRIORITY_CLASS_INVERSE_MAP[pc] + return PRIORITY_CLASS_INVERSE_MAP[pc] def _rule_to_template(rule): @@ -445,23 +376,12 @@ def _rule_to_template(rule): return templaterule -def _strip_device_condition(rule): - for i, c in enumerate(rule['conditions']): - if c['kind'] == 'device': - del rule['conditions'][i] - return rule - - def _namespaced_rule_id_from_spec(spec): return _namespaced_rule_id(spec, spec['rule_id']) def _namespaced_rule_id(spec, rule_id): - if spec['scope'] == 'global': - scope = 'global' - else: - scope = 'device/%s' % (spec['profile_tag']) - return "%s/%s/%s" % (scope, spec['template'], rule_id) + return "global/%s/%s" % (spec['template'], rule_id) def _rule_id_from_namespaced(in_rule_id): diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 5547f1b112..4c662e6e3c 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -45,7 +45,7 @@ class PusherRestServlet(ClientV1RestServlet): ) defer.returnValue((200, {})) - reqd = ['profile_tag', 'kind', 'app_id', 'app_display_name', + reqd = ['kind', 'app_id', 'app_display_name', 'device_display_name', 'pushkey', 'lang', 'data'] missing = [] for i in reqd: @@ -73,14 +73,14 @@ class PusherRestServlet(ClientV1RestServlet): yield pusher_pool.add_pusher( user_id=user.to_string(), access_token=requester.access_token_id, - profile_tag=content['profile_tag'], kind=content['kind'], app_id=content['app_id'], app_display_name=content['app_display_name'], device_display_name=content['device_display_name'], pushkey=content['pushkey'], lang=content['lang'], - data=content['data'] + data=content['data'], + profile_tag=content.get('profile_tag', ""), ) except PusherConfigException as pce: raise SynapseError(400, "Config Error: " + pce.message, diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index d77a817682..5820539a92 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -27,15 +27,14 @@ class EventPushActionsStore(SQLBaseStore): def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): """ :param event: the event set actions for - :param tuples: list of tuples of (user_id, profile_tag, actions) + :param tuples: list of tuples of (user_id, actions) """ values = [] - for uid, profile_tag, actions in tuples: + for uid, actions in tuples: values.append({ 'room_id': event.room_id, 'event_id': event.event_id, 'user_id': uid, - 'profile_tag': profile_tag, 'actions': json.dumps(actions), 'stream_ordering': event.internal_metadata.stream_ordering, 'topological_ordering': event.depth, @@ -43,7 +42,7 @@ class EventPushActionsStore(SQLBaseStore): 'highlight': 1 if _action_has_highlight(actions) else 0, }) - for uid, _, __ in tuples: + for uid, __ in tuples: txn.call_after( self.get_unread_event_push_actions_by_room_for_user.invalidate_many, (event.room_id, uid) diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 8ec706178a..c23648cdbc 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -80,9 +80,9 @@ class PusherStore(SQLBaseStore): defer.returnValue(rows) @defer.inlineCallbacks - def add_pusher(self, user_id, access_token, profile_tag, kind, app_id, + def add_pusher(self, user_id, access_token, kind, app_id, app_display_name, device_display_name, - pushkey, pushkey_ts, lang, data): + pushkey, pushkey_ts, lang, data, profile_tag=""): try: next_id = yield self._pushers_id_gen.get_next() yield self._simple_upsert( @@ -95,12 +95,12 @@ class PusherStore(SQLBaseStore): dict( access_token=access_token, kind=kind, - profile_tag=profile_tag, app_display_name=app_display_name, device_display_name=device_display_name, ts=pushkey_ts, lang=lang, data=encode_canonical_json(data), + profile_tag=profile_tag, ), insertion_values=dict( id=next_id, -- cgit 1.5.1 From 9892d017b25380184fb8db47ef859b07053f00f9 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 24 Feb 2016 16:31:07 +0000 Subject: Remove unused get_rule_attr method --- synapse/rest/client/v1/push_rule.py | 8 -------- 1 file changed, 8 deletions(-) (limited to 'synapse/rest/client/v1/push_rule.py') diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 5db2805d68..6c8f09e898 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -200,14 +200,6 @@ class PushRuleRestServlet(ClientV1RestServlet): else: raise UnrecognizedRequestError() - def get_rule_attr(self, user_id, namespaced_rule_id, attr): - if attr == 'enabled': - return self.hs.get_datastore().get_push_rule_enabled_by_user_rule_id( - user_id, namespaced_rule_id - ) - else: - raise UnrecognizedRequestError() - def _rule_spec_from_path(path): if len(path) < 2: -- cgit 1.5.1 From 15c2ac2cac7377e48eb0531f911c4b7ea1891457 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 25 Feb 2016 15:13:07 +0000 Subject: Make sure we return a JSON object when returning the values of specif… MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ic keys from a push rule --- synapse/rest/client/v1/push_rule.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'synapse/rest/client/v1/push_rule.py') diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 6c8f09e898..d26e4cde3e 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -324,7 +324,9 @@ def _filter_ruleset_with_path(ruleset, path): attr = path[0] if attr in the_rule: - return the_rule[attr] + # Make sure we return a JSON object as the attribute may be a + # JSON value. + return {attr: the_rule[attr]} else: raise UnrecognizedRequestError() -- cgit 1.5.1 From de27f7fc79b785961181d13749468ae3e2019772 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 26 Feb 2016 14:28:19 +0000 Subject: Add support for changing the actions for default rules See matrix-org/matrix-doc#283 Works by adding dummy rules to the push rules table with a negative priority class and then using those rules to clobber the default rule actions when adding the default rules in ``list_with_base_rules`` --- synapse/push/baserules.py | 57 ++++++++++++++++++++++++++++++++----- synapse/rest/client/v1/push_rule.py | 31 +++++++++++++++++--- synapse/storage/push_rule.py | 25 ++++++++++++++++ 3 files changed, 102 insertions(+), 11 deletions(-) (limited to 'synapse/rest/client/v1/push_rule.py') diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 0832c77cb4..86a2998bcc 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -13,46 +13,67 @@ # limitations under the License. from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP +import copy def list_with_base_rules(rawrules): + """Combine the list of rules set by the user with the default push rules + + :param list rawrules: The rules the user has modified or set. + :returns: A new list with the rules set by the user combined with the + defaults. + """ ruleslist = [] + # Grab the base rules that the user has modified. + # The modified base rules have a priority_class of -1. + modified_base_rules = { + r['rule_id']: r for r in rawrules if r['priority_class'] < 0 + } + + # Remove the modified base rules from the list, They'll be added back + # in the default postions in the list. + rawrules = [r for r in rawrules if r['priority_class'] >= 0] + # shove the server default rules for each kind onto the end of each current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1] ruleslist.extend(make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class] + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules )) for r in rawrules: if r['priority_class'] < current_prio_class: while r['priority_class'] < current_prio_class: ruleslist.extend(make_base_append_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class] + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, )) current_prio_class -= 1 if current_prio_class > 0: ruleslist.extend(make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class] + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, )) ruleslist.append(r) while current_prio_class > 0: ruleslist.extend(make_base_append_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class] + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, )) current_prio_class -= 1 if current_prio_class > 0: ruleslist.extend(make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class] + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, )) return ruleslist -def make_base_append_rules(kind): +def make_base_append_rules(kind, modified_base_rules): rules = [] if kind == 'override': @@ -62,15 +83,31 @@ def make_base_append_rules(kind): elif kind == 'content': rules = BASE_APPEND_CONTENT_RULES + # Copy the rules before modifying them + rules = copy.deepcopy(rules) + for r in rules: + # Only modify the actions, keep the conditions the same. + modified = modified_base_rules.get(r['rule_id']) + if modified: + r['actions'] = modified['actions'] + return rules -def make_base_prepend_rules(kind): +def make_base_prepend_rules(kind, modified_base_rules): rules = [] if kind == 'override': rules = BASE_PREPEND_OVERRIDE_RULES + # Copy the rules before modifying them + rules = copy.deepcopy(rules) + for r in rules: + # Only modify the actions, keep the conditions the same. + modified = modified_base_rules.get(r['rule_id']) + if modified: + r['actions'] = modified['actions'] + return rules @@ -263,18 +300,24 @@ BASE_APPEND_UNDERRIDE_RULES = [ ] +BASE_RULE_IDS = set() + for r in BASE_APPEND_CONTENT_RULES: r['priority_class'] = PRIORITY_CLASS_MAP['content'] r['default'] = True + BASE_RULE_IDS.add(r['rule_id']) for r in BASE_PREPEND_OVERRIDE_RULES: r['priority_class'] = PRIORITY_CLASS_MAP['override'] r['default'] = True + BASE_RULE_IDS.add(r['rule_id']) for r in BASE_APPEND_OVRRIDE_RULES: r['priority_class'] = PRIORITY_CLASS_MAP['override'] r['default'] = True + BASE_RULE_IDS.add(r['rule_id']) for r in BASE_APPEND_UNDERRIDE_RULES: r['priority_class'] = PRIORITY_CLASS_MAP['underride'] r['default'] = True + BASE_RULE_IDS.add(r['rule_id']) diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index d26e4cde3e..970a019223 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -22,7 +22,7 @@ from .base import ClientV1RestServlet, client_path_patterns from synapse.storage.push_rule import ( InconsistentRuleException, RuleNotFoundException ) -import synapse.push.baserules as baserules +from synapse.push.baserules import list_with_base_rules, BASE_RULE_IDS from synapse.push.rulekinds import ( PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP ) @@ -55,6 +55,10 @@ class PushRuleRestServlet(ClientV1RestServlet): yield self.set_rule_attr(requester.user.to_string(), spec, content) defer.returnValue((200, {})) + if spec['rule_id'].startswith('.'): + # Rule ids starting with '.' are reserved for server default rules. + raise SynapseError(400, "cannot add new rule_ids that start with '.'") + try: (conditions, actions) = _rule_tuple_from_request_object( spec['template'], @@ -128,7 +132,7 @@ class PushRuleRestServlet(ClientV1RestServlet): ruleslist.append(rule) # We're going to be mutating this a lot, so do a deep copy - ruleslist = copy.deepcopy(baserules.list_with_base_rules(ruleslist)) + ruleslist = copy.deepcopy(list_with_base_rules(ruleslist)) rules = {'global': {}, 'device': {}} @@ -197,6 +201,18 @@ class PushRuleRestServlet(ClientV1RestServlet): return self.hs.get_datastore().set_push_rule_enabled( user_id, namespaced_rule_id, val ) + elif spec['attr'] == 'actions': + actions = val.get('actions') + _check_actions(actions) + namespaced_rule_id = _namespaced_rule_id_from_spec(spec) + rule_id = spec['rule_id'] + is_default_rule = rule_id.startswith(".") + if is_default_rule: + if namespaced_rule_id not in BASE_RULE_IDS: + raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) + return self.hs.get_datastore().set_push_rule_actions( + user_id, namespaced_rule_id, actions, is_default_rule + ) else: raise UnrecognizedRequestError() @@ -274,6 +290,15 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj): raise InvalidRuleException("No actions found") actions = req_obj['actions'] + _check_actions(actions) + + return conditions, actions + + +def _check_actions(actions): + if not isinstance(actions, list): + raise InvalidRuleException("No actions found") + for a in actions: if a in ['notify', 'dont_notify', 'coalesce']: pass @@ -282,8 +307,6 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj): else: raise InvalidRuleException("Unrecognised action") - return conditions, actions - def _add_empty_priority_class_arrays(d): for pc in PRIORITY_CLASS_MAP.keys(): diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index e19a81e41f..bb5c14d912 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -294,6 +294,31 @@ class PushRuleStore(SQLBaseStore): self.get_push_rules_enabled_for_user.invalidate, (user_id,) ) + def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule): + actions_json = json.dumps(actions) + + def set_push_rule_actions_txn(txn): + if is_default_rule: + # Add a dummy rule to the rules table with the user specified + # actions. + priority_class = -1 + priority = 1 + self._upsert_push_rule_txn( + txn, user_id, rule_id, priority_class, priority, + "[]", actions_json + ) + else: + self._simple_update_one_txn( + txn, + "push_rules", + {'user_name': user_id, 'rule_id': rule_id}, + {'actions': actions_json}, + ) + + return self.runInteraction( + "set_push_rule_actions", set_push_rule_actions_txn, + ) + class RuleNotFoundException(Exception): pass -- cgit 1.5.1 From ddf9e7b3027eee61086ebfb447c5fa33e9b640fe Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 3 Mar 2016 14:57:45 +0000 Subject: Hook up the push rules to the notifier --- synapse/handlers/message.py | 4 ++-- synapse/notifier.py | 2 +- synapse/rest/client/v1/push_rule.py | 44 ++++++++++++++++++++++++------------- synapse/streams/events.py | 4 ++++ synapse/types.py | 7 ++++++ 5 files changed, 43 insertions(+), 18 deletions(-) (limited to 'synapse/rest/client/v1/push_rule.py') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index afa7c9c36c..2fa12c8f2b 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -647,8 +647,8 @@ class MessageHandler(BaseHandler): user_id, messages, is_peeking=is_peeking ) - start_token = StreamToken(token[0], 0, 0, 0, 0) - end_token = StreamToken(token[1], 0, 0, 0, 0) + start_token = StreamToken.START.copy_and_replace("room_key", token[0]) + end_token = StreamToken.START.copy_and_replace("room_key", token[1]) time_now = self.clock.time_msec() diff --git a/synapse/notifier.py b/synapse/notifier.py index 3c36a20868..9b69b0333a 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -284,7 +284,7 @@ class Notifier(object): @defer.inlineCallbacks def wait_for_events(self, user_id, timeout, callback, room_ids=None, - from_token=StreamToken("s0", "0", "0", "0", "0")): + from_token=StreamToken.START): """Wait until the callback returns a non empty response or the timeout fires. """ diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 970a019223..cf68725ca1 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -36,6 +36,11 @@ class PushRuleRestServlet(ClientV1RestServlet): SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( "Unrecognised request: You probably wanted a trailing slash") + def __init__(self, hs): + super(PushRuleRestServlet, self).__init__(hs) + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() + @defer.inlineCallbacks def on_PUT(self, request): spec = _rule_spec_from_path(request.postpath) @@ -51,8 +56,11 @@ class PushRuleRestServlet(ClientV1RestServlet): content = _parse_json(request) + user_id = requester.user.to_string() + if 'attr' in spec: - yield self.set_rule_attr(requester.user.to_string(), spec, content) + yield self.set_rule_attr(user_id, spec, content) + self.notify_user(user_id) defer.returnValue((200, {})) if spec['rule_id'].startswith('.'): @@ -77,8 +85,8 @@ class PushRuleRestServlet(ClientV1RestServlet): after = _namespaced_rule_id(spec, after[0]) try: - yield self.hs.get_datastore().add_push_rule( - user_id=requester.user.to_string(), + yield self.store.add_push_rule( + user_id=user_id, rule_id=_namespaced_rule_id_from_spec(spec), priority_class=priority_class, conditions=conditions, @@ -86,6 +94,7 @@ class PushRuleRestServlet(ClientV1RestServlet): before=before, after=after ) + self.notify_user(user_id) except InconsistentRuleException as e: raise SynapseError(400, e.message) except RuleNotFoundException as e: @@ -98,13 +107,15 @@ class PushRuleRestServlet(ClientV1RestServlet): spec = _rule_spec_from_path(request.postpath) requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() namespaced_rule_id = _namespaced_rule_id_from_spec(spec) try: - yield self.hs.get_datastore().delete_push_rule( - requester.user.to_string(), namespaced_rule_id + yield self.store.delete_push_rule( + user_id, namespaced_rule_id ) + self.notify_user(user_id) defer.returnValue((200, {})) except StoreError as e: if e.code == 404: @@ -115,14 +126,12 @@ class PushRuleRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): requester = yield self.auth.get_user_by_req(request) - user = requester.user + user_id = requester.user.to_string() # we build up the full structure and then decide which bits of it # to send which means doing unnecessary work sometimes but is # is probably not going to make a whole lot of difference - rawrules = yield self.hs.get_datastore().get_push_rules_for_user( - user.to_string() - ) + rawrules = yield self.store.get_push_rules_for_user(user_id) ruleslist = [] for rawrule in rawrules: @@ -138,8 +147,7 @@ class PushRuleRestServlet(ClientV1RestServlet): rules['global'] = _add_empty_priority_class_arrays(rules['global']) - enabled_map = yield self.hs.get_datastore().\ - get_push_rules_enabled_for_user(user.to_string()) + enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) for r in ruleslist: rulearray = None @@ -152,9 +160,9 @@ class PushRuleRestServlet(ClientV1RestServlet): pattern_type = c.pop("pattern_type", None) if pattern_type == "user_id": - c["pattern"] = user.to_string() + c["pattern"] = user_id elif pattern_type == "user_localpart": - c["pattern"] = user.localpart + c["pattern"] = requester.user.localpart rulearray = rules['global'][template_name] @@ -188,6 +196,12 @@ class PushRuleRestServlet(ClientV1RestServlet): def on_OPTIONS(self, _): return 200, {} + def notify_user(self, user_id): + stream_id = self.store.get_push_rules_stream_token() + self.notifier.on_new_event( + "push_rules_key", stream_id, users=[user_id] + ) + def set_rule_attr(self, user_id, spec, val): if spec['attr'] == 'enabled': if isinstance(val, dict) and "enabled" in val: @@ -198,7 +212,7 @@ class PushRuleRestServlet(ClientV1RestServlet): # bools directly, so let's not break them. raise SynapseError(400, "Value for 'enabled' must be boolean") namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - return self.hs.get_datastore().set_push_rule_enabled( + return self.store.set_push_rule_enabled( user_id, namespaced_rule_id, val ) elif spec['attr'] == 'actions': @@ -210,7 +224,7 @@ class PushRuleRestServlet(ClientV1RestServlet): if is_default_rule: if namespaced_rule_id not in BASE_RULE_IDS: raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) - return self.hs.get_datastore().set_push_rule_actions( + return self.store.set_push_rule_actions( user_id, namespaced_rule_id, actions, is_default_rule ) else: diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 5ddf4e988b..d4c0bb6732 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -38,9 +38,12 @@ class EventSources(object): name: cls(hs) for name, cls in EventSources.SOURCE_TYPES.items() } + self.store = hs.get_datastore() @defer.inlineCallbacks def get_current_token(self, direction='f'): + push_rules_key, _ = self.store.get_push_rules_stream_token() + token = StreamToken( room_key=( yield self.sources["room"].get_current_key(direction) @@ -57,5 +60,6 @@ class EventSources(object): account_data_key=( yield self.sources["account_data"].get_current_key() ), + push_rules_key=push_rules_key, ) defer.returnValue(token) diff --git a/synapse/types.py b/synapse/types.py index d5bd95cbd3..5b166835bd 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -115,6 +115,7 @@ class StreamToken( "typing_key", "receipt_key", "account_data_key", + "push_rules_key", )) ): _SEPARATOR = "_" @@ -150,6 +151,7 @@ class StreamToken( or (int(other.typing_key) < int(self.typing_key)) or (int(other.receipt_key) < int(self.receipt_key)) or (int(other.account_data_key) < int(self.account_data_key)) + or (int(other.push_rules_key) < int(self.push_rules_key)) ) def copy_and_advance(self, key, new_value): @@ -174,6 +176,11 @@ class StreamToken( return StreamToken(**d) +StreamToken.START = StreamToken( + *(["s0"] + ["0"] * (len(StreamToken._fields) - 1)) +) + + class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): """Tokens are positions between events. The token "s1" comes after event 1. -- cgit 1.5.1 From 3406eba4ef40de888ebb5b22c0ea4925b2dddeb1 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 3 Mar 2016 16:11:59 +0000 Subject: Move the code for formatting push rules into a separate function --- synapse/push/clientformat.py | 112 ++++++++++++++++++++++++++++++++++++ synapse/rest/client/v1/push_rule.py | 90 ++--------------------------- 2 files changed, 116 insertions(+), 86 deletions(-) create mode 100644 synapse/push/clientformat.py (limited to 'synapse/rest/client/v1/push_rule.py') diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py new file mode 100644 index 0000000000..ae9db9ec2f --- /dev/null +++ b/synapse/push/clientformat.py @@ -0,0 +1,112 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.push.baserules import list_with_base_rules + +from synapse.push.rulekinds import ( + PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP +) + +import copy +import simplejson as json + + +def format_push_rules_for_user(user, rawrules, enabled_map): + """Converts a list of rawrules and a enabled map into nested dictionaries + to match the Matrix client-server format for push rules""" + + ruleslist = [] + for rawrule in rawrules: + rule = dict(rawrule) + rule["conditions"] = json.loads(rawrule["conditions"]) + rule["actions"] = json.loads(rawrule["actions"]) + ruleslist.append(rule) + + # We're going to be mutating this a lot, so do a deep copy + ruleslist = copy.deepcopy(list_with_base_rules(ruleslist)) + + rules = {'global': {}, 'device': {}} + + rules['global'] = _add_empty_priority_class_arrays(rules['global']) + + for r in ruleslist: + rulearray = None + + template_name = _priority_class_to_template_name(r['priority_class']) + + # Remove internal stuff. + for c in r["conditions"]: + c.pop("_id", None) + + pattern_type = c.pop("pattern_type", None) + if pattern_type == "user_id": + c["pattern"] = user.to_string() + elif pattern_type == "user_localpart": + c["pattern"] = user.localpart + + rulearray = rules['global'][template_name] + + template_rule = _rule_to_template(r) + if template_rule: + if r['rule_id'] in enabled_map: + template_rule['enabled'] = enabled_map[r['rule_id']] + elif 'enabled' in r: + template_rule['enabled'] = r['enabled'] + else: + template_rule['enabled'] = True + rulearray.append(template_rule) + + return rules + + +def _add_empty_priority_class_arrays(d): + for pc in PRIORITY_CLASS_MAP.keys(): + d[pc] = [] + return d + + +def _rule_to_template(rule): + unscoped_rule_id = None + if 'rule_id' in rule: + unscoped_rule_id = _rule_id_from_namespaced(rule['rule_id']) + + template_name = _priority_class_to_template_name(rule['priority_class']) + if template_name in ['override', 'underride']: + templaterule = {k: rule[k] for k in ["conditions", "actions"]} + elif template_name in ["sender", "room"]: + templaterule = {'actions': rule['actions']} + unscoped_rule_id = rule['conditions'][0]['pattern'] + elif template_name == 'content': + if len(rule["conditions"]) != 1: + return None + thecond = rule["conditions"][0] + if "pattern" not in thecond: + return None + templaterule = {'actions': rule['actions']} + templaterule["pattern"] = thecond["pattern"] + + if unscoped_rule_id: + templaterule['rule_id'] = unscoped_rule_id + if 'default' in rule: + templaterule['default'] = rule['default'] + return templaterule + + +def _rule_id_from_namespaced(in_rule_id): + return in_rule_id.split('/')[-1] + + +def _priority_class_to_template_name(pc): + return PRIORITY_CLASS_INVERSE_MAP[pc] diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index cf68725ca1..edfe28c79b 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -22,12 +22,10 @@ from .base import ClientV1RestServlet, client_path_patterns from synapse.storage.push_rule import ( InconsistentRuleException, RuleNotFoundException ) -from synapse.push.baserules import list_with_base_rules, BASE_RULE_IDS -from synapse.push.rulekinds import ( - PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP -) +from synapse.push.clientformat import format_push_rules_for_user +from synapse.push.baserules import BASE_RULE_IDS +from synapse.push.rulekinds import PRIORITY_CLASS_MAP -import copy import simplejson as json @@ -133,48 +131,9 @@ class PushRuleRestServlet(ClientV1RestServlet): # is probably not going to make a whole lot of difference rawrules = yield self.store.get_push_rules_for_user(user_id) - ruleslist = [] - for rawrule in rawrules: - rule = dict(rawrule) - rule["conditions"] = json.loads(rawrule["conditions"]) - rule["actions"] = json.loads(rawrule["actions"]) - ruleslist.append(rule) - - # We're going to be mutating this a lot, so do a deep copy - ruleslist = copy.deepcopy(list_with_base_rules(ruleslist)) - - rules = {'global': {}, 'device': {}} - - rules['global'] = _add_empty_priority_class_arrays(rules['global']) - enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) - for r in ruleslist: - rulearray = None - - template_name = _priority_class_to_template_name(r['priority_class']) - - # Remove internal stuff. - for c in r["conditions"]: - c.pop("_id", None) - - pattern_type = c.pop("pattern_type", None) - if pattern_type == "user_id": - c["pattern"] = user_id - elif pattern_type == "user_localpart": - c["pattern"] = requester.user.localpart - - rulearray = rules['global'][template_name] - - template_rule = _rule_to_template(r) - if template_rule: - if r['rule_id'] in enabled_map: - template_rule['enabled'] = enabled_map[r['rule_id']] - elif 'enabled' in r: - template_rule['enabled'] = r['enabled'] - else: - template_rule['enabled'] = True - rulearray.append(template_rule) + rules = format_push_rules_for_user(requester.user, rawrules, enabled_map) path = request.postpath[1:] @@ -322,12 +281,6 @@ def _check_actions(actions): raise InvalidRuleException("Unrecognised action") -def _add_empty_priority_class_arrays(d): - for pc in PRIORITY_CLASS_MAP.keys(): - d[pc] = [] - return d - - def _filter_ruleset_with_path(ruleset, path): if path == []: raise UnrecognizedRequestError( @@ -376,37 +329,6 @@ def _priority_class_from_spec(spec): return pc -def _priority_class_to_template_name(pc): - return PRIORITY_CLASS_INVERSE_MAP[pc] - - -def _rule_to_template(rule): - unscoped_rule_id = None - if 'rule_id' in rule: - unscoped_rule_id = _rule_id_from_namespaced(rule['rule_id']) - - template_name = _priority_class_to_template_name(rule['priority_class']) - if template_name in ['override', 'underride']: - templaterule = {k: rule[k] for k in ["conditions", "actions"]} - elif template_name in ["sender", "room"]: - templaterule = {'actions': rule['actions']} - unscoped_rule_id = rule['conditions'][0]['pattern'] - elif template_name == 'content': - if len(rule["conditions"]) != 1: - return None - thecond = rule["conditions"][0] - if "pattern" not in thecond: - return None - templaterule = {'actions': rule['actions']} - templaterule["pattern"] = thecond["pattern"] - - if unscoped_rule_id: - templaterule['rule_id'] = unscoped_rule_id - if 'default' in rule: - templaterule['default'] = rule['default'] - return templaterule - - def _namespaced_rule_id_from_spec(spec): return _namespaced_rule_id(spec, spec['rule_id']) @@ -415,10 +337,6 @@ def _namespaced_rule_id(spec, rule_id): return "global/%s/%s" % (spec['template'], rule_id) -def _rule_id_from_namespaced(in_rule_id): - return in_rule_id.split('/')[-1] - - class InvalidRuleException(Exception): pass -- cgit 1.5.1 From 1b4f4a936fb416d81203fcd66be690f9a04b2b62 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 4 Mar 2016 14:44:01 +0000 Subject: Hook up the push rules stream to account_data in /sync --- synapse/handlers/sync.py | 22 +++++++ synapse/rest/client/v1/push_rule.py | 2 +- synapse/storage/__init__.py | 5 ++ synapse/storage/push_rule.py | 125 ++++++++++++++++-------------------- 4 files changed, 85 insertions(+), 69 deletions(-) (limited to 'synapse/rest/client/v1/push_rule.py') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index fded6e4009..92eab20c7c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -20,6 +20,7 @@ from synapse.api.constants import Membership, EventTypes from synapse.util import unwrapFirstError from synapse.util.logcontext import LoggingContext, preserve_fn from synapse.util.metrics import Measure +from synapse.push.clientformat import format_push_rules_for_user from twisted.internet import defer @@ -224,6 +225,10 @@ class SyncHandler(BaseHandler): ) ) + account_data['m.push_rules'] = yield self.push_rules_for_user( + sync_config.user + ) + tags_by_room = yield self.store.get_tags_for_user( sync_config.user.to_string() ) @@ -322,6 +327,14 @@ class SyncHandler(BaseHandler): defer.returnValue(room_sync) + @defer.inlineCallbacks + def push_rules_for_user(self, user): + user_id = user.to_string() + rawrules = yield self.store.get_push_rules_for_user(user_id) + enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) + rules = format_push_rules_for_user(user, rawrules, enabled_map) + defer.returnValue(rules) + def account_data_for_user(self, account_data): account_data_events = [] @@ -481,6 +494,15 @@ class SyncHandler(BaseHandler): ) ) + push_rules_changed = yield self.store.have_push_rules_changed_for_user( + user_id, int(since_token.push_rules_key) + ) + + if push_rules_changed: + account_data["m.push_rules"] = yield self.push_rules_for_user( + sync_config.user + ) + # Get a list of membership change events that have happened. rooms_changed = yield self.store.get_membership_changes_for_user( user_id, since_token.room_key, now_token.room_key diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index edfe28c79b..981d7708db 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -156,7 +156,7 @@ class PushRuleRestServlet(ClientV1RestServlet): return 200, {} def notify_user(self, user_id): - stream_id = self.store.get_push_rules_stream_token() + stream_id, _ = self.store.get_push_rules_stream_token() self.notifier.on_new_event( "push_rules_key", stream_id, users=[user_id] ) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index e2d7b52569..7b7b03d052 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -160,6 +160,11 @@ class DataStore(RoomMemberStore, RoomStore, prefilled_cache=presence_cache_prefill ) + self.push_rules_stream_cache = StreamChangeCache( + "PushRulesStreamChangeCache", + self._push_rules_stream_id_gen.get_max_token()[0], + ) + super(DataStore, self).__init__(hs) def take_presence_startup_info(self): diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index e034024108..792fcbdf5b 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -244,15 +244,10 @@ class PushRuleStore(SQLBaseStore): ) if update_stream: - self._simple_insert_txn( - txn, - table="push_rules_stream", - values={ - "stream_id": stream_id, - "stream_ordering": stream_ordering, - "user_id": user_id, - "rule_id": rule_id, - "op": "ADD", + self._insert_push_rules_update_txn( + txn, stream_id, stream_ordering, user_id, rule_id, + op="ADD", + data={ "priority_class": priority_class, "priority": priority, "conditions": conditions_json, @@ -260,13 +255,6 @@ class PushRuleStore(SQLBaseStore): } ) - txn.call_after( - self.get_push_rules_for_user.invalidate, (user_id,) - ) - txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, (user_id,) - ) - @defer.inlineCallbacks def delete_push_rule(self, user_id, rule_id): """ @@ -284,22 +272,10 @@ class PushRuleStore(SQLBaseStore): "push_rules", {'user_name': user_id, 'rule_id': rule_id}, ) - self._simple_insert_txn( - txn, - table="push_rules_stream", - values={ - "stream_id": stream_id, - "stream_ordering": stream_ordering, - "user_id": user_id, - "rule_id": rule_id, - "op": "DELETE", - } - ) - txn.call_after( - self.get_push_rules_for_user.invalidate, (user_id,) - ) - txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, (user_id,) + + self._insert_push_rules_update_txn( + txn, stream_id, stream_ordering, user_id, rule_id, + op="DELETE" ) with self._push_rules_stream_id_gen.get_next() as (stream_id, stream_ordering): @@ -328,23 +304,9 @@ class PushRuleStore(SQLBaseStore): {'id': new_id}, ) - self._simple_insert_txn( - txn, - "push_rules_stream", - values={ - "stream_id": stream_id, - "stream_ordering": stream_ordering, - "user_id": user_id, - "rule_id": rule_id, - "op": "ENABLE" if enabled else "DISABLE", - } - ) - - txn.call_after( - self.get_push_rules_for_user.invalidate, (user_id,) - ) - txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, (user_id,) + self._insert_push_rules_update_txn( + txn, stream_id, stream_ordering, user_id, rule_id, + op="ENABLE" if enabled else "DISABLE" ) @defer.inlineCallbacks @@ -370,24 +332,9 @@ class PushRuleStore(SQLBaseStore): {'actions': actions_json}, ) - self._simple_insert_txn( - txn, - "push_rules_stream", - values={ - "stream_id": stream_id, - "stream_ordering": stream_ordering, - "user_id": user_id, - "rule_id": rule_id, - "op": "ACTIONS", - "actions": actions_json, - } - ) - - txn.call_after( - self.get_push_rules_for_user.invalidate, (user_id,) - ) - txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, (user_id,) + self._insert_push_rules_update_txn( + txn, stream_id, stream_ordering, user_id, rule_id, + op="ACTIONS", data={"actions": actions_json} ) with self._push_rules_stream_id_gen.get_next() as (stream_id, stream_ordering): @@ -396,6 +343,31 @@ class PushRuleStore(SQLBaseStore): stream_id, stream_ordering ) + def _insert_push_rules_update_txn( + self, txn, stream_id, stream_ordering, user_id, rule_id, op, data=None + ): + values = { + "stream_id": stream_id, + "stream_ordering": stream_ordering, + "user_id": user_id, + "rule_id": rule_id, + "op": op, + } + if data is not None: + values.update(data) + + self._simple_insert_txn(txn, "push_rules_stream", values=values) + + txn.call_after( + self.get_push_rules_for_user.invalidate, (user_id,) + ) + txn.call_after( + self.get_push_rules_enabled_for_user.invalidate, (user_id,) + ) + txn.call_after( + self.push_rules_stream_cache.entity_has_changed, user_id, stream_id + ) + def get_all_push_rule_updates(self, last_id, current_id, limit): """Get all the push rules changes that have happend on the server""" def get_all_push_rule_updates_txn(txn): @@ -403,7 +375,7 @@ class PushRuleStore(SQLBaseStore): "SELECT stream_id, stream_ordering, user_id, rule_id," " op, priority_class, priority, conditions, actions" " FROM push_rules_stream" - " WHERE ? < stream_id and stream_id <= ?" + " WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) txn.execute(sql, (last_id, current_id, limit)) @@ -418,6 +390,23 @@ class PushRuleStore(SQLBaseStore): room stream ordering it corresponds to.""" return self._push_rules_stream_id_gen.get_max_token() + def have_push_rules_changed_for_user(self, user_id, last_id): + if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): + logger.error("FNARG") + return defer.succeed(False) + else: + def have_push_rules_changed_txn(txn): + sql = ( + "SELECT COUNT(stream_id) FROM push_rules_stream" + " WHERE user_id = ? AND ? < stream_id" + ) + txn.execute(sql, (user_id, last_id)) + count, = txn.fetchone() + return bool(count) + return self.runInteraction( + "have_push_rules_changed", have_push_rules_changed_txn + ) + class RuleNotFoundException(Exception): pass -- cgit 1.5.1 From b7dbe5147a85bea8f14b78a27ff499fe5a0d444a Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 9 Mar 2016 11:26:26 +0000 Subject: Add a parse_json_object function to deduplicate all the copy+pasted _parse_json functions. Also document the parse_.* functions. --- synapse/http/servlet.py | 70 ++++++++++++++++++++++++++-- synapse/rest/client/v1/directory.py | 16 ++----- synapse/rest/client/v1/login.py | 15 +----- synapse/rest/client/v1/push_rule.py | 16 ++----- synapse/rest/client/v1/pusher.py | 17 ++----- synapse/rest/client/v1/register.py | 14 +----- synapse/rest/client/v1/room.py | 26 ++++------- synapse/rest/client/v2_alpha/_base.py | 22 --------- synapse/rest/client/v2_alpha/account.py | 8 ++-- synapse/rest/client/v2_alpha/register.py | 8 ++-- synapse/rest/client/v2_alpha/tokenrefresh.py | 6 +-- 11 files changed, 97 insertions(+), 121 deletions(-) (limited to 'synapse/rest/client/v1/push_rule.py') diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 7bd87940b4..41c519c000 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -15,14 +15,27 @@ """ This module contains base REST classes for constructing REST servlets. """ -from synapse.api.errors import SynapseError +from synapse.api.errors import SynapseError, Codes import logging +import simplejson logger = logging.getLogger(__name__) def parse_integer(request, name, default=None, required=False): + """Parse an integer parameter from the request string + + :param request: the twisted HTTP request. + :param name (str): the name of the query parameter. + :param default: value to use if the parameter is absent, defaults to None. + :param required (bool): whether to raise a 400 SynapseError if the + parameter is absent, defaults to False. + :return: An int value or the default. + :raises + SynapseError if the parameter is absent and required, or if the + parameter is present and not an integer. + """ if name in request.args: try: return int(request.args[name][0]) @@ -32,12 +45,25 @@ def parse_integer(request, name, default=None, required=False): else: if required: message = "Missing integer query parameter %r" % (name,) - raise SynapseError(400, message) + raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) else: return default def parse_boolean(request, name, default=None, required=False): + """Parse a boolean parameter from the request query string + + :param request: the twisted HTTP request. + :param name (str): the name of the query parameter. + :param default: value to use if the parameter is absent, defaults to None. + :param required (bool): whether to raise a 400 SynapseError if the + parameter is absent, defaults to False. + :return: A bool value or the default. + :raises + SynapseError if the parameter is absent and required, or if the + parameter is present and not one of "true" or "false". + """ + if name in request.args: try: return { @@ -53,30 +79,64 @@ def parse_boolean(request, name, default=None, required=False): else: if required: message = "Missing boolean query parameter %r" % (name,) - raise SynapseError(400, message) + raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) else: return default def parse_string(request, name, default=None, required=False, allowed_values=None, param_type="string"): + """Parse a string parameter from the request query string. + + :param request: the twisted HTTP request. + :param name (str): the name of the query parameter. + :param default: value to use if the parameter is absent, defaults to None. + :param required (bool): whether to raise a 400 SynapseError if the + parameter is absent, defaults to False. + :param allowed_values (list): List of allowed values for the string, + or None if any value is allowed, defaults to None + :return: A string value or the default. + :raises + SynapseError if the parameter is absent and required, or if the + parameter is present, must be one of a list of allowed values and + is not one of those allowed values. + """ + if name in request.args: value = request.args[name][0] if allowed_values is not None and value not in allowed_values: message = "Query parameter %r must be one of [%s]" % ( name, ", ".join(repr(v) for v in allowed_values) ) - raise SynapseError(message) + raise SynapseError(400, message) else: return value else: if required: message = "Missing %s query parameter %r" % (param_type, name) - raise SynapseError(400, message) + raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) else: return default +def parse_json_object_from_request(request): + """Parse a JSON object from the body of a twisted HTTP request. + + :param request: the twisted HTTP request. + :raises + SynapseError if the request body couldn't be decoded as JSON or + if it wasn't a JSON object. + """ + try: + content = simplejson.loads(request.content.read()) + if type(content) != dict: + message = "Content must be a JSON object." + raise SynapseError(400, message, errcode=Codes.BAD_JSON) + return content + except simplejson.JSONDecodeError: + raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) + + class RestServlet(object): """ A Synapse REST Servlet. diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 8bfe9fdea8..60c5ec77aa 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -18,9 +18,10 @@ from twisted.internet import defer from synapse.api.errors import AuthError, SynapseError, Codes from synapse.types import RoomAlias +from synapse.http.servlet import parse_json_object_from_request + from .base import ClientV1RestServlet, client_path_patterns -import simplejson as json import logging @@ -45,7 +46,7 @@ class ClientDirectoryServer(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_alias): - content = _parse_json(request) + content = parse_json_object_from_request(request) if "room_id" not in content: raise SynapseError(400, "Missing room_id key", errcode=Codes.BAD_JSON) @@ -135,14 +136,3 @@ class ClientDirectoryServer(ClientV1RestServlet): ) defer.returnValue((200, {})) - - -def _parse_json(request): - try: - content = json.loads(request.content.read()) - if type(content) != dict: - raise SynapseError(400, "Content must be a JSON object.", - errcode=Codes.NOT_JSON) - return content - except ValueError: - raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index f6902a60a8..fe593d07ce 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, LoginError, Codes from synapse.types import UserID from synapse.http.server import finish_request +from synapse.http.servlet import parse_json_object_from_request from .base import ClientV1RestServlet, client_path_patterns @@ -79,7 +80,7 @@ class LoginRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - login_submission = _parse_json(request) + login_submission = parse_json_object_from_request(request) try: if login_submission["type"] == LoginRestServlet.PASS_TYPE: if not self.password_enabled: @@ -400,18 +401,6 @@ class CasTicketServlet(ClientV1RestServlet): return (user, attributes) -def _parse_json(request): - try: - content = json.loads(request.content.read()) - if type(content) != dict: - raise SynapseError( - 400, "Content must be a JSON object.", errcode=Codes.BAD_JSON - ) - return content - except ValueError: - raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) - - def register_servlets(hs, http_server): LoginRestServlet(hs).register(http_server) if hs.config.saml2_enabled: diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 981d7708db..b5695d427a 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -16,7 +16,7 @@ from twisted.internet import defer from synapse.api.errors import ( - SynapseError, Codes, UnrecognizedRequestError, NotFoundError, StoreError + SynapseError, UnrecognizedRequestError, NotFoundError, StoreError ) from .base import ClientV1RestServlet, client_path_patterns from synapse.storage.push_rule import ( @@ -25,8 +25,7 @@ from synapse.storage.push_rule import ( from synapse.push.clientformat import format_push_rules_for_user from synapse.push.baserules import BASE_RULE_IDS from synapse.push.rulekinds import PRIORITY_CLASS_MAP - -import simplejson as json +from synapse.http.servlet import parse_json_object_from_request class PushRuleRestServlet(ClientV1RestServlet): @@ -52,7 +51,7 @@ class PushRuleRestServlet(ClientV1RestServlet): if '/' in spec['rule_id'] or '\\' in spec['rule_id']: raise SynapseError(400, "rule_id may not contain slashes") - content = _parse_json(request) + content = parse_json_object_from_request(request) user_id = requester.user.to_string() @@ -341,14 +340,5 @@ class InvalidRuleException(Exception): pass -# XXX: C+ped from rest/room.py - surely this should be common? -def _parse_json(request): - try: - content = json.loads(request.content.read()) - return content - except ValueError: - raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) - - def register_servlets(hs, http_server): PushRuleRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 4c662e6e3c..ee029b4f77 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -17,9 +17,10 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, Codes from synapse.push import PusherConfigException +from synapse.http.servlet import parse_json_object_from_request + from .base import ClientV1RestServlet, client_path_patterns -import simplejson as json import logging logger = logging.getLogger(__name__) @@ -33,7 +34,7 @@ class PusherRestServlet(ClientV1RestServlet): requester = yield self.auth.get_user_by_req(request) user = requester.user - content = _parse_json(request) + content = parse_json_object_from_request(request) pusher_pool = self.hs.get_pusherpool() @@ -92,17 +93,5 @@ class PusherRestServlet(ClientV1RestServlet): return 200, {} -# XXX: C+ped from rest/room.py - surely this should be common? -def _parse_json(request): - try: - content = json.loads(request.content.read()) - if type(content) != dict: - raise SynapseError(400, "Content must be a JSON object.", - errcode=Codes.NOT_JSON) - return content - except ValueError: - raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) - - def register_servlets(hs, http_server): PusherRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 040a7a7ffa..c6a2ef2ccc 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -20,12 +20,12 @@ from synapse.api.errors import SynapseError, Codes from synapse.api.constants import LoginType from .base import ClientV1RestServlet, client_path_patterns import synapse.util.stringutils as stringutils +from synapse.http.servlet import parse_json_object_from_request from synapse.util.async import run_on_reactor from hashlib import sha1 import hmac -import simplejson as json import logging logger = logging.getLogger(__name__) @@ -98,7 +98,7 @@ class RegisterRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - register_json = _parse_json(request) + register_json = parse_json_object_from_request(request) session = (register_json["session"] if "session" in register_json else None) @@ -355,15 +355,5 @@ class RegisterRestServlet(ClientV1RestServlet): ) -def _parse_json(request): - try: - content = json.loads(request.content.read()) - if type(content) != dict: - raise SynapseError(400, "Content must be a JSON object.") - return content - except ValueError: - raise SynapseError(400, "Content not JSON.") - - def register_servlets(hs, http_server): RegisterRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 4b7d198c52..7a9f3d11b9 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -22,6 +22,7 @@ from synapse.streams.config import PaginationConfig from synapse.api.constants import EventTypes, Membership from synapse.types import UserID, RoomID, RoomAlias from synapse.events.utils import serialize_event +from synapse.http.servlet import parse_json_object_from_request import simplejson as json import logging @@ -137,7 +138,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): requester = yield self.auth.get_user_by_req(request) - content = _parse_json(request) + content = parse_json_object_from_request(request) event_dict = { "type": event_type, @@ -179,7 +180,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, event_type, txn_id=None): requester = yield self.auth.get_user_by_req(request, allow_guest=True) - content = _parse_json(request) + content = parse_json_object_from_request(request) msg_handler = self.handlers.message_handler event = yield msg_handler.create_and_send_nonmember_event( @@ -229,7 +230,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): ) try: - content = _parse_json(request) + content = parse_json_object_from_request(request) except: # Turns out we used to ignore the body entirely, and some clients # cheekily send invalid bodies. @@ -433,7 +434,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): raise AuthError(403, "Guest access not allowed") try: - content = _parse_json(request) + content = parse_json_object_from_request(request) except: # Turns out we used to ignore the body entirely, and some clients # cheekily send invalid bodies. @@ -500,7 +501,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, event_id, txn_id=None): requester = yield self.auth.get_user_by_req(request) - content = _parse_json(request) + content = parse_json_object_from_request(request) msg_handler = self.handlers.message_handler event = yield msg_handler.create_and_send_nonmember_event( @@ -548,7 +549,7 @@ class RoomTypingRestServlet(ClientV1RestServlet): room_id = urllib.unquote(room_id) target_user = UserID.from_string(urllib.unquote(user_id)) - content = _parse_json(request) + content = parse_json_object_from_request(request) typing_handler = self.handlers.typing_notification_handler @@ -580,7 +581,7 @@ class SearchRestServlet(ClientV1RestServlet): def on_POST(self, request): requester = yield self.auth.get_user_by_req(request) - content = _parse_json(request) + content = parse_json_object_from_request(request) batch = request.args.get("next_batch", [None])[0] results = yield self.handlers.search_handler.search( @@ -592,17 +593,6 @@ class SearchRestServlet(ClientV1RestServlet): defer.returnValue((200, results)) -def _parse_json(request): - try: - content = json.loads(request.content.read()) - if type(content) != dict: - raise SynapseError(400, "Content must be a JSON object.", - errcode=Codes.NOT_JSON) - return content - except ValueError: - raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) - - def register_txn_path(servlet, regex_string, http_server, with_get=False): """Registers a transaction-based path. diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index 24af322126..b6faa2b0e6 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -17,11 +17,9 @@ """ from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX -from synapse.api.errors import SynapseError import re import logging -import simplejson logger = logging.getLogger(__name__) @@ -44,23 +42,3 @@ def client_v2_patterns(path_regex, releases=(0,)): new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release) patterns.append(re.compile("^" + new_prefix + path_regex)) return patterns - - -def parse_request_allow_empty(request): - content = request.content.read() - if content is None or content == '': - return None - try: - return simplejson.loads(content) - except simplejson.JSONDecodeError: - raise SynapseError(400, "Content not JSON.") - - -def parse_json_dict_from_request(request): - try: - content = simplejson.loads(request.content.read()) - if type(content) != dict: - raise SynapseError(400, "Content must be a JSON object.") - return content - except simplejson.JSONDecodeError: - raise SynapseError(400, "Content not JSON.") diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index a614b79d45..688b051580 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -17,10 +17,10 @@ from twisted.internet import defer from synapse.api.constants import LoginType from synapse.api.errors import LoginError, SynapseError, Codes -from synapse.http.servlet import RestServlet +from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.util.async import run_on_reactor -from ._base import client_v2_patterns, parse_json_dict_from_request +from ._base import client_v2_patterns import logging @@ -41,7 +41,7 @@ class PasswordRestServlet(RestServlet): def on_POST(self, request): yield run_on_reactor() - body = parse_json_dict_from_request(request) + body = parse_json_object_from_request(request) authed, result, params = yield self.auth_handler.check_auth([ [LoginType.PASSWORD], @@ -114,7 +114,7 @@ class ThreepidRestServlet(RestServlet): def on_POST(self, request): yield run_on_reactor() - body = parse_json_dict_from_request(request) + body = parse_json_object_from_request(request) threePidCreds = body.get('threePidCreds') threePidCreds = body.get('three_pid_creds', threePidCreds) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index ec5c21fa1f..b090e66bcf 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -17,9 +17,9 @@ from twisted.internet import defer from synapse.api.constants import LoginType from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError -from synapse.http.servlet import RestServlet +from synapse.http.servlet import RestServlet, parse_json_object_from_request -from ._base import client_v2_patterns, parse_json_dict_from_request +from ._base import client_v2_patterns import logging import hmac @@ -73,7 +73,7 @@ class RegisterRestServlet(RestServlet): ret = yield self.onEmailTokenRequest(request) defer.returnValue(ret) - body = parse_json_dict_from_request(request) + body = parse_json_object_from_request(request) # we do basic sanity checks here because the auth layer will store these # in sessions. Pull out the username/password provided to us. @@ -236,7 +236,7 @@ class RegisterRestServlet(RestServlet): @defer.inlineCallbacks def onEmailTokenRequest(self, request): - body = parse_json_dict_from_request(request) + body = parse_json_object_from_request(request) required = ['id_server', 'client_secret', 'email', 'send_attempt'] absent = [] diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 3553f6b040..a158c2209a 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -16,9 +16,9 @@ from twisted.internet import defer from synapse.api.errors import AuthError, StoreError, SynapseError -from synapse.http.servlet import RestServlet +from synapse.http.servlet import RestServlet, parse_json_object_from_request -from ._base import client_v2_patterns, parse_json_dict_from_request +from ._base import client_v2_patterns class TokenRefreshRestServlet(RestServlet): @@ -35,7 +35,7 @@ class TokenRefreshRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): - body = parse_json_dict_from_request(request) + body = parse_json_object_from_request(request) try: old_refresh_token = body["refresh_token"] auth_handler = self.hs.get_handlers().auth_handler -- cgit 1.5.1 From 398cd1edfbefa207e44047bd63adcdcc6e859f2e Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Mon, 14 Mar 2016 14:16:41 +0000 Subject: Fix regression where synapse checked whether push rules were valid JSON before the compatibility hack that handled clients sending invalid JSON --- synapse/http/servlet.py | 21 +++++++++++++++++---- synapse/rest/client/v1/push_rule.py | 4 ++-- 2 files changed, 19 insertions(+), 6 deletions(-) (limited to 'synapse/rest/client/v1/push_rule.py') diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 1996f8b136..1c8bd8666f 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -119,13 +119,13 @@ def parse_string(request, name, default=None, required=False, return default -def parse_json_object_from_request(request): - """Parse a JSON object from the body of a twisted HTTP request. +def parse_json_value_from_request(request): + """Parse a JSON value from the body of a twisted HTTP request. :param request: the twisted HTTP request. + :returns: The JSON value. :raises - SynapseError if the request body couldn't be decoded as JSON or - if it wasn't a JSON object. + SynapseError if the request body couldn't be decoded as JSON. """ try: content_bytes = request.content.read() @@ -137,6 +137,19 @@ def parse_json_object_from_request(request): except simplejson.JSONDecodeError: raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) + return content + + +def parse_json_object_from_request(request): + """Parse a JSON object from the body of a twisted HTTP request. + + :param request: the twisted HTTP request. + :raises + SynapseError if the request body couldn't be decoded as JSON or + if it wasn't a JSON object. + """ + content = parse_json_value_from_request(request) + if type(content) != dict: message = "Content must be a JSON object." raise SynapseError(400, message, errcode=Codes.BAD_JSON) diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index b5695d427a..02d837ee6a 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -25,7 +25,7 @@ from synapse.storage.push_rule import ( from synapse.push.clientformat import format_push_rules_for_user from synapse.push.baserules import BASE_RULE_IDS from synapse.push.rulekinds import PRIORITY_CLASS_MAP -from synapse.http.servlet import parse_json_object_from_request +from synapse.http.servlet import parse_json_value_from_request class PushRuleRestServlet(ClientV1RestServlet): @@ -51,7 +51,7 @@ class PushRuleRestServlet(ClientV1RestServlet): if '/' in spec['rule_id'] or '\\' in spec['rule_id']: raise SynapseError(400, "rule_id may not contain slashes") - content = parse_json_object_from_request(request) + content = parse_json_value_from_request(request) user_id = requester.user.to_string() -- cgit 1.5.1