diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 32a8a2ee46..218df884b0 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -15,13 +15,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union
-
-import six
-from six import iteritems
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ Dict,
+ List,
+ Match,
+ Optional,
+ Tuple,
+ Union,
+)
-from canonicaljson import json
-from prometheus_client import Counter
+from prometheus_client import Counter, Histogram
from twisted.internet import defer
from twisted.internet.abstract import isIPAddress
@@ -55,10 +62,13 @@ from synapse.replication.http.federation import (
ReplicationGetQueryRestServlet,
)
from synapse.types import JsonDict, get_domain_from_id
-from synapse.util import glob_to_regex, unwrapFirstError
+from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
# when processing incoming transactions, we try to handle multiple rooms in
# parallel, up to this limit.
TRANSACTION_CONCURRENCY_LIMIT = 10
@@ -73,6 +83,10 @@ received_queries_counter = Counter(
"synapse_federation_server_received_queries", "", ["type"]
)
+pdu_process_time = Histogram(
+ "synapse_federation_server_pdu_process_time", "Time taken to process an event",
+)
+
class FederationServer(FederationBase):
def __init__(self, hs):
@@ -94,6 +108,9 @@ class FederationServer(FederationBase):
# We cache responses to state queries, as they take a while and often
# come in waves.
self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
+ self._state_ids_resp_cache = ResponseCache(
+ hs, "state_ids_resp", timeout_ms=30000
+ )
async def on_backfill_request(
self, origin: str, room_id: str, versions: List[str], limit: int
@@ -274,21 +291,22 @@ class FederationServer(FederationBase):
for pdu in pdus_by_room[room_id]:
event_id = pdu.event_id
- with nested_logging_context(event_id):
- try:
- await self._handle_received_pdu(origin, pdu)
- pdu_results[event_id] = {}
- except FederationError as e:
- logger.warning("Error handling PDU %s: %s", event_id, e)
- pdu_results[event_id] = {"error": str(e)}
- except Exception as e:
- f = failure.Failure()
- pdu_results[event_id] = {"error": str(e)}
- logger.error(
- "Failed to handle PDU %s",
- event_id,
- exc_info=(f.type, f.value, f.getTracebackObject()),
- )
+ with pdu_process_time.time():
+ with nested_logging_context(event_id):
+ try:
+ await self._handle_received_pdu(origin, pdu)
+ pdu_results[event_id] = {}
+ except FederationError as e:
+ logger.warning("Error handling PDU %s: %s", event_id, e)
+ pdu_results[event_id] = {"error": str(e)}
+ except Exception as e:
+ f = failure.Failure()
+ pdu_results[event_id] = {"error": str(e)}
+ logger.error(
+ "Failed to handle PDU %s",
+ event_id,
+ exc_info=(f.type, f.value, f.getTracebackObject()),
+ )
await concurrently_execute(
process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
@@ -360,10 +378,16 @@ class FederationServer(FederationBase):
if not in_room:
raise AuthError(403, "Host not in room.")
+ resp = await self._state_ids_resp_cache.wrap(
+ (room_id, event_id), self._on_state_ids_request_compute, room_id, event_id,
+ )
+
+ return 200, resp
+
+ async def _on_state_ids_request_compute(self, room_id, event_id):
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
auth_chain_ids = await self.store.get_auth_chain_ids(state_ids)
-
- return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
+ return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
async def _on_context_state_request_compute(
self, room_id: str, event_id: str
@@ -524,9 +548,9 @@ class FederationServer(FederationBase):
json_result = {} # type: Dict[str, Dict[str, dict]]
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
- for key_id, json_bytes in keys.items():
+ for key_id, json_str in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
- key_id: json.loads(json_bytes)
+ key_id: json_decoder.decode(json_str)
}
logger.info(
@@ -534,9 +558,9 @@ class FederationServer(FederationBase):
",".join(
(
"%s for %s:%s" % (key_id, user_id, device_id)
- for user_id, user_keys in iteritems(json_result)
- for device_id, device_keys in iteritems(user_keys)
- for key_id, _ in iteritems(device_keys)
+ for user_id, user_keys in json_result.items()
+ for device_id, device_keys in user_keys.items()
+ for key_id, _ in device_keys.items()
)
),
)
@@ -715,7 +739,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
# server name is a literal IP
allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
if not isinstance(allow_ip_literals, bool):
- logger.warning("Ignorning non-bool allow_ip_literals flag")
+ logger.warning("Ignoring non-bool allow_ip_literals flag")
allow_ip_literals = True
if not allow_ip_literals:
# check for ipv6 literals. These start with '['.
@@ -729,7 +753,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
# next, check the deny list
deny = acl_event.content.get("deny", [])
if not isinstance(deny, (list, tuple)):
- logger.warning("Ignorning non-list deny ACL %s", deny)
+ logger.warning("Ignoring non-list deny ACL %s", deny)
deny = []
for e in deny:
if _acl_entry_matches(server_name, e):
@@ -739,7 +763,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
# then the allow list.
allow = acl_event.content.get("allow", [])
if not isinstance(allow, (list, tuple)):
- logger.warning("Ignorning non-list allow ACL %s", allow)
+ logger.warning("Ignoring non-list allow ACL %s", allow)
allow = []
for e in allow:
if _acl_entry_matches(server_name, e):
@@ -752,7 +776,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
def _acl_entry_matches(server_name: str, acl_entry: str) -> Match:
- if not isinstance(acl_entry, six.string_types):
+ if not isinstance(acl_entry, str):
logger.warning(
"Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
)
@@ -761,16 +785,35 @@ def _acl_entry_matches(server_name: str, acl_entry: str) -> Match:
return regex.match(server_name)
-class FederationHandlerRegistry(object):
+class FederationHandlerRegistry:
"""Allows classes to register themselves as handlers for a given EDU or
query type for incoming federation traffic.
"""
- def __init__(self):
- self.edu_handlers = {}
- self.query_handlers = {}
+ def __init__(self, hs: "HomeServer"):
+ self.config = hs.config
+ self.http_client = hs.get_simple_http_client()
+ self.clock = hs.get_clock()
+ self._instance_name = hs.get_instance_name()
- def register_edu_handler(self, edu_type: str, handler: Callable[[str, dict], None]):
+ # These are safe to load in monolith mode, but will explode if we try
+ # and use them. However we have guards before we use them to ensure that
+ # we don't route to ourselves, and in monolith mode that will always be
+ # the case.
+ self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs)
+ self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs)
+
+ self.edu_handlers = (
+ {}
+ ) # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
+ self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]]
+
+ # Map from type to instance name that we should route EDU handling to.
+ self._edu_type_to_instance = {} # type: Dict[str, str]
+
+ def register_edu_handler(
+ self, edu_type: str, handler: Callable[[str, dict], Awaitable[None]]
+ ):
"""Sets the handler callable that will be used to handle an incoming
federation EDU of the given type.
@@ -807,66 +850,56 @@ class FederationHandlerRegistry(object):
self.query_handlers[query_type] = handler
+ def register_instance_for_edu(self, edu_type: str, instance_name: str):
+ """Register that the EDU handler is on a different instance than master.
+ """
+ self._edu_type_to_instance[edu_type] = instance_name
+
async def on_edu(self, edu_type: str, origin: str, content: dict):
+ if not self.config.use_presence and edu_type == "m.presence":
+ return
+
+ # Check if we have a handler on this instance
handler = self.edu_handlers.get(edu_type)
- if not handler:
- logger.warning("No handler registered for EDU type %s", edu_type)
+ if handler:
+ with start_active_span_from_edu(content, "handle_edu"):
+ try:
+ await handler(origin, content)
+ except SynapseError as e:
+ logger.info("Failed to handle edu %r: %r", edu_type, e)
+ except Exception:
+ logger.exception("Failed to handle edu %r", edu_type)
return
- with start_active_span_from_edu(content, "handle_edu"):
+ # Check if we can route it somewhere else that isn't us
+ route_to = self._edu_type_to_instance.get(edu_type, "master")
+ if route_to != self._instance_name:
try:
- await handler(origin, content)
+ await self._send_edu(
+ instance_name=route_to,
+ edu_type=edu_type,
+ origin=origin,
+ content=content,
+ )
except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception:
logger.exception("Failed to handle edu %r", edu_type)
-
- def on_query(self, query_type: str, args: dict) -> defer.Deferred:
- handler = self.query_handlers.get(query_type)
- if not handler:
- logger.warning("No handler registered for query type %s", query_type)
- raise NotFoundError("No handler for Query type '%s'" % (query_type,))
-
- return handler(args)
-
-
-class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
- """A FederationHandlerRegistry for worker processes.
-
- When receiving EDU or queries it will check if an appropriate handler has
- been registered on the worker, if there isn't one then it calls off to the
- master process.
- """
-
- def __init__(self, hs):
- self.config = hs.config
- self.http_client = hs.get_simple_http_client()
- self.clock = hs.get_clock()
-
- self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs)
- self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs)
-
- super(ReplicationFederationHandlerRegistry, self).__init__()
-
- async def on_edu(self, edu_type: str, origin: str, content: dict):
- """Overrides FederationHandlerRegistry
- """
- if not self.config.use_presence and edu_type == "m.presence":
return
- handler = self.edu_handlers.get(edu_type)
- if handler:
- return await super(ReplicationFederationHandlerRegistry, self).on_edu(
- edu_type, origin, content
- )
-
- return await self._send_edu(edu_type=edu_type, origin=origin, content=content)
+ # Oh well, let's just log and move on.
+ logger.warning("No handler registered for EDU type %s", edu_type)
async def on_query(self, query_type: str, args: dict):
- """Overrides FederationHandlerRegistry
- """
handler = self.query_handlers.get(query_type)
if handler:
return await handler(args)
- return await self._get_query_client(query_type=query_type, args=args)
+ # Check if we can route it somewhere else that isn't us
+ if self._instance_name == "master":
+ return await self._get_query_client(query_type=query_type, args=args)
+
+ # Uh oh, no handler! Let's raise an exception so the request returns an
+ # error.
+ logger.warning("No handler registered for query type %s", query_type)
+ raise NotFoundError("No handler for Query type '%s'" % (query_type,))
|