summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/message.py4
-rw-r--r--synapse/notifier.py2
-rw-r--r--synapse/rest/client/v1/push_rule.py44
-rw-r--r--synapse/streams/events.py4
-rw-r--r--synapse/types.py7
5 files changed, 43 insertions, 18 deletions
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index afa7c9c36c..2fa12c8f2b 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -647,8 +647,8 @@ class MessageHandler(BaseHandler):
             user_id, messages, is_peeking=is_peeking
         )
 
-        start_token = StreamToken(token[0], 0, 0, 0, 0)
-        end_token = StreamToken(token[1], 0, 0, 0, 0)
+        start_token = StreamToken.START.copy_and_replace("room_key", token[0])
+        end_token = StreamToken.START.copy_and_replace("room_key", token[1])
 
         time_now = self.clock.time_msec()
 
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 3c36a20868..9b69b0333a 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -284,7 +284,7 @@ class Notifier(object):
 
     @defer.inlineCallbacks
     def wait_for_events(self, user_id, timeout, callback, room_ids=None,
-                        from_token=StreamToken("s0", "0", "0", "0", "0")):
+                        from_token=StreamToken.START):
         """Wait until the callback returns a non empty response or the
         timeout fires.
         """
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 970a019223..cf68725ca1 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -36,6 +36,11 @@ class PushRuleRestServlet(ClientV1RestServlet):
     SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
         "Unrecognised request: You probably wanted a trailing slash")
 
+    def __init__(self, hs):
+        super(PushRuleRestServlet, self).__init__(hs)
+        self.store = hs.get_datastore()
+        self.notifier = hs.get_notifier()
+
     @defer.inlineCallbacks
     def on_PUT(self, request):
         spec = _rule_spec_from_path(request.postpath)
@@ -51,8 +56,11 @@ class PushRuleRestServlet(ClientV1RestServlet):
 
         content = _parse_json(request)
 
+        user_id = requester.user.to_string()
+
         if 'attr' in spec:
-            yield self.set_rule_attr(requester.user.to_string(), spec, content)
+            yield self.set_rule_attr(user_id, spec, content)
+            self.notify_user(user_id)
             defer.returnValue((200, {}))
 
         if spec['rule_id'].startswith('.'):
@@ -77,8 +85,8 @@ class PushRuleRestServlet(ClientV1RestServlet):
             after = _namespaced_rule_id(spec, after[0])
 
         try:
