summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/account_validity.py101
-rw-r--r--synapse/handlers/appservice.py4
-rw-r--r--synapse/handlers/auth.py2
-rw-r--r--synapse/handlers/cas.py (renamed from synapse/handlers/cas_handler.py)0
-rw-r--r--synapse/handlers/deactivate_account.py4
-rw-r--r--synapse/handlers/device.py23
-rw-r--r--synapse/handlers/event_auth.py86
-rw-r--r--synapse/handlers/federation.py46
-rw-r--r--synapse/handlers/identity.py29
-rw-r--r--synapse/handlers/oidc.py (renamed from synapse/handlers/oidc_handler.py)32
-rw-r--r--synapse/handlers/presence.py312
-rw-r--r--synapse/handlers/room_member.py62
-rw-r--r--synapse/handlers/saml.py (renamed from synapse/handlers/saml_handler.py)0
-rw-r--r--synapse/handlers/sso.py3
-rw-r--r--synapse/handlers/sync.py13
15 files changed, 549 insertions, 168 deletions
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py

index 66ce7e8b83..5b927f10b3 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py
@@ -17,7 +17,7 @@ import email.utils import logging from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, Optional, Tuple from synapse.api.errors import StoreError, SynapseError from synapse.logging.context import make_deferred_yieldable @@ -39,28 +39,44 @@ class AccountValidityHandler: self.sendmail = self.hs.get_sendmail() self.clock = self.hs.get_clock() - self._account_validity = self.hs.config.account_validity + self._account_validity_enabled = ( + hs.config.account_validity.account_validity_enabled + ) + self._account_validity_renew_by_email_enabled = ( + hs.config.account_validity.account_validity_renew_by_email_enabled + ) + + self._account_validity_period = None + if self._account_validity_enabled: + self._account_validity_period = ( + hs.config.account_validity.account_validity_period + ) if ( - self._account_validity.enabled - and self._account_validity.renew_by_email_enabled + self._account_validity_enabled + and self._account_validity_renew_by_email_enabled ): # Don't do email-specific configuration if renewal by email is disabled. - self._template_html = self.config.account_validity_template_html - self._template_text = self.config.account_validity_template_text + self._template_html = ( + hs.config.account_validity.account_validity_template_html + ) + self._template_text = ( + hs.config.account_validity.account_validity_template_text + ) + account_validity_renew_email_subject = ( + hs.config.account_validity.account_validity_renew_email_subject + ) try: - app_name = self.hs.config.email_app_name + app_name = hs.config.email_app_name - self._subject = self._account_validity.renew_email_subject % { - "app": app_name - } + self._subject = account_validity_renew_email_subject % {"app": app_name} - self._from_string = self.hs.config.email_notif_from % {"app": app_name} + self._from_string = hs.config.email_notif_from % {"app": app_name} except Exception: # If substitution failed, fall back to the bare strings. - self._subject = self._account_validity.renew_email_subject - self._from_string = self.hs.config.email_notif_from + self._subject = account_validity_renew_email_subject + self._from_string = hs.config.email_notif_from self._raw_from = email.utils.parseaddr(self._from_string)[1] @@ -220,50 +236,87 @@ class AccountValidityHandler: attempts += 1 raise StoreError(500, "Couldn't generate a unique string as refresh string.") - async def renew_account(self, renewal_token: str) -> bool: + async def renew_account(self, renewal_token: str) -> Tuple[bool, bool, int]: """Renews the account attached to a given renewal token by pushing back the expiration date by the current validity period in the server's configuration. + If it turns out that the token is valid but has already been used, then the + token is considered stale. A token is stale if the 'token_used_ts_ms' db column + is non-null. + Args: renewal_token: Token sent with the renewal request. Returns: - Whether the provided token is valid. + A tuple containing: + * A bool representing whether the token is valid and unused. + * A bool which is `True` if the token is valid, but stale. + * An int representing the user's expiry timestamp as milliseconds since the + epoch, or 0 if the token was invalid. """ try: - user_id = await self.store.get_user_from_renewal_token(renewal_token) + ( + user_id, + current_expiration_ts, + token_used_ts, + ) = await self.store.get_user_from_renewal_token(renewal_token) except StoreError: - return False + return False, False, 0 + + # Check whether this token has already been used. + if token_used_ts: + logger.info( + "User '%s' attempted to use previously used token '%s' to renew account", + user_id, + renewal_token, + ) + return False, True, current_expiration_ts logger.debug("Renewing an account for user %s", user_id) - await self.renew_account_for_user(user_id) - return True + # Renew the account. Pass the renewal_token here so that it is not cleared. + # We want to keep the token around in case the user attempts to renew their + # account with the same token twice (clicking the email link twice). + # + # In that case, the token will be accepted, but the account's expiration ts + # will remain unchanged. + new_expiration_ts = await self.renew_account_for_user( + user_id, renewal_token=renewal_token + ) + + return True, False, new_expiration_ts async def renew_account_for_user( self, user_id: str, expiration_ts: Optional[int] = None, email_sent: bool = False, + renewal_token: Optional[str] = None, ) -> int: """Renews the account attached to a given user by pushing back the expiration date by the current validity period in the server's configuration. Args: - renewal_token: Token sent with the renewal request. + user_id: The ID of the user to renew. expiration_ts: New expiration date. Defaults to now + validity period. - email_sen: Whether an email has been sent for this validity period. - Defaults to False. + email_sent: Whether an email has been sent for this validity period. + renewal_token: Token sent with the renewal request. The user's token + will be cleared if this is None. Returns: New expiration date for this account, as a timestamp in milliseconds since epoch. """ + now = self.clock.time_msec() if expiration_ts is None: - expiration_ts = self.clock.time_msec() + self._account_validity.period + expiration_ts = now + self._account_validity_period await self.store.set_account_validity_for_user( - user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent + user_id=user_id, + expiration_ts=expiration_ts, + email_sent=email_sent, + renewal_token=renewal_token, + token_used_ts=now, ) return expiration_ts diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index d7bc4e23ed..177310f0be 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py
@@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Union from prometheus_client import Counter @@ -33,7 +33,7 @@ from synapse.metrics.background_process_metrics import ( wrap_as_background_process, ) from synapse.storage.databases.main.directory import RoomAliasMapping -from synapse.types import Collection, JsonDict, RoomAlias, RoomStreamToken, UserID +from synapse.types import JsonDict, RoomAlias, RoomStreamToken, UserID from synapse.util.metrics import Measure if TYPE_CHECKING: diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index b8a37b6477..36f2450e2e 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py
@@ -1248,7 +1248,7 @@ class AuthHandler(BaseHandler): # see if any of our auth providers want to know about this for provider in self.password_providers: - for token, token_id, device_id in tokens_and_devices: + for token, _, device_id in tokens_and_devices: await provider.on_logged_out( user_id=user_id, device_id=device_id, access_token=token ) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas.py
index 7346ccfe93..7346ccfe93 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas.py
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 3f6f9f7f3d..45d2404dde 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py
@@ -49,7 +49,9 @@ class DeactivateAccountHandler(BaseHandler): if hs.config.run_background_tasks: hs.get_reactor().callWhenRunning(self._start_user_parting) - self._account_validity_enabled = hs.config.account_validity.enabled + self._account_validity_enabled = ( + hs.config.account_validity.account_validity_enabled + ) async def deactivate_account( self, diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index d75edb184b..95bdc5902a 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py
@@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple from synapse.api import errors from synapse.api.constants import EventTypes @@ -28,7 +28,6 @@ from synapse.api.errors import ( from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import ( - Collection, JsonDict, StreamToken, UserID, @@ -156,8 +155,7 @@ class DeviceWorkerHandler(BaseHandler): # The user may have left the room # TODO: Check if they actually did or if we were just invited. if room_id not in room_ids: - for key, event_id in current_state_ids.items(): - etype, state_key = key + for etype, state_key in current_state_ids.keys(): if etype != EventTypes.Member: continue possibly_left.add(state_key) @@ -179,8 +177,7 @@ class DeviceWorkerHandler(BaseHandler): log_kv( {"event": "encountered empty previous state", "room_id": room_id} ) - for key, event_id in current_state_ids.items(): - etype, state_key = key + for etype, state_key in current_state_ids.keys(): if etype != EventTypes.Member: continue possibly_changed.add(state_key) @@ -198,8 +195,7 @@ class DeviceWorkerHandler(BaseHandler): for state_dict in prev_state_ids.values(): member_event = state_dict.get((EventTypes.Member, user_id), None) if not member_event or member_event != current_member_id: - for key, event_id in current_state_ids.items(): - etype, state_key = key + for etype, state_key in current_state_ids.keys(): if etype != EventTypes.Member: continue possibly_changed.add(state_key) @@ -714,7 +710,7 @@ class DeviceListUpdater: # This can happen since we batch updates return - for device_id, stream_id, prev_ids, content in pending_updates: + for device_id, stream_id, prev_ids, _ in pending_updates: logger.debug( "Handling update %r/%r, ID: %r, prev: %r ", user_id, @@ -740,7 +736,7 @@ class DeviceListUpdater: else: # Simply update the single device, since we know that is the only # change (because of the single prev_id matching the current cache) - for device_id, stream_id, prev_ids, content in pending_updates: + for device_id, stream_id, _, content in pending_updates: await self.store.update_remote_device_list_cache_entry( user_id, device_id, content, stream_id ) @@ -929,6 +925,10 @@ class DeviceListUpdater: else: cached_devices = await self.store.get_cached_devices_for_user(user_id) if cached_devices == {d["device_id"]: d for d in devices}: + logging.info( + "Skipping device list resync for %s, as our cache matches already", + user_id, + ) devices = [] ignore_devices = True @@ -944,6 +944,9 @@ class DeviceListUpdater: await self.store.update_remote_device_list_cache( user_id, devices, stream_id ) + # mark the cache as valid, whether or not we actually processed any device + # list updates. + await self.store.mark_remote_user_device_cache_as_valid(user_id) device_ids = [device["device_id"] for device in devices] # Handle cross-signing keys. diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py new file mode 100644
index 0000000000..eff639f407 --- /dev/null +++ b/synapse/handlers/event_auth.py
@@ -0,0 +1,86 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from synapse.api.constants import EventTypes, JoinRules +from synapse.api.room_versions import RoomVersion +from synapse.types import StateMap + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class EventAuthHandler: + """ + This class contains methods for authenticating events added to room graphs. + """ + + def __init__(self, hs: "HomeServer"): + self._store = hs.get_datastore() + + async def can_join_without_invite( + self, state_ids: StateMap[str], room_version: RoomVersion, user_id: str + ) -> bool: + """ + Check whether a user can join a room without an invite. + + When joining a room with restricted joined rules (as defined in MSC3083), + the membership of spaces must be checked during join. + + Args: + state_ids: The state of the room as it currently is. + room_version: The room version of the room being joined. + user_id: The user joining the room. + + Returns: + True if the user can join the room, false otherwise. + """ + # This only applies to room versions which support the new join rule. + if not room_version.msc3083_join_rules: + return True + + # If there's no join rule, then it defaults to invite (so this doesn't apply). + join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None) + if not join_rules_event_id: + return True + + # If the join rule is not restricted, this doesn't apply. + join_rules_event = await self._store.get_event(join_rules_event_id) + if join_rules_event.content.get("join_rule") != JoinRules.MSC3083_RESTRICTED: + return True + + # If allowed is of the wrong form, then only allow invited users. + allowed_spaces = join_rules_event.content.get("allow", []) + if not isinstance(allowed_spaces, list): + return False + + # Get the list of joined rooms and see if there's an overlap. + joined_rooms = await self._store.get_rooms_for_user(user_id) + + # Pull out the other room IDs, invalid data gets filtered. + for space in allowed_spaces: + if not isinstance(space, dict): + continue + + space_id = space.get("space") + if not isinstance(space_id, str): + continue + + # The user was joined to one of the spaces specified, they can join + # this room! + if space_id in joined_rooms: + return True + + # The user was not in any of the required spaces. + return False diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 4b3730aa3b..9d867aaf4d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py
@@ -146,6 +146,7 @@ class FederationHandler(BaseHandler): self.is_mine_id = hs.is_mine_id self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() + self._event_auth_handler = hs.get_event_auth_handler() self._message_handler = hs.get_message_handler() self._server_notices_mxid = hs.config.server_notices_mxid self.config = hs.config @@ -1673,8 +1674,40 @@ class FederationHandler(BaseHandler): # would introduce the danger of backwards-compatibility problems. event.internal_metadata.send_on_behalf_of = origin + # Calculate the event context. context = await self.state_handler.compute_event_context(event) - context = await self._auth_and_persist_event(origin, event, context) + + # Get the state before the new event. + prev_state_ids = await context.get_prev_state_ids() + + # Check if the user is already in the room or invited to the room. + user_id = event.state_key + prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) + newly_joined = True + user_is_invited = False + if prev_member_event_id: + prev_member_event = await self.store.get_event(prev_member_event_id) + newly_joined = prev_member_event.membership != Membership.JOIN + user_is_invited = prev_member_event.membership == Membership.INVITE + + # If the member is not already in the room, and not invited, check if + # they should be allowed access via membership in a space. + if ( + newly_joined + and not user_is_invited + and not await self._event_auth_handler.can_join_without_invite( + prev_state_ids, + event.room_version, + user_id, + ) + ): + raise AuthError( + 403, + "You do not belong to any of the required spaces to join this room.", + ) + + # Persist the event. + await self._auth_and_persist_event(origin, event, context) logger.debug( "on_send_join_request: After _auth_and_persist_event: %s, sigs: %s", @@ -1682,8 +1715,6 @@ class FederationHandler(BaseHandler): event.signatures, ) - prev_state_ids = await context.get_prev_state_ids() - state_ids = list(prev_state_ids.values()) auth_chain = await self.store.get_auth_chain(event.room_id, state_ids) @@ -2006,7 +2037,7 @@ class FederationHandler(BaseHandler): state: Optional[Iterable[EventBase]] = None, auth_events: Optional[MutableStateMap[EventBase]] = None, backfilled: bool = False, - ) -> EventContext: + ) -> None: """ Process an event by performing auth checks and then persisting to the database. @@ -2028,9 +2059,6 @@ class FederationHandler(BaseHandler): event is an outlier), may be the auth events claimed by the remote server. backfilled: True if the event was backfilled. - - Returns: - The event context. """ context = await self._check_event_auth( origin, @@ -2060,8 +2088,6 @@ class FederationHandler(BaseHandler): ) raise - return context - async def _auth_and_persist_events( self, origin: str, @@ -2956,7 +2982,7 @@ class FederationHandler(BaseHandler): try: # for each sig on the third_party_invite block of the actual invite for server, signature_block in signed["signatures"].items(): - for key_name, encoded_signature in signature_block.items(): + for key_name in signature_block.keys(): if not key_name.startswith("ed25519:"): continue diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 87a8b89237..0b3b1fadb5 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py
@@ -15,7 +15,6 @@ # limitations under the License. """Utilities for interacting with Identity Servers""" - import logging import urllib.parse from typing import Awaitable, Callable, Dict, List, Optional, Tuple @@ -34,7 +33,11 @@ from synapse.http.site import SynapseRequest from synapse.types import JsonDict, Requester from synapse.util import json_decoder from synapse.util.hash import sha256_and_url_safe_base64 -from synapse.util.stringutils import assert_valid_client_secret, random_string +from synapse.util.stringutils import ( + assert_valid_client_secret, + random_string, + valid_id_server_location, +) from ._base import BaseHandler @@ -172,6 +175,11 @@ class IdentityHandler(BaseHandler): server with, if necessary. Required if use_v2 is true use_v2: Whether to use v2 Identity Service API endpoints. Defaults to True + Raises: + SynapseError: On any of the following conditions + - the supplied id_server is not a valid identity server name + - we failed to contact the supplied identity server + Returns: The response from the identity server """ @@ -181,6 +189,12 @@ class IdentityHandler(BaseHandler): if id_access_token is None: use_v2 = False + if not valid_id_server_location(id_server): + raise SynapseError( + 400, + "id_server must be a valid hostname with optional port and path components", + ) + # Decide which API endpoint URLs to use headers = {} bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid} @@ -269,12 +283,21 @@ class IdentityHandler(BaseHandler): id_server: Identity server to unbind from Raises: - SynapseError: If we failed to contact the identity server + SynapseError: On any of the following conditions + - the supplied id_server is not a valid identity server name + - we failed to contact the supplied identity server Returns: True on success, otherwise False if the identity server doesn't support unbinding """ + + if not valid_id_server_location(id_server): + raise SynapseError( + 400, + "id_server must be a valid hostname with optional port and path components", + ) + url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,) url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii") diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc.py
index b156196a70..ee6e41c0e4 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc.py
@@ -15,7 +15,7 @@ import inspect import logging from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union -from urllib.parse import urlencode +from urllib.parse import urlencode, urlparse import attr import pymacaroons @@ -37,10 +37,7 @@ from twisted.web.client import readBody from twisted.web.http_headers import Headers from synapse.config import ConfigError -from synapse.config.oidc_config import ( - OidcProviderClientSecretJwtKey, - OidcProviderConfig, -) +from synapse.config.oidc import OidcProviderClientSecretJwtKey, OidcProviderConfig from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable @@ -71,8 +68,8 @@ logger = logging.getLogger(__name__) # # Here we have the names of the cookies, and the options we use to set them. _SESSION_COOKIES = [ - (b"oidc_session", b"Path=/_synapse/client/oidc; HttpOnly; Secure; SameSite=None"), - (b"oidc_session_no_samesite", b"Path=/_synapse/client/oidc; HttpOnly"), + (b"oidc_session", b"HttpOnly; Secure; SameSite=None"), + (b"oidc_session_no_samesite", b"HttpOnly"), ] #: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and @@ -282,6 +279,13 @@ class OidcProvider: self._config = provider self._callback_url = hs.config.oidc_callback_url # type: str + # Calculate the prefix for OIDC callback paths based on the public_baseurl. + # We'll insert this into the Path= parameter of any session cookies we set. + public_baseurl_path = urlparse(hs.config.server.public_baseurl).path + self._callback_path_prefix = ( + public_baseurl_path.encode("utf-8") + b"_synapse/client/oidc" + ) + self._oidc_attribute_requirements = provider.attribute_requirements self._scopes = provider.scopes self._user_profile_method = provider.user_profile_method @@ -782,8 +786,13 @@ class OidcProvider: for cookie_name, options in _SESSION_COOKIES: request.cookies.append( - b"%s=%s; Max-Age=3600; %s" - % (cookie_name, cookie.encode("utf-8"), options) + b"%s=%s; Max-Age=3600; Path=%s; %s" + % ( + cookie_name, + cookie.encode("utf-8"), + self._callback_path_prefix, + options, + ) ) metadata = await self.load_metadata() @@ -960,6 +969,11 @@ class OidcProvider: # and attempt to match it. attributes = await oidc_response_to_user_attributes(failures=0) + if attributes.localpart is None: + # If no localpart is returned then we will generate one, so + # there is no need to search for existing users. + return None + user_id = UserID(attributes.localpart, self._server_name).to_string() users = await self._store.get_users_by_id_case_insensitive(user_id) if users: diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 6460eb9952..969c73c1e7 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py
@@ -24,9 +24,11 @@ The methods that define policy are: import abc import contextlib import logging +from bisect import bisect from contextlib import contextmanager from typing import ( TYPE_CHECKING, + Collection, Dict, FrozenSet, Iterable, @@ -53,10 +55,11 @@ from synapse.replication.http.presence import ( ReplicationBumpPresenceActiveTime, ReplicationPresenceSetState, ) +from synapse.replication.http.streams import ReplicationGetStreamUpdates from synapse.replication.tcp.commands import ClearUserSyncsCommand -from synapse.state import StateHandler +from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream from synapse.storage.databases.main import DataStore -from synapse.types import Collection, JsonDict, UserID, get_domain_from_id +from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.metrics import Measure @@ -118,19 +121,21 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER class BasePresenceHandler(abc.ABC): - """Parts of the PresenceHandler that are shared between workers and master""" + """Parts of the PresenceHandler that are shared between workers and presence + writer""" def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.store = hs.get_datastore() self.presence_router = hs.get_presence_router() self.state = hs.get_state_handler() + self.is_mine_id = hs.is_mine_id self._federation = None - if hs.should_send_federation() or not hs.config.worker_app: + if hs.should_send_federation(): self._federation = hs.get_federation_sender() - self._send_federation = hs.should_send_federation() + self._federation_queue = PresenceFederationQueue(hs, self) self._busy_presence_enabled = hs.config.experimental.msc3026_enabled @@ -253,28 +258,38 @@ class BasePresenceHandler(abc.ABC): """ pass - async def process_replication_rows(self, token, rows): - """Process presence stream rows received over replication.""" - pass + async def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: list + ): + """Process streams received over replication.""" + await self._federation_queue.process_replication_rows( + stream_name, instance_name, token, rows + ) + + def get_federation_queue(self) -> "PresenceFederationQueue": + """Get the presence federation queue.""" + return self._federation_queue async def maybe_send_presence_to_interested_destinations( self, states: List[UserPresenceState] ): """If this instance is a federation sender, send the states to all - destinations that are interested. + destinations that are interested. Filters out any states for remote + users. """ - if not self._send_federation: + if not self._federation: return - # If this worker sends federation we must have a FederationSender. - assert self._federation + states = [s for s in states if self.is_mine_id(s.user_id)] + + if not states: + return hosts_and_states = await get_interested_remotes( self.store, self.presence_router, states, - self.state, ) for destinations, states in hosts_and_states: @@ -292,10 +307,17 @@ class WorkerPresenceHandler(BasePresenceHandler): def __init__(self, hs): super().__init__(hs) self.hs = hs - self.is_mine_id = hs.is_mine_id + + self._presence_writer_instance = hs.config.worker.writers.presence[0] self._presence_enabled = hs.config.use_presence + # Route presence EDUs to the right worker + hs.get_federation_registry().register_instances_for_edu( + "m.presence", + hs.config.worker.writers.presence, + ) + # The number of ongoing syncs on this process, by user id. # Empty if _presence_enabled is false. self._user_to_num_current_syncs = {} # type: Dict[str, int] @@ -303,8 +325,8 @@ class WorkerPresenceHandler(BasePresenceHandler): self.notifier = hs.get_notifier() self.instance_id = hs.get_instance_id() - # user_id -> last_sync_ms. Lists the users that have stopped syncing - # but we haven't notified the master of that yet + # user_id -> last_sync_ms. Lists the users that have stopped syncing but + # we haven't notified the presence writer of that yet self.users_going_offline = {} self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs) @@ -337,22 +359,23 @@ class WorkerPresenceHandler(BasePresenceHandler): ) def mark_as_coming_online(self, user_id): - """A user has started syncing. Send a UserSync to the master, unless they - had recently stopped syncing. + """A user has started syncing. Send a UserSync to the presence writer, + unless they had recently stopped syncing. Args: user_id (str) """ going_offline = self.users_going_offline.pop(user_id, None) if not going_offline: - # Safe to skip because we haven't yet told the master they were offline + # Safe to skip because we haven't yet told the presence writer they + # were offline self.send_user_sync(user_id, True, self.clock.time_msec()) def mark_as_going_offline(self, user_id): - """A user has stopped syncing. We wait before notifying the master as - its likely they'll come back soon. This allows us to avoid sending - a stopped syncing immediately followed by a started syncing notification - to the master + """A user has stopped syncing. We wait before notifying the presence + writer as its likely they'll come back soon. This allows us to avoid + sending a stopped syncing immediately followed by a started syncing + notification to the presence writer Args: user_id (str) @@ -360,8 +383,8 @@ class WorkerPresenceHandler(BasePresenceHandler): self.users_going_offline[user_id] = self.clock.time_msec() def send_stop_syncing(self): - """Check if there are any users who have stopped syncing a while ago - and haven't come back yet. If there are poke the master about them. + """Check if there are any users who have stopped syncing a while ago and + haven't come back yet. If there are poke the presence writer about them. """ now = self.clock.time_msec() for user_id, last_sync_ms in list(self.users_going_offline.items()): @@ -421,7 +444,14 @@ class WorkerPresenceHandler(BasePresenceHandler): # If this is a federation sender, notify about presence updates. await self.maybe_send_presence_to_interested_destinations(states) - async def process_replication_rows(self, token, rows): + async def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: list + ): + await super().process_replication_rows(stream_name, instance_name, token, rows) + + if stream_name != PresenceStream.NAME: + return + states = [ UserPresenceState( row.user_id, @@ -470,9 +500,12 @@ class WorkerPresenceHandler(BasePresenceHandler): if not self.hs.config.use_presence: return - # Proxy request to master + # Proxy request to instance that writes presence await self._set_state_client( - user_id=user_id, state=state, ignore_status_msg=ignore_status_msg + instance_name=self._presence_writer_instance, + user_id=user_id, + state=state, + ignore_status_msg=ignore_status_msg, ) async def bump_presence_active_time(self, user): @@ -483,16 +516,17 @@ class WorkerPresenceHandler(BasePresenceHandler): if not self.hs.config.use_presence: return - # Proxy request to master + # Proxy request to instance that writes presence user_id = user.to_string() - await self._bump_active_client(user_id=user_id) + await self._bump_active_client( + instance_name=self._presence_writer_instance, user_id=user_id + ) class PresenceHandler(BasePresenceHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs - self.is_mine_id = hs.is_mine_id self.server_name = hs.hostname self.wheel_timer = WheelTimer() self.notifier = hs.get_notifier() @@ -721,15 +755,12 @@ class PresenceHandler(BasePresenceHandler): self.store, self.presence_router, list(to_federation_ping.values()), - self.state, ) - # Since this is master we know that we have a federation sender or - # queue, and so this will be defined. - assert self._federation - for destinations, states in hosts_and_states: - self._federation.send_presence_to_destinations(states, destinations) + self._federation_queue.send_presence_to_destinations( + states, destinations + ) async def _handle_timeouts(self): """Checks the presence of users that have timed out and updates as @@ -1208,13 +1239,9 @@ class PresenceHandler(BasePresenceHandler): user_presence_states ) - # Since this is master we know that we have a federation sender or - # queue, and so this will be defined. - assert self._federation - # Send out user presence updates for each destination for destination, user_state_set in presence_destinations.items(): - self._federation.send_presence_to_destinations( + self._federation_queue.send_presence_to_destinations( destinations=[destination], states=user_state_set ) @@ -1354,7 +1381,6 @@ class PresenceEventSource: self.get_presence_router = hs.get_presence_router self.clock = hs.get_clock() self.store = hs.get_datastore() - self.state = hs.get_state_handler() @log_function async def get_new_events( @@ -1823,7 +1849,6 @@ async def get_interested_remotes( store: DataStore, presence_router: PresenceRouter, states: List[UserPresenceState], - state_handler: StateHandler, ) -> List[Tuple[Collection[str], List[UserPresenceState]]]: """Given a list of presence states figure out which remote servers should be sent which. @@ -1834,7 +1859,6 @@ async def get_interested_remotes( store: The homeserver's data store. presence_router: A module for augmenting the destinations for presence updates. states: A list of incoming user presence updates. - state_handler: Returns: A list of 2-tuples of destinations and states, where for @@ -1851,7 +1875,8 @@ async def get_interested_remotes( ) for room_id, states in room_ids_to_states.items(): - hosts = await state_handler.get_current_hosts_in_room(room_id) + user_ids = await store.get_users_in_room(room_id) + hosts = {get_domain_from_id(user_id) for user_id in user_ids} hosts_and_states.append((hosts, states)) for user_id, states in users_to_states.items(): @@ -1859,3 +1884,198 @@ async def get_interested_remotes( hosts_and_states.append(([host], states)) return hosts_and_states + + +class PresenceFederationQueue: + """Handles sending ad hoc presence updates over federation, which are *not* + due to state updates (that get handled via the presence stream), e.g. + federation pings and sending existing present states to newly joined hosts. + + Only the last N minutes will be queued, so if a federation sender instance + is down for longer then some updates will be dropped. This is OK as presence + is ephemeral, and so it will self correct eventually. + + On workers the class tracks the last received position of the stream from + replication, and handles querying for missed updates over HTTP replication, + c.f. `get_current_token` and `get_replication_rows`. + """ + + # How long to keep entries in the queue for. Workers that are down for + # longer than this duration will miss out on older updates. + _KEEP_ITEMS_IN_QUEUE_FOR_MS = 5 * 60 * 1000 + + # How often to check if we can expire entries from the queue. + _CLEAR_ITEMS_EVERY_MS = 60 * 1000 + + def __init__(self, hs: "HomeServer", presence_handler: BasePresenceHandler): + self._clock = hs.get_clock() + self._notifier = hs.get_notifier() + self._instance_name = hs.get_instance_name() + self._presence_handler = presence_handler + self._repl_client = ReplicationGetStreamUpdates.make_client(hs) + + # Should we keep a queue of recent presence updates? We only bother if + # another process may be handling federation sending. + self._queue_presence_updates = True + + # Whether this instance is a presence writer. + self._presence_writer = self._instance_name in hs.config.worker.writers.presence + + # The FederationSender instance, if this process sends federation traffic directly. + self._federation = None + + if hs.should_send_federation(): + self._federation = hs.get_federation_sender() + + # We don't bother queuing up presence states if only this instance + # is sending federation. + if hs.config.worker.federation_shard_config.instances == [ + self._instance_name + ]: + self._queue_presence_updates = False + + # The queue of recently queued updates as tuples of: `(timestamp, + # stream_id, destinations, user_ids)`. We don't store the full states + # for efficiency, and remote workers will already have the full states + # cached. + self._queue = [] # type: List[Tuple[int, int, Collection[str], Set[str]]] + + self._next_id = 1 + + # Map from instance name to current token + self._current_tokens = {} # type: Dict[str, int] + + if self._queue_presence_updates: + self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS) + + def _clear_queue(self): + """Clear out older entries from the queue.""" + clear_before = self._clock.time_msec() - self._KEEP_ITEMS_IN_QUEUE_FOR_MS + + # The queue is sorted by timestamp, so we can bisect to find the right + # place to purge before. Note that we are searching using a 1-tuple with + # the time, which does The Right Thing since the queue is a tuple where + # the first item is a timestamp. + index = bisect(self._queue, (clear_before,)) + self._queue = self._queue[index:] + + def send_presence_to_destinations( + self, states: Collection[UserPresenceState], destinations: Collection[str] + ) -> None: + """Send the presence states to the given destinations. + + Will forward to the local federation sender (if there is one) and queue + to send over replication (if there are other federation sender instances.). + + Must only be called on the presence writer process. + """ + + # This should only be called on a presence writer. + assert self._presence_writer + + if self._federation: + self._federation.send_presence_to_destinations( + states=states, + destinations=destinations, + ) + + if not self._queue_presence_updates: + return + + now = self._clock.time_msec() + + stream_id = self._next_id + self._next_id += 1 + + self._queue.append((now, stream_id, destinations, {s.user_id for s in states})) + + self._notifier.notify_replication() + + def get_current_token(self, instance_name: str) -> int: + """Get the current position of the stream. + + On workers this returns the last stream ID received from replication. + """ + if instance_name == self._instance_name: + return self._next_id - 1 + else: + return self._current_tokens.get(instance_name, 0) + + async def get_replication_rows( + self, + instance_name: str, + from_token: int, + upto_token: int, + target_row_count: int, + ) -> Tuple[List[Tuple[int, Tuple[str, str]]], int, bool]: + """Get all the updates between the two tokens. + + We return rows in the form of `(destination, user_id)` to keep the size + of each row bounded (rather than returning the sets in a row). + + On workers this will query the presence writer process via HTTP replication. + """ + if instance_name != self._instance_name: + # If not local we query over http replication from the presence + # writer + result = await self._repl_client( + instance_name=instance_name, + stream_name=PresenceFederationStream.NAME, + from_token=from_token, + upto_token=upto_token, + ) + return result["updates"], result["upto_token"], result["limited"] + + # We can find the correct position in the queue by noting that there is + # exactly one entry per stream ID, and that the last entry has an ID of + # `self._next_id - 1`, so we can count backwards from the end. + # + # Since the start of the queue is periodically truncated we need to + # handle the case where `from_token` stream ID has already been dropped. + start_idx = max(from_token - self._next_id, -len(self._queue)) + + to_send = [] # type: List[Tuple[int, Tuple[str, str]]] + limited = False + new_id = upto_token + for _, stream_id, destinations, user_ids in self._queue[start_idx:]: + if stream_id > upto_token: + break + + new_id = stream_id + + to_send.extend( + (stream_id, (destination, user_id)) + for destination in destinations + for user_id in user_ids + ) + + if len(to_send) > target_row_count: + limited = True + break + + return to_send, new_id, limited + + async def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: list + ): + if stream_name != PresenceFederationStream.NAME: + return + + # We keep track of the current tokens (so that we can catch up with anything we missed after a disconnect) + self._current_tokens[instance_name] = token + + # If we're a federation sender we pull out the presence states to send + # and forward them on. + if not self._federation: + return + + hosts_to_users = {} # type: Dict[str, Set[str]] + for row in rows: + hosts_to_users.setdefault(row.destination, set()).add(row.user_id) + + for host, user_ids in hosts_to_users.items(): + states = await self._presence_handler.current_state_for_users(user_ids) + self._federation.send_presence_to_destinations( + states=states.values(), + destinations=[host], + ) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 6e28677530..54c25e3557 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py
@@ -19,7 +19,7 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from synapse import types -from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules, Membership +from synapse.api.constants import AccountDataTypes, EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, @@ -28,7 +28,6 @@ from synapse.api.errors import ( SynapseError, ) from synapse.api.ratelimiting import Ratelimiter -from synapse.api.room_versions import RoomVersion from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID @@ -64,6 +63,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.profile_handler = hs.get_profile_handler() self.event_creation_handler = hs.get_event_creation_handler() self.account_data_handler = hs.get_account_data_handler() + self.event_auth_handler = hs.get_event_auth_handler() self.member_linearizer = Linearizer(name="member") self.member_limiter = Linearizer(max_count=10, name="member_as_limiter") @@ -179,62 +179,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id) - async def _can_join_without_invite( - self, state_ids: StateMap[str], room_version: RoomVersion, user_id: str - ) -> bool: - """ - Check whether a user can join a room without an invite. - - When joining a room with restricted joined rules (as defined in MSC3083), - the membership of spaces must be checked during join. - - Args: - state_ids: The state of the room as it currently is. - room_version: The room version of the room being joined. - user_id: The user joining the room. - - Returns: - True if the user can join the room, false otherwise. - """ - # This only applies to room versions which support the new join rule. - if not room_version.msc3083_join_rules: - return True - - # If there's no join rule, then it defaults to public (so this doesn't apply). - join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None) - if not join_rules_event_id: - return True - - # If the join rule is not restricted, this doesn't apply. - join_rules_event = await self.store.get_event(join_rules_event_id) - if join_rules_event.content.get("join_rule") != JoinRules.MSC3083_RESTRICTED: - return True - - # If allowed is of the wrong form, then only allow invited users. - allowed_spaces = join_rules_event.content.get("allow", []) - if not isinstance(allowed_spaces, list): - return False - - # Get the list of joined rooms and see if there's an overlap. - joined_rooms = await self.store.get_rooms_for_user(user_id) - - # Pull out the other room IDs, invalid data gets filtered. - for space in allowed_spaces: - if not isinstance(space, dict): - continue - - space_id = space.get("space") - if not isinstance(space_id, str): - continue - - # The user was joined to one of the spaces specified, they can join - # this room! - if space_id in joined_rooms: - return True - - # The user was not in any of the required spaces. - return False - async def _local_membership_update( self, requester: Requester, @@ -303,7 +247,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if ( newly_joined and not user_is_invited - and not await self._can_join_without_invite( + and not await self.event_auth_handler.can_join_without_invite( prev_state_ids, event.room_version, user_id ) ): diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml.py
index 80ba65b9e0..80ba65b9e0 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml.py
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 8d00ffdc73..044ff06d84 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py
@@ -18,6 +18,7 @@ from typing import ( Any, Awaitable, Callable, + Collection, Dict, Iterable, List, @@ -40,7 +41,7 @@ from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent from synapse.http.server import respond_with_html, respond_with_redirect from synapse.http.site import SynapseRequest -from synapse.types import Collection, JsonDict, UserID, contains_invalid_mxid_characters +from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters from synapse.util.async_helpers import Linearizer from synapse.util.stringutils import random_string diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index d2ba805f86..3ffc4628cb 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py
@@ -14,7 +14,17 @@ # limitations under the License. import itertools import logging -from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Dict, + FrozenSet, + List, + Optional, + Set, + Tuple, +) import attr from prometheus_client import Counter @@ -28,7 +38,6 @@ from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter from synapse.types import ( - Collection, JsonDict, MutableStateMap, Requester,