summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/auth.py23
-rw-r--r--synapse/handlers/federation.py387
-rw-r--r--synapse/handlers/initial_sync.py2
-rw-r--r--synapse/handlers/register.py18
-rw-r--r--synapse/handlers/room_member.py17
-rw-r--r--synapse/handlers/room_summary.py76
-rw-r--r--synapse/handlers/sso.py2
-rw-r--r--synapse/handlers/sync.py198
-rw-r--r--synapse/handlers/ui_auth/__init__.py5
-rw-r--r--synapse/handlers/ui_auth/checkers.py75
10 files changed, 490 insertions, 313 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py

index 161b3c933c..98d3d2d97f 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py
@@ -627,23 +627,28 @@ class AuthHandler(BaseHandler): async def add_oob_auth( self, stagetype: str, authdict: Dict[str, Any], clientip: str - ) -> bool: + ) -> None: """ Adds the result of out-of-band authentication into an existing auth session. Currently used for adding the result of fallback auth. + + Raises: + LoginError if the stagetype is unknown or the session is missing. + LoginError is raised by check_auth if authentication fails. """ if stagetype not in self.checkers: - raise LoginError(400, "", Codes.MISSING_PARAM) + raise LoginError( + 400, f"Unknown UIA stage type: {stagetype}", Codes.INVALID_PARAM + ) if "session" not in authdict: - raise LoginError(400, "", Codes.MISSING_PARAM) + raise LoginError(400, "Missing session ID", Codes.MISSING_PARAM) + # If authentication fails a LoginError is raised. Otherwise, store + # the successful result. result = await self.checkers[stagetype].check_auth(authdict, clientip) - if result: - await self.store.mark_ui_auth_stage_complete( - authdict["session"], stagetype, result - ) - return True - return False + await self.store.mark_ui_auth_stage_complete( + authdict["session"], stagetype, result + ) def get_session_id(self, clientdict: Dict[str, Any]) -> Optional[str]: """ diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index c0e13bdaac..246df43501 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py
@@ -203,18 +203,13 @@ class FederationHandler(BaseHandler): self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages - async def on_receive_pdu( - self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False - ) -> None: - """Process a PDU received via a federation /send/ transaction, or - via backfill of missing prev_events + async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None: + """Process a PDU received via a federation /send/ transaction Args: origin: server which initiated the /send/ transaction. Will be used to fetch missing events or state. pdu: received PDU - sent_to_us_directly: True if this event was pushed to us; False if - we pulled it as the result of a missing prev_event. """ room_id = pdu.room_id @@ -276,8 +271,6 @@ class FederationHandler(BaseHandler): ) return None - state = None - # Check that the event passes auth based on the state at the event. This is # done for events that are to be added to the timeline (non-outliers). # @@ -285,23 +278,16 @@ class FederationHandler(BaseHandler): # - Fetching any missing prev events to fill in gaps in the graph # - Fetching state if we have a hole in the graph if not pdu.internal_metadata.is_outlier(): - # We only backfill backwards to the min depth. - min_depth = await self.get_min_depth_for_context(pdu.room_id) - - logger.debug("min_depth: %d", min_depth) - prevs = set(pdu.prev_event_ids()) seen = await self.store.have_events_in_timeline(prevs) + missing_prevs = prevs - seen + + if missing_prevs: + # We only backfill backwards to the min depth. + min_depth = await self.get_min_depth_for_context(pdu.room_id) + logger.debug("min_depth: %d", min_depth) - if min_depth is not None and pdu.depth < min_depth: - # This is so that we don't notify the user about this - # message, to work around the fact that some events will - # reference really really old events we really don't want to - # send to the clients. - pdu.internal_metadata.outlier = True - elif min_depth is not None and pdu.depth > min_depth: - missing_prevs = prevs - seen - if sent_to_us_directly and missing_prevs: + if min_depth is not None and pdu.depth > min_depth: # If we're missing stuff, ensure we only fetch stuff one # at a time. logger.info( @@ -325,42 +311,23 @@ class FederationHandler(BaseHandler): % (event_id, e) ) from e - # Update the set of things we've seen after trying to - # fetch the missing stuff - seen = await self.store.have_events_in_timeline(prevs) - - if not prevs - seen: - logger.info( - "Found all missing prev_events", - ) - - missing_prevs = prevs - seen - if missing_prevs: - # We've still not been able to get all of the prev_events for this event. - # - # In this case, we need to fall back to asking another server in the - # federation for the state at this event. That's ok provided we then - # resolve the state against other bits of the DAG before using it (which - # will ensure that you can't just take over a room by sending an event, - # withholding its prev_events, and declaring yourself to be an admin in - # the subsequent state request). - # - # Now, if we're pulling this event as a missing prev_event, then clearly - # this event is not going to become the only forward-extremity and we are - # guaranteed to resolve its state against our existing forward - # extremities, so that should be fine. - # - # On the other hand, if this event was pushed to us, it is possible for - # it to become the only forward-extremity in the room, and we would then - # trust its state to be the state for the whole room. This is very bad. - # Further, if the event was pushed to us, there is no excuse for us not to - # have all the prev_events. We therefore reject any such events. - # - # XXX this really feels like it could/should be merged with the above, - # but there is an interaction with min_depth that I'm not really - # following. - - if sent_to_us_directly: + # Update the set of things we've seen after trying to + # fetch the missing stuff + seen = await self.store.have_events_in_timeline(prevs) + missing_prevs = prevs - seen + + if not missing_prevs: + logger.info("Found all missing prev_events") + + if missing_prevs: + # since this event was pushed to us, it is possible for it to + # become the only forward-extremity in the room, and we would then + # trust its state to be the state for the whole room. This is very + # bad. Further, if the event was pushed to us, there is no excuse + # for us not to have all the prev_events. (XXX: apart from + # min_depth?) + # + # We therefore reject any such events. logger.warning( "Rejecting: failed to fetch %d prev events: %s", len(missing_prevs), @@ -376,93 +343,7 @@ class FederationHandler(BaseHandler): affected=pdu.event_id, ) - logger.info( - "Event %s is missing prev_events %s: calculating state for a " - "backwards extremity", - event_id, - shortstr(missing_prevs), - ) - - # Calculate the state after each of the previous events, and - # resolve them to find the correct state at the current event. - event_map = {event_id: pdu} - try: - # Get the state of the events we know about - ours = await self.state_store.get_state_groups_ids(room_id, seen) - - # state_maps is a list of mappings from (type, state_key) to event_id - state_maps: List[StateMap[str]] = list(ours.values()) - - # we don't need this any more, let's delete it. - del ours - - # Ask the remote server for the states we don't - # know about - for p in missing_prevs: - logger.info("Requesting state after missing prev_event %s", p) - - with nested_logging_context(p): - # note that if any of the missing prevs share missing state or - # auth events, the requests to fetch those events are deduped - # by the get_pdu_cache in federation_client. - remote_state = ( - await self._get_state_after_missing_prev_event( - origin, room_id, p - ) - ) - - remote_state_map = { - (x.type, x.state_key): x.event_id for x in remote_state - } - state_maps.append(remote_state_map) - - for x in remote_state: - event_map[x.event_id] = x - - room_version = await self.store.get_room_version_id(room_id) - state_map = ( - await self._state_resolution_handler.resolve_events_with_store( - room_id, - room_version, - state_maps, - event_map, - state_res_store=StateResolutionStore(self.store), - ) - ) - - # We need to give _process_received_pdu the actual state events - # rather than event ids, so generate that now. - - # First though we need to fetch all the events that are in - # state_map, so we can build up the state below. - evs = await self.store.get_events( - list(state_map.values()), - get_prev_content=False, - redact_behaviour=EventRedactBehaviour.AS_IS, - ) - event_map.update(evs) - - state = [event_map[e] for e in state_map.values()] - except Exception: - logger.warning( - "Error attempting to resolve state at missing " "prev_events", - exc_info=True, - ) - raise FederationError( - "ERROR", - 403, - "We can't get valid state history.", - affected=event_id, - ) - - # A second round of checks for all events. Check that the event passes auth - # based on `auth_events`, this allows us to assert that the event would - # have been allowed at some point. If an event passes this check its OK - # for it to be used as part of a returned `/state` request, as either - # a) we received the event as part of the original join and so trust it, or - # b) we'll do a state resolution with existing state before it becomes - # part of the "current state", which adds more protection. - await self._process_received_pdu(origin, pdu, state=state) + await self._process_received_pdu(origin, pdu, state=None) async def _get_missing_events_for_pdu( self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int @@ -562,24 +443,7 @@ class FederationHandler(BaseHandler): return logger.info("Got %d prev_events", len(missing_events)) - - # We want to sort these by depth so we process them and - # tell clients about them in order. - missing_events.sort(key=lambda x: x.depth) - - for ev in missing_events: - logger.info("Handling received prev_event %s", ev) - with nested_logging_context(ev.event_id): - try: - await self.on_receive_pdu(origin, ev, sent_to_us_directly=False) - except FederationError as e: - if e.code == 403: - logger.warning( - "Received prev_event %s failed history check.", - ev.event_id, - ) - else: - raise + await self._process_pulled_events(origin, missing_events) async def _get_state_for_room( self, @@ -1496,6 +1360,198 @@ class FederationHandler(BaseHandler): event_infos, ) + async def _process_pulled_events( + self, origin: str, events: Iterable[EventBase] + ) -> None: + """Process a batch of events we have pulled from a remote server + + Pulls in any events required to auth the events, persists the received events, + and notifies clients, if appropriate. + + Assumes the events have already had their signatures and hashes checked. + + Params: + origin: The server we received these events from + events: The received events. + """ + + # We want to sort these by depth so we process them and + # tell clients about them in order. + sorted_events = sorted(events, key=lambda x: x.depth) + + for ev in sorted_events: + with nested_logging_context(ev.event_id): + await self._process_pulled_event(origin, ev) + + async def _process_pulled_event(self, origin: str, event: EventBase) -> None: + """Process a single event that we have pulled from a remote server + + Pulls in any events required to auth the event, persists the received event, + and notifies clients, if appropriate. + + Assumes the event has already had its signatures and hashes checked. + + This is somewhat equivalent to on_receive_pdu, but applies somewhat different + logic in the case that we are missing prev_events (in particular, it just + requests the state at that point, rather than triggering a get_missing_events) - + so is appropriate when we have pulled the event from a remote server, rather + than having it pushed to us. + + Params: + origin: The server we received this event from + events: The received event + """ + logger.info("Processing pulled event %s", event) + + # these should not be outliers. + assert not event.internal_metadata.is_outlier() + + event_id = event.event_id + + existing = await self.store.get_event( + event_id, allow_none=True, allow_rejected=True + ) + if existing: + if not existing.internal_metadata.is_outlier(): + logger.info( + "Ignoring received event %s which we have already seen", + event_id, + ) + return + logger.info("De-outliering event %s", event_id) + + try: + self._sanity_check_event(event) + except SynapseError as err: + logger.warning("Event %s failed sanity check: %s", event_id, err) + return + + try: + state = await self._resolve_state_at_missing_prevs(origin, event) + await self._process_received_pdu(origin, event, state=state) + except FederationError as e: + if e.code == 403: + logger.warning("Pulled event %s failed history check.", event_id) + else: + raise + + async def _resolve_state_at_missing_prevs( + self, dest: str, event: EventBase + ) -> Optional[Iterable[EventBase]]: + """Calculate the state at an event with missing prev_events. + + This is used when we have pulled a batch of events from a remote server, and + still don't have all the prev_events. + + If we already have all the prev_events for `event`, this method does nothing. + + Otherwise, the missing prevs become new backwards extremities, and we fall back + to asking the remote server for the state after each missing `prev_event`, + and resolving across them. + + That's ok provided we then resolve the state against other bits of the DAG + before using it - in other words, that the received event `event` is not going + to become the only forwards_extremity in the room (which will ensure that you + can't just take over a room by sending an event, withholding its prev_events, + and declaring yourself to be an admin in the subsequent state request). + + In other words: we should only call this method if `event` has been *pulled* + as part of a batch of missing prev events, or similar. + + Params: + dest: the remote server to ask for state at the missing prevs. Typically, + this will be the server we got `event` from. + event: an event to check for missing prevs. + + Returns: + if we already had all the prev events, `None`. Otherwise, returns a list of + the events in the state at `event`. + """ + room_id = event.room_id + event_id = event.event_id + + prevs = set(event.prev_event_ids()) + seen = await self.store.have_events_in_timeline(prevs) + missing_prevs = prevs - seen + + if not missing_prevs: + return None + + logger.info( + "Event %s is missing prev_events %s: calculating state for a " + "backwards extremity", + event_id, + shortstr(missing_prevs), + ) + # Calculate the state after each of the previous events, and + # resolve them to find the correct state at the current event. + event_map = {event_id: event} + try: + # Get the state of the events we know about + ours = await self.state_store.get_state_groups_ids(room_id, seen) + + # state_maps is a list of mappings from (type, state_key) to event_id + state_maps: List[StateMap[str]] = list(ours.values()) + + # we don't need this any more, let's delete it. + del ours + + # Ask the remote server for the states we don't + # know about + for p in missing_prevs: + logger.info("Requesting state after missing prev_event %s", p) + + with nested_logging_context(p): + # note that if any of the missing prevs share missing state or + # auth events, the requests to fetch those events are deduped + # by the get_pdu_cache in federation_client. + remote_state = await self._get_state_after_missing_prev_event( + dest, room_id, p + ) + + remote_state_map = { + (x.type, x.state_key): x.event_id for x in remote_state + } + state_maps.append(remote_state_map) + + for x in remote_state: + event_map[x.event_id] = x + + room_version = await self.store.get_room_version_id(room_id) + state_map = await self._state_resolution_handler.resolve_events_with_store( + room_id, + room_version, + state_maps, + event_map, + state_res_store=StateResolutionStore(self.store), + ) + + # We need to give _process_received_pdu the actual state events + # rather than event ids, so generate that now. + + # First though we need to fetch all the events that are in + # state_map, so we can build up the state below. + evs = await self.store.get_events( + list(state_map.values()), + get_prev_content=False, + redact_behaviour=EventRedactBehaviour.AS_IS, + ) + event_map.update(evs) + + state = [event_map[e] for e in state_map.values()] + except Exception: + logger.warning( + "Error attempting to resolve state at missing prev_events", + exc_info=True, + ) + raise FederationError( + "ERROR", + 403, + "We can't get valid state history.", + affected=event_id, + ) + return state + def _sanity_check_event(self, ev: EventBase) -> None: """ Do some early sanity checks of a received event @@ -1764,7 +1820,7 @@ class FederationHandler(BaseHandler): p, ) with nested_logging_context(p.event_id): - await self.on_receive_pdu(origin, p, sent_to_us_directly=True) + await self.on_receive_pdu(origin, p) except Exception as e: logger.warning( "Error handling queued PDU %s from %s: %s", p.event_id, origin, e @@ -2375,6 +2431,7 @@ class FederationHandler(BaseHandler): not event.internal_metadata.is_outlier() and not backfilled and not context.rejected + and (await self.store.get_min_depth(event.room_id)) <= event.depth ): await self.action_generator.handle_push_actions_for_event( event, context diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index e1c544a3c9..4e8f7f1d85 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py
@@ -151,7 +151,7 @@ class InitialSyncHandler(BaseHandler): limit = 10 async def handle_room(event: RoomsForUser): - d = { + d: JsonDict = { "room_id": event.room_id, "membership": event.membership, "visibility": ( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 8cf614136e..0ed59d757b 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py
@@ -56,6 +56,22 @@ login_counter = Counter( ) +def init_counters_for_auth_provider(auth_provider_id: str) -> None: + """Ensure the prometheus counters for the given auth provider are initialised + + This fixes a problem where the counters are not reported for a given auth provider + until the user first logs in/registers. + """ + for is_guest in (True, False): + login_counter.labels(guest=is_guest, auth_provider=auth_provider_id) + for shadow_banned in (True, False): + registration_counter.labels( + guest=is_guest, + shadow_banned=shadow_banned, + auth_provider=auth_provider_id, + ) + + class LoginDict(TypedDict): device_id: str access_token: str @@ -96,6 +112,8 @@ class RegistrationHandler(BaseHandler): self.session_lifetime = hs.config.session_lifetime self.access_token_lifetime = hs.config.access_token_lifetime + init_counters_for_auth_provider("") + async def check_username( self, localpart: str, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index ba13196218..401b84aad1 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py
@@ -36,6 +36,7 @@ from synapse.api.ratelimiting import Ratelimiter from synapse.event_auth import get_named_level, get_power_level_event from synapse.events import EventBase from synapse.events.snapshot import EventContext +from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN from synapse.types import ( JsonDict, Requester, @@ -79,7 +80,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.account_data_handler = hs.get_account_data_handler() self.event_auth_handler = hs.get_event_auth_handler() - self.member_linearizer = Linearizer(name="member") + self.member_linearizer: Linearizer = Linearizer(name="member") self.clock = hs.get_clock() self.spam_checker = hs.get_spam_checker() @@ -556,6 +557,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): content.pop("displayname", None) content.pop("avatar_url", None) + if len(content.get("displayname") or "") > MAX_DISPLAYNAME_LEN: + raise SynapseError( + 400, + f"Displayname is too long (max {MAX_DISPLAYNAME_LEN})", + errcode=Codes.BAD_JSON, + ) + + if len(content.get("avatar_url") or "") > MAX_AVATAR_URL_LEN: + raise SynapseError( + 400, + f"Avatar URL is too long (max {MAX_AVATAR_URL_LEN})", + errcode=Codes.BAD_JSON, + ) + effective_membership_state = action if action in ["kick", "unban"]: effective_membership_state = "leave" diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index ac6cfc0da9..906985c754 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py
@@ -28,12 +28,11 @@ from synapse.api.constants import ( Membership, RoomTypes, ) -from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError from synapse.events import EventBase from synapse.events.utils import format_event_for_client_v2 from synapse.types import JsonDict from synapse.util.caches.response_cache import ResponseCache -from synapse.util.stringutils import random_string if TYPE_CHECKING: from synapse.server import HomeServer @@ -76,6 +75,9 @@ class _PaginationSession: class RoomSummaryHandler: + # A unique key used for pagination sessions for the room hierarchy endpoint. + _PAGINATION_SESSION_TYPE = "room_hierarchy_pagination" + # The time a pagination session remains valid for. _PAGINATION_SESSION_VALIDITY_PERIOD_MS = 5 * 60 * 1000 @@ -87,12 +89,6 @@ class RoomSummaryHandler: self._server_name = hs.hostname self._federation_client = hs.get_federation_client() - # A map of query information to the current pagination state. - # - # TODO Allow for multiple workers to share this data. - # TODO Expire pagination tokens. - self._pagination_sessions: Dict[_PaginationKey, _PaginationSession] = {} - # If a user tries to fetch the same page multiple times in quick succession, # only process the first attempt and return its result to subsequent requests. self._pagination_response_cache: ResponseCache[ @@ -102,21 +98,6 @@ class RoomSummaryHandler: "get_room_hierarchy", ) - def _expire_pagination_sessions(self): - """Expire pagination session which are old.""" - expire_before = ( - self._clock.time_msec() - self._PAGINATION_SESSION_VALIDITY_PERIOD_MS - ) - to_expire = [] - - for key, value in self._pagination_sessions.items(): - if value.creation_time_ms < expire_before: - to_expire.append(key) - - for key in to_expire: - logger.debug("Expiring pagination session id %s", key) - del self._pagination_sessions[key] - async def get_space_summary( self, requester: str, @@ -327,18 +308,29 @@ class RoomSummaryHandler: # If this is continuing a previous session, pull the persisted data. if from_token: - self._expire_pagination_sessions() + try: + pagination_session = await self._store.get_session( + session_type=self._PAGINATION_SESSION_TYPE, + session_id=from_token, + ) + except StoreError: + raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM) - pagination_key = _PaginationKey( - requested_room_id, suggested_only, max_depth, from_token - ) - if pagination_key not in self._pagination_sessions: + # If the requester, room ID, suggested-only, or max depth were modified + # the session is invalid. + if ( + requester != pagination_session["requester"] + or requested_room_id != pagination_session["room_id"] + or suggested_only != pagination_session["suggested_only"] + or max_depth != pagination_session["max_depth"] + ): raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM) # Load the previous state. - pagination_session = self._pagination_sessions[pagination_key] - room_queue = pagination_session.room_queue - processed_rooms = pagination_session.processed_rooms + room_queue = [ + _RoomQueueEntry(*fields) for fields in pagination_session["room_queue"] + ] + processed_rooms = set(pagination_session["processed_rooms"]) else: # The queue of rooms to process, the next room is last on the stack. room_queue = [_RoomQueueEntry(requested_room_id, ())] @@ -456,13 +448,21 @@ class RoomSummaryHandler: # If there's additional data, generate a pagination token (and persist state). if room_queue: - next_batch = random_string(24) - result["next_batch"] = next_batch - pagination_key = _PaginationKey( - requested_room_id, suggested_only, max_depth, next_batch - ) - self._pagination_sessions[pagination_key] = _PaginationSession( - self._clock.time_msec(), room_queue, processed_rooms + result["next_batch"] = await self._store.create_session( + session_type=self._PAGINATION_SESSION_TYPE, + value={ + # Information which must be identical across pagination. + "requester": requester, + "room_id": requested_room_id, + "suggested_only": suggested_only, + "max_depth": max_depth, + # The stored state. + "room_queue": [ + attr.astuple(room_entry) for room_entry in room_queue + ], + "processed_rooms": list(processed_rooms), + }, + expiry_ms=self._PAGINATION_SESSION_VALIDITY_PERIOD_MS, ) return result diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 1b855a685c..0e6ebb574e 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py
@@ -37,6 +37,7 @@ from twisted.web.server import Request from synapse.api.constants import LoginType from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.config.sso import SsoAttributeRequirement +from synapse.handlers.register import init_counters_for_auth_provider from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent from synapse.http.server import respond_with_html, respond_with_redirect @@ -213,6 +214,7 @@ class SsoHandler: p_id = p.idp_id assert p_id not in self._identity_providers self._identity_providers[p_id] = p + init_counters_for_auth_provider(p_id) def get_identity_providers(self) -> Mapping[str, SsoIdentityProvider]: """Get the configured identity providers""" diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 590642f510..86c3c7f0df 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py
@@ -1,5 +1,4 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2018, 2019 New Vector Ltd +# Copyright 2015-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,6 +30,8 @@ from prometheus_client import Counter from synapse.api.constants import AccountDataTypes, EventTypes, Membership from synapse.api.filtering import FilterCollection +from synapse.api.presence import UserPresenceState +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.logging.context import current_context from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span @@ -86,20 +87,20 @@ LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE = 100 SyncRequestKey = Tuple[Any, ...] -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class SyncConfig: - user = attr.ib(type=UserID) - filter_collection = attr.ib(type=FilterCollection) - is_guest = attr.ib(type=bool) - request_key = attr.ib(type=SyncRequestKey) - device_id = attr.ib(type=Optional[str]) + user: UserID + filter_collection: FilterCollection + is_guest: bool + request_key: SyncRequestKey + device_id: Optional[str] -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class TimelineBatch: - prev_batch = attr.ib(type=StreamToken) - events = attr.ib(type=List[EventBase]) - limited = attr.ib(type=bool) + prev_batch: StreamToken + events: List[EventBase] + limited: bool def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -113,16 +114,16 @@ class TimelineBatch: # if there are updates for it, which we check after the instance has been created. # This should not be a big deal because we update the notification counts afterwards as # well anyway. -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class JoinedSyncResult: - room_id = attr.ib(type=str) - timeline = attr.ib(type=TimelineBatch) - state = attr.ib(type=StateMap[EventBase]) - ephemeral = attr.ib(type=List[JsonDict]) - account_data = attr.ib(type=List[JsonDict]) - unread_notifications = attr.ib(type=JsonDict) - summary = attr.ib(type=Optional[JsonDict]) - unread_count = attr.ib(type=int) + room_id: str + timeline: TimelineBatch + state: StateMap[EventBase] + ephemeral: List[JsonDict] + account_data: List[JsonDict] + unread_notifications: JsonDict + summary: Optional[JsonDict] + unread_count: int def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -138,12 +139,12 @@ class JoinedSyncResult: ) -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class ArchivedSyncResult: - room_id = attr.ib(type=str) - timeline = attr.ib(type=TimelineBatch) - state = attr.ib(type=StateMap[EventBase]) - account_data = attr.ib(type=List[JsonDict]) + room_id: str + timeline: TimelineBatch + state: StateMap[EventBase] + account_data: List[JsonDict] def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -152,37 +153,37 @@ class ArchivedSyncResult: return bool(self.timeline or self.state or self.account_data) -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class InvitedSyncResult: - room_id = attr.ib(type=str) - invite = attr.ib(type=EventBase) + room_id: str + invite: EventBase def __bool__(self) -> bool: """Invited rooms should always be reported to the client""" return True -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class KnockedSyncResult: - room_id = attr.ib(type=str) - knock = attr.ib(type=EventBase) + room_id: str + knock: EventBase def __bool__(self) -> bool: """Knocked rooms should always be reported to the client""" return True -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class GroupsSyncResult: - join = attr.ib(type=JsonDict) - invite = attr.ib(type=JsonDict) - leave = attr.ib(type=JsonDict) + join: JsonDict + invite: JsonDict + leave: JsonDict def __bool__(self) -> bool: return bool(self.join or self.invite or self.leave) -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class DeviceLists: """ Attributes: @@ -190,27 +191,27 @@ class DeviceLists: left: List of user_ids whose devices we no longer track """ - changed = attr.ib(type=Collection[str]) - left = attr.ib(type=Collection[str]) + changed: Collection[str] + left: Collection[str] def __bool__(self) -> bool: return bool(self.changed or self.left) -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class _RoomChanges: """The set of room entries to include in the sync, plus the set of joined and left room IDs since last sync. """ - room_entries = attr.ib(type=List["RoomSyncResultBuilder"]) - invited = attr.ib(type=List[InvitedSyncResult]) - knocked = attr.ib(type=List[KnockedSyncResult]) - newly_joined_rooms = attr.ib(type=List[str]) - newly_left_rooms = attr.ib(type=List[str]) + room_entries: List["RoomSyncResultBuilder"] + invited: List[InvitedSyncResult] + knocked: List[KnockedSyncResult] + newly_joined_rooms: List[str] + newly_left_rooms: List[str] -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class SyncResult: """ Attributes: @@ -230,18 +231,18 @@ class SyncResult: groups: Group updates, if any """ - next_batch = attr.ib(type=StreamToken) - presence = attr.ib(type=List[JsonDict]) - account_data = attr.ib(type=List[JsonDict]) - joined = attr.ib(type=List[JoinedSyncResult]) - invited = attr.ib(type=List[InvitedSyncResult]) - knocked = attr.ib(type=List[KnockedSyncResult]) - archived = attr.ib(type=List[ArchivedSyncResult]) - to_device = attr.ib(type=List[JsonDict]) - device_lists = attr.ib(type=DeviceLists) - device_one_time_keys_count = attr.ib(type=JsonDict) - device_unused_fallback_key_types = attr.ib(type=List[str]) - groups = attr.ib(type=Optional[GroupsSyncResult]) + next_batch: StreamToken + presence: List[UserPresenceState] + account_data: List[JsonDict] + joined: List[JoinedSyncResult] + invited: List[InvitedSyncResult] + knocked: List[KnockedSyncResult] + archived: List[ArchivedSyncResult] + to_device: List[JsonDict] + device_lists: DeviceLists + device_one_time_keys_count: JsonDict + device_unused_fallback_key_types: List[str] + groups: Optional[GroupsSyncResult] def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -701,7 +702,7 @@ class SyncHandler: name_id = state_ids.get((EventTypes.Name, "")) canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, "")) - summary = {} + summary: JsonDict = {} empty_ms = MemberSummary([], 0) # TODO: only send these when they change. @@ -1843,6 +1844,9 @@ class SyncHandler: knocked = [] for event in room_list: + if event.room_version_id not in KNOWN_ROOM_VERSIONS: + continue + if event.membership == Membership.JOIN: room_entries.append( RoomSyncResultBuilder( @@ -2076,21 +2080,23 @@ class SyncHandler: # If the membership's stream ordering is after the given stream # ordering, we need to go and work out if the user was in the room # before. - for room_id, event_pos in joined_rooms: - if not event_pos.persisted_after(room_key): - joined_room_ids.add(room_id) + for joined_room in joined_rooms: + if not joined_room.event_pos.persisted_after(room_key): + joined_room_ids.add(joined_room.room_id) continue - logger.info("User joined room after current token: %s", room_id) + logger.info("User joined room after current token: %s", joined_room.room_id) extrems = ( await self.store.get_forward_extremities_for_room_at_stream_ordering( - room_id, event_pos.stream + joined_room.room_id, joined_room.event_pos.stream ) ) - users_in_room = await self.state.get_current_users_in_room(room_id, extrems) + users_in_room = await self.state.get_current_users_in_room( + joined_room.room_id, extrems + ) if user_id in users_in_room: - joined_room_ids.add(room_id) + joined_room_ids.add(joined_room.room_id) return frozenset(joined_room_ids) @@ -2160,7 +2166,7 @@ def _calculate_state( return {event_id_to_key[e]: e for e in state_ids} -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class SyncResultBuilder: """Used to help build up a new SyncResult for a user @@ -2172,33 +2178,33 @@ class SyncResultBuilder: joined_room_ids: List of rooms the user is joined to # The following mirror the fields in a sync response - presence (list) - account_data (list) - joined (list[JoinedSyncResult]) - invited (list[InvitedSyncResult]) - knocked (list[KnockedSyncResult]) - archived (list[ArchivedSyncResult]) - groups (GroupsSyncResult|None) - to_device (list) + presence + account_data + joined + invited + knocked + archived + groups + to_device """ - sync_config = attr.ib(type=SyncConfig) - full_state = attr.ib(type=bool) - since_token = attr.ib(type=Optional[StreamToken]) - now_token = attr.ib(type=StreamToken) - joined_room_ids = attr.ib(type=FrozenSet[str]) + sync_config: SyncConfig + full_state: bool + since_token: Optional[StreamToken] + now_token: StreamToken + joined_room_ids: FrozenSet[str] - presence = attr.ib(type=List[JsonDict], default=attr.Factory(list)) - account_data = attr.ib(type=List[JsonDict], default=attr.Factory(list)) - joined = attr.ib(type=List[JoinedSyncResult], default=attr.Factory(list)) - invited = attr.ib(type=List[InvitedSyncResult], default=attr.Factory(list)) - knocked = attr.ib(type=List[KnockedSyncResult], default=attr.Factory(list)) - archived = attr.ib(type=List[ArchivedSyncResult], default=attr.Factory(list)) - groups = attr.ib(type=Optional[GroupsSyncResult], default=None) - to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list)) + presence: List[UserPresenceState] = attr.Factory(list) + account_data: List[JsonDict] = attr.Factory(list) + joined: List[JoinedSyncResult] = attr.Factory(list) + invited: List[InvitedSyncResult] = attr.Factory(list) + knocked: List[KnockedSyncResult] = attr.Factory(list) + archived: List[ArchivedSyncResult] = attr.Factory(list) + groups: Optional[GroupsSyncResult] = None + to_device: List[JsonDict] = attr.Factory(list) -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class RoomSyncResultBuilder: """Stores information needed to create either a `JoinedSyncResult` or `ArchivedSyncResult`. @@ -2214,10 +2220,10 @@ class RoomSyncResultBuilder: upto_token: Latest point to return events from. """ - room_id = attr.ib(type=str) - rtype = attr.ib(type=str) - events = attr.ib(type=Optional[List[EventBase]]) - newly_joined = attr.ib(type=bool) - full_state = attr.ib(type=bool) - since_token = attr.ib(type=Optional[StreamToken]) - upto_token = attr.ib(type=StreamToken) + room_id: str + rtype: str + events: Optional[List[EventBase]] + newly_joined: bool + full_state: bool + since_token: Optional[StreamToken] + upto_token: StreamToken diff --git a/synapse/handlers/ui_auth/__init__.py b/synapse/handlers/ui_auth/__init__.py
index 4c3b669fae..13b0c61d2e 100644 --- a/synapse/handlers/ui_auth/__init__.py +++ b/synapse/handlers/ui_auth/__init__.py
@@ -34,3 +34,8 @@ class UIAuthSessionDataConstants: # used by validate_user_via_ui_auth to store the mxid of the user we are validating # for. REQUEST_USER_ID = "request_user_id" + + # used during registration to store the registration token used (if required) so that: + # - we can prevent a token being used twice by one session + # - we can 'use up' the token after registration has successfully completed + REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token" diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 5414ce77d8..d3828dec6b 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py
@@ -49,7 +49,7 @@ class UserInteractiveAuthChecker: clientip: The IP address of the client. Raises: - SynapseError if authentication failed + LoginError if authentication failed. Returns: The result of authentication (to pass back to the client?) @@ -131,7 +131,9 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): ) if resp_body["success"]: return True - raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) + raise LoginError( + 401, "Captcha authentication failed", errcode=Codes.UNAUTHORIZED + ) class _BaseThreepidAuthChecker: @@ -191,7 +193,9 @@ class _BaseThreepidAuthChecker: raise AssertionError("Unrecognized threepid medium: %s" % (medium,)) if not threepid: - raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) + raise LoginError( + 401, "Unable to get validated threepid", errcode=Codes.UNAUTHORIZED + ) if threepid["medium"] != medium: raise LoginError( @@ -237,11 +241,76 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): return await self._check_threepid("msisdn", authdict) +class RegistrationTokenAuthChecker(UserInteractiveAuthChecker): + AUTH_TYPE = LoginType.REGISTRATION_TOKEN + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + self.hs = hs + self._enabled = bool(hs.config.registration_requires_token) + self.store = hs.get_datastore() + + def is_enabled(self) -> bool: + return self._enabled + + async def check_auth(self, authdict: dict, clientip: str) -> Any: + if "token" not in authdict: + raise LoginError(400, "Missing registration token", Codes.MISSING_PARAM) + if not isinstance(authdict["token"], str): + raise LoginError( + 400, "Registration token must be a string", Codes.INVALID_PARAM + ) + if "session" not in authdict: + raise LoginError(400, "Missing UIA session", Codes.MISSING_PARAM) + + # Get these here to avoid cyclic dependencies + from synapse.handlers.ui_auth import UIAuthSessionDataConstants + + auth_handler = self.hs.get_auth_handler() + + session = authdict["session"] + token = authdict["token"] + + # If the LoginType.REGISTRATION_TOKEN stage has already been completed, + # return early to avoid incrementing `pending` again. + stored_token = await auth_handler.get_session_data( + session, UIAuthSessionDataConstants.REGISTRATION_TOKEN + ) + if stored_token: + if token != stored_token: + raise LoginError( + 400, "Registration token has changed", Codes.INVALID_PARAM + ) + else: + return token + + if await self.store.registration_token_is_valid(token): + # Increment pending counter, so that if token has limited uses it + # can't be used up by someone else in the meantime. + await self.store.set_registration_token_pending(token) + # Store the token in the UIA session, so that once registration + # is complete `completed` can be incremented. + await auth_handler.set_session_data( + session, + UIAuthSessionDataConstants.REGISTRATION_TOKEN, + token, + ) + # The token will be stored as the result of the authentication stage + # in ui_auth_sessions_credentials. This allows the pending counter + # for tokens to be decremented when expired sessions are deleted. + return token + else: + raise LoginError( + 401, "Invalid registration token", errcode=Codes.UNAUTHORIZED + ) + + INTERACTIVE_AUTH_CHECKERS = [ DummyAuthChecker, TermsAuthChecker, RecaptchaAuthChecker, EmailIdentityAuthChecker, MsisdnAuthChecker, + RegistrationTokenAuthChecker, ] """A list of UserInteractiveAuthChecker classes"""