summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_base.py28
-rw-r--r--synapse/federation/federation_client.py61
-rw-r--r--synapse/federation/federation_server.py181
-rw-r--r--synapse/federation/send_queue.py2
-rw-r--r--synapse/federation/sender/__init__.py9
-rw-r--r--synapse/federation/transport/client.py56
6 files changed, 212 insertions, 125 deletions
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 5c991e5412..4b115aac04 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -25,19 +25,15 @@ from twisted.python.failure import Failure
 
 from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
 from synapse.api.errors import Codes, SynapseError
-from synapse.api.room_versions import (
-    KNOWN_ROOM_VERSIONS,
-    EventFormatVersions,
-    RoomVersion,
-)
+from synapse.api.room_versions import EventFormatVersions, RoomVersion
 from synapse.crypto.event_signing import check_event_content_hash
 from synapse.crypto.keyring import Keyring
 from synapse.events import EventBase, make_event_from_dict
 from synapse.events.utils import prune_event
 from synapse.http.servlet import assert_params_in_dict
 from synapse.logging.context import (
-    LoggingContext,
     PreserveLoggingContext,
+    current_context,
     make_deferred_yieldable,
 )
 from synapse.types import JsonDict, get_domain_from_id
@@ -55,13 +51,15 @@ class FederationBase(object):
         self.store = hs.get_datastore()
         self._clock = hs.get_clock()
 
-    def _check_sigs_and_hash(self, room_version: str, pdu: EventBase) -> Deferred:
+    def _check_sigs_and_hash(
+        self, room_version: RoomVersion, pdu: EventBase
+    ) -> Deferred:
         return make_deferred_yieldable(
             self._check_sigs_and_hashes(room_version, [pdu])[0]
         )
 
     def _check_sigs_and_hashes(
-        self, room_version: str, pdus: List[EventBase]
+        self, room_version: RoomVersion, pdus: List[EventBase]
     ) -> List[Deferred]:
         """Checks that each of the received events is correctly signed by the
         sending server.
@@ -80,7 +78,7 @@ class FederationBase(object):
         """
         deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
 
-        ctx = LoggingContext.current_context()
+        ctx = current_context()
 
         def callback(_, pdu: EventBase):
             with PreserveLoggingContext(ctx):
