diff options
Diffstat (limited to 'synapse/handlers')
-rw-r--r-- | synapse/handlers/device.py | 4 | ||||
-rw-r--r-- | synapse/handlers/e2e_keys.py | 12 | ||||
-rw-r--r-- | synapse/handlers/federation.py | 161 | ||||
-rw-r--r-- | synapse/handlers/presence.py | 278 | ||||
-rw-r--r-- | synapse/handlers/room_member.py | 75 | ||||
-rw-r--r-- | synapse/handlers/sync.py | 10 |
6 files changed, 398 insertions, 142 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 54293d0b9c..7e76db3e2a 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -631,7 +631,7 @@ class DeviceListUpdater: max_len=10000, expiry_ms=30 * 60 * 1000, iterable=True, - ) + ) # type: ExpiringCache[str, Set[str]] # Attempt to resync out of sync device lists every 30s. self._resync_retry_in_progress = False @@ -760,7 +760,7 @@ class DeviceListUpdater: """Given a list of updates for a user figure out if we need to do a full resync, or whether we have enough data that we can just apply the delta. """ - seen_updates = self._seen_updates.get(user_id, set()) + seen_updates = self._seen_updates.get(user_id, set()) # type: Set[str] extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 739653a3fa..92b18378fc 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -38,7 +38,6 @@ from synapse.types import ( ) from synapse.util import json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer -from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination if TYPE_CHECKING: @@ -1292,17 +1291,6 @@ class SigningKeyEduUpdater: # user_id -> list of updates waiting to be handled. self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]] - # Recently seen stream ids. We don't bother keeping these in the DB, - # but they're useful to have them about to reduce the number of spurious - # resyncs. - self._seen_updates = ExpiringCache( - cache_name="signing_key_update_edu", - clock=self.clock, - max_len=10000, - expiry_ms=30 * 60 * 1000, - iterable=True, - ) - async def incoming_signing_key_update( self, origin: str, edu_content: JsonDict ) -> None: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3ebee38ebe..5ea8a7b603 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -21,7 +21,17 @@ import itertools import logging from collections.abc import Container from http import HTTPStatus -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import ( + TYPE_CHECKING, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) import attr from signedjson.key import decode_verify_key_bytes @@ -171,15 +181,17 @@ class FederationHandler(BaseHandler): self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages - async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: + async def on_receive_pdu( + self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False + ) -> None: """Process a PDU received via a federation /send/ transaction, or via backfill of missing prev_events Args: - origin (str): server which initiated the /send/ transaction. Will + origin: server which initiated the /send/ transaction. Will be used to fetch missing events or state. - pdu (FrozenEvent): received PDU - sent_to_us_directly (bool): True if this event was pushed to us; False if + pdu: received PDU + sent_to_us_directly: True if this event was pushed to us; False if we pulled it as the result of a missing prev_event. """ @@ -411,13 +423,15 @@ class FederationHandler(BaseHandler): await self._process_received_pdu(origin, pdu, state=state) - async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): + async def _get_missing_events_for_pdu( + self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int + ) -> None: """ Args: - origin (str): Origin of the pdu. Will be called to get the missing events + origin: Origin of the pdu. Will be called to get the missing events pdu: received pdu - prevs (set(str)): List of event ids which we are missing - min_depth (int): Minimum depth of events to return. + prevs: List of event ids which we are missing + min_depth: Minimum depth of events to return. """ room_id = pdu.room_id @@ -778,7 +792,7 @@ class FederationHandler(BaseHandler): origin: str, event: EventBase, state: Optional[Iterable[EventBase]], - ): + ) -> None: """Called when we have a new pdu. We need to do auth checks and put it through the StateHandler. @@ -887,7 +901,9 @@ class FederationHandler(BaseHandler): logger.exception("Failed to resync device for %s", sender) @log_function - async def backfill(self, dest, room_id, limit, extremities): + async def backfill( + self, dest: str, room_id: str, limit: int, extremities: List[str] + ) -> List[EventBase]: """Trigger a backfill request to `dest` for the given `room_id` This will attempt to get more events from the remote. If the other side @@ -1142,16 +1158,15 @@ class FederationHandler(BaseHandler): curr_state = await self.state_handler.get_current_state(room_id) - def get_domains_from_state(state): + def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]: """Get joined domains from state Args: - state (dict[tuple, FrozenEvent]): State map from type/state - key to event. + state: State map from type/state key to event. Returns: - list[tuple[str, int]]: Returns a list of servers with the - lowest depth of their joins. Sorted by lowest depth first. + Returns a list of servers with the lowest depth of their joins. + Sorted by lowest depth first. """ joined_users = [ (state_key, int(event.depth)) @@ -1179,7 +1194,7 @@ class FederationHandler(BaseHandler): domain for domain, depth in curr_domains if domain != self.server_name ] - async def try_backfill(domains): + async def try_backfill(domains: List[str]) -> bool: # TODO: Should we try multiple of these at a time? for dom in domains: try: @@ -1258,21 +1273,25 @@ class FederationHandler(BaseHandler): } for e_id, _ in sorted_extremeties_tuple: - likely_domains = get_domains_from_state(states[e_id]) + likely_extremeties_domains = get_domains_from_state(states[e_id]) success = await try_backfill( - [dom for dom, _ in likely_domains if dom not in tried_domains] + [ + dom + for dom, _ in likely_extremeties_domains + if dom not in tried_domains + ] ) if success: return True - tried_domains.update(dom for dom, _ in likely_domains) + tried_domains.update(dom for dom, _ in likely_extremeties_domains) return False async def _get_events_and_persist( self, destination: str, room_id: str, events: Iterable[str] - ): + ) -> None: """Fetch the given events from a server, and persist them as outliers. This function *does not* recursively get missing auth events of the @@ -1348,7 +1367,7 @@ class FederationHandler(BaseHandler): event_infos, ) - def _sanity_check_event(self, ev): + def _sanity_check_event(self, ev: EventBase) -> None: """ Do some early sanity checks of a received event @@ -1357,9 +1376,7 @@ class FederationHandler(BaseHandler): or cascade of event fetches. Args: - ev (synapse.events.EventBase): event to be checked - - Returns: None + ev: event to be checked Raises: SynapseError if the event does not pass muster @@ -1380,7 +1397,7 @@ class FederationHandler(BaseHandler): ) raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events") - async def send_invite(self, target_host, event): + async def send_invite(self, target_host: str, event: EventBase) -> EventBase: """Sends the invite to the remote server for signing. Invites must be signed by the invitee's server before distribution. @@ -1528,12 +1545,13 @@ class FederationHandler(BaseHandler): run_in_background(self._handle_queued_pdus, room_queue) - async def _handle_queued_pdus(self, room_queue): + async def _handle_queued_pdus( + self, room_queue: List[Tuple[EventBase, str]] + ) -> None: """Process PDUs which got queued up while we were busy send_joining. Args: - room_queue (list[FrozenEvent, str]): list of PDUs to be processed - and the servers that sent them + room_queue: list of PDUs to be processed and the servers that sent them """ for p, origin in room_queue: try: @@ -1612,7 +1630,7 @@ class FederationHandler(BaseHandler): return event - async def on_send_join_request(self, origin, pdu): + async def on_send_join_request(self, origin: str, pdu: EventBase) -> JsonDict: """We have received a join event for a room. Fully process it and respond with the current state and auth chains. """ @@ -1668,7 +1686,7 @@ class FederationHandler(BaseHandler): async def on_invite_request( self, origin: str, event: EventBase, room_version: RoomVersion - ): + ) -> EventBase: """We've got an invite event. Process and persist it. Sign it. Respond with the now signed event. @@ -1841,7 +1859,7 @@ class FederationHandler(BaseHandler): return event - async def on_send_leave_request(self, origin, pdu): + async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None: """ We have received a leave event for a room. Fully process it.""" event = pdu @@ -1969,12 +1987,17 @@ class FederationHandler(BaseHandler): else: return None - async def get_min_depth_for_context(self, context): + async def get_min_depth_for_context(self, context: str) -> int: return await self.store.get_min_depth(context) async def _handle_new_event( - self, origin, event, state=None, auth_events=None, backfilled=False - ): + self, + origin: str, + event: EventBase, + state: Optional[Iterable[EventBase]] = None, + auth_events: Optional[MutableStateMap[EventBase]] = None, + backfilled: bool = False, + ) -> EventContext: context = await self._prep_event( origin, event, state=state, auth_events=auth_events, backfilled=backfilled ) @@ -2280,40 +2303,14 @@ class FederationHandler(BaseHandler): logger.warning("Soft-failing %r because %s", event, e) event.internal_metadata.soft_failed = True - async def on_query_auth( - self, origin, event_id, room_id, remote_auth_chain, rejects, missing - ): - in_room = await self.auth.check_host_in_room(room_id, origin) - if not in_room: - raise AuthError(403, "Host not in room.") - - event = await self.store.get_event(event_id, check_room_id=room_id) - - # Just go through and process each event in `remote_auth_chain`. We - # don't want to fall into the trap of `missing` being wrong. - for e in remote_auth_chain: - try: - await self._handle_new_event(origin, e) - except AuthError: - pass - - # Now get the current auth_chain for the event. - local_auth_chain = await self.store.get_auth_chain( - room_id, list(event.auth_event_ids()), include_given=True - ) - - # TODO: Check if we would now reject event_id. If so we need to tell - # everyone. - - ret = await self.construct_auth_difference(local_auth_chain, remote_auth_chain) - - logger.debug("on_query_auth returning: %s", ret) - - return ret - async def on_get_missing_events( - self, origin, room_id, earliest_events, latest_events, limit - ): + self, + origin: str, + room_id: str, + earliest_events: List[str], + latest_events: List[str], + limit: int, + ) -> List[EventBase]: in_room = await self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -2617,8 +2614,8 @@ class FederationHandler(BaseHandler): assumes that we have already processed all events in remote_auth Params: - local_auth (list) - remote_auth (list) + local_auth + remote_auth Returns: dict @@ -2742,8 +2739,8 @@ class FederationHandler(BaseHandler): @log_function async def exchange_third_party_invite( - self, sender_user_id, target_user_id, room_id, signed - ): + self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict + ) -> None: third_party_invite = {"signed": signed} event_dict = { @@ -2835,8 +2832,12 @@ class FederationHandler(BaseHandler): await member_handler.send_membership_event(None, event, context) async def add_display_name_to_third_party_invite( - self, room_version, event_dict, event, context - ): + self, + room_version: str, + event_dict: JsonDict, + event: EventBase, + context: EventContext, + ) -> Tuple[EventBase, EventContext]: key = ( EventTypes.ThirdPartyInvite, event.content["third_party_invite"]["signed"]["token"], @@ -2872,13 +2873,13 @@ class FederationHandler(BaseHandler): EventValidator().validate_new(event, self.config) return (event, context) - async def _check_signature(self, event, context): + async def _check_signature(self, event: EventBase, context: EventContext) -> None: """ Checks that the signature in the event is consistent with its invite. Args: - event (Event): The m.room.member event to check - context (EventContext): + event: The m.room.member event to check + context: Raises: AuthError: if signature didn't match any keys, or key has been @@ -2964,13 +2965,13 @@ class FederationHandler(BaseHandler): raise last_exception - async def _check_key_revocation(self, public_key, url): + async def _check_key_revocation(self, public_key: str, url: str) -> None: """ Checks whether public_key has been revoked. Args: - public_key (str): base-64 encoded public key. - url (str): Key revocation URL. + public_key: base-64 encoded public key. + url: Key revocation URL. Raises: AuthError: if they key has been revoked. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index da92feacc9..c817f2952d 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -25,7 +25,17 @@ The methods that define policy are: import abc import logging from contextlib import contextmanager -from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple +from typing import ( + TYPE_CHECKING, + Dict, + FrozenSet, + Iterable, + List, + Optional, + Set, + Tuple, + Union, +) from prometheus_client import Counter from typing_extensions import ContextManager @@ -34,6 +44,7 @@ import synapse.metrics from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState +from synapse.events.presence_router import PresenceRouter from synapse.logging.context import run_in_background from synapse.logging.utils import log_function from synapse.metrics import LaterGauge @@ -42,7 +53,7 @@ from synapse.state import StateHandler from synapse.storage.databases.main import DataStore from synapse.types import Collection, JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer @@ -209,6 +220,7 @@ class PresenceHandler(BasePresenceHandler): self.notifier = hs.get_notifier() self.federation = hs.get_federation_sender() self.state = hs.get_state_handler() + self.presence_router = hs.get_presence_router() self._presence_enabled = hs.config.use_presence federation_registry = hs.get_federation_registry() @@ -653,7 +665,7 @@ class PresenceHandler(BasePresenceHandler): """ stream_id, max_token = await self.store.update_presence(states) - parties = await get_interested_parties(self.store, states) + parties = await get_interested_parties(self.store, self.presence_router, states) room_ids_to_states, users_to_states = parties self.notifier.on_new_event( @@ -1041,7 +1053,12 @@ class PresenceEventSource: # # Presence -> Notifier -> PresenceEventSource -> Presence # + # Same with get_module_api, get_presence_router + # + # AuthHandler -> Notifier -> PresenceEventSource -> ModuleApi -> AuthHandler self.get_presence_handler = hs.get_presence_handler + self.get_module_api = hs.get_module_api + self.get_presence_router = hs.get_presence_router self.clock = hs.get_clock() self.store = hs.get_datastore() self.state = hs.get_state_handler() @@ -1055,7 +1072,7 @@ class PresenceEventSource: include_offline=True, explicit_room_id=None, **kwargs - ): + ) -> Tuple[List[UserPresenceState], int]: # The process for getting presence events are: # 1. Get the rooms the user is in. # 2. Get the list of user in the rooms. @@ -1068,7 +1085,17 @@ class PresenceEventSource: # We don't try and limit the presence updates by the current token, as # sending down the rare duplicate is not a concern. + user_id = user.to_string() + stream_change_cache = self.store.presence_stream_cache + with Measure(self.clock, "presence.get_new_events"): + if user_id in self.get_module_api()._send_full_presence_to_local_users: + # This user has been specified by a module to receive all current, online + # user presence. Removing from_key and setting include_offline to false + # will do effectively this. + from_key = None + include_offline = False + if from_key is not None: from_key = int(from_key) @@ -1091,59 +1118,209 @@ class PresenceEventSource: # doesn't return. C.f. #5503. return [], max_token - presence = self.get_presence_handler() - stream_change_cache = self.store.presence_stream_cache - + # Figure out which other users this user should receive updates for users_interested_in = await self._get_interested_in(user, explicit_room_id) - user_ids_changed = set() # type: Collection[str] - changed = None - if from_key: - changed = stream_change_cache.get_all_entities_changed(from_key) + # We have a set of users that we're interested in the presence of. We want to + # cross-reference that with the users that have actually changed their presence. - if changed is not None and len(changed) < 500: - assert isinstance(user_ids_changed, set) + # Check whether this user should see all user updates - # For small deltas, its quicker to get all changes and then - # work out if we share a room or they're in our presence list - get_updates_counter.labels("stream").inc() - for other_user_id in changed: - if other_user_id in users_interested_in: - user_ids_changed.add(other_user_id) - else: - # Too many possible updates. Find all users we can see and check - # if any of them have changed. - get_updates_counter.labels("full").inc() + if users_interested_in == PresenceRouter.ALL_USERS: + # Provide presence state for all users + presence_updates = await self._filter_all_presence_updates_for_user( + user_id, include_offline, from_key + ) - if from_key: - user_ids_changed = stream_change_cache.get_entities_changed( - users_interested_in, from_key + # Remove the user from the list of users to receive all presence + if user_id in self.get_module_api()._send_full_presence_to_local_users: + self.get_module_api()._send_full_presence_to_local_users.remove( + user_id ) + + return presence_updates, max_token + + # Make mypy happy. users_interested_in should now be a set + assert not isinstance(users_interested_in, str) + + # The set of users that we're interested in and that have had a presence update. + # We'll actually pull the presence updates for these users at the end. + interested_and_updated_users = ( + set() + ) # type: Union[Set[str], FrozenSet[str]] + + if from_key: + # First get all users that have had a presence update + updated_users = stream_change_cache.get_all_entities_changed(from_key) + + # Cross-reference users we're interested in with those that have had updates. + # Use a slightly-optimised method for processing smaller sets of updates. + if updated_users is not None and len(updated_users) < 500: + # For small deltas, it's quicker to get all changes and then + # cross-reference with the users we're interested in + get_updates_counter.labels("stream").inc() + for other_user_id in updated_users: + if other_user_id in users_interested_in: + # mypy thinks this variable could be a FrozenSet as it's possibly set + # to one in the `get_entities_changed` call below, and `add()` is not + # method on a FrozenSet. That doesn't affect us here though, as + # `interested_and_updated_users` is clearly a set() above. + interested_and_updated_users.add(other_user_id) # type: ignore else: - user_ids_changed = users_interested_in + # Too many possible updates. Find all users we can see and check + # if any of them have changed. + get_updates_counter.labels("full").inc() - updates = await presence.current_state_for_users(user_ids_changed) + interested_and_updated_users = ( + stream_change_cache.get_entities_changed( + users_interested_in, from_key + ) + ) + else: + # No from_key has been specified. Return the presence for all users + # this user is interested in + interested_and_updated_users = users_interested_in + + # Retrieve the current presence state for each user + users_to_state = await self.get_presence_handler().current_state_for_users( + interested_and_updated_users + ) + presence_updates = list(users_to_state.values()) - if include_offline: - return (list(updates.values()), max_token) + # Remove the user from the list of users to receive all presence + if user_id in self.get_module_api()._send_full_presence_to_local_users: + self.get_module_api()._send_full_presence_to_local_users.remove(user_id) + + if not include_offline: + # Filter out offline presence states + presence_updates = self._filter_offline_presence_state(presence_updates) + + return presence_updates, max_token + + async def _filter_all_presence_updates_for_user( + self, + user_id: str, + include_offline: bool, + from_key: Optional[int] = None, + ) -> List[UserPresenceState]: + """ + Computes the presence updates a user should receive. + + First pulls presence updates from the database. Then consults PresenceRouter + for whether any updates should be excluded by user ID. + + Args: + user_id: The User ID of the user to compute presence updates for. + include_offline: Whether to include offline presence states from the results. + from_key: The minimum stream ID of updates to pull from the database + before filtering. + + Returns: + A list of presence states for the given user to receive. + """ + if from_key: + # Only return updates since the last sync + updated_users = self.store.presence_stream_cache.get_all_entities_changed( + from_key + ) + if not updated_users: + updated_users = [] + + # Get the actual presence update for each change + users_to_state = await self.get_presence_handler().current_state_for_users( + updated_users + ) + presence_updates = list(users_to_state.values()) + + if not include_offline: + # Filter out offline states + presence_updates = self._filter_offline_presence_state(presence_updates) else: - return ( - [s for s in updates.values() if s.state != PresenceState.OFFLINE], - max_token, + users_to_state = await self.store.get_presence_for_all_users( + include_offline=include_offline ) + presence_updates = list(users_to_state.values()) + + # TODO: This feels wildly inefficient, and it's unfortunate we need to ask the + # module for information on a number of users when we then only take the info + # for a single user + + # Filter through the presence router + users_to_state_set = await self.get_presence_router().get_users_for_states( + presence_updates + ) + + # We only want the mapping for the syncing user + presence_updates = list(users_to_state_set[user_id]) + + # Return presence information for all users + return presence_updates + + def _filter_offline_presence_state( + self, presence_updates: Iterable[UserPresenceState] + ) -> List[UserPresenceState]: + """Given an iterable containing user presence updates, return a list with any offline + presence states removed. + + Args: + presence_updates: Presence states to filter + + Returns: + A new list with any offline presence states removed. + """ + return [ + update + for update in presence_updates + if update.state != PresenceState.OFFLINE + ] + def get_current_key(self): return self.store.get_current_presence_token() @cached(num_args=2, cache_context=True) - async def _get_interested_in(self, user, explicit_room_id, cache_context): + async def _get_interested_in( + self, + user: UserID, + explicit_room_id: Optional[str] = None, + cache_context: Optional[_CacheContext] = None, + ) -> Union[Set[str], str]: """Returns the set of users that the given user should see presence - updates for + updates for. + + Args: + user: The user to retrieve presence updates for. + explicit_room_id: The users that are in the room will be returned. + + Returns: + A set of user IDs to return presence updates for, or "ALL" to return all + known updates. """ user_id = user.to_string() users_interested_in = set() users_interested_in.add(user_id) # So that we receive our own presence + # cache_context isn't likely to ever be None due to the @cached decorator, + # but we can't have a non-optional argument after the optional argument + # explicit_room_id either. Assert cache_context is not None so we can use it + # without mypy complaining. + assert cache_context + + # Check with the presence router whether we should poll additional users for + # their presence information + additional_users = await self.get_presence_router().get_interested_users( + user.to_string() + ) + if additional_users == PresenceRouter.ALL_USERS: + # If the module requested that this user see the presence updates of *all* + # users, then simply return that instead of calculating what rooms this + # user shares + return PresenceRouter.ALL_USERS + + # Add the additional users from the router + users_interested_in.update(additional_users) + + # Find the users who share a room with this user users_who_share_room = await self.store.get_users_who_share_room_with_user( user_id, on_invalidate=cache_context.invalidate ) @@ -1314,14 +1491,15 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now): async def get_interested_parties( - store: DataStore, states: List[UserPresenceState] + store: DataStore, presence_router: PresenceRouter, states: List[UserPresenceState] ) -> Tuple[Dict[str, List[UserPresenceState]], Dict[str, List[UserPresenceState]]]: """Given a list of states return which entities (rooms, users) are interested in the given states. Args: - store - states + store: The homeserver's data store. + presence_router: A module for augmenting the destinations for presence updates. + states: A list of incoming user presence updates. Returns: A 2-tuple of `(room_ids_to_states, users_to_states)`, @@ -1337,11 +1515,22 @@ async def get_interested_parties( # Always notify self users_to_states.setdefault(state.user_id, []).append(state) + # Ask a presence routing module for any additional parties if one + # is loaded. + router_users_to_states = await presence_router.get_users_for_states(states) + + # Update the dictionaries with additional destinations and state to send + for user_id, user_states in router_users_to_states.items(): + users_to_states.setdefault(user_id, []).extend(user_states) + return room_ids_to_states, users_to_states async def get_interested_remotes( - store: DataStore, states: List[UserPresenceState], state_handler: StateHandler + store: DataStore, + presence_router: PresenceRouter, + states: List[UserPresenceState], + state_handler: StateHandler, ) -> List[Tuple[Collection[str], List[UserPresenceState]]]: """Given a list of presence states figure out which remote servers should be sent which. @@ -1349,9 +1538,10 @@ async def get_interested_remotes( All the presence states should be for local users only. Args: - store - states - state_handler + store: The homeserver's data store. + presence_router: A module for augmenting the destinations for presence updates. + states: A list of incoming user presence updates. + state_handler: Returns: A list of 2-tuples of destinations and states, where for @@ -1363,7 +1553,9 @@ async def get_interested_remotes( # First we look up the rooms each user is in (as well as any explicit # subscriptions), then for each distinct room we look up the remote # hosts in those rooms. - room_ids_to_states, users_to_states = await get_interested_parties(store, states) + room_ids_to_states, users_to_states = await get_interested_parties( + store, presence_router, states + ) for room_id, states in room_ids_to_states.items(): hosts = await state_handler.get_current_hosts_in_room(room_id) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 1cf12f3255..894ef859f4 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -20,7 +20,7 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from synapse import types -from synapse.api.constants import AccountDataTypes, EventTypes, Membership +from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules, Membership from synapse.api.errors import ( AuthError, Codes, @@ -29,6 +29,7 @@ from synapse.api.errors import ( SynapseError, ) from synapse.api.ratelimiting import Ratelimiter +from synapse.api.room_versions import RoomVersion from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID @@ -178,6 +179,62 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id) + async def _can_join_without_invite( + self, state_ids: StateMap[str], room_version: RoomVersion, user_id: str + ) -> bool: + """ + Check whether a user can join a room without an invite. + + When joining a room with restricted joined rules (as defined in MSC3083), + the membership of spaces must be checked during join. + + Args: + state_ids: The state of the room as it currently is. + room_version: The room version of the room being joined. + user_id: The user joining the room. + + Returns: + True if the user can join the room, false otherwise. + """ + # This only applies to room versions which support the new join rule. + if not room_version.msc3083_join_rules: + return True + + # If there's no join rule, then it defaults to public (so this doesn't apply). + join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None) + if not join_rules_event_id: + return True + + # If the join rule is not restricted, this doesn't apply. + join_rules_event = await self.store.get_event(join_rules_event_id) + if join_rules_event.content.get("join_rule") != JoinRules.MSC3083_RESTRICTED: + return True + + # If allowed is of the wrong form, then only allow invited users. + allowed_spaces = join_rules_event.content.get("allow", []) + if not isinstance(allowed_spaces, list): + return False + + # Get the list of joined rooms and see if there's an overlap. + joined_rooms = await self.store.get_rooms_for_user(user_id) + + # Pull out the other room IDs, invalid data gets filtered. + for space in allowed_spaces: + if not isinstance(space, dict): + continue + + space_id = space.get("space") + if not isinstance(space_id, str): + continue + + # The user was joined to one of the spaces specified, they can join + # this room! + if space_id in joined_rooms: + return True + + # The user was not in any of the required spaces. + return False + async def _local_membership_update( self, requester: Requester, @@ -235,9 +292,25 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if event.membership == Membership.JOIN: newly_joined = True + user_is_invited = False if prev_member_event_id: prev_member_event = await self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN + user_is_invited = prev_member_event.membership == Membership.INVITE + + # If the member is not already in the room and is not accepting an invite, + # check if they should be allowed access via membership in a space. + if ( + newly_joined + and not user_is_invited + and not await self._can_join_without_invite( + prev_state_ids, event.room_version, user_id + ) + ): + raise AuthError( + 403, + "You do not belong to any of the required spaces to join this room.", + ) # Only rate-limit if the user actually joined the room, otherwise we'll end # up blocking profile updates. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 7b356ba7e5..ff11266c67 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -252,13 +252,13 @@ class SyncHandler: self.storage = hs.get_storage() self.state_store = self.storage.state - # ExpiringCache((User, Device)) -> LruCache(state_key => event_id) + # ExpiringCache((User, Device)) -> LruCache(user_id => event_id) self.lazy_loaded_members_cache = ExpiringCache( "lazy_loaded_members_cache", self.clock, max_len=0, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, - ) + ) # type: ExpiringCache[Tuple[str, Optional[str]], LruCache[str, str]] async def wait_for_sync_for_user( self, @@ -733,8 +733,10 @@ class SyncHandler: def get_lazy_loaded_members_cache( self, cache_key: Tuple[str, Optional[str]] - ) -> LruCache: - cache = self.lazy_loaded_members_cache.get(cache_key) + ) -> LruCache[str, str]: + cache = self.lazy_loaded_members_cache.get( + cache_key + ) # type: Optional[LruCache[str, str]] if cache is None: logger.debug("creating LruCache for %r", cache_key) cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE) |