diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index 1d7a607529..19f214281e 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -14,8 +14,7 @@
# limitations under the License.
from synapse.http.server import JsonResource
-from synapse.replication.http import membership, send_event
-
+from synapse.replication.http import federation, membership, send_event
REPLICATION_PREFIX = "/_synapse/replication"
@@ -28,3 +27,4 @@ class ReplicationRestResource(JsonResource):
def register_servlets(self, hs):
send_event.register_servlets(hs, self)
membership.register_servlets(hs, self)
+ federation.register_servlets(hs, self)
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
new file mode 100644
index 0000000000..5e5376cf58
--- /dev/null
+++ b/synapse/replication/http/_base.py
@@ -0,0 +1,215 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+import logging
+import re
+
+from six.moves import urllib
+
+from twisted.internet import defer
+
+from synapse.api.errors import CodeMessageException, HttpResponseException
+from synapse.util.caches.response_cache import ResponseCache
+from synapse.util.stringutils import random_string
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationEndpoint(object):
+ """Helper base class for defining new replication HTTP endpoints.
+
+ This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..`
+ (with an `/:txn_id` prefix for cached requests.), where NAME is a name,
+ PATH_ARGS are a tuple of parameters to be encoded in the URL.
+
+ For example, if `NAME` is "send_event" and `PATH_ARGS` is `("event_id",)`,
+ with `CACHE` set to true then this generates an endpoint:
+
+ /_synapse/replication/send_event/:event_id/:txn_id
+
+ For POST/PUT requests the payload is serialized to json and sent as the
+ body, while for GET requests the payload is added as query parameters. See
+ `_serialize_payload` for details.
+
+ Incoming requests are handled by overriding `_handle_request`. Servers
+ must call `register` to register the path with the HTTP server.
+
+ Requests can be sent by calling the client returned by `make_client`.
+
+ Attributes:
+ NAME (str): A name for the endpoint, added to the path as well as used
+ in logging and metrics.
+ PATH_ARGS (tuple[str]): A list of parameters to be added to the path.
+ Adding parameters to the path (rather than payload) can make it
+ easier to follow along in the log files.
+ METHOD (str): The method of the HTTP request, defaults to POST. Can be
+ one of POST, PUT or GET. If GET then the payload is sent as query
+ parameters rather than a JSON body.
+ CACHE (bool): Whether server should cache the result of the request/
+ If true then transparently adds a txn_id to all requests, and
+ `_handle_request` must return a Deferred.
+ RETRY_ON_TIMEOUT(bool): Whether or not to retry the request when a 504
+ is received.
+ """
+
+ __metaclass__ = abc.ABCMeta
+
+ NAME = abc.abstractproperty()
+ PATH_ARGS = abc.abstractproperty()
+
+ METHOD = "POST"
+ CACHE = True
+ RETRY_ON_TIMEOUT = True
+
+ def __init__(self, hs):
+ if self.CACHE:
+ self.response_cache = ResponseCache(
+ hs, "repl." + self.NAME,
+ timeout_ms=30 * 60 * 1000,
+ )
+
+ assert self.METHOD in ("PUT", "POST", "GET")
+
+ @abc.abstractmethod
+ def _serialize_payload(**kwargs):
+ """Static method that is called when creating a request.
+
+ Concrete implementations should have explicit parameters (rather than
+ kwargs) so that an appropriate exception is raised if the client is
+ called with unexpected parameters. All PATH_ARGS must appear in
+ argument list.
+
+ Returns:
+ Deferred[dict]|dict: If POST/PUT request then dictionary must be
+ JSON serialisable, otherwise must be appropriate for adding as
+ query args.
+ """
+ return {}
+
+ @abc.abstractmethod
+ def _handle_request(self, request, **kwargs):
+ """Handle incoming request.
+
+ This is called with the request object and PATH_ARGS.
+
+ Returns:
+ Deferred[dict]: A JSON serialisable dict to be used as response
+ body of request.
+ """
+ pass
+
+ @classmethod
+ def make_client(cls, hs):
+ """Create a client that makes requests.
+
+ Returns a callable that accepts the same parameters as `_serialize_payload`.
+ """
+ clock = hs.get_clock()
+ host = hs.config.worker_replication_host
+ port = hs.config.worker_replication_http_port
+
+ client = hs.get_simple_http_client()
+
+ @defer.inlineCallbacks
+ def send_request(**kwargs):
+ data = yield cls._serialize_payload(**kwargs)
+
+ url_args = [urllib.parse.quote(kwargs[name]) for name in cls.PATH_ARGS]
+
+ if cls.CACHE:
+ txn_id = random_string(10)
+ url_args.append(txn_id)
+
+ if cls.METHOD == "POST":
+ request_func = client.post_json_get_json
+ elif cls.METHOD == "PUT":
+ request_func = client.put_json
+ elif cls.METHOD == "GET":
+ request_func = client.get_json
+ else:
+ # We have already asserted in the constructor that a
+ # compatible was picked, but lets be paranoid.
+ raise Exception(
+ "Unknown METHOD on %s replication endpoint" % (cls.NAME,)
+ )
+
+ uri = "http://%s:%s/_synapse/replication/%s/%s" % (
+ host, port, cls.NAME, "/".join(url_args)
+ )
+
+ try:
+ # We keep retrying the same request for timeouts. This is so that we
+ # have a good idea that the request has either succeeded or failed on
+ # the master, and so whether we should clean up or not.
+ while True:
+ try:
+ result = yield request_func(uri, data)
+ break
+ except CodeMessageException as e:
+ if e.code != 504 or not cls.RETRY_ON_TIMEOUT:
+ raise
+
+ logger.warn("%s request timed out", cls.NAME)
+
+ # If we timed out we probably don't need to worry about backing
+ # off too much, but lets just wait a little anyway.
+ yield clock.sleep(1)
+ except HttpResponseException as e:
+ # We convert to SynapseError as we know that it was a SynapseError
+ # on the master process that we should send to the client. (And
+ # importantly, not stack traces everywhere)
+ raise e.to_synapse_error()
+
+ defer.returnValue(result)
+
+ return send_request
+
+ def register(self, http_server):
+ """Called by the server to register this as a handler to the
+ appropriate path.
+ """
+
+ url_args = list(self.PATH_ARGS)
+ handler = self._handle_request
+ method = self.METHOD
+
+ if self.CACHE:
+ handler = self._cached_handler
+ url_args.append("txn_id")
+
+ args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
+ pattern = re.compile("^/_synapse/replication/%s/%s$" % (
+ self.NAME,
+ args
+ ))
+
+ http_server.register_paths(method, [pattern], handler)
+
+ def _cached_handler(self, request, txn_id, **kwargs):
+ """Called on new incoming requests when caching is enabled. Checks
+ if there is a cached response for the request and returns that,
+ otherwise calls `_handle_request` and caches its response.
+ """
+ # We just use the txn_id here, but we probably also want to use the
+ # other PATH_ARGS as well.
+
+ assert self.CACHE
+
+ return self.response_cache.wrap(
+ txn_id,
+ self._handle_request,
+ request, **kwargs
+ )
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
new file mode 100644
index 0000000000..64a79da162
--- /dev/null
+++ b/synapse/replication/http/federation.py
@@ -0,0 +1,259 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.events import FrozenEvent
+from synapse.events.snapshot import EventContext
+from synapse.http.servlet import parse_json_object_from_request
+from synapse.replication.http._base import ReplicationEndpoint
+from synapse.util.metrics import Measure
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
+ """Handles events newly received from federation, including persisting and
+ notifying.
+
+ The API looks like:
+
+ POST /_synapse/replication/fed_send_events/:txn_id
+
+ {
+ "events": [{
+ "event": { .. serialized event .. },
+ "internal_metadata": { .. serialized internal_metadata .. },
+ "rejected_reason": .., // The event.rejected_reason field
+ "context": { .. serialized event context .. },
+ }],
+ "backfilled": false
+ """
+
+ NAME = "fed_send_events"
+ PATH_ARGS = ()
+
+ def __init__(self, hs):
+ super(ReplicationFederationSendEventsRestServlet, self).__init__(hs)
+
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self.federation_handler = hs.get_handlers().federation_handler
+
+ @staticmethod
+ @defer.inlineCallbacks
+ def _serialize_payload(store, event_and_contexts, backfilled):
+ """
+ Args:
+ store
+ event_and_contexts (list[tuple[FrozenEvent, EventContext]])
+ backfilled (bool): Whether or not the events are the result of
+ backfilling
+ """
+ event_payloads = []
+ for event, context in event_and_contexts:
+ serialized_context = yield context.serialize(event, store)
+
+ event_payloads.append({
+ "event": event.get_pdu_json(),
+ "internal_metadata": event.internal_metadata.get_dict(),
+ "rejected_reason": event.rejected_reason,
+ "context": serialized_context,
+ })
+
+ payload = {
+ "events": event_payloads,
+ "backfilled": backfilled,
+ }
+
+ defer.returnValue(payload)
+
+ @defer.inlineCallbacks
+ def _handle_request(self, request):
+ with Measure(self.clock, "repl_fed_send_events_parse"):
+ content = parse_json_object_from_request(request)
+
+ backfilled = content["backfilled"]
+
+ event_payloads = content["events"]
+
+ event_and_contexts = []
+ for event_payload in event_payloads:
+ event_dict = event_payload["event"]
+ internal_metadata = event_payload["internal_metadata"]
+ rejected_reason = event_payload["rejected_reason"]
+ event = FrozenEvent(event_dict, internal_metadata, rejected_reason)
+
+ context = yield EventContext.deserialize(
+ self.store, event_payload["context"],
+ )
+
+ event_and_contexts.append((event, context))
+
+ logger.info(
+ "Got %d events from federation",
+ len(event_and_contexts),
+ )
+
+ yield self.federation_handler.persist_events_and_notify(
+ event_and_contexts, backfilled,
+ )
+
+ defer.returnValue((200, {}))
+
+
+class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
+ """Handles EDUs newly received from federation, including persisting and
+ notifying.
+
+ Request format:
+
+ POST /_synapse/replication/fed_send_edu/:edu_type/:txn_id
+
+ {
+ "origin": ...,
+ "content: { ... }
+ }
+ """
+
+ NAME = "fed_send_edu"
+ PATH_ARGS = ("edu_type",)
+
+ def __init__(self, hs):
+ super(ReplicationFederationSendEduRestServlet, self).__init__(hs)
+
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self.registry = hs.get_federation_registry()
+
+ @staticmethod
+ def _serialize_payload(edu_type, origin, content):
+ return {
+ "origin": origin,
+ "content": content,
+ }
+
+ @defer.inlineCallbacks
+ def _handle_request(self, request, edu_type):
+ with Measure(self.clock, "repl_fed_send_edu_parse"):
+ content = parse_json_object_from_request(request)
+
+ origin = content["origin"]
+ edu_content = content["content"]
+
+ logger.info(
+ "Got %r edu from %s",
+ edu_type, origin,
+ )
+
+ result = yield self.registry.on_edu(edu_type, origin, edu_content)
+
+ defer.returnValue((200, result))
+
+
+class ReplicationGetQueryRestServlet(ReplicationEndpoint):
+ """Handle responding to queries from federation.
+
+ Request format:
+
+ POST /_synapse/replication/fed_query/:query_type
+
+ {
+ "args": { ... }
+ }
+ """
+
+ NAME = "fed_query"
+ PATH_ARGS = ("query_type",)
+
+ # This is a query, so let's not bother caching
+ CACHE = False
+
+ def __init__(self, hs):
+ super(ReplicationGetQueryRestServlet, self).__init__(hs)
+
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self.registry = hs.get_federation_registry()
+
+ @staticmethod
+ def _serialize_payload(query_type, args):
+ """
+ Args:
+ query_type (str)
+ args (dict): The arguments received for the given query type
+ """
+ return {
+ "args": args,
+ }
+
+ @defer.inlineCallbacks
+ def _handle_request(self, request, query_type):
+ with Measure(self.clock, "repl_fed_query_parse"):
+ content = parse_json_object_from_request(request)
+
+ args = content["args"]
+
+ logger.info(
+ "Got %r query",
+ query_type,
+ )
+
+ result = yield self.registry.on_query(query_type, args)
+
+ defer.returnValue((200, result))
+
+
+class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
+ """Called to clean up any data in DB for a given room, ready for the
+ server to join the room.
+
+ Request format:
+
+ POST /_synapse/replication/fed_query/:fed_cleanup_room/:txn_id
+
+ {}
+ """
+
+ NAME = "fed_cleanup_room"
+ PATH_ARGS = ("room_id",)
+
+ def __init__(self, hs):
+ super(ReplicationCleanRoomRestServlet, self).__init__(hs)
+
+ self.store = hs.get_datastore()
+
+ @staticmethod
+ def _serialize_payload(room_id, args):
+ """
+ Args:
+ room_id (str)
+ """
+ return {}
+
+ @defer.inlineCallbacks
+ def _handle_request(self, request, room_id):
+ yield self.store.clean_room_for_join(room_id)
+
+ defer.returnValue((200, {}))
+
+
+def register_servlets(hs, http_server):
+ ReplicationFederationSendEventsRestServlet(hs).register(http_server)
+ ReplicationFederationSendEduRestServlet(hs).register(http_server)
+ ReplicationGetQueryRestServlet(hs).register(http_server)
+ ReplicationCleanRoomRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index e66c4e881f..e58bebf12a 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -14,182 +14,63 @@
# limitations under the License.
import logging
-import re
from twisted.internet import defer
-from synapse.api.errors import SynapseError, MatrixCodeMessageException
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.servlet import parse_json_object_from_request
+from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import Requester, UserID
-from synapse.util.distributor import user_left_room, user_joined_room
+from synapse.util.distributor import user_joined_room, user_left_room
logger = logging.getLogger(__name__)
-@defer.inlineCallbacks
-def remote_join(client, host, port, requester, remote_room_hosts,
- room_id, user_id, content):
- """Ask the master to do a remote join for the given user to the given room
+class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
+ """Does a remote join for the given user to the given room
- Args:
- client (SimpleHttpClient)
- host (str): host of master
- port (int): port on master listening for HTTP replication
- requester (Requester)
- remote_room_hosts (list[str]): Servers to try and join via
- room_id (str)
- user_id (str)
- content (dict): The event content to use for the join event
+ Request format:
- Returns:
- Deferred
- """
- uri = "http://%s:%s/_synapse/replication/remote_join" % (host, port)
-
- payload = {
- "requester": requester.serialize(),
- "remote_room_hosts": remote_room_hosts,
- "room_id": room_id,
- "user_id": user_id,
- "content": content,
- }
-
- try:
- result = yield client.post_json_get_json(uri, payload)
- except MatrixCodeMessageException as e:
- # We convert to SynapseError as we know that it was a SynapseError
- # on the master process that we should send to the client. (And
- # importantly, not stack traces everywhere)
- raise SynapseError(e.code, e.msg, e.errcode)
- defer.returnValue(result)
-
-
-@defer.inlineCallbacks
-def remote_reject_invite(client, host, port, requester, remote_room_hosts,
- room_id, user_id):
- """Ask master to reject the invite for the user and room.
-
- Args:
- client (SimpleHttpClient)
- host (str): host of master
- port (int): port on master listening for HTTP replication
- requester (Requester)
- remote_room_hosts (list[str]): Servers to try and reject via
- room_id (str)
- user_id (str)
-
- Returns:
- Deferred
- """
- uri = "http://%s:%s/_synapse/replication/remote_reject_invite" % (host, port)
-
- payload = {
- "requester": requester.serialize(),
- "remote_room_hosts": remote_room_hosts,
- "room_id": room_id,
- "user_id": user_id,
- }
-
- try:
- result = yield client.post_json_get_json(uri, payload)
- except MatrixCodeMessageException as e:
- # We convert to SynapseError as we know that it was a SynapseError
- # on the master process that we should send to the client. (And
- # importantly, not stack traces everywhere)
- raise SynapseError(e.code, e.msg, e.errcode)
- defer.returnValue(result)
-
-
-@defer.inlineCallbacks
-def get_or_register_3pid_guest(client, host, port, requester,
- medium, address, inviter_user_id):
- """Ask the master to get/create a guest account for given 3PID.
-
- Args:
- client (SimpleHttpClient)
- host (str): host of master
- port (int): port on master listening for HTTP replication
- requester (Requester)
- medium (str)
- address (str)
- inviter_user_id (str): The user ID who is trying to invite the
- 3PID
-
- Returns:
- Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
- 3PID guest account.
- """
+ POST /_synapse/replication/remote_join/:room_id/:user_id
- uri = "http://%s:%s/_synapse/replication/get_or_register_3pid_guest" % (host, port)
-
- payload = {
- "requester": requester.serialize(),
- "medium": medium,
- "address": address,
- "inviter_user_id": inviter_user_id,
- }
-
- try:
- result = yield client.post_json_get_json(uri, payload)
- except MatrixCodeMessageException as e:
- # We convert to SynapseError as we know that it was a SynapseError
- # on the master process that we should send to the client. (And
- # importantly, not stack traces everywhere)
- raise SynapseError(e.code, e.msg, e.errcode)
- defer.returnValue(result)
-
-
-@defer.inlineCallbacks
-def notify_user_membership_change(client, host, port, user_id, room_id, change):
- """Notify master that a user has joined or left the room
-
- Args:
- client (SimpleHttpClient)
- host (str): host of master
- port (int): port on master listening for HTTP replication.
- user_id (str)
- room_id (str)
- change (str): Either "join" or "left"
-
- Returns:
- Deferred
+ {
+ "requester": ...,
+ "remote_room_hosts": [...],
+ "content": { ... }
+ }
"""
- assert change in ("joined", "left")
-
- uri = "http://%s:%s/_synapse/replication/user_%s_room" % (host, port, change)
-
- payload = {
- "user_id": user_id,
- "room_id": room_id,
- }
-
- try:
- result = yield client.post_json_get_json(uri, payload)
- except MatrixCodeMessageException as e:
- # We convert to SynapseError as we know that it was a SynapseError
- # on the master process that we should send to the client. (And
- # importantly, not stack traces everywhere)
- raise SynapseError(e.code, e.msg, e.errcode)
- defer.returnValue(result)
-
-class ReplicationRemoteJoinRestServlet(RestServlet):
- PATTERNS = [re.compile("^/_synapse/replication/remote_join$")]
+ NAME = "remote_join"
+ PATH_ARGS = ("room_id", "user_id",)
def __init__(self, hs):
- super(ReplicationRemoteJoinRestServlet, self).__init__()
+ super(ReplicationRemoteJoinRestServlet, self).__init__(hs)
self.federation_handler = hs.get_handlers().federation_handler
self.store = hs.get_datastore()
self.clock = hs.get_clock()
+ @staticmethod
+ def _serialize_payload(requester, room_id, user_id, remote_room_hosts,
+ content):
+ """
+ Args:
+ requester(Requester)
+ room_id (str)
+ user_id (str)
+ remote_room_hosts (list[str]): Servers to try and join via
+ content(dict): The event content to use for the join event
+ """
+ return {
+ "requester": requester.serialize(),
+ "remote_room_hosts": remote_room_hosts,
+ "content": content,
+ }
+
@defer.inlineCallbacks
- def on_POST(self, request):
+ def _handle_request(self, request, room_id, user_id):
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"]
- room_id = content["room_id"]
- user_id = content["user_id"]
event_content = content["content"]
requester = Requester.deserialize(self.store, content["requester"])
@@ -212,23 +93,48 @@ class ReplicationRemoteJoinRestServlet(RestServlet):
defer.returnValue((200, {}))
-class ReplicationRemoteRejectInviteRestServlet(RestServlet):
- PATTERNS = [re.compile("^/_synapse/replication/remote_reject_invite$")]
+class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
+ """Rejects the invite for the user and room.
+
+ Request format:
+
+ POST /_synapse/replication/remote_reject_invite/:room_id/:user_id
+
+ {
+ "requester": ...,
+ "remote_room_hosts": [...],
+ }
+ """
+
+ NAME = "remote_reject_invite"
+ PATH_ARGS = ("room_id", "user_id",)
def __init__(self, hs):
- super(ReplicationRemoteRejectInviteRestServlet, self).__init__()
+ super(ReplicationRemoteRejectInviteRestServlet, self).__init__(hs)
self.federation_handler = hs.get_handlers().federation_handler
self.store = hs.get_datastore()
self.clock = hs.get_clock()
+ @staticmethod
+ def _serialize_payload(requester, room_id, user_id, remote_room_hosts):
+ """
+ Args:
+ requester(Requester)
+ room_id (str)
+ user_id (str)
+ remote_room_hosts (list[str]): Servers to try and reject via
+ """
+ return {
+ "requester": requester.serialize(),
+ "remote_room_hosts": remote_room_hosts,
+ }
+
@defer.inlineCallbacks
- def on_POST(self, request):
+ def _handle_request(self, request, room_id, user_id):
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"]
- room_id = content["room_id"]
- user_id = content["user_id"]
requester = Requester.deserialize(self.store, content["requester"])
@@ -264,18 +170,50 @@ class ReplicationRemoteRejectInviteRestServlet(RestServlet):
defer.returnValue((200, ret))
-class ReplicationRegister3PIDGuestRestServlet(RestServlet):
- PATTERNS = [re.compile("^/_synapse/replication/get_or_register_3pid_guest$")]
+class ReplicationRegister3PIDGuestRestServlet(ReplicationEndpoint):
+ """Gets/creates a guest account for given 3PID.
+
+ Request format:
+
+ POST /_synapse/replication/get_or_register_3pid_guest/
+
+ {
+ "requester": ...,
+ "medium": ...,
+ "address": ...,
+ "inviter_user_id": ...
+ }
+ """
+
+ NAME = "get_or_register_3pid_guest"
+ PATH_ARGS = ()
def __init__(self, hs):
- super(ReplicationRegister3PIDGuestRestServlet, self).__init__()
+ super(ReplicationRegister3PIDGuestRestServlet, self).__init__(hs)
self.registeration_handler = hs.get_handlers().registration_handler
self.store = hs.get_datastore()
self.clock = hs.get_clock()
+ @staticmethod
+ def _serialize_payload(requester, medium, address, inviter_user_id):
+ """
+ Args:
+ requester(Requester)
+ medium (str)
+ address (str)
+ inviter_user_id (str): The user ID who is trying to invite the
+ 3PID
+ """
+ return {
+ "requester": requester.serialize(),
+ "medium": medium,
+ "address": address,
+ "inviter_user_id": inviter_user_id,
+ }
+
@defer.inlineCallbacks
- def on_POST(self, request):
+ def _handle_request(self, request):
content = parse_json_object_from_request(request)
medium = content["medium"]
@@ -296,23 +234,41 @@ class ReplicationRegister3PIDGuestRestServlet(RestServlet):
defer.returnValue((200, ret))
-class ReplicationUserJoinedLeftRoomRestServlet(RestServlet):
- PATTERNS = [re.compile("^/_synapse/replication/user_(?P<change>joined|left)_room$")]
+class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
+ """Notifies that a user has joined or left the room
+
+ Request format:
+
+ POST /_synapse/replication/membership_change/:room_id/:user_id/:change
+
+ {}
+ """
+
+ NAME = "membership_change"
+ PATH_ARGS = ("room_id", "user_id", "change")
+ CACHE = False # No point caching as should return instantly.
def __init__(self, hs):
- super(ReplicationUserJoinedLeftRoomRestServlet, self).__init__()
+ super(ReplicationUserJoinedLeftRoomRestServlet, self).__init__(hs)
self.registeration_handler = hs.get_handlers().registration_handler
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.distributor = hs.get_distributor()
- def on_POST(self, request, change):
- content = parse_json_object_from_request(request)
+ @staticmethod
+ def _serialize_payload(room_id, user_id, change):
+ """
+ Args:
+ room_id (str)
+ user_id (str)
+ change (str): Either "joined" or "left"
+ """
+ assert change in ("joined", "left",)
- user_id = content["user_id"]
- room_id = content["room_id"]
+ return {}
+ def _handle_request(self, request, room_id, user_id, change):
logger.info("user membership change: %s in %s", user_id, room_id)
user = UserID.from_string(user_id)
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index a9baa2c1c3..5b52c91650 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -13,86 +13,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
-from synapse.api.errors import (
- SynapseError, MatrixCodeMessageException, CodeMessageException,
-)
from synapse.events import FrozenEvent
from synapse.events.snapshot import EventContext
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.util.async import sleep
-from synapse.util.caches.response_cache import ResponseCache
-from synapse.util.metrics import Measure
+from synapse.http.servlet import parse_json_object_from_request
+from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import Requester, UserID
-
-import logging
-import re
+from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
-@defer.inlineCallbacks
-def send_event_to_master(client, host, port, requester, event, context,
- ratelimit, extra_users):
- """Send event to be handled on the master
-
- Args:
- client (SimpleHttpClient)
- host (str): host of master
- port (int): port on master listening for HTTP replication
- requester (Requester)
- event (FrozenEvent)
- context (EventContext)
- ratelimit (bool)
- extra_users (list(UserID)): Any extra users to notify about event
- """
- uri = "http://%s:%s/_synapse/replication/send_event/%s" % (
- host, port, event.event_id,
- )
-
- payload = {
- "event": event.get_pdu_json(),
- "internal_metadata": event.internal_metadata.get_dict(),
- "rejected_reason": event.rejected_reason,
- "context": context.serialize(event),
- "requester": requester.serialize(),
- "ratelimit": ratelimit,
- "extra_users": [u.to_string() for u in extra_users],
- }
-
- try:
- # We keep retrying the same request for timeouts. This is so that we
- # have a good idea that the request has either succeeded or failed on
- # the master, and so whether we should clean up or not.
- while True:
- try:
- result = yield client.put_json(uri, payload)
- break
- except CodeMessageException as e:
- if e.code != 504:
- raise
-
- logger.warn("send_event request timed out")
-
- # If we timed out we probably don't need to worry about backing
- # off too much, but lets just wait a little anyway.
- yield sleep(1)
- except MatrixCodeMessageException as e:
- # We convert to SynapseError as we know that it was a SynapseError
- # on the master process that we should send to the client. (And
- # importantly, not stack traces everywhere)
- raise SynapseError(e.code, e.msg, e.errcode)
- defer.returnValue(result)
-
-
-class ReplicationSendEventRestServlet(RestServlet):
+class ReplicationSendEventRestServlet(ReplicationEndpoint):
"""Handles events newly created on workers, including persisting and
notifying.
The API looks like:
- POST /_synapse/replication/send_event/:event_id
+ POST /_synapse/replication/send_event/:event_id/:txn_id
{
"event": { .. serialized event .. },
@@ -104,27 +45,47 @@ class ReplicationSendEventRestServlet(RestServlet):
"extra_users": [],
}
"""
- PATTERNS = [re.compile("^/_synapse/replication/send_event/(?P<event_id>[^/]+)$")]
+ NAME = "send_event"
+ PATH_ARGS = ("event_id",)
def __init__(self, hs):
- super(ReplicationSendEventRestServlet, self).__init__()
+ super(ReplicationSendEventRestServlet, self).__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
- # The responses are tiny, so we may as well cache them for a while
- self.response_cache = ResponseCache(hs, "send_event", timeout_ms=30 * 60 * 1000)
+ @staticmethod
+ @defer.inlineCallbacks
+ def _serialize_payload(event_id, store, event, context, requester,
+ ratelimit, extra_users):
+ """
+ Args:
+ event_id (str)
+ store (DataStore)
+ requester (Requester)
+ event (FrozenEvent)
+ context (EventContext)
+ ratelimit (bool)
+ extra_users (list(UserID)): Any extra users to notify about event
+ """
+
+ serialized_context = yield context.serialize(event, store)
+
+ payload = {
+ "event": event.get_pdu_json(),
+ "internal_metadata": event.internal_metadata.get_dict(),
+ "rejected_reason": event.rejected_reason,
+ "context": serialized_context,
+ "requester": requester.serialize(),
+ "ratelimit": ratelimit,
+ "extra_users": [u.to_string() for u in extra_users],
+ }
- def on_PUT(self, request, event_id):
- return self.response_cache.wrap(
- event_id,
- self._handle_request,
- request
- )
+ defer.returnValue(payload)
@defer.inlineCallbacks
- def _handle_request(self, request):
+ def _handle_request(self, request, event_id):
with Measure(self.clock, "repl_send_event_parse"):
content = parse_json_object_from_request(request)
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 61f5590c53..3f7be74e02 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -13,13 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import PostgresEngine
from ._slaved_id_tracker import SlavedIdTracker
-import logging
-
logger = logging.getLogger(__name__)
diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py
index 8cae3076f4..b53a4c6bd1 100644
--- a/synapse/replication/slave/storage/appservice.py
+++ b/synapse/replication/slave/storage/appservice.py
@@ -15,7 +15,8 @@
# limitations under the License.
from synapse.storage.appservice import (
- ApplicationServiceWorkerStore, ApplicationServiceTransactionWorkerStore,
+ ApplicationServiceTransactionWorkerStore,
+ ApplicationServiceWorkerStore,
)
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 352c9a2aa8..60641f1a49 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -13,11 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStore
from synapse.storage.client_ips import LAST_SEEN_GRANULARITY
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.descriptors import Cache
+from ._base import BaseSlavedStore
+
class SlavedClientIpStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 6f3fb64770..87eaa53004 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -13,11 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+from ._base import BaseSlavedStore
+from ._slaved_id_tracker import SlavedIdTracker
class SlavedDeviceInboxStore(BaseSlavedStore):
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 7687867aee..8206a988f7 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -13,12 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.storage.end_to_end_keys import EndToEndKeyStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
+from ._base import BaseSlavedStore
+from ._slaved_id_tracker import SlavedIdTracker
+
class SlavedDeviceStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py
index 6deecd3963..1d1d48709a 100644
--- a/synapse/replication/slave/storage/directory.py
+++ b/synapse/replication/slave/storage/directory.py
@@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStore
from synapse.storage.directory import DirectoryWorkerStore
+from ._base import BaseSlavedStore
+
class DirectoryStore(DirectoryWorkerStore, BaseSlavedStore):
pass
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index b1f64ef0d8..4830c68f35 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -20,9 +20,11 @@ from synapse.storage.event_federation import EventFederationWorkerStore
from synapse.storage.event_push_actions import EventPushActionsWorkerStore
from synapse.storage.events_worker import EventsWorkerStore
from synapse.storage.roommember import RoomMemberWorkerStore
+from synapse.storage.signatures import SignatureWorkerStore
from synapse.storage.state import StateGroupWorkerStore
from synapse.storage.stream import StreamWorkerStore
-from synapse.storage.signatures import SignatureWorkerStore
+from synapse.storage.user_erasure_store import UserErasureWorkerStore
+
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
@@ -42,9 +44,10 @@ class SlavedEventStore(EventFederationWorkerStore,
RoomMemberWorkerStore,
EventPushActionsWorkerStore,
StreamWorkerStore,
- EventsWorkerStore,
StateGroupWorkerStore,
+ EventsWorkerStore,
SignatureWorkerStore,
+ UserErasureWorkerStore,
BaseSlavedStore):
def __init__(self, db_conn, hs):
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
index 819ed62881..456a14cd5c 100644
--- a/synapse/replication/slave/storage/filtering.py
+++ b/synapse/replication/slave/storage/filtering.py
@@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStore
from synapse.storage.filtering import FilteringStore
+from ._base import BaseSlavedStore
+
class SlavedFilteringStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 0bc4bce5b0..5777f07c8d 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -13,11 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
+from ._base import BaseSlavedStore
+from ._slaved_id_tracker import SlavedIdTracker
+
class SlavedGroupServerStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py
index dd2ae49e48..05ed168463 100644
--- a/synapse/replication/slave/storage/keys.py
+++ b/synapse/replication/slave/storage/keys.py
@@ -13,10 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.storage.keys import KeyStore
+from ._base import BaseSlavedStore
+
class SlavedKeyStore(BaseSlavedStore):
_get_server_verify_key = KeyStore.__dict__[
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index cfb9280181..80b744082a 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
-
-from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.storage import DataStore
from synapse.storage.presence import PresenceStore
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+from ._base import BaseSlavedStore
+from ._slaved_id_tracker import SlavedIdTracker
class SlavedPresenceStore(BaseSlavedStore):
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index bb2c40b6e3..f0200c1e98 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -14,10 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from .events import SlavedEventStore
-from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage.push_rule import PushRulesWorkerStore
+from ._slaved_id_tracker import SlavedIdTracker
+from .events import SlavedEventStore
+
class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore):
def __init__(self, db_conn, hs):
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index a7cd5a7291..3b2213c0d4 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -14,11 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.storage.pusher import PusherWorkerStore
+
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
-from synapse.storage.pusher import PusherWorkerStore
-
class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 1647072f65..ed12342f40 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -14,11 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.storage.receipts import ReceiptsWorkerStore
+
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
-from synapse.storage.receipts import ReceiptsWorkerStore
-
# So, um, we want to borrow a load of functions intended for reading from
# a DataStore, but we don't want to take functions that either write to the
# DataStore or are cached and don't have cache invalidation logic.
@@ -49,7 +49,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type))
- self.get_linearized_receipts_for_room.invalidate_many((room_id,))
+ self._get_linearized_receipts_for_room.invalidate_many((room_id,))
self.get_last_receipt_event_id_for_user.invalidate(
(user_id, room_id, receipt_type)
)
diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py
index 7323bf0f1e..408d91df1c 100644
--- a/synapse/replication/slave/storage/registration.py
+++ b/synapse/replication/slave/storage/registration.py
@@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStore
from synapse.storage.registration import RegistrationWorkerStore
+from ._base import BaseSlavedStore
+
class SlavedRegistrationStore(RegistrationWorkerStore, BaseSlavedStore):
pass
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 5ae1670157..0cb474928c 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -13,8 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStore
from synapse.storage.room import RoomWorkerStore
+
+from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py
index fbb58f35da..3527beb3c9 100644
--- a/synapse/replication/slave/storage/transactions.py
+++ b/synapse/replication/slave/storage/transactions.py
@@ -13,18 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStore
-from synapse.storage import DataStore
from synapse.storage.transactions import TransactionStore
+from ._base import BaseSlavedStore
-class TransactionStore(BaseSlavedStore):
- get_destination_retry_timings = TransactionStore.__dict__[
- "get_destination_retry_timings"
- ]
- _get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__
- set_destination_retry_timings = DataStore.set_destination_retry_timings.__func__
- _set_destination_retry_timings = DataStore._set_destination_retry_timings.__func__
- prep_send_transaction = DataStore.prep_send_transaction.__func__
- delivered_txn = DataStore.delivered_txn.__func__
+class SlavedTransactionStore(TransactionStore, BaseSlavedStore):
+ pass
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 6d2513c4e2..cbe9645817 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -15,17 +15,20 @@
"""A replication client for use by synapse workers.
"""
-from twisted.internet import reactor, defer
+import logging
+
+from twisted.internet import defer
from twisted.internet.protocol import ReconnectingClientFactory
from .commands import (
- FederationAckCommand, UserSyncCommand, RemovePusherCommand, InvalidateCacheCommand,
+ FederationAckCommand,
+ InvalidateCacheCommand,
+ RemovePusherCommand,
UserIpCommand,
+ UserSyncCommand,
)
from .protocol import ClientReplicationStreamProtocol
-import logging
-
logger = logging.getLogger(__name__)
@@ -44,7 +47,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
self.server_name = hs.config.server_name
self._clock = hs.get_clock() # As self.clock is defined in super class
- reactor.addSystemEventTrigger("before", "shutdown", self.stopTrying)
+ hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
def startedConnecting(self, connector):
logger.info("Connecting to replication: %r", connector.getDestination())
@@ -95,7 +98,7 @@ class ReplicationClientHandler(object):
factory = ReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host
port = hs.config.worker_replication_port
- reactor.connectTCP(host, port, factory)
+ hs.get_reactor().connectTCP(host, port, factory)
def on_rdata(self, stream_name, token, rows):
"""Called when we get new replication data. By default this just pokes
@@ -104,7 +107,7 @@ class ReplicationClientHandler(object):
Can be overriden in subclasses to handle more.
"""
logger.info("Received rdata %s -> %s", stream_name, token)
- self.store.process_replication_rows(stream_name, token, rows)
+ return self.store.process_replication_rows(stream_name, token, rows)
def on_position(self, stream_name, token):
"""Called when we get new position data. By default this just pokes
@@ -112,7 +115,7 @@ class ReplicationClientHandler(object):
Can be overriden in subclasses to handle more.
"""
- self.store.process_replication_rows(stream_name, token, [])
+ return self.store.process_replication_rows(stream_name, token, [])
def on_sync(self, data):
"""When we received a SYNC we wake up any deferreds that were waiting
@@ -189,7 +192,7 @@ class ReplicationClientHandler(object):
"""Returns a deferred that is resolved when we receive a SYNC command
with given data.
- Used by tests.
+ [Not currently] used by tests.
"""
return self.awaiting_syncs.setdefault(data, defer.Deferred())
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 12aac3cc6b..327556f6a1 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -19,13 +19,17 @@ allowed to be sent by which side.
"""
import logging
-import simplejson
+import platform
+if platform.python_implementation() == "PyPy":
+ import json
+ _json_encoder = json.JSONEncoder()
+else:
+ import simplejson as json
+ _json_encoder = json.JSONEncoder(namedtuple_as_object=False)
logger = logging.getLogger(__name__)
-_json_encoder = simplejson.JSONEncoder(namedtuple_as_object=False)
-
class Command(object):
"""The base command class.
@@ -55,6 +59,12 @@ class Command(object):
"""
return self.data
+ def get_logcontext_id(self):
+ """Get a suitable string for the logcontext when processing this command"""
+
+ # by default, we just use the command name.
+ return self.NAME
+
class ServerCommand(Command):
"""Sent by the server on new connection and includes the server_name.
@@ -102,7 +112,7 @@ class RdataCommand(Command):
return cls(
stream_name,
None if token == "batch" else int(token),
- simplejson.loads(row_json)
+ json.loads(row_json)
)
def to_line(self):
@@ -112,6 +122,9 @@ class RdataCommand(Command):
_json_encoder.encode(self.row),
))
+ def get_logcontext_id(self):
+ return "RDATA-" + self.stream_name
+
class PositionCommand(Command):
"""Sent by the client to tell the client the stream postition without
@@ -186,6 +199,9 @@ class ReplicateCommand(Command):
def to_line(self):
return " ".join((self.stream_name, str(self.token),))
+ def get_logcontext_id(self):
+ return "REPLICATE-" + self.stream_name
+
class UserSyncCommand(Command):
"""Sent by the client to inform the server that a user has started or
@@ -300,7 +316,7 @@ class InvalidateCacheCommand(Command):
def from_line(cls, line):
cache_func, keys_json = line.split(" ", 1)
- return cls(cache_func, simplejson.loads(keys_json))
+ return cls(cache_func, json.loads(keys_json))
def to_line(self):
return " ".join((
@@ -329,7 +345,7 @@ class UserIpCommand(Command):
def from_line(cls, line):
user_id, jsn = line.split(" ", 1)
- access_token, ip, user_agent, device_id, last_seen = simplejson.loads(jsn)
+ access_token, ip, user_agent, device_id, last_seen = json.loads(jsn)
return cls(
user_id, access_token, ip, user_agent, device_id, last_seen
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index c870475cd1..74e892c104 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -49,29 +49,39 @@ indicate which side is sending, these are *not* included on the wire::
* connection closed by server *
"""
+import fcntl
+import logging
+import struct
+from collections import defaultdict
+
+from six import iteritems, iterkeys
+
+from prometheus_client import Counter
+
from twisted.internet import defer
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
-from .commands import (
- COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS,
- ErrorCommand, ServerCommand, RdataCommand, PositionCommand, PingCommand,
- NameCommand, ReplicateCommand, UserSyncCommand, SyncCommand,
-)
-from .streams import STREAMS_MAP
-
from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.stringutils import random_string
-from prometheus_client import Counter
-
-from collections import defaultdict
-
-from six import iterkeys, iteritems
-
-import logging
-import struct
-import fcntl
+from .commands import (
+ COMMAND_MAP,
+ VALID_CLIENT_COMMANDS,
+ VALID_SERVER_COMMANDS,
+ ErrorCommand,
+ NameCommand,
+ PingCommand,
+ PositionCommand,
+ RdataCommand,
+ ReplicateCommand,
+ ServerCommand,
+ SyncCommand,
+ UserSyncCommand,
+)
+from .streams import STREAMS_MAP
connection_close_counter = Counter(
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"])
@@ -214,7 +224,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# Now lets try and call on_<CMD_NAME> function
try:
- getattr(self, "on_%s" % (cmd_name,))(cmd)
+ run_as_background_process(
+ "replication-" + cmd.get_logcontext_id(),
+ getattr(self, "on_%s" % (cmd_name,)),
+ cmd,
+ )
except Exception:
logger.exception("[%s] Failed to handle line: %r", self.id(), line)
@@ -379,7 +393,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.name = cmd.data
def on_USER_SYNC(self, cmd):
- self.streamer.on_user_sync(
+ return self.streamer.on_user_sync(
self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms,
)
@@ -389,22 +403,33 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
if stream_name == "ALL":
# Subscribe to all streams we're publishing to.
- for stream in iterkeys(self.streamer.streams_by_name):
- self.subscribe_to_stream(stream, token)
+ deferreds = [
+ run_in_background(
+ self.subscribe_to_stream,
+ stream, token,
+ )
+ for stream in iterkeys(self.streamer.streams_by_name)
+ ]
+
+ return make_deferred_yieldable(
+ defer.gatherResults(deferreds, consumeErrors=True)
+ )
else:
- self.subscribe_to_stream(stream_name, token)
+ return self.subscribe_to_stream(stream_name, token)
def on_FEDERATION_ACK(self, cmd):
- self.streamer.federation_ack(cmd.token)
+ return self.streamer.federation_ack(cmd.token)
def on_REMOVE_PUSHER(self, cmd):
- self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
+ return self.streamer.on_remove_pusher(
+ cmd.app_id, cmd.push_key, cmd.user_id,
+ )
def on_INVALIDATE_CACHE(self, cmd):
- self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
+ return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
def on_USER_IP(self, cmd):
- self.streamer.on_user_ip(
+ return self.streamer.on_user_ip(
cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id,
cmd.last_seen,
)
@@ -534,14 +559,13 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, [])
rows.append(row)
-
- self.handler.on_rdata(stream_name, cmd.token, rows)
+ return self.handler.on_rdata(stream_name, cmd.token, rows)
def on_POSITION(self, cmd):
- self.handler.on_position(cmd.stream_name, cmd.token)
+ return self.handler.on_position(cmd.stream_name, cmd.token)
def on_SYNC(self, cmd):
- self.handler.on_sync(cmd.data)
+ return self.handler.on_sync(cmd.data)
def replicate(self, stream_name, token):
"""Send the subscription request to the server
@@ -564,11 +588,13 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# The following simply registers metrics for the replication connections
pending_commands = LaterGauge(
- "pending_commands", "", ["name", "conn_id"],
+ "synapse_replication_tcp_protocol_pending_commands",
+ "",
+ ["name", "conn_id"],
lambda: {
- (p.name, p.conn_id): len(p.pending_commands)
- for p in connected_connections
- })
+ (p.name, p.conn_id): len(p.pending_commands) for p in connected_connections
+ },
+)
def transport_buffer_size(protocol):
@@ -579,11 +605,13 @@ def transport_buffer_size(protocol):
transport_send_buffer = LaterGauge(
- "synapse_replication_tcp_transport_send_buffer", "", ["name", "conn_id"],
+ "synapse_replication_tcp_protocol_transport_send_buffer",
+ "",
+ ["name", "conn_id"],
lambda: {
- (p.name, p.conn_id): transport_buffer_size(p)
- for p in connected_connections
- })
+ (p.name, p.conn_id): transport_buffer_size(p) for p in connected_connections
+ },
+)
def transport_kernel_read_buffer_size(protocol, read=True):
@@ -602,37 +630,50 @@ def transport_kernel_read_buffer_size(protocol, read=True):
tcp_transport_kernel_send_buffer = LaterGauge(
- "synapse_replication_tcp_transport_kernel_send_buffer", "", ["name", "conn_id"],
+ "synapse_replication_tcp_protocol_transport_kernel_send_buffer",
+ "",
+ ["name", "conn_id"],
lambda: {
(p.name, p.conn_id): transport_kernel_read_buffer_size(p, False)
for p in connected_connections
- })
+ },
+)
tcp_transport_kernel_read_buffer = LaterGauge(
- "synapse_replication_tcp_transport_kernel_read_buffer", "", ["name", "conn_id"],
+ "synapse_replication_tcp_protocol_transport_kernel_read_buffer",
+ "",
+ ["name", "conn_id"],
lambda: {
(p.name, p.conn_id): transport_kernel_read_buffer_size(p, True)
for p in connected_connections
- })
+ },
+)
tcp_inbound_commands = LaterGauge(
- "synapse_replication_tcp_inbound_commands", "", ["command", "name", "conn_id"],
+ "synapse_replication_tcp_protocol_inbound_commands",
+ "",
+ ["command", "name", "conn_id"],
lambda: {
(k[0], p.name, p.conn_id): count
for p in connected_connections
for k, count in iteritems(p.inbound_commands_counter)
- })
+ },
+)
tcp_outbound_commands = LaterGauge(
- "synapse_replication_tcp_outbound_commands", "", ["command", "name", "conn_id"],
+ "synapse_replication_tcp_protocol_outbound_commands",
+ "",
+ ["command", "name", "conn_id"],
lambda: {
(k[0], p.name, p.conn_id): count
for p in connected_connections
for k, count in iteritems(p.outbound_commands_counter)
- })
+ },
+)
# number of updates received for each RDATA stream
-inbound_rdata_count = Counter("synapse_replication_tcp_inbound_rdata_count", "",
- ["stream_name"])
+inbound_rdata_count = Counter(
+ "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"]
+)
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 63bd6d2652..fd59f1595f 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -15,19 +15,21 @@
"""The server side of the replication stream.
"""
-from twisted.internet import defer, reactor
-from twisted.internet.protocol import Factory
+import logging
-from .streams import STREAMS_MAP, FederationStream
-from .protocol import ServerReplicationStreamProtocol
+from six import itervalues
-from synapse.util.metrics import Measure, measure_func
-from synapse.metrics import LaterGauge
+from prometheus_client import Counter
-import logging
+from twisted.internet import defer
+from twisted.internet.protocol import Factory
-from prometheus_client import Counter
-from six import itervalues
+from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.metrics import Measure, measure_func
+
+from .protocol import ServerReplicationStreamProtocol
+from .streams import STREAMS_MAP, FederationStream
stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates",
"", ["stream_name"])
@@ -109,14 +111,13 @@ class ReplicationStreamer(object):
self.is_looping = False
self.pending_updates = False
- reactor.addSystemEventTrigger("before", "shutdown", self.on_shutdown)
+ hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown)
def on_shutdown(self):
# close all connections on shutdown
for conn in self.connections:
conn.send_error("server shutting down")
- @defer.inlineCallbacks
def on_notifier_poke(self):
"""Checks if there is actually any new data and sends it to the
connections if there are.
@@ -131,14 +132,16 @@ class ReplicationStreamer(object):
stream.discard_updates_and_advance()
return
- # If we're in the process of checking for new updates, mark that fact
- # and return
+ self.pending_updates = True
+
if self.is_looping:
- logger.debug("Noitifier poke loop already running")
- self.pending_updates = True
+ logger.debug("Notifier poke loop already running")
return
- self.pending_updates = True
+ run_as_background_process("replication_notifier", self._run_notifier_loop)
+
+ @defer.inlineCallbacks
+ def _run_notifier_loop(self):
self.is_looping = True
try:
diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py
index 4c60bf79f9..55fe701c5c 100644
--- a/synapse/replication/tcp/streams.py
+++ b/synapse/replication/tcp/streams.py
@@ -24,11 +24,10 @@ Each stream is defined by the following information:
update_function: The function that returns a list of updates between two tokens
"""
-from twisted.internet import defer
-from collections import namedtuple
-
import logging
+from collections import namedtuple
+from twisted.internet import defer
logger = logging.getLogger(__name__)
|