diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index da06ab379d..32a8a2ee46 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
+# Copyright 2019 Matrix.org Federation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union
import six
from six import iteritems
@@ -36,21 +38,24 @@ from synapse.api.errors import (
UnsupportedRoomVersionError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
-from synapse.crypto.event_signing import compute_event_signature
-from synapse.events import room_version_to_event_format
+from synapse.events import EventBase
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
from synapse.http.endpoint import parse_server_name
-from synapse.logging.context import nested_logging_context
+from synapse.logging.context import (
+ make_deferred_yieldable,
+ nested_logging_context,
+ run_in_background,
+)
from synapse.logging.opentracing import log_kv, start_active_span_from_edu, trace
from synapse.logging.utils import log_function
from synapse.replication.http.federation import (
ReplicationFederationSendEduRestServlet,
ReplicationGetQueryRestServlet,
)
-from synapse.types import get_domain_from_id
-from synapse.util import glob_to_regex
+from synapse.types import JsonDict, get_domain_from_id
+from synapse.util import glob_to_regex, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
@@ -75,6 +80,9 @@ class FederationServer(FederationBase):
self.auth = hs.get_auth()
self.handler = hs.get_handlers().federation_handler
+ self.state = hs.get_state_handler()
+
+ self.device_handler = hs.get_device_handler()
self._server_linearizer = Linearizer("fed_server")
self._transaction_linearizer = Linearizer("fed_txn_handler")
@@ -87,14 +95,14 @@ class FederationServer(FederationBase):
# come in waves.
self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
- @defer.inlineCallbacks
- @log_function
- def on_backfill_request(self, origin, room_id, versions, limit):
- with (yield self._server_linearizer.queue((origin, room_id))):
+ async def on_backfill_request(
+ self, origin: str, room_id: str, versions: List[str], limit: int
+ ) -> Tuple[int, Dict[str, Any]]:
+ with (await self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
- pdus = yield self.handler.on_backfill_request(
+ pdus = await self.handler.on_backfill_request(
origin, room_id, versions, limit
)
@@ -102,76 +110,114 @@ class FederationServer(FederationBase):
return 200, res
- @defer.inlineCallbacks
- @log_function
- def on_incoming_transaction(self, origin, transaction_data):
+ async def on_incoming_transaction(
+ self, origin: str, transaction_data: JsonDict
+ ) -> Tuple[int, Dict[str, Any]]:
# keep this as early as possible to make the calculated origin ts as
# accurate as possible.
request_time = self._clock.time_msec()
transaction = Transaction(**transaction_data)
- if not transaction.transaction_id:
+ if not transaction.transaction_id: # type: ignore
raise Exception("Transaction missing transaction_id")
- logger.debug("[%s] Got transaction", transaction.transaction_id)
+ logger.debug("[%s] Got transaction", transaction.transaction_id) # type: ignore
# use a linearizer to ensure that we don't process the same transaction
# multiple times in parallel.
with (
- yield self._transaction_linearizer.queue(
- (origin, transaction.transaction_id)
+ await self._transaction_linearizer.queue(
+ (origin, transaction.transaction_id) # type: ignore
)
):
- result = yield self._handle_incoming_transaction(
+ result = await self._handle_incoming_transaction(
origin, transaction, request_time
)
return result
- @defer.inlineCallbacks
- def _handle_incoming_transaction(self, origin, transaction, request_time):
+ async def _handle_incoming_transaction(
+ self, origin: str, transaction: Transaction, request_time: int
+ ) -> Tuple[int, Dict[str, Any]]:
""" Process an incoming transaction and return the HTTP response
Args:
- origin (unicode): the server making the request
- transaction (Transaction): incoming transaction
- request_time (int): timestamp that the HTTP request arrived at
+ origin: the server making the request
+ transaction: incoming transaction
+ request_time: timestamp that the HTTP request arrived at
Returns:
- Deferred[(int, object)]: http response code and body
+ HTTP response code and body
"""
- response = yield self.transaction_actions.have_responded(origin, transaction)
+ response = await self.transaction_actions.have_responded(origin, transaction)
if response:
logger.debug(
"[%s] We've already responded to this request",
- transaction.transaction_id,
+ transaction.transaction_id, # type: ignore
)
return response
- logger.debug("[%s] Transaction is new", transaction.transaction_id)
+ logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore
- # Reject if PDU count > 50 and EDU count > 100
- if len(transaction.pdus) > 50 or (
- hasattr(transaction, "edus") and len(transaction.edus) > 100
+ # Reject if PDU count > 50 or EDU count > 100
+ if len(transaction.pdus) > 50 or ( # type: ignore
+ hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore
):
logger.info("Transaction PDU or EDU count too large. Returning 400")
response = {}
- yield self.transaction_actions.set_response(
+ await self.transaction_actions.set_response(
origin, transaction, 400, response
)
return 400, response
- received_pdus_counter.inc(len(transaction.pdus))
+ # We process PDUs and EDUs in parallel. This is important as we don't
+ # want to block things like to device messages from reaching clients
+ # behind the potentially expensive handling of PDUs.
+ pdu_results, _ = await make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(
+ self._handle_pdus_in_txn, origin, transaction, request_time
+ ),
+ run_in_background(self._handle_edus_in_txn, origin, transaction),
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+ )
+
+ response = {"pdus": pdu_results}
+
+ logger.debug("Returning: %s", str(response))
+
+ await self.transaction_actions.set_response(origin, transaction, 200, response)
+ return 200, response
+
+ async def _handle_pdus_in_txn(
+ self, origin: str, transaction: Transaction, request_time: int
+ ) -> Dict[str, dict]:
+ """Process the PDUs in a received transaction.
+
+ Args:
+ origin: the server making the request
+ transaction: incoming transaction
+ request_time: timestamp that the HTTP request arrived at
+
+ Returns:
+ A map from event ID of a processed PDU to any errors we should
+ report back to the sending server.
+ """
+
+ received_pdus_counter.inc(len(transaction.pdus)) # type: ignore
origin_host, _ = parse_server_name(origin)
- pdus_by_room = {}
+ pdus_by_room = {} # type: Dict[str, List[EventBase]]
- for p in transaction.pdus:
+ for p in transaction.pdus: # type: ignore
if "unsigned" in p:
unsigned = p["unsigned"]
if "age" in unsigned:
@@ -196,24 +242,17 @@ class FederationServer(FederationBase):
continue
try:
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version(room_id)
except NotFoundError:
logger.info("Ignoring PDU for unknown room_id: %s", room_id)
continue
-
- try:
- format_ver = room_version_to_event_format(room_version)
- except UnsupportedRoomVersionError:
+ except UnsupportedRoomVersionError as e:
# this can happen if support for a given room version is withdrawn,
# so that we still get events for said room.
- logger.info(
- "Ignoring PDU for room %s with unknown version %s",
- room_id,
- room_version,
- )
+ logger.info("Ignoring PDU: %s", e)
continue
- event = event_from_pdu_json(p, format_ver)
+ event = event_from_pdu_json(p, room_version)
pdus_by_room.setdefault(room_id, []).append(event)
pdu_results = {}
@@ -222,13 +261,12 @@ class FederationServer(FederationBase):
# require callouts to other servers to fetch missing events), but
# impose a limit to avoid going too crazy with ram/cpu.
- @defer.inlineCallbacks
- def process_pdus_for_room(room_id):
+ async def process_pdus_for_room(room_id: str):
logger.debug("Processing PDUs for %s", room_id)
try:
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
except AuthError as e:
- logger.warn("Ignoring PDUs for room %s from banned server", room_id)
+ logger.warning("Ignoring PDUs for room %s from banned server", room_id)
for pdu in pdus_by_room[room_id]:
event_id = pdu.event_id
pdu_results[event_id] = e.error_dict()
@@ -238,10 +276,10 @@ class FederationServer(FederationBase):
event_id = pdu.event_id
with nested_logging_context(event_id):
try:
- yield self._handle_received_pdu(origin, pdu)
+ await self._handle_received_pdu(origin, pdu)
pdu_results[event_id] = {}
except FederationError as e:
- logger.warn("Error handling PDU %s: %s", event_id, e)
+ logger.warning("Error handling PDU %s: %s", event_id, e)
pdu_results[event_id] = {"error": str(e)}
except Exception as e:
f = failure.Failure()
@@ -252,36 +290,40 @@ class FederationServer(FederationBase):
exc_info=(f.type, f.value, f.getTracebackObject()),
)
- yield concurrently_execute(
+ await concurrently_execute(
process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
)
- if hasattr(transaction, "edus"):
- for edu in (Edu(**x) for x in transaction.edus):
- yield self.received_edu(origin, edu.edu_type, edu.content)
-
- response = {"pdus": pdu_results}
+ return pdu_results
- logger.debug("Returning: %s", str(response))
+ async def _handle_edus_in_txn(self, origin: str, transaction: Transaction):
+ """Process the EDUs in a received transaction.
+ """
- yield self.transaction_actions.set_response(origin, transaction, 200, response)
- return 200, response
+ async def _process_edu(edu_dict):
+ received_edus_counter.inc()
- @defer.inlineCallbacks
- def received_edu(self, origin, edu_type, content):
- received_edus_counter.inc()
- yield self.registry.on_edu(edu_type, origin, content)
+ edu = Edu(
+ origin=origin,
+ destination=self.server_name,
+ edu_type=edu_dict["edu_type"],
+ content=edu_dict["content"],
+ )
+ await self.registry.on_edu(edu.edu_type, origin, edu.content)
- @defer.inlineCallbacks
- @log_function
- def on_context_state_request(self, origin, room_id, event_id):
- if not event_id:
- raise NotImplementedError("Specify an event")
+ await concurrently_execute(
+ _process_edu,
+ getattr(transaction, "edus", []),
+ TRANSACTION_CONCURRENCY_LIMIT,
+ )
+ async def on_context_state_request(
+ self, origin: str, room_id: str, event_id: str
+ ) -> Tuple[int, Dict[str, Any]]:
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
- in_room = yield self.auth.check_host_in_room(room_id, origin)
+ in_room = await self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -290,237 +332,196 @@ class FederationServer(FederationBase):
# in the cache so we could return it without waiting for the linearizer
# - but that's non-trivial to get right, and anyway somewhat defeats
# the point of the linearizer.
- with (yield self._server_linearizer.queue((origin, room_id))):
- resp = yield self._state_resp_cache.wrap(
- (room_id, event_id),
- self._on_context_state_request_compute,
- room_id,
- event_id,
+ with (await self._server_linearizer.queue((origin, room_id))):
+ resp = dict(
+ await self._state_resp_cache.wrap(
+ (room_id, event_id),
+ self._on_context_state_request_compute,
+ room_id,
+ event_id,
+ )
)
+ room_version = await self.store.get_room_version_id(room_id)
+ resp["room_version"] = room_version
+
return 200, resp
- @defer.inlineCallbacks
- def on_state_ids_request(self, origin, room_id, event_id):
+ async def on_state_ids_request(
+ self, origin: str, room_id: str, event_id: str
+ ) -> Tuple[int, Dict[str, Any]]:
if not event_id:
raise NotImplementedError("Specify an event")
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
- in_room = yield self.auth.check_host_in_room(room_id, origin)
+ in_room = await self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
- state_ids = yield self.handler.get_state_ids_for_pdu(room_id, event_id)
- auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids)
+ state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
+ auth_chain_ids = await self.store.get_auth_chain_ids(state_ids)
return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
- @defer.inlineCallbacks
- def _on_context_state_request_compute(self, room_id, event_id):
- pdus = yield self.handler.get_state_for_pdu(room_id, event_id)
- auth_chain = yield self.store.get_auth_chain([pdu.event_id for pdu in pdus])
-
- for event in auth_chain:
- # We sign these again because there was a bug where we
- # incorrectly signed things the first time round
- if self.hs.is_mine_id(event.event_id):
- event.signatures.update(
- compute_event_signature(
- event.get_pdu_json(),
- self.hs.hostname,
- self.hs.config.signing_key[0],
- )
- )
+ async def _on_context_state_request_compute(
+ self, room_id: str, event_id: str
+ ) -> Dict[str, list]:
+ if event_id:
+ pdus = await self.handler.get_state_for_pdu(room_id, event_id)
+ else:
+ pdus = (await self.state.get_current_state(room_id)).values()
+
+ auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus])
return {
"pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
}
- @defer.inlineCallbacks
- @log_function
- def on_pdu_request(self, origin, event_id):
- pdu = yield self.handler.get_persisted_pdu(origin, event_id)
+ async def on_pdu_request(
+ self, origin: str, event_id: str
+ ) -> Tuple[int, Union[JsonDict, str]]:
+ pdu = await self.handler.get_persisted_pdu(origin, event_id)
if pdu:
return 200, self._transaction_from_pdus([pdu]).get_dict()
else:
return 404, ""
- @defer.inlineCallbacks
- def on_query_request(self, query_type, args):
+ async def on_query_request(
+ self, query_type: str, args: Dict[str, str]
+ ) -> Tuple[int, Dict[str, Any]]:
received_queries_counter.labels(query_type).inc()
- resp = yield self.registry.on_query(query_type, args)
+ resp = await self.registry.on_query(query_type, args)
return 200, resp
- @defer.inlineCallbacks
- def on_make_join_request(self, origin, room_id, user_id, supported_versions):
+ async def on_make_join_request(
+ self, origin: str, room_id: str, user_id: str, supported_versions: List[str]
+ ) -> Dict[str, Any]:
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version_id(room_id)
if room_version not in supported_versions:
- logger.warn("Room version %s not in %s", room_version, supported_versions)
+ logger.warning(
+ "Room version %s not in %s", room_version, supported_versions
+ )
raise IncompatibleRoomVersionError(room_version=room_version)
- pdu = yield self.handler.on_make_join_request(origin, room_id, user_id)
+ pdu = await self.handler.on_make_join_request(origin, room_id, user_id)
time_now = self._clock.time_msec()
return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
- @defer.inlineCallbacks
- def on_invite_request(self, origin, content, room_version):
- if room_version not in KNOWN_ROOM_VERSIONS:
+ async def on_invite_request(
+ self, origin: str, content: JsonDict, room_version_id: str
+ ) -> Dict[str, Any]:
+ room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
+ if not room_version:
raise SynapseError(
400,
"Homeserver does not support this room version",
Codes.UNSUPPORTED_ROOM_VERSION,
)
- format_ver = room_version_to_event_format(room_version)
-
- pdu = event_from_pdu_json(content, format_ver)
+ pdu = event_from_pdu_json(content, room_version)
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, pdu.room_id)
- ret_pdu = yield self.handler.on_invite_request(origin, pdu)
+ await self.check_server_matches_acl(origin_host, pdu.room_id)
+ pdu = await self._check_sigs_and_hash(room_version, pdu)
+ ret_pdu = await self.handler.on_invite_request(origin, pdu, room_version)
time_now = self._clock.time_msec()
return {"event": ret_pdu.get_pdu_json(time_now)}
- @defer.inlineCallbacks
- def on_send_join_request(self, origin, content, room_id):
+ async def on_send_join_request(
+ self, origin: str, content: JsonDict, room_id: str
+ ) -> Dict[str, Any]:
logger.debug("on_send_join_request: content: %s", content)
- room_version = yield self.store.get_room_version(room_id)
- format_ver = room_version_to_event_format(room_version)
- pdu = event_from_pdu_json(content, format_ver)
+ room_version = await self.store.get_room_version(room_id)
+ pdu = event_from_pdu_json(content, room_version)
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, pdu.room_id)
+ await self.check_server_matches_acl(origin_host, pdu.room_id)
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
- res_pdus = yield self.handler.on_send_join_request(origin, pdu)
+
+ pdu = await self._check_sigs_and_hash(room_version, pdu)
+
+ res_pdus = await self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
- return (
- 200,
- {
- "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
- "auth_chain": [
- p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
- ],
- },
- )
+ return {
+ "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
+ "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]],
+ }
- @defer.inlineCallbacks
- def on_make_leave_request(self, origin, room_id, user_id):
+ async def on_make_leave_request(
+ self, origin: str, room_id: str, user_id: str
+ ) -> Dict[str, Any]:
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
- pdu = yield self.handler.on_make_leave_request(origin, room_id, user_id)
+ await self.check_server_matches_acl(origin_host, room_id)
+ pdu = await self.handler.on_make_leave_request(origin, room_id, user_id)
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version_id(room_id)
time_now = self._clock.time_msec()
return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
- @defer.inlineCallbacks
- def on_send_leave_request(self, origin, content, room_id):
+ async def on_send_leave_request(
+ self, origin: str, content: JsonDict, room_id: str
+ ) -> dict:
logger.debug("on_send_leave_request: content: %s", content)
- room_version = yield self.store.get_room_version(room_id)
- format_ver = room_version_to_event_format(room_version)
- pdu = event_from_pdu_json(content, format_ver)
+ room_version = await self.store.get_room_version(room_id)
+ pdu = event_from_pdu_json(content, room_version)
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, pdu.room_id)
+ await self.check_server_matches_acl(origin_host, pdu.room_id)
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
- yield self.handler.on_send_leave_request(origin, pdu)
- return 200, {}
- @defer.inlineCallbacks
- def on_event_auth(self, origin, room_id, event_id):
- with (yield self._server_linearizer.queue((origin, room_id))):
- origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
+ pdu = await self._check_sigs_and_hash(room_version, pdu)
- time_now = self._clock.time_msec()
- auth_pdus = yield self.handler.on_event_auth(event_id)
- res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]}
- return 200, res
+ await self.handler.on_send_leave_request(origin, pdu)
+ return {}
- @defer.inlineCallbacks
- def on_query_auth_request(self, origin, content, room_id, event_id):
- """
- Content is a dict with keys::
- auth_chain (list): A list of events that give the auth chain.
- missing (list): A list of event_ids indicating what the other
- side (`origin`) think we're missing.
- rejects (dict): A mapping from event_id to a 2-tuple of reason
- string and a proof (or None) of why the event was rejected.
- The keys of this dict give the list of events the `origin` has
- rejected.
-
- Args:
- origin (str)
- content (dict)
- event_id (str)
-
- Returns:
- Deferred: Results in `dict` with the same format as `content`
- """
- with (yield self._server_linearizer.queue((origin, room_id))):
+ async def on_event_auth(
+ self, origin: str, room_id: str, event_id: str
+ ) -> Tuple[int, Dict[str, Any]]:
+ with (await self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
-
- room_version = yield self.store.get_room_version(room_id)
- format_ver = room_version_to_event_format(room_version)
-
- auth_chain = [
- event_from_pdu_json(e, format_ver) for e in content["auth_chain"]
- ]
-
- signed_auth = yield self._check_sigs_and_hash_and_fetch(
- origin, auth_chain, outlier=True, room_version=room_version
- )
-
- ret = yield self.handler.on_query_auth(
- origin,
- event_id,
- room_id,
- signed_auth,
- content.get("rejects", []),
- content.get("missing", []),
- )
+ await self.check_server_matches_acl(origin_host, room_id)
time_now = self._clock.time_msec()
- send_content = {
- "auth_chain": [e.get_pdu_json(time_now) for e in ret["auth_chain"]],
- "rejects": ret.get("rejects", []),
- "missing": ret.get("missing", []),
- }
-
- return 200, send_content
+ auth_pdus = await self.handler.on_event_auth(event_id)
+ res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]}
+ return 200, res
@log_function
- def on_query_client_keys(self, origin, content):
- return self.on_query_request("client_keys", content)
+ async def on_query_client_keys(
+ self, origin: str, content: Dict[str, str]
+ ) -> Tuple[int, Dict[str, Any]]:
+ return await self.on_query_request("client_keys", content)
- def on_query_user_devices(self, origin, user_id):
- return self.on_query_request("user_devices", user_id)
+ async def on_query_user_devices(
+ self, origin: str, user_id: str
+ ) -> Tuple[int, Dict[str, Any]]:
+ keys = await self.device_handler.on_federation_query_user_devices(user_id)
+ return 200, keys
@trace
- @defer.inlineCallbacks
- @log_function
- def on_claim_client_keys(self, origin, content):
+ async def on_claim_client_keys(
+ self, origin: str, content: JsonDict
+ ) -> Dict[str, Any]:
query = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items():
query.append((user_id, device_id, algorithm))
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
- results = yield self.store.claim_e2e_one_time_keys(query)
+ results = await self.store.claim_e2e_one_time_keys(query)
- json_result = {}
+ json_result = {} # type: Dict[str, Dict[str, dict]]
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
@@ -542,16 +543,19 @@ class FederationServer(FederationBase):
return {"one_time_keys": json_result}
- @defer.inlineCallbacks
- @log_function
- def on_get_missing_events(
- self, origin, room_id, earliest_events, latest_events, limit
- ):
- with (yield self._server_linearizer.queue((origin, room_id))):
+ async def on_get_missing_events(
+ self,
+ origin: str,
+ room_id: str,
+ earliest_events: List[str],
+ latest_events: List[str],
+ limit: int,
+ ) -> Dict[str, list]:
+ with (await self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
- logger.info(
+ logger.debug(
"on_get_missing_events: earliest_events: %r, latest_events: %r,"
" limit: %d",
earliest_events,
@@ -559,27 +563,27 @@ class FederationServer(FederationBase):
limit,
)
- missing_events = yield self.handler.on_get_missing_events(
+ missing_events = await self.handler.on_get_missing_events(
origin, room_id, earliest_events, latest_events, limit
)
if len(missing_events) < 5:
- logger.info(
+ logger.debug(
"Returning %d events: %r", len(missing_events), missing_events
)
else:
- logger.info("Returning %d events", len(missing_events))
+ logger.debug("Returning %d events", len(missing_events))
time_now = self._clock.time_msec()
return {"events": [ev.get_pdu_json(time_now) for ev in missing_events]}
@log_function
- def on_openid_userinfo(self, token):
+ async def on_openid_userinfo(self, token: str) -> Optional[str]:
ts_now_ms = self._clock.time_msec()
- return self.store.get_user_id_for_open_id_token(token, ts_now_ms)
+ return await self.store.get_user_id_for_open_id_token(token, ts_now_ms)
- def _transaction_from_pdus(self, pdu_list):
+ def _transaction_from_pdus(self, pdu_list: List[EventBase]) -> Transaction:
"""Returns a new Transaction containing the given PDUs suitable for
transmission.
"""
@@ -592,8 +596,7 @@ class FederationServer(FederationBase):
destination=None,
)
- @defer.inlineCallbacks
- def _handle_received_pdu(self, origin, pdu):
+ async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None:
""" Process a PDU received in a federation /send/ transaction.
If the event is invalid, then this method throws a FederationError.
@@ -614,10 +617,8 @@ class FederationServer(FederationBase):
until we try to backfill across the discontinuity.
Args:
- origin (str): server which sent the pdu
- pdu (FrozenEvent): received pdu
-
- Returns (Deferred): completes with None
+ origin: server which sent the pdu
+ pdu: received pdu
Raises: FederationError if the signatures / hash do not match, or
if the event was unacceptable for any other reason (eg, too large,
@@ -646,68 +647,67 @@ class FederationServer(FederationBase):
logger.info("Accepting join PDU %s from %s", pdu.event_id, origin)
# We've already checked that we know the room version by this point
- room_version = yield self.store.get_room_version(pdu.room_id)
+ room_version = await self.store.get_room_version(pdu.room_id)
# Check signature.
try:
- pdu = yield self._check_sigs_and_hash(room_version, pdu)
+ pdu = await self._check_sigs_and_hash(room_version, pdu)
except SynapseError as e:
raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id)
- yield self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True)
+ await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True)
def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name
- @defer.inlineCallbacks
- def exchange_third_party_invite(
- self, sender_user_id, target_user_id, room_id, signed
+ async def exchange_third_party_invite(
+ self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict
):
- ret = yield self.handler.exchange_third_party_invite(
+ ret = await self.handler.exchange_third_party_invite(
sender_user_id, target_user_id, room_id, signed
)
return ret
- @defer.inlineCallbacks
- def on_exchange_third_party_invite_request(self, room_id, event_dict):
- ret = yield self.handler.on_exchange_third_party_invite_request(
+ async def on_exchange_third_party_invite_request(
+ self, room_id: str, event_dict: Dict
+ ):
+ ret = await self.handler.on_exchange_third_party_invite_request(
room_id, event_dict
)
return ret
- @defer.inlineCallbacks
- def check_server_matches_acl(self, server_name, room_id):
+ async def check_server_matches_acl(self, server_name: str, room_id: str):
"""Check if the given server is allowed by the server ACLs in the room
Args:
- server_name (str): name of server, *without any port part*
- room_id (str): ID of the room to check
+ server_name: name of server, *without any port part*
+ room_id: ID of the room to check
Raises:
AuthError if the server does not match the ACL
"""
- state_ids = yield self.store.get_current_state_ids(room_id)
+ state_ids = await self.store.get_current_state_ids(room_id)
acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
if not acl_event_id:
return
- acl_event = yield self.store.get_event(acl_event_id)
+ acl_event = await self.store.get_event(acl_event_id)
if server_matches_acl_event(server_name, acl_event):
return
raise AuthError(code=403, msg="Server is banned from room")
-def server_matches_acl_event(server_name, acl_event):
+def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
"""Check if the given server is allowed by the ACL event
Args:
- server_name (str): name of server, without any port part
- acl_event (EventBase): m.room.server_acl event
+ server_name: name of server, without any port part
+ acl_event: m.room.server_acl event
Returns:
- bool: True if this server is allowed by the ACLs
+ True if this server is allowed by the ACLs
"""
logger.debug("Checking %s against acl %s", server_name, acl_event.content)
@@ -715,7 +715,7 @@ def server_matches_acl_event(server_name, acl_event):
# server name is a literal IP
allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
if not isinstance(allow_ip_literals, bool):
- logger.warn("Ignorning non-bool allow_ip_literals flag")
+ logger.warning("Ignorning non-bool allow_ip_literals flag")
allow_ip_literals = True
if not allow_ip_literals:
# check for ipv6 literals. These start with '['.
@@ -729,7 +729,7 @@ def server_matches_acl_event(server_name, acl_event):
# next, check the deny list
deny = acl_event.content.get("deny", [])
if not isinstance(deny, (list, tuple)):
- logger.warn("Ignorning non-list deny ACL %s", deny)
+ logger.warning("Ignorning non-list deny ACL %s", deny)
deny = []
for e in deny:
if _acl_entry_matches(server_name, e):
@@ -739,7 +739,7 @@ def server_matches_acl_event(server_name, acl_event):
# then the allow list.
allow = acl_event.content.get("allow", [])
if not isinstance(allow, (list, tuple)):
- logger.warn("Ignorning non-list allow ACL %s", allow)
+ logger.warning("Ignorning non-list allow ACL %s", allow)
allow = []
for e in allow:
if _acl_entry_matches(server_name, e):
@@ -751,9 +751,9 @@ def server_matches_acl_event(server_name, acl_event):
return False
-def _acl_entry_matches(server_name, acl_entry):
+def _acl_entry_matches(server_name: str, acl_entry: str) -> Match:
if not isinstance(acl_entry, six.string_types):
- logger.warn(
+ logger.warning(
"Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
)
return False
@@ -770,13 +770,13 @@ class FederationHandlerRegistry(object):
self.edu_handlers = {}
self.query_handlers = {}
- def register_edu_handler(self, edu_type, handler):
+ def register_edu_handler(self, edu_type: str, handler: Callable[[str, dict], None]):
"""Sets the handler callable that will be used to handle an incoming
federation EDU of the given type.
Args:
- edu_type (str): The type of the incoming EDU to register handler for
- handler (Callable[[str, dict]]): A callable invoked on incoming EDU
+ edu_type: The type of the incoming EDU to register handler for
+ handler: A callable invoked on incoming EDU
of the given type. The arguments are the origin server name and
the EDU contents.
"""
@@ -787,14 +787,16 @@ class FederationHandlerRegistry(object):
self.edu_handlers[edu_type] = handler
- def register_query_handler(self, query_type, handler):
+ def register_query_handler(
+ self, query_type: str, handler: Callable[[dict], defer.Deferred]
+ ):
"""Sets the handler callable that will be used to handle an incoming
federation query of the given type.
Args:
- query_type (str): Category name of the query, which should match
+ query_type: Category name of the query, which should match
the string used by make_query.
- handler (Callable[[dict], Deferred[dict]]): Invoked to handle
+ handler: Invoked to handle
incoming queries of this type. The return will be yielded
on and the result used as the response to the query request.
"""
@@ -805,24 +807,24 @@ class FederationHandlerRegistry(object):
self.query_handlers[query_type] = handler
- @defer.inlineCallbacks
- def on_edu(self, edu_type, origin, content):
+ async def on_edu(self, edu_type: str, origin: str, content: dict):
handler = self.edu_handlers.get(edu_type)
if not handler:
- logger.warn("No handler registered for EDU type %s", edu_type)
+ logger.warning("No handler registered for EDU type %s", edu_type)
+ return
with start_active_span_from_edu(content, "handle_edu"):
try:
- yield handler(origin, content)
+ await handler(origin, content)
except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception:
logger.exception("Failed to handle edu %r", edu_type)
- def on_query(self, query_type, args):
+ def on_query(self, query_type: str, args: dict) -> defer.Deferred:
handler = self.query_handlers.get(query_type)
if not handler:
- logger.warn("No handler registered for query type %s", query_type)
+ logger.warning("No handler registered for query type %s", query_type)
raise NotFoundError("No handler for Query type '%s'" % (query_type,))
return handler(args)
@@ -846,7 +848,7 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
super(ReplicationFederationHandlerRegistry, self).__init__()
- def on_edu(self, edu_type, origin, content):
+ async def on_edu(self, edu_type: str, origin: str, content: dict):
"""Overrides FederationHandlerRegistry
"""
if not self.config.use_presence and edu_type == "m.presence":
@@ -854,17 +856,17 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
handler = self.edu_handlers.get(edu_type)
if handler:
- return super(ReplicationFederationHandlerRegistry, self).on_edu(
+ return await super(ReplicationFederationHandlerRegistry, self).on_edu(
edu_type, origin, content
)
- return self._send_edu(edu_type=edu_type, origin=origin, content=content)
+ return await self._send_edu(edu_type=edu_type, origin=origin, content=content)
- def on_query(self, query_type, args):
+ async def on_query(self, query_type: str, args: dict):
"""Overrides FederationHandlerRegistry
"""
handler = self.query_handlers.get(query_type)
if handler:
- return handler(args)
+ return await handler(args)
- return self._get_query_client(query_type=query_type, args=args)
+ return await self._get_query_client(query_type=query_type, args=args)
|