diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index 589ee94c66..19f214281e 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -14,7 +14,7 @@
# limitations under the License.
from synapse.http.server import JsonResource
-from synapse.replication.http import membership, send_event
+from synapse.replication.http import federation, membership, send_event
REPLICATION_PREFIX = "/_synapse/replication"
@@ -27,3 +27,4 @@ class ReplicationRestResource(JsonResource):
def register_servlets(self, hs):
send_event.register_servlets(hs, self)
membership.register_servlets(hs, self)
+ federation.register_servlets(hs, self)
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
new file mode 100644
index 0000000000..64a79da162
--- /dev/null
+++ b/synapse/replication/http/federation.py
@@ -0,0 +1,259 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.events import FrozenEvent
+from synapse.events.snapshot import EventContext
+from synapse.http.servlet import parse_json_object_from_request
+from synapse.replication.http._base import ReplicationEndpoint
+from synapse.util.metrics import Measure
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
+ """Handles events newly received from federation, including persisting and
+ notifying.
+
+ The API looks like:
+
+ POST /_synapse/replication/fed_send_events/:txn_id
+
+ {
+ "events": [{
+ "event": { .. serialized event .. },
+ "internal_metadata": { .. serialized internal_metadata .. },
+ "rejected_reason": .., // The event.rejected_reason field
+ "context": { .. serialized event context .. },
+ }],
+ "backfilled": false
+ """
+
+ NAME = "fed_send_events"
+ PATH_ARGS = ()
+
+ def __init__(self, hs):
+ super(ReplicationFederationSendEventsRestServlet, self).__init__(hs)
+
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self.federation_handler = hs.get_handlers().federation_handler
+
+ @staticmethod
+ @defer.inlineCallbacks
+ def _serialize_payload(store, event_and_contexts, backfilled):
+ """
+ Args:
+ store
+ event_and_contexts (list[tuple[FrozenEvent, EventContext]])
+ backfilled (bool): Whether or not the events are the result of
+ backfilling
+ """
+ event_payloads = []
+ for event, context in event_and_contexts:
+ serialized_context = yield context.serialize(event, store)
+
+ event_payloads.append({
+ "event": event.get_pdu_json(),
+ "internal_metadata": event.internal_metadata.get_dict(),
+ "rejected_reason": event.rejected_reason,
+ "context": serialized_context,
+ })
+
+ payload = {
+ "events": event_payloads,
+ "backfilled": backfilled,
+ }
+
+ defer.returnValue(payload)
+
+ @defer.inlineCallbacks
+ def _handle_request(self, request):
+ with Measure(self.clock, "repl_fed_send_events_parse"):
+ content = parse_json_object_from_request(request)
+
+ backfilled = content["backfilled"]
+
+ event_payloads = content["events"]
+
+ event_and_contexts = []
+ for event_payload in event_payloads:
+ event_dict = event_payload["event"]
+ internal_metadata = event_payload["internal_metadata"]
+ rejected_reason = event_payload["rejected_reason"]
+ event = FrozenEvent(event_dict, internal_metadata, rejected_reason)
+
+ context = yield EventContext.deserialize(
+ self.store, event_payload["context"],
+ )
+
+ event_and_contexts.append((event, context))
+
+ logger.info(
+ "Got %d events from federation",
+ len(event_and_contexts),
+ )
+
+ yield self.federation_handler.persist_events_and_notify(
+ event_and_contexts, backfilled,
+ )
+
+ defer.returnValue((200, {}))
+
+
+class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
+ """Handles EDUs newly received from federation, including persisting and
+ notifying.
+
+ Request format:
+
+ POST /_synapse/replication/fed_send_edu/:edu_type/:txn_id
+
+ {
+ "origin": ...,
+ "content: { ... }
+ }
+ """
+
+ NAME = "fed_send_edu"
+ PATH_ARGS = ("edu_type",)
+
+ def __init__(self, hs):
+ super(ReplicationFederationSendEduRestServlet, self).__init__(hs)
+
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self.registry = hs.get_federation_registry()
+
+ @staticmethod
+ def _serialize_payload(edu_type, origin, content):
+ return {
+ "origin": origin,
+ "content": content,
+ }
+
+ @defer.inlineCallbacks
+ def _handle_request(self, request, edu_type):
+ with Measure(self.clock, "repl_fed_send_edu_parse"):
+ content = parse_json_object_from_request(request)
+
+ origin = content["origin"]
+ edu_content = content["content"]
+
+ logger.info(
+ "Got %r edu from %s",
+ edu_type, origin,
+ )
+
+ result = yield self.registry.on_edu(edu_type, origin, edu_content)
+
+ defer.returnValue((200, result))
+
+
+class ReplicationGetQueryRestServlet(ReplicationEndpoint):
+ """Handle responding to queries from federation.
+
+ Request format:
+
+ POST /_synapse/replication/fed_query/:query_type
+
+ {
+ "args": { ... }
+ }
+ """
+
+ NAME = "fed_query"
+ PATH_ARGS = ("query_type",)
+
+ # This is a query, so let's not bother caching
+ CACHE = False
+
+ def __init__(self, hs):
+ super(ReplicationGetQueryRestServlet, self).__init__(hs)
+
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self.registry = hs.get_federation_registry()
+
+ @staticmethod
+ def _serialize_payload(query_type, args):
+ """
+ Args:
+ query_type (str)
+ args (dict): The arguments received for the given query type
+ """
+ return {
+ "args": args,
+ }
+
+ @defer.inlineCallbacks
+ def _handle_request(self, request, query_type):
+ with Measure(self.clock, "repl_fed_query_parse"):
+ content = parse_json_object_from_request(request)
+
+ args = content["args"]
+
+ logger.info(
+ "Got %r query",
+ query_type,
+ )
+
+ result = yield self.registry.on_query(query_type, args)
+
+ defer.returnValue((200, result))
+
+
+class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
+ """Called to clean up any data in DB for a given room, ready for the
+ server to join the room.
+
+ Request format:
+
+ POST /_synapse/replication/fed_query/:fed_cleanup_room/:txn_id
+
+ {}
+ """
+
+ NAME = "fed_cleanup_room"
+ PATH_ARGS = ("room_id",)
+
+ def __init__(self, hs):
+ super(ReplicationCleanRoomRestServlet, self).__init__(hs)
+
+ self.store = hs.get_datastore()
+
+ @staticmethod
+ def _serialize_payload(room_id, args):
+ """
+ Args:
+ room_id (str)
+ """
+ return {}
+
+ @defer.inlineCallbacks
+ def _handle_request(self, request, room_id):
+ yield self.store.clean_room_for_join(room_id)
+
+ defer.returnValue((200, {}))
+
+
+def register_servlets(hs, http_server):
+ ReplicationFederationSendEventsRestServlet(hs).register(http_server)
+ ReplicationFederationSendEduRestServlet(hs).register(http_server)
+ ReplicationGetQueryRestServlet(hs).register(http_server)
+ ReplicationCleanRoomRestServlet(hs).register(http_server)
diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py
index 9c9a5eadd9..3527beb3c9 100644
--- a/synapse/replication/slave/storage/transactions.py
+++ b/synapse/replication/slave/storage/transactions.py
@@ -13,19 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage import DataStore
from synapse.storage.transactions import TransactionStore
from ._base import BaseSlavedStore
-class TransactionStore(BaseSlavedStore):
- get_destination_retry_timings = TransactionStore.__dict__[
- "get_destination_retry_timings"
- ]
- _get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__
- set_destination_retry_timings = DataStore.set_destination_retry_timings.__func__
- _set_destination_retry_timings = DataStore._set_destination_retry_timings.__func__
-
- prep_send_transaction = DataStore.prep_send_transaction.__func__
- delivered_txn = DataStore.delivered_txn.__func__
+class SlavedTransactionStore(TransactionStore, BaseSlavedStore):
+ pass
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 970e94313e..cbe9645817 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -107,7 +107,7 @@ class ReplicationClientHandler(object):
Can be overriden in subclasses to handle more.
"""
logger.info("Received rdata %s -> %s", stream_name, token)
- self.store.process_replication_rows(stream_name, token, rows)
+ return self.store.process_replication_rows(stream_name, token, rows)
def on_position(self, stream_name, token):
"""Called when we get new position data. By default this just pokes
@@ -115,7 +115,7 @@ class ReplicationClientHandler(object):
Can be overriden in subclasses to handle more.
"""
- self.store.process_replication_rows(stream_name, token, [])
+ return self.store.process_replication_rows(stream_name, token, [])
def on_sync(self, data):
"""When we received a SYNC we wake up any deferreds that were waiting
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index f3908df642..327556f6a1 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -59,6 +59,12 @@ class Command(object):
"""
return self.data
+ def get_logcontext_id(self):
+ """Get a suitable string for the logcontext when processing this command"""
+
+ # by default, we just use the command name.
+ return self.NAME
+
class ServerCommand(Command):
"""Sent by the server on new connection and includes the server_name.
@@ -116,6 +122,9 @@ class RdataCommand(Command):
_json_encoder.encode(self.row),
))
+ def get_logcontext_id(self):
+ return "RDATA-" + self.stream_name
+
class PositionCommand(Command):
"""Sent by the client to tell the client the stream postition without
@@ -190,6 +199,9 @@ class ReplicateCommand(Command):
def to_line(self):
return " ".join((self.stream_name, str(self.token),))
+ def get_logcontext_id(self):
+ return "REPLICATE-" + self.stream_name
+
class UserSyncCommand(Command):
"""Sent by the client to inform the server that a user has started or
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index dec5ac0913..74e892c104 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -63,6 +63,8 @@ from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.stringutils import random_string
from .commands import (
@@ -222,7 +224,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# Now lets try and call on_<CMD_NAME> function
try:
- getattr(self, "on_%s" % (cmd_name,))(cmd)
+ run_as_background_process(
+ "replication-" + cmd.get_logcontext_id(),
+ getattr(self, "on_%s" % (cmd_name,)),
+ cmd,
+ )
except Exception:
logger.exception("[%s] Failed to handle line: %r", self.id(), line)
@@ -387,7 +393,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.name = cmd.data
def on_USER_SYNC(self, cmd):
- self.streamer.on_user_sync(
+ return self.streamer.on_user_sync(
self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms,
)
@@ -397,22 +403,33 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
if stream_name == "ALL":
# Subscribe to all streams we're publishing to.
- for stream in iterkeys(self.streamer.streams_by_name):
- self.subscribe_to_stream(stream, token)
+ deferreds = [
+ run_in_background(
+ self.subscribe_to_stream,
+ stream, token,
+ )
+ for stream in iterkeys(self.streamer.streams_by_name)
+ ]
+
+ return make_deferred_yieldable(
+ defer.gatherResults(deferreds, consumeErrors=True)
+ )
else:
- self.subscribe_to_stream(stream_name, token)
+ return self.subscribe_to_stream(stream_name, token)
def on_FEDERATION_ACK(self, cmd):
- self.streamer.federation_ack(cmd.token)
+ return self.streamer.federation_ack(cmd.token)
def on_REMOVE_PUSHER(self, cmd):
- self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
+ return self.streamer.on_remove_pusher(
+ cmd.app_id, cmd.push_key, cmd.user_id,
+ )
def on_INVALIDATE_CACHE(self, cmd):
- self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
+ return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
def on_USER_IP(self, cmd):
- self.streamer.on_user_ip(
+ return self.streamer.on_user_ip(
cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id,
cmd.last_seen,
)
@@ -542,14 +559,13 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, [])
rows.append(row)
-
- self.handler.on_rdata(stream_name, cmd.token, rows)
+ return self.handler.on_rdata(stream_name, cmd.token, rows)
def on_POSITION(self, cmd):
- self.handler.on_position(cmd.stream_name, cmd.token)
+ return self.handler.on_position(cmd.stream_name, cmd.token)
def on_SYNC(self, cmd):
- self.handler.on_sync(cmd.data)
+ return self.handler.on_sync(cmd.data)
def replicate(self, stream_name, token):
"""Send the subscription request to the server
|