summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-04-14 12:35:28 -0400
committerGitHub <noreply@github.com>2021-04-14 12:35:28 -0400
commit936e69825ab684ca580edb6038e86fdb2e561776 (patch)
tree586428b9d2306769c22dad062cbf8e63151f6e91
parentRevert "Check for space membership during a remote join of a restricted room.... (diff)
downloadsynapse-936e69825ab684ca580edb6038e86fdb2e561776.tar.xz
Separate creating an event context from persisting it in the federation handler (#9800)
This refactoring allows adding logic that uses the event context
before persisting it.
-rw-r--r--changelog.d/9800.feature1
-rw-r--r--synapse/handlers/federation.py178
-rw-r--r--tests/test_federation.py6
3 files changed, 118 insertions, 67 deletions
diff --git a/changelog.d/9800.feature b/changelog.d/9800.feature
new file mode 100644
index 0000000000..9404ad2fc0
--- /dev/null
+++ b/changelog.d/9800.feature
@@ -0,0 +1 @@
+Update experimental support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083): restricting room access via group membership.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index fe1d83f6b8..4b3730aa3b 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -103,7 +103,7 @@ logger = logging.getLogger(__name__)
 
 @attr.s(slots=True)
 class _NewEventInfo:
-    """Holds information about a received event, ready for passing to _handle_new_events
+    """Holds information about a received event, ready for passing to _auth_and_persist_events
 
     Attributes:
         event: the received event
@@ -807,7 +807,10 @@ class FederationHandler(BaseHandler):
         logger.debug("Processing event: %s", event)
 
         try:
-            await self._handle_new_event(origin, event, state=state)
+            context = await self.state_handler.compute_event_context(
+                event, old_state=state
+            )
+            await self._auth_and_persist_event(origin, event, context, state=state)
         except AuthError as e:
             raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
 
@@ -1010,7 +1013,9 @@ class FederationHandler(BaseHandler):
             )
 
         if ev_infos:
-            await self._handle_new_events(dest, room_id, ev_infos, backfilled=True)
+            await self._auth_and_persist_events(
+                dest, room_id, ev_infos, backfilled=True
+            )
 
         # Step 2: Persist the rest of the events in the chunk one by one
         events.sort(key=lambda e: e.depth)
@@ -1023,10 +1028,12 @@ class FederationHandler(BaseHandler):
             # non-outliers
             assert not event.internal_metadata.is_outlier()
 
+            context = await self.state_handler.compute_event_context(event)
+
             # We store these one at a time since each event depends on the
             # previous to work out the state.
             # TODO: We can probably do something more clever here.
-            await self._handle_new_event(dest, event, backfilled=True)
+            await self._auth_and_persist_event(dest, event, context, backfilled=True)
 
         return events
 
@@ -1360,7 +1367,7 @@ class FederationHandler(BaseHandler):
 
             event_infos.append(_NewEventInfo(event, None, auth))
 
