summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/replication/slave/storage/_base.py31
-rw-r--r--synapse/replication/slave/storage/account_data.py49
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py23
-rw-r--r--synapse/replication/slave/storage/devices.py24
-rw-r--r--synapse/replication/slave/storage/events.py57
-rw-r--r--synapse/replication/slave/storage/presence.py19
-rw-r--r--synapse/replication/slave/storage/push_rule.py23
-rw-r--r--synapse/replication/slave/storage/pushers.py16
-rw-r--r--synapse/replication/slave/storage/receipts.py24
-rw-r--r--synapse/replication/slave/storage/room.py11
-rw-r--r--tests/replication/slave/storage/_base.py30
11 files changed, 128 insertions, 179 deletions
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index ab133db872..b962641166 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -15,7 +15,6 @@
 
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.engines import PostgresEngine
-from twisted.internet import defer
 
 from ._slaved_id_tracker import SlavedIdTracker
 
@@ -34,8 +33,7 @@ class BaseSlavedStore(SQLBaseStore):
         else:
             self._cache_id_gen = None
 
-        self.expire_cache_url = hs.config.worker_replication_url + "/expire_cache"
-        self.http_client = hs.get_simple_http_client()
+        self.hs = hs
 
     def stream_positions(self):
         pos = {}
@@ -43,35 +41,20 @@ class BaseSlavedStore(SQLBaseStore):
             pos["caches"] = self._cache_id_gen.get_current_token()
         return pos
 
-    def process_replication(self, result):
-        stream = result.get("caches")
-        if stream:
-            for row in stream["rows"]:
-                (
-                    position, cache_func, keys, invalidation_ts,
-                ) = row
-
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "caches":
+            self._cache_id_gen.advance(token)
+            for row in rows:
                 try:
-                    getattr(self, cache_func).invalidate(tuple(keys))
+                    getattr(self, row.cache_func).invalidate(tuple(row.keys))
                 except AttributeError:
                     # We probably haven't pulled in the cache in this worker,
                     # which is fine.
                     pass
-            self._cache_id_gen.advance(int(stream["position"]))
-        return defer.succeed(None)
 
     def _invalidate_cache_and_stream(self, txn, cache_func, keys):
         txn.call_after(cache_func.invalidate, keys)
         txn.call_after(self._send_invalidation_poke, cache_func, keys)
 
-    @defer.inlineCallbacks
     def _send_invalidation_poke(self, cache_func, keys):
-        try:
-            yield self.http_client.post_json_get_json(self.expire_cache_url, {
-                "invalidate": [{
-                    "name": cache_func.__name__,
-                    "keys": list(keys),
-                }]
-            })
-        except:
-            logger.exception("Failed to poke on expire_cache")
+        self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys)
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 77c64722c7..efbd87918e 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -69,38 +69,25 @@ class SlavedAccountDataStore(BaseSlavedStore):
         result["tag_account_data"] = position
         return result
 
-    def process_replication(self, result):
-        stream = result.get("user_account_data")
-        if stream:
-            self._account_data_id_gen.advance(int(stream["position"]))
-            for row in stream["rows"]:
-                position, user_id, data_type = row[:3]
-                self.get_global_account_data_by_type_for_user.invalidate(
-                    (data_type, user_id,)
-                )
-                self.get_account_data_for_user.invalidate((user_id,))
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "tag_account_data":
+            self._account_data_id_gen.advance(token)
+            for row in rows:
+                self.get_tags_for_user.invalidate((row.user_id,))
                 self._account_data_stream_cache.entity_has_changed(
-                    user_id, position
+                    row.user_id, token
                 )