@@ -146,7 +144,7 @@ class PduToCheckSig(
 
 
 def _check_sigs_on_pdus(
-    keyring: Keyring, room_version: str, pdus: Iterable[EventBase]
+    keyring: Keyring, room_version: RoomVersion, pdus: Iterable[EventBase]
 ) -> List[Deferred]:
     """Check that the given events are correctly signed
 
@@ -191,10 +189,6 @@ def _check_sigs_on_pdus(
         for p in pdus
     ]
 
-    v = KNOWN_ROOM_VERSIONS.get(room_version)
-    if not v:
-        raise RuntimeError("Unrecognized room version %s" % (room_version,))
-
     # First we check that the sender event is signed by the sender's domain
     # (except if its a 3pid invite, in which case it may be sent by any server)
     pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
@@ -204,7 +198,7 @@ def _check_sigs_on_pdus(
             (
                 p.sender_domain,
                 p.redacted_pdu_json,
-                p.pdu.origin_server_ts if v.enforce_key_validity else 0,
+                p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
                 p.pdu.event_id,
             )
             for p in pdus_to_check_sender
@@ -227,7 +221,7 @@ def _check_sigs_on_pdus(
     # event id's domain (normally only the case for joins/leaves), and add additional
     # checks. Only do this if the room version has a concept of event ID domain
     # (ie, the room version uses old-style non-hash event IDs).
-    if v.event_format == EventFormatVersions.V1:
+    if room_version.event_format == EventFormatVersions.V1:
         pdus_to_check_event_id = [
             p
             for p in pdus_to_check
@@ -239,7 +233,7 @@ def _check_sigs_on_pdus(
                 (
                     get_domain_from_id(p.pdu.event_id),
                     p.redacted_pdu_json,
-                    p.pdu.origin_server_ts if v.enforce_key_validity else 0,
+                    p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
                     p.pdu.event_id,
                 )
                 for p in pdus_to_check_event_id
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 8c6b839478..687cd841ac 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -220,8 +220,7 @@ class FederationClient(FederationBase):
         # FIXME: We should handle signature failures more gracefully.
         pdus[:] = await make_deferred_yieldable(
             defer.gatherResults(
-                self._check_sigs_and_hashes(room_version.identifier, pdus),
-                consumeErrors=True,
+                self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True,
             ).addErrback(unwrapFirstError)
         )
 
@@ -291,9 +290,7 @@ class FederationClient(FederationBase):
                     pdu = pdu_list[0]
 
                     # Check signatures are correct.
-                    signed_pdu = await self._check_sigs_and_hash(
-                        room_version.identifier, pdu
-                    )
+                    signed_pdu = await self._check_sigs_and_hash(room_version, pdu)
 
                     break
 
@@ -350,7 +347,7 @@ class FederationClient(FederationBase):
         self,
         origin: str,
         pdus: List[EventBase],
-        room_version: str,
+        room_version: RoomVersion,
         outlier: bool = False,
         include_none: bool = False,
     ) -> List[EventBase]:
@@ -396,7 +393,7 @@ class FederationClient(FederationBase):
                         self.get_pdu(
                             destinations=[pdu.origin],
                             event_id=pdu.event_id,
-                            room_version=room_version,  # type: ignore
+                            room_version=room_version,
                             outlier=outlier,
                             timeout=10000,
                         )
@@ -434,7 +431,7 @@ class FederationClient(FederationBase):
         ]
 
         signed_auth = await self._check_sigs_and_hash_and_fetch(
-            destination, auth_chain, outlier=True, room_version=room_version.identifier
+            destination, auth_chain, outlier=True, room_version=room_version
         )
 
         signed_auth.sort(key=lambda e: e.depth)
@@ -661,7 +658,7 @@ class FederationClient(FederationBase):
                 destination,
                 list(pdus.values()),
                 outlier=True,
-                room_version=room_version.identifier,
+                room_version=room_version,
             )
 
             valid_pdus_map = {p.event_id: p for p in valid_pdus}
@@ -756,7 +753,7 @@ class FederationClient(FederationBase):
         pdu = event_from_pdu_json(pdu_dict, room_version)
 
         # Check signatures are correct.
-        pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
+        pdu = await self._check_sigs_and_hash(room_version, pdu)
 
         # FIXME: We should handle signature failures more gracefully.
 
@@ -886,18 +883,37 @@ class FederationClient(FederationBase):
 
     def get_public_rooms(
         self,
-        destination,
-        limit=None,
-        since_token=None,
-        search_filter=None,
-        include_all_networks=False,
-        third_party_instance_id=None,
+        remote_server: str,
+        limit: Optional[int] = None,
+        since_token: Optional[str] = None,
+        search_filter: Optional[Dict] = None,
+        include_all_networks: bool = False,
+        third_party_instance_id: Optional[str] = None,
     ):
-        if destination == self.server_name:
-            return
+        """Get the list of public rooms from a remote homeserver
 
+        Args:
+            remote_server: The name of the remote server
+            limit: Maximum amount of rooms to return
+            since_token: Used for result pagination
+            search_filter: A filter dictionary to send the remote homeserver
+                and filter the result set
+            include_all_networks: Whether to include results from all third party instances
+            third_party_instance_id: Whether to only include results from a specific third
+                party instance
+
+        Returns:
+            Deferred[Dict[str, Any]]: The response from the remote server, or None if
+            `remote_server` is the same as the local server_name
+
+        Raises:
+            HttpResponseException: There was an exception returned from the remote server
+            SynapseException: M_FORBIDDEN when the remote server has disallowed publicRoom
+                requests over federation
+
+        """
         return self.transport_layer.get_public_rooms(
-            destination,
+            remote_server,
             limit,
             since_token,
             search_filter,
@@ -948,7 +964,7 @@ class FederationClient(FederationBase):
             ]
 
             signed_events = await self._check_sigs_and_hash_and_fetch(
-                destination, events, outlier=False, room_version=room_version.identifier
+                destination, events, outlier=False, room_version=room_version
             )
         except HttpResponseException as e:
             if not e.code == 400:
@@ -960,14 +976,13 @@ class FederationClient(FederationBase):
 
         return signed_events
 
-    @defer.inlineCallbacks
-    def forward_third_party_invite(self, destinations, room_id, event_dict):
+    async def forward_third_party_invite(self, destinations, room_id, event_dict):
         for destination in destinations:
             if destination == self.server_name:
                 continue
 
             try:
-                yield self.transport_layer.exchange_third_party_invite(
+                await self.transport_layer.exchange_third_party_invite(
                     destination=destination, room_id=room_id, event_dict=event_dict
                 )
                 return None
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 275b9c99d7..32a8a2ee46 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -15,7 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Dict
+from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union
 
 import six
 from six import iteritems
@@ -38,6 +38,7 @@ from synapse.api.errors import (
     UnsupportedRoomVersionError,
 )
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.events import EventBase
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
 from synapse.federation.persistence import TransactionActions
 from synapse.federation.units import Edu, Transaction
@@ -94,7 +95,9 @@ class FederationServer(FederationBase):
         # come in waves.
         self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
 
-    async def on_backfill_request(self, origin, room_id, versions, limit):
+    async def on_backfill_request(
+        self, origin: str, room_id: str, versions: List[str], limit: int
+    ) -> Tuple[int, Dict[str, Any]]:
         with (await self._server_linearizer.queue((origin, room_id))):
             origin_host, _ = parse_server_name(origin)
             await self.check_server_matches_acl(origin_host, room_id)
@@ -107,23 +110,25 @@ class FederationServer(FederationBase):
 
         return 200, res
 
-    async def on_incoming_transaction(self, origin, transaction_data):
+    async def on_incoming_transaction(
+        self, origin: str, transaction_data: JsonDict
+    ) -> Tuple[int, Dict[str, Any]]:
         # keep this as early as possible to make the calculated origin ts as
         # accurate as possible.
         request_time = self._clock.time_msec()
 
         transaction = Transaction(**transaction_data)
 
-        if not transaction.transaction_id:
+        if not transaction.transaction_id:  # type: ignore
             raise Exception("Transaction missing transaction_id")
 
-        logger.debug("[%s] Got transaction", transaction.transaction_id)
+        logger.debug("[%s] Got transaction", transaction.transaction_id)  # type: ignore
 
         # use a linearizer to ensure that we don't process the same transaction
         # multiple times in parallel.
         with (
             await self._transaction_linearizer.queue(
-                (origin, transaction.transaction_id)
+                (origin, transaction.transaction_id)  # type: ignore
             )
         ):
             result = await self._handle_incoming_transaction(
@@ -132,31 +137,33 @@ class FederationServer(FederationBase):
 
         return result
 
-    async def _handle_incoming_transaction(self, origin, transaction, request_time):
+    async def _handle_incoming_transaction(
+        self, origin: str, transaction: Transaction, request_time: int
+    ) -> Tuple[int, Dict[str, Any]]:
         """ Process an incoming transaction and return the HTTP response
 
         Args:
-            origin (unicode): the server making the request
-            transaction (Transaction): incoming transaction
-            request_time (int): timestamp that the HTTP request arrived at
+            origin: the server making the request
+            transaction: incoming transaction
+            request_time: timestamp that the HTTP request arrived at
 
         Returns:
-            Deferred[(int, object)]: http response code and body
+            HTTP response code and body
         """
         response = await self.transaction_actions.have_responded(origin, transaction)
 
         if response:
             logger.debug(
                 "[%s] We've already responded to this request",
-                transaction.transaction_id,
+                transaction.transaction_id,  # type: ignore
             )
             return response
 
-        logger.debug("[%s] Transaction is new", transaction.transaction_id)
+        logger.debug("[%s] Transaction is new", transaction.transaction_id)  # type: ignore
 
         # Reject if PDU count > 50 or EDU count > 100
-        if len(transaction.pdus) > 50 or (
-            hasattr(transaction, "edus") and len(transaction.edus) > 100
+        if len(transaction.pdus) > 50 or (  # type: ignore
+            hasattr(transaction, "edus") and len(transaction.edus) > 100  # type: ignore
         ):
 
             logger.info("Transaction PDU or EDU count too large. Returning 400")
@@ -204,13 +211,13 @@ class FederationServer(FederationBase):
             report back to the sending server.
         """
 
-        received_pdus_counter.inc(len(transaction.pdus))
+        received_pdus_counter.inc(len(transaction.pdus))  # type: ignore
 
         origin_host, _ = parse_server_name(origin)
 
-        pdus_by_room = {}
+        pdus_by_room = {}  # type: Dict[str, List[EventBase]]
 
-        for p in transaction.pdus:
+        for p in transaction.pdus:  # type: ignore
             if "unsigned" in p:
                 unsigned = p["unsigned"]
                 if "age" in unsigned:
@@ -254,7 +261,7 @@ class FederationServer(FederationBase):
         # require callouts to other servers to fetch missing events), but
         # impose a limit to avoid going too crazy with ram/cpu.
 
-        async def process_pdus_for_room(room_id):
+        async def process_pdus_for_room(room_id: str):
             logger.debug("Processing PDUs for %s", room_id)
             try:
                 await self.check_server_matches_acl(origin_host, room_id)
@@ -310,7 +317,9 @@ class FederationServer(FederationBase):
             TRANSACTION_CONCURRENCY_LIMIT,
         )
 
