diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 9270bdd079..96633a176c 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014 OpenMarket Ltd
+# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -27,6 +27,7 @@ from synapse.push.rulekinds import (
PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
)
+import copy
import simplejson as json
@@ -43,7 +44,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
except InvalidRuleException as e:
raise SynapseError(400, e.message)
- user, _, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
if '/' in spec['rule_id'] or '\\' in spec['rule_id']:
raise SynapseError(400, "rule_id may not contain slashes")
@@ -51,7 +52,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
content = _parse_json(request)
if 'attr' in spec:
- self.set_rule_attr(user.to_string(), spec, content)
+ yield self.set_rule_attr(requester.user.to_string(), spec, content)
defer.returnValue((200, {}))
try:
@@ -65,15 +66,16 @@ class PushRuleRestServlet(ClientV1RestServlet):
raise SynapseError(400, e.message)
before = request.args.get("before", None)
- if before and len(before):
- before = before[0]
+ if before:
+ before = _namespaced_rule_id(spec, before[0])
+
after = request.args.get("after", None)
- if after and len(after):
- after = after[0]
+ if after:
+ after = _namespaced_rule_id(spec, after[0])
try:
yield self.hs.get_datastore().add_push_rule(
- user_name=user.to_string(),
+ user_id=requester.user.to_string(),
rule_id=_namespaced_rule_id_from_spec(spec),
priority_class=priority_class,
conditions=conditions,
@@ -92,13 +94,13 @@ class PushRuleRestServlet(ClientV1RestServlet):
def on_DELETE(self, request):
spec = _rule_spec_from_path(request.postpath)
- user, _, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
try:
yield self.hs.get_datastore().delete_push_rule(
- user.to_string(), namespaced_rule_id
+ requester.user.to_string(), namespaced_rule_id
)
defer.returnValue((200, {}))
except StoreError as e:
@@ -109,7 +111,8 @@ class PushRuleRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
- user, _, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
+ user = requester.user
# we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is
@@ -125,7 +128,8 @@ class PushRuleRestServlet(ClientV1RestServlet):
rule["actions"] = json.loads(rawrule["actions"])
ruleslist.append(rule)
- ruleslist = baserules.list_with_base_rules(ruleslist, user)
+ # We're going to be mutating this a lot, so do a deep copy
+ ruleslist = copy.deepcopy(baserules.list_with_base_rules(ruleslist))
rules = {'global': {}, 'device': {}}
@@ -139,6 +143,16 @@ class PushRuleRestServlet(ClientV1RestServlet):
template_name = _priority_class_to_template_name(r['priority_class'])
+ # Remove internal stuff.
+ for c in r["conditions"]:
+ c.pop("_id", None)
+
+ pattern_type = c.pop("pattern_type", None)
+ if pattern_type == "user_id":
+ c["pattern"] = user.to_string()
+ elif pattern_type == "user_localpart":
+ c["pattern"] = user.localpart
+
if r['priority_class'] > PRIORITY_CLASS_MAP['override']:
# per-device rule
profile_tag = _profile_tag_from_conditions(r["conditions"])
@@ -205,7 +219,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
def on_OPTIONS(self, _):
return 200, {}
- def set_rule_attr(self, user_name, spec, val):
+ def set_rule_attr(self, user_id, spec, val):
if spec['attr'] == 'enabled':
if isinstance(val, dict) and "enabled" in val:
val = val["enabled"]
@@ -215,16 +229,16 @@ class PushRuleRestServlet(ClientV1RestServlet):
# bools directly, so let's not break them.
raise SynapseError(400, "Value for 'enabled' must be boolean")
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
- self.hs.get_datastore().set_push_rule_enabled(
- user_name, namespaced_rule_id, val
+ return self.hs.get_datastore().set_push_rule_enabled(
+ user_id, namespaced_rule_id, val
)
else:
raise UnrecognizedRequestError()
- def get_rule_attr(self, user_name, namespaced_rule_id, attr):
+ def get_rule_attr(self, user_id, namespaced_rule_id, attr):
if attr == 'enabled':
return self.hs.get_datastore().get_push_rule_enabled_by_user_rule_id(
- user_name, namespaced_rule_id
+ user_id, namespaced_rule_id
)
else:
raise UnrecognizedRequestError()
@@ -439,11 +453,15 @@ def _strip_device_condition(rule):
def _namespaced_rule_id_from_spec(spec):
+ return _namespaced_rule_id(spec, spec['rule_id'])
+
+
+def _namespaced_rule_id(spec, rule_id):
if spec['scope'] == 'global':
scope = 'global'
else:
scope = 'device/%s' % (spec['profile_tag'])
- return "%s/%s/%s" % (scope, spec['template'], spec['rule_id'])
+ return "%s/%s/%s" % (scope, spec['template'], rule_id)
def _rule_id_from_namespaced(in_rule_id):
|