-        await self._handle_new_events(
+        await self._auth_and_persist_events(
             destination,
             room_id,
             event_infos,
@@ -1666,10 +1673,11 @@ class FederationHandler(BaseHandler):
         # would introduce the danger of backwards-compatibility problems.
         event.internal_metadata.send_on_behalf_of = origin
 
-        context = await self._handle_new_event(origin, event)
+        context = await self.state_handler.compute_event_context(event)
+        context = await self._auth_and_persist_event(origin, event, context)
 
         logger.debug(
-            "on_send_join_request: After _handle_new_event: %s, sigs: %s",
+            "on_send_join_request: After _auth_and_persist_event: %s, sigs: %s",
             event.event_id,
             event.signatures,
         )
@@ -1878,10 +1886,11 @@ class FederationHandler(BaseHandler):
 
         event.internal_metadata.outlier = False
 
-        await self._handle_new_event(origin, event)
+        context = await self.state_handler.compute_event_context(event)
+        await self._auth_and_persist_event(origin, event, context)
 
         logger.debug(
-            "on_send_leave_request: After _handle_new_event: %s, sigs: %s",
+            "on_send_leave_request: After _auth_and_persist_event: %s, sigs: %s",
             event.event_id,
             event.signatures,
         )
@@ -1989,16 +1998,47 @@ class FederationHandler(BaseHandler):
     async def get_min_depth_for_context(self, context: str) -> int:
         return await self.store.get_min_depth(context)
 
-    async def _handle_new_event(
+    async def _auth_and_persist_event(
         self,
         origin: str,
         event: EventBase,
+        context: EventContext,
         state: Optional[Iterable[EventBase]] = None,
         auth_events: Optional[MutableStateMap[EventBase]] = None,
         backfilled: bool = False,
     ) -> EventContext:
-        context = await self._prep_event(
-            origin, event, state=state, auth_events=auth_events, backfilled=backfilled
+        """
+        Process an event by performing auth checks and then persisting to the database.
+
+        Args:
+            origin: The host the event originates from.
+            event: The event itself.
+            context:
+                The event context.
+
+                NB that this function potentially modifies it.
+            state:
+                The state events used to check the event for soft-fail. If this is
+                not provided the current state events will be used.
+            auth_events:
+                Map from (event_type, state_key) to event
+
+                Normally, our calculated auth_events based on the state of the room
+                at the event's position in the DAG, though occasionally (eg if the
+                event is an outlier), may be the auth events claimed by the remote
+                server.
+            backfilled: True if the event was backfilled.
+
+        Returns:
+             The event context.
+        """
+        context = await self._check_event_auth(
+            origin,
+            event,
+            context,
+            state=state,
+            auth_events=auth_events,
+            backfilled=backfilled,
         )
 
         try:
@@ -2022,7 +2062,7 @@ class FederationHandler(BaseHandler):
 
         return context
 
-    async def _handle_new_events(
+    async def _auth_and_persist_events(
         self,
         origin: str,
         room_id: str,
@@ -2040,9 +2080,13 @@ class FederationHandler(BaseHandler):
         async def prep(ev_info: _NewEventInfo):
             event = ev_info.event
             with nested_logging_context(suffix=event.event_id):
-                res = await self._prep_event(
+                res = await self.state_handler.compute_event_context(
+                    event, old_state=ev_info.state
+                )
+                res = await self._check_event_auth(
                     origin,
                     event,
+                    res,
                     state=ev_info.state,
                     auth_events=ev_info.auth_events,
                     backfilled=backfilled,
@@ -2177,49 +2221,6 @@ class FederationHandler(BaseHandler):
             room_id, [(event, new_event_context)]
         )
 
-    async def _prep_event(
-        self,
-        origin: str,
-        event: EventBase,
-        state: Optional[Iterable[EventBase]],
-        auth_events: Optional[MutableStateMap[EventBase]],
-        backfilled: bool,
-    ) -> EventContext:
-        context = await self.state_handler.compute_event_context(event, old_state=state)
-
-        if not auth_events:
-            prev_state_ids = await context.get_prev_state_ids()
-            auth_events_ids = self.auth.compute_auth_events(
-                event, prev_state_ids, for_verification=True
-            )
-            auth_events_x = await self.store.get_events(auth_events_ids)
-            auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
-
-        # This is a hack to fix some old rooms where the initial join event
-        # didn't reference the create event in its auth events.
-        if event.type == EventTypes.Member and not event.auth_event_ids():
-            if len(event.prev_event_ids()) == 1 and event.depth < 5:
-                c = await self.store.get_event(
-                    event.prev_event_ids()[0], allow_none=True
-                )
-                if c and c.type == EventTypes.Create:
-                    auth_events[(c.type, c.state_key)] = c
-
-        context = await self.do_auth(origin, event, context, auth_events=auth_events)
-
-        if not context.rejected:
-            await self._check_for_soft_fail(event, state, backfilled)
-
-        if event.type == EventTypes.GuestAccess and not context.rejected:
-            await self.maybe_kick_guest_users(event)
-
-        # If we are going to send this event over federation we precaclculate
-        # the joined hosts.
-        if event.internal_metadata.get_send_on_behalf_of():
-            await self.event_creation_handler.cache_joined_hosts_for_event(event)
-
-        return context
-
     async def _check_for_soft_fail(
         self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool
     ) -> None:
@@ -2330,19 +2331,28 @@ class FederationHandler(BaseHandler):
 
         return missing_events
 
-    async def do_auth(
+    async def _check_event_auth(
         self,
         origin: str,
         event: EventBase,
         context: EventContext,
-        auth_events: MutableStateMap[EventBase],
+        state: Optional[Iterable[EventBase]],
+        auth_events: Optional[MutableStateMap[EventBase]],
+        backfilled: bool,
     ) -> EventContext:
         """
+        Checks whether an event should be rejected (for failing auth checks).
 
         Args:
-            origin:
-            event:
+            origin: The host the event originates from.
+            event: The event itself.
             context:
+                The event context.
+
+                NB that this function potentially modifies it.
+            state:
+                The state events used to check the event for soft-fail. If this is
+                not provided the current state events will be used.
             auth_events:
                 Map from (event_type, state_key) to event
 
@@ -2352,12 +2362,34 @@ class FederationHandler(BaseHandler):
                 server.
 
                 Also NB that this function adds entries to it.
+
+                If this is not provided, it is calculated from the previous state IDs.
+            backfilled: True if the event was backfilled.
+
         Returns:
-            updated context object
+            The updated context object.
         """
         room_version = await self.store.get_room_version_id(event.room_id)
         room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
 
+        if not auth_events:
+            prev_state_ids = await context.get_prev_state_ids()
+            auth_events_ids = self.auth.compute_auth_events(
+                event, prev_state_ids, for_verification=True
+            )
+            auth_events_x = await self.store.get_events(auth_events_ids)
+            auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
+
+        # This is a hack to fix some old rooms where the initial join event
+        # didn't reference the create event in its auth events.
+        if event.type == EventTypes.Member and not event.auth_event_ids():
+            if len(event.prev_event_ids()) == 1 and event.depth < 5:
+                c = await self.store.get_event(
+                    event.prev_event_ids()[0], allow_none=True
+                )
+                if c and c.type == EventTypes.Create:
+                    auth_events[(c.type, c.state_key)] = c
+
         try:
             context = await self._update_auth_events_and_context_for_auth(
                 origin, event, context, auth_events
@@ -2379,6 +2411,17 @@ class FederationHandler(BaseHandler):
             logger.warning("Failed auth resolution for %r because %s", event, e)
             context.rejected = RejectedReason.AUTH_ERROR
 
+        if not context.rejected:
+            await self._check_for_soft_fail(event, state, backfilled)
+
+        if event.type == EventTypes.GuestAccess and not context.rejected:
+            await self.maybe_kick_guest_users(event)
+
+        # If we are going to send this event over federation we precaclculate
+        # the joined hosts.
+        if event.internal_metadata.get_send_on_behalf_of():
+            await self.event_creation_handler.cache_joined_hosts_for_event(event)
+
         return context
 
     async def _update_auth_events_and_context_for_auth(
@@ -2388,7 +2431,7 @@ class FederationHandler(BaseHandler):
         context: EventContext,
         auth_events: MutableStateMap[EventBase],
     ) -> EventContext:
-        """Helper for do_auth. See there for docs.
+        """Helper for _check_event_auth. See there for docs.
 
         Checks whether a given event has the expected auth events. If it
         doesn't then we talk to the remote server to compare state to see if
@@ -2468,9 +2511,14 @@ class FederationHandler(BaseHandler):
                         e.internal_metadata.outlier = True
 
                         logger.debug(
-                            "do_auth %s missing_auth: %s", event.event_id, e.event_id
+                            "_check_event_auth %s missing_auth: %s",
+                            event.event_id,
+                            e.event_id,
+                        )
+                        context = await self.state_handler.compute_event_context(e)
+                        await self._auth_and_persist_event(
+                            origin, e, context, auth_events=auth
                         )
-                        await self._handle_new_event(origin, e, auth_events=auth)
 
                         if e.event_id in event_auth_events:
                             auth_events[(e.type, e.state_key)] = e
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 86a44a13da..0a3a996ec1 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -75,8 +75,10 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         )
 
         self.handler = self.homeserver.get_federation_handler()
-        self.handler.do_auth = lambda origin, event, context, auth_events: succeed(
-            context
+        self.handler._check_event_auth = (
+            lambda origin, event, context, state, auth_events, backfilled: succeed(
+                context
+            )
         )
         self.client = self.homeserver.get_federation_client()
         self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(