summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_client.py8
-rw-r--r--synapse/federation/federation_server.py13
-rw-r--r--synapse/federation/send_queue.py2
-rw-r--r--synapse/federation/sender/__init__.py17
-rw-r--r--synapse/federation/sender/per_destination_queue.py2
-rw-r--r--synapse/federation/transport/client.py96
-rw-r--r--synapse/federation/transport/server.py6
7 files changed, 70 insertions, 74 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 994e6c8d5a..38ac7ec699 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -135,7 +135,7 @@ class FederationClient(FederationBase):
                 and try the request anyway.
 
         Returns:
-            a Deferred which will eventually yield a JSON object from the
+            a Awaitable which will eventually yield a JSON object from the
             response
         """
         sent_queries_counter.labels(query_type).inc()
@@ -157,7 +157,7 @@ class FederationClient(FederationBase):
             content (dict): The query content.
 
         Returns:
-            a Deferred which will eventually yield a JSON object from the
+            an Awaitable which will eventually yield a JSON object from the
             response
         """
         sent_queries_counter.labels("client_device_keys").inc()
@@ -180,7 +180,7 @@ class FederationClient(FederationBase):
             content (dict): The query content.
 
         Returns:
-            a Deferred which will eventually yield a JSON object from the
+            an Awaitable which will eventually yield a JSON object from the
             response
         """
         sent_queries_counter.labels("client_one_time_keys").inc()
@@ -900,7 +900,7 @@ class FederationClient(FederationBase):
                 party instance
 
         Returns:
-            Deferred[Dict[str, Any]]: The response from the remote server, or None if
+            Awaitable[Dict[str, Any]]: The response from the remote server, or None if
             `remote_server` is the same as the local server_name
 
         Raises:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 23625ba995..11c5d63298 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -109,6 +109,9 @@ class FederationServer(FederationBase):
         # We cache responses to state queries, as they take a while and often
         # come in waves.
         self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
+        self._state_ids_resp_cache = ResponseCache(
+            hs, "state_ids_resp", timeout_ms=30000
+        )
 
     async def on_backfill_request(
         self, origin: str, room_id: str, versions: List[str], limit: int
@@ -376,10 +379,16 @@ class FederationServer(FederationBase):
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
+        resp = await self._state_ids_resp_cache.wrap(
+            (room_id, event_id), self._on_state_ids_request_compute, room_id, event_id,
+        )
+
+        return 200, resp
+
+    async def _on_state_ids_request_compute(self, room_id, event_id):
         state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
         auth_chain_ids = await self.store.get_auth_chain_ids(state_ids)
-
-        return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
+        return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
 
     async def _on_context_state_request_compute(
         self, room_id: str, event_id: str
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 4fc9ff92e5..2b0ab2dcbf 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -57,7 +57,7 @@ class FederationRemoteSendQueue(object):
 
         # We may have multiple federation sender instances, so we need to track
         # their positions separately.
-        self._sender_instances = hs.config.federation.federation_shard_config.instances
+        self._sender_instances = hs.config.worker.federation_shard_config.instances
         self._sender_positions = {}
 
         # Pending presence map user_id -> UserPresenceState
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index b328a4df09..94cc63001e 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -70,7 +70,7 @@ class FederationSender(object):
         self._transaction_manager = TransactionManager(hs)
 
         self._instance_name = hs.get_instance_name()
-        self._federation_shard_config = hs.config.federation.federation_shard_config
+        self._federation_shard_config = hs.config.worker.federation_shard_config
 
         # map from destination to PerDestinationQueue
         self._per_destination_queues = {}  # type: Dict[str, PerDestinationQueue]
@@ -288,8 +288,7 @@ class FederationSender(object):
         for destination in destinations:
             self._get_per_destination_queue(destination).send_pdu(pdu, order)
 
-    @defer.inlineCallbacks
-    def send_read_receipt(self, receipt: ReadReceipt):
+    async def send_read_receipt(self, receipt: ReadReceipt) -> None:
         """Send a RR to any other servers in the room
 
         Args:
@@ -330,7 +329,7 @@ class FederationSender(object):
         room_id = receipt.room_id
 
         # Work out which remote servers should be poked and poke them.
