diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index ab133db872..b962641166 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -15,7 +15,6 @@
from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import PostgresEngine
-from twisted.internet import defer
from ._slaved_id_tracker import SlavedIdTracker
@@ -34,8 +33,7 @@ class BaseSlavedStore(SQLBaseStore):
else:
self._cache_id_gen = None
- self.expire_cache_url = hs.config.worker_replication_url + "/expire_cache"
- self.http_client = hs.get_simple_http_client()
+ self.hs = hs
def stream_positions(self):
pos = {}
@@ -43,35 +41,20 @@ class BaseSlavedStore(SQLBaseStore):
pos["caches"] = self._cache_id_gen.get_current_token()
return pos
- def process_replication(self, result):
- stream = result.get("caches")
- if stream:
- for row in stream["rows"]:
- (
- position, cache_func, keys, invalidation_ts,
- ) = row
-
+ def process_replication_rows(self, stream_name, token, rows):
+ if stream_name == "caches":
+ self._cache_id_gen.advance(token)
+ for row in rows:
try:
- getattr(self, cache_func).invalidate(tuple(keys))
+ getattr(self, row.cache_func).invalidate(tuple(row.keys))
except AttributeError:
# We probably haven't pulled in the cache in this worker,
# which is fine.
pass
- self._cache_id_gen.advance(int(stream["position"]))
- return defer.succeed(None)
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
txn.call_after(cache_func.invalidate, keys)
txn.call_after(self._send_invalidation_poke, cache_func, keys)
- @defer.inlineCallbacks
def _send_invalidation_poke(self, cache_func, keys):
- try:
- yield self.http_client.post_json_get_json(self.expire_cache_url, {
- "invalidate": [{
- "name": cache_func.__name__,
- "keys": list(keys),
- }]
- })
- except:
- logger.exception("Failed to poke on expire_cache")
+ self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys)
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 77c64722c7..efbd87918e 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -69,38 +69,25 @@ class SlavedAccountDataStore(BaseSlavedStore):
result["tag_account_data"] = position
return result
- def process_replication(self, result):
- stream = result.get("user_account_data")
- if stream:
- self._account_data_id_gen.advance(int(stream["position"]))
- for row in stream["rows"]:
- position, user_id, data_type = row[:3]
- self.get_global_account_data_by_type_for_user.invalidate(
- (data_type, user_id,)
- )
- self.get_account_data_for_user.invalidate((user_id,))
+ def process_replication_rows(self, stream_name, token, rows):
+ if stream_name == "tag_account_data":
+ self._account_data_id_gen.advance(token)
+ for row in rows:
+ self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(
- user_id, position
+ row.user_id, token
)
-
- stream = result.get("room_account_data")
- if stream:
- self._account_data_id_gen.advance(int(stream["position"]))
- for row in stream["rows"]:
- position, user_id = row[:2]
- self.get_account_data_for_user.invalidate((user_id,))
+ elif stream_name == "account_data":
+ self._account_data_id_gen.advance(token)
+ for row in rows:
+ if not row.room_id:
+ self.get_global_account_data_by_type_for_user.invalidate(
+ (row.data_type, row.user_id,)
+ )
+ self.get_account_data_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(
- user_id, position
+ row.user_id, token
)
-
- stream = result.get("tag_account_data")
- if stream:
- self._account_data_id_gen.advance(int(stream["position"]))
- for row in stream["rows"]:
- position, user_id = row[:2]
- self.get_tags_for_user.invalidate((user_id,))
- self._account_data_stream_cache.entity_has_changed(
- user_id, position
- )
-
- return super(SlavedAccountDataStore, self).process_replication(result)
+ return super(SlavedAccountDataStore, self).process_replication_rows(
+ stream_name, token, rows
+ )
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index f9102e0d89..6f3fb64770 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -53,21 +53,18 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
result["to_device"] = self._device_inbox_id_gen.get_current_token()
return result
- def process_replication(self, result):
- stream = result.get("to_device")
- if stream:
- self._device_inbox_id_gen.advance(int(stream["position"]))
- for row in stream["rows"]:
- stream_id = row[0]
- entity = row[1]
-
- if entity.startswith("@"):
+ def process_replication_rows(self, stream_name, token, rows):
+ if stream_name == "to_device":
+ self._device_inbox_id_gen.advance(token)
+ for row in rows:
+ if row.entity.startswith("@"):
self._device_inbox_stream_cache.entity_has_changed(
- entity, stream_id
+ row.entity, token
)
else:
self._device_federation_outbox_stream_cache.entity_has_changed(
- entity, stream_id
+ row.entity, token
)
-
- return super(SlavedDeviceInboxStore, self).process_replication(result)
+ return super(SlavedDeviceInboxStore, self).process_replication_rows(
+ stream_name, token, rows
+ )
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index ca46aa17b6..4d4a435471 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -51,22 +51,18 @@ class SlavedDeviceStore(BaseSlavedStore):
result["device_lists"] = self._device_list_id_gen.get_current_token()
return result
- def process_replication(self, result):
- stream = result.get("device_lists")
- if stream:
- self._device_list_id_gen.advance(int(stream["position"]))
- for row in stream["rows"]:
- stream_id = row[0]
- user_id = row[1]
- destination = row[2]
-
+ def process_replication_rows(self, stream_name, token, rows):
+ if stream_name == "device_lists":
+ self._device_list_id_gen.advance(token)
+ for row in rows:
self._device_list_stream_cache.entity_has_changed(
- user_id, stream_id
+ row.user_id, token
)
- if destination:
+ if row.destination:
self._device_list_federation_stream_cache.entity_has_changed(
- destination, stream_id
+ row.destination, token
)
-
- return super(SlavedDeviceStore, self).process_replication(result)
+ return super(SlavedDeviceStore, self).process_replication_rows(
+ stream_name, token, rows
+ )
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index d4db1e452e..5fd47706ef 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -201,48 +201,25 @@ class SlavedEventStore(BaseSlavedStore):
result["backfill"] = -self._backfill_id_gen.get_current_token()
return result
- def process_replication(self, result):
- stream = result.get("events")
- if stream:
- self._stream_id_gen.advance(int(stream["position"]))
-
- if stream["rows"]:
- logger.info("Got %d event rows", len(stream["rows"]))
-
- for row in stream["rows"]:
- self._process_replication_row(
- row, backfilled=False,
+ def process_replication_rows(self, stream_name, token, rows):
+ if stream_name == "events":
+ self._stream_id_gen.advance(token)
+ for row in rows:
+ self.invalidate_caches_for_event(
+ token, row.event_id, row.room_id, row.type, row.state_key,
+ row.redacts,
+ backfilled=False,
)
-
- stream = result.get("backfill")
- if stream:
- self._backfill_id_gen.advance(-int(stream["position"]))
- for row in stream["rows"]:
- self._process_replication_row(
- row, backfilled=True,
+ elif stream_name == "backfill":
+ self._backfill_id_gen.advance(-token)
+ for row in rows:
+ self.invalidate_caches_for_event(
+ -token, row.event_id, row.room_id, row.type, row.state_key,
+ row.redacts,
+ backfilled=True,
)
-
- stream = result.get("forward_ex_outliers")
- if stream:
- self._stream_id_gen.advance(int(stream["position"]))
- for row in stream["rows"]:
- event_id = row[1]
- self._invalidate_get_event_cache(event_id)
-
- stream = result.get("backward_ex_outliers")
- if stream:
- self._backfill_id_gen.advance(-int(stream["position"]))
- for row in stream["rows"]:
- event_id = row[1]
- self._invalidate_get_event_cache(event_id)
-
- return super(SlavedEventStore, self).process_replication(result)
-
- def _process_replication_row(self, row, backfilled):
- stream_ordering = row[0] if not backfilled else -row[0]
- self.invalidate_caches_for_event(
- stream_ordering, row[1], row[2], row[3], row[4], row[5],
- backfilled=backfilled,
+ return super(SlavedEventStore, self).process_replication_rows(
+ stream_name, token, rows
)
def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index e4a2414d78..dffc80adc3 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -48,15 +48,14 @@ class SlavedPresenceStore(BaseSlavedStore):
result["presence"] = position
return result
- def process_replication(self, result):
- stream = result.get("presence")
- if stream:
- self._presence_id_gen.advance(int(stream["position"]))
- for row in stream["rows"]:
- position, user_id = row[:2]
+ def process_replication_rows(self, stream_name, token, rows):
+ if stream_name == "presence":
+ self._presence_id_gen.advance(token)
+ for row in rows:
self.presence_stream_cache.entity_has_changed(
- user_id, position
+ row.user_id, token
)
- self._get_presence_for_user.invalidate((user_id,))
-
- return super(SlavedPresenceStore, self).process_replication(result)
+ self._get_presence_for_user.invalidate((row.user_id,))
+ return super(SlavedPresenceStore, self).process_replication_rows(
+ stream_name, token, rows
+ )
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 21ceb0213a..83e880fdd2 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -50,18 +50,15 @@ class SlavedPushRuleStore(SlavedEventStore):
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
return result
- def process_replication(self, result):
- stream = result.get("push_rules")
- if stream:
- for row in stream["rows"]:
- position = row[0]
- user_id = row[2]
- self.get_push_rules_for_user.invalidate((user_id,))
- self.get_push_rules_enabled_for_user.invalidate((user_id,))
+ def process_replication_rows(self, stream_name, token, rows):
+ if stream_name == "push_rules":
+ self._push_rules_stream_id_gen.advance(token)
+ for row in rows:
+ self.get_push_rules_for_user.invalidate((row.user_id,))
+ self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
self.push_rules_stream_cache.entity_has_changed(
- user_id, position
+ row.user_id, token
)
-
- self._push_rules_stream_id_gen.advance(int(stream["position"]))
-
- return super(SlavedPushRuleStore, self).process_replication(result)
+ return super(SlavedPushRuleStore, self).process_replication_rows(
+ stream_name, token, rows
+ )
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index d88206b3bb..4e8d68ece9 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -40,13 +40,9 @@ class SlavedPusherStore(BaseSlavedStore):
result["pushers"] = self._pushers_id_gen.get_current_token()
return result
- def process_replication(self, result):
- stream = result.get("pushers")
- if stream:
- self._pushers_id_gen.advance(int(stream["position"]))
-
- stream = result.get("deleted_pushers")
- if stream:
- self._pushers_id_gen.advance(int(stream["position"]))
-
- return super(SlavedPusherStore, self).process_replication(result)
+ def process_replication_rows(self, stream_name, token, rows):
+ if stream_name == "pushers":
+ self._pushers_id_gen.advance(token)
+ return super(SlavedPusherStore, self).process_replication_rows(
+ stream_name, token, rows
+ )
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index ac9662d399..b371574ece 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -65,20 +65,22 @@ class SlavedReceiptsStore(BaseSlavedStore):
result["receipts"] = self._receipts_id_gen.get_current_token()
return result
- def process_replication(self, result):
- stream = result.get("receipts")
- if stream:
- self._receipts_id_gen.advance(int(stream["position"]))
- for row in stream["rows"]:
- position, room_id, receipt_type, user_id = row[:4]
- self.invalidate_caches_for_receipt(room_id, receipt_type, user_id)
- self._receipts_stream_cache.entity_has_changed(room_id, position)
-
- return super(SlavedReceiptsStore, self).process_replication(result)
-
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type))
self.get_linearized_receipts_for_room.invalidate_many((room_id,))
self.get_last_receipt_event_id_for_user.invalidate(
(user_id, room_id, receipt_type)
)
+
+ def process_replication_rows(self, stream_name, token, rows):
+ if stream_name == "receipts":
+ self._receipts_id_gen.advance(token)
+ for row in rows:
+ self.invalidate_caches_for_receipt(
+ row.room_id, row.receipt_type, row.user_id
+ )
+ self._receipts_stream_cache.entity_has_changed(row.room_id, token)
+
+ return super(SlavedReceiptsStore, self).process_replication_rows(
+ stream_name, token, rows
+ )
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 6df9a25ef3..f510384033 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -46,9 +46,10 @@ class RoomStore(BaseSlavedStore):
result["public_rooms"] = self._public_room_id_gen.get_current_token()
return result
- def process_replication(self, result):
- stream = result.get("public_rooms")
- if stream:
- self._public_room_id_gen.advance(int(stream["position"]))
+ def process_replication_rows(self, stream_name, token, rows):
+ if stream_name == "public_rooms":
+ self._public_room_id_gen.advance(token)
- return super(RoomStore, self).process_replication(result)
+ return super(RoomStore, self).process_replication_rows(
+ stream_name, token, rows
+ )
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
new file mode 100644
index 0000000000..251d3afcf4
--- /dev/null
+++ b/synapse/replication/tcp/client.py
@@ -0,0 +1,196 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""A replication client for use by synapse workers.
+"""
+
+from twisted.internet import reactor, defer
+from twisted.internet.protocol import ReconnectingClientFactory
+
+from .commands import (
+ FederationAckCommand, UserSyncCommand, RemovePusherCommand, InvalidateCacheCommand,
+)
+from .protocol import ClientReplicationStreamProtocol
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationClientFactory(ReconnectingClientFactory):
+ """Factory for building connections to the master. Will reconnect if the
+ connection is lost.
+
+ Accepts a handler that will be called when new data is available or data
+ is required.
+ """
+ maxDelay = 5 # Try at least once every N seconds
+
+ def __init__(self, hs, client_name, handler):
+ self.client_name = client_name
+ self.handler = handler
+ self.server_name = hs.config.server_name
+ self._clock = hs.get_clock() # As self.clock is defined in super class
+
+ reactor.addSystemEventTrigger("before", "shutdown", self.stopTrying)
+
+ def startedConnecting(self, connector):
+ logger.info("Connecting to replication: %r", connector.getDestination())
+
+ def buildProtocol(self, addr):
+ logger.info("Connected to replication: %r", addr)
+ self.resetDelay()
+ return ClientReplicationStreamProtocol(
+ self.client_name, self.server_name, self._clock, self.handler
+ )
+
+ def clientConnectionLost(self, connector, reason):
+ logger.error("Lost replication conn: %r", reason)
+ ReconnectingClientFactory.clientConnectionLost(self, connector, reason)
+
+ def clientConnectionFailed(self, connector, reason):
+ logger.error("Failed to connect to replication: %r", reason)
+ ReconnectingClientFactory.clientConnectionFailed(
+ self, connector, reason
+ )
+
+
+class ReplicationClientHandler(object):
+ """A base handler that can be passed to the ReplicationClientFactory.
+
+ By default proxies incoming replication data to the SlaveStore.
+ """
+ def __init__(self, store):
+ self.store = store
+
+ # The current connection. None if we are currently (re)connecting
+ self.connection = None
+
+ # Any pending commands to be sent once a new connection has been
+ # established
+ self.pending_commands = []
+
+ # Map from string -> deferred, to wake up when receiveing a SYNC with
+ # the given string.
+ # Used for tests.
+ self.awaiting_syncs = {}
+
+ def start_replication(self, hs):
+ """Helper method to start a replication connection to the remote server
+ using TCP.
+ """
+ client_name = hs.config.worker_name
+ factory = ReplicationClientFactory(hs, client_name, self)
+ host = hs.config.worker_replication_host
+ port = hs.config.worker_replication_port
+ reactor.connectTCP(host, port, factory)
+
+ def on_rdata(self, stream_name, token, rows):
+ """Called when we get new replication data. By default this just pokes
+ the slave store.
+
+ Can be overriden in subclasses to handle more.
+ """
+ logger.info("Received rdata %s -> %s", stream_name, token)
+ self.store.process_replication_rows(stream_name, token, rows)
+
+ def on_position(self, stream_name, token):
+ """Called when we get new position data. By default this just pokes
+ the slave store.
+
+ Can be overriden in subclasses to handle more.
+ """
+ self.store.process_replication_rows(stream_name, token, [])
+
+ def on_sync(self, data):
+ """When we received a SYNC we wake up any deferreds that were waiting
+ for the sync with the given data.
+
+ Used by tests.
+ """
+ d = self.awaiting_syncs.pop(data, None)
+ if d:
+ d.callback(data)
+
+ def get_streams_to_replicate(self):
+ """Called when a new connection has been established and we need to
+ subscribe to streams.
+
+ Returns a dictionary of stream name to token.
+ """
+ args = self.store.stream_positions()
+ user_account_data = args.pop("user_account_data", None)
+ room_account_data = args.pop("room_account_data", None)
+ if user_account_data:
+ args["account_data"] = user_account_data
+ elif room_account_data:
+ args["account_data"] = room_account_data
+ return args
+
+ def get_currently_syncing_users(self):
+ """Get the list of currently syncing users (if any). This is called
+ when a connection has been established and we need to send the
+ currently syncing users. (Overriden by the synchrotron's only)
+ """
+ return []
+
+ def send_command(self, cmd):
+ """Send a command to master (when we get establish a connection if we
+ don't have one already.)
+ """
+ if self.connection:
+ self.connection.send_command(cmd)
+ else:
+ logger.warn("Queuing command as not connected: %r", cmd.NAME)
+ self.pending_commands.append(cmd)
+
+ def send_federation_ack(self, token):
+ """Ack data for the federation stream. This allows the master to drop
+ data stored purely in memory.
+ """
+ self.send_command(FederationAckCommand(token))
+
+ def send_user_sync(self, user_id, is_syncing, last_sync_ms):
+ """Poke the master that a user has started/stopped syncing.
+ """
+ self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms))
+
+ def send_remove_pusher(self, app_id, push_key, user_id):
+ """Poke the master to remove a pusher for a user
+ """
+ cmd = RemovePusherCommand(app_id, push_key, user_id)
+ self.send_command(cmd)
+
+ def send_invalidate_cache(self, cache_func, keys):
+ """Poke the master to invalidate a cache.
+ """
+ cmd = InvalidateCacheCommand(cache_func, keys)
+ self.send_command(cmd)
+
+ def await_sync(self, data):
+ """Returns a deferred that is resolved when we receive a SYNC command
+ with given data.
+
+ Used by tests.
+ """
+ return self.awaiting_syncs.setdefault(data, defer.Deferred())
+
+ def update_connection(self, connection):
+ """Called when a connection has been established (or lost with None).
+ """
+ self.connection = connection
+ if connection:
+ for cmd in self.pending_commands:
+ connection.send_command(cmd)
+ self.pending_commands = []
|