summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/expire_cache.py60
-rw-r--r--synapse/replication/resource.py37
-rw-r--r--synapse/replication/slave/storage/_base.py19
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py20
-rw-r--r--synapse/replication/slave/storage/events.py14
-rw-r--r--synapse/replication/slave/storage/room.py3
-rw-r--r--synapse/replication/slave/storage/transactions.py10
7 files changed, 150 insertions, 13 deletions
diff --git a/synapse/replication/expire_cache.py b/synapse/replication/expire_cache.py
new file mode 100644
index 0000000000..c05a50d7a6
--- /dev/null
+++ b/synapse/replication/expire_cache.py
@@ -0,0 +1,60 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.http.server import respond_with_json_bytes, request_handler
+from synapse.http.servlet import parse_json_object_from_request
+
+from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
+
+
+class ExpireCacheResource(Resource):
+    """
+    HTTP endpoint for expiring storage caches.
+
+    POST /_synapse/replication/expire_cache HTTP/1.1
+    Content-Type: application/json
+
+    {
+        "invalidate": [
+            {
+                "name": "func_name",
+                "keys": ["key1", "key2"]
+            }
+        ]
+    }
+    """
+
+    def __init__(self, hs):
+        Resource.__init__(self)  # Resource is old-style, so no super()
+
+        self.store = hs.get_datastore()
+        self.version_string = hs.version_string
+        self.clock = hs.get_clock()
+
+    def render_POST(self, request):
+        self._async_render_POST(request)
+        return NOT_DONE_YET
+
+    @request_handler()
+    def _async_render_POST(self, request):
+        content = parse_json_object_from_request(request)
+
+        for row in content["invalidate"]:
+            name = row["name"]
+            keys = tuple(row["keys"])
+
+            getattr(self.store, name).invalidate(keys)
+
+        respond_with_json_bytes(request, 200, "{}")
diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py
index 5a14c51d23..4616e9b34a 100644
--- a/synapse/replication/resource.py
+++ b/synapse/replication/resource.py
@@ -17,6 +17,7 @@ from synapse.http.servlet import parse_integer, parse_string
 from synapse.http.server import request_handler, finish_request
 from synapse.replication.pusher_resource import PusherResource
 from synapse.replication.presence_resource import PresenceResource
+from synapse.replication.expire_cache import ExpireCacheResource
 from synapse.api.errors import SynapseError
 
 from twisted.web.resource import Resource
@@ -44,6 +45,7 @@ STREAM_NAMES = (
     ("caches",),
     ("to_device",),
     ("public_rooms",),
+    ("federation",),
 )
 
 
@@ -116,11 +118,14 @@ class ReplicationResource(Resource):
         self.sources = hs.get_event_sources()
         self.presence_handler = hs.get_presence_handler()
         self.typing_handler = hs.get_typing_handler()
+        self.federation_sender = hs.get_federation_sender()
         self.notifier = hs.notifier
         self.clock = hs.get_clock()
+        self.config = hs.get_config()
 
         self.putChild("remove_pushers", PusherResource(hs))
         self.putChild("syncing_users", PresenceResource(hs))
+        self.putChild("expire_cache", ExpireCacheResource(hs))
 
     def render_GET(self, request):
         self._async_render_GET(request)
@@ -134,6 +139,7 @@ class ReplicationResource(Resource):
         pushers_token = self.store.get_pushers_stream_token()
         caches_token = self.store.get_cache_stream_token()
         public_rooms_token = self.store.get_current_public_room_stream_id()
+        federation_token = self.federation_sender.get_current_token()
 
         defer.returnValue(_ReplicationToken(
             room_stream_token,
@@ -148,6 +154,7 @@ class ReplicationResource(Resource):
             caches_token,
             int(stream_token.to_device_key),
             int(public_rooms_token),
+            int(federation_token),
         ))
 
     @request_handler()
@@ -164,8 +171,13 @@ class ReplicationResource(Resource):
         }
         request_streams["streams"] = parse_string(request, "streams")
 
+        federation_ack = parse_integer(request, "federation_ack", None)
+
         def replicate():
