summary refs log tree commit diff
diff options
context:
space:
mode:
authorMark Haines <mjark@negativecurvature.net>2016-03-16 11:26:22 +0000
committerMark Haines <mjark@negativecurvature.net>2016-03-16 11:26:22 +0000
commitadd89a03a68c9b4471546945389235184e640726 (patch)
treebab9829607686f8b03dce2f542716ac8b418f31e
parentMerge pull request #648 from matrix-org/rav/password_reset (diff)
parentAdd a comment to offer a hint to an explanation for why we have a unique cons... (diff)
downloadsynapse-add89a03a68c9b4471546945389235184e640726.tar.xz
Merge pull request #647 from matrix-org/markjh/pushers_stream
Add replication stream for pushers
-rw-r--r--synapse/notifier.py6
-rw-r--r--synapse/replication/resource.py25
-rw-r--r--synapse/rest/client/v1/pusher.py6
-rw-r--r--synapse/storage/__init__.py5
-rw-r--r--synapse/storage/pusher.py63
-rw-r--r--synapse/storage/schema/delta/30/deleted_pushers.sql25
-rw-r--r--synapse/storage/util/id_generators.py7
-rw-r--r--tests/replication/test_resource.py1
8 files changed, 120 insertions, 18 deletions
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 9b69b0333a..f00cd8c588 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -282,6 +282,12 @@ class Notifier(object):
 
             self.notify_replication()
 
+    def on_new_replication_data(self):
+        """Used to inform replication listeners that something has happend
+        without waking up any of the normal user event streams"""
+        with PreserveLoggingContext():
+            self.notify_replication()
+
     @defer.inlineCallbacks
     def wait_for_events(self, user_id, timeout, callback, room_ids=None,
                         from_token=StreamToken.START):
diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py
index adc1eb1d0b..8c1ae0fbc7 100644
--- a/synapse/replication/resource.py
+++ b/synapse/replication/resource.py
@@ -37,6 +37,7 @@ STREAM_NAMES = (
     ("user_account_data", "room_account_data", "tag_account_data",),
     ("backfill",),
     ("push_rules",),
+    ("pushers",),
 )
 
 
@@ -65,6 +66,7 @@ class ReplicationResource(Resource):
     * "tag_account_data": Per room per user tags.
     * "backfill": Old events that have been backfilled from other servers.
     * "push_rules": Per user changes to push rules.
+    * "pushers": Per user changes to their pushers.
 
     The API takes two additional query parameters:
 
@@ -120,6 +122,7 @@ class ReplicationResource(Resource):
         stream_token = yield self.sources.get_current_token()
         backfill_token = yield self.store.get_current_backfill_token()
         push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
+        pushers_token = self.store.get_pushers_stream_token()
 
         defer.returnValue(_ReplicationToken(
             room_stream_token,
@@ -129,6 +132,7 @@ class ReplicationResource(Resource):
             int(stream_token.account_data_key),
             backfill_token,
             push_rules_token,
+            pushers_token,
         ))
 
     @request_handler
@@ -151,6 +155,7 @@ class ReplicationResource(Resource):
             yield self.typing(writer, current_token)  # TODO: implement limit
             yield self.receipts(writer, current_token, limit)
             yield self.push_rules(writer, current_token, limit)
+            yield self.pushers(writer, current_token, limit)
             self.streams(writer, current_token)
 
             logger.info("Replicated %d rows", writer.total)
@@ -297,6 +302,24 @@ class ReplicationResource(Resource):
                 "priority_class", "priority", "conditions", "actions"
             ))
 
+    @defer.inlineCallbacks
+    def pushers(self, writer, current_token, limit):
+        current_position = current_token.pushers
+
+        pushers = parse_integer(writer.request, "pushers")
+        if pushers is not None:
+            updated, deleted = yield self.store.get_all_updated_pushers(
+                pushers, current_position, limit
+            )
+            writer.write_header_and_rows("pushers", updated, (
+                "position", "user_id", "access_token", "profile_tag", "kind",
+                "app_id", "app_display_name", "device_display_name", "pushkey",
+                "ts", "lang", "data"
+            ))
+            writer.write_header_and_rows("deleted", deleted, (
+                "position", "user_id", "app_id", "pushkey"
+            ))
+
 
 class _Writer(object):
     """Writes the streams as a JSON object as the response to the request"""
@@ -327,7 +350,7 @@ class _Writer(object):
 
 class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
     "events", "presence", "typing", "receipts", "account_data", "backfill",
-    "push_rules"
+    "push_rules", "pushers"
 ))):
     __slots__ = []
 
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index ee029b4f77..9881f068c3 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -29,6 +29,10 @@ logger = logging.getLogger(__name__)
 class PusherRestServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/pushers/set$")
 
+    def __init__(self, hs):
+        super(PusherRestServlet, self).__init__(hs)
+        self.notifier = hs.get_notifier()
+
     @defer.inlineCallbacks
     def on_POST(self, request):
         requester = yield self.auth.get_user_by_req(request)
@@ -87,6 +91,8 @@ class PusherRestServlet(ClientV1RestServlet):
             raise SynapseError(400, "Config Error: " + pce.message,
                                errcode=Codes.MISSING_PARAM)
 