-            yield self.hs.get_datastore().add_push_rule(
-                user_id=requester.user.to_string(),
+            yield self.store.add_push_rule(
+                user_id=user_id,
                 rule_id=_namespaced_rule_id_from_spec(spec),
                 priority_class=priority_class,
                 conditions=conditions,
@@ -86,6 +94,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
                 before=before,
                 after=after
             )
+            self.notify_user(user_id)
         except InconsistentRuleException as e:
             raise SynapseError(400, e.message)
         except RuleNotFoundException as e:
@@ -98,13 +107,15 @@ class PushRuleRestServlet(ClientV1RestServlet):
         spec = _rule_spec_from_path(request.postpath)
 
         requester = yield self.auth.get_user_by_req(request)
+        user_id = requester.user.to_string()
 
         namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
 
         try:
-            yield self.hs.get_datastore().delete_push_rule(
-                requester.user.to_string(), namespaced_rule_id
+            yield self.store.delete_push_rule(
+                user_id, namespaced_rule_id
             )
+            self.notify_user(user_id)
             defer.returnValue((200, {}))
         except StoreError as e:
             if e.code == 404:
@@ -115,14 +126,12 @@ class PushRuleRestServlet(ClientV1RestServlet):
     @defer.inlineCallbacks
     def on_GET(self, request):
         requester = yield self.auth.get_user_by_req(request)
-        user = requester.user
+        user_id = requester.user.to_string()
 
         # we build up the full structure and then decide which bits of it
         # to send which means doing unnecessary work sometimes but is
         # is probably not going to make a whole lot of difference
-        rawrules = yield self.hs.get_datastore().get_push_rules_for_user(
-            user.to_string()
-        )
+        rawrules = yield self.store.get_push_rules_for_user(user_id)
 
         ruleslist = []
         for rawrule in rawrules:
@@ -138,8 +147,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
 
         rules['global'] = _add_empty_priority_class_arrays(rules['global'])
 
-        enabled_map = yield self.hs.get_datastore().\
-            get_push_rules_enabled_for_user(user.to_string())
+        enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
 
         for r in ruleslist:
             rulearray = None
@@ -152,9 +160,9 @@ class PushRuleRestServlet(ClientV1RestServlet):
 
                 pattern_type = c.pop("pattern_type", None)
                 if pattern_type == "user_id":
-                    c["pattern"] = user.to_string()
+                    c["pattern"] = user_id
                 elif pattern_type == "user_localpart":
-                    c["pattern"] = user.localpart
+                    c["pattern"] = requester.user.localpart
 
             rulearray = rules['global'][template_name]
 
@@ -188,6 +196,12 @@ class PushRuleRestServlet(ClientV1RestServlet):
     def on_OPTIONS(self, _):
         return 200, {}
 
+    def notify_user(self, user_id):
+        stream_id = self.store.get_push_rules_stream_token()
+        self.notifier.on_new_event(
+            "push_rules_key", stream_id, users=[user_id]
+        )
+
     def set_rule_attr(self, user_id, spec, val):
         if spec['attr'] == 'enabled':
             if isinstance(val, dict) and "enabled" in val:
@@ -198,7 +212,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
                 # bools directly, so let's not break them.
                 raise SynapseError(400, "Value for 'enabled' must be boolean")
             namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
-            return self.hs.get_datastore().set_push_rule_enabled(
+            return self.store.set_push_rule_enabled(
                 user_id, namespaced_rule_id, val
             )
         elif spec['attr'] == 'actions':
@@ -210,7 +224,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
             if is_default_rule:
                 if namespaced_rule_id not in BASE_RULE_IDS:
                     raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
-            return self.hs.get_datastore().set_push_rule_actions(
+            return self.store.set_push_rule_actions(
                 user_id, namespaced_rule_id, actions, is_default_rule
             )
         else:
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 5ddf4e988b..d4c0bb6732 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -38,9 +38,12 @@ class EventSources(object):
             name: cls(hs)
             for name, cls in EventSources.SOURCE_TYPES.items()
         }
+        self.store = hs.get_datastore()
 
     @defer.inlineCallbacks
     def get_current_token(self, direction='f'):
+        push_rules_key, _ = self.store.get_push_rules_stream_token()
+
         token = StreamToken(
             room_key=(
                 yield self.sources["room"].get_current_key(direction)
@@ -57,5 +60,6 @@ class EventSources(object):
             account_data_key=(
                 yield self.sources["account_data"].get_current_key()
             ),
+            push_rules_key=push_rules_key,
         )
         defer.returnValue(token)
diff --git a/synapse/types.py b/synapse/types.py
index d5bd95cbd3..5b166835bd 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -115,6 +115,7 @@ class StreamToken(
         "typing_key",
         "receipt_key",
         "account_data_key",
+        "push_rules_key",
     ))
 ):
     _SEPARATOR = "_"
@@ -150,6 +151,7 @@ class StreamToken(
             or (int(other.typing_key) < int(self.typing_key))
             or (int(other.receipt_key) < int(self.receipt_key))
             or (int(other.account_data_key) < int(self.account_data_key))
+            or (int(other.push_rules_key) < int(self.push_rules_key))
         )
 
     def copy_and_advance(self, key, new_value):
@@ -174,6 +176,11 @@ class StreamToken(
         return StreamToken(**d)
 
 
+StreamToken.START = StreamToken(
+    *(["s0"] + ["0"] * (len(StreamToken._fields) - 1))
+)
+
+
 class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
     """Tokens are positions between events. The token "s1" comes after event 1.