-            return self.replicate(request_streams, limit)
+            return self.replicate(
+                request_streams, limit,
+                federation_ack=federation_ack
+            )
 
         writer = yield self.notifier.wait_for_replication(replicate, timeout)
         result = writer.finish()
@@ -183,7 +195,7 @@ class ReplicationResource(Resource):
         finish_request(request)
 
     @defer.inlineCallbacks
-    def replicate(self, request_streams, limit):
+    def replicate(self, request_streams, limit, federation_ack=None):
         writer = _Writer()
         current_token = yield self.current_replication_token()
         logger.debug("Replicating up to %r", current_token)
@@ -202,6 +214,7 @@ class ReplicationResource(Resource):
         yield self.caches(writer, current_token, limit, request_streams)
         yield self.to_device(writer, current_token, limit, request_streams)
         yield self.public_rooms(writer, current_token, limit, request_streams)
+        self.federation(writer, current_token, limit, request_streams, federation_ack)
         self.streams(writer, current_token, request_streams)
 
         logger.debug("Replicated %d rows", writer.total)
@@ -462,7 +475,24 @@ class ReplicationResource(Resource):
             )
             upto_token = _position_from_rows(public_rooms_rows, current_position)
             writer.write_header_and_rows("public_rooms", public_rooms_rows, (
-                "position", "room_id", "visibility"
+                "position", "room_id", "visibility", "appservice_id", "network_id",
+            ), position=upto_token)
+
+    def federation(self, writer, current_token, limit, request_streams, federation_ack):
+        if self.config.send_federation:
+            return
+
+        current_position = current_token.federation
+
+        federation = request_streams.get("federation")
+
+        if federation is not None and federation != current_position:
+            federation_rows = self.federation_sender.get_replication_rows(
+                federation, limit, federation_ack=federation_ack,
+            )
+            upto_token = _position_from_rows(federation_rows, current_position)
+            writer.write_header_and_rows("federation", federation_rows, (
+                "position", "type", "content",
             ), position=upto_token)
 
 
@@ -497,6 +527,7 @@ class _Writer(object):
 class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
     "events", "presence", "typing", "receipts", "account_data", "backfill",
     "push_rules", "pushers", "state", "caches", "to_device", "public_rooms",
+    "federation",
 ))):
     __slots__ = []
 
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index f19540d6bb..18076e0f3b 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -34,6 +34,9 @@ class BaseSlavedStore(SQLBaseStore):
         else:
             self._cache_id_gen = None
 
+        self.expire_cache_url = hs.config.worker_replication_url + "/expire_cache"
+        self.http_client = hs.get_simple_http_client()
+
     def stream_positions(self):
         pos = {}
         if self._cache_id_gen:
@@ -54,3 +57,19 @@ class BaseSlavedStore(SQLBaseStore):
                     logger.info("Got unexpected cache_func: %r", cache_func)
             self._cache_id_gen.advance(int(stream["position"]))
         return defer.succeed(None)
+
+    def _invalidate_cache_and_stream(self, txn, cache_func, keys):
+        txn.call_after(cache_func.invalidate, keys)
+        txn.call_after(self._send_invalidation_poke, cache_func, keys)
+
+    @defer.inlineCallbacks
+    def _send_invalidation_poke(self, cache_func, keys):
+        try:
+            yield self.http_client.post_json_get_json(self.expire_cache_url, {
+                "invalidate": [{
+                    "name": cache_func.__name__,
+                    "keys": list(keys),
+                }]
+            })
+        except:
+            logger.exception("Failed to poke on expire_cache")
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 3bfd5e8213..cc860f9f9b 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -29,10 +29,16 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
             "DeviceInboxStreamChangeCache",
             self._device_inbox_id_gen.get_current_token()
         )
+        self._device_federation_outbox_stream_cache = StreamChangeCache(
+            "DeviceFederationOutboxStreamChangeCache",
+            self._device_inbox_id_gen.get_current_token()
+        )
 
     get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__
     get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__
+    get_new_device_msgs_for_remote = DataStore.get_new_device_msgs_for_remote.__func__
     delete_messages_for_device = DataStore.delete_messages_for_device.__func__
