summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
authorreivilibre <oliverw@matrix.org>2021-10-27 17:27:23 +0100
committerGitHub <noreply@github.com>2021-10-27 17:27:23 +0100
commit75ca0a6168f92dab3255839cf85fb0df3a0076c3 (patch)
treeb4326bf5fae23b6df52d9f43dbc6d1ddce3b68c6 /synapse/federation
parentFixed config parse bug in review_recent_signups (#11191) (diff)
downloadsynapse-75ca0a6168f92dab3255839cf85fb0df3a0076c3.tar.xz
Annotate `log_function` decorator (#10943)
Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_client.py17
-rw-r--r--synapse/federation/federation_server.py10
-rw-r--r--synapse/federation/sender/transaction_manager.py1
-rw-r--r--synapse/federation/transport/client.py22
4 files changed, 39 insertions, 11 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 2ab4dec88f..670186f548 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -227,7 +227,7 @@ class FederationClient(FederationBase):
         )
 
     async def backfill(
-        self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
+        self, dest: str, room_id: str, limit: int, extremities: Collection[str]
     ) -> Optional[List[EventBase]]:
         """Requests some more historic PDUs for the given room from the
         given destination server.
@@ -237,6 +237,8 @@ class FederationClient(FederationBase):
             room_id: The room_id to backfill.
             limit: The maximum number of events to return.
             extremities: our current backwards extremities, to backfill from
+                Must be a Collection that is falsy when empty.
+                (Iterable is not enough here!)
         """
         logger.debug("backfill extrem=%s", extremities)
 
@@ -250,11 +252,22 @@ class FederationClient(FederationBase):
 
         logger.debug("backfill transaction_data=%r", transaction_data)
 
+        if not isinstance(transaction_data, dict):
+            # TODO we probably want an exception type specific to federation
+            # client validation.
+            raise TypeError("Backfill transaction_data is not a dict.")
+
+        transaction_data_pdus = transaction_data.get("pdus")
+        if not isinstance(transaction_data_pdus, list):
+            # TODO we probably want an exception type specific to federation
+            # client validation.
+            raise TypeError("transaction_data.pdus is not a list.")
+
         room_version = await self.store.get_room_version(room_id)
 
         pdus = [
             event_from_pdu_json(p, room_version, outlier=False)
-            for p in transaction_data["pdus"]
+            for p in transaction_data_pdus
         ]
 
         # Check signatures and hash of pdus, removing any from the list that fail checks
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 0d66034f44..32a75993d9 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -295,14 +295,16 @@ class FederationServer(FederationBase):
         Returns:
             HTTP response code and body
         """
-        response = await self.transaction_actions.have_responded(origin, transaction)
+        existing_response = await self.transaction_actions.have_responded(
+            origin, transaction
+        )
 
-        if response:
+        if existing_response:
             logger.debug(
                 "[%s] We've already responded to this request",
                 transaction.transaction_id,
             )
-            return response
+            return existing_response
 
         logger.debug("[%s] Transaction is new", transaction.transaction_id)
 
@@ -632,7 +634,7 @@ class FederationServer(FederationBase):
 
     async def on_make_knock_request(
         self, origin: str, room_id: str, user_id: str, supported_versions: List[str]
-    ) -> Dict[str, Union[EventBase, str]]:
+    ) -> JsonDict:
         """We've received a /make_knock/ request, so we create a partial knock
         event for the room and hand that back, along with the room version, to the knocking
         homeserver. We do *not* persist or process this event until the other server has
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index dc555cca0b..ab935e5a7e 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -149,7 +149,6 @@ class TransactionManager:
                 )
             except HttpResponseException as e:
                 code = e.code
-                response = e.response
 
                 set_tag(tags.ERROR, True)
 
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 8b247fe206..d963178838 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -15,7 +15,19 @@
 
 import logging
 import urllib
-from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union
+from typing import (
+    Any,
+    Awaitable,
+    Callable,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Mapping,
+    Optional,
+    Tuple,
+    Union,
+)
 
 import attr
 import ijson
@@ -100,7 +112,7 @@ class TransportLayerClient:
 
     @log_function
     async def backfill(
-        self, destination: str, room_id: str, event_tuples: Iterable[str], limit: int
+        self, destination: str, room_id: str, event_tuples: Collection[str], limit: int
     ) -> Optional[JsonDict]:
         """Requests `limit` previous PDUs in a given context before list of
         PDUs.
@@ -108,7 +120,9 @@ class TransportLayerClient:
         Args:
             destination
             room_id
-            event_tuples
+            event_tuples:
+                Must be a Collection that is falsy when empty.
+                (Iterable is not enough here!)
             limit
 
         Returns:
@@ -786,7 +800,7 @@ class TransportLayerClient:
     @log_function
     def join_group(
         self, destination: str, group_id: str, user_id: str, content: JsonDict
-    ) -> JsonDict:
+    ) -> Awaitable[JsonDict]:
         """Attempts to join a group"""
         path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id)