+        self.notifier.on_new_replication_data()
+
         defer.returnValue((200, {}))
 
     def on_OPTIONS(self, _):
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 168eb27b03..250ba536ea 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -119,12 +119,15 @@ class DataStore(RoomMemberStore, RoomStore,
         self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
         self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
         self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
-        self._pushers_id_gen = IdGenerator(db_conn, "pushers", "id")
         self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
         self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
         self._push_rules_stream_id_gen = ChainedIdGenerator(
             self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
         )
+        self._pushers_id_gen = StreamIdGenerator(
+            db_conn, "pushers", "id",
+            extra_tables=[("deleted_pushers", "stream_id")],
+        )
 
         events_max = self._stream_id_gen.get_max_token()
         event_cache_prefill, min_event_val = self._get_cache_dict(
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 7693ab9082..87b2ac5773 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -16,8 +16,6 @@
 from ._base import SQLBaseStore
 from twisted.internet import defer
 
-from synapse.api.errors import StoreError
-
 from canonicaljson import encode_canonical_json
 
 import logging
@@ -79,12 +77,41 @@ class PusherStore(SQLBaseStore):
         rows = yield self.runInteraction("get_all_pushers", get_pushers)
         defer.returnValue(rows)
 
+    def get_pushers_stream_token(self):
+        return self._pushers_id_gen.get_max_token()
+
+    def get_all_updated_pushers(self, last_id, current_id, limit):
+        def get_all_updated_pushers_txn(txn):
+            sql = (
+                "SELECT id, user_name, access_token, profile_tag, kind,"
+                " app_id, app_display_name, device_display_name, pushkey, ts,"
+                " lang, data"
+                " FROM pushers"
+                " WHERE ? < id AND id <= ?"
+                " ORDER BY id ASC LIMIT ?"
+            )
+            txn.execute(sql, (last_id, current_id, limit))
+            updated = txn.fetchall()
+
+            sql = (
+                "SELECT stream_id, user_id, app_id, pushkey"
+                " FROM deleted_pushers"
+                " WHERE ? < stream_id AND stream_id <= ?"
+                " ORDER BY stream_id ASC LIMIT ?"
+            )
+            txn.execute(sql, (last_id, current_id, limit))
+            deleted = txn.fetchall()
+
+            return (updated, deleted)
+        return self.runInteraction(
+            "get_all_updated_pushers", get_all_updated_pushers_txn
+        )
+
     @defer.inlineCallbacks
     def add_pusher(self, user_id, access_token, kind, app_id,
                    app_display_name, device_display_name,
                    pushkey, pushkey_ts, lang, data, profile_tag=""):
-        try:
-            next_id = self._pushers_id_gen.get_next()
+        with self._pushers_id_gen.get_next() as stream_id:
             yield self._simple_upsert(
                 "pushers",
                 dict(
@@ -101,23 +128,29 @@ class PusherStore(SQLBaseStore):
                     lang=lang,
                     data=encode_canonical_json(data),
                     profile_tag=profile_tag,
-                ),
-                insertion_values=dict(
-                    id=next_id,
+                    id=stream_id,
                 ),
                 desc="add_pusher",
             )
-        except Exception as e:
-            logger.error("create_pusher with failed: %s", e)
-            raise StoreError(500, "Problem creating pusher.")
 
     @defer.inlineCallbacks
     def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
-        yield self._simple_delete_one(
-            "pushers",
-            {"app_id": app_id, "pushkey": pushkey, 'user_name': user_id},
-            desc="delete_pusher_by_app_id_pushkey_user_id",
-        )
+        def delete_pusher_txn(txn, stream_id):
+            self._simple_delete_one_txn(
+                txn,
+                "pushers",
+                {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}
+            )
+            self._simple_upsert_txn(
+                txn,
+                "deleted_pushers",
+                {"app_id": app_id, "pushkey": pushkey, "user_id": user_id},
+                {"stream_id": stream_id},
+            )
+        with self._pushers_id_gen.get_next() as stream_id:
+            yield self.runInteraction(
+                "delete_pusher", delete_pusher_txn, stream_id
+            )
 
     @defer.inlineCallbacks
     def update_pusher_last_token(self, app_id, pushkey, user_id, last_token):
diff --git a/synapse/storage/schema/delta/30/deleted_pushers.sql b/synapse/storage/schema/delta/30/deleted_pushers.sql
new file mode 100644
index 0000000000..712c454aa1
--- /dev/null
+++ b/synapse/storage/schema/delta/30/deleted_pushers.sql
@@ -0,0 +1,25 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS deleted_pushers(
+    stream_id BIGINT NOT NULL,
+    app_id TEXT NOT NULL,
+    pushkey TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    /* We only track the most recent delete for each app_id, pushkey and user_id. */
+    UNIQUE (app_id, pushkey, user_id)
+);
+
+CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id);
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 610ddad423..a02dfc7d58 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -49,9 +49,14 @@ class StreamIdGenerator(object):
         with stream_id_gen.get_next() as stream_id:
             # ... persist event ...
     """
-    def __init__(self, db_conn, table, column):
+    def __init__(self, db_conn, table, column, extra_tables=[]):
         self._lock = threading.Lock()
         self._current_max = _load_max_id(db_conn, table, column)
+        for table, column in extra_tables:
+            self._current_max = max(
+                self._current_max,
+                _load_max_id(db_conn, table, column)
+            )
         self._unfinished_ids = deque()
 
     def get_next(self):
diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py
index 4a42eb3365..f4b5fb3328 100644
--- a/tests/replication/test_resource.py
+++ b/tests/replication/test_resource.py
@@ -131,6 +131,7 @@ class ReplicationResourceCase(unittest.TestCase):
     test_timeout_tag_account_data = _test_timeout("tag_account_data")
     test_timeout_backfill = _test_timeout("backfill")
     test_timeout_push_rules = _test_timeout("push_rules")
+    test_timeout_pushers = _test_timeout("pushers")
 
     @defer.inlineCallbacks
     def send_text_message(self, room_id, message):