summary refs log tree commit diff
path: root/synapse/federation/federation_server.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/federation_server.py')
-rw-r--r--synapse/federation/federation_server.py121
1 files changed, 72 insertions, 49 deletions
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 1d050e54e2..742d29291e 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -34,7 +34,7 @@ from twisted.internet import defer
 from twisted.internet.abstract import isIPAddress
 from twisted.python import failure
 
-from synapse.api.constants import EduTypes, EventTypes
+from synapse.api.constants import EduTypes, EventTypes, Membership
 from synapse.api.errors import (
     AuthError,
     Codes,
@@ -46,6 +46,7 @@ from synapse.api.errors import (
 )
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
 from synapse.federation.persistence import TransactionActions
 from synapse.federation.units import Edu, Transaction
@@ -544,26 +545,21 @@ class FederationServer(FederationBase):
         return {"event": ret_pdu.get_pdu_json(time_now)}
 
     async def on_send_join_request(
-        self, origin: str, content: JsonDict
+        self, origin: str, content: JsonDict, room_id: str
     ) -> Dict[str, Any]:
-        logger.debug("on_send_join_request: content: %s", content)
-
-        assert_params_in_dict(content, ["room_id"])
-        room_version = await self.store.get_room_version(content["room_id"])
-        pdu = event_from_pdu_json(content, room_version)
-
-        origin_host, _ = parse_server_name(origin)
-        await self.check_server_matches_acl(origin_host, pdu.room_id)
-
-        logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
+        context = await self._on_send_membership_event(
+            origin, content, Membership.JOIN, room_id
+        )
 
-        pdu = await self._check_sigs_and_hash(room_version, pdu)
+        prev_state_ids = await context.get_prev_state_ids()
+        state_ids = list(prev_state_ids.values())
+        auth_chain = await self.store.get_auth_chain(room_id, state_ids)
+        state = await self.store.get_events(state_ids)
 
-        res_pdus = await self.handler.on_send_join_request(origin, pdu)
         time_now = self._clock.time_msec()
         return {
-            "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
-            "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]],
+            "state": [p.get_pdu_json(time_now) for p in state.values()],
+            "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain],
         }
 
     async def on_make_leave_request(
@@ -578,21 +574,11 @@ class FederationServer(FederationBase):
         time_now = self._clock.time_msec()
         return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
 
-    async def on_send_leave_request(self, origin: str, content: JsonDict) -> dict:
+    async def on_send_leave_request(
+        self, origin: str, content: JsonDict, room_id: str
+    ) -> dict:
         logger.debug("on_send_leave_request: content: %s", content)
-
-        assert_params_in_dict(content, ["room_id"])
-        room_version = await self.store.get_room_version(content["room_id"])
-        pdu = event_from_pdu_json(content, room_version)
-
-        origin_host, _ = parse_server_name(origin)
-        await self.check_server_matches_acl(origin_host, pdu.room_id)
-
-        logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
-
-        pdu = await self._check_sigs_and_hash(room_version, pdu)
-
-        await self.handler.on_send_leave_request(origin, pdu)
+        await self._on_send_membership_event(origin, content, Membership.LEAVE, room_id)
         return {}
 
     async def on_make_knock_request(
@@ -658,39 +644,76 @@ class FederationServer(FederationBase):
         Returns:
             The stripped room state.
         """
-        logger.debug("on_send_knock_request: content: %s", content)
+        event_context = await self._on_send_membership_event(
+            origin, content, Membership.KNOCK, room_id
+        )
+
+        # Retrieve stripped state events from the room and send them back to the remote
+        # server. This will allow the remote server's clients to display information
+        # related to the room while the knock request is pending.
+        stripped_room_state = (
+            await self.store.get_stripped_room_state_from_event_context(
+                event_context, self._room_prejoin_state_types
+            )
+        )
+        return {"knock_state_events": stripped_room_state}
+
+    async def _on_send_membership_event(
+        self, origin: str, content: JsonDict, membership_type: str, room_id: str
+    ) -> EventContext:
+        """Handle an on_send_{join,leave,knock} request
+
+        Does some preliminary validation before passing the request on to the
+        federation handler.
+
+        Args:
+            origin: The (authenticated) requesting server
+            content: The body of the send_* request - a complete membership event
+            membership_type: The expected membership type (join or leave, depending
+                on the endpoint)
+            room_id: The room_id from the request, to be validated against the room_id
+                in the event
+
+        Returns:
+            The context of the event after inserting it into the room graph.
+
+        Raises:
+            SynapseError if there is a problem with the request, including things like
+               the room_id not matching or the event not being authorized.
+        """
+        assert_params_in_dict(content, ["room_id"])
+        if content["room_id"] != room_id:
+            raise SynapseError(
+                400,
+                "Room ID in body does not match that in request path",
+                Codes.BAD_JSON,
+            )
 
         room_version = await self.store.get_room_version(room_id)
 
-        # Check that this room supports knocking as defined by its room version
-        if not room_version.msc2403_knocking:
+        if membership_type == Membership.KNOCK and not room_version.msc2403_knocking:
             raise SynapseError(
                 403,
                 "This room version does not support knocking",
                 errcode=Codes.FORBIDDEN,
             )
 
-        pdu = event_from_pdu_json(content, room_version)
+        event = event_from_pdu_json(content, room_version)
 
-        origin_host, _ = parse_server_name(origin)
-        await self.check_server_matches_acl(origin_host, pdu.room_id)
+        if event.type != EventTypes.Member or not event.is_state():
+            raise SynapseError(400, "Not an m.room.member event", Codes.BAD_JSON)
 
-        logger.debug("on_send_knock_request: pdu sigs: %s", pdu.signatures)
+        if event.content.get("membership") != membership_type:
+            raise SynapseError(400, "Not a %s event" % membership_type, Codes.BAD_JSON)
 
-        pdu = await self._check_sigs_and_hash(room_version, pdu)
+        origin_host, _ = parse_server_name(origin)
+        await self.check_server_matches_acl(origin_host, event.room_id)
 
-        # Handle the event, and retrieve the EventContext
-        event_context = await self.handler.on_send_knock_request(origin, pdu)
+        logger.debug("_on_send_membership_event: pdu sigs: %s", event.signatures)
 
-        # Retrieve stripped state events from the room and send them back to the remote
-        # server. This will allow the remote server's clients to display information
-        # related to the room while the knock request is pending.
-        stripped_room_state = (
-            await self.store.get_stripped_room_state_from_event_context(
-                event_context, self._room_prejoin_state_types
-            )
-        )
-        return {"knock_state_events": stripped_room_state}
+        event = await self._check_sigs_and_hash(room_version, event)
+
+        return await self.handler.on_send_membership_event(origin, event)
 
     async def on_event_auth(
         self, origin: str, room_id: str, event_id: str