summary refs log tree commit diff
path: root/synapse/replication/slave
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/slave')
-rw-r--r--synapse/replication/slave/storage/_base.py19
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py20
-rw-r--r--synapse/replication/slave/storage/events.py14
-rw-r--r--synapse/replication/slave/storage/room.py3
-rw-r--r--synapse/replication/slave/storage/transactions.py10
5 files changed, 56 insertions, 10 deletions
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index f19540d6bb..18076e0f3b 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -34,6 +34,9 @@ class BaseSlavedStore(SQLBaseStore):
         else:
             self._cache_id_gen = None
 
+        self.expire_cache_url = hs.config.worker_replication_url + "/expire_cache"
+        self.http_client = hs.get_simple_http_client()
+
     def stream_positions(self):
         pos = {}
         if self._cache_id_gen:
@@ -54,3 +57,19 @@ class BaseSlavedStore(SQLBaseStore):
                     logger.info("Got unexpected cache_func: %r", cache_func)
             self._cache_id_gen.advance(int(stream["position"]))
         return defer.succeed(None)
+
+    def _invalidate_cache_and_stream(self, txn, cache_func, keys):
+        txn.call_after(cache_func.invalidate, keys)
+        txn.call_after(self._send_invalidation_poke, cache_func, keys)
+
+    @defer.inlineCallbacks
+    def _send_invalidation_poke(self, cache_func, keys):
+        try:
+            yield self.http_client.post_json_get_json(self.expire_cache_url, {
+                "invalidate": [{
+                    "name": cache_func.__name__,
+                    "keys": list(keys),
+                }]
+            })
+        except:
+            logger.exception("Failed to poke on expire_cache")
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 3bfd5e8213..cc860f9f9b 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -29,10 +29,16 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
             "DeviceInboxStreamChangeCache",
             self._device_inbox_id_gen.get_current_token()
         )
+        self._device_federation_outbox_stream_cache = StreamChangeCache(
+            "DeviceFederationOutboxStreamChangeCache",
+            self._device_inbox_id_gen.get_current_token()
+        )
 
     get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__
     get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__
+    get_new_device_msgs_for_remote = DataStore.get_new_device_msgs_for_remote.__func__
     delete_messages_for_device = DataStore.delete_messages_for_device.__func__
+    delete_device_msgs_for_remote = DataStore.delete_device_msgs_for_remote.__func__
 
     def stream_positions(self):
         result = super(SlavedDeviceInboxStore, self).stream_positions()
@@ -45,9 +51,15 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
             self._device_inbox_id_gen.advance(int(stream["position"]))
             for row in stream["rows"]:
                 stream_id = row[0]
-                user_id = row[1]
-                self._device_inbox_stream_cache.entity_has_changed(
-                    user_id, stream_id
-                )
+                entity = row[1]
+
+                if entity.startswith("@"):
+                    self._device_inbox_stream_cache.entity_has_changed(
+                        entity, stream_id
+                    )
+                else:
+                    self._device_federation_outbox_stream_cache.entity_has_changed(
+                        entity, stream_id
+                    )
 
         return super(SlavedDeviceInboxStore, self).process_replication(result)
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 0c26e96e98..64f18bbb3e 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -26,6 +26,11 @@ from synapse.storage.stream import StreamStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 import ujson as json
+import logging
+
+
+logger = logging.getLogger(__name__)
+
 
 # So, um, we want to borrow a load of functions intended for reading from
 # a DataStore, but we don't want to take functions that either write to the
@@ -180,6 +185,11 @@ class SlavedEventStore(BaseSlavedStore):
         EventFederationStore.__dict__["_get_forward_extremeties_for_room"]
     )
 
+    get_all_new_events_stream = DataStore.get_all_new_events_stream.__func__
+
+    get_federation_out_pos = DataStore.get_federation_out_pos.__func__
+    update_federation_out_pos = DataStore.update_federation_out_pos.__func__
+
     def stream_positions(self):
         result = super(SlavedEventStore, self).stream_positions()
         result["events"] = self._stream_id_gen.get_current_token()
@@ -194,6 +204,10 @@ class SlavedEventStore(BaseSlavedStore):
         stream = result.get("events")
         if stream:
             self._stream_id_gen.advance(int(stream["position"]))
+
+            if stream["rows"]:
+                logger.info("Got %d event rows", len(stream["rows"]))
+
             for row in stream["rows"]:
                 self._process_replication_row(
                     row, backfilled=False, state_resets=state_resets
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 23c613863f..6df9a25ef3 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -15,6 +15,7 @@
 
 from ._base import BaseSlavedStore
 from synapse.storage import DataStore
+from synapse.storage.room import RoomStore
 from ._slaved_id_tracker import SlavedIdTracker
 
 
@@ -30,7 +31,7 @@ class RoomStore(BaseSlavedStore):
         DataStore.get_current_public_room_stream_id.__func__
     )
     get_public_room_ids_at_stream_id = (
-        DataStore.get_public_room_ids_at_stream_id.__func__
+        RoomStore.__dict__["get_public_room_ids_at_stream_id"]
     )
     get_public_room_ids_at_stream_id_txn = (
         DataStore.get_public_room_ids_at_stream_id_txn.__func__
diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py
index 6f2ba98af5..fbb58f35da 100644
--- a/synapse/replication/slave/storage/transactions.py
+++ b/synapse/replication/slave/storage/transactions.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
 from ._base import BaseSlavedStore
 from synapse.storage import DataStore
 from synapse.storage.transactions import TransactionStore
@@ -22,9 +21,10 @@ from synapse.storage.transactions import TransactionStore
 class TransactionStore(BaseSlavedStore):
     get_destination_retry_timings = TransactionStore.__dict__[
         "get_destination_retry_timings"
-    ].orig
+    ]
     _get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__
+    set_destination_retry_timings = DataStore.set_destination_retry_timings.__func__
+    _set_destination_retry_timings = DataStore._set_destination_retry_timings.__func__
 
-    # For now, don't record the destination rety timings
-    def set_destination_retry_timings(*args, **kwargs):
-        return defer.succeed(None)
+    prep_send_transaction = DataStore.prep_send_transaction.__func__
+    delivered_txn = DataStore.delivered_txn.__func__