-
-        stream = result.get("room_account_data")
-        if stream:
-            self._account_data_id_gen.advance(int(stream["position"]))
-            for row in stream["rows"]:
-                position, user_id = row[:2]
-                self.get_account_data_for_user.invalidate((user_id,))
+        elif stream_name == "account_data":
+            self._account_data_id_gen.advance(token)
+            for row in rows:
+                if not row.room_id:
+                    self.get_global_account_data_by_type_for_user.invalidate(
+                        (row.data_type, row.user_id,)
+                    )
+                self.get_account_data_for_user.invalidate((row.user_id,))
                 self._account_data_stream_cache.entity_has_changed(
-                    user_id, position
+                    row.user_id, token
                 )
-
-        stream = result.get("tag_account_data")
-        if stream:
-            self._account_data_id_gen.advance(int(stream["position"]))
-            for row in stream["rows"]:
-                position, user_id = row[:2]
-                self.get_tags_for_user.invalidate((user_id,))
-                self._account_data_stream_cache.entity_has_changed(
-                    user_id, position
-                )
-
-        return super(SlavedAccountDataStore, self).process_replication(result)
+        return super(SlavedAccountDataStore, self).process_replication_rows(
+            stream_name, token, rows
+        )
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index f9102e0d89..6f3fb64770 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -53,21 +53,18 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
         result["to_device"] = self._device_inbox_id_gen.get_current_token()
         return result
 
-    def process_replication(self, result):
-        stream = result.get("to_device")
-        if stream:
-            self._device_inbox_id_gen.advance(int(stream["position"]))
-            for row in stream["rows"]:
-                stream_id = row[0]
-                entity = row[1]
-
-                if entity.startswith("@"):
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "to_device":
+            self._device_inbox_id_gen.advance(token)
+            for row in rows:
+                if row.entity.startswith("@"):
                     self._device_inbox_stream_cache.entity_has_changed(
-                        entity, stream_id
+                        row.entity, token
                     )
                 else:
                     self._device_federation_outbox_stream_cache.entity_has_changed(
-                        entity, stream_id
+                        row.entity, token
                     )
-
-        return super(SlavedDeviceInboxStore, self).process_replication(result)
+        return super(SlavedDeviceInboxStore, self).process_replication_rows(
+            stream_name, token, rows
+        )
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index ca46aa17b6..4d4a435471 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -51,22 +51,18 @@ class SlavedDeviceStore(BaseSlavedStore):
         result["device_lists"] = self._device_list_id_gen.get_current_token()
         return result
 
-    def process_replication(self, result):
-        stream = result.get("device_lists")
-        if stream:
-            self._device_list_id_gen.advance(int(stream["position"]))
-            for row in stream["rows"]:
-                stream_id = row[0]
-                user_id = row[1]
-                destination = row[2]
-
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "device_lists":
+            self._device_list_id_gen.advance(token)
+            for row in rows:
                 self._device_list_stream_cache.entity_has_changed(
-                    user_id, stream_id
+                    row.user_id, token
                 )
 
-                if destination:
+                if row.destination:
                     self._device_list_federation_stream_cache.entity_has_changed(
-                        destination, stream_id
+                        row.destination, token
                     )
-
-        return super(SlavedDeviceStore, self).process_replication(result)
+        return super(SlavedDeviceStore, self).process_replication_rows(
+            stream_name, token, rows
+        )
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index d4db1e452e..5fd47706ef 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -201,48 +201,25 @@ class SlavedEventStore(BaseSlavedStore):
         result["backfill"] = -self._backfill_id_gen.get_current_token()
         return result
 
