diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 5c991e5412..4b115aac04 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -25,19 +25,15 @@ from twisted.python.failure import Failure
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
-from synapse.api.room_versions import (
- KNOWN_ROOM_VERSIONS,
- EventFormatVersions,
- RoomVersion,
-)
+from synapse.api.room_versions import EventFormatVersions, RoomVersion
from synapse.crypto.event_signing import check_event_content_hash
from synapse.crypto.keyring import Keyring
from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import prune_event
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import (
- LoggingContext,
PreserveLoggingContext,
+ current_context,
make_deferred_yieldable,
)
from synapse.types import JsonDict, get_domain_from_id
@@ -55,13 +51,15 @@ class FederationBase(object):
self.store = hs.get_datastore()
self._clock = hs.get_clock()
- def _check_sigs_and_hash(self, room_version: str, pdu: EventBase) -> Deferred:
+ def _check_sigs_and_hash(
+ self, room_version: RoomVersion, pdu: EventBase
+ ) -> Deferred:
return make_deferred_yieldable(
self._check_sigs_and_hashes(room_version, [pdu])[0]
)
def _check_sigs_and_hashes(
- self, room_version: str, pdus: List[EventBase]
+ self, room_version: RoomVersion, pdus: List[EventBase]
) -> List[Deferred]:
"""Checks that each of the received events is correctly signed by the
sending server.
@@ -80,7 +78,7 @@ class FederationBase(object):
"""
deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
- ctx = LoggingContext.current_context()
+ ctx = current_context()
def callback(_, pdu: EventBase):
with PreserveLoggingContext(ctx):
@@ -146,7 +144,7 @@ class PduToCheckSig(
def _check_sigs_on_pdus(
- keyring: Keyring, room_version: str, pdus: Iterable[EventBase]
+ keyring: Keyring, room_version: RoomVersion, pdus: Iterable[EventBase]
) -> List[Deferred]:
"""Check that the given events are correctly signed
@@ -191,10 +189,6 @@ def _check_sigs_on_pdus(
for p in pdus
]
- v = KNOWN_ROOM_VERSIONS.get(room_version)
- if not v:
- raise RuntimeError("Unrecognized room version %s" % (room_version,))
-
# First we check that the sender event is signed by the sender's domain
# (except if its a 3pid invite, in which case it may be sent by any server)
pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
@@ -204,7 +198,7 @@ def _check_sigs_on_pdus(
(
p.sender_domain,
p.redacted_pdu_json,
- p.pdu.origin_server_ts if v.enforce_key_validity else 0,
+ p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
p.pdu.event_id,
)
for p in pdus_to_check_sender
@@ -227,7 +221,7 @@ def _check_sigs_on_pdus(
# event id's domain (normally only the case for joins/leaves), and add additional
# checks. Only do this if the room version has a concept of event ID domain
# (ie, the room version uses old-style non-hash event IDs).
- if v.event_format == EventFormatVersions.V1:
+ if room_version.event_format == EventFormatVersions.V1:
pdus_to_check_event_id = [
p
for p in pdus_to_check
@@ -239,7 +233,7 @@ def _check_sigs_on_pdus(
(
get_domain_from_id(p.pdu.event_id),
p.redacted_pdu_json,
- p.pdu.origin_server_ts if v.enforce_key_validity else 0,
+ p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
p.pdu.event_id,
)
for p in pdus_to_check_event_id
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 8c6b839478..687cd841ac 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -220,8 +220,7 @@ class FederationClient(FederationBase):
# FIXME: We should handle signature failures more gracefully.
pdus[:] = await make_deferred_yieldable(
defer.gatherResults(
- self._check_sigs_and_hashes(room_version.identifier, pdus),
- consumeErrors=True,
+ self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True,
).addErrback(unwrapFirstError)
)
@@ -291,9 +290,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0]
# Check signatures are correct.
- signed_pdu = await self._check_sigs_and_hash(
- room_version.identifier, pdu
- )
+ signed_pdu = await self._check_sigs_and_hash(room_version, pdu)
break
@@ -350,7 +347,7 @@ class FederationClient(FederationBase):
self,
origin: str,
pdus: List[EventBase],
- room_version: str,
+ room_version: RoomVersion,
outlier: bool = False,
include_none: bool = False,
) -> List[EventBase]:
@@ -396,7 +393,7 @@ class FederationClient(FederationBase):
self.get_pdu(
destinations=[pdu.origin],
event_id=pdu.event_id,
- room_version=room_version, # type: ignore
+ room_version=room_version,
outlier=outlier,
timeout=10000,
)
@@ -434,7 +431,7 @@ class FederationClient(FederationBase):
]
signed_auth = await self._check_sigs_and_hash_and_fetch(
- destination, auth_chain, outlier=True, room_version=room_version.identifier
+ destination, auth_chain, outlier=True, room_version=room_version
)
signed_auth.sort(key=lambda e: e.depth)
@@ -661,7 +658,7 @@ class FederationClient(FederationBase):
destination,
list(pdus.values()),
outlier=True,
- room_version=room_version.identifier,
+ room_version=room_version,
)
valid_pdus_map = {p.event_id: p for p in valid_pdus}
@@ -756,7 +753,7 @@ class FederationClient(FederationBase):
pdu = event_from_pdu_json(pdu_dict, room_version)
# Check signatures are correct.
- pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
+ pdu = await self._check_sigs_and_hash(room_version, pdu)
# FIXME: We should handle signature failures more gracefully.
@@ -886,18 +883,37 @@ class FederationClient(FederationBase):
def get_public_rooms(
self,
- destination,
- limit=None,
- since_token=None,
- search_filter=None,
- include_all_networks=False,
- third_party_instance_id=None,
+ remote_server: str,
+ limit: Optional[int] = None,
+ since_token: Optional[str] = None,
+ search_filter: Optional[Dict] = None,
+ include_all_networks: bool = False,
+ third_party_instance_id: Optional[str] = None,
):
- if destination == self.server_name:
- return
+ """Get the list of public rooms from a remote homeserver
+ Args:
+ remote_server: The name of the remote server
+ limit: Maximum amount of rooms to return
+ since_token: Used for result pagination
+ search_filter: A filter dictionary to send the remote homeserver
+ and filter the result set
+ include_all_networks: Whether to include results from all third party instances
+ third_party_instance_id: Whether to only include results from a specific third
+ party instance
+
+ Returns:
+ Deferred[Dict[str, Any]]: The response from the remote server, or None if
+ `remote_server` is the same as the local server_name
+
+ Raises:
+ HttpResponseException: There was an exception returned from the remote server
+ SynapseException: M_FORBIDDEN when the remote server has disallowed publicRoom
+ requests over federation
+
+ """
return self.transport_layer.get_public_rooms(
- destination,
+ remote_server,
limit,
since_token,
search_filter,
@@ -948,7 +964,7 @@ class FederationClient(FederationBase):
]
signed_events = await self._check_sigs_and_hash_and_fetch(
- destination, events, outlier=False, room_version=room_version.identifier
+ destination, events, outlier=False, room_version=room_version
)
except HttpResponseException as e:
if not e.code == 400:
@@ -960,14 +976,13 @@ class FederationClient(FederationBase):
return signed_events
- @defer.inlineCallbacks
- def forward_third_party_invite(self, destinations, room_id, event_dict):
+ async def forward_third_party_invite(self, destinations, room_id, event_dict):
for destination in destinations:
if destination == self.server_name:
continue
try:
- yield self.transport_layer.exchange_third_party_invite(
+ await self.transport_layer.exchange_third_party_invite(
destination=destination, room_id=room_id, event_dict=event_dict
)
return None
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 275b9c99d7..32a8a2ee46 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Dict
+from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union
import six
from six import iteritems
@@ -38,6 +38,7 @@ from synapse.api.errors import (
UnsupportedRoomVersionError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.events import EventBase
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
@@ -94,7 +95,9 @@ class FederationServer(FederationBase):
# come in waves.
self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
- async def on_backfill_request(self, origin, room_id, versions, limit):
+ async def on_backfill_request(
+ self, origin: str, room_id: str, versions: List[str], limit: int
+ ) -> Tuple[int, Dict[str, Any]]:
with (await self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
@@ -107,23 +110,25 @@ class FederationServer(FederationBase):
return 200, res
- async def on_incoming_transaction(self, origin, transaction_data):
+ async def on_incoming_transaction(
+ self, origin: str, transaction_data: JsonDict
+ ) -> Tuple[int, Dict[str, Any]]:
# keep this as early as possible to make the calculated origin ts as
# accurate as possible.
request_time = self._clock.time_msec()
transaction = Transaction(**transaction_data)
- if not transaction.transaction_id:
+ if not transaction.transaction_id: # type: ignore
raise Exception("Transaction missing transaction_id")
- logger.debug("[%s] Got transaction", transaction.transaction_id)
+ logger.debug("[%s] Got transaction", transaction.transaction_id) # type: ignore
# use a linearizer to ensure that we don't process the same transaction
# multiple times in parallel.
with (
await self._transaction_linearizer.queue(
- (origin, transaction.transaction_id)
+ (origin, transaction.transaction_id) # type: ignore
)
):
result = await self._handle_incoming_transaction(
@@ -132,31 +137,33 @@ class FederationServer(FederationBase):
return result
- async def _handle_incoming_transaction(self, origin, transaction, request_time):
+ async def _handle_incoming_transaction(
+ self, origin: str, transaction: Transaction, request_time: int
+ ) -> Tuple[int, Dict[str, Any]]:
""" Process an incoming transaction and return the HTTP response
Args:
- origin (unicode): the server making the request
- transaction (Transaction): incoming transaction
- request_time (int): timestamp that the HTTP request arrived at
+ origin: the server making the request
+ transaction: incoming transaction
+ request_time: timestamp that the HTTP request arrived at
Returns:
- Deferred[(int, object)]: http response code and body
+ HTTP response code and body
"""
response = await self.transaction_actions.have_responded(origin, transaction)
if response:
logger.debug(
"[%s] We've already responded to this request",
- transaction.transaction_id,
+ transaction.transaction_id, # type: ignore
)
return response
- logger.debug("[%s] Transaction is new", transaction.transaction_id)
+ logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore
# Reject if PDU count > 50 or EDU count > 100
- if len(transaction.pdus) > 50 or (
- hasattr(transaction, "edus") and len(transaction.edus) > 100
+ if len(transaction.pdus) > 50 or ( # type: ignore
+ hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore
):
logger.info("Transaction PDU or EDU count too large. Returning 400")
@@ -204,13 +211,13 @@ class FederationServer(FederationBase):
report back to the sending server.
"""
- received_pdus_counter.inc(len(transaction.pdus))
+ received_pdus_counter.inc(len(transaction.pdus)) # type: ignore
origin_host, _ = parse_server_name(origin)
- pdus_by_room = {}
+ pdus_by_room = {} # type: Dict[str, List[EventBase]]
- for p in transaction.pdus:
+ for p in transaction.pdus: # type: ignore
if "unsigned" in p:
unsigned = p["unsigned"]
if "age" in unsigned:
@@ -254,7 +261,7 @@ class FederationServer(FederationBase):
# require callouts to other servers to fetch missing events), but
# impose a limit to avoid going too crazy with ram/cpu.
- async def process_pdus_for_room(room_id):
+ async def process_pdus_for_room(room_id: str):
logger.debug("Processing PDUs for %s", room_id)
try:
await self.check_server_matches_acl(origin_host, room_id)
@@ -310,7 +317,9 @@ class FederationServer(FederationBase):
TRANSACTION_CONCURRENCY_LIMIT,
)
- async def on_context_state_request(self, origin, room_id, event_id):
+ async def on_context_state_request(
+ self, origin: str, room_id: str, event_id: str
+ ) -> Tuple[int, Dict[str, Any]]:
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
@@ -338,7 +347,9 @@ class FederationServer(FederationBase):
return 200, resp
- async def on_state_ids_request(self, origin, room_id, event_id):
+ async def on_state_ids_request(
+ self, origin: str, room_id: str, event_id: str
+ ) -> Tuple[int, Dict[str, Any]]:
if not event_id:
raise NotImplementedError("Specify an event")
@@ -354,7 +365,9 @@ class FederationServer(FederationBase):
return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
- async def _on_context_state_request_compute(self, room_id, event_id):
+ async def _on_context_state_request_compute(
+ self, room_id: str, event_id: str
+ ) -> Dict[str, list]:
if event_id:
pdus = await self.handler.get_state_for_pdu(room_id, event_id)
else:
@@ -367,7 +380,9 @@ class FederationServer(FederationBase):
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
}
- async def on_pdu_request(self, origin, event_id):
+ async def on_pdu_request(
+ self, origin: str, event_id: str
+ ) -> Tuple[int, Union[JsonDict, str]]:
pdu = await self.handler.get_persisted_pdu(origin, event_id)
if pdu:
@@ -375,12 +390,16 @@ class FederationServer(FederationBase):
else:
return 404, ""
- async def on_query_request(self, query_type, args):
+ async def on_query_request(
+ self, query_type: str, args: Dict[str, str]
+ ) -> Tuple[int, Dict[str, Any]]:
received_queries_counter.labels(query_type).inc()
resp = await self.registry.on_query(query_type, args)
return 200, resp
- async def on_make_join_request(self, origin, room_id, user_id, supported_versions):
+ async def on_make_join_request(
+ self, origin: str, room_id: str, user_id: str, supported_versions: List[str]
+ ) -> Dict[str, Any]:
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
@@ -397,7 +416,7 @@ class FederationServer(FederationBase):
async def on_invite_request(
self, origin: str, content: JsonDict, room_version_id: str
- ):
+ ) -> Dict[str, Any]:
room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
if not room_version:
raise SynapseError(
@@ -409,12 +428,14 @@ class FederationServer(FederationBase):
pdu = event_from_pdu_json(content, room_version)
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, pdu.room_id)
- pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
+ pdu = await self._check_sigs_and_hash(room_version, pdu)
ret_pdu = await self.handler.on_invite_request(origin, pdu, room_version)
time_now = self._clock.time_msec()
return {"event": ret_pdu.get_pdu_json(time_now)}
- async def on_send_join_request(self, origin, content, room_id):
+ async def on_send_join_request(
+ self, origin: str, content: JsonDict, room_id: str
+ ) -> Dict[str, Any]:
logger.debug("on_send_join_request: content: %s", content)
room_version = await self.store.get_room_version(room_id)
@@ -425,7 +446,7 @@ class FederationServer(FederationBase):
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
- pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
+ pdu = await self._check_sigs_and_hash(room_version, pdu)
res_pdus = await self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
@@ -434,7 +455,9 @@ class FederationServer(FederationBase):
"auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]],
}
- async def on_make_leave_request(self, origin, room_id, user_id):
+ async def on_make_leave_request(
+ self, origin: str, room_id: str, user_id: str
+ ) -> Dict[str, Any]:
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
pdu = await self.handler.on_make_leave_request(origin, room_id, user_id)
@@ -444,7 +467,9 @@ class FederationServer(FederationBase):
time_now = self._clock.time_msec()
return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
- async def on_send_leave_request(self, origin, content, room_id):
+ async def on_send_leave_request(
+ self, origin: str, content: JsonDict, room_id: str
+ ) -> dict:
logger.debug("on_send_leave_request: content: %s", content)
room_version = await self.store.get_room_version(room_id)
@@ -455,12 +480,14 @@ class FederationServer(FederationBase):
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
- pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
+ pdu = await self._check_sigs_and_hash(room_version, pdu)
await self.handler.on_send_leave_request(origin, pdu)
return {}
- async def on_event_auth(self, origin, room_id, event_id):
+ async def on_event_auth(
+ self, origin: str, room_id: str, event_id: str
+ ) -> Tuple[int, Dict[str, Any]]:
with (await self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
@@ -471,15 +498,21 @@ class FederationServer(FederationBase):
return 200, res
@log_function
- def on_query_client_keys(self, origin, content):
- return self.on_query_request("client_keys", content)
-
- async def on_query_user_devices(self, origin: str, user_id: str):
+ async def on_query_client_keys(
+ self, origin: str, content: Dict[str, str]
+ ) -> Tuple[int, Dict[str, Any]]:
+ return await self.on_query_request("client_keys", content)
+
+ async def on_query_user_devices(
+ self, origin: str, user_id: str
+ ) -> Tuple[int, Dict[str, Any]]:
keys = await self.device_handler.on_federation_query_user_devices(user_id)
return 200, keys
@trace
- async def on_claim_client_keys(self, origin, content):
+ async def on_claim_client_keys(
+ self, origin: str, content: JsonDict
+ ) -> Dict[str, Any]:
query = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items():
@@ -488,7 +521,7 @@ class FederationServer(FederationBase):
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
results = await self.store.claim_e2e_one_time_keys(query)
- json_result = {}
+ json_result = {} # type: Dict[str, Dict[str, dict]]
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
@@ -511,8 +544,13 @@ class FederationServer(FederationBase):
return {"one_time_keys": json_result}
async def on_get_missing_events(
- self, origin, room_id, earliest_events, latest_events, limit
- ):
+ self,
+ origin: str,
+ room_id: str,
+ earliest_events: List[str],
+ latest_events: List[str],
+ limit: int,
+ ) -> Dict[str, list]:
with (await self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
@@ -541,11 +579,11 @@ class FederationServer(FederationBase):
return {"events": [ev.get_pdu_json(time_now) for ev in missing_events]}
@log_function
- def on_openid_userinfo(self, token):
+ async def on_openid_userinfo(self, token: str) -> Optional[str]:
ts_now_ms = self._clock.time_msec()
- return self.store.get_user_id_for_open_id_token(token, ts_now_ms)
+ return await self.store.get_user_id_for_open_id_token(token, ts_now_ms)
- def _transaction_from_pdus(self, pdu_list):
+ def _transaction_from_pdus(self, pdu_list: List[EventBase]) -> Transaction:
"""Returns a new Transaction containing the given PDUs suitable for
transmission.
"""
@@ -558,7 +596,7 @@ class FederationServer(FederationBase):
destination=None,
)
- async def _handle_received_pdu(self, origin, pdu):
+ async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None:
""" Process a PDU received in a federation /send/ transaction.
If the event is invalid, then this method throws a FederationError.
@@ -579,10 +617,8 @@ class FederationServer(FederationBase):
until we try to backfill across the discontinuity.
Args:
- origin (str): server which sent the pdu
- pdu (FrozenEvent): received pdu
-
- Returns (Deferred): completes with None
+ origin: server which sent the pdu
+ pdu: received pdu
Raises: FederationError if the signatures / hash do not match, or
if the event was unacceptable for any other reason (eg, too large,
@@ -611,7 +647,7 @@ class FederationServer(FederationBase):
logger.info("Accepting join PDU %s from %s", pdu.event_id, origin)
# We've already checked that we know the room version by this point
- room_version = await self.store.get_room_version_id(pdu.room_id)
+ room_version = await self.store.get_room_version(pdu.room_id)
# Check signature.
try:
@@ -625,25 +661,27 @@ class FederationServer(FederationBase):
return "<ReplicationLayer(%s)>" % self.server_name
async def exchange_third_party_invite(
- self, sender_user_id, target_user_id, room_id, signed
+ self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict
):
ret = await self.handler.exchange_third_party_invite(
sender_user_id, target_user_id, room_id, signed
)
return ret
- async def on_exchange_third_party_invite_request(self, room_id, event_dict):
+ async def on_exchange_third_party_invite_request(
+ self, room_id: str, event_dict: Dict
+ ):
ret = await self.handler.on_exchange_third_party_invite_request(
room_id, event_dict
)
return ret
- async def check_server_matches_acl(self, server_name, room_id):
+ async def check_server_matches_acl(self, server_name: str, room_id: str):
"""Check if the given server is allowed by the server ACLs in the room
Args:
- server_name (str): name of server, *without any port part*
- room_id (str): ID of the room to check
+ server_name: name of server, *without any port part*
+ room_id: ID of the room to check
Raises:
AuthError if the server does not match the ACL
@@ -661,15 +699,15 @@ class FederationServer(FederationBase):
raise AuthError(code=403, msg="Server is banned from room")
-def server_matches_acl_event(server_name, acl_event):
+def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
"""Check if the given server is allowed by the ACL event
Args:
- server_name (str): name of server, without any port part
- acl_event (EventBase): m.room.server_acl event
+ server_name: name of server, without any port part
+ acl_event: m.room.server_acl event
Returns:
- bool: True if this server is allowed by the ACLs
+ True if this server is allowed by the ACLs
"""
logger.debug("Checking %s against acl %s", server_name, acl_event.content)
@@ -713,7 +751,7 @@ def server_matches_acl_event(server_name, acl_event):
return False
-def _acl_entry_matches(server_name, acl_entry):
+def _acl_entry_matches(server_name: str, acl_entry: str) -> Match:
if not isinstance(acl_entry, six.string_types):
logger.warning(
"Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
@@ -732,13 +770,13 @@ class FederationHandlerRegistry(object):
self.edu_handlers = {}
self.query_handlers = {}
- def register_edu_handler(self, edu_type, handler):
+ def register_edu_handler(self, edu_type: str, handler: Callable[[str, dict], None]):
"""Sets the handler callable that will be used to handle an incoming
federation EDU of the given type.
Args:
- edu_type (str): The type of the incoming EDU to register handler for
- handler (Callable[[str, dict]]): A callable invoked on incoming EDU
+ edu_type: The type of the incoming EDU to register handler for
+ handler: A callable invoked on incoming EDU
of the given type. The arguments are the origin server name and
the EDU contents.
"""
@@ -749,14 +787,16 @@ class FederationHandlerRegistry(object):
self.edu_handlers[edu_type] = handler
- def register_query_handler(self, query_type, handler):
+ def register_query_handler(
+ self, query_type: str, handler: Callable[[dict], defer.Deferred]
+ ):
"""Sets the handler callable that will be used to handle an incoming
federation query of the given type.
Args:
- query_type (str): Category name of the query, which should match
+ query_type: Category name of the query, which should match
the string used by make_query.
- handler (Callable[[dict], Deferred[dict]]): Invoked to handle
+ handler: Invoked to handle
incoming queries of this type. The return will be yielded
on and the result used as the response to the query request.
"""
@@ -767,10 +807,11 @@ class FederationHandlerRegistry(object):
self.query_handlers[query_type] = handler
- async def on_edu(self, edu_type, origin, content):
+ async def on_edu(self, edu_type: str, origin: str, content: dict):
handler = self.edu_handlers.get(edu_type)
if not handler:
logger.warning("No handler registered for EDU type %s", edu_type)
+ return
with start_active_span_from_edu(content, "handle_edu"):
try:
@@ -780,7 +821,7 @@ class FederationHandlerRegistry(object):
except Exception:
logger.exception("Failed to handle edu %r", edu_type)
- def on_query(self, query_type, args):
+ def on_query(self, query_type: str, args: dict) -> defer.Deferred:
handler = self.query_handlers.get(query_type)
if not handler:
logger.warning("No handler registered for query type %s", query_type)
@@ -807,7 +848,7 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
super(ReplicationFederationHandlerRegistry, self).__init__()
- async def on_edu(self, edu_type, origin, content):
+ async def on_edu(self, edu_type: str, origin: str, content: dict):
"""Overrides FederationHandlerRegistry
"""
if not self.config.use_presence and edu_type == "m.presence":
@@ -821,7 +862,7 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
return await self._send_edu(edu_type=edu_type, origin=origin, content=content)
- async def on_query(self, query_type, args):
+ async def on_query(self, query_type: str, args: dict):
"""Overrides FederationHandlerRegistry
"""
handler = self.query_handlers.get(query_type)
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 876fb0e245..52f4f54215 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -31,6 +31,7 @@ Events are replicated via a separate events stream.
import logging
from collections import namedtuple
+from typing import Dict, List, Tuple, Type
from six import iteritems
@@ -56,21 +57,35 @@ class FederationRemoteSendQueue(object):
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
- self.presence_map = {} # Pending presence map user_id -> UserPresenceState
- self.presence_changed = SortedDict() # Stream position -> list[user_id]
+ # Pending presence map user_id -> UserPresenceState
+ self.presence_map = {} # type: Dict[str, UserPresenceState]
+
+ # Stream position -> list[user_id]
+ self.presence_changed = SortedDict() # type: SortedDict[int, List[str]]
# Stores the destinations we need to explicitly send presence to about a
# given user.
# Stream position -> (user_id, destinations)
- self.presence_destinations = SortedDict()
+ self.presence_destinations = (
+ SortedDict()
+ ) # type: SortedDict[int, Tuple[str, List[str]]]
+
+ # (destination, key) -> EDU
+ self.keyed_edu = {} # type: Dict[Tuple[str, tuple], Edu]
- self.keyed_edu = {} # (destination, key) -> EDU
- self.keyed_edu_changed = SortedDict() # stream position -> (destination, key)
+ # stream position -> (destination, key)
+ self.keyed_edu_changed = (
+ SortedDict()
+ ) # type: SortedDict[int, Tuple[str, tuple]]
- self.edus = SortedDict() # stream position -> Edu
+ self.edus = SortedDict() # type: SortedDict[int, Edu]
+ # stream ID for the next entry into presence_changed/keyed_edu_changed/edus.
self.pos = 1
- self.pos_time = SortedDict()
+
+ # map from stream ID to the time that stream entry was generated, so that we
+ # can clear out entries after a while
+ self.pos_time = SortedDict() # type: SortedDict[int, int]
# EVERYTHING IS SAD. In particular, python only makes new scopes when
# we make a new function, so we need to make a new function so the inner
@@ -158,8 +173,10 @@ class FederationRemoteSendQueue(object):
for edu_key in self.keyed_edu_changed.values():
live_keys.add(edu_key)
- to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys]
- for edu_key in to_del:
+ keys_to_del = [
+ edu_key for edu_key in self.keyed_edu if edu_key not in live_keys
+ ]
+ for edu_key in keys_to_del:
del self.keyed_edu[edu_key]
# Delete things out of edu map
@@ -250,19 +267,23 @@ class FederationRemoteSendQueue(object):
self._clear_queue_before_pos(token)
async def get_replication_rows(
- self, from_token, to_token, limit, federation_ack=None
- ):
+ self, instance_name: str, from_token: int, to_token: int, target_row_count: int
+ ) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
"""Get rows to be sent over federation between the two tokens
Args:
- from_token (int)
- to_token(int)
- limit (int)
- federation_ack (int): Optional. The position where the worker is
- explicitly acknowledged it has handled. Allows us to drop
- data from before that point
+ instance_name: the name of the current process
+ from_token: the previous stream token: the starting point for fetching the
+ updates
+ to_token: the new stream token: the point to get updates up to
+ target_row_count: a target for the number of rows to be returned.
+
+ Returns: a triplet `(updates, new_last_token, limited)`, where:
+ * `updates` is a list of `(token, row)` entries.
+ * `new_last_token` is the new position in stream.
+ * `limited` is whether there are more updates to fetch.
"""
- # TODO: Handle limit.
+ # TODO: Handle target_row_count.
# To handle restarts where we wrap around
if from_token > self.pos:
@@ -270,12 +291,7 @@ class FederationRemoteSendQueue(object):
# list of tuple(int, BaseFederationRow), where the first is the position
# of the federation stream.
- rows = []
-
- # There should be only one reader, so lets delete everything its
- # acknowledged its seen.
- if federation_ack:
- self._clear_queue_before_pos(federation_ack)
+ rows = [] # type: List[Tuple[int, BaseFederationRow]]
# Fetch changed presence
i = self.presence_changed.bisect_right(from_token)
@@ -332,7 +348,11 @@ class FederationRemoteSendQueue(object):
# Sort rows based on pos
rows.sort()
- return [(pos, row.TypeId, row.to_data()) for pos, row in rows]
+ return (
+ [(pos, (row.TypeId, row.to_data())) for pos, row in rows],
+ to_token,
+ False,
+ )
class BaseFederationRow(object):
@@ -341,7 +361,7 @@ class BaseFederationRow(object):
Specifies how to identify, serialize and deserialize the different types.
"""
- TypeId = None # Unique string that ids the type. Must be overriden in sub classes.
+ TypeId = "" # Unique string that ids the type. Must be overriden in sub classes.
@staticmethod
def from_data(data):
@@ -454,10 +474,14 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
buff.edus.setdefault(self.edu.destination, []).append(self.edu)
-TypeToRow = {
- Row.TypeId: Row
- for Row in (PresenceRow, PresenceDestinationsRow, KeyedEduRow, EduRow,)
-}
+_rowtypes = (
+ PresenceRow,
+ PresenceDestinationsRow,
+ KeyedEduRow,
+ EduRow,
+) # type: Tuple[Type[BaseFederationRow], ...]
+
+TypeToRow = {Row.TypeId: Row for Row in _rowtypes}
ParsedFederationStreamData = namedtuple(
@@ -477,7 +501,7 @@ def process_rows_for_federation(transaction_queue, rows):
Args:
transaction_queue (FederationSender)
- rows (list(synapse.replication.tcp.streams.FederationStreamRow))
+ rows (list(synapse.replication.tcp.streams.federation.FederationStream.FederationStreamRow))
"""
# The federation stream contains a bunch of different types of
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 233cb33daf..d473576902 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Dict, Hashable, Iterable, List, Optional, Set
+from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple
from six import itervalues
@@ -498,5 +498,16 @@ class FederationSender(object):
self._get_per_destination_queue(destination).attempt_new_transaction()
- def get_current_token(self) -> int:
+ @staticmethod
+ def get_current_token() -> int:
+ # Dummy implementation for case where federation sender isn't offloaded
+ # to a worker.
return 0
+
+ @staticmethod
+ async def get_replication_rows(
+ instance_name: str, from_token: int, to_token: int, target_row_count: int
+ ) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
+ # Dummy implementation for case where federation sender isn't offloaded
+ # to a worker.
+ return [], 0, False
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index e13cd20ffa..276a2b596f 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -15,11 +15,10 @@
# limitations under the License.
import datetime
import logging
-from typing import Dict, Hashable, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Tuple
from prometheus_client import Counter
-import synapse.server
from synapse.api.errors import (
FederationDeniedError,
HttpResponseException,
@@ -34,6 +33,9 @@ from synapse.storage.presence import UserPresenceState
from synapse.types import ReadReceipt
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
+if TYPE_CHECKING:
+ import synapse.server
+
# This is defined in the Matrix spec and enforced by the receiver.
MAX_EDUS_PER_TRANSACTION = 100
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 3c2a02a3b3..a2752a54a5 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -13,11 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import List
+from typing import TYPE_CHECKING, List
from canonicaljson import json
-import synapse.server
from synapse.api.errors import HttpResponseException
from synapse.events import EventBase
from synapse.federation.persistence import TransactionActions
@@ -31,6 +30,9 @@ from synapse.logging.opentracing import (
)
from synapse.util.metrics import measure_func
+if TYPE_CHECKING:
+ import synapse.server
+
logger = logging.getLogger(__name__)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 383e3fdc8b..060bf07197 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -15,13 +15,14 @@
# limitations under the License.
import logging
-from typing import Any, Dict
+from typing import Any, Dict, Optional
from six.moves import urllib
from twisted.internet import defer
from synapse.api.constants import Membership
+from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.api.urls import (
FEDERATION_UNSTABLE_PREFIX,
FEDERATION_V1_PREFIX,
@@ -326,18 +327,25 @@ class TransportLayerClient(object):
@log_function
def get_public_rooms(
self,
- remote_server,
- limit,
- since_token,
- search_filter=None,
- include_all_networks=False,
- third_party_instance_id=None,
+ remote_server: str,
+ limit: Optional[int] = None,
+ since_token: Optional[str] = None,
+ search_filter: Optional[Dict] = None,
+ include_all_networks: bool = False,
+ third_party_instance_id: Optional[str] = None,
):
+ """Get the list of public rooms from a remote homeserver
+
+ See synapse.federation.federation_client.FederationClient.get_public_rooms for
+ more information.
+ """
if search_filter:
# this uses MSC2197 (Search Filtering over Federation)
path = _create_v1_path("/publicRooms")
- data = {"include_all_networks": "true" if include_all_networks else "false"}
+ data = {
+ "include_all_networks": "true" if include_all_networks else "false"
+ } # type: Dict[str, Any]
if third_party_instance_id:
data["third_party_instance_id"] = third_party_instance_id
if limit:
@@ -347,9 +355,19 @@ class TransportLayerClient(object):
data["filter"] = search_filter
- response = yield self.client.post_json(
- destination=remote_server, path=path, data=data, ignore_backoff=True
- )
+ try:
+ response = yield self.client.post_json(
+ destination=remote_server, path=path, data=data, ignore_backoff=True
+ )
+ except HttpResponseException as e:
+ if e.code == 403:
+ raise SynapseError(
+ 403,
+ "You are not allowed to view the public rooms list of %s"
+ % (remote_server,),
+ errcode=Codes.FORBIDDEN,
+ )
+ raise
else:
path = _create_v1_path("/publicRooms")
@@ -363,9 +381,19 @@ class TransportLayerClient(object):
if since_token:
args["since"] = [since_token]
- response = yield self.client.get_json(
- destination=remote_server, path=path, args=args, ignore_backoff=True
- )
+ try:
+ response = yield self.client.get_json(
+ destination=remote_server, path=path, args=args, ignore_backoff=True
+ )
+ except HttpResponseException as e:
+ if e.code == 403:
+ raise SynapseError(
+ 403,
+ "You are not allowed to view the public rooms list of %s"
+ % (remote_server,),
+ errcode=Codes.FORBIDDEN,
+ )
+ raise
return response
|