+    delete_device_msgs_for_remote = DataStore.delete_device_msgs_for_remote.__func__
 
     def stream_positions(self):
         result = super(SlavedDeviceInboxStore, self).stream_positions()
@@ -45,9 +51,15 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
             self._device_inbox_id_gen.advance(int(stream["position"]))
             for row in stream["rows"]:
                 stream_id = row[0]
-                user_id = row[1]
-                self._device_inbox_stream_cache.entity_has_changed(
-                    user_id, stream_id
-                )
+                entity = row[1]
+
+                if entity.startswith("@"):
+                    self._device_inbox_stream_cache.entity_has_changed(
+                        entity, stream_id
+                    )
+                else:
+                    self._device_federation_outbox_stream_cache.entity_has_changed(
+                        entity, stream_id
+                    )
 
         return super(SlavedDeviceInboxStore, self).process_replication(result)
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 0c26e96e98..64f18bbb3e 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -26,6 +26,11 @@ from synapse.storage.stream import StreamStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 import ujson as json
+import logging
+
+
+logger = logging.getLogger(__name__)
+
 
 # So, um, we want to borrow a load of functions intended for reading from
 # a DataStore, but we don't want to take functions that either write to the
@@ -180,6 +185,11 @@ class SlavedEventStore(BaseSlavedStore):
         EventFederationStore.__dict__["_get_forward_extremeties_for_room"]
     )
 
+    get_all_new_events_stream = DataStore.get_all_new_events_stream.__func__
+
+    get_federation_out_pos = DataStore.get_federation_out_pos.__func__
+    update_federation_out_pos = DataStore.update_federation_out_pos.__func__
+
     def stream_positions(self):
         result = super(SlavedEventStore, self).stream_positions()
         result["events"] = self._stream_id_gen.get_current_token()
@@ -194,6 +204,10 @@ class SlavedEventStore(BaseSlavedStore):
         stream = result.get("events")
         if stream:
             self._stream_id_gen.advance(int(stream["position"]))
+
+            if stream["rows"]:
+                logger.info("Got %d event rows", len(stream["rows"]))
+
             for row in stream["rows"]:
                 self._process_replication_row(
                     row, backfilled=False, state_resets=state_resets
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 23c613863f..6df9a25ef3 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -15,6 +15,7 @@
 
 from ._base import BaseSlavedStore
 from synapse.storage import DataStore
+from synapse.storage.room import RoomStore
 from ._slaved_id_tracker import SlavedIdTracker
 
 
@@ -30,7 +31,7 @@ class RoomStore(BaseSlavedStore):
         DataStore.get_current_public_room_stream_id.__func__
     )
     get_public_room_ids_at_stream_id = (
-        DataStore.get_public_room_ids_at_stream_id.__func__
+        RoomStore.__dict__["get_public_room_ids_at_stream_id"]
     )
     get_public_room_ids_at_stream_id_txn = (
         DataStore.get_public_room_ids_at_stream_id_txn.__func__
diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py
index 6f2ba98af5..fbb58f35da 100644
--- a/synapse/replication/slave/storage/transactions.py
+++ b/synapse/replication/slave/storage/transactions.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
 from ._base import BaseSlavedStore
 from synapse.storage import DataStore
 from synapse.storage.transactions import TransactionStore
@@ -22,9 +21,10 @@ from synapse.storage.transactions import TransactionStore
 class TransactionStore(BaseSlavedStore):
     get_destination_retry_timings = TransactionStore.__dict__[
         "get_destination_retry_timings"
-    ].orig
+    ]
     _get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__
+    set_destination_retry_timings = DataStore.set_destination_retry_timings.__func__
+    _set_destination_retry_timings = DataStore._set_destination_retry_timings.__func__
 
-    # For now, don't record the destination rety timings
-    def set_destination_retry_timings(*args, **kwargs):
-        return defer.succeed(None)
+    prep_send_transaction = DataStore.prep_send_transaction.__func__
+    delivered_txn = DataStore.delivered_txn.__func__