-    def process_replication(self, result):
-        stream = result.get("events")
-        if stream:
-            self._stream_id_gen.advance(int(stream["position"]))
-
-            if stream["rows"]:
-                logger.info("Got %d event rows", len(stream["rows"]))
-
-            for row in stream["rows"]:
-                self._process_replication_row(
-                    row, backfilled=False,
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "events":
+            self._stream_id_gen.advance(token)
+            for row in rows:
+                self.invalidate_caches_for_event(
+                    token, row.event_id, row.room_id, row.type, row.state_key,
+                    row.redacts,
+                    backfilled=False,
                 )
-
-        stream = result.get("backfill")
-        if stream:
-            self._backfill_id_gen.advance(-int(stream["position"]))
-            for row in stream["rows"]:
-                self._process_replication_row(
-                    row, backfilled=True,
+        elif stream_name == "backfill":
+            self._backfill_id_gen.advance(-token)
+            for row in rows:
+                self.invalidate_caches_for_event(
+                    -token, row.event_id, row.room_id, row.type, row.state_key,
+                    row.redacts,
+                    backfilled=True,
                 )
-
-        stream = result.get("forward_ex_outliers")
-        if stream:
-            self._stream_id_gen.advance(int(stream["position"]))
-            for row in stream["rows"]:
-                event_id = row[1]
-                self._invalidate_get_event_cache(event_id)
-
-        stream = result.get("backward_ex_outliers")
-        if stream:
-            self._backfill_id_gen.advance(-int(stream["position"]))
-            for row in stream["rows"]:
-                event_id = row[1]
-                self._invalidate_get_event_cache(event_id)
-
-        return super(SlavedEventStore, self).process_replication(result)
-
-    def _process_replication_row(self, row, backfilled):
-        stream_ordering = row[0] if not backfilled else -row[0]
-        self.invalidate_caches_for_event(
-            stream_ordering, row[1], row[2], row[3], row[4], row[5],
-            backfilled=backfilled,
+        return super(SlavedEventStore, self).process_replication_rows(
+            stream_name, token, rows
         )
 
     def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index e4a2414d78..dffc80adc3 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -48,15 +48,14 @@ class SlavedPresenceStore(BaseSlavedStore):
         result["presence"] = position
         return result
 
-    def process_replication(self, result):
-        stream = result.get("presence")
-        if stream:
-            self._presence_id_gen.advance(int(stream["position"]))
-            for row in stream["rows"]:
-                position, user_id = row[:2]
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "presence":
+            self._presence_id_gen.advance(token)
+            for row in rows:
                 self.presence_stream_cache.entity_has_changed(
-                    user_id, position
+                    row.user_id, token
                 )
-                self._get_presence_for_user.invalidate((user_id,))
-
-        return super(SlavedPresenceStore, self).process_replication(result)
+                self._get_presence_for_user.invalidate((row.user_id,))
+        return super(SlavedPresenceStore, self).process_replication_rows(
+            stream_name, token, rows
+        )
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 21ceb0213a..83e880fdd2 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -50,18 +50,15 @@ class SlavedPushRuleStore(SlavedEventStore):
         result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
         return result
 
-    def process_replication(self, result):
-        stream = result.get("push_rules")
-        if stream:
-            for row in stream["rows"]:
-                position = row[0]
-                user_id = row[2]
-                self.get_push_rules_for_user.invalidate((user_id,))
-                self.get_push_rules_enabled_for_user.invalidate((user_id,))
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "push_rules":
+            self._push_rules_stream_id_gen.advance(token)
+            for row in rows:
+                self.get_push_rules_for_user.invalidate((row.user_id,))
+                self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
                 self.push_rules_stream_cache.entity_has_changed(
-                    user_id, position
+                    row.user_id, token
                 )
-
-            self._push_rules_stream_id_gen.advance(int(stream["position"]))
-
-        return super(SlavedPushRuleStore, self).process_replication(result)
+        return super(SlavedPushRuleStore, self).process_replication_rows(
+            stream_name, token, rows
+        )
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index d88206b3bb..4e8d68ece9 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -40,13 +40,9 @@ class SlavedPusherStore(BaseSlavedStore):
         result["pushers"] = self._pushers_id_gen.get_current_token()
         return result
 
-    def process_replication(self, result):
-        stream = result.get("pushers")
-        if stream:
-            self._pushers_id_gen.advance(int(stream["position"]))
-
-        stream = result.get("deleted_pushers")
-        if stream:
-            self._pushers_id_gen.advance(int(stream["position"]))
-
-        return super(SlavedPusherStore, self).process_replication(result)
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "pushers":
+            self._pushers_id_gen.advance(token)
+        return super(SlavedPusherStore, self).process_replication_rows(
+            stream_name, token, rows
+        )
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index ac9662d399..b371574ece 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -65,20 +65,22 @@ class SlavedReceiptsStore(BaseSlavedStore):
         result["receipts"] = self._receipts_id_gen.get_current_token()
         return result
 
-    def process_replication(self, result):
-        stream = result.get("receipts")
-        if stream:
-            self._receipts_id_gen.advance(int(stream["position"]))
-            for row in stream["rows"]:
-                position, room_id, receipt_type, user_id = row[:4]
-                self.invalidate_caches_for_receipt(room_id, receipt_type, user_id)
-                self._receipts_stream_cache.entity_has_changed(room_id, position)
-
-        return super(SlavedReceiptsStore, self).process_replication(result)
-
     def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
         self.get_receipts_for_user.invalidate((user_id, receipt_type))
         self.get_linearized_receipts_for_room.invalidate_many((room_id,))
         self.get_last_receipt_event_id_for_user.invalidate(
             (user_id, room_id, receipt_type)
         )
+
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "receipts":
+            self._receipts_id_gen.advance(token)
+            for row in rows:
+                self.invalidate_caches_for_receipt(
+                    row.room_id, row.receipt_type, row.user_id
+                )
+                self._receipts_stream_cache.entity_has_changed(row.room_id, token)
+
+        return super(SlavedReceiptsStore, self).process_replication_rows(
+            stream_name, token, rows
+        )
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 6df9a25ef3..f510384033 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -46,9 +46,10 @@ class RoomStore(BaseSlavedStore):
         result["public_rooms"] = self._public_room_id_gen.get_current_token()
         return result
 
-    def process_replication(self, result):
-        stream = result.get("public_rooms")
-        if stream:
-            self._public_room_id_gen.advance(int(stream["position"]))
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "public_rooms":
+            self._public_room_id_gen.advance(token)
 
-        return super(RoomStore, self).process_replication(result)
+        return super(RoomStore, self).process_replication_rows(
+            stream_name, token, rows
+        )
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index b82868054d..81063f19a1 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -12,12 +12,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
+from twisted.internet import defer, reactor
 from tests import unittest
 
 from mock import Mock, NonCallableMock
 from tests.utils import setup_test_homeserver
-from synapse.replication.resource import ReplicationResource
+from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.replication.tcp.client import (
+    ReplicationClientHandler, ReplicationClientFactory,
+)
 
 
 class BaseSlavedStoreTestCase(unittest.TestCase):
@@ -33,18 +36,29 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
         )
         self.hs.get_ratelimiter().send_message.return_value = (True, 0)
 
-        self.replication = ReplicationResource(self.hs)
-
         self.master_store = self.hs.get_datastore()
         self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
         self.event_id = 0
 
+        server_factory = ReplicationStreamProtocolFactory(self.hs)
+        listener = reactor.listenUNIX("\0xxx", server_factory)
+        self.addCleanup(listener.stopListening)
+        self.streamer = server_factory.streamer
+
+        self.replication_handler = ReplicationClientHandler(self.slaved_store)
+        client_factory = ReplicationClientFactory(
+            self.hs, "client_name", self.replication_handler
+        )
+        client_connector = reactor.connectUNIX("\0xxx", client_factory)
+        self.addCleanup(client_factory.stopTrying)
+        self.addCleanup(client_connector.disconnect)
+
     @defer.inlineCallbacks
     def replicate(self):
-        streams = self.slaved_store.stream_positions()
-        writer = yield self.replication.replicate(streams, 100)
-        result = writer.finish()
-        yield self.slaved_store.process_replication(result)
+        yield self.streamer.on_notifier_poke()
+        d = self.replication_handler.await_sync("replication_test")
+        self.streamer.send_sync_to_all_connections("replication_test")
+        yield d
 
     @defer.inlineCallbacks
     def check(self, method, args, expected_result=None):