-        domains = yield self.state.get_current_hosts_in_room(room_id)
+        domains = await self.state.get_current_hosts_in_room(room_id)
         domains = [
             d
             for d in domains
@@ -385,8 +384,7 @@ class FederationSender(object):
             queue.flush_read_receipts_for_room(room_id)
 
     @preserve_fn  # the caller should not yield on this
-    @defer.inlineCallbacks
-    def send_presence(self, states: List[UserPresenceState]):
+    async def send_presence(self, states: List[UserPresenceState]):
         """Send the new presence states to the appropriate destinations.
 
         This actually queues up the presence states ready for sending and
@@ -421,7 +419,7 @@ class FederationSender(object):
                 if not states_map:
                     break
 
-                yield self._process_presence_inner(list(states_map.values()))
+                await self._process_presence_inner(list(states_map.values()))
         except Exception:
             logger.exception("Error sending presence states to servers")
         finally:
@@ -448,12 +446,11 @@ class FederationSender(object):
             self._get_per_destination_queue(destination).send_presence(states)
 
     @measure_func("txnqueue._process_presence")
-    @defer.inlineCallbacks
-    def _process_presence_inner(self, states: List[UserPresenceState]):
+    async def _process_presence_inner(self, states: List[UserPresenceState]):
         """Given a list of states populate self.pending_presence_by_dest and
         poke to send a new transaction to each destination
         """
-        hosts_and_states = yield get_interested_remotes(self.store, states, self.state)
+        hosts_and_states = await get_interested_remotes(self.store, states, self.state)
 
         for destinations, states in hosts_and_states:
             for destination in destinations:
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 3436741783..dd150f89a6 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -75,7 +75,7 @@ class PerDestinationQueue(object):
         self._store = hs.get_datastore()
         self._transaction_manager = transaction_manager
         self._instance_name = hs.get_instance_name()
-        self._federation_shard_config = hs.config.federation.federation_shard_config
+        self._federation_shard_config = hs.config.worker.federation_shard_config
 
         self._should_send_on_this_instance = True
         if not self._federation_shard_config.should_handle(
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index cfdf23d366..9ea821dbb2 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -18,8 +18,6 @@ import logging
 import urllib
 from typing import Any, Dict, Optional
 
-from twisted.internet import defer
-
 from synapse.api.constants import Membership
 from synapse.api.errors import Codes, HttpResponseException, SynapseError
 from synapse.api.urls import (
@@ -51,7 +49,7 @@ class TransportLayerClient(object):
             event_id (str): The event we want the context at.
 
         Returns:
-            Deferred: Results in a dict received from the remote homeserver.
+            Awaitable: Results in a dict received from the remote homeserver.
         """
         logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id)
 
@@ -75,7 +73,7 @@ class TransportLayerClient(object):
                 giving up. None indicates no timeout.
 
         Returns:
-            Deferred: Results in a dict received from the remote homeserver.
+            Awaitable: Results in a dict received from the remote homeserver.
         """
         logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id)
 
@@ -96,7 +94,7 @@ class TransportLayerClient(object):
             limit (int)
 
         Returns:
-            Deferred: Results in a dict received from the remote homeserver.
+            Awaitable: Results in a dict received from the remote homeserver.
         """
         logger.debug(
             "backfill dest=%s, room_id=%s, event_tuples=%r, limit=%s",
@@ -118,16 +116,15 @@ class TransportLayerClient(object):
             destination, path=path, args=args, try_trailing_slash_on_400=True
         )
 
-    @defer.inlineCallbacks
     @log_function
