summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2019-12-05 15:02:35 +0000
committerGitHub <noreply@github.com>2019-12-05 15:02:35 +0000
commit63d6ad1064c1a5fe23da3b6b64474a2b211f5eea (patch)
tree3e2b1e6632ea1c0524771faca1a98d646904e798
parentSanity-check the rooms of auth events before pulling them in. (#6472) (diff)
downloadsynapse-63d6ad1064c1a5fe23da3b6b64474a2b211f5eea.tar.xz
Stronger typing in the federation handler (#6480)
replace the event_info dict with an attrs thing

Diffstat (limited to '')
-rw-r--r--changelog.d/6480.misc1
-rw-r--r--synapse/handlers/federation.py81
2 files changed, 58 insertions, 24 deletions
diff --git a/changelog.d/6480.misc b/changelog.d/6480.misc
new file mode 100644
index 0000000000..d9a44389b9
--- /dev/null
+++ b/changelog.d/6480.misc
@@ -0,0 +1 @@
+Refactor some code in the event authentication path for clarity.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index f5d04cdf91..bc26921768 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -19,11 +19,13 @@
 
 import itertools
 import logging
+from typing import Dict, Iterable, Optional, Sequence, Tuple
 
 import six
 from six import iteritems, itervalues
 from six.moves import http_client, zip
 
+import attr
 from signedjson.key import decode_verify_key_bytes
 from signedjson.sign import verify_signed_json
 from unpaddedbase64 import decode_base64
@@ -45,6 +47,7 @@ from synapse.api.errors import (
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
 from synapse.crypto.event_signing import compute_event_signature
 from synapse.event_auth import auth_types_for_event
+from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.events.validator import EventValidator
 from synapse.logging.context import (
@@ -72,6 +75,23 @@ from ._base import BaseHandler
 logger = logging.getLogger(__name__)
 
 
+@attr.s
+class _NewEventInfo:
+    """Holds information about a received event, ready for passing to _handle_new_events
+
+    Attributes:
+        event: the received event
+
+        state: the state at that event
+
+        auth_events: the auth_event map for that event
+    """
+
+    event = attr.ib(type=EventBase)
+    state = attr.ib(type=Optional[Sequence[EventBase]], default=None)
+    auth_events = attr.ib(type=Optional[Dict[Tuple[str, str], EventBase]], default=None)
+
+
 def shortstr(iterable, maxitems=5):
     """If iterable has maxitems or fewer, return the stringification of a list
     containing those items.
@@ -597,14 +617,14 @@ class FederationHandler(BaseHandler):
                     for e in auth_chain
                     if e.event_id in auth_ids or e.type == EventTypes.Create
                 }
-                event_infos.append({"event": e, "auth_events": auth})
+                event_infos.append(_NewEventInfo(event=e, auth_events=auth))
                 seen_ids.add(e.event_id)
 
             logger.info(
                 "[%s %s] persisting newly-received auth/state events %s",
                 room_id,
                 event_id,
-                [e["event"].event_id for e in event_infos],
+                [e.event.event_id for e in event_infos],
             )
             yield self._handle_new_events(origin, event_infos)
 
@@ -795,9 +815,9 @@ class FederationHandler(BaseHandler):
 
             a.internal_metadata.outlier = True
             ev_infos.append(
-                {
-                    "event": a,
-                    "auth_events": {
+                _NewEventInfo(
+                    event=a,
+                    auth_events={
                         (
                             auth_events[a_id].type,
                             auth_events[a_id].state_key,
@@ -805,7 +825,7 @@ class FederationHandler(BaseHandler):
                         for a_id in a.auth_event_ids()
                         if a_id in auth_events
                     },
-                }
+                )
             )
 
         # Step 1b: persist the events in the chunk we fetched state for (i.e.
@@ -817,10 +837,10 @@ class FederationHandler(BaseHandler):
             assert not ev.internal_metadata.is_outlier()
 
             ev_infos.append(
-                {
-                    "event": ev,
-                    "state": events_to_state[e_id],
-                    "auth_events": {
+                _NewEventInfo(
+                    event=ev,
+                    state=events_to_state[e_id],
+                    auth_events={
                         (
                             auth_events[a_id].type,
                             auth_events[a_id].state_key,
@@ -828,7 +848,7 @@ class FederationHandler(BaseHandler):
                         for a_id in ev.auth_event_ids()
                         if a_id in auth_events
                     },
-                }
+                )
             )
 
         yield self._handle_new_events(dest, ev_infos, backfilled=True)
@@ -1713,7 +1733,12 @@ class FederationHandler(BaseHandler):
         return context
 
     @defer.inlineCallbacks
-    def _handle_new_events(self, origin, event_infos, backfilled=False):
+    def _handle_new_events(
+        self,
+        origin: str,
+        event_infos: Iterable[_NewEventInfo],
+        backfilled: bool = False,
+    ):
         """Creates the appropriate contexts and persists events. The events
         should not depend on one another, e.g. this should be used to persist
         a bunch of outliers, but not a chunk of individual events that depend
@@ -1723,14 +1748,14 @@ class FederationHandler(BaseHandler):
         """
 
         @defer.inlineCallbacks
-        def prep(ev_info):
-            event = ev_info["event"]
+        def prep(ev_info: _NewEventInfo):
+            event = ev_info.event
             with nested_logging_context(suffix=event.event_id):
                 res = yield self._prep_event(
                     origin,
                     event,
-                    state=ev_info.get("state"),
-                    auth_events=ev_info.get("auth_events"),
+                    state=ev_info.state,
+                    auth_events=ev_info.auth_events,
                     backfilled=backfilled,
                 )
             return res
@@ -1744,7 +1769,7 @@ class FederationHandler(BaseHandler):
 
         yield self.persist_events_and_notify(
             [
-                (ev_info["event"], context)
+                (ev_info.event, context)
                 for ev_info, context in zip(event_infos, contexts)
             ],
             backfilled=backfilled,
@@ -1846,7 +1871,14 @@ class FederationHandler(BaseHandler):
         yield self.persist_events_and_notify([(event, new_event_context)])
 
     @defer.inlineCallbacks
-    def _prep_event(self, origin, event, state, auth_events, backfilled):
+    def _prep_event(
+        self,
+        origin: str,
+        event: EventBase,
+        state: Optional[Iterable[EventBase]],
+        auth_events: Optional[Dict[Tuple[str, str], EventBase]],
+        backfilled: bool,
+    ):
         """
 
         Args:
@@ -1854,7 +1886,7 @@ class FederationHandler(BaseHandler):
             event:
             state:
             auth_events:
-            backfilled (bool)
+            backfilled:
 
         Returns:
             Deferred, which resolves to synapse.events.snapshot.EventContext
@@ -1890,15 +1922,16 @@ class FederationHandler(BaseHandler):
         return context
 
     @defer.inlineCallbacks
-    def _check_for_soft_fail(self, event, state, backfilled):
+    def _check_for_soft_fail(
+        self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool
+    ):
         """Checks if we should soft fail the event, if so marks the event as
         such.
 
         Args:
-            event (FrozenEvent)
-            state (dict|None): The state at the event if we don't have all the
-                event's prev events
-            backfilled (bool): Whether the event is from backfill
+            event
+            state: The state at the event if we don't have all the event's prev events
+            backfilled: Whether the event is from backfill
 
         Returns:
             Deferred