diff options
Diffstat (limited to 'synapse/handlers/federation.py')
-rw-r--r-- | synapse/handlers/federation.py | 166 |
1 files changed, 85 insertions, 81 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4dbd8e1d98..b5aaa244dd 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -19,8 +19,9 @@ import itertools import logging +from collections import Container from http import HTTPStatus -from typing import Dict, Iterable, List, Optional, Sequence, Tuple +from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union import attr from signedjson.key import decode_verify_key_bytes @@ -742,6 +743,9 @@ class FederationHandler(BaseHandler): # device and recognize the algorithm then we can work out the # exact key to expect. Otherwise check it matches any key we # have for that device. + + current_keys = [] # type: Container[str] + if device: keys = device.get("keys", {}).get("keys", {}) @@ -758,15 +762,15 @@ class FederationHandler(BaseHandler): current_keys = keys.values() elif device_id: # We don't have any keys for the device ID. - current_keys = [] + pass else: # The event didn't include a device ID, so we just look for # keys across all devices. - current_keys = ( + current_keys = [ key for device in cached_devices for key in device.get("keys", {}).get("keys", {}).values() - ) + ] # We now check that the sender key matches (one of) the expected # keys. @@ -1011,7 +1015,7 @@ class FederationHandler(BaseHandler): if e_type == EventTypes.Member and event.membership == Membership.JOIN ] - joined_domains = {} + joined_domains = {} # type: Dict[str, int] for u, d in joined_users: try: dom = get_domain_from_id(u) @@ -1277,14 +1281,15 @@ class FederationHandler(BaseHandler): try: # Try the host we successfully got a response to /make_join/ # request first. + host_list = list(target_hosts) try: - target_hosts.remove(origin) - target_hosts.insert(0, origin) + host_list.remove(origin) + host_list.insert(0, origin) except ValueError: pass ret = await self.federation_client.send_join( - target_hosts, event, room_version_obj + host_list, event, room_version_obj ) origin = ret["origin"] @@ -1584,13 +1589,14 @@ class FederationHandler(BaseHandler): # Try the host that we succesfully called /make_leave/ on first for # the /send_leave/ request. + host_list = list(target_hosts) try: - target_hosts.remove(origin) - target_hosts.insert(0, origin) + host_list.remove(origin) + host_list.insert(0, origin) except ValueError: pass - await self.federation_client.send_leave(target_hosts, event) + await self.federation_client.send_leave(host_list, event) context = await self.state_handler.compute_event_context(event) stream_id = await self.persist_events_and_notify([(event, context)]) @@ -1604,7 +1610,7 @@ class FederationHandler(BaseHandler): user_id: str, membership: str, content: JsonDict = {}, - params: Optional[Dict[str, str]] = None, + params: Optional[Dict[str, Union[str, Iterable[str]]]] = None, ) -> Tuple[str, EventBase, RoomVersion]: ( origin, @@ -2018,8 +2024,8 @@ class FederationHandler(BaseHandler): auth_events_ids = await self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) - auth_events = await self.store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events.values()} + auth_events_x = await self.store.get_events(auth_events_ids) + auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()} # This is a hack to fix some old rooms where the initial join event # didn't reference the create event in its auth events. @@ -2055,76 +2061,67 @@ class FederationHandler(BaseHandler): # For new (non-backfilled and non-outlier) events we check if the event # passes auth based on the current state. If it doesn't then we # "soft-fail" the event. - do_soft_fail_check = not backfilled and not event.internal_metadata.is_outlier() - if do_soft_fail_check: - extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id) - - extrem_ids = set(extrem_ids) - prev_event_ids = set(event.prev_event_ids()) - - if extrem_ids == prev_event_ids: - # If they're the same then the current state is the same as the - # state at the event, so no point rechecking auth for soft fail. - do_soft_fail_check = False - - if do_soft_fail_check: - room_version = await self.store.get_room_version_id(event.room_id) - room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - - # Calculate the "current state". - if state is not None: - # If we're explicitly given the state then we won't have all the - # prev events, and so we have a gap in the graph. In this case - # we want to be a little careful as we might have been down for - # a while and have an incorrect view of the current state, - # however we still want to do checks as gaps are easy to - # maliciously manufacture. - # - # So we use a "current state" that is actually a state - # resolution across the current forward extremities and the - # given state at the event. This should correctly handle cases - # like bans, especially with state res v2. + if backfilled or event.internal_metadata.is_outlier(): + return - state_sets = await self.state_store.get_state_groups( - event.room_id, extrem_ids - ) - state_sets = list(state_sets.values()) - state_sets.append(state) - current_state_ids = await self.state_handler.resolve_events( - room_version, state_sets, event - ) - current_state_ids = { - k: e.event_id for k, e in current_state_ids.items() - } - else: - current_state_ids = await self.state_handler.get_current_state_ids( - event.room_id, latest_event_ids=extrem_ids - ) + extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id) + extrem_ids = set(extrem_ids) + prev_event_ids = set(event.prev_event_ids()) - logger.debug( - "Doing soft-fail check for %s: state %s", - event.event_id, - current_state_ids, + if extrem_ids == prev_event_ids: + # If they're the same then the current state is the same as the + # state at the event, so no point rechecking auth for soft fail. + return + + room_version = await self.store.get_room_version_id(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + + # Calculate the "current state". + if state is not None: + # If we're explicitly given the state then we won't have all the + # prev events, and so we have a gap in the graph. In this case + # we want to be a little careful as we might have been down for + # a while and have an incorrect view of the current state, + # however we still want to do checks as gaps are easy to + # maliciously manufacture. + # + # So we use a "current state" that is actually a state + # resolution across the current forward extremities and the + # given state at the event. This should correctly handle cases + # like bans, especially with state res v2. + + state_sets = await self.state_store.get_state_groups( + event.room_id, extrem_ids + ) + state_sets = list(state_sets.values()) + state_sets.append(state) + current_state_ids = await self.state_handler.resolve_events( + room_version, state_sets, event + ) + current_state_ids = {k: e.event_id for k, e in current_state_ids.items()} + else: + current_state_ids = await self.state_handler.get_current_state_ids( + event.room_id, latest_event_ids=extrem_ids ) - # Now check if event pass auth against said current state - auth_types = auth_types_for_event(event) - current_state_ids = [ - e for k, e in current_state_ids.items() if k in auth_types - ] + logger.debug( + "Doing soft-fail check for %s: state %s", event.event_id, current_state_ids, + ) - current_auth_events = await self.store.get_events(current_state_ids) - current_auth_events = { - (e.type, e.state_key): e for e in current_auth_events.values() - } + # Now check if event pass auth against said current state + auth_types = auth_types_for_event(event) + current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types] - try: - event_auth.check( - room_version_obj, event, auth_events=current_auth_events - ) - except AuthError as e: - logger.warning("Soft-failing %r because %s", event, e) - event.internal_metadata.soft_failed = True + current_auth_events = await self.store.get_events(current_state_ids) + current_auth_events = { + (e.type, e.state_key): e for e in current_auth_events.values() + } + + try: + event_auth.check(room_version_obj, event, auth_events=current_auth_events) + except AuthError as e: + logger.warning("Soft-failing %r because %s", event, e) + event.internal_metadata.soft_failed = True async def on_query_auth( self, origin, event_id, room_id, remote_auth_chain, rejects, missing @@ -2293,10 +2290,10 @@ class FederationHandler(BaseHandler): remote_auth_chain = await self.federation_client.get_event_auth( origin, event.room_id, event.event_id ) - except RequestSendFailed as e: + except RequestSendFailed as e1: # The other side isn't around or doesn't implement the # endpoint, so lets just bail out. - logger.info("Failed to get event auth from remote: %s", e) + logger.info("Failed to get event auth from remote: %s", e1) return context seen_remotes = await self.store.have_seen_events( @@ -2774,7 +2771,8 @@ class FederationHandler(BaseHandler): logger.debug("Checking auth on event %r", event.content) - last_exception = None + last_exception = None # type: Optional[Exception] + # for each public key in the 3pid invite event for public_key_object in self.hs.get_auth().get_public_keys(invite_event): try: @@ -2828,6 +2826,12 @@ class FederationHandler(BaseHandler): return except Exception as e: last_exception = e + + if last_exception is None: + # we can only get here if get_public_keys() returned an empty list + # TODO: make this better + raise RuntimeError("no public key in invite event") + raise last_exception async def _check_key_revocation(self, public_key, url): |