summary refs log tree commit diff
path: root/synapse/rest/client/v1/push_rule.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client/v1/push_rule.py')
-rw-r--r--synapse/rest/client/v1/push_rule.py280
1 files changed, 72 insertions, 208 deletions
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 2272d66dc7..02d837ee6a 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -16,19 +16,16 @@
 from twisted.internet import defer
 
 from synapse.api.errors import (
-    SynapseError, Codes, UnrecognizedRequestError, NotFoundError, StoreError
+    SynapseError, UnrecognizedRequestError, NotFoundError, StoreError
 )
 from .base import ClientV1RestServlet, client_path_patterns
 from synapse.storage.push_rule import (
     InconsistentRuleException, RuleNotFoundException
 )
-import synapse.push.baserules as baserules
-from synapse.push.rulekinds import (
-    PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
-)
-
-import copy
-import simplejson as json
+from synapse.push.clientformat import format_push_rules_for_user
+from synapse.push.baserules import BASE_RULE_IDS
+from synapse.push.rulekinds import PRIORITY_CLASS_MAP
+from synapse.http.servlet import parse_json_value_from_request
 
 
 class PushRuleRestServlet(ClientV1RestServlet):
@@ -36,6 +33,11 @@ class PushRuleRestServlet(ClientV1RestServlet):
     SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
         "Unrecognised request: You probably wanted a trailing slash")
 
+    def __init__(self, hs):
+        super(PushRuleRestServlet, self).__init__(hs)
+        self.store = hs.get_datastore()
+        self.notifier = hs.get_notifier()
+
     @defer.inlineCallbacks
     def on_PUT(self, request):
         spec = _rule_spec_from_path(request.postpath)
@@ -49,32 +51,39 @@ class PushRuleRestServlet(ClientV1RestServlet):
         if '/' in spec['rule_id'] or '\\' in spec['rule_id']:
             raise SynapseError(400, "rule_id may not contain slashes")
 
-        content = _parse_json(request)
+        content = parse_json_value_from_request(request)
+
+        user_id = requester.user.to_string()
 
         if 'attr' in spec:
-            self.set_rule_attr(requester.user.to_string(), spec, content)
+            yield self.set_rule_attr(user_id, spec, content)
+            self.notify_user(user_id)
             defer.returnValue((200, {}))
 
+        if spec['rule_id'].startswith('.'):
+            # Rule ids starting with '.' are reserved for server default rules.
+            raise SynapseError(400, "cannot add new rule_ids that start with '.'")
+
         try:
             (conditions, actions) = _rule_tuple_from_request_object(
                 spec['template'],
                 spec['rule_id'],
                 content,
-                device=spec['device'] if 'device' in spec else None
             )
         except InvalidRuleException as e:
             raise SynapseError(400, e.message)
 
         before = request.args.get("before", None)
-        if before and len(before):
-            before = before[0]
+        if before:
+            before = _namespaced_rule_id(spec, before[0])
+
         after = request.args.get("after", None)
-        if after and len(after):
-            after = after[0]
+        if after:
+            after = _namespaced_rule_id(spec, after[0])
 
         try:
-            yield self.hs.get_datastore().add_push_rule(
-                user_id=requester.user.to_string(),
+            yield self.store.add_push_rule(
+                user_id=user_id,
                 rule_id=_namespaced_rule_id_from_spec(spec),
                 priority_class=priority_class,
                 conditions=conditions,
@@ -82,6 +91,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
                 before=before,
                 after=after
             )
+            self.notify_user(user_id)
         except InconsistentRuleException as e:
             raise SynapseError(400, e.message)
         except RuleNotFoundException as e:
@@ -94,13 +104,15 @@ class PushRuleRestServlet(ClientV1RestServlet):
         spec = _rule_spec_from_path(request.postpath)
 
         requester = yield self.auth.get_user_by_req(request)
+        user_id = requester.user.to_string()
 
         namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
 
         try:
-            yield self.hs.get_datastore().delete_push_rule(
-                requester.user.to_string(), namespaced_rule_id
+            yield self.store.delete_push_rule(
+                user_id, namespaced_rule_id
             )
+            self.notify_user(user_id)
             defer.returnValue((200, {}))
         except StoreError as e:
             if e.code == 404:
@@ -111,74 +123,16 @@ class PushRuleRestServlet(ClientV1RestServlet):
     @defer.inlineCallbacks
     def on_GET(self, request):
         requester = yield self.auth.get_user_by_req(request)
-        user = requester.user
+        user_id = requester.user.to_string()
 
         # we build up the full structure and then decide which bits of it
         # to send which means doing unnecessary work sometimes but is
         # is probably not going to make a whole lot of difference
-        rawrules = yield self.hs.get_datastore().get_push_rules_for_user(
-            user.to_string()
-        )
+        rawrules = yield self.store.get_push_rules_for_user(user_id)
 
-        ruleslist = []
-        for rawrule in rawrules:
-            rule = dict(rawrule)
-            rule["conditions"] = json.loads(rawrule["conditions"])
-            rule["actions"] = json.loads(rawrule["actions"])
-            ruleslist.append(rule)
-
-        # We're going to be mutating this a lot, so do a deep copy
-        ruleslist = copy.deepcopy(baserules.list_with_base_rules(ruleslist))
-
-        rules = {'global': {}, 'device': {}}
-
-        rules['global'] = _add_empty_priority_class_arrays(rules['global'])
-
-        enabled_map = yield self.hs.get_datastore().\
-            get_push_rules_enabled_for_user(user.to_string())
-
-        for r in ruleslist:
-            rulearray = None
-
-            template_name = _priority_class_to_template_name(r['priority_class'])
-
-            # Remove internal stuff.
-            for c in r["conditions"]:
-                c.pop("_id", None)
-
-                pattern_type = c.pop("pattern_type", None)
-                if pattern_type == "user_id":
-                    c["pattern"] = user.to_string()
-                elif pattern_type == "user_localpart":
-                    c["pattern"] = user.localpart
-
-            if r['priority_class'] > PRIORITY_CLASS_MAP['override']:
-                # per-device rule
-                profile_tag = _profile_tag_from_conditions(r["conditions"])
-                r = _strip_device_condition(r)
-                if not profile_tag:
-                    continue
-                if profile_tag not in rules['device']:
-                    rules['device'][profile_tag] = {}
-                    rules['device'][profile_tag] = (
-                        _add_empty_priority_class_arrays(
-                            rules['device'][profile_tag]
-                        )
-                    )
-
-                rulearray = rules['device'][profile_tag][template_name]
-            else:
-                rulearray = rules['global'][template_name]
-
-            template_rule = _rule_to_template(r)
-            if template_rule:
-                if r['rule_id'] in enabled_map:
-                    template_rule['enabled'] = enabled_map[r['rule_id']]
-                elif 'enabled' in r:
-                    template_rule['enabled'] = r['enabled']
-                else:
-                    template_rule['enabled'] = True
-                rulearray.append(template_rule)
+        enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
+
+        rules = format_push_rules_for_user(requester.user, rawrules, enabled_map)
 
         path = request.postpath[1:]
 
@@ -194,30 +148,18 @@ class PushRuleRestServlet(ClientV1RestServlet):
             path = path[1:]
             result = _filter_ruleset_with_path(rules['global'], path)
             defer.returnValue((200, result))
-        elif path[0] == 'device':
-            path = path[1:]
-            if path == []:
-                raise UnrecognizedRequestError(
-                    PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
-                )
-            if path[0] == '':
-                defer.returnValue((200, rules['device']))
-
-            profile_tag = path[0]
-            path = path[1:]
-            if profile_tag not in rules['device']:
-                ret = {}
-                ret = _add_empty_priority_class_arrays(ret)
-                defer.returnValue((200, ret))
-            ruleset = rules['device'][profile_tag]
-            result = _filter_ruleset_with_path(ruleset, path)
-            defer.returnValue((200, result))
         else:
             raise UnrecognizedRequestError()
 
     def on_OPTIONS(self, _):
         return 200, {}
 
+    def notify_user(self, user_id):
+        stream_id, _ = self.store.get_push_rules_stream_token()
+        self.notifier.on_new_event(
+            "push_rules_key", stream_id, users=[user_id]
+        )
+
     def set_rule_attr(self, user_id, spec, val):
         if spec['attr'] == 'enabled':
             if isinstance(val, dict) and "enabled" in val:
@@ -228,16 +170,20 @@ class PushRuleRestServlet(ClientV1RestServlet):
                 # bools directly, so let's not break them.
                 raise SynapseError(400, "Value for 'enabled' must be boolean")
             namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
-            self.hs.get_datastore().set_push_rule_enabled(
+            return self.store.set_push_rule_enabled(
                 user_id, namespaced_rule_id, val
             )
-        else:
-            raise UnrecognizedRequestError()
-
-    def get_rule_attr(self, user_id, namespaced_rule_id, attr):
-        if attr == 'enabled':
-            return self.hs.get_datastore().get_push_rule_enabled_by_user_rule_id(
-                user_id, namespaced_rule_id
+        elif spec['attr'] == 'actions':
+            actions = val.get('actions')
+            _check_actions(actions)
+            namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
+            rule_id = spec['rule_id']
+            is_default_rule = rule_id.startswith(".")
+            if is_default_rule:
+                if namespaced_rule_id not in BASE_RULE_IDS:
+                    raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
+            return self.store.set_push_rule_actions(
+                user_id, namespaced_rule_id, actions, is_default_rule
             )
         else:
             raise UnrecognizedRequestError()
@@ -251,16 +197,9 @@ def _rule_spec_from_path(path):
 
     scope = path[1]
     path = path[2:]
-    if scope not in ['global', 'device']:
+    if scope != 'global':
         raise UnrecognizedRequestError()
 
-    device = None
-    if scope == 'device':
-        if len(path) == 0:
-            raise UnrecognizedRequestError()
-        device = path[0]
-        path = path[1:]
-
     if len(path) == 0:
         raise UnrecognizedRequestError()
 
@@ -277,8 +216,6 @@ def _rule_spec_from_path(path):
         'template': template,
         'rule_id': rule_id
     }
-    if device:
-        spec['profile_tag'] = device
 
     path = path[1:]
 
@@ -288,7 +225,7 @@ def _rule_spec_from_path(path):
     return spec
 
 
-def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None):
+def _rule_tuple_from_request_object(rule_template, rule_id, req_obj):
     if rule_template in ['override', 'underride']:
         if 'conditions' not in req_obj:
             raise InvalidRuleException("Missing 'conditions'")
@@ -321,16 +258,19 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None
     else:
         raise InvalidRuleException("Unknown rule template: %s" % (rule_template,))
 
-    if device:
-        conditions.append({
-            'kind': 'device',
-            'profile_tag': device
-        })
-
     if 'actions' not in req_obj:
         raise InvalidRuleException("No actions found")
     actions = req_obj['actions']
 
+    _check_actions(actions)
+
+    return conditions, actions
+
+
+def _check_actions(actions):
+    if not isinstance(actions, list):
+        raise InvalidRuleException("No actions found")
+
     for a in actions:
         if a in ['notify', 'dont_notify', 'coalesce']:
             pass
@@ -339,25 +279,6 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None
         else:
             raise InvalidRuleException("Unrecognised action")
 
-    return conditions, actions
-
-
-def _add_empty_priority_class_arrays(d):
-    for pc in PRIORITY_CLASS_MAP.keys():
-        d[pc] = []
-    return d
-
-
-def _profile_tag_from_conditions(conditions):
-    """
-    Given a list of conditions, return the profile tag of the
-    device rule if there is one
-    """
-    for c in conditions:
-        if c['kind'] == 'device':
-            return c['profile_tag']
-    return None
-
 
 def _filter_ruleset_with_path(ruleset, path):
     if path == []:
@@ -392,89 +313,32 @@ def _filter_ruleset_with_path(ruleset, path):
 
     attr = path[0]
     if attr in the_rule:
-        return the_rule[attr]
+        # Make sure we return a JSON object as the attribute may be a
+        # JSON value.
+        return {attr: the_rule[attr]}
     else:
         raise UnrecognizedRequestError()
 
 
 def _priority_class_from_spec(spec):
     if spec['template'] not in PRIORITY_CLASS_MAP.keys():
-        raise InvalidRuleException("Unknown template: %s" % (spec['kind']))
+        raise InvalidRuleException("Unknown template: %s" % (spec['template']))
     pc = PRIORITY_CLASS_MAP[spec['template']]
 
-    if spec['scope'] == 'device':
-        pc += len(PRIORITY_CLASS_MAP)
-
     return pc
 
 
-def _priority_class_to_template_name(pc):
-    if pc > PRIORITY_CLASS_MAP['override']:
-        # per-device
-        prio_class_index = pc - len(PRIORITY_CLASS_MAP)
-        return PRIORITY_CLASS_INVERSE_MAP[prio_class_index]
-    else:
-        return PRIORITY_CLASS_INVERSE_MAP[pc]
-
-
-def _rule_to_template(rule):
-    unscoped_rule_id = None
-    if 'rule_id' in rule:
-        unscoped_rule_id = _rule_id_from_namespaced(rule['rule_id'])
-
-    template_name = _priority_class_to_template_name(rule['priority_class'])
-    if template_name in ['override', 'underride']:
-        templaterule = {k: rule[k] for k in ["conditions", "actions"]}
-    elif template_name in ["sender", "room"]:
-        templaterule = {'actions': rule['actions']}
-        unscoped_rule_id = rule['conditions'][0]['pattern']
-    elif template_name == 'content':
-        if len(rule["conditions"]) != 1:
-            return None
-        thecond = rule["conditions"][0]
-        if "pattern" not in thecond:
-            return None
-        templaterule = {'actions': rule['actions']}
-        templaterule["pattern"] = thecond["pattern"]
-
-    if unscoped_rule_id:
-            templaterule['rule_id'] = unscoped_rule_id
-    if 'default' in rule:
-        templaterule['default'] = rule['default']
-    return templaterule
-
-
-def _strip_device_condition(rule):
-    for i, c in enumerate(rule['conditions']):
-        if c['kind'] == 'device':
-            del rule['conditions'][i]
-    return rule
-
-
 def _namespaced_rule_id_from_spec(spec):
-    if spec['scope'] == 'global':
-        scope = 'global'
-    else:
-        scope = 'device/%s' % (spec['profile_tag'])
-    return "%s/%s/%s" % (scope, spec['template'], spec['rule_id'])
+    return _namespaced_rule_id(spec, spec['rule_id'])
 
 
-def _rule_id_from_namespaced(in_rule_id):
-    return in_rule_id.split('/')[-1]
+def _namespaced_rule_id(spec, rule_id):
+    return "global/%s/%s" % (spec['template'], rule_id)
 
 
 class InvalidRuleException(Exception):
     pass
 
 
-# XXX: C+ped from rest/room.py - surely this should be common?
-def _parse_json(request):
-    try:
-        content = json.loads(request.content.read())
-        return content
-    except ValueError:
-        raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
-
-
 def register_servlets(hs, http_server):
     PushRuleRestServlet(hs).register(http_server)