summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorDavid Baker <dave@matrix.org>2015-01-22 19:32:17 +0000
committerDavid Baker <dave@matrix.org>2015-01-22 19:32:17 +0000
commit8a850573c9cf50dd83ba47c033b28fe2bbbaf9d4 (patch)
tree2d61c19377aebe7d855061dde7bb91fc154a0e63 /synapse
parentoops, this is not its own schema file (diff)
downloadsynapse-8a850573c9cf50dd83ba47c033b28fe2bbbaf9d4.tar.xz
As yet fairly untested GET API for push rules
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/errors.py14
-rw-r--r--synapse/rest/client/v1/push_rule.py138
-rw-r--r--synapse/storage/push_rule.py8
3 files changed, 145 insertions, 15 deletions
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 55181fe77e..01207282d6 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -87,13 +87,25 @@ class UnrecognizedRequestError(SynapseError):
     """An error indicating we don't understand the request you're trying to make"""
     def __init__(self, *args, **kwargs):
         if "errcode" not in kwargs:
-            kwargs["errcode"] = Codes.NOT_FOUND
+            kwargs["errcode"] = Codes.UNRECOGNIZED
         super(UnrecognizedRequestError, self).__init__(
             400,
             "Unrecognized request",
             **kwargs
         )
 
+
+class NotFoundError(SynapseError):
+    """An error indicating we can't find the thing you asked for"""
+    def __init__(self, *args, **kwargs):
+        if "errcode" not in kwargs:
+            kwargs["errcode"] = Codes.NOT_FOUND
+        super(UnrecognizedRequestError, self).__init__(
+            404,
+            "Not found",
+            **kwargs
+        )
+
 class AuthError(SynapseError):
     """An error raised when there was a problem authorising an event."""
 
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index b5e74479cf..2803c1f071 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -15,7 +15,7 @@
 
 from twisted.internet import defer
 
-from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
+from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError, NotFoundError
 from base import RestServlet, client_path_pattern
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
 
@@ -24,6 +24,14 @@ import json
 
 class PushRuleRestServlet(RestServlet):
     PATTERN = client_path_pattern("/pushrules/.*$")
+    PRIORITY_CLASS_MAP = {
+        'underride': 0,
+        'sender': 1,
+        'room': 2,
+        'content': 3,
+        'override': 4
+    }
+    PRIORITY_CLASS_INVERSE_MAP = {v: k for k,v in PRIORITY_CLASS_MAP.items()}
 
     def rule_spec_from_path(self, path):
         if len(path) < 2:
@@ -109,15 +117,7 @@ class PushRuleRestServlet(RestServlet):
         return (conditions, actions)
 
     def priority_class_from_spec(self, spec):
-        map = {
-            'underride': 0,
-            'sender': 1,
-            'room': 2,
-            'content': 3,
-            'override': 4
-        }
-
-        if spec['template'] not in map.keys():
+        if spec['template'] not in PushRuleRestServlet.PRIORITY_CLASS_MAP.keys():
             raise InvalidRuleException("Unknown template: %s" % (spec['kind']))
         pc = map[spec['template']]
 
@@ -171,10 +171,128 @@ class PushRuleRestServlet(RestServlet):
 
         defer.returnValue((200, {}))
 
