diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index 19b69e0e11..a84a064c8d 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -30,7 +30,8 @@ REPLICATION_PREFIX = "/_synapse/replication"
class ReplicationRestResource(JsonResource):
def __init__(self, hs):
- JsonResource.__init__(self, hs, canonical_json=False)
+ # We enable extracting jaeger contexts here as these are internal APIs.
+ super().__init__(hs, canonical_json=False, extract_context=True)
self.register_servlets(hs)
def register_servlets(self, hs):
@@ -38,10 +39,10 @@ class ReplicationRestResource(JsonResource):
federation.register_servlets(hs, self)
presence.register_servlets(hs, self)
membership.register_servlets(hs, self)
+ streams.register_servlets(hs, self)
# The following can't currently be instantiated on workers.
if hs.config.worker.worker_app is None:
login.register_servlets(hs, self)
register.register_servlets(hs, self)
devices.register_servlets(hs, self)
- streams.register_servlets(hs, self)
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 793cef6c26..ba16f22c91 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -16,32 +16,24 @@
import abc
import logging
import re
+import urllib
from inspect import signature
from typing import Dict, List, Tuple
-from six import raise_from
-from six.moves import urllib
-
-from twisted.internet import defer
-
from synapse.api.errors import (
CodeMessageException,
HttpResponseException,
RequestSendFailed,
SynapseError,
)
-from synapse.logging.opentracing import (
- inject_active_span_byte_dict,
- trace,
- trace_servlet,
-)
+from synapse.logging.opentracing import inject_active_span_byte_dict, trace
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)
-class ReplicationEndpoint(object):
+class ReplicationEndpoint:
"""Helper base class for defining new replication HTTP endpoints.
This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..`
@@ -98,16 +90,16 @@ class ReplicationEndpoint(object):
# assert here that sub classes don't try and use the name.
assert (
"instance_name" not in self.PATH_ARGS
- ), "`instance_name` is a reserved paramater name"
+ ), "`instance_name` is a reserved parameter name"
assert (
"instance_name"
not in signature(self.__class__._serialize_payload).parameters
- ), "`instance_name` is a reserved paramater name"
+ ), "`instance_name` is a reserved parameter name"
assert self.METHOD in ("PUT", "POST", "GET")
@abc.abstractmethod
- def _serialize_payload(**kwargs):
+ async def _serialize_payload(**kwargs):
"""Static method that is called when creating a request.
Concrete implementations should have explicit parameters (rather than
@@ -116,9 +108,8 @@ class ReplicationEndpoint(object):
argument list.
Returns:
- Deferred[dict]|dict: If POST/PUT request then dictionary must be
- JSON serialisable, otherwise must be appropriate for adding as
- query args.
+ dict: If POST/PUT request then dictionary must be JSON serialisable,
+ otherwise must be appropriate for adding as query args.
"""
return {}
@@ -150,8 +141,7 @@ class ReplicationEndpoint(object):
instance_map = hs.config.worker.instance_map
@trace(opname="outgoing_replication_request")
- @defer.inlineCallbacks
- def send_request(instance_name="master", **kwargs):
+ async def send_request(instance_name="master", **kwargs):
if instance_name == local_instance_name:
raise Exception("Trying to send HTTP request to self")
if instance_name == "master":
@@ -165,7 +155,7 @@ class ReplicationEndpoint(object):
"Instance %r not in 'instance_map' config" % (instance_name,)
)
- data = yield cls._serialize_payload(**kwargs)
+ data = await cls._serialize_payload(**kwargs)
url_args = [
urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
@@ -203,7 +193,7 @@ class ReplicationEndpoint(object):
headers = {} # type: Dict[bytes, List[bytes]]
inject_active_span_byte_dict(headers, None, check_destination=False)
try:
- result = yield request_func(uri, data, headers=headers)
+ result = await request_func(uri, data, headers=headers)
break
except CodeMessageException as e:
if e.code != 504 or not cls.RETRY_ON_TIMEOUT:
@@ -213,14 +203,14 @@ class ReplicationEndpoint(object):
# If we timed out we probably don't need to worry about backing
# off too much, but lets just wait a little anyway.
- yield clock.sleep(1)
+ await clock.sleep(1)
except HttpResponseException as e:
# We convert to SynapseError as we know that it was a SynapseError
# on the master process that we should send to the client. (And
# importantly, not stack traces everywhere)
raise e.to_synapse_error()
except RequestSendFailed as e:
- raise_from(SynapseError(502, "Failed to talk to master"), e)
+ raise SynapseError(502, "Failed to talk to master") from e
return result
@@ -242,11 +232,8 @@ class ReplicationEndpoint(object):
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
- handler = trace_servlet(self.__class__.__name__, extract_context=True)(handler)
- # We don't let register paths trace this servlet using the default tracing
- # options because we wish to extract the context explicitly.
http_server.register_paths(
- method, [pattern], handler, self.__class__.__name__, trace=False
+ method, [pattern], handler, self.__class__.__name__,
)
def _cached_handler(self, request, txn_id, **kwargs):
diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
index e32aac0a25..20f3ba76c0 100644
--- a/synapse/replication/http/devices.py
+++ b/synapse/replication/http/devices.py
@@ -60,7 +60,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
- def _serialize_payload(user_id):
+ async def _serialize_payload(user_id):
return {}
async def _handle_request(self, request, user_id):
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index c287c4e269..6b56315148 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext
@@ -67,8 +65,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
self.federation_handler = hs.get_handlers().federation_handler
@staticmethod
- @defer.inlineCallbacks
- def _serialize_payload(store, event_and_contexts, backfilled):
+ async def _serialize_payload(store, event_and_contexts, backfilled):
"""
Args:
store
@@ -78,7 +75,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
"""
event_payloads = []
for event, context in event_and_contexts:
- serialized_context = yield context.serialize(event, store)
+ serialized_context = await context.serialize(event, store)
event_payloads.append(
{
@@ -154,7 +151,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
self.registry = hs.get_federation_registry()
@staticmethod
- def _serialize_payload(edu_type, origin, content):
+ async def _serialize_payload(edu_type, origin, content):
return {"origin": origin, "content": content}
async def _handle_request(self, request, edu_type):
@@ -197,7 +194,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
self.registry = hs.get_federation_registry()
@staticmethod
- def _serialize_payload(query_type, args):
+ async def _serialize_payload(query_type, args):
"""
Args:
query_type (str)
@@ -238,7 +235,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
self.store = hs.get_datastore()
@staticmethod
- def _serialize_payload(room_id, args):
+ async def _serialize_payload(room_id, args):
"""
Args:
room_id (str)
@@ -273,7 +270,7 @@ class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint):
self.store = hs.get_datastore()
@staticmethod
- def _serialize_payload(room_id, room_version):
+ async def _serialize_payload(room_id, room_version):
return {"room_version": room_version.identifier}
async def _handle_request(self, request, room_id):
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 798b9d3af5..fb326bb869 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -36,7 +36,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler()
@staticmethod
- def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
+ async def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
"""
Args:
device_id (str|None): Device ID to use, if None a new one is
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index a7174c4a8f..741329ab5f 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -14,11 +14,11 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
-from synapse.types import Requester, UserID
+from synapse.types import JsonDict, Requester, UserID
from synapse.util.distributor import user_joined_room, user_left_room
if TYPE_CHECKING:
@@ -52,7 +52,9 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
- def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content):
+ async def _serialize_payload(
+ requester, room_id, user_id, remote_room_hosts, content
+ ):
"""
Args:
requester(Requester)
@@ -88,49 +90,54 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
- """Rejects the invite for the user and room.
+ """Rejects an out-of-band invite we have received from a remote server
Request format:
- POST /_synapse/replication/remote_reject_invite/:room_id/:user_id
+ POST /_synapse/replication/remote_reject_invite/:event_id
{
+ "txn_id": ...,
"requester": ...,
- "remote_room_hosts": [...],
"content": { ... }
}
"""
NAME = "remote_reject_invite"
- PATH_ARGS = ("room_id", "user_id")
+ PATH_ARGS = ("invite_event_id",)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super(ReplicationRemoteRejectInviteRestServlet, self).__init__(hs)
- self.federation_handler = hs.get_handlers().federation_handler
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.member_handler = hs.get_room_member_handler()
@staticmethod
- def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content):
+ async def _serialize_payload( # type: ignore
+ invite_event_id: str,
+ txn_id: Optional[str],
+ requester: Requester,
+ content: JsonDict,
+ ):
"""
Args:
- requester(Requester)
- room_id (str)
- user_id (str)
- remote_room_hosts (list[str]): Servers to try and reject via
+ invite_event_id: ID of the invite to be rejected
+ txn_id: optional transaction ID supplied by the client
+ requester: user making the rejection request, according to the access token
+ content: additional content to include in the rejection event.
+ Normally an empty dict.
"""
return {
+ "txn_id": txn_id,
"requester": requester.serialize(),
- "remote_room_hosts": remote_room_hosts,
"content": content,
}
- async def _handle_request(self, request, room_id, user_id):
+ async def _handle_request(self, request, invite_event_id):
content = parse_json_object_from_request(request)
- remote_room_hosts = content["remote_room_hosts"]
+ txn_id = content["txn_id"]
event_content = content["content"]
requester = Requester.deserialize(self.store, content["requester"])
@@ -138,60 +145,14 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
if requester.user:
request.authenticated_entity = requester.user.to_string()
- logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id)
-
- try:
- event, stream_id = await self.federation_handler.do_remotely_reject_invite(
- remote_room_hosts, room_id, user_id, event_content,
- )
- event_id = event.event_id
- except Exception as e:
- # if we were unable to reject the exception, just mark
- # it as rejected on our end and plough ahead.
- #
- # The 'except' clause is very broad, but we need to
- # capture everything from DNS failures upwards
- #
- logger.warning("Failed to reject invite: %s", e)
-
- stream_id = await self.member_handler.locally_reject_invite(
- user_id, room_id
- )
- event_id = None
+ # hopefully we're now on the master, so this won't recurse!
+ event_id, stream_id = await self.member_handler.remote_reject_invite(
+ invite_event_id, txn_id, requester, event_content,
+ )
return 200, {"event_id": event_id, "stream_id": stream_id}
-class ReplicationLocallyRejectInviteRestServlet(ReplicationEndpoint):
- """Rejects the invite for the user and room locally.
-
- Request format:
-
- POST /_synapse/replication/locally_reject_invite/:room_id/:user_id
-
- {}
- """
-
- NAME = "locally_reject_invite"
- PATH_ARGS = ("room_id", "user_id")
-
- def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
-
- self.member_handler = hs.get_room_member_handler()
-
- @staticmethod
- def _serialize_payload(room_id, user_id):
- return {}
-
- async def _handle_request(self, request, room_id, user_id):
- logger.info("locally_reject_invite: %s out of room: %s", user_id, room_id)
-
- stream_id = await self.member_handler.locally_reject_invite(user_id, room_id)
-
- return 200, {"stream_id": stream_id}
-
-
class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
"""Notifies that a user has joined or left the room
@@ -215,7 +176,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
self.distributor = hs.get_distributor()
@staticmethod
- def _serialize_payload(room_id, user_id, change):
+ async def _serialize_payload(room_id, user_id, change):
"""
Args:
room_id (str)
@@ -245,4 +206,3 @@ def register_servlets(hs, http_server):
ReplicationRemoteJoinRestServlet(hs).register(http_server)
ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)
- ReplicationLocallyRejectInviteRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py
index ea1b33331b..bc9aa82cb4 100644
--- a/synapse/replication/http/presence.py
+++ b/synapse/replication/http/presence.py
@@ -50,7 +50,7 @@ class ReplicationBumpPresenceActiveTime(ReplicationEndpoint):
self._presence_handler = hs.get_presence_handler()
@staticmethod
- def _serialize_payload(user_id):
+ async def _serialize_payload(user_id):
return {}
async def _handle_request(self, request, user_id):
@@ -92,7 +92,7 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
self._presence_handler = hs.get_presence_handler()
@staticmethod
- def _serialize_payload(user_id, state, ignore_status_msg=False):
+ async def _serialize_payload(user_id, state, ignore_status_msg=False):
return {
"state": state,
"ignore_status_msg": ignore_status_msg,
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 0c4aca1291..a02b27474d 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -34,7 +34,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler()
@staticmethod
- def _serialize_payload(
+ async def _serialize_payload(
user_id,
password_hash,
was_guest,
@@ -44,6 +44,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
admin,
user_type,
address,
+ shadow_banned,
):
"""
Args:
@@ -60,6 +61,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
address (str|None): the IP address used to perform the regitration.
+ shadow_banned (bool): Whether to shadow-ban the user
"""
return {
"password_hash": password_hash,
@@ -70,6 +72,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
"admin": admin,
"user_type": user_type,
"address": address,
+ "shadow_banned": shadow_banned,
}
async def _handle_request(self, request, user_id):
@@ -87,6 +90,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
admin=content["admin"],
user_type=content["user_type"],
address=content["address"],
+ shadow_banned=content["shadow_banned"],
)
return 200, {}
@@ -105,7 +109,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler()
@staticmethod
- def _serialize_payload(user_id, auth_result, access_token):
+ async def _serialize_payload(user_id, auth_result, access_token):
"""
Args:
user_id (str): The user ID that consented
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index c981723c1a..f13d452426 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext
@@ -62,8 +60,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
- @defer.inlineCallbacks
- def _serialize_payload(
+ async def _serialize_payload(
event_id, store, event, context, requester, ratelimit, extra_users
):
"""
@@ -77,7 +74,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
extra_users (list(UserID)): Any extra users to notify about event
"""
- serialized_context = yield context.serialize(event, store)
+ serialized_context = await context.serialize(event, store)
payload = {
"event": event.get_pdu_json(),
diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py
index bde97eef32..309159e304 100644
--- a/synapse/replication/http/streams.py
+++ b/synapse/replication/http/streams.py
@@ -54,7 +54,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
self.streams = hs.get_replication_streams()
@staticmethod
- def _serialize_payload(stream_name, from_token, upto_token):
+ async def _serialize_payload(stream_name, from_token, upto_token):
return {"from_token": from_token, "upto_token": upto_token}
async def _handle_request(self, request, stream_name):
|