-    def send_transaction(self, transaction, json_data_callback=None):
+    async def send_transaction(self, transaction, json_data_callback=None):
         """ Sends the given Transaction to its destination
 
         Args:
             transaction (Transaction)
 
         Returns:
-            Deferred: Succeeds when we get a 2xx HTTP response. The result
+            Succeeds when we get a 2xx HTTP response. The result
             will be the decoded JSON body.
 
             Fails with ``HTTPRequestException`` if we get an HTTP response
@@ -154,7 +151,7 @@ class TransportLayerClient(object):
 
         path = _create_v1_path("/send/%s", transaction.transaction_id)
 
-        response = yield self.client.put_json(
+        response = await self.client.put_json(
             transaction.destination,
             path=path,
             data=json_data,
@@ -166,14 +163,13 @@ class TransportLayerClient(object):
 
         return response
 
-    @defer.inlineCallbacks
     @log_function
-    def make_query(
+    async def make_query(
         self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False
     ):
         path = _create_v1_path("/query/%s", query_type)
 
-        content = yield self.client.get_json(
+        content = await self.client.get_json(
             destination=destination,
             path=path,
             args=args,
@@ -184,9 +180,10 @@ class TransportLayerClient(object):
 
         return content
 
-    @defer.inlineCallbacks
     @log_function
-    def make_membership_event(self, destination, room_id, user_id, membership, params):
+    async def make_membership_event(
+        self, destination, room_id, user_id, membership, params
+    ):
         """Asks a remote server to build and sign us a membership event
 
         Note that this does not append any events to any graphs.
@@ -200,7 +197,7 @@ class TransportLayerClient(object):
                 request.
 
         Returns:
-            Deferred: Succeeds when we get a 2xx HTTP response. The result
+            Succeeds when we get a 2xx HTTP response. The result
             will be the decoded JSON body (ie, the new event).
 
             Fails with ``HTTPRequestException`` if we get an HTTP response
@@ -231,7 +228,7 @@ class TransportLayerClient(object):
             ignore_backoff = True
             retry_on_dns_fail = True
 
-        content = yield self.client.get_json(
+        content = await self.client.get_json(
             destination=destination,
             path=path,
             args=params,
@@ -242,34 +239,31 @@ class TransportLayerClient(object):
 
         return content
 
-    @defer.inlineCallbacks
     @log_function
-    def send_join_v1(self, destination, room_id, event_id, content):
+    async def send_join_v1(self, destination, room_id, event_id, content):
         path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
 
-        response = yield self.client.put_json(
+        response = await self.client.put_json(
             destination=destination, path=path, data=content
         )
 
         return response
 
-    @defer.inlineCallbacks
     @log_function
-    def send_join_v2(self, destination, room_id, event_id, content):
+    async def send_join_v2(self, destination, room_id, event_id, content):
         path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
 
-        response = yield self.client.put_json(
+        response = await self.client.put_json(
             destination=destination, path=path, data=content
         )
 
         return response
 
-    @defer.inlineCallbacks
     @log_function
-    def send_leave_v1(self, destination, room_id, event_id, content):
+    async def send_leave_v1(self, destination, room_id, event_id, content):
         path = _create_v1_path("/send_leave/%s/%s", room_id, event_id)
 
-        response = yield self.client.put_json(
+        response = await self.client.put_json(
             destination=destination,
             path=path,
             data=content,
@@ -282,12 +276,11 @@ class TransportLayerClient(object):
 
         return response
 
-    @defer.inlineCallbacks
     @log_function
-    def send_leave_v2(self, destination, room_id, event_id, content):
+    async def send_leave_v2(self, destination, room_id, event_id, content):
         path = _create_v2_path("/send_leave/%s/%s", room_id, event_id)
 
-        response = yield self.client.put_json(
+        response = await self.client.put_json(
             destination=destination,
             path=path,
             data=content,
@@ -300,31 +293,28 @@ class TransportLayerClient(object):
 
         return response
 
-    @defer.inlineCallbacks
     @log_function
-    def send_invite_v1(self, destination, room_id, event_id, content):
+    async def send_invite_v1(self, destination, room_id, event_id, content):
         path = _create_v1_path("/invite/%s/%s", room_id, event_id)
 
-        response = yield self.client.put_json(
+        response = await self.client.put_json(
             destination=destination, path=path, data=content, ignore_backoff=True
         )
 
         return response
 
-    @defer.inlineCallbacks
     @log_function
-    def send_invite_v2(self, destination, room_id, event_id, content):
+    async def send_invite_v2(self, destination, room_id, event_id, content):
         path = _create_v2_path("/invite/%s/%s", room_id, event_id)
 
-        response = yield self.client.put_json(
+        response = await self.client.put_json(
             destination=destination, path=path, data=content, ignore_backoff=True
         )
 
         return response
 
-    @defer.inlineCallbacks
     @log_function
-    def get_public_rooms(
+    async def get_public_rooms(
         self,
         remote_server: str,
         limit: Optional[int] = None,
@@ -355,7 +345,7 @@ class TransportLayerClient(object):
             data["filter"] = search_filter
 
             try:
-                response = yield self.client.post_json(
+                response = await self.client.post_json(
                     destination=remote_server, path=path, data=data, ignore_backoff=True
                 )
             except HttpResponseException as e:
@@ -381,7 +371,7 @@ class TransportLayerClient(object):
                 args["since"] = [since_token]
 
             try:
-                response = yield self.client.get_json(
+                response = await self.client.get_json(
                     destination=remote_server, path=path, args=args, ignore_backoff=True
                 )
             except HttpResponseException as e:
@@ -396,29 +386,26 @@ class TransportLayerClient(object):
 
         return response
 
-    @defer.inlineCallbacks
     @log_function
-    def exchange_third_party_invite(self, destination, room_id, event_dict):
+    async def exchange_third_party_invite(self, destination, room_id, event_dict):
         path = _create_v1_path("/exchange_third_party_invite/%s", room_id)
 
-        response = yield self.client.put_json(
+        response = await self.client.put_json(
             destination=destination, path=path, data=event_dict
         )
 
         return response
 
-    @defer.inlineCallbacks
     @log_function
-    def get_event_auth(self, destination, room_id, event_id):
+    async def get_event_auth(self, destination, room_id, event_id):
         path = _create_v1_path("/event_auth/%s/%s", room_id, event_id)
 
-        content = yield self.client.get_json(destination=destination, path=path)
+        content = await self.client.get_json(destination=destination, path=path)
 
         return content
 
-    @defer.inlineCallbacks
     @log_function
-    def query_client_keys(self, destination, query_content, timeout):
+    async def query_client_keys(self, destination, query_content, timeout):
         """Query the device keys for a list of user ids hosted on a remote
         server.
 
@@ -453,14 +440,13 @@ class TransportLayerClient(object):
         """
         path = _create_v1_path("/user/keys/query")
 
-        content = yield self.client.post_json(
+        content = await self.client.post_json(
             destination=destination, path=path, data=query_content, timeout=timeout
         )
         return content
 
-    @defer.inlineCallbacks
     @log_function
-    def query_user_devices(self, destination, user_id, timeout):
+    async def query_user_devices(self, destination, user_id, timeout):
         """Query the devices for a user id hosted on a remote server.
 
         Response:
@@ -493,14 +479,13 @@ class TransportLayerClient(object):
         """
         path = _create_v1_path("/user/devices/%s", user_id)
 
-        content = yield self.client.get_json(
+        content = await self.client.get_json(
             destination=destination, path=path, timeout=timeout
         )
         return content
 
-    @defer.inlineCallbacks
     @log_function
-    def claim_client_keys(self, destination, query_content, timeout):
+    async def claim_client_keys(self, destination, query_content, timeout):
         """Claim one-time keys for a list of devices hosted on a remote server.
 
         Request:
@@ -532,14 +517,13 @@ class TransportLayerClient(object):
 
         path = _create_v1_path("/user/keys/claim")
 
-        content = yield self.client.post_json(
+        content = await self.client.post_json(
             destination=destination, path=path, data=query_content, timeout=timeout
         )
         return content
 
-    @defer.inlineCallbacks
     @log_function
-    def get_missing_events(
+    async def get_missing_events(
         self,
         destination,
         room_id,
@@ -551,7 +535,7 @@ class TransportLayerClient(object):
     ):
         path = _create_v1_path("/get_missing_events/%s", room_id)
 
-        content = yield self.client.post_json(
+        content = await self.client.post_json(
             destination=destination,
             path=path,
             data={
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 24f7d4b3bc..5e111aa902 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -338,6 +338,12 @@ class BaseFederationServlet(object):
                 if origin:
                     with ratelimiter.ratelimit(origin) as d:
                         await d
+                        if request._disconnected:
+                            logger.warning(
+                                "client disconnected before we started processing "
+                                "request"
+                            )
+                            return -1, None
                         response = await func(
                             origin, content, request.args, *args, **kwargs
                         )