+    @defer.inlineCallbacks
+    def on_GET(self, request):
+        user = yield self.auth.get_user_by_req(request)
+
+        # we build up the full structure and then decide which bits of it
+        # to send which means doing unnecessary work sometimes but is
+        # is probably not going to make a whole lot of difference
+        rawrules = yield self.hs.get_datastore().get_push_rules_for_user_name(user.to_string())
+
+        rules = {'global': {}, 'device': {}}
+
+        rules['global'] = _add_empty_priority_class_arrays(rules['global'])
+
+        for r in rawrules:
+            rulearray = None
+
+            r["conditions"] = json.loads(r["conditions"])
+            r["actions"] = json.loads(r["actions"])
+
+            template_name = _priority_class_to_template_name(r['priority_class'])
+
+            if r['priority_class'] > PushRuleRestServlet.PRIORITY_CLASS_MAP['override']:
+                # per-device rule
+                instance_handle = _instance_handle_from_conditions(r["conditions"])
+                if not instance_handle:
+                    continue
+                if instance_handle not in rules['device']:
+                    rules['device'][instance_handle] = []
+                    rules['device'][instance_handle] = \
+                        _add_empty_priority_class_arrays(rules['device'][instance_handle])
+
+                rulearray = rules['device'][instance_handle]
+            else:
+                rulearray = rules['global'][template_name]
+
+            template_rule = _rule_to_template(r)
+            if template_rule:
+                rulearray.append(template_rule)
+
+        path = request.postpath[1:]
+        if path == []:
+            defer.returnValue((200, rules))
+
+        if path[0] == 'global':
+            path = path[1:]
+            result = _filter_ruleset_with_path(rules['global'], path)
+            defer.returnValue((200, result))
+        elif path[0] == 'device':
+            path = path[1:]
+            if path == []:
+                raise UnrecognizedRequestError
+            instance_handle = path[0]
+            if instance_handle not in rules['device']:
+                ret = {}
+                ret = _add_empty_priority_class_arrays(ret)
+                defer.returnValue((200, ret))
+            ruleset = rules['device'][instance_handle]
+            result = _filter_ruleset_with_path(ruleset, path)
+            defer.returnValue((200, result))
+        else:
+            raise UnrecognizedRequestError()
+
+
     def on_OPTIONS(self, _):
         return 200, {}
 
 
+def _add_empty_priority_class_arrays(d):
+    for pc in PushRuleRestServlet.PRIORITY_CLASS_MAP.keys():
+        d[pc] = []
+    return d
+
+def _instance_handle_from_conditions(conditions):
+    """
+    Given a list of conditions, return the instance handle of the
+    device rule if there is one
+    """
+    for c in conditions:
+        if c['kind'] == 'device':
+            return c['instance_handle']
+    return None
+
+def _filter_ruleset_with_path(ruleset, path):
+    if path == []:
+        return ruleset
+    template_kind = path[0]
+    if template_kind not in ruleset:
+        raise UnrecognizedRequestError()
+    path = path[1:]
+    if path == []:
+        return ruleset[template_kind]
+    rule_id = path[0]
+    for r in ruleset[template_kind]:
+        if r['rule_id'] == rule_id:
+            return r
+    raise NotFoundError
+
+def _priority_class_to_template_name(pc):
+    if pc > PushRuleRestServlet.PRIORITY_CLASS_MAP['override']:
+        # per-device
+        prio_class_index = pc - PushRuleRestServlet.PRIORITY_CLASS_MAP['override']
+        return PushRuleRestServlet.PRIORITY_CLASS_INVERSE_MAP[prio_class_index]
+    else:
+        return PushRuleRestServlet.PRIORITY_CLASS_INVERSE_MAP[pc]
+
+def _rule_to_template(rule):
+    template_name = _priority_class_to_template_name(rule['priority_class'])
+    if template_name in ['override', 'underride']:
+        return {k:rule[k] for k in ["rule_id", "conditions", "actions"]}
+    elif template_name in ["sender", "room"]:
+        return {k:rule[k] for k in ["rule_id", "actions"]}
+    elif template_name == 'content':
+        if len(rule["conditions"]) != 1:
+            return None
+        thecond = rule["conditions"][0]
+        if "pattern" not in thecond:
+            return None
+        ret = {k:rule[k] for k in ["rule_id", "actions"]}
+        ret["pattern"] = thecond["pattern"]
+        return ret
+
+
 class InvalidRuleException(Exception):
     pass
 
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index dbbb35b2ab..d087257ffc 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -29,11 +29,11 @@ class PushRuleStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_push_rules_for_user_name(self, user_name):
         sql = (
-            "SELECT "+",".join(PushRuleTable.fields)+
-            "FROM pushers "
-            "WHERE user_name = ?"
+            "SELECT "+",".join(PushRuleTable.fields)+" "
+            "FROM "+PushRuleTable.table_name+" "
+            "WHERE user_name = ? "
+            "ORDER BY priority_class DESC, priority DESC"
         )
-
         rows = yield self._execute(None, sql, user_name)
 
         dicts = []