-    async def on_context_state_request(self, origin, room_id, event_id):
+    async def on_context_state_request(
+        self, origin: str, room_id: str, event_id: str
+    ) -> Tuple[int, Dict[str, Any]]:
         origin_host, _ = parse_server_name(origin)
         await self.check_server_matches_acl(origin_host, room_id)
 
@@ -338,7 +347,9 @@ class FederationServer(FederationBase):
 
         return 200, resp
 
-    async def on_state_ids_request(self, origin, room_id, event_id):
+    async def on_state_ids_request(
+        self, origin: str, room_id: str, event_id: str
+    ) -> Tuple[int, Dict[str, Any]]:
         if not event_id:
             raise NotImplementedError("Specify an event")
 
@@ -354,7 +365,9 @@ class FederationServer(FederationBase):
 
         return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
 
-    async def _on_context_state_request_compute(self, room_id, event_id):
+    async def _on_context_state_request_compute(
+        self, room_id: str, event_id: str
+    ) -> Dict[str, list]:
         if event_id:
             pdus = await self.handler.get_state_for_pdu(room_id, event_id)
         else:
@@ -367,7 +380,9 @@ class FederationServer(FederationBase):
             "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
         }
 
-    async def on_pdu_request(self, origin, event_id):
+    async def on_pdu_request(
+        self, origin: str, event_id: str
+    ) -> Tuple[int, Union[JsonDict, str]]:
         pdu = await self.handler.get_persisted_pdu(origin, event_id)
 
         if pdu:
@@ -375,12 +390,16 @@ class FederationServer(FederationBase):
         else:
             return 404, ""
 
-    async def on_query_request(self, query_type, args):
+    async def on_query_request(
+        self, query_type: str, args: Dict[str, str]
+    ) -> Tuple[int, Dict[str, Any]]:
         received_queries_counter.labels(query_type).inc()
         resp = await self.registry.on_query(query_type, args)
         return 200, resp
 
-    async def on_make_join_request(self, origin, room_id, user_id, supported_versions):
+    async def on_make_join_request(
+        self, origin: str, room_id: str, user_id: str, supported_versions: List[str]
+    ) -> Dict[str, Any]:
         origin_host, _ = parse_server_name(origin)
         await self.check_server_matches_acl(origin_host, room_id)
 
@@ -397,7 +416,7 @@ class FederationServer(FederationBase):
 
     async def on_invite_request(
         self, origin: str, content: JsonDict, room_version_id: str
-    ):
+    ) -> Dict[str, Any]:
         room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
         if not room_version:
             raise SynapseError(
@@ -409,12 +428,14 @@ class FederationServer(FederationBase):
         pdu = event_from_pdu_json(content, room_version)
         origin_host, _ = parse_server_name(origin)
         await self.check_server_matches_acl(origin_host, pdu.room_id)
-        pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
+        pdu = await self._check_sigs_and_hash(room_version, pdu)
         ret_pdu = await self.handler.on_invite_request(origin, pdu, room_version)
         time_now = self._clock.time_msec()
         return {"event": ret_pdu.get_pdu_json(time_now)}
 
-    async def on_send_join_request(self, origin, content, room_id):
+    async def on_send_join_request(
+        self, origin: str, content: JsonDict, room_id: str
+    ) -> Dict[str, Any]:
         logger.debug("on_send_join_request: content: %s", content)
 
         room_version = await self.store.get_room_version(room_id)
@@ -425,7 +446,7 @@ class FederationServer(FederationBase):
 
         logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
 
-        pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
+        pdu = await self._check_sigs_and_hash(room_version, pdu)
 
         res_pdus = await self.handler.on_send_join_request(origin, pdu)
         time_now = self._clock.time_msec()
@@ -434,7 +455,9 @@ class FederationServer(FederationBase):
             "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]],
         }
 
-    async def on_make_leave_request(self, origin, room_id, user_id):
+    async def on_make_leave_request(
+        self, origin: str, room_id: str, user_id: str
+    ) -> Dict[str, Any]:
         origin_host, _ = parse_server_name(origin)
         await self.check_server_matches_acl(origin_host, room_id)
         pdu = await self.handler.on_make_leave_request(origin, room_id, user_id)
@@ -444,7 +467,9 @@ class FederationServer(FederationBase):
         time_now = self._clock.time_msec()
         return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
 
-    async def on_send_leave_request(self, origin, content, room_id):
+    async def on_send_leave_request(
+        self, origin: str, content: JsonDict, room_id: str
+    ) -> dict:
         logger.debug("on_send_leave_request: content: %s", content)
 
         room_version = await self.store.get_room_version(room_id)
@@ -455,12 +480,14 @@ class FederationServer(FederationBase):
 
         logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
 
-        pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
+        pdu = await self._check_sigs_and_hash(room_version, pdu)
 
         await self.handler.on_send_leave_request(origin, pdu)
         return {}
 
-    async def on_event_auth(self, origin, room_id, event_id):
+    async def on_event_auth(
+        self, origin: str, room_id: str, event_id: str
+    ) -> Tuple[int, Dict[str, Any]]:
         with (await self._server_linearizer.queue((origin, room_id))):
             origin_host, _ = parse_server_name(origin)
             await self.check_server_matches_acl(origin_host, room_id)
@@ -471,15 +498,21 @@ class FederationServer(FederationBase):
         return 200, res
 
     @log_function
-    def on_query_client_keys(self, origin, content):
-        return self.on_query_request("client_keys", content)
-
-    async def on_query_user_devices(self, origin: str, user_id: str):
+    async def on_query_client_keys(
+        self, origin: str, content: Dict[str, str]
+    ) -> Tuple[int, Dict[str, Any]]:
+        return await self.on_query_request("client_keys", content)
+
+    async def on_query_user_devices(
+        self, origin: str, user_id: str
+    ) -> Tuple[int, Dict[str, Any]]:
         keys = await self.device_handler.on_federation_query_user_devices(user_id)
         return 200, keys
 
     @trace
-    async def on_claim_client_keys(self, origin, content):
+    async def on_claim_client_keys(
+        self, origin: str, content: JsonDict
+    ) -> Dict[str, Any]:
         query = []
         for user_id, device_keys in content.get("one_time_keys", {}).items():
             for device_id, algorithm in device_keys.items():
@@ -488,7 +521,7 @@ class FederationServer(FederationBase):
         log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
         results = await self.store.claim_e2e_one_time_keys(query)
 
-        json_result = {}
+        json_result = {}  # type: Dict[str, Dict[str, dict]]
         for user_id, device_keys in results.items():
             for device_id, keys in device_keys.items():
                 for key_id, json_bytes in keys.items():
@@ -511,8 +544,13 @@ class FederationServer(FederationBase):
         return {"one_time_keys": json_result}
 
     async def on_get_missing_events(
-        self, origin, room_id, earliest_events, latest_events, limit
-    ):
+        self,
+        origin: str,
+        room_id: str,
+        earliest_events: List[str],
+        latest_events: List[str],
+        limit: int,
+    ) -> Dict[str, list]:
         with (await self._server_linearizer.queue((origin, room_id))):
             origin_host, _ = parse_server_name(origin)
             await self.check_server_matches_acl(origin_host, room_id)
@@ -541,11 +579,11 @@ class FederationServer(FederationBase):
         return {"events": [ev.get_pdu_json(time_now) for ev in missing_events]}
 
     @log_function
-    def on_openid_userinfo(self, token):
+    async def on_openid_userinfo(self, token: str) -> Optional[str]:
         ts_now_ms = self._clock.time_msec()
-        return self.store.get_user_id_for_open_id_token(token, ts_now_ms)
+        return await self.store.get_user_id_for_open_id_token(token, ts_now_ms)
 
-    def _transaction_from_pdus(self, pdu_list):
+    def _transaction_from_pdus(self, pdu_list: List[EventBase]) -> Transaction:
         """Returns a new Transaction containing the given PDUs suitable for
         transmission.
         """
@@ -558,7 +596,7 @@ class FederationServer(FederationBase):
             destination=None,
         )
 
-    async def _handle_received_pdu(self, origin, pdu):
+    async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None:
         """ Process a PDU received in a federation /send/ transaction.
 
         If the event is invalid, then this method throws a FederationError.
@@ -579,10 +617,8 @@ class FederationServer(FederationBase):
         until we try to backfill across the discontinuity.
 
         Args:
-            origin (str): server which sent the pdu
-            pdu (FrozenEvent): received pdu
-
-        Returns (Deferred): completes with None
+            origin: server which sent the pdu
+            pdu: received pdu
 
         Raises: FederationError if the signatures / hash do not match, or
             if the event was unacceptable for any other reason (eg, too large,
@@ -611,7 +647,7 @@ class FederationServer(FederationBase):
                 logger.info("Accepting join PDU %s from %s", pdu.event_id, origin)
 
         # We've already checked that we know the room version by this point
-        room_version = await self.store.get_room_version_id(pdu.room_id)
+        room_version = await self.store.get_room_version(pdu.room_id)
 
         # Check signature.
         try:
@@ -625,25 +661,27 @@ class FederationServer(FederationBase):
         return "<ReplicationLayer(%s)>" % self.server_name
 
     async def exchange_third_party_invite(
-        self, sender_user_id, target_user_id, room_id, signed
+        self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict
     ):
         ret = await self.handler.exchange_third_party_invite(
             sender_user_id, target_user_id, room_id, signed
         )
         return ret
 
-    async def on_exchange_third_party_invite_request(self, room_id, event_dict):
+    async def on_exchange_third_party_invite_request(
+        self, room_id: str, event_dict: Dict
+    ):
         ret = await self.handler.on_exchange_third_party_invite_request(
             room_id, event_dict
         )
         return ret
 
-    async def check_server_matches_acl(self, server_name, room_id):
+    async def check_server_matches_acl(self, server_name: str, room_id: str):
         """Check if the given server is allowed by the server ACLs in the room
 
         Args:
-            server_name (str): name of server, *without any port part*
-            room_id (str): ID of the room to check
+            server_name: name of server, *without any port part*
+            room_id: ID of the room to check
 
         Raises:
             AuthError if the server does not match the ACL
@@ -661,15 +699,15 @@ class FederationServer(FederationBase):
         raise AuthError(code=403, msg="Server is banned from room")
 
 
-def server_matches_acl_event(server_name, acl_event):
+def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
     """Check if the given server is allowed by the ACL event
 
     Args:
-        server_name (str): name of server, without any port part
-        acl_event (EventBase): m.room.server_acl event
+        server_name: name of server, without any port part
+        acl_event: m.room.server_acl event
 
     Returns:
-        bool: True if this server is allowed by the ACLs
+        True if this server is allowed by the ACLs
     """
     logger.debug("Checking %s against acl %s", server_name, acl_event.content)
 
@@ -713,7 +751,7 @@ def server_matches_acl_event(server_name, acl_event):
     return False
 
 
-def _acl_entry_matches(server_name, acl_entry):
+def _acl_entry_matches(server_name: str, acl_entry: str) -> Match:
     if not isinstance(acl_entry, six.string_types):
         logger.warning(
             "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
@@ -732,13 +770,13 @@ class FederationHandlerRegistry(object):
         self.edu_handlers = {}
         self.query_handlers = {}
 
-    def register_edu_handler(self, edu_type, handler):
+    def register_edu_handler(self, edu_type: str, handler: Callable[[str, dict], None]):
         """Sets the handler callable that will be used to handle an incoming
         federation EDU of the given type.
 
         Args:
-            edu_type (str): The type of the incoming EDU to register handler for
-            handler (Callable[[str, dict]]): A callable invoked on incoming EDU
+            edu_type: The type of the incoming EDU to register handler for
+            handler: A callable invoked on incoming EDU
                 of the given type. The arguments are the origin server name and
                 the EDU contents.
         """
@@ -749,14 +787,16 @@ class FederationHandlerRegistry(object):
 
         self.edu_handlers[edu_type] = handler
 
-    def register_query_handler(self, query_type, handler):
+    def register_query_handler(
+        self, query_type: str, handler: Callable[[dict], defer.Deferred]
+    ):
         """Sets the handler callable that will be used to handle an incoming
         federation query of the given type.
 
         Args:
-            query_type (str): Category name of the query, which should match
+            query_type: Category name of the query, which should match
                 the string used by make_query.
-            handler (Callable[[dict], Deferred[dict]]): Invoked to handle
+            handler: Invoked to handle
                 incoming queries of this type. The return will be yielded
                 on and the result used as the response to the query request.
         """
@@ -767,10 +807,11 @@ class FederationHandlerRegistry(object):
 
         self.query_handlers[query_type] = handler
 
-    async def on_edu(self, edu_type, origin, content):
+    async def on_edu(self, edu_type: str, origin: str, content: dict):
         handler = self.edu_handlers.get(edu_type)
         if not handler:
             logger.warning("No handler registered for EDU type %s", edu_type)
+            return
 
         with start_active_span_from_edu(content, "handle_edu"):
             try:
@@ -780,7 +821,7 @@ class FederationHandlerRegistry(object):
             except Exception:
                 logger.exception("Failed to handle edu %r", edu_type)
 
-    def on_query(self, query_type, args):
+    def on_query(self, query_type: str, args: dict) -> defer.Deferred:
         handler = self.query_handlers.get(query_type)
         if not handler:
             logger.warning("No handler registered for query type %s", query_type)
@@ -807,7 +848,7 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
 
         super(ReplicationFederationHandlerRegistry, self).__init__()
 
-    async def on_edu(self, edu_type, origin, content):
+    async def on_edu(self, edu_type: str, origin: str, content: dict):
         """Overrides FederationHandlerRegistry
         """
         if not self.config.use_presence and edu_type == "m.presence":
@@ -821,7 +862,7 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
 
         return await self._send_edu(edu_type=edu_type, origin=origin, content=content)
 
-    async def on_query(self, query_type, args):
+    async def on_query(self, query_type: str, args: dict):
         """Overrides FederationHandlerRegistry
         """
         handler = self.query_handlers.get(query_type)
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 876fb0e245..e1700ca8aa 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -477,7 +477,7 @@ def process_rows_for_federation(transaction_queue, rows):
 
     Args:
         transaction_queue (FederationSender)
-        rows (list(synapse.replication.tcp.streams.FederationStreamRow))
+        rows (list(synapse.replication.tcp.streams.federation.FederationStream.FederationStreamRow))
     """
 
     # The federation stream contains a bunch of different types of
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 233cb33daf..a477578e44 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -499,4 +499,13 @@ class FederationSender(object):
         self._get_per_destination_queue(destination).attempt_new_transaction()
 
     def get_current_token(self) -> int:
+        # Dummy implementation for case where federation sender isn't offloaded
+        # to a worker.
         return 0
+
+    async def get_replication_rows(
+        self, from_token, to_token, limit, federation_ack=None
+    ):
+        # Dummy implementation for case where federation sender isn't offloaded
+        # to a worker.
+        return []
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 383e3fdc8b..060bf07197 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -15,13 +15,14 @@
 # limitations under the License.
 
 import logging
-from typing import Any, Dict
+from typing import Any, Dict, Optional
 
 from six.moves import urllib
 
 from twisted.internet import defer
 
 from synapse.api.constants import Membership
+from synapse.api.errors import Codes, HttpResponseException, SynapseError
 from synapse.api.urls import (
     FEDERATION_UNSTABLE_PREFIX,
     FEDERATION_V1_PREFIX,
@@ -326,18 +327,25 @@ class TransportLayerClient(object):
     @log_function
     def get_public_rooms(
         self,
-        remote_server,
-        limit,
-        since_token,
-        search_filter=None,
-        include_all_networks=False,
-        third_party_instance_id=None,
+        remote_server: str,
+        limit: Optional[int] = None,
+        since_token: Optional[str] = None,
+        search_filter: Optional[Dict] = None,
+        include_all_networks: bool = False,
+        third_party_instance_id: Optional[str] = None,
     ):
+        """Get the list of public rooms from a remote homeserver
+
+        See synapse.federation.federation_client.FederationClient.get_public_rooms for
+        more information.
+        """
         if search_filter:
             # this uses MSC2197 (Search Filtering over Federation)
             path = _create_v1_path("/publicRooms")
 
-            data = {"include_all_networks": "true" if include_all_networks else "false"}
+            data = {
+                "include_all_networks": "true" if include_all_networks else "false"
+            }  # type: Dict[str, Any]
             if third_party_instance_id:
                 data["third_party_instance_id"] = third_party_instance_id
             if limit:
@@ -347,9 +355,19 @@ class TransportLayerClient(object):
 
             data["filter"] = search_filter
 
-            response = yield self.client.post_json(
-                destination=remote_server, path=path, data=data, ignore_backoff=True
-            )
+            try:
+                response = yield self.client.post_json(
+                    destination=remote_server, path=path, data=data, ignore_backoff=True
+                )
+            except HttpResponseException as e:
+                if e.code == 403:
+                    raise SynapseError(
+                        403,
+                        "You are not allowed to view the public rooms list of %s"
+                        % (remote_server,),
+                        errcode=Codes.FORBIDDEN,
+                    )
+                raise
         else:
             path = _create_v1_path("/publicRooms")
 
@@ -363,9 +381,19 @@ class TransportLayerClient(object):
             if since_token:
                 args["since"] = [since_token]
 
-            response = yield self.client.get_json(
-                destination=remote_server, path=path, args=args, ignore_backoff=True
-            )
+            try:
+                response = yield self.client.get_json(
+                    destination=remote_server, path=path, args=args, ignore_backoff=True
+                )
+            except HttpResponseException as e:
+                if e.code == 403:
+                    raise SynapseError(
+                        403,
+                        "You are not allowed to view the public rooms list of %s"
+                        % (remote_server,),
+                        errcode=Codes.FORBIDDEN,
+                    )
+                raise
 
         return response