summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/event_auth.py10
-rw-r--r--synapse/handlers/federation.py49
2 files changed, 43 insertions, 16 deletions
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index c582355146..c0981eee62 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -65,14 +65,16 @@ def check(
 
     room_id = event.room_id
 
-    # I'm not really expecting to get auth events in the wrong room, but let's
-    # sanity-check it
+    # We need to ensure that the auth events are actually for the same room, to
+    # stop people from using powers they've been granted in other rooms for
+    # example.
     for auth_event in auth_events.values():
         if auth_event.room_id != room_id:
-            raise Exception(
+            raise AuthError(
+                403,
                 "During auth for event %s in room %s, found event %s in the state "
                 "which is in room %s"
-                % (event.event_id, room_id, auth_event.event_id, auth_event.room_id)
+                % (event.event_id, room_id, auth_event.event_id, auth_event.room_id),
             )
 
     if do_sig_check:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index ca7da42a3f..930ae088c6 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -618,6 +618,11 @@ class FederationHandler(BaseHandler):
         will be omitted from the result. Likewise, any events which turn out not to
         be in the given room.
 
+        This function *does not* automatically get missing auth events of the
+        newly fetched events. Callers must include the full auth chain of
+        of the missing events in the `event_ids` argument, to ensure that any
+        missing auth events are correctly fetched.
+
         Returns:
             map from event_id to event
         """
@@ -1131,12 +1136,16 @@ class FederationHandler(BaseHandler):
     ):
         """Fetch the given events from a server, and persist them as outliers.
 
+        This function *does not* recursively get missing auth events of the
+        newly fetched events. Callers must include in the `events` argument
+        any missing events from the auth chain.
+
         Logs a warning if we can't find the given event.
         """
 
         room_version = await self.store.get_room_version(room_id)
 
-        event_infos = []
+        event_map = {}  # type: Dict[str, EventBase]
 
         async def get_event(event_id: str):
             with nested_logging_context(event_id):
@@ -1150,17 +1159,7 @@ class FederationHandler(BaseHandler):
                         )
                         return
 
-                    # recursively fetch the auth events for this event
-                    auth_events = await self._get_events_from_store_or_dest(
-                        destination, room_id, event.auth_event_ids()
-                    )
-                    auth = {}
-                    for auth_event_id in event.auth_event_ids():
-                        ae = auth_events.get(auth_event_id)
-                        if ae:
-                            auth[(ae.type, ae.state_key)] = ae
-
-                    event_infos.append(_NewEventInfo(event, None, auth))
+                    event_map[event.event_id] = event
 
                 except Exception as e:
                     logger.warning(
@@ -1172,6 +1171,32 @@ class FederationHandler(BaseHandler):
 
         await concurrently_execute(get_event, events, 5)
 
+        # Make a map of auth events for each event. We do this after fetching
+        # all the events as some of the events' auth events will be in the list
+        # of requested events.
+
+        auth_events = [
+            aid
+            for event in event_map.values()
+            for aid in event.auth_event_ids()
+            if aid not in event_map
+        ]
+        persisted_events = await self.store.get_events(
+            auth_events, allow_rejected=True,
+        )
+
+        event_infos = []
+        for event in event_map.values():
+            auth = {}
+            for auth_event_id in event.auth_event_ids():
+                ae = persisted_events.get(auth_event_id) or event_map.get(auth_event_id)
+                if ae:
+                    auth[(ae.type, ae.state_key)] = ae
+                else:
+                    logger.info("Missing auth event %s", auth_event_id)
+
+            event_infos.append(_NewEventInfo(event, None, auth))
+
         await self._handle_new_events(
             destination, event_infos,
         )