summary refs log tree commit diff
path: root/synapse/push
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/push/__init__.py11
-rw-r--r--synapse/push/action_generator.py6
-rw-r--r--synapse/push/baserules.py57
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py14
-rw-r--r--synapse/push/clientformat.py112
-rw-r--r--synapse/push/httppusher.py3
-rw-r--r--synapse/push/push_rule_evaluator.py19
-rw-r--r--synapse/push/pusherpool.py58
8 files changed, 211 insertions, 69 deletions
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 8da2d8716c..296c4447ec 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -21,7 +21,7 @@ from synapse.util.logcontext import LoggingContext
 from synapse.util.metrics import Measure
 
 import synapse.util.async
-import push_rule_evaluator as push_rule_evaluator
+from .push_rule_evaluator import evaluator_for_user_id
 
 import logging
 import random
@@ -47,14 +47,13 @@ class Pusher(object):
     MAX_BACKOFF = 60 * 60 * 1000
     GIVE_UP_AFTER = 24 * 60 * 60 * 1000
 
-    def __init__(self, _hs, profile_tag, user_id, app_id,
+    def __init__(self, _hs, user_id, app_id,
                  app_display_name, device_display_name, pushkey, pushkey_ts,
                  data, last_token, last_success, failing_since):
         self.hs = _hs
         self.evStreamHandler = self.hs.get_handlers().event_stream_handler
         self.store = self.hs.get_datastore()
         self.clock = self.hs.get_clock()
-        self.profile_tag = profile_tag
         self.user_id = user_id
         self.app_id = app_id
         self.app_display_name = app_display_name
@@ -186,8 +185,8 @@ class Pusher(object):
         processed = False
 
         rule_evaluator = yield \
-            push_rule_evaluator.evaluator_for_user_id_and_profile_tag(
-                self.user_id, self.profile_tag, single_event['room_id'], self.store
+            evaluator_for_user_id(
+                self.user_id, single_event['room_id'], self.store
             )
 
         actions = yield rule_evaluator.actions_for_event(single_event)
@@ -318,7 +317,7 @@ class Pusher(object):
     @defer.inlineCallbacks
     def _get_badge_count(self):
         invites, joins = yield defer.gatherResults([
-            self.store.get_invites_for_user(self.user_id),
+            self.store.get_invited_rooms_for_user(self.user_id),
             self.store.get_rooms_for_user(self.user_id),
         ], consumeErrors=True)
 
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index e0da0868ec..84efcdd184 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -15,7 +15,7 @@
 
 from twisted.internet import defer
 
-import bulk_push_rule_evaluator
+from .bulk_push_rule_evaluator import evaluator_for_room_id
 
 import logging
 
@@ -35,7 +35,7 @@ class ActionGenerator:
 
     @defer.inlineCallbacks
     def handle_push_actions_for_event(self, event, context, handler):
-        bulk_evaluator = yield bulk_push_rule_evaluator.evaluator_for_room_id(
+        bulk_evaluator = yield evaluator_for_room_id(
             event.room_id, self.hs, self.store
         )
 
@@ -44,5 +44,5 @@ class ActionGenerator:
         )
 
         context.push_actions = [
-            (uid, None, actions) for uid, actions in actions_by_user.items()
+            (uid, actions) for uid, actions in actions_by_user.items()
         ]
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 0832c77cb4..86a2998bcc 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -13,46 +13,67 @@
 # limitations under the License.
 
 from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
+import copy
 
 
 def list_with_base_rules(rawrules):
+    """Combine the list of rules set by the user with the default push rules
+
+    :param list rawrules: The rules the user has modified or set.
+    :returns: A new list with the rules set by the user combined with the
+        defaults.
+    """
     ruleslist = []
 
+    # Grab the base rules that the user has modified.
+    # The modified base rules have a priority_class of -1.
+    modified_base_rules = {
+        r['rule_id']: r for r in rawrules if r['priority_class'] < 0
+    }
+
+    # Remove the modified base rules from the list, They'll be added back
+    # in the default postions in the list.
+    rawrules = [r for r in rawrules if r['priority_class'] >= 0]
+
     # shove the server default rules for each kind onto the end of each
     current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1]
 
     ruleslist.extend(make_base_prepend_rules(
-        PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
+        PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
     ))
 
     for r in rawrules:
         if r['priority_class'] < current_prio_class:
             while r['priority_class'] < current_prio_class:
                 ruleslist.extend(make_base_append_rules(
-                    PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
+                    PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
+                    modified_base_rules,
                 ))
                 current_prio_class -= 1
                 if current_prio_class > 0:
                     ruleslist.extend(make_base_prepend_rules(
-                        PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
+                        PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
+                        modified_base_rules,
                     ))
 
         ruleslist.append(r)
 
     while current_prio_class > 0:
         ruleslist.extend(make_base_append_rules(
-            PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
+            PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
+            modified_base_rules,
         ))
         current_prio_class -= 1
         if current_prio_class > 0:
             ruleslist.extend(make_base_prepend_rules(
-                PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
+                PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
+                modified_base_rules,
             ))
 
     return ruleslist
 
 
-def make_base_append_rules(kind):
+def make_base_append_rules(kind, modified_base_rules):
     rules = []
 
     if kind == 'override':
@@ -62,15 +83,31 @@ def make_base_append_rules(kind):
     elif kind == 'content':
         rules = BASE_APPEND_CONTENT_RULES
 
+    # Copy the rules before modifying them
+    rules = copy.deepcopy(rules)
+    for r in rules:
+        # Only modify the actions, keep the conditions the same.
+        modified = modified_base_rules.get(r['rule_id'])
+        if modified:
+            r['actions'] = modified['actions']
+
     return rules
 
 
-def make_base_prepend_rules(kind):
+def make_base_prepend_rules(kind, modified_base_rules):
     rules = []
 
     if kind == 'override':
         rules = BASE_PREPEND_OVERRIDE_RULES
 
+    # Copy the rules before modifying them
+    rules = copy.deepcopy(rules)
+    for r in rules:
+        # Only modify the actions, keep the conditions the same.
+        modified = modified_base_rules.get(r['rule_id'])
+        if modified:
+            r['actions'] = modified['actions']
+
     return rules
 
 
@@ -263,18 +300,24 @@ BASE_APPEND_UNDERRIDE_RULES = [
 ]
 
 
+BASE_RULE_IDS = set()
+
 for r in BASE_APPEND_CONTENT_RULES:
     r['priority_class'] = PRIORITY_CLASS_MAP['content']
     r['default'] = True
+    BASE_RULE_IDS.add(r['rule_id'])
 
 for r in BASE_PREPEND_OVERRIDE_RULES:
     r['priority_class'] = PRIORITY_CLASS_MAP['override']
     r['default'] = True
+    BASE_RULE_IDS.add(r['rule_id'])
 
 for r in BASE_APPEND_OVRRIDE_RULES:
     r['priority_class'] = PRIORITY_CLASS_MAP['override']
     r['default'] = True
+    BASE_RULE_IDS.add(r['rule_id'])
 
 for r in BASE_APPEND_UNDERRIDE_RULES:
     r['priority_class'] = PRIORITY_CLASS_MAP['underride']
     r['default'] = True
+    BASE_RULE_IDS.add(r['rule_id'])
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 8ac5ceb9ef..76d7eb7ce0 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -18,8 +18,8 @@ import ujson as json
 
 from twisted.internet import defer
 
-import baserules
-from push_rule_evaluator import PushRuleEvaluatorForEvent
+from .baserules import list_with_base_rules
+from .push_rule_evaluator import PushRuleEvaluatorForEvent
 
 from synapse.api.constants import EventTypes
 
@@ -39,7 +39,7 @@ def _get_rules(room_id, user_ids, store):
     rules_enabled_by_user = yield store.bulk_get_push_rules_enabled(user_ids)
 
     rules_by_user = {
-        uid: baserules.list_with_base_rules([
+        uid: list_with_base_rules([
             decode_rule_json(rule_list)
             for rule_list in rules_by_user.get(uid, [])
         ])
@@ -103,11 +103,13 @@ class BulkPushRuleEvaluator:
 
         users_dict = yield self.store.are_guests(self.rules_by_user.keys())
 
-        filtered_by_user = yield handler._filter_events_for_clients(
+        filtered_by_user = yield handler.filter_events_for_clients(
             users_dict.items(), [event], {event.event_id: current_state}
         )
 
-        evaluator = PushRuleEvaluatorForEvent(event, len(self.users_in_room))
+        room_members = yield self.store.get_users_in_room(self.room_id)
+
+        evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
 
         condition_cache = {}
 
@@ -152,7 +154,7 @@ def _condition_checker(evaluator, conditions, uid, display_name, cache):
             elif res is True:
                 continue
 
-        res = evaluator.matches(cond, uid, display_name, None)
+        res = evaluator.matches(cond, uid, display_name)
         if _id:
             cache[_id] = bool(res)
 
diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
new file mode 100644
index 0000000000..ae9db9ec2f
--- /dev/null
+++ b/synapse/push/clientformat.py
@@ -0,0 +1,112 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.push.baserules import list_with_base_rules
+
+from synapse.push.rulekinds import (
+    PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
+)
+
+import copy
+import simplejson as json
+
+
+def format_push_rules_for_user(user, rawrules, enabled_map):
+    """Converts a list of rawrules and a enabled map into nested dictionaries
+    to match the Matrix client-server format for push rules"""
+
+    ruleslist = []
+    for rawrule in rawrules:
+        rule = dict(rawrule)
+        rule["conditions"] = json.loads(rawrule["conditions"])
+        rule["actions"] = json.loads(rawrule["actions"])
+        ruleslist.append(rule)
+
+    # We're going to be mutating this a lot, so do a deep copy
+    ruleslist = copy.deepcopy(list_with_base_rules(ruleslist))
+
+    rules = {'global': {}, 'device': {}}
+
+    rules['global'] = _add_empty_priority_class_arrays(rules['global'])
+
+    for r in ruleslist:
+        rulearray = None
+
+        template_name = _priority_class_to_template_name(r['priority_class'])
+
+        # Remove internal stuff.
+        for c in r["conditions"]:
+            c.pop("_id", None)
+
+            pattern_type = c.pop("pattern_type", None)
+            if pattern_type == "user_id":
+                c["pattern"] = user.to_string()
+            elif pattern_type == "user_localpart":
+                c["pattern"] = user.localpart
+
+        rulearray = rules['global'][template_name]
+
+        template_rule = _rule_to_template(r)
+        if template_rule:
+            if r['rule_id'] in enabled_map:
+                template_rule['enabled'] = enabled_map[r['rule_id']]
+            elif 'enabled' in r:
+                template_rule['enabled'] = r['enabled']
+            else:
+                template_rule['enabled'] = True
+            rulearray.append(template_rule)
+
+    return rules
+
+
+def _add_empty_priority_class_arrays(d):
+    for pc in PRIORITY_CLASS_MAP.keys():
+        d[pc] = []
+    return d
+
+
+def _rule_to_template(rule):
+    unscoped_rule_id = None
+    if 'rule_id' in rule:
+        unscoped_rule_id = _rule_id_from_namespaced(rule['rule_id'])
+
+    template_name = _priority_class_to_template_name(rule['priority_class'])
+    if template_name in ['override', 'underride']:
+        templaterule = {k: rule[k] for k in ["conditions", "actions"]}
+    elif template_name in ["sender", "room"]:
+        templaterule = {'actions': rule['actions']}
+        unscoped_rule_id = rule['conditions'][0]['pattern']
+    elif template_name == 'content':
+        if len(rule["conditions"]) != 1:
+            return None
+        thecond = rule["conditions"][0]
+        if "pattern" not in thecond:
+            return None
+        templaterule = {'actions': rule['actions']}
+        templaterule["pattern"] = thecond["pattern"]
+
+    if unscoped_rule_id:
+            templaterule['rule_id'] = unscoped_rule_id
+    if 'default' in rule:
+        templaterule['default'] = rule['default']
+    return templaterule
+
+
+def _rule_id_from_namespaced(in_rule_id):
+    return in_rule_id.split('/')[-1]
+
+
+def _priority_class_to_template_name(pc):
+    return PRIORITY_CLASS_INVERSE_MAP[pc]
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index cdc4494928..9be4869360 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -23,12 +23,11 @@ logger = logging.getLogger(__name__)
 
 
 class HttpPusher(Pusher):
-    def __init__(self, _hs, profile_tag, user_id, app_id,
+    def __init__(self, _hs, user_id, app_id,
                  app_display_name, device_display_name, pushkey, pushkey_ts,
                  data, last_token, last_success, failing_since):
         super(HttpPusher, self).__init__(
             _hs,
-            profile_tag,
             user_id,
             app_id,
             app_display_name,
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 2a2b4437dc..51f73a5b78 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -15,7 +15,7 @@
 
 from twisted.internet import defer
 
-import baserules
+from .baserules import list_with_base_rules
 
 import logging
 import simplejson as json
@@ -33,7 +33,7 @@ INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
 
 
 @defer.inlineCallbacks
-def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store):
+def evaluator_for_user_id(user_id, room_id, store):
     rawrules = yield store.get_push_rules_for_user(user_id)
     enabled_map = yield store.get_push_rules_enabled_for_user(user_id)
     our_member_event = yield store.get_current_state(
@@ -43,7 +43,7 @@ def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store):
     )
 
     defer.returnValue(PushRuleEvaluator(
-        user_id, profile_tag, rawrules, enabled_map,
+        user_id, rawrules, enabled_map,
         room_id, our_member_event, store
     ))
 
@@ -77,10 +77,9 @@ def _room_member_count(ev, condition, room_member_count):
 class PushRuleEvaluator:
     DEFAULT_ACTIONS = []
 
-    def __init__(self, user_id, profile_tag, raw_rules, enabled_map, room_id,
+    def __init__(self, user_id, raw_rules, enabled_map, room_id,
                  our_member_event, store):
         self.user_id = user_id
-        self.profile_tag = profile_tag
         self.room_id = room_id
         self.our_member_event = our_member_event
         self.store = store
@@ -92,7 +91,7 @@ class PushRuleEvaluator:
             rule['actions'] = json.loads(raw_rule['actions'])
             rules.append(rule)
 
-        self.rules = baserules.list_with_base_rules(rules)
+        self.rules = list_with_base_rules(rules)
 
         self.enabled_map = enabled_map
 
@@ -152,7 +151,7 @@ class PushRuleEvaluator:
             matches = True
             for c in conditions:
                 matches = evaluator.matches(
-                    c, self.user_id, my_display_name, self.profile_tag
+                    c, self.user_id, my_display_name
                 )
                 if not matches:
                     break
@@ -189,13 +188,9 @@ class PushRuleEvaluatorForEvent(object):
         # Maps strings of e.g. 'content.body' -> event["content"]["body"]
         self._value_cache = _flatten_dict(event)
 
-    def matches(self, condition, user_id, display_name, profile_tag):
+    def matches(self, condition, user_id, display_name):
         if condition['kind'] == 'event_match':
             return self._event_match(condition, user_id)
-        elif condition['kind'] == 'device':
-            if 'profile_tag' not in condition:
-                return True
-            return condition['profile_tag'] == profile_tag
         elif condition['kind'] == 'contains_display_name':
             return self._contains_display_name(display_name)
         elif condition['kind'] == 'room_member_count':
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index d7dcb2de4b..0b463c6fdb 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -16,7 +16,7 @@
 
 from twisted.internet import defer
 
-from httppusher import HttpPusher
+from .httppusher import HttpPusher
 from synapse.push import PusherConfigException
 from synapse.util.logcontext import preserve_fn
 
@@ -29,6 +29,7 @@ class PusherPool:
     def __init__(self, _hs):
         self.hs = _hs
         self.store = self.hs.get_datastore()
+        self.clock = self.hs.get_clock()
         self.pushers = {}
         self.last_pusher_started = -1
 
@@ -38,8 +39,11 @@ class PusherPool:
         self._start_pushers(pushers)
 
     @defer.inlineCallbacks
-    def add_pusher(self, user_id, access_token, profile_tag, kind, app_id,
-                   app_display_name, device_display_name, pushkey, lang, data):
+    def add_pusher(self, user_id, access_token, kind, app_id,
+                   app_display_name, device_display_name, pushkey, lang, data,
+                   profile_tag=""):
+        time_now_msec = self.clock.time_msec()
+
         # we try to create the pusher just to validate the config: it
         # will then get pulled out of the database,
         # recreated, added and started: this means we have only one
@@ -47,23 +51,31 @@ class PusherPool:
         self._create_pusher({
             "user_name": user_id,
             "kind": kind,
-            "profile_tag": profile_tag,
             "app_id": app_id,
             "app_display_name": app_display_name,
             "device_display_name": device_display_name,
             "pushkey": pushkey,
-            "ts": self.hs.get_clock().time_msec(),
+            "ts": time_now_msec,
             "lang": lang,
             "data": data,
             "last_token": None,
             "last_success": None,
             "failing_since": None
         })
-        yield self._add_pusher_to_store(
-            user_id, access_token, profile_tag, kind, app_id,
-            app_display_name, device_display_name,
-            pushkey, lang, data
+        yield self.store.add_pusher(
+            user_id=user_id,
+            access_token=access_token,
+            kind=kind,
+            app_id=app_id,
+            app_display_name=app_display_name,
+            device_display_name=device_display_name,
+            pushkey=pushkey,
+            pushkey_ts=time_now_msec,
+            lang=lang,
+            data=data,
+            profile_tag=profile_tag,
         )
+        yield self._refresh_pusher(app_id, pushkey, user_id)
 
     @defer.inlineCallbacks
     def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey,
@@ -80,44 +92,24 @@ class PusherPool:
                 yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
 
     @defer.inlineCallbacks
-    def remove_pushers_by_user(self, user_id):
+    def remove_pushers_by_user(self, user_id, except_token_ids=[]):
         all = yield self.store.get_all_pushers()
         logger.info(
-            "Removing all pushers for user %s",
-            user_id,
+            "Removing all pushers for user %s except access tokens ids %r",
+            user_id, except_token_ids
         )
         for p in all:
-            if p['user_name'] == user_id:
+            if p['user_name'] == user_id and p['access_token'] not in except_token_ids:
                 logger.info(
                     "Removing pusher for app id %s, pushkey %s, user %s",
                     p['app_id'], p['pushkey'], p['user_name']
                 )
                 yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
 
-    @defer.inlineCallbacks
-    def _add_pusher_to_store(self, user_id, access_token, profile_tag, kind,
-                             app_id, app_display_name, device_display_name,
-                             pushkey, lang, data):
-        yield self.store.add_pusher(
-            user_id=user_id,
-            access_token=access_token,
-            profile_tag=profile_tag,
-            kind=kind,
-            app_id=app_id,
-            app_display_name=app_display_name,
-            device_display_name=device_display_name,
-            pushkey=pushkey,
-            pushkey_ts=self.hs.get_clock().time_msec(),
-            lang=lang,
-            data=data,
-        )
-        yield self._refresh_pusher(app_id, pushkey, user_id)
-
     def _create_pusher(self, pusherdict):
         if pusherdict['kind'] == 'http':
             return HttpPusher(
                 self.hs,
-                profile_tag=pusherdict['profile_tag'],
                 user_id=pusherdict['user_name'],
                 app_id=pusherdict['app_id'],
                 app_display_name=pusherdict['app_display_name'],