diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index b9ce3477ad..a7174c4a8f 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -14,14 +14,16 @@
# limitations under the License.
import logging
-
-from twisted.internet import defer
+from typing import TYPE_CHECKING
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import Requester, UserID
from synapse.util.distributor import user_joined_room, user_left_room
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -65,8 +67,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
"content": content,
}
- @defer.inlineCallbacks
- def _handle_request(self, request, room_id, user_id):
+ async def _handle_request(self, request, room_id, user_id):
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"]
@@ -79,11 +80,11 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
logger.info("remote_join: %s into room: %s", user_id, room_id)
- yield self.federation_handler.do_invite_join(
+ event_id, stream_id = await self.federation_handler.do_invite_join(
remote_room_hosts, room_id, user_id, event_content
)
- return 200, {}
+ return 200, {"event_id": event_id, "stream_id": stream_id}
class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
@@ -96,6 +97,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
{
"requester": ...,
"remote_room_hosts": [...],
+ "content": { ... }
}
"""
@@ -108,9 +110,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
self.federation_handler = hs.get_handlers().federation_handler
self.store = hs.get_datastore()
self.clock = hs.get_clock()
+ self.member_handler = hs.get_room_member_handler()
@staticmethod
- def _serialize_payload(requester, room_id, user_id, remote_room_hosts):
+ def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content):
"""
Args:
requester(Requester)
@@ -121,13 +124,14 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
return {
"requester": requester.serialize(),
"remote_room_hosts": remote_room_hosts,
+ "content": content,
}
- @defer.inlineCallbacks
- def _handle_request(self, request, room_id, user_id):
+ async def _handle_request(self, request, room_id, user_id):
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"]
+ event_content = content["content"]
requester = Requester.deserialize(self.store, content["requester"])
@@ -137,10 +141,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id)
try:
- event = yield self.federation_handler.do_remotely_reject_invite(
- remote_room_hosts, room_id, user_id
+ event, stream_id = await self.federation_handler.do_remotely_reject_invite(
+ remote_room_hosts, room_id, user_id, event_content,
)
- ret = event.get_pdu_json()
+ event_id = event.event_id
except Exception as e:
# if we were unable to reject the exception, just mark
# it as rejected on our end and plough ahead.
@@ -148,12 +152,44 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
# The 'except' clause is very broad, but we need to
# capture everything from DNS failures upwards
#
- logger.warn("Failed to reject invite: %s", e)
+ logger.warning("Failed to reject invite: %s", e)
+
+ stream_id = await self.member_handler.locally_reject_invite(
+ user_id, room_id
+ )
+ event_id = None
+
+ return 200, {"event_id": event_id, "stream_id": stream_id}
+
+
+class ReplicationLocallyRejectInviteRestServlet(ReplicationEndpoint):
+ """Rejects the invite for the user and room locally.
+
+ Request format:
+
+ POST /_synapse/replication/locally_reject_invite/:room_id/:user_id
+
+ {}
+ """
+
+ NAME = "locally_reject_invite"
+ PATH_ARGS = ("room_id", "user_id")
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+
+ self.member_handler = hs.get_room_member_handler()
+
+ @staticmethod
+ def _serialize_payload(room_id, user_id):
+ return {}
+
+ async def _handle_request(self, request, room_id, user_id):
+ logger.info("locally_reject_invite: %s out of room: %s", user_id, room_id)
- yield self.store.locally_reject_invite(user_id, room_id)
- ret = {}
+ stream_id = await self.member_handler.locally_reject_invite(user_id, room_id)
- return 200, ret
+ return 200, {"stream_id": stream_id}
class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
@@ -209,3 +245,4 @@ def register_servlets(hs, http_server):
ReplicationRemoteJoinRestServlet(hs).register(http_server)
ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)
+ ReplicationLocallyRejectInviteRestServlet(hs).register(http_server)
|