summary refs log tree commit diff
path: root/synapse/replication/http/membership.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/http/membership.py')
-rw-r--r--synapse/replication/http/membership.py69
1 files changed, 53 insertions, 16 deletions
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index b9ce3477ad..a7174c4a8f 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -14,14 +14,16 @@
 # limitations under the License.
 
 import logging
-
-from twisted.internet import defer
+from typing import TYPE_CHECKING
 
 from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
 from synapse.types import Requester, UserID
 from synapse.util.distributor import user_joined_room, user_left_room
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -65,8 +67,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
             "content": content,
         }
 
-    @defer.inlineCallbacks
-    def _handle_request(self, request, room_id, user_id):
+    async def _handle_request(self, request, room_id, user_id):
         content = parse_json_object_from_request(request)
 
         remote_room_hosts = content["remote_room_hosts"]
@@ -79,11 +80,11 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
 
         logger.info("remote_join: %s into room: %s", user_id, room_id)
 
-        yield self.federation_handler.do_invite_join(
+        event_id, stream_id = await self.federation_handler.do_invite_join(
             remote_room_hosts, room_id, user_id, event_content
         )
 
-        return 200, {}
+        return 200, {"event_id": event_id, "stream_id": stream_id}
 
 
 class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
@@ -96,6 +97,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
         {
             "requester": ...,
             "remote_room_hosts": [...],
+            "content": { ... }
         }
     """
 
@@ -108,9 +110,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
         self.federation_handler = hs.get_handlers().federation_handler
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
+        self.member_handler = hs.get_room_member_handler()
 
     @staticmethod
-    def _serialize_payload(requester, room_id, user_id, remote_room_hosts):
+    def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content):
         """
         Args:
             requester(Requester)
@@ -121,13 +124,14 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
         return {
             "requester": requester.serialize(),
             "remote_room_hosts": remote_room_hosts,
+            "content": content,
         }
 
-    @defer.inlineCallbacks
-    def _handle_request(self, request, room_id, user_id):
+    async def _handle_request(self, request, room_id, user_id):
         content = parse_json_object_from_request(request)
 
         remote_room_hosts = content["remote_room_hosts"]
+        event_content = content["content"]
 
         requester = Requester.deserialize(self.store, content["requester"])
 
@@ -137,10 +141,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
         logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id)
 
         try:
-            event = yield self.federation_handler.do_remotely_reject_invite(
-                remote_room_hosts, room_id, user_id
+            event, stream_id = await self.federation_handler.do_remotely_reject_invite(
+                remote_room_hosts, room_id, user_id, event_content,
             )
-            ret = event.get_pdu_json()
+            event_id = event.event_id
         except Exception as e:
             # if we were unable to reject the exception, just mark
             # it as rejected on our end and plough ahead.
@@ -148,12 +152,44 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
             # The 'except' clause is very broad, but we need to
             # capture everything from DNS failures upwards
             #
-            logger.warn("Failed to reject invite: %s", e)
+            logger.warning("Failed to reject invite: %s", e)
+
+            stream_id = await self.member_handler.locally_reject_invite(
+                user_id, room_id
+            )
+            event_id = None
+
+        return 200, {"event_id": event_id, "stream_id": stream_id}
+
+
+class ReplicationLocallyRejectInviteRestServlet(ReplicationEndpoint):
+    """Rejects the invite for the user and room locally.
+
+    Request format:
+
+        POST /_synapse/replication/locally_reject_invite/:room_id/:user_id
+
+        {}
+    """
+
+    NAME = "locally_reject_invite"
+    PATH_ARGS = ("room_id", "user_id")
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__(hs)
+
+        self.member_handler = hs.get_room_member_handler()
+
+    @staticmethod
+    def _serialize_payload(room_id, user_id):
+        return {}
+
+    async def _handle_request(self, request, room_id, user_id):
+        logger.info("locally_reject_invite: %s out of room: %s", user_id, room_id)
 
-            yield self.store.locally_reject_invite(user_id, room_id)
-            ret = {}
+        stream_id = await self.member_handler.locally_reject_invite(user_id, room_id)
 
-        return 200, ret
+        return 200, {"stream_id": stream_id}
 
 
 class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
@@ -209,3 +245,4 @@ def register_servlets(hs, http_server):
     ReplicationRemoteJoinRestServlet(hs).register(http_server)
     ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
     ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)
+    ReplicationLocallyRejectInviteRestServlet(hs).register(http_server)