diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 4dbd8e1d98..b5aaa244dd 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -19,8 +19,9 @@
import itertools
import logging
+from collections import Container
from http import HTTPStatus
-from typing import Dict, Iterable, List, Optional, Sequence, Tuple
+from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
import attr
from signedjson.key import decode_verify_key_bytes
@@ -742,6 +743,9 @@ class FederationHandler(BaseHandler):
# device and recognize the algorithm then we can work out the
# exact key to expect. Otherwise check it matches any key we
# have for that device.
+
+ current_keys = [] # type: Container[str]
+
if device:
keys = device.get("keys", {}).get("keys", {})
@@ -758,15 +762,15 @@ class FederationHandler(BaseHandler):
current_keys = keys.values()
elif device_id:
# We don't have any keys for the device ID.
- current_keys = []
+ pass
else:
# The event didn't include a device ID, so we just look for
# keys across all devices.
- current_keys = (
+ current_keys = [
key
for device in cached_devices
for key in device.get("keys", {}).get("keys", {}).values()
- )
+ ]
# We now check that the sender key matches (one of) the expected
# keys.
@@ -1011,7 +1015,7 @@ class FederationHandler(BaseHandler):
if e_type == EventTypes.Member and event.membership == Membership.JOIN
]
- joined_domains = {}
+ joined_domains = {} # type: Dict[str, int]
for u, d in joined_users:
try:
dom = get_domain_from_id(u)
@@ -1277,14 +1281,15 @@ class FederationHandler(BaseHandler):
try:
# Try the host we successfully got a response to /make_join/
# request first.
+ host_list = list(target_hosts)
try:
- target_hosts.remove(origin)
- target_hosts.insert(0, origin)
+ host_list.remove(origin)
+ host_list.insert(0, origin)
except ValueError:
pass
ret = await self.federation_client.send_join(
- target_hosts, event, room_version_obj
+ host_list, event, room_version_obj
)
origin = ret["origin"]
@@ -1584,13 +1589,14 @@ class FederationHandler(BaseHandler):
# Try the host that we succesfully called /make_leave/ on first for
# the /send_leave/ request.
+ host_list = list(target_hosts)
try:
- target_hosts.remove(origin)
- target_hosts.insert(0, origin)
+ host_list.remove(origin)
+ host_list.insert(0, origin)
except ValueError:
pass
- await self.federation_client.send_leave(target_hosts, event)
+ await self.federation_client.send_leave(host_list, event)
context = await self.state_handler.compute_event_context(event)
stream_id = await self.persist_events_and_notify([(event, context)])
@@ -1604,7 +1610,7 @@ class FederationHandler(BaseHandler):
user_id: str,
membership: str,
content: JsonDict = {},
- params: Optional[Dict[str, str]] = None,
+ params: Optional[Dict[str, Union[str, Iterable[str]]]] = None,
) -> Tuple[str, EventBase, RoomVersion]:
(
origin,
@@ -2018,8 +2024,8 @@ class FederationHandler(BaseHandler):
auth_events_ids = await self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
- auth_events = await self.store.get_events(auth_events_ids)
- auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
+ auth_events_x = await self.store.get_events(auth_events_ids)
+ auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
@@ -2055,76 +2061,67 @@ class FederationHandler(BaseHandler):
# For new (non-backfilled and non-outlier) events we check if the event
# passes auth based on the current state. If it doesn't then we
# "soft-fail" the event.
- do_soft_fail_check = not backfilled and not event.internal_metadata.is_outlier()
- if do_soft_fail_check:
- extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
-
- extrem_ids = set(extrem_ids)
- prev_event_ids = set(event.prev_event_ids())
-
- if extrem_ids == prev_event_ids:
- # If they're the same then the current state is the same as the
- # state at the event, so no point rechecking auth for soft fail.
- do_soft_fail_check = False
-
- if do_soft_fail_check:
- room_version = await self.store.get_room_version_id(event.room_id)
- room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
-
- # Calculate the "current state".
- if state is not None:
- # If we're explicitly given the state then we won't have all the
- # prev events, and so we have a gap in the graph. In this case
- # we want to be a little careful as we might have been down for
- # a while and have an incorrect view of the current state,
- # however we still want to do checks as gaps are easy to
- # maliciously manufacture.
- #
- # So we use a "current state" that is actually a state
- # resolution across the current forward extremities and the
- # given state at the event. This should correctly handle cases
- # like bans, especially with state res v2.
+ if backfilled or event.internal_metadata.is_outlier():
+ return
- state_sets = await self.state_store.get_state_groups(
- event.room_id, extrem_ids
- )
- state_sets = list(state_sets.values())
- state_sets.append(state)
- current_state_ids = await self.state_handler.resolve_events(
- room_version, state_sets, event
- )
- current_state_ids = {
- k: e.event_id for k, e in current_state_ids.items()
- }
- else:
- current_state_ids = await self.state_handler.get_current_state_ids(
- event.room_id, latest_event_ids=extrem_ids
- )
+ extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
+ extrem_ids = set(extrem_ids)
+ prev_event_ids = set(event.prev_event_ids())
- logger.debug(
- "Doing soft-fail check for %s: state %s",
- event.event_id,
- current_state_ids,
+ if extrem_ids == prev_event_ids:
+ # If they're the same then the current state is the same as the
+ # state at the event, so no point rechecking auth for soft fail.
+ return
+
+ room_version = await self.store.get_room_version_id(event.room_id)
+ room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+
+ # Calculate the "current state".
+ if state is not None:
+ # If we're explicitly given the state then we won't have all the
+ # prev events, and so we have a gap in the graph. In this case
+ # we want to be a little careful as we might have been down for
+ # a while and have an incorrect view of the current state,
+ # however we still want to do checks as gaps are easy to
+ # maliciously manufacture.
+ #
+ # So we use a "current state" that is actually a state
+ # resolution across the current forward extremities and the
+ # given state at the event. This should correctly handle cases
+ # like bans, especially with state res v2.
+
+ state_sets = await self.state_store.get_state_groups(
+ event.room_id, extrem_ids
+ )
+ state_sets = list(state_sets.values())
+ state_sets.append(state)
+ current_state_ids = await self.state_handler.resolve_events(
+ room_version, state_sets, event
+ )
+ current_state_ids = {k: e.event_id for k, e in current_state_ids.items()}
+ else:
+ current_state_ids = await self.state_handler.get_current_state_ids(
+ event.room_id, latest_event_ids=extrem_ids
)
- # Now check if event pass auth against said current state
- auth_types = auth_types_for_event(event)
- current_state_ids = [
- e for k, e in current_state_ids.items() if k in auth_types
- ]
+ logger.debug(
+ "Doing soft-fail check for %s: state %s", event.event_id, current_state_ids,
+ )
- current_auth_events = await self.store.get_events(current_state_ids)
- current_auth_events = {
- (e.type, e.state_key): e for e in current_auth_events.values()
- }
+ # Now check if event pass auth against said current state
+ auth_types = auth_types_for_event(event)
+ current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
- try:
- event_auth.check(
- room_version_obj, event, auth_events=current_auth_events
- )
- except AuthError as e:
- logger.warning("Soft-failing %r because %s", event, e)
- event.internal_metadata.soft_failed = True
+ current_auth_events = await self.store.get_events(current_state_ids)
+ current_auth_events = {
+ (e.type, e.state_key): e for e in current_auth_events.values()
+ }
+
+ try:
+ event_auth.check(room_version_obj, event, auth_events=current_auth_events)
+ except AuthError as e:
+ logger.warning("Soft-failing %r because %s", event, e)
+ event.internal_metadata.soft_failed = True
async def on_query_auth(
self, origin, event_id, room_id, remote_auth_chain, rejects, missing
@@ -2293,10 +2290,10 @@ class FederationHandler(BaseHandler):
remote_auth_chain = await self.federation_client.get_event_auth(
origin, event.room_id, event.event_id
)
- except RequestSendFailed as e:
+ except RequestSendFailed as e1:
# The other side isn't around or doesn't implement the
# endpoint, so lets just bail out.
- logger.info("Failed to get event auth from remote: %s", e)
+ logger.info("Failed to get event auth from remote: %s", e1)
return context
seen_remotes = await self.store.have_seen_events(
@@ -2774,7 +2771,8 @@ class FederationHandler(BaseHandler):
logger.debug("Checking auth on event %r", event.content)
- last_exception = None
+ last_exception = None # type: Optional[Exception]
+
# for each public key in the 3pid invite event
for public_key_object in self.hs.get_auth().get_public_keys(invite_event):
try:
@@ -2828,6 +2826,12 @@ class FederationHandler(BaseHandler):
return
except Exception as e:
last_exception = e
+
+ if last_exception is None:
+ # we can only get here if get_public_keys() returned an empty list
+ # TODO: make this better
+ raise RuntimeError("no public key in invite event")
+
raise last_exception
async def _check_key_revocation(self, public_key, url):
|