summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_server.py45
-rw-r--r--synapse/federation/send_queue.py2
-rw-r--r--synapse/federation/sender/__init__.py11
-rw-r--r--synapse/federation/sender/per_destination_queue.py2
-rw-r--r--synapse/federation/transport/server.py70
5 files changed, 70 insertions, 60 deletions
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 24329dd0e3..4b6ab470d0 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -22,7 +22,6 @@ from typing import (
     Callable,
     Dict,
     List,
-    Match,
     Optional,
     Tuple,
     Union,
@@ -50,6 +49,7 @@ from synapse.federation.federation_base import FederationBase, event_from_pdu_js
 from synapse.federation.persistence import TransactionActions
 from synapse.federation.units import Edu, Transaction
 from synapse.http.endpoint import parse_server_name
+from synapse.http.servlet import assert_params_in_dict
 from synapse.logging.context import (
     make_deferred_yieldable,
     nested_logging_context,
@@ -100,10 +100,15 @@ class FederationServer(FederationBase):
         super().__init__(hs)
 
         self.auth = hs.get_auth()
-        self.handler = hs.get_handlers().federation_handler
+        self.handler = hs.get_federation_handler()
         self.state = hs.get_state_handler()
 
         self.device_handler = hs.get_device_handler()
+
+        # Ensure the following handlers are loaded since they register callbacks
+        # with FederationHandlerRegistry.
+        hs.get_directory_handler()
+
         self._federation_ratelimiter = hs.get_federation_ratelimiter()
 
         self._server_linearizer = Linearizer("fed_server")
@@ -112,7 +117,7 @@ class FederationServer(FederationBase):
         # We cache results for transaction with the same ID
         self._transaction_resp_cache = ResponseCache(
             hs, "fed_txn_handler", timeout_ms=30000
-        )
+        )  # type: ResponseCache[Tuple[str, str]]
 
         self.transaction_actions = TransactionActions(self.store)
 
@@ -120,10 +125,12 @@ class FederationServer(FederationBase):
 
         # We cache responses to state queries, as they take a while and often
         # come in waves.
-        self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
+        self._state_resp_cache = ResponseCache(
+            hs, "state_resp", timeout_ms=30000
+        )  # type: ResponseCache[Tuple[str, str]]
         self._state_ids_resp_cache = ResponseCache(
             hs, "state_ids_resp", timeout_ms=30000
-        )
+        )  # type: ResponseCache[Tuple[str, str]]
 
         self._federation_metrics_domains = (
             hs.get_config().federation.federation_metrics_domains
@@ -385,7 +392,7 @@ class FederationServer(FederationBase):
             TRANSACTION_CONCURRENCY_LIMIT,
         )
 
-    async def on_context_state_request(
+    async def on_room_state_request(
         self, origin: str, room_id: str, event_id: str
     ) -> Tuple[int, Dict[str, Any]]:
         origin_host, _ = parse_server_name(origin)
@@ -508,11 +515,12 @@ class FederationServer(FederationBase):
         return {"event": ret_pdu.get_pdu_json(time_now)}
 
     async def on_send_join_request(
-        self, origin: str, content: JsonDict, room_id: str
+        self, origin: str, content: JsonDict
     ) -> Dict[str, Any]:
         logger.debug("on_send_join_request: content: %s", content)
 
-        room_version = await self.store.get_room_version(room_id)
+        assert_params_in_dict(content, ["room_id"])
+        room_version = await self.store.get_room_version(content["room_id"])
         pdu = event_from_pdu_json(content, room_version)
 
         origin_host, _ = parse_server_name(origin)
@@ -541,12 +549,11 @@ class FederationServer(FederationBase):
         time_now = self._clock.time_msec()
         return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
 
-    async def on_send_leave_request(
-        self, origin: str, content: JsonDict, room_id: str
-    ) -> dict:
+    async def on_send_leave_request(self, origin: str, content: JsonDict) -> dict:
         logger.debug("on_send_leave_request: content: %s", content)
 
-        room_version = await self.store.get_room_version(room_id)
+        assert_params_in_dict(content, ["room_id"])
+        room_version = await self.store.get_room_version(content["room_id"])
         pdu = event_from_pdu_json(content, room_version)
 
         origin_host, _ = parse_server_name(origin)
@@ -742,12 +749,8 @@ class FederationServer(FederationBase):
         )
         return ret
 
-    async def on_exchange_third_party_invite_request(
-        self, room_id: str, event_dict: Dict
-    ):
-        ret = await self.handler.on_exchange_third_party_invite_request(
-            room_id, event_dict
-        )
+    async def on_exchange_third_party_invite_request(self, event_dict: Dict):
+        ret = await self.handler.on_exchange_third_party_invite_request(event_dict)
         return ret
 
     async def check_server_matches_acl(self, server_name: str, room_id: str):
@@ -825,14 +828,14 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
     return False
 
 
-def _acl_entry_matches(server_name: str, acl_entry: str) -> Match:
+def _acl_entry_matches(server_name: str, acl_entry: Any) -> bool:
     if not isinstance(acl_entry, str):
         logger.warning(
             "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
         )
         return False
     regex = glob_to_regex(acl_entry)
-    return regex.match(server_name)
+    return bool(regex.match(server_name))
 
 
 class FederationHandlerRegistry:
@@ -862,7 +865,7 @@ class FederationHandlerRegistry:
         self._edu_type_to_instance = {}  # type: Dict[str, str]
 
     def register_edu_handler(
-        self, edu_type: str, handler: Callable[[str, dict], Awaitable[None]]
+        self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
     ):
         """Sets the handler callable that will be used to handle an incoming
         federation EDU of the given type.
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 8e46957d15..5f1bf492c1 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -188,7 +188,7 @@ class FederationRemoteSendQueue:
             for key in keys[:i]:
                 del self.edus[key]
 
-    def notify_new_events(self, current_id):
+    def notify_new_events(self, max_token):
         """As per FederationSender"""
         # We don't need to replicate this as it gets sent down a different
         # stream.
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 8bb17b3a05..604cfd1935 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -40,7 +40,7 @@ from synapse.metrics import (
     events_processed_counter,
 )
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import ReadReceipt
+from synapse.types import ReadReceipt, RoomStreamToken
 from synapse.util.metrics import Measure, measure_func
 
 logger = logging.getLogger(__name__)
@@ -154,10 +154,15 @@ class FederationSender:
             self._per_destination_queues[destination] = queue
         return queue
 
-    def notify_new_events(self, current_id: int) -> None:
+    def notify_new_events(self, max_token: RoomStreamToken) -> None:
         """This gets called when we have some new events we might want to
         send out to other servers.
         """
+        # We just use the minimum stream ordering and ignore the vector clock
+        # component. This is safe to do as long as we *always* ignore the vector
+        # clock components.
+        current_id = max_token.stream
+
         self._last_poked_id = max(current_id, self._last_poked_id)
 
         if self._is_processing:
@@ -297,6 +302,8 @@ class FederationSender:
         sent_pdus_destination_dist_total.inc(len(destinations))
         sent_pdus_destination_dist_count.inc()
 
+        assert pdu.internal_metadata.stream_ordering
+
         # track the fact that we have a PDU for these destinations,
         # to allow us to perform catch-up later on if the remote is unreachable
         # for a while.
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index bc99af3fdd..db8e456fe8 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -158,6 +158,7 @@ class PerDestinationQueue:
             # yet know if we have anything to catch up (None)
             self._pending_pdus.append(pdu)
         else:
+            assert pdu.internal_metadata.stream_ordering
             self._catchup_last_skipped = pdu.internal_metadata.stream_ordering
 
         self.attempt_new_transaction()
@@ -361,6 +362,7 @@ class PerDestinationQueue:
                         last_successful_stream_ordering = (
                             final_pdu.internal_metadata.stream_ordering
                         )
+                        assert last_successful_stream_ordering
                         await self._store.set_destination_last_successful_stream_ordering(
                             self._destination, last_successful_stream_ordering
                         )
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 3a6b95631e..b53e7a20ec 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -154,7 +154,7 @@ class Authenticator:
         )
 
         logger.debug("Request from %s", origin)
-        request.authenticated_entity = origin
+        request.requester = origin
 
         # If we get a valid signed request from the other side, its probably
         # alive
@@ -440,13 +440,13 @@ class FederationEventServlet(BaseFederationServlet):
 
 
 class FederationStateV1Servlet(BaseFederationServlet):
-    PATH = "/state/(?P<context>[^/]*)/?"
+    PATH = "/state/(?P<room_id>[^/]*)/?"
 
-    # This is when someone asks for all data for a given context.
-    async def on_GET(self, origin, content, query, context):
-        return await self.handler.on_context_state_request(
+    # This is when someone asks for all data for a given room.
+    async def on_GET(self, origin, content, query, room_id):
+        return await self.handler.on_room_state_request(
             origin,
-            context,
+            room_id,
             parse_string_from_args(query, "event_id", None, required=False),
         )
 
@@ -463,16 +463,16 @@ class FederationStateIdsServlet(BaseFederationServlet):
 
 
 class FederationBackfillServlet(BaseFederationServlet):
-    PATH = "/backfill/(?P<context>[^/]*)/?"
+    PATH = "/backfill/(?P<room_id>[^/]*)/?"
 
-    async def on_GET(self, origin, content, query, context):
+    async def on_GET(self, origin, content, query, room_id):
         versions = [x.decode("ascii") for x in query[b"v"]]
         limit = parse_integer_from_args(query, "limit", None)
 
         if not limit:
             return 400, {"error": "Did not include limit param"}
 
-        return await self.handler.on_backfill_request(origin, context, versions, limit)
+        return await self.handler.on_backfill_request(origin, room_id, versions, limit)
 
 
 class FederationQueryServlet(BaseFederationServlet):
@@ -487,9 +487,9 @@ class FederationQueryServlet(BaseFederationServlet):
 
 
 class FederationMakeJoinServlet(BaseFederationServlet):
-    PATH = "/make_join/(?P<context>[^/]*)/(?P<user_id>[^/]*)"
+    PATH = "/make_join/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
 
-    async def on_GET(self, origin, _content, query, context, user_id):
+    async def on_GET(self, origin, _content, query, room_id, user_id):
         """
         Args:
             origin (unicode): The authenticated server_name of the calling server
@@ -511,16 +511,16 @@ class FederationMakeJoinServlet(BaseFederationServlet):
             supported_versions = ["1"]
 
         content = await self.handler.on_make_join_request(
-            origin, context, user_id, supported_versions=supported_versions
+            origin, room_id, user_id, supported_versions=supported_versions
         )
         return 200, content
 
 
 class FederationMakeLeaveServlet(BaseFederationServlet):
-    PATH = "/make_leave/(?P<context>[^/]*)/(?P<user_id>[^/]*)"
+    PATH = "/make_leave/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
 
-    async def on_GET(self, origin, content, query, context, user_id):
-        content = await self.handler.on_make_leave_request(origin, context, user_id)
+    async def on_GET(self, origin, content, query, room_id, user_id):
+        content = await self.handler.on_make_leave_request(origin, room_id, user_id)
         return 200, content
 
 
@@ -528,7 +528,7 @@ class FederationV1SendLeaveServlet(BaseFederationServlet):
     PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
 
     async def on_PUT(self, origin, content, query, room_id, event_id):
-        content = await self.handler.on_send_leave_request(origin, content, room_id)
+        content = await self.handler.on_send_leave_request(origin, content)
         return 200, (200, content)
 
 
@@ -538,43 +538,43 @@ class FederationV2SendLeaveServlet(BaseFederationServlet):
     PREFIX = FEDERATION_V2_PREFIX
 
     async def on_PUT(self, origin, content, query, room_id, event_id):
-        content = await self.handler.on_send_leave_request(origin, content, room_id)
+        content = await self.handler.on_send_leave_request(origin, content)
         return 200, content
 
 
 class FederationEventAuthServlet(BaseFederationServlet):
-    PATH = "/event_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
+    PATH = "/event_auth/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
 
-    async def on_GET(self, origin, content, query, context, event_id):
-        return await self.handler.on_event_auth(origin, context, event_id)
+    async def on_GET(self, origin, content, query, room_id, event_id):
+        return await self.handler.on_event_auth(origin, room_id, event_id)
 
 
 class FederationV1SendJoinServlet(BaseFederationServlet):
-    PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
+    PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
 
-    async def on_PUT(self, origin, content, query, context, event_id):
-        # TODO(paul): assert that context/event_id parsed from path actually
+    async def on_PUT(self, origin, content, query, room_id, event_id):
+        # TODO(paul): assert that room_id/event_id parsed from path actually
         #   match those given in content
-        content = await self.handler.on_send_join_request(origin, content, context)
+        content = await self.handler.on_send_join_request(origin, content)
         return 200, (200, content)
 
 
 class FederationV2SendJoinServlet(BaseFederationServlet):
-    PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
+    PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
 
     PREFIX = FEDERATION_V2_PREFIX
 
-    async def on_PUT(self, origin, content, query, context, event_id):
-        # TODO(paul): assert that context/event_id parsed from path actually
+    async def on_PUT(self, origin, content, query, room_id, event_id):
+        # TODO(paul): assert that room_id/event_id parsed from path actually
         #   match those given in content
-        content = await self.handler.on_send_join_request(origin, content, context)
+        content = await self.handler.on_send_join_request(origin, content)
         return 200, content
 
 
 class FederationV1InviteServlet(BaseFederationServlet):
-    PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
+    PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
 
-    async def on_PUT(self, origin, content, query, context, event_id):
+    async def on_PUT(self, origin, content, query, room_id, event_id):
         # We don't get a room version, so we have to assume its EITHER v1 or
         # v2. This is "fine" as the only difference between V1 and V2 is the
         # state resolution algorithm, and we don't use that for processing
@@ -589,12 +589,12 @@ class FederationV1InviteServlet(BaseFederationServlet):
 
 
 class FederationV2InviteServlet(BaseFederationServlet):
-    PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
+    PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
 
     PREFIX = FEDERATION_V2_PREFIX
 
-    async def on_PUT(self, origin, content, query, context, event_id):
-        # TODO(paul): assert that context/event_id parsed from path actually
+    async def on_PUT(self, origin, content, query, room_id, event_id):
+        # TODO(paul): assert that room_id/event_id parsed from path actually
         #   match those given in content
 
         room_version = content["room_version"]
@@ -616,9 +616,7 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
     PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
 
     async def on_PUT(self, origin, content, query, room_id):
-        content = await self.handler.on_exchange_third_party_invite_request(
-            room_id, content
-        )
+        content = await self.handler.on_exchange_third_party_invite_request(content)
         return 200, content