diff options
Diffstat (limited to 'synapse')
22 files changed, 705 insertions, 419 deletions
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 2e3add7ac5..ab801108ca 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -122,6 +122,7 @@ from synapse.storage.data_stores.main.monthly_active_users import ( MonthlyActiveUsersWorkerStore, ) from synapse.storage.data_stores.main.presence import UserPresenceState +from synapse.storage.data_stores.main.search import SearchWorkerStore from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore from synapse.storage.data_stores.main.user_directory import UserDirectoryStore from synapse.types import ReadReceipt @@ -451,6 +452,7 @@ class GenericWorkerSlavedStore( SlavedFilteringStore, MonthlyActiveUsersWorkerStore, MediaRepositoryStore, + SearchWorkerStore, BaseSlavedStore, ): def __init__(self, database, db_conn, hs): diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index aea3985a5f..1b13e84425 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -270,7 +270,7 @@ class ApplicationService(object): def is_exclusive_room(self, room_id): return self._is_exclusive(ApplicationService.NS_ROOMS, room_id) - def get_exlusive_user_regexes(self): + def get_exclusive_user_regexes(self): """Get the list of regexes used to determine if a user is exclusively registered by the AS """ diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 5a5b568a95..c582355146 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import Set, Tuple +from typing import List, Optional, Set, Tuple from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes @@ -29,18 +29,19 @@ from synapse.api.room_versions import ( EventFormatVersions, RoomVersion, ) -from synapse.types import UserID, get_domain_from_id +from synapse.events import EventBase +from synapse.types import StateMap, UserID, get_domain_from_id logger = logging.getLogger(__name__) def check( room_version_obj: RoomVersion, - event, - auth_events, - do_sig_check=True, - do_size_check=True, -): + event: EventBase, + auth_events: StateMap[EventBase], + do_sig_check: bool = True, + do_size_check: bool = True, +) -> None: """ Checks if this event is correctly authed. Args: @@ -189,7 +190,7 @@ def check( logger.debug("Allowing! %s", event) -def _check_size_limits(event): +def _check_size_limits(event: EventBase) -> None: def too_big(field): raise EventSizeError("%s too large" % (field,)) @@ -207,13 +208,18 @@ def _check_size_limits(event): too_big("event") -def _can_federate(event, auth_events): +def _can_federate(event: EventBase, auth_events: StateMap[EventBase]) -> bool: creation_event = auth_events.get((EventTypes.Create, "")) + # There should always be a creation event, but if not don't federate. + if not creation_event: + return False return creation_event.content.get("m.federate", True) is True -def _is_membership_change_allowed(event, auth_events): +def _is_membership_change_allowed( + event: EventBase, auth_events: StateMap[EventBase] +) -> None: membership = event.content["membership"] # Check if this is the room creator joining: @@ -339,21 +345,25 @@ def _is_membership_change_allowed(event, auth_events): raise AuthError(500, "Unknown membership %s" % membership) -def _check_event_sender_in_room(event, auth_events): +def _check_event_sender_in_room( + event: EventBase, auth_events: StateMap[EventBase] +) -> None: key = (EventTypes.Member, event.user_id) member_event = auth_events.get(key) - return _check_joined_room(member_event, event.user_id, event.room_id) + _check_joined_room(member_event, event.user_id, event.room_id) -def _check_joined_room(member, user_id, room_id): +def _check_joined_room(member: Optional[EventBase], user_id: str, room_id: str) -> None: if not member or member.membership != Membership.JOIN: raise AuthError( 403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member)) ) -def get_send_level(etype, state_key, power_levels_event): +def get_send_level( + etype: str, state_key: Optional[str], power_levels_event: Optional[EventBase] +) -> int: """Get the power level required to send an event of a given type The federation spec [1] refers to this as "Required Power Level". @@ -361,13 +371,13 @@ def get_send_level(etype, state_key, power_levels_event): https://matrix.org/docs/spec/server_server/unstable.html#definitions Args: - etype (str): type of event - state_key (str|None): state_key of state event, or None if it is not + etype: type of event + state_key: state_key of state event, or None if it is not a state event. - power_levels_event (synapse.events.EventBase|None): power levels event + power_levels_event: power levels event in force at this point in the room Returns: - int: power level required to send this event. + power level required to send this event. """ if power_levels_event: @@ -388,7 +398,7 @@ def get_send_level(etype, state_key, power_levels_event): return int(send_level) -def _can_send_event(event, auth_events): +def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool: power_levels_event = _get_power_level_event(auth_events) send_level = get_send_level(event.type, event.get("state_key"), power_levels_event) @@ -410,7 +420,9 @@ def _can_send_event(event, auth_events): return True -def check_redaction(room_version_obj: RoomVersion, event, auth_events): +def check_redaction( + room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase], +) -> bool: """Check whether the event sender is allowed to redact the target event. Returns: @@ -442,7 +454,9 @@ def check_redaction(room_version_obj: RoomVersion, event, auth_events): raise AuthError(403, "You don't have permission to redact events") -def _check_power_levels(room_version_obj, event, auth_events): +def _check_power_levels( + room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase], +) -> None: user_list = event.content.get("users", {}) # Validate users for k, v in user_list.items(): @@ -473,7 +487,7 @@ def _check_power_levels(room_version_obj, event, auth_events): ("redact", None), ("kick", None), ("invite", None), - ] + ] # type: List[Tuple[str, Optional[str]]] old_list = current_state.content.get("users", {}) for user in set(list(old_list) + list(user_list)): @@ -503,12 +517,12 @@ def _check_power_levels(room_version_obj, event, auth_events): new_loc = new_loc.get(dir, {}) if level_to_check in old_loc: - old_level = int(old_loc[level_to_check]) + old_level = int(old_loc[level_to_check]) # type: Optional[int] else: old_level = None if level_to_check in new_loc: - new_level = int(new_loc[level_to_check]) + new_level = int(new_loc[level_to_check]) # type: Optional[int] else: new_level = None @@ -534,21 +548,21 @@ def _check_power_levels(room_version_obj, event, auth_events): ) -def _get_power_level_event(auth_events): +def _get_power_level_event(auth_events: StateMap[EventBase]) -> Optional[EventBase]: return auth_events.get((EventTypes.PowerLevels, "")) -def get_user_power_level(user_id, auth_events): +def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int: """Get a user's power level Args: - user_id (str): user's id to look up in power_levels - auth_events (dict[(str, str), synapse.events.EventBase]): + user_id: user's id to look up in power_levels + auth_events: state in force at this point in the room (or rather, a subset of it including at least the create event and power levels event. Returns: - int: the user's power level in this room. + the user's power level in this room. """ power_level_event = _get_power_level_event(auth_events) if power_level_event: @@ -574,7 +588,7 @@ def get_user_power_level(user_id, auth_events): return 0 -def _get_named_level(auth_events, name, default): +def _get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -> int: power_level_event = _get_power_level_event(auth_events) if not power_level_event: @@ -587,7 +601,7 @@ def _get_named_level(auth_events, name, default): return default -def _verify_third_party_invite(event, auth_events): +def _verify_third_party_invite(event: EventBase, auth_events: StateMap[EventBase]): """ Validates that the invite event is authorized by a previous third-party invite. @@ -662,7 +676,7 @@ def get_public_keys(invite_event): return public_keys -def auth_types_for_event(event) -> Set[Tuple[str, str]]: +def auth_types_for_event(event: EventBase) -> Set[Tuple[str, str]]: """Given an event, return a list of (EventType, StateKey) that may be needed to auth the event. The returned list may be a superset of what would actually be required depending on the full state of the room. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 524281d2f1..75b39e878c 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -80,7 +80,9 @@ class AuthHandler(BaseHandler): self.hs = hs # FIXME better possibility to access registrationHandler later? self.macaroon_gen = hs.get_macaroon_generator() self._password_enabled = hs.config.password_enabled - self._sso_enabled = hs.config.saml2_enabled or hs.config.cas_enabled + self._sso_enabled = ( + hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled + ) # we keep this as a list despite the O(N^2) implication so that we can # keep PASSWORD first and avoid confusing clients which pick the first diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 178f263439..4ba8c7fda5 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -311,7 +311,7 @@ class OidcHandler: ``ClientAuth`` to authenticate with the client with its ID and secret. Args: - code: The autorization code we got from the callback. + code: The authorization code we got from the callback. Returns: A dict containing various tokens. @@ -497,11 +497,14 @@ class OidcHandler: return UserInfo(claims) async def handle_redirect_request( - self, request: SynapseRequest, client_redirect_url: bytes - ) -> None: + self, + request: SynapseRequest, + client_redirect_url: bytes, + ui_auth_session_id: Optional[str] = None, + ) -> str: """Handle an incoming request to /login/sso/redirect - It redirects the browser to the authorization endpoint with a few + It returns a redirect to the authorization endpoint with a few parameters: - ``client_id``: the client ID set in ``oidc_config.client_id`` @@ -511,24 +514,32 @@ class OidcHandler: - ``state``: a random string - ``nonce``: a random string - In addition to redirecting the client, we are setting a cookie with + In addition generating a redirect URL, we are setting a cookie with a signed macaroon token containing the state, the nonce and the client_redirect_url params. Those are then checked when the client comes back from the provider. - Args: request: the incoming request from the browser. We'll respond to it with a redirect and a cookie. client_redirect_url: the URL that we should redirect the client to when everything is done + ui_auth_session_id: The session ID of the ongoing UI Auth (or + None if this is a login). + + Returns: + The redirect URL to the authorization endpoint. + """ state = generate_token() nonce = generate_token() cookie = self._generate_oidc_session_token( - state=state, nonce=nonce, client_redirect_url=client_redirect_url.decode(), + state=state, + nonce=nonce, + client_redirect_url=client_redirect_url.decode(), + ui_auth_session_id=ui_auth_session_id, ) request.addCookie( SESSION_COOKIE_NAME, @@ -541,7 +552,7 @@ class OidcHandler: metadata = await self.load_metadata() authorization_endpoint = metadata.get("authorization_endpoint") - uri = prepare_grant_uri( + return prepare_grant_uri( authorization_endpoint, client_id=self._client_auth.client_id, response_type="code", @@ -550,8 +561,6 @@ class OidcHandler: state=state, nonce=nonce, ) - request.redirect(uri) - finish_request(request) async def handle_oidc_callback(self, request: SynapseRequest) -> None: """Handle an incoming request to /_synapse/oidc/callback @@ -625,7 +634,11 @@ class OidcHandler: # Deserialize the session token and verify it. try: - nonce, client_redirect_url = self._verify_oidc_session_token(session, state) + ( + nonce, + client_redirect_url, + ui_auth_session_id, + ) = self._verify_oidc_session_token(session, state) except MacaroonDeserializationException as e: logger.exception("Invalid session") self._render_error(request, "invalid_session", str(e)) @@ -678,15 +691,21 @@ class OidcHandler: return # and finally complete the login - await self._auth_handler.complete_sso_login( - user_id, request, client_redirect_url - ) + if ui_auth_session_id: + await self._auth_handler.complete_sso_ui_auth( + user_id, ui_auth_session_id, request + ) + else: + await self._auth_handler.complete_sso_login( + user_id, request, client_redirect_url + ) def _generate_oidc_session_token( self, state: str, nonce: str, client_redirect_url: str, + ui_auth_session_id: Optional[str], duration_in_ms: int = (60 * 60 * 1000), ) -> str: """Generates a signed token storing data about an OIDC session. @@ -702,6 +721,8 @@ class OidcHandler: nonce: The ``nonce`` parameter passed to the OIDC provider. client_redirect_url: The URL the client gave when it initiated the flow. + ui_auth_session_id: The session ID of the ongoing UI Auth (or + None if this is a login). duration_in_ms: An optional duration for the token in milliseconds. Defaults to an hour. @@ -718,12 +739,19 @@ class OidcHandler: macaroon.add_first_party_caveat( "client_redirect_url = %s" % (client_redirect_url,) ) + if ui_auth_session_id: + macaroon.add_first_party_caveat( + "ui_auth_session_id = %s" % (ui_auth_session_id,) + ) now = self._clock.time_msec() expiry = now + duration_in_ms macaroon.add_first_party_caveat("time < %d" % (expiry,)) + return macaroon.serialize() - def _verify_oidc_session_token(self, session: str, state: str) -> Tuple[str, str]: + def _verify_oidc_session_token( + self, session: str, state: str + ) -> Tuple[str, str, Optional[str]]: """Verifies and extract an OIDC session token. This verifies that a given session token was issued by this homeserver @@ -734,7 +762,7 @@ class OidcHandler: state: The state the OIDC provider gave back Returns: - The nonce and the client_redirect_url for this session + The nonce, client_redirect_url, and ui_auth_session_id for this session """ macaroon = pymacaroons.Macaroon.deserialize(session) @@ -744,17 +772,27 @@ class OidcHandler: v.satisfy_exact("state = %s" % (state,)) v.satisfy_general(lambda c: c.startswith("nonce = ")) v.satisfy_general(lambda c: c.startswith("client_redirect_url = ")) + # Sometimes there's a UI auth session ID, it seems to be OK to attempt + # to always satisfy this. + v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = ")) v.satisfy_general(self._verify_expiry) v.verify(macaroon, self._macaroon_secret_key) - # Extract the `nonce` and `client_redirect_url` from the token + # Extract the `nonce`, `client_redirect_url`, and maybe the + # `ui_auth_session_id` from the token. nonce = self._get_value_from_macaroon(macaroon, "nonce") client_redirect_url = self._get_value_from_macaroon( macaroon, "client_redirect_url" ) + try: + ui_auth_session_id = self._get_value_from_macaroon( + macaroon, "ui_auth_session_id" + ) # type: Optional[str] + except ValueError: + ui_auth_session_id = None - return nonce, client_redirect_url + return nonce, client_redirect_url, ui_auth_session_id def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str: """Extracts a caveat value from a macaroon token. @@ -773,7 +811,7 @@ class OidcHandler: for caveat in macaroon.caveats: if caveat.caveat_id.startswith(prefix): return caveat.caveat_id[len(prefix) :] - raise Exception("No %s caveat in macaroon" % (key,)) + raise ValueError("No %s caveat in macaroon" % (key,)) def _verify_expiry(self, caveat: str) -> bool: prefix = "time < " diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 4ddeba4c97..e51e1c32fe 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -17,13 +17,16 @@ import abc import logging +from typing import Dict, Iterable, List, Optional, Tuple, Union from six.moves import http_client from synapse import types from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, Codes, SynapseError -from synapse.types import Collection, RoomID, UserID +from synapse.events import EventBase +from synapse.events.snapshot import EventContext +from synapse.types import Collection, Requester, RoomAlias, RoomID, UserID from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room, user_left_room @@ -74,84 +77,84 @@ class RoomMemberHandler(object): self.base_handler = BaseHandler(hs) @abc.abstractmethod - async def _remote_join(self, requester, remote_room_hosts, room_id, user, content): + async def _remote_join( + self, + requester: Requester, + remote_room_hosts: List[str], + room_id: str, + user: UserID, + content: dict, + ) -> Optional[dict]: """Try and join a room that this server is not in Args: - requester (Requester) - remote_room_hosts (list[str]): List of servers that can be used - to join via. - room_id (str): Room that we are trying to join - user (UserID): User who is trying to join - content (dict): A dict that should be used as the content of the - join event. - - Returns: - Deferred + requester + remote_room_hosts: List of servers that can be used to join via. + room_id: Room that we are trying to join + user: User who is trying to join + content: A dict that should be used as the content of the join event. """ raise NotImplementedError() @abc.abstractmethod async def _remote_reject_invite( - self, requester, remote_room_hosts, room_id, target, content - ): + self, + requester: Requester, + remote_room_hosts: List[str], + room_id: str, + target: UserID, + content: dict, + ) -> dict: """Attempt to reject an invite for a room this server is not in. If we fail to do so we locally mark the invite as rejected. Args: - requester (Requester) - remote_room_hosts (list[str]): List of servers to use to try and - reject invite - room_id (str) - target (UserID): The user rejecting the invite - content (dict): The content for the rejection event + requester + remote_room_hosts: List of servers to use to try and reject invite + room_id + target: The user rejecting the invite + content: The content for the rejection event Returns: - Deferred[dict]: A dictionary to be returned to the client, may + A dictionary to be returned to the client, may include event_id etc, or nothing if we locally rejected """ raise NotImplementedError() @abc.abstractmethod - async def _user_joined_room(self, target, room_id): + async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Notifies distributor on master process that the user has joined the room. Args: - target (UserID) - room_id (str) - - Returns: - None + target + room_id """ raise NotImplementedError() @abc.abstractmethod - async def _user_left_room(self, target, room_id): + async def _user_left_room(self, target: UserID, room_id: str) -> None: """Notifies distributor on master process that the user has left the room. Args: - target (UserID) - room_id (str) - - Returns: - None + target + room_id """ raise NotImplementedError() async def _local_membership_update( self, - requester, - target, - room_id, - membership, + requester: Requester, + target: UserID, + room_id: str, + membership: str, prev_event_ids: Collection[str], - txn_id=None, - ratelimit=True, - content=None, - require_consent=True, - ): + txn_id: Optional[str] = None, + ratelimit: bool = True, + content: Optional[dict] = None, + require_consent: bool = True, + ) -> EventBase: user_id = target.to_string() if content is None: @@ -214,16 +217,13 @@ class RoomMemberHandler(object): async def copy_room_tags_and_direct_to_room( self, old_room_id, new_room_id, user_id - ): + ) -> None: """Copies the tags and direct room state from one room to another. Args: - old_room_id (str) - new_room_id (str) - user_id (str) - - Returns: - Deferred[None] + old_room_id: The room ID of the old room. + new_room_id: The room ID of the new room. + user_id: The user's ID. """ # Retrieve user account data for predecessor room user_account_data, _ = await self.store.get_account_data_for_user(user_id) @@ -253,17 +253,17 @@ class RoomMemberHandler(object): async def update_membership( self, - requester, - target, - room_id, - action, - txn_id=None, - remote_room_hosts=None, - third_party_signed=None, - ratelimit=True, - content=None, - require_consent=True, - ): + requester: Requester, + target: UserID, + room_id: str, + action: str, + txn_id: Optional[str] = None, + remote_room_hosts: Optional[List[str]] = None, + third_party_signed: Optional[dict] = None, + ratelimit: bool = True, + content: Optional[dict] = None, + require_consent: bool = True, + ) -> Union[EventBase, Optional[dict]]: key = (room_id,) with (await self.member_linearizer.queue(key)): @@ -284,17 +284,17 @@ class RoomMemberHandler(object): async def _update_membership( self, - requester, - target, - room_id, - action, - txn_id=None, - remote_room_hosts=None, - third_party_signed=None, - ratelimit=True, - content=None, - require_consent=True, - ): + requester: Requester, + target: UserID, + room_id: str, + action: str, + txn_id: Optional[str] = None, + remote_room_hosts: Optional[List[str]] = None, + third_party_signed: Optional[dict] = None, + ratelimit: bool = True, + content: Optional[dict] = None, + require_consent: bool = True, + ) -> Union[EventBase, Optional[dict]]: content_specified = bool(content) if content is None: content = {} @@ -468,12 +468,11 @@ class RoomMemberHandler(object): else: # send the rejection to the inviter's HS. remote_room_hosts = remote_room_hosts + [inviter.domain] - res = await self._remote_reject_invite( + return await self._remote_reject_invite( requester, remote_room_hosts, room_id, target, content, ) - return res - res = await self._local_membership_update( + return await self._local_membership_update( requester=requester, target=target, room_id=room_id, @@ -484,9 +483,10 @@ class RoomMemberHandler(object): content=content, require_consent=require_consent, ) - return res - async def transfer_room_state_on_room_upgrade(self, old_room_id, room_id): + async def transfer_room_state_on_room_upgrade( + self, old_room_id: str, room_id: str + ) -> None: """Upon our server becoming aware of an upgraded room, either by upgrading a room ourselves or joining one, we can transfer over information from the previous room. @@ -494,12 +494,8 @@ class RoomMemberHandler(object): well as migrating the room directory state. Args: - old_room_id (str): The ID of the old room - - room_id (str): The ID of the new room - - Returns: - Deferred + old_room_id: The ID of the old room + room_id: The ID of the new room """ logger.info("Transferring room state from %s to %s", old_room_id, room_id) @@ -526,17 +522,16 @@ class RoomMemberHandler(object): # Remove the old room from those groups await self.store.remove_room_from_group(group_id, old_room_id) - async def copy_user_state_on_room_upgrade(self, old_room_id, new_room_id, user_ids): + async def copy_user_state_on_room_upgrade( + self, old_room_id: str, new_room_id: str, user_ids: Iterable[str] + ) -> None: """Copy user-specific information when they join a new room when that new room is the result of a room upgrade Args: - old_room_id (str): The ID of upgraded room - new_room_id (str): The ID of the new room - user_ids (Iterable[str]): User IDs to copy state for - - Returns: - Deferred + old_room_id: The ID of upgraded room + new_room_id: The ID of the new room + user_ids: User IDs to copy state for """ logger.debug( @@ -566,17 +561,23 @@ class RoomMemberHandler(object): ) continue - async def send_membership_event(self, requester, event, context, ratelimit=True): + async def send_membership_event( + self, + requester: Requester, + event: EventBase, + context: EventContext, + ratelimit: bool = True, + ): """ Change the membership status of a user in a room. Args: - requester (Requester): The local user who requested the membership + requester: The local user who requested the membership event. If None, certain checks, like whether this homeserver can act as the sender, will be skipped. - event (SynapseEvent): The membership event. + event: The membership event. context: The context of the event. - ratelimit (bool): Whether to rate limit this request. + ratelimit: Whether to rate limit this request. Raises: SynapseError if there was a problem changing the membership. """ @@ -636,7 +637,9 @@ class RoomMemberHandler(object): if prev_member_event.membership == Membership.JOIN: await self._user_left_room(target_user, room_id) - async def _can_guest_join(self, current_state_ids): + async def _can_guest_join( + self, current_state_ids: Dict[Tuple[str, str], str] + ) -> bool: """ Returns whether a guest can join a room based on its current state. """ @@ -653,12 +656,14 @@ class RoomMemberHandler(object): and guest_access.content["guest_access"] == "can_join" ) - async def lookup_room_alias(self, room_alias): + async def lookup_room_alias( + self, room_alias: RoomAlias + ) -> Tuple[RoomID, List[str]]: """ Get the room ID associated with a room alias. Args: - room_alias (RoomAlias): The alias to look up. + room_alias: The alias to look up. Returns: A tuple of: The room ID as a RoomID object. @@ -682,24 +687,25 @@ class RoomMemberHandler(object): return RoomID.from_string(room_id), servers - async def _get_inviter(self, user_id, room_id): + async def _get_inviter(self, user_id: str, room_id: str) -> Optional[UserID]: invite = await self.store.get_invite_for_local_user_in_room( user_id=user_id, room_id=room_id ) if invite: return UserID.from_string(invite.sender) + return None async def do_3pid_invite( self, - room_id, - inviter, - medium, - address, - id_server, - requester, - txn_id, - id_access_token=None, - ): + room_id: str, + inviter: UserID, + medium: str, + address: str, + id_server: str, + requester: Requester, + txn_id: Optional[str], + id_access_token: Optional[str] = None, + ) -> None: if self.config.block_non_admin_invites: is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: @@ -748,15 +754,15 @@ class RoomMemberHandler(object): async def _make_and_store_3pid_invite( self, - requester, - id_server, - medium, - address, - room_id, - user, - txn_id, - id_access_token=None, - ): + requester: Requester, + id_server: str, + medium: str, + address: str, + room_id: str, + user: UserID, + txn_id: Optional[str], + id_access_token: Optional[str] = None, + ) -> None: room_state = await self.state_handler.get_current_state(room_id) inviter_display_name = "" @@ -830,7 +836,9 @@ class RoomMemberHandler(object): txn_id=txn_id, ) - async def _is_host_in_room(self, current_state_ids): + async def _is_host_in_room( + self, current_state_ids: Dict[Tuple[str, str], str] + ) -> bool: # Have we just created the room, and is this about to be the very # first member event? create_event_id = current_state_ids.get(("m.room.create", "")) @@ -852,7 +860,7 @@ class RoomMemberHandler(object): return False - async def _is_server_notice_room(self, room_id): + async def _is_server_notice_room(self, room_id: str) -> bool: if self._server_notices_mxid is None: return False user_ids = await self.store.get_users_in_room(room_id) @@ -867,13 +875,15 @@ class RoomMemberMasterHandler(RoomMemberHandler): self.distributor.declare("user_joined_room") self.distributor.declare("user_left_room") - async def _is_remote_room_too_complex(self, room_id, remote_room_hosts): + async def _is_remote_room_too_complex( + self, room_id: str, remote_room_hosts: List[str] + ) -> Optional[bool]: """ Check if complexity of a remote room is too great. Args: - room_id (str) - remote_room_hosts (list[str]) + room_id + remote_room_hosts Returns: bool of whether the complexity is too great, or None if unable to be fetched @@ -887,21 +897,26 @@ class RoomMemberMasterHandler(RoomMemberHandler): return complexity["v1"] > max_complexity return None - async def _is_local_room_too_complex(self, room_id): + async def _is_local_room_too_complex(self, room_id: str) -> bool: """ Check if the complexity of a local room is too great. Args: - room_id (str) - - Returns: bool + room_id: The room ID to check for complexity. """ max_complexity = self.hs.config.limit_remote_rooms.complexity complexity = await self.store.get_room_complexity(room_id) return complexity["v1"] > max_complexity - async def _remote_join(self, requester, remote_room_hosts, room_id, user, content): + async def _remote_join( + self, + requester: Requester, + remote_room_hosts: List[str], + room_id: str, + user: UserID, + content: dict, + ) -> None: """Implements RoomMemberHandler._remote_join """ # filter ourselves out of remote_room_hosts: do_invite_join ignores it @@ -961,8 +976,13 @@ class RoomMemberMasterHandler(RoomMemberHandler): ) async def _remote_reject_invite( - self, requester, remote_room_hosts, room_id, target, content - ): + self, + requester: Requester, + remote_room_hosts: List[str], + room_id: str, + target: UserID, + content: dict, + ) -> dict: """Implements RoomMemberHandler._remote_reject_invite """ fed_handler = self.federation_handler @@ -983,17 +1003,17 @@ class RoomMemberMasterHandler(RoomMemberHandler): await self.store.locally_reject_invite(target.to_string(), room_id) return {} - async def _user_joined_room(self, target, room_id): + async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_joined_room """ - return user_joined_room(self.distributor, target, room_id) + user_joined_room(self.distributor, target, room_id) - async def _user_left_room(self, target, room_id): + async def _user_left_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_left_room """ - return user_left_room(self.distributor, target, room_id) + user_left_room(self.distributor, target, room_id) - async def forget(self, user, room_id): + async def forget(self, user: UserID, room_id: str) -> None: user_id = user.to_string() member = await self.state_handler.get_current_state( diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index 0fc54349ab..5c776cc0be 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import List, Optional from synapse.api.errors import SynapseError from synapse.handlers.room_member import RoomMemberHandler @@ -22,6 +23,7 @@ from synapse.replication.http.membership import ( ReplicationRemoteRejectInviteRestServlet as ReplRejectInvite, ReplicationUserJoinedLeftRoomRestServlet as ReplJoinedLeft, ) +from synapse.types import Requester, UserID logger = logging.getLogger(__name__) @@ -34,7 +36,14 @@ class RoomMemberWorkerHandler(RoomMemberHandler): self._remote_reject_client = ReplRejectInvite.make_client(hs) self._notify_change_client = ReplJoinedLeft.make_client(hs) - async def _remote_join(self, requester, remote_room_hosts, room_id, user, content): + async def _remote_join( + self, + requester: Requester, + remote_room_hosts: List[str], + room_id: str, + user: UserID, + content: dict, + ) -> Optional[dict]: """Implements RoomMemberHandler._remote_join """ if len(remote_room_hosts) == 0: @@ -53,8 +62,13 @@ class RoomMemberWorkerHandler(RoomMemberHandler): return ret async def _remote_reject_invite( - self, requester, remote_room_hosts, room_id, target, content - ): + self, + requester: Requester, + remote_room_hosts: List[str], + room_id: str, + target: UserID, + content: dict, + ) -> dict: """Implements RoomMemberHandler._remote_reject_invite """ return await self._remote_reject_client( @@ -65,16 +79,16 @@ class RoomMemberWorkerHandler(RoomMemberHandler): content=content, ) - async def _user_joined_room(self, target, room_id): + async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_joined_room """ - return await self._notify_change_client( + await self._notify_change_client( user_id=target.to_string(), room_id=room_id, change="joined" ) - async def _user_left_room(self, target, room_id): + async def _user_left_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_left_room """ - return await self._notify_change_client( + await self._notify_change_client( user_id=target.to_string(), room_id=room_id, change="left" ) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index b313720a4b..1a1a50a24f 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -15,11 +15,6 @@ # limitations under the License. import logging -from synapse.api.constants import EventTypes -from synapse.replication.tcp.streams.events import ( - EventsStreamCurrentStateRow, - EventsStreamEventRow, -) from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore from synapse.storage.data_stores.main.event_push_actions import ( EventPushActionsWorkerStore, @@ -35,7 +30,6 @@ from synapse.storage.database import Database from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker logger = logging.getLogger(__name__) @@ -62,11 +56,6 @@ class SlavedEventStore( BaseSlavedStore, ): def __init__(self, database: Database, db_conn, hs): - self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") - self._backfill_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", step=-1 - ) - super(SlavedEventStore, self).__init__(database, db_conn, hs) events_max = self._stream_id_gen.get_current_token() @@ -92,81 +81,3 @@ class SlavedEventStore( def get_room_min_stream_ordering(self): return self._backfill_id_gen.get_current_token() - - def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "events": - self._stream_id_gen.advance(token) - for row in rows: - self._process_event_stream_row(token, row) - elif stream_name == "backfill": - self._backfill_id_gen.advance(-token) - for row in rows: - self.invalidate_caches_for_event( - -token, - row.event_id, - row.room_id, - row.type, - row.state_key, - row.redacts, - row.relates_to, - backfilled=True, - ) - return super().process_replication_rows(stream_name, instance_name, token, rows) - - def _process_event_stream_row(self, token, row): - data = row.data - - if row.type == EventsStreamEventRow.TypeId: - self.invalidate_caches_for_event( - token, - data.event_id, - data.room_id, - data.type, - data.state_key, - data.redacts, - data.relates_to, - backfilled=False, - ) - elif row.type == EventsStreamCurrentStateRow.TypeId: - self._curr_state_delta_stream_cache.entity_has_changed( - row.data.room_id, token - ) - - if data.type == EventTypes.Member: - self.get_rooms_for_user_with_stream_ordering.invalidate( - (data.state_key,) - ) - else: - raise Exception("Unknown events stream row type %s" % (row.type,)) - - def invalidate_caches_for_event( - self, - stream_ordering, - event_id, - room_id, - etype, - state_key, - redacts, - relates_to, - backfilled, - ): - self._invalidate_get_event_cache(event_id) - - self.get_latest_event_ids_in_room.invalidate((room_id,)) - - self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) - - if not backfilled: - self._events_stream_cache.entity_has_changed(room_id, stream_ordering) - - if redacts: - self._invalidate_get_event_cache(redacts) - - if etype == EventTypes.Member: - self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) - self.get_invited_rooms_for_local_user.invalidate((state_key,)) - - if relates_to: - self.get_relations_for_event.invalidate_many((relates_to,)) - self.get_aggregation_groups_for_event.invalidate_many((relates_to,)) - self.get_applicable_edit.invalidate((relates_to,)) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 5d5816d7eb..6adb19463a 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -15,19 +15,11 @@ # limitations under the License. from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore -from synapse.storage.database import Database -from ._slaved_id_tracker import SlavedIdTracker from .events import SlavedEventStore class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): - def __init__(self, database: Database, db_conn, hs): - self._push_rules_stream_id_gen = SlavedIdTracker( - db_conn, "push_rules_stream", "stream_id" - ) - super(SlavedPushRuleStore, self).__init__(database, db_conn, hs) - def get_push_rules_stream_token(self): return ( self._push_rules_stream_id_gen.get_current_token(), diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index b48a6a3e91..d42aaff055 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -14,14 +14,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +import heapq import logging from collections import namedtuple -from typing import Any, Awaitable, Callable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + List, + Optional, + Tuple, + TypeVar, +) import attr from synapse.replication.http.streams import ReplicationGetStreamUpdates +if TYPE_CHECKING: + import synapse.server + logger = logging.getLogger(__name__) # the number of rows to request from an update_function. @@ -37,7 +50,7 @@ Token = int # parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's # just a row from a database query, though this is dependent on the stream in question. # -StreamRow = Tuple +StreamRow = TypeVar("StreamRow", bound=Tuple) # The type returned by the update_function of a stream, as well as get_updates(), # get_updates_since, etc. @@ -533,32 +546,63 @@ class AccountDataStream(Stream): """ AccountDataStreamRow = namedtuple( - "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str + "AccountDataStream", + ("user_id", "room_id", "data_type"), # str # Optional[str] # str ) NAME = "account_data" ROW_TYPE = AccountDataStreamRow - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): self.store = hs.get_datastore() super().__init__( hs.get_instance_name(), current_token_without_instance(self.store.get_max_account_data_stream_id), - db_query_to_update_function(self._update_function), + self._update_function, + ) + + async def _update_function( + self, instance_name: str, from_token: int, to_token: int, limit: int + ) -> StreamUpdateResult: + limited = False + global_results = await self.store.get_updated_global_account_data( + from_token, to_token, limit ) - async def _update_function(self, from_token, to_token, limit): - global_results, room_results = await self.store.get_all_updated_account_data( - from_token, from_token, to_token, limit + # if the global results hit the limit, we'll need to limit the room results to + # the same stream token. + if len(global_results) >= limit: + to_token = global_results[-1][0] + limited = True + + room_results = await self.store.get_updated_room_account_data( + from_token, to_token, limit ) - results = list(room_results) - results.extend( - (stream_id, user_id, None, account_data_type) + # likewise, if the room results hit the limit, limit the global results to + # the same stream token. + if len(room_results) >= limit: + to_token = room_results[-1][0] + limited = True + + # convert the global results to the right format, and limit them to the to_token + # at the same time + global_rows = ( + (stream_id, (user_id, None, account_data_type)) for stream_id, user_id, account_data_type in global_results + if stream_id <= to_token + ) + + # we know that the room_results are already limited to `to_token` so no need + # for a check on `stream_id` here. + room_rows = ( + (stream_id, (user_id, room_id, account_data_type)) + for stream_id, user_id, room_id, account_data_type in room_results ) - return results + # we need to return a sorted list, so merge them together. + updates = list(heapq.merge(room_rows, global_rows)) + return updates, to_token, limited class GroupServerStream(Stream): diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index de7eca21f8..d89b2e5532 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -401,19 +401,22 @@ class BaseSSORedirectServlet(RestServlet): PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) - def on_GET(self, request: SynapseRequest): + async def on_GET(self, request: SynapseRequest): args = request.args if b"redirectUrl" not in args: return 400, "Redirect URL not specified for SSO auth" client_redirect_url = args[b"redirectUrl"][0] - sso_url = self.get_sso_url(client_redirect_url) + sso_url = await self.get_sso_url(request, client_redirect_url) request.redirect(sso_url) finish_request(request) - def get_sso_url(self, client_redirect_url: bytes) -> bytes: + async def get_sso_url( + self, request: SynapseRequest, client_redirect_url: bytes + ) -> bytes: """Get the URL to redirect to, to perform SSO auth Args: + request: The client request to redirect. client_redirect_url: the URL that we should redirect the client to when everything is done @@ -428,7 +431,9 @@ class CasRedirectServlet(BaseSSORedirectServlet): def __init__(self, hs): self._cas_handler = hs.get_cas_handler() - def get_sso_url(self, client_redirect_url: bytes) -> bytes: + async def get_sso_url( + self, request: SynapseRequest, client_redirect_url: bytes + ) -> bytes: return self._cas_handler.get_redirect_url( {"redirectUrl": client_redirect_url} ).encode("ascii") @@ -465,11 +470,13 @@ class SAMLRedirectServlet(BaseSSORedirectServlet): def __init__(self, hs): self._saml_handler = hs.get_saml_handler() - def get_sso_url(self, client_redirect_url: bytes) -> bytes: + async def get_sso_url( + self, request: SynapseRequest, client_redirect_url: bytes + ) -> bytes: return self._saml_handler.handle_redirect_request(client_redirect_url) -class OIDCRedirectServlet(RestServlet): +class OIDCRedirectServlet(BaseSSORedirectServlet): """Implementation for /login/sso/redirect for the OIDC login flow.""" PATTERNS = client_patterns("/login/sso/redirect", v1=True) @@ -477,12 +484,12 @@ class OIDCRedirectServlet(RestServlet): def __init__(self, hs): self._oidc_handler = hs.get_oidc_handler() - async def on_GET(self, request): - args = request.args - if b"redirectUrl" not in args: - return 400, "Redirect URL not specified for SSO auth" - client_redirect_url = args[b"redirectUrl"][0] - await self._oidc_handler.handle_redirect_request(request, client_redirect_url) + async def get_sso_url( + self, request: SynapseRequest, client_redirect_url: bytes + ) -> bytes: + return await self._oidc_handler.handle_redirect_request( + request, client_redirect_url + ) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 24dd3d3e96..7bca1326d5 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -131,14 +131,19 @@ class AuthRestServlet(RestServlet): self.registration_handler = hs.get_registration_handler() # SSO configuration. - self._saml_enabled = hs.config.saml2_enabled - if self._saml_enabled: - self._saml_handler = hs.get_saml_handler() self._cas_enabled = hs.config.cas_enabled if self._cas_enabled: self._cas_handler = hs.get_cas_handler() self._cas_server_url = hs.config.cas_server_url self._cas_service_url = hs.config.cas_service_url + self._saml_enabled = hs.config.saml2_enabled + if self._saml_enabled: + self._saml_handler = hs.get_saml_handler() + self._oidc_enabled = hs.config.oidc_enabled + if self._oidc_enabled: + self._oidc_handler = hs.get_oidc_handler() + self._cas_server_url = hs.config.cas_server_url + self._cas_service_url = hs.config.cas_service_url async def on_GET(self, request, stagetype): session = parse_string(request, "session") @@ -172,11 +177,17 @@ class AuthRestServlet(RestServlet): ) elif self._saml_enabled: - client_redirect_url = "" + client_redirect_url = b"" sso_redirect_url = self._saml_handler.handle_redirect_request( client_redirect_url, session ) + elif self._oidc_enabled: + client_redirect_url = b"" + sso_redirect_url = await self._oidc_handler.handle_redirect_request( + request, client_redirect_url, session + ) + else: raise SynapseError(400, "Homeserver not configured for SSO.") diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index 5df9dce79d..4b4763c701 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -24,7 +24,6 @@ from synapse.config.homeserver import HomeServerConfig from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( - ChainedIdGenerator, IdGenerator, MultiWriterIdGenerator, StreamIdGenerator, @@ -125,19 +124,6 @@ class DataStore( self._clock = hs.get_clock() self.database_engine = database.engine - self._stream_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - extra_tables=[("local_invites", "stream_id")], - ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], - ) self._presence_id_gen = StreamIdGenerator( db_conn, "presence_stream", "stream_id" ) @@ -164,9 +150,6 @@ class DataStore( self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") - self._push_rules_stream_id_gen = ChainedIdGenerator( - self._stream_id_gen, db_conn, "push_rules_stream", "stream_id" - ) self._pushers_id_gen = StreamIdGenerator( db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] ) diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py index 46b494b334..f9eef1b78e 100644 --- a/synapse/storage/data_stores/main/account_data.py +++ b/synapse/storage/data_stores/main/account_data.py @@ -16,6 +16,7 @@ import abc import logging +from typing import List, Tuple from canonicaljson import json @@ -175,41 +176,64 @@ class AccountDataWorkerStore(SQLBaseStore): "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn ) - def get_all_updated_account_data( - self, last_global_id, last_room_id, current_id, limit - ): - """Get all the client account_data that has changed on the server + async def get_updated_global_account_data( + self, last_id: int, current_id: int, limit: int + ) -> List[Tuple[int, str, str]]: + """Get the global account_data that has changed, for the account_data stream + Args: - last_global_id(int): The position to fetch from for top level data - last_room_id(int): The position to fetch from for per room data - current_id(int): The position to fetch up to. + last_id: the last stream_id from the previous batch. + current_id: the maximum stream_id to return up to + limit: the maximum number of rows to return + Returns: - A deferred pair of lists of tuples of stream_id int, user_id string, - room_id string, and type string. + A list of tuples of stream_id int, user_id string, + and type string. """ - if last_room_id == current_id and last_global_id == current_id: - return defer.succeed(([], [])) + if last_id == current_id: + return [] - def get_updated_account_data_txn(txn): + def get_updated_global_account_data_txn(txn): sql = ( "SELECT stream_id, user_id, account_data_type" " FROM account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) - txn.execute(sql, (last_global_id, current_id, limit)) - global_results = txn.fetchall() + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() + + return await self.db.runInteraction( + "get_updated_global_account_data", get_updated_global_account_data_txn + ) + + async def get_updated_room_account_data( + self, last_id: int, current_id: int, limit: int + ) -> List[Tuple[int, str, str, str]]: + """Get the global account_data that has changed, for the account_data stream + Args: + last_id: the last stream_id from the previous batch. + current_id: the maximum stream_id to return up to + limit: the maximum number of rows to return + + Returns: + A list of tuples of stream_id int, user_id string, + room_id string and type string. + """ + if last_id == current_id: + return [] + + def get_updated_room_account_data_txn(txn): sql = ( "SELECT stream_id, user_id, room_id, account_data_type" " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) - txn.execute(sql, (last_room_id, current_id, limit)) - room_results = txn.fetchall() - return global_results, room_results + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() - return self.db.runInteraction( - "get_all_updated_account_data_txn", get_updated_account_data_txn + return await self.db.runInteraction( + "get_updated_room_account_data", get_updated_room_account_data_txn ) def get_updated_account_data_for_user(self, user_id, stream_id): diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py index efbc06c796..7a1fe8cdd2 100644 --- a/synapse/storage/data_stores/main/appservice.py +++ b/synapse/storage/data_stores/main/appservice.py @@ -30,12 +30,12 @@ logger = logging.getLogger(__name__) def _make_exclusive_regex(services_cache): - # We precompie a regex constructed from all the regexes that the AS's + # We precompile a regex constructed from all the regexes that the AS's # have registered for exclusive users. exclusive_user_regexes = [ regex.pattern for service in services_cache - for regex in service.get_exlusive_user_regexes() + for regex in service.get_exclusive_user_regexes() ] if exclusive_user_regexes: exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes) diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index 342a87a46b..eac5a4e55b 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -16,8 +16,13 @@ import itertools import logging -from typing import Any, Iterable, Optional +from typing import Any, Iterable, Optional, Tuple +from synapse.api.constants import EventTypes +from synapse.replication.tcp.streams.events import ( + EventsStreamCurrentStateRow, + EventsStreamEventRow, +) from synapse.storage._base import SQLBaseStore from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine @@ -66,7 +71,22 @@ class CacheInvalidationWorkerStore(SQLBaseStore): ) def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "caches": + if stream_name == "events": + for row in rows: + self._process_event_stream_row(token, row) + elif stream_name == "backfill": + for row in rows: + self._invalidate_caches_for_event( + -token, + row.event_id, + row.room_id, + row.type, + row.state_key, + row.redacts, + row.relates_to, + backfilled=True, + ) + elif stream_name == "caches": if self._cache_id_gen: self._cache_id_gen.advance(instance_name, token) @@ -85,6 +105,84 @@ class CacheInvalidationWorkerStore(SQLBaseStore): super().process_replication_rows(stream_name, instance_name, token, rows) + def _process_event_stream_row(self, token, row): + data = row.data + + if row.type == EventsStreamEventRow.TypeId: + self._invalidate_caches_for_event( + token, + data.event_id, + data.room_id, + data.type, + data.state_key, + data.redacts, + data.relates_to, + backfilled=False, + ) + elif row.type == EventsStreamCurrentStateRow.TypeId: + self._curr_state_delta_stream_cache.entity_has_changed( + row.data.room_id, token + ) + + if data.type == EventTypes.Member: + self.get_rooms_for_user_with_stream_ordering.invalidate( + (data.state_key,) + ) + else: + raise Exception("Unknown events stream row type %s" % (row.type,)) + + def _invalidate_caches_for_event( + self, + stream_ordering, + event_id, + room_id, + etype, + state_key, + redacts, + relates_to, + backfilled, + ): + self._invalidate_get_event_cache(event_id) + + self.get_latest_event_ids_in_room.invalidate((room_id,)) + + self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) + + if not backfilled: + self._events_stream_cache.entity_has_changed(room_id, stream_ordering) + + if redacts: + self._invalidate_get_event_cache(redacts) + + if etype == EventTypes.Member: + self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) + self.get_invited_rooms_for_local_user.invalidate((state_key,)) + + if relates_to: + self.get_relations_for_event.invalidate_many((relates_to,)) + self.get_aggregation_groups_for_event.invalidate_many((relates_to,)) + self.get_applicable_edit.invalidate((relates_to,)) + + async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): + """Invalidates the cache and adds it to the cache stream so slaves + will know to invalidate their caches. + + This should only be used to invalidate caches where slaves won't + otherwise know from other replication streams that the cache should + be invalidated. + """ + cache_func = getattr(self, cache_name, None) + if not cache_func: + return + + cache_func.invalidate(keys) + await self.db.runInteraction( + "invalidate_cache_and_stream", + self._send_invalidation_to_replication, + cache_func.__name__, + keys, + ) + def _invalidate_cache_and_stream(self, txn, cache_func, keys): """Invalidates the cache and adds it to the cache stream so slaves will know to invalidate their caches. diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 970c31bd05..9130b74eb5 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -37,8 +37,10 @@ from synapse.events import make_event_from_dict from synapse.events.utils import prune_event from synapse.logging.context import PreserveLoggingContext, current_context from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.database import Database +from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import get_domain_from_id from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks from synapse.util.iterutils import batch_iter @@ -74,6 +76,31 @@ class EventsWorkerStore(SQLBaseStore): def __init__(self, database: Database, db_conn, hs): super(EventsWorkerStore, self).__init__(database, db_conn, hs) + if hs.config.worker_app is None: + # We are the process in charge of generating stream ids for events, + # so instantiate ID generators based on the database + self._stream_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + extra_tables=[("local_invites", "stream_id")], + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + ) + else: + # Another process is in charge of persisting events and generating + # stream IDs: rely on the replication streams to let us know which + # IDs we can process. + self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") + self._backfill_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", step=-1 + ) + self._get_event_cache = Cache( "*getEvent*", keylen=3, @@ -85,6 +112,14 @@ class EventsWorkerStore(SQLBaseStore): self._event_fetch_list = [] self._event_fetch_ongoing = 0 + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == "events": + self._stream_id_gen.advance(token) + elif stream_name == "backfill": + self._backfill_id_gen.advance(-token) + + super().process_replication_rows(stream_name, instance_name, token, rows) + def get_received_ts(self, event_id): """Get received_ts (when it was persisted) for the event. diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py index 0963e6c250..fb1361f1c1 100644 --- a/synapse/storage/data_stores/main/group_server.py +++ b/synapse/storage/data_stores/main/group_server.py @@ -68,24 +68,78 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_invited_users_in_group", ) - def get_rooms_in_group(self, group_id, include_private=False): + def get_rooms_in_group(self, group_id: str, include_private: bool = False): + """Retrieve the rooms that belong to a given group. Does not return rooms that + lack members. + + Args: + group_id: The ID of the group to query for rooms + include_private: Whether to return private rooms in results + + Returns: + Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the + form of: + + { + "room_id": "!a_room_id:example.com", # The ID of the room + "is_public": False # Whether this is a public room or not + } + """ # TODO: Pagination - keyvalues = {"group_id": group_id} - if not include_private: - keyvalues["is_public"] = True + def _get_rooms_in_group_txn(txn): + sql = """ + SELECT room_id, is_public FROM group_rooms + WHERE group_id = ? + AND room_id IN ( + SELECT group_rooms.room_id FROM group_rooms + LEFT JOIN room_stats_current ON + group_rooms.room_id = room_stats_current.room_id + AND joined_members > 0 + AND local_users_in_room > 0 + LEFT JOIN rooms ON + group_rooms.room_id = rooms.room_id + AND (room_version <> '') = ? + ) + """ + args = [group_id, False] - return self.db.simple_select_list( - table="group_rooms", - keyvalues=keyvalues, - retcols=("room_id", "is_public"), - desc="get_rooms_in_group", - ) + if not include_private: + sql += " AND is_public = ?" + args += [True] + + txn.execute(sql, args) + + return [ + {"room_id": room_id, "is_public": is_public} + for room_id, is_public in txn + ] - def get_rooms_for_summary_by_category(self, group_id, include_private=False): + return self.db.runInteraction("get_rooms_in_group", _get_rooms_in_group_txn) + + def get_rooms_for_summary_by_category( + self, group_id: str, include_private: bool = False, + ): """Get the rooms and categories that should be included in a summary request - Returns ([rooms], [categories]) + Args: + group_id: The ID of the group to query the summary for + include_private: Whether to return private rooms in results + + Returns: + Deferred[Tuple[List, Dict]]: A tuple containing: + + * A list of dictionaries with the keys: + * "room_id": str, the room ID + * "is_public": bool, whether the room is public + * "category_id": str|None, the category ID if set, else None + * "order": int, the sort order of rooms + + * A dictionary with the key: + * category_id (str): a dictionary with the keys: + * "is_public": bool, whether the category is public + * "profile": str, the category profile + * "order": int, the sort order of rooms in this category """ def _get_rooms_for_summary_txn(txn): @@ -97,13 +151,23 @@ class GroupServerWorkerStore(SQLBaseStore): SELECT room_id, is_public, category_id, room_order FROM group_summary_rooms WHERE group_id = ? + AND room_id IN ( + SELECT group_rooms.room_id FROM group_rooms + LEFT JOIN room_stats_current ON + group_rooms.room_id = room_stats_current.room_id + AND joined_members > 0 + AND local_users_in_room > 0 + LEFT JOIN rooms ON + group_rooms.room_id = rooms.room_id + AND (room_version <> '') = ? + ) """ if not include_private: sql += " AND is_public = ?" - txn.execute(sql, (group_id, True)) + txn.execute(sql, (group_id, False, True)) else: - txn.execute(sql, (group_id,)) + txn.execute(sql, (group_id, False)) rooms = [ { diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/data_stores/main/profile.py index 2b52cf9c1a..bfc9369f0b 100644 --- a/synapse/storage/data_stores/main/profile.py +++ b/synapse/storage/data_stores/main/profile.py @@ -110,7 +110,7 @@ class ProfileStore(ProfileWorkerStore): return self.db.simple_update( table="remote_profile_cache", keyvalues={"user_id": user_id}, - values={ + updatevalues={ "displayname": displayname, "avatar_url": avatar_url, "last_check": self._clock.time_msec(), diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index b3faafa0a4..ef8f40959f 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -16,19 +16,23 @@ import abc import logging +from typing import Union from canonicaljson import json from twisted.internet import defer from synapse.push.baserules import list_with_base_rules +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore +from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.pusher import PusherWorkerStore from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore from synapse.storage.database import Database from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException +from synapse.storage.util.id_generators import ChainedIdGenerator from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -64,6 +68,7 @@ class PushRulesWorkerStore( ReceiptsWorkerStore, PusherWorkerStore, RoomMemberWorkerStore, + EventsWorkerStore, SQLBaseStore, ): """This is an abstract base class where subclasses must implement @@ -77,6 +82,15 @@ class PushRulesWorkerStore( def __init__(self, database: Database, db_conn, hs): super(PushRulesWorkerStore, self).__init__(database, db_conn, hs) + if hs.config.worker.worker_app is None: + self._push_rules_stream_id_gen = ChainedIdGenerator( + self._stream_id_gen, db_conn, "push_rules_stream", "stream_id" + ) # type: Union[ChainedIdGenerator, SlavedIdTracker] + else: + self._push_rules_stream_id_gen = SlavedIdTracker( + db_conn, "push_rules_stream", "stream_id" + ) + push_rules_prefill, push_rules_id = self.db.get_cache_dict( db_conn, "push_rules_stream", diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index ee75b92344..13f49d8060 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -37,7 +37,55 @@ SearchEntry = namedtuple( ) -class SearchBackgroundUpdateStore(SQLBaseStore): +class SearchWorkerStore(SQLBaseStore): + def store_search_entries_txn(self, txn, entries): + """Add entries to the search table + + Args: + txn (cursor): + entries (iterable[SearchEntry]): + entries to be added to the table + """ + if not self.hs.config.enable_search: + return + if isinstance(self.database_engine, PostgresEngine): + sql = ( + "INSERT INTO event_search" + " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)" + " VALUES (?,?,?,to_tsvector('english', ?),?,?)" + ) + + args = ( + ( + entry.event_id, + entry.room_id, + entry.key, + entry.value, + entry.stream_ordering, + entry.origin_server_ts, + ) + for entry in entries + ) + + txn.executemany(sql, args) + + elif isinstance(self.database_engine, Sqlite3Engine): + sql = ( + "INSERT INTO event_search (event_id, room_id, key, value)" + " VALUES (?,?,?,?)" + ) + args = ( + (entry.event_id, entry.room_id, entry.key, entry.value) + for entry in entries + ) + + txn.executemany(sql, args) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + +class SearchBackgroundUpdateStore(SearchWorkerStore): EVENT_SEARCH_UPDATE_NAME = "event_search" EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" @@ -296,52 +344,6 @@ class SearchBackgroundUpdateStore(SQLBaseStore): return num_rows - def store_search_entries_txn(self, txn, entries): - """Add entries to the search table - - Args: - txn (cursor): - entries (iterable[SearchEntry]): - entries to be added to the table - """ - if not self.hs.config.enable_search: - return - if isinstance(self.database_engine, PostgresEngine): - sql = ( - "INSERT INTO event_search" - " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)" - " VALUES (?,?,?,to_tsvector('english', ?),?,?)" - ) - - args = ( - ( - entry.event_id, - entry.room_id, - entry.key, - entry.value, - entry.stream_ordering, - entry.origin_server_ts, - ) - for entry in entries - ) - - txn.executemany(sql, args) - - elif isinstance(self.database_engine, Sqlite3Engine): - sql = ( - "INSERT INTO event_search (event_id, room_id, key, value)" - " VALUES (?,?,?,?)" - ) - args = ( - (entry.event_id, entry.room_id, entry.key, entry.value) - for entry in entries - ) - - txn.executemany(sql, args) - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") - class SearchStore(SearchBackgroundUpdateStore): def __init__(self, database: Database, db_conn, hs): diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 86d04ea9ac..f89ce0bed2 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -166,6 +166,7 @@ class ChainedIdGenerator(object): def __init__(self, chained_generator, db_conn, table, column): self.chained_generator = chained_generator + self._table = table self._lock = threading.Lock() self._current_max = _load_current_id(db_conn, table, column) self._unfinished_ids = deque() # type: Deque[Tuple[int, int]] @@ -204,6 +205,16 @@ class ChainedIdGenerator(object): return self._current_max, self.chained_generator.get_current_token() + def advance(self, token: int): + """Stub implementation for advancing the token when receiving updates + over replication; raises an exception as this instance should be the + only source of updates. + """ + + raise Exception( + "Attempted to advance token on source for table %r", self._table + ) + class MultiWriterIdGenerator: """An ID generator that tracks a stream that can have multiple writers. |