diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index bb81c0e81d..d29b066a56 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -32,6 +32,10 @@ logger = logging.getLogger(__name__)
class BaseHandler:
"""
Common base class for the event handlers.
+
+ Deprecated: new code should not use this. Instead, Handler classes should define the
+ fields they actually need. The utility methods should either be factored out to
+ standalone helper functions, or to different Handler classes.
"""
def __init__(self, hs: "HomeServer"):
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 341135822e..b1a5df9638 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,14 +13,157 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import random
from typing import TYPE_CHECKING, List, Tuple
+from synapse.replication.http.account_data import (
+ ReplicationAddTagRestServlet,
+ ReplicationRemoveTagRestServlet,
+ ReplicationRoomAccountDataRestServlet,
+ ReplicationUserAccountDataRestServlet,
+)
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
+class AccountDataHandler:
+ def __init__(self, hs: "HomeServer"):
+ self._store = hs.get_datastore()
+ self._instance_name = hs.get_instance_name()
+ self._notifier = hs.get_notifier()
+
+ self._user_data_client = ReplicationUserAccountDataRestServlet.make_client(hs)
+ self._room_data_client = ReplicationRoomAccountDataRestServlet.make_client(hs)
+ self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs)
+ self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs)
+ self._account_data_writers = hs.config.worker.writers.account_data
+
+ async def add_account_data_to_room(
+ self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
+ ) -> int:
+ """Add some account_data to a room for a user.
+
+ Args:
+ user_id: The user to add a tag for.
+ room_id: The room to add a tag for.
+ account_data_type: The type of account_data to add.
+ content: A json object to associate with the tag.
+
+ Returns:
+ The maximum stream ID.
+ """
+ if self._instance_name in self._account_data_writers:
+ max_stream_id = await self._store.add_account_data_to_room(
+ user_id, room_id, account_data_type, content
+ )
+
+ self._notifier.on_new_event(
+ "account_data_key", max_stream_id, users=[user_id]
+ )
+
+ return max_stream_id
+ else:
+ response = await self._room_data_client(
+ instance_name=random.choice(self._account_data_writers),
+ user_id=user_id,
+ room_id=room_id,
+ account_data_type=account_data_type,
+ content=content,
+ )
+ return response["max_stream_id"]
+
+ async def add_account_data_for_user(
+ self, user_id: str, account_data_type: str, content: JsonDict
+ ) -> int:
+ """Add some account_data to a room for a user.
+
+ Args:
+ user_id: The user to add a tag for.
+ account_data_type: The type of account_data to add.
+ content: A json object to associate with the tag.
+
+ Returns:
+ The maximum stream ID.
+ """
+
+ if self._instance_name in self._account_data_writers:
+ max_stream_id = await self._store.add_account_data_for_user(
+ user_id, account_data_type, content
+ )
+
+ self._notifier.on_new_event(
+ "account_data_key", max_stream_id, users=[user_id]
+ )
+ return max_stream_id
+ else:
+ response = await self._user_data_client(
+ instance_name=random.choice(self._account_data_writers),
+ user_id=user_id,
+ account_data_type=account_data_type,
+ content=content,
+ )
+ return response["max_stream_id"]
+
+ async def add_tag_to_room(
+ self, user_id: str, room_id: str, tag: str, content: JsonDict
+ ) -> int:
+ """Add a tag to a room for a user.
+
+ Args:
+ user_id: The user to add a tag for.
+ room_id: The room to add a tag for.
+ tag: The tag name to add.
+ content: A json object to associate with the tag.
+
+ Returns:
+ The next account data ID.
+ """
+ if self._instance_name in self._account_data_writers:
+ max_stream_id = await self._store.add_tag_to_room(
+ user_id, room_id, tag, content
+ )
+
+ self._notifier.on_new_event(
+ "account_data_key", max_stream_id, users=[user_id]
+ )
+ return max_stream_id
+ else:
+ response = await self._add_tag_client(
+ instance_name=random.choice(self._account_data_writers),
+ user_id=user_id,
+ room_id=room_id,
+ tag=tag,
+ content=content,
+ )
+ return response["max_stream_id"]
+
+ async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int:
+ """Remove a tag from a room for a user.
+
+ Returns:
+ The next account data ID.
+ """
+ if self._instance_name in self._account_data_writers:
+ max_stream_id = await self._store.remove_tag_from_room(
+ user_id, room_id, tag
+ )
+
+ self._notifier.on_new_event(
+ "account_data_key", max_stream_id, users=[user_id]
+ )
+ return max_stream_id
+ else:
+ response = await self._remove_tag_client(
+ instance_name=random.choice(self._account_data_writers),
+ user_id=user_id,
+ room_id=room_id,
+ tag=tag,
+ )
+ return response["max_stream_id"]
+
+
class AccountDataEventSource:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
index 8476256a59..5ecb2da1ac 100644
--- a/synapse/handlers/acme.py
+++ b/synapse/handlers/acme.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING
import twisted
import twisted.internet.error
@@ -22,6 +23,9 @@ from twisted.web.resource import Resource
from synapse.app import check_bind_error
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
ACME_REGISTER_FAIL_ERROR = """
@@ -35,12 +39,12 @@ solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/AC
class AcmeHandler:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.reactor = hs.get_reactor()
self._acme_domain = hs.config.acme_domain
- async def start_listening(self):
+ async def start_listening(self) -> None:
from synapse.handlers import acme_issuing_service
# Configure logging for txacme, if you need to debug
@@ -85,7 +89,7 @@ class AcmeHandler:
logger.error(ACME_REGISTER_FAIL_ERROR)
raise
- async def provision_certificate(self):
+ async def provision_certificate(self) -> None:
logger.warning("Reprovisioning %s", self._acme_domain)
@@ -110,5 +114,3 @@ class AcmeHandler:
except Exception:
logger.exception("Failed saving!")
raise
-
- return True
diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py
index 7294649d71..ae2a9dd9c2 100644
--- a/synapse/handlers/acme_issuing_service.py
+++ b/synapse/handlers/acme_issuing_service.py
@@ -22,8 +22,10 @@ only need (and may only have available) if we are doing ACME, so is designed to
imported conditionally.
"""
import logging
+from typing import Dict, Iterable, List
import attr
+import pem
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from josepy import JWKRSA
@@ -36,20 +38,27 @@ from txacme.util import generate_private_key
from zope.interface import implementer
from twisted.internet import defer
+from twisted.internet.interfaces import IReactorTCP
from twisted.python.filepath import FilePath
from twisted.python.url import URL
+from twisted.web.resource import IResource
logger = logging.getLogger(__name__)
-def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource):
+def create_issuing_service(
+ reactor: IReactorTCP,
+ acme_url: str,
+ account_key_file: str,
+ well_known_resource: IResource,
+) -> AcmeIssuingService:
"""Create an ACME issuing service, and attach it to a web Resource
Args:
reactor: twisted reactor
- acme_url (str): URL to use to request certificates
- account_key_file (str): where to store the account key
- well_known_resource (twisted.web.IResource): web resource for .well-known.
+ acme_url: URL to use to request certificates
+ account_key_file: where to store the account key
+ well_known_resource: web resource for .well-known.
we will attach a child resource for "acme-challenge".
Returns:
@@ -83,18 +92,20 @@ class ErsatzStore:
A store that only stores in memory.
"""
- certs = attr.ib(default=attr.Factory(dict))
+ certs = attr.ib(type=Dict[bytes, List[bytes]], default=attr.Factory(dict))
- def store(self, server_name, pem_objects):
+ def store(
+ self, server_name: bytes, pem_objects: Iterable[pem.AbstractPEMObject]
+ ) -> defer.Deferred:
self.certs[server_name] = [o.as_bytes() for o in pem_objects]
return defer.succeed(None)
-def load_or_create_client_key(key_file):
+def load_or_create_client_key(key_file: str) -> JWKRSA:
"""Load the ACME account key from a file, creating it if it does not exist.
Args:
- key_file (str): name of the file to use as the account key
+ key_file: name of the file to use as the account key
"""
# this is based on txacme.endpoint.load_or_create_client_key, but doesn't
# hardcode the 'client.key' filename
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index a703944543..37e63da9b1 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -13,27 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import logging
-from typing import List
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
from synapse.api.constants import Membership
-from synapse.events import FrozenEvent
-from synapse.types import RoomStreamToken, StateMap
+from synapse.events import EventBase
+from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
class AdminHandler(BaseHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.storage = hs.get_storage()
self.state_store = self.storage.state
- async def get_whois(self, user):
+ async def get_whois(self, user: UserID) -> JsonDict:
connections = []
sessions = await self.store.get_user_ip_and_agents(user)
@@ -53,7 +57,7 @@ class AdminHandler(BaseHandler):
return ret
- async def get_user(self, user):
+ async def get_user(self, user: UserID) -> Optional[JsonDict]:
"""Function to get user details"""
ret = await self.store.get_user_by_id(user.to_string())
if ret:
@@ -64,12 +68,12 @@ class AdminHandler(BaseHandler):
ret["threepids"] = threepids
return ret
- async def export_user_data(self, user_id, writer):
+ async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:
"""Write all data we have on the user to the given writer.
Args:
- user_id (str)
- writer (ExfiltrationWriter)
+ user_id: The user ID to fetch data of.
+ writer: The writer to write to.
Returns:
Resolves when all data for a user has been written.
@@ -128,7 +132,8 @@ class AdminHandler(BaseHandler):
from_key = RoomStreamToken(0, 0)
to_key = RoomStreamToken(None, stream_ordering)
- written_events = set() # Events that we've processed in this room
+ # Events that we've processed in this room
+ written_events = set() # type: Set[str]
# We need to track gaps in the events stream so that we can then
# write out the state at those events. We do this by keeping track
@@ -140,8 +145,8 @@ class AdminHandler(BaseHandler):
# The reverse mapping to above, i.e. map from unseen event to events
# that have the unseen event in their prev_events, i.e. the unseen
- # events "children". dict[str, set[str]]
- unseen_to_child_events = {}
+ # events "children".
+ unseen_to_child_events = {} # type: Dict[str, Set[str]]
# We fetch events in the room the user could see by fetching *all*
# events that we have and then filtering, this isn't the most
@@ -197,38 +202,46 @@ class AdminHandler(BaseHandler):
return writer.finished()
-class ExfiltrationWriter:
+class ExfiltrationWriter(metaclass=abc.ABCMeta):
"""Interface used to specify how to write exported data.
"""
- def write_events(self, room_id: str, events: List[FrozenEvent]):
+ @abc.abstractmethod
+ def write_events(self, room_id: str, events: List[EventBase]) -> None:
"""Write a batch of events for a room.
"""
- pass
+ raise NotImplementedError()
- def write_state(self, room_id: str, event_id: str, state: StateMap[FrozenEvent]):
+ @abc.abstractmethod
+ def write_state(
+ self, room_id: str, event_id: str, state: StateMap[EventBase]
+ ) -> None:
"""Write the state at the given event in the room.
This only gets called for backward extremities rather than for each
event.
"""
- pass
+ raise NotImplementedError()
- def write_invite(self, room_id: str, event: FrozenEvent, state: StateMap[dict]):
+ @abc.abstractmethod
+ def write_invite(
+ self, room_id: str, event: EventBase, state: StateMap[dict]
+ ) -> None:
"""Write an invite for the room, with associated invite state.
Args:
- room_id
- event
- state: A subset of the state at the
- invite, with a subset of the event keys (type, state_key
- content and sender)
+ room_id: The room ID the invite is for.
+ event: The invite event.
+ state: A subset of the state at the invite, with a subset of the
+ event keys (type, state_key content and sender).
"""
+ raise NotImplementedError()
- def finished(self):
+ @abc.abstractmethod
+ def finished(self) -> Any:
"""Called when all data has successfully been exported and written.
This functions return value is passed to the caller of
`export_user_data`.
"""
- pass
+ raise NotImplementedError()
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index c7dc07008a..a19c556437 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
import logging
import time
import unicodedata
@@ -22,6 +21,7 @@ import urllib.parse
from typing import (
TYPE_CHECKING,
Any,
+ Awaitable,
Callable,
Dict,
Iterable,
@@ -36,6 +36,8 @@ import attr
import bcrypt
import pymacaroons
+from twisted.web.http import Request
+
from synapse.api.constants import LoginType
from synapse.api.errors import (
AuthError,
@@ -47,20 +49,25 @@ from synapse.api.errors import (
UserDeactivatedError,
)
from synapse.api.ratelimiting import Ratelimiter
-from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
+from synapse.handlers._base import BaseHandler
+from synapse.handlers.ui_auth import (
+ INTERACTIVE_AUTH_CHECKERS,
+ UIAuthSessionDataConstants,
+)
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
+from synapse.http import get_request_user_agent
from synapse.http.server import finish_request, respond_with_html
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
+from synapse.storage.roommember import ProfileInfo
from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils
+from synapse.util.async_helpers import maybe_awaitable
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
@@ -193,39 +200,27 @@ class AuthHandler(BaseHandler):
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.password_enabled
- self._sso_enabled = (
- hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
- )
-
- # we keep this as a list despite the O(N^2) implication so that we can
- # keep PASSWORD first and avoid confusing clients which pick the first
- # type in the list. (NB that the spec doesn't require us to do so and
- # clients which favour types that they don't understand over those that
- # they do are technically broken)
+ self._password_localdb_enabled = hs.config.password_localdb_enabled
# start out by assuming PASSWORD is enabled; we will remove it later if not.
- login_types = []
- if hs.config.password_localdb_enabled:
- login_types.append(LoginType.PASSWORD)
+ login_types = set()
+ if self._password_localdb_enabled:
+ login_types.add(LoginType.PASSWORD)
for provider in self.password_providers:
- if hasattr(provider, "get_supported_login_types"):
- for t in provider.get_supported_login_types().keys():
- if t not in login_types:
- login_types.append(t)
+ login_types.update(provider.get_supported_login_types().keys())
if not self._password_enabled:
+ login_types.discard(LoginType.PASSWORD)
+
+ # Some clients just pick the first type in the list. In this case, we want
+ # them to use PASSWORD (rather than token or whatever), so we want to make sure
+ # that comes first, where it's present.
+ self._supported_login_types = []
+ if LoginType.PASSWORD in login_types:
+ self._supported_login_types.append(LoginType.PASSWORD)
login_types.remove(LoginType.PASSWORD)
-
- self._supported_login_types = login_types
-
- # Login types and UI Auth types have a heavy overlap, but are not
- # necessarily identical. Login types have SSO (and other login types)
- # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
- ui_auth_types = login_types.copy()
- if self._sso_enabled:
- ui_auth_types.append(LoginType.SSO)
- self._supported_ui_auth_types = ui_auth_types
+ self._supported_login_types.extend(login_types)
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
@@ -235,6 +230,9 @@ class AuthHandler(BaseHandler):
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
)
+ # The number of seconds to keep a UI auth session active.
+ self._ui_auth_session_timeout = hs.config.ui_auth_session_timeout
+
# Ratelimitier for failed /login attempts
self._failed_login_attempts_ratelimiter = Ratelimiter(
clock=hs.get_clock(),
@@ -266,10 +264,6 @@ class AuthHandler(BaseHandler):
# authenticating for an operation to occur on their account.
self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template
- # The following template is shown after a successful user interactive
- # authentication session. It tells the user they can close the window.
- self._sso_auth_success_template = hs.config.sso_auth_success_template
-
# The following template is shown during the SSO authentication process if
# the account is deactivated.
self._sso_account_deactivated_template = (
@@ -290,9 +284,8 @@ class AuthHandler(BaseHandler):
requester: Requester,
request: SynapseRequest,
request_body: Dict[str, Any],
- clientip: str,
description: str,
- ) -> Tuple[dict, str]:
+ ) -> Tuple[dict, Optional[str]]:
"""
Checks that the user is who they claim to be, via a UI auth.
@@ -307,8 +300,6 @@ class AuthHandler(BaseHandler):
request_body: The body of the request sent by the client
- clientip: The IP address of the client.
-
description: A human readable string to be displayed to the user that
describes the operation happening on their account.
@@ -319,7 +310,8 @@ class AuthHandler(BaseHandler):
have been given only in a previous call).
'session_id' is the ID of this session, either passed in by the
- client or assigned by this call
+ client or assigned by this call. This is None if UI auth was
+ skipped (by re-using a previous validation).
Raises:
InteractiveAuthIncompleteError if the client has not yet completed
@@ -333,40 +325,90 @@ class AuthHandler(BaseHandler):
"""
- user_id = requester.user.to_string()
+ if self._ui_auth_session_timeout:
+ last_validated = await self.store.get_access_token_last_validated(
+ requester.access_token_id
+ )
+ if self.clock.time_msec() - last_validated < self._ui_auth_session_timeout:
+ # Return the input parameters, minus the auth key, which matches
+ # the logic in check_ui_auth.
+ request_body.pop("auth", None)
+ return request_body, None
+
+ requester_user_id = requester.user.to_string()
# Check if we should be ratelimited due to too many previous failed attempts
- self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
+ self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False)
# build a list of supported flows
- flows = [[login_type] for login_type in self._supported_ui_auth_types]
+ supported_ui_auth_types = await self._get_available_ui_auth_types(
+ requester.user
+ )
+ flows = [[login_type] for login_type in supported_ui_auth_types]
+
+ def get_new_session_data() -> JsonDict:
+ return {UIAuthSessionDataConstants.REQUEST_USER_ID: requester_user_id}
try:
result, params, session_id = await self.check_ui_auth(
- flows, request, request_body, clientip, description
+ flows, request, request_body, description, get_new_session_data,
)
except LoginError:
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
- self._failed_uia_attempts_ratelimiter.can_do_action(user_id)
+ self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id)
raise
# find the completed login type
- for login_type in self._supported_ui_auth_types:
+ for login_type in supported_ui_auth_types:
if login_type not in result:
continue
- user_id = result[login_type]
+ validated_user_id = result[login_type]
break
else:
# this can't happen
raise Exception("check_auth returned True but no successful login type")
# check that the UI auth matched the access token
- if user_id != requester.user.to_string():
+ if validated_user_id != requester_user_id:
raise AuthError(403, "Invalid auth")
+ # Note that the access token has been validated.
+ await self.store.update_access_token_last_validated(requester.access_token_id)
+
return params, session_id
+ async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
+ """Get a list of the authentication types this user can use
+ """
+
+ ui_auth_types = set()
+
+ # if the HS supports password auth, and the user has a non-null password, we
+ # support password auth
+ if self._password_localdb_enabled and self._password_enabled:
+ lookupres = await self._find_user_id_and_pwd_hash(user.to_string())
+ if lookupres:
+ _, password_hash = lookupres
+ if password_hash:
+ ui_auth_types.add(LoginType.PASSWORD)
+
+ # also allow auth from password providers
+ for provider in self.password_providers:
+ for t in provider.get_supported_login_types().keys():
+ if t == LoginType.PASSWORD and not self._password_enabled:
+ continue
+ ui_auth_types.add(t)
+
+ # if sso is enabled, allow the user to log in via SSO iff they have a mapping
+ # from sso to mxid.
+ if await self.hs.get_sso_handler().get_identity_providers_for_user(
+ user.to_string()
+ ):
+ ui_auth_types.add(LoginType.SSO)
+
+ return ui_auth_types
+
def get_enabled_auth_types(self):
"""Return the enabled user-interactive authentication types
@@ -380,8 +422,8 @@ class AuthHandler(BaseHandler):
flows: List[List[str]],
request: SynapseRequest,
clientdict: Dict[str, Any],
- clientip: str,
description: str,
+ get_new_session_data: Optional[Callable[[], JsonDict]] = None,
) -> Tuple[dict, dict, str]:
"""
Takes a dictionary sent by the client in the login / registration
@@ -402,11 +444,16 @@ class AuthHandler(BaseHandler):
clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
- clientip: The IP address of the client.
-
description: A human readable string to be displayed to the user that
describes the operation happening on their account.
+ get_new_session_data:
+ an optional callback which will be called when starting a new session.
+ it should return data to be stored as part of the session.
+
+ The keys of the returned data should be entries in
+ UIAuthSessionDataConstants.
+
Returns:
A tuple of (creds, params, session_id).
@@ -423,13 +470,10 @@ class AuthHandler(BaseHandler):
all the stages in any of the permitted flows.
"""
- authdict = None
sid = None # type: Optional[str]
- if clientdict and "auth" in clientdict:
- authdict = clientdict["auth"]
- del clientdict["auth"]
- if "session" in authdict:
- sid = authdict["session"]
+ authdict = clientdict.pop("auth", {})
+ if "session" in authdict:
+ sid = authdict["session"]
# Convert the URI and method to strings.
uri = request.uri.decode("utf-8")
@@ -437,10 +481,15 @@ class AuthHandler(BaseHandler):
# If there's no session ID, create a new session.
if not sid:
+ new_session_data = get_new_session_data() if get_new_session_data else {}
+
session = await self.store.create_ui_auth_session(
clientdict, uri, method, description
)
+ for k, v in new_session_data.items():
+ await self.set_session_data(session.session_id, k, v)
+
else:
try:
session = await self.store.get_ui_auth_session(sid)
@@ -496,7 +545,8 @@ class AuthHandler(BaseHandler):
# authentication flow.
await self.store.set_ui_auth_clientdict(sid, clientdict)
- user_agent = request.get_user_agent("")
+ user_agent = get_request_user_agent(request)
+ clientip = request.getClientIP()
await self.store.add_user_agent_ip_to_ui_auth_session(
session.session_id, user_agent, clientip
@@ -518,22 +568,14 @@ class AuthHandler(BaseHandler):
session.session_id, login_type, result
)
except LoginError as e:
- if login_type == LoginType.EMAIL_IDENTITY:
- # riot used to have a bug where it would request a new
- # validation token (thus sending a new email) each time it
- # got a 401 with a 'flows' field.
- # (https://github.com/vector-im/vector-web/issues/2447).
- #
- # Grandfather in the old behaviour for now to avoid
- # breaking old riot deployments.
- raise
-
# this step failed. Merge the error dict into the response
# so that the client can have another go.
errordict = e.error_dict()
creds = await self.store.get_completed_ui_auth_stages(session.session_id)
for f in flows:
+ # If all the required credentials have been supplied, the user has
+ # successfully completed the UI auth process!
if len(set(f) - set(creds)) == 0:
# it's very useful to know what args are stored, but this can
# include the password in the case of registering, so only log
@@ -599,7 +641,8 @@ class AuthHandler(BaseHandler):
Args:
session_id: The ID of this session as returned from check_auth
- key: The key to store the data under
+ key: The key to store the data under. An entry from
+ UIAuthSessionDataConstants.
value: The data to store
"""
try:
@@ -615,7 +658,8 @@ class AuthHandler(BaseHandler):
Args:
session_id: The ID of this session as returned from check_auth
- key: The key to store the data under
+ key: The key the data was stored under. An entry from
+ UIAuthSessionDataConstants.
default: Value to return if the key has not been set
"""
try:
@@ -709,6 +753,7 @@ class AuthHandler(BaseHandler):
device_id: Optional[str],
valid_until_ms: Optional[int],
puppets_user_id: Optional[str] = None,
+ is_appservice_ghost: bool = False,
) -> str:
"""
Creates a new access token for the user with the given user ID.
@@ -725,6 +770,7 @@ class AuthHandler(BaseHandler):
we should always have a device ID)
valid_until_ms: when the token is valid until. None for
no expiry.
+ is_appservice_ghost: Whether the user is an application ghost user
Returns:
The access token for the user's session.
Raises:
@@ -745,7 +791,11 @@ class AuthHandler(BaseHandler):
"Logging in user %s on device %s%s", user_id, device_id, fmt_expiry
)
- await self.auth.check_auth_blocking(user_id)
+ if (
+ not is_appservice_ghost
+ or self.hs.config.appservice.track_appservice_user_ips
+ ):
+ await self.auth.check_auth_blocking(user_id)
access_token = self.macaroon_gen.generate_access_token(user_id)
await self.store.add_access_token_to_user(
@@ -831,7 +881,7 @@ class AuthHandler(BaseHandler):
async def validate_login(
self, login_submission: Dict[str, Any], ratelimit: bool = False,
- ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
+ ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate auth types which don't
@@ -974,7 +1024,7 @@ class AuthHandler(BaseHandler):
async def _validate_userid_login(
self, username: str, login_submission: Dict[str, Any],
- ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
+ ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Helper for validate_login
Handles login, once we've mapped 3pids onto userids
@@ -1029,7 +1079,7 @@ class AuthHandler(BaseHandler):
if result:
return result
- if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
+ if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
known_login_type = True
# we've already checked that there is a (valid) password field
@@ -1052,7 +1102,7 @@ class AuthHandler(BaseHandler):
async def check_password_provider_3pid(
self, medium: str, address: str, password: str
- ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], None]]]:
+ ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Check if a password provider is able to validate a thirdparty login
Args:
@@ -1283,12 +1333,12 @@ class AuthHandler(BaseHandler):
else:
return False
- async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
+ async def start_sso_ui_auth(self, request: SynapseRequest, session_id: str) -> str:
"""
Get the HTML for the SSO redirect confirmation page.
Args:
- redirect_url: The URL to redirect to the SSO provider.
+ request: The incoming HTTP request
session_id: The user interactive authentication session ID.
Returns:
@@ -1298,38 +1348,48 @@ class AuthHandler(BaseHandler):
session = await self.store.get_ui_auth_session(session_id)
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
- return self._sso_auth_confirm_template.render(
- description=session.description, redirect_url=redirect_url,
+
+ user_id_to_verify = await self.get_session_data(
+ session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
+ ) # type: str
+
+ idps = await self.hs.get_sso_handler().get_identity_providers_for_user(
+ user_id_to_verify
)
- async def complete_sso_ui_auth(
- self, registered_user_id: str, session_id: str, request: SynapseRequest,
- ):
- """Having figured out a mxid for this user, complete the HTTP request
+ if not idps:
+ # we checked that the user had some remote identities before offering an SSO
+ # flow, so either it's been deleted or the client has requested SSO despite
+ # it not being offered.
+ raise SynapseError(400, "User has no SSO identities")
- Args:
- registered_user_id: The registered user ID to complete SSO login for.
- request: The request to complete.
- client_redirect_url: The URL to which to redirect the user at the end of the
- process.
- """
- # Mark the stage of the authentication as successful.
- # Save the user who authenticated with SSO, this will be used to ensure
- # that the account be modified is also the person who logged in.
- await self.store.mark_ui_auth_stage_complete(
- session_id, LoginType.SSO, registered_user_id
+ # for now, just pick one
+ idp_id, sso_auth_provider = next(iter(idps.items()))
+ if len(idps) > 0:
+ logger.warning(
+ "User %r has previously logged in with multiple SSO IdPs; arbitrarily "
+ "picking %r",
+ user_id_to_verify,
+ idp_id,
+ )
+
+ redirect_url = await sso_auth_provider.handle_redirect_request(
+ request, None, session_id
)
- # Render the HTML and return.
- html = self._sso_auth_success_template
- respond_with_html(request, 200, html)
+ return self._sso_auth_confirm_template.render(
+ description=session.description,
+ redirect_url=redirect_url,
+ idp=sso_auth_provider,
+ )
async def complete_sso_login(
self,
registered_user_id: str,
- request: SynapseRequest,
+ request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
+ new_user: bool = False,
):
"""Having figured out a mxid for this user, complete the HTTP request
@@ -1340,6 +1400,8 @@ class AuthHandler(BaseHandler):
process.
extra_attributes: Extra attributes which will be passed to the client
during successful login. Must be JSON serializable.
+ new_user: True if we should use wording appropriate to a user who has just
+ registered.
"""
# If the account has been deactivated, do not proceed with the login
# flow.
@@ -1348,22 +1410,37 @@ class AuthHandler(BaseHandler):
respond_with_html(request, 403, self._sso_account_deactivated_template)
return
+ profile = await self.store.get_profileinfo(
+ UserID.from_string(registered_user_id).localpart
+ )
+
self._complete_sso_login(
- registered_user_id, request, client_redirect_url, extra_attributes
+ registered_user_id,
+ request,
+ client_redirect_url,
+ extra_attributes,
+ new_user=new_user,
+ user_profile_data=profile,
)
def _complete_sso_login(
self,
registered_user_id: str,
- request: SynapseRequest,
+ request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
+ new_user: bool = False,
+ user_profile_data: Optional[ProfileInfo] = None,
):
"""
The synchronous portion of complete_sso_login.
This exists purely for backwards compatibility of synapse.module_api.ModuleApi.
"""
+
+ if user_profile_data is None:
+ user_profile_data = ProfileInfo(None, None)
+
# Store any extra attributes which will be passed in the login response.
# Note that this is per-user so it may overwrite a previous value, this
# is considered OK since the newest SSO attributes should be most valid.
@@ -1401,6 +1478,9 @@ class AuthHandler(BaseHandler):
display_url=redirect_url_no_params,
redirect_url=redirect_url,
server_name=self._server_name,
+ new_user=new_user,
+ user_id=registered_user_id,
+ user_profile=user_profile_data,
)
respond_with_html(request, 200, html)
@@ -1438,8 +1518,8 @@ class AuthHandler(BaseHandler):
@staticmethod
def add_query_param_to_url(url: str, param_name: str, param: Any):
url_parts = list(urllib.parse.urlparse(url))
- query = dict(urllib.parse.parse_qsl(url_parts[4]))
- query.update({param_name: param})
+ query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True)
+ query.append((param_name, param))
url_parts[4] = urllib.parse.urlencode(query)
return urllib.parse.urlunparse(url_parts)
@@ -1609,6 +1689,6 @@ class PasswordProvider:
# This might return an awaitable, if it does block the log out
# until it completes.
- result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
- if inspect.isawaitable(result):
- await result
+ await maybe_awaitable(
+ g(user_id=user_id, device_id=device_id, access_token=access_token,)
+ )
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index f4ea0a9767..bd35d1fb87 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -13,13 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import urllib
-from typing import TYPE_CHECKING, Dict, Optional, Tuple
+import urllib.parse
+from typing import TYPE_CHECKING, Dict, Optional
from xml.etree import ElementTree as ET
+import attr
+
from twisted.web.client import PartialDownloadError
-from synapse.api.errors import Codes, LoginError
+from synapse.api.errors import HttpResponseException
+from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.types import UserID, map_username_to_mxid_localpart
@@ -29,6 +32,26 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+class CasError(Exception):
+ """Used to catch errors when validating the CAS ticket.
+ """
+
+ def __init__(self, error, error_description=None):
+ self.error = error
+ self.error_description = error_description
+
+ def __str__(self):
+ if self.error_description:
+ return "{}: {}".format(self.error, self.error_description)
+ return self.error
+
+
+@attr.s(slots=True, frozen=True)
+class CasResponse:
+ username = attr.ib(type=str)
+ attributes = attr.ib(type=Dict[str, Optional[str]])
+
+
class CasHandler:
"""
Utility class for to handle the response from a CAS SSO service.
@@ -40,6 +63,7 @@ class CasHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self._hostname = hs.hostname
+ self._store = hs.get_datastore()
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
@@ -50,6 +74,21 @@ class CasHandler:
self._http_client = hs.get_proxied_http_client()
+ # identifier for the external_ids table
+ self.idp_id = "cas"
+
+ # user-facing name of this auth provider
+ self.idp_name = "CAS"
+
+ # we do not currently support brands/icons for CAS auth, but this is required by
+ # the SsoIdentityProvider protocol type.
+ self.idp_icon = None
+ self.idp_brand = None
+
+ self._sso_handler = hs.get_sso_handler()
+
+ self._sso_handler.register_identity_provider(self)
+
def _build_service_param(self, args: Dict[str, str]) -> str:
"""
Generates a value to use as the "service" parameter when redirecting or
@@ -61,22 +100,24 @@ class CasHandler:
Returns:
The URL to use as a "service" parameter.
"""
- return "%s%s?%s" % (
- self._cas_service_url,
- "/_matrix/client/r0/login/cas/ticket",
- urllib.parse.urlencode(args),
- )
+ return "%s?%s" % (self._cas_service_url, urllib.parse.urlencode(args),)
async def _validate_ticket(
self, ticket: str, service_args: Dict[str, str]
- ) -> Tuple[str, Optional[str]]:
+ ) -> CasResponse:
"""
- Validate a CAS ticket with the server, parse the response, and return the user and display name.
+ Validate a CAS ticket with the server, and return the parsed the response.
Args:
ticket: The CAS ticket from the client.
service_args: Additional arguments to include in the service URL.
- Should be the same as those passed to `get_redirect_url`.
+ Should be the same as those passed to `handle_redirect_request`.
+
+ Raises:
+ CasError: If there's an error parsing the CAS response.
+
+ Returns:
+ The parsed CAS response.
"""
uri = self._cas_server_url + "/proxyValidate"
args = {
@@ -89,77 +130,91 @@ class CasHandler:
# Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data
body = pde.response
+ except HttpResponseException as e:
+ description = (
+ (
+ 'Authorization server responded with a "{status}" error '
+ "while exchanging the authorization code."
+ ).format(status=e.code),
+ )
+ raise CasError("server_error", description) from e
- user, attributes = self._parse_cas_response(body)
- displayname = attributes.pop(self._cas_displayname_attribute, None)
-
- for required_attribute, required_value in self._cas_required_attributes.items():
- # If required attribute was not in CAS Response - Forbidden
- if required_attribute not in attributes:
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
-
- # Also need to check value
- if required_value is not None:
- actual_value = attributes[required_attribute]
- # If required attribute value does not match expected - Forbidden
- if required_value != actual_value:
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
-
- return user, displayname
+ return self._parse_cas_response(body)
- def _parse_cas_response(
- self, cas_response_body: bytes
- ) -> Tuple[str, Dict[str, Optional[str]]]:
+ def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse:
"""
Retrieve the user and other parameters from the CAS response.
Args:
cas_response_body: The response from the CAS query.
+ Raises:
+ CasError: If there's an error parsing the CAS response.
+
Returns:
- A tuple of the user and a mapping of other attributes.
+ The parsed CAS response.
"""
- user = None
- attributes = {}
- try:
- root = ET.fromstring(cas_response_body)
- if not root.tag.endswith("serviceResponse"):
- raise Exception("root of CAS response is not serviceResponse")
- success = root[0].tag.endswith("authenticationSuccess")
- for child in root[0]:
- if child.tag.endswith("user"):
- user = child.text
- if child.tag.endswith("attributes"):
- for attribute in child:
- # ElementTree library expands the namespace in
- # attribute tags to the full URL of the namespace.
- # We don't care about namespace here and it will always
- # be encased in curly braces, so we remove them.
- tag = attribute.tag
- if "}" in tag:
- tag = tag.split("}")[1]
- attributes[tag] = attribute.text
- if user is None:
- raise Exception("CAS response does not contain user")
- except Exception:
- logger.exception("Error parsing CAS response")
- raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
- if not success:
- raise LoginError(
- 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
+
+ # Ensure the response is valid.
+ root = ET.fromstring(cas_response_body)
+ if not root.tag.endswith("serviceResponse"):
+ raise CasError(
+ "missing_service_response",
+ "root of CAS response is not serviceResponse",
)
- return user, attributes
- def get_redirect_url(self, service_args: Dict[str, str]) -> str:
- """
- Generates a URL for the CAS server where the client should be redirected.
+ success = root[0].tag.endswith("authenticationSuccess")
+ if not success:
+ raise CasError("unsucessful_response", "Unsuccessful CAS response")
+
+ # Iterate through the nodes and pull out the user and any extra attributes.
+ user = None
+ attributes = {}
+ for child in root[0]:
+ if child.tag.endswith("user"):
+ user = child.text
+ if child.tag.endswith("attributes"):
+ for attribute in child:
+ # ElementTree library expands the namespace in
+ # attribute tags to the full URL of the namespace.
+ # We don't care about namespace here and it will always
+ # be encased in curly braces, so we remove them.
+ tag = attribute.tag
+ if "}" in tag:
+ tag = tag.split("}")[1]
+ attributes[tag] = attribute.text
+
+ # Ensure a user was found.
+ if user is None:
+ raise CasError("no_user", "CAS response does not contain user")
+
+ return CasResponse(user, attributes)
+
+ async def handle_redirect_request(
+ self,
+ request: SynapseRequest,
+ client_redirect_url: Optional[bytes],
+ ui_auth_session_id: Optional[str] = None,
+ ) -> str:
+ """Generates a URL for the CAS server where the client should be redirected.
Args:
- service_args: Additional arguments to include in the final redirect URL.
+ request: the incoming HTTP request
+ client_redirect_url: the URL that we should redirect the
+ client to after login (or None for UI Auth).
+ ui_auth_session_id: The session ID of the ongoing UI Auth (or
+ None if this is a login).
Returns:
- The URL to redirect the client to.
+ URL to redirect to
"""
+
+ if ui_auth_session_id:
+ service_args = {"session": ui_auth_session_id}
+ else:
+ assert client_redirect_url
+ service_args = {"redirectUrl": client_redirect_url.decode("utf8")}
+
args = urllib.parse.urlencode(
{"service": self._build_service_param(service_args)}
)
@@ -201,59 +256,150 @@ class CasHandler:
args["redirectUrl"] = client_redirect_url
if session:
args["session"] = session
- username, user_display_name = await self._validate_ticket(ticket, args)
-
- # Pull out the user-agent and IP from the request.
- user_agent = request.get_user_agent("")
- ip_address = self.hs.get_ip_from_request(request)
- # Get the matrix ID from the CAS username.
- user_id = await self._map_cas_user_to_matrix_user(
- username, user_display_name, user_agent, ip_address
+ try:
+ cas_response = await self._validate_ticket(ticket, args)
+ except CasError as e:
+ logger.exception("Could not validate ticket")
+ self._sso_handler.render_error(request, e.error, e.error_description, 401)
+ return
+
+ await self._handle_cas_response(
+ request, cas_response, client_redirect_url, session
)
+ async def _handle_cas_response(
+ self,
+ request: SynapseRequest,
+ cas_response: CasResponse,
+ client_redirect_url: Optional[str],
+ session: Optional[str],
+ ) -> None:
+ """Handle a CAS response to a ticket request.
+
+ Assumes that the response has been validated. Maps the user onto an MXID,
+ registering them if necessary, and returns a response to the browser.
+
+ Args:
+ request: the incoming request from the browser. We'll respond to it with an
+ HTML page or a redirect
+
+ cas_response: The parsed CAS response.
+
+ client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given.
+ This should be the same as the redirectUrl from the original `/login/sso/redirect` request.
+
+ session: The session parameter from the `/cas/ticket` HTTP request, if given.
+ This should be the UI Auth session id.
+ """
+
+ # first check if we're doing a UIA
if session:
- await self._auth_handler.complete_sso_ui_auth(
- user_id, session, request,
+ return await self._sso_handler.complete_sso_ui_auth_request(
+ self.idp_id, cas_response.username, session, request,
)
- else:
- # If this not a UI auth request than there must be a redirect URL.
- assert client_redirect_url
- await self._auth_handler.complete_sso_login(
- user_id, request, client_redirect_url
- )
+ # otherwise, we're handling a login request.
+
+ # Ensure that the attributes of the logged in user meet the required
+ # attributes.
+ for required_attribute, required_value in self._cas_required_attributes.items():
+ # If required attribute was not in CAS Response - Forbidden
+ if required_attribute not in cas_response.attributes:
+ self._sso_handler.render_error(
+ request,
+ "unauthorised",
+ "You are not authorised to log in here.",
+ 401,
+ )
+ return
+
+ # Also need to check value
+ if required_value is not None:
+ actual_value = cas_response.attributes[required_attribute]
+ # If required attribute value does not match expected - Forbidden
+ if required_value != actual_value:
+ self._sso_handler.render_error(
+ request,
+ "unauthorised",
+ "You are not authorised to log in here.",
+ 401,
+ )
+ return
+
+ # Call the mapper to register/login the user
+
+ # If this not a UI auth request than there must be a redirect URL.
+ assert client_redirect_url is not None
- async def _map_cas_user_to_matrix_user(
+ try:
+ await self._complete_cas_login(cas_response, request, client_redirect_url)
+ except MappingException as e:
+ logger.exception("Could not map user")
+ self._sso_handler.render_error(request, "mapping_error", str(e))
+
+ async def _complete_cas_login(
self,
- remote_user_id: str,
- display_name: Optional[str],
- user_agent: str,
- ip_address: str,
- ) -> str:
+ cas_response: CasResponse,
+ request: SynapseRequest,
+ client_redirect_url: str,
+ ) -> None:
"""
- Given a CAS username, retrieve the user ID for it and possibly register the user.
+ Given a CAS response, complete the login flow
- Args:
- remote_user_id: The username from the CAS response.
- display_name: The display name from the CAS response.
- user_agent: The user agent of the client making the request.
- ip_address: The IP address of the client making the request.
+ Retrieves the remote user ID, registers the user if necessary, and serves
+ a redirect back to the client with a login-token.
- Returns:
- The user ID associated with this response.
+ Args:
+ cas_response: The parsed CAS response.
+ request: The request to respond to
+ client_redirect_url: The redirect URL passed in by the client.
+
+ Raises:
+ MappingException if there was a problem mapping the response to a user.
+ RedirectException: some mapping providers may raise this if they need
+ to redirect to an interstitial page.
"""
+ # Note that CAS does not support a mapping provider, so the logic is hard-coded.
+ localpart = map_username_to_mxid_localpart(cas_response.username)
+
+ async def cas_response_to_user_attributes(failures: int) -> UserAttributes:
+ """
+ Map from CAS attributes to user attributes.
+ """
+ # Due to the grandfathering logic matching any previously registered
+ # mxids it isn't expected for there to be any failures.
+ if failures:
+ raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
+
+ display_name = cas_response.attributes.get(
+ self._cas_displayname_attribute, None
+ )
- localpart = map_username_to_mxid_localpart(remote_user_id)
- user_id = UserID(localpart, self._hostname).to_string()
- registered_user_id = await self._auth_handler.check_user_exists(user_id)
+ return UserAttributes(localpart=localpart, display_name=display_name)
- # If the user does not exist, register it.
- if not registered_user_id:
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart,
- default_display_name=display_name,
- user_agent_ips=[(user_agent, ip_address)],
+ async def grandfather_existing_users() -> Optional[str]:
+ # Since CAS did not always use the user_external_ids table, always
+ # to attempt to map to existing users.
+ user_id = UserID(localpart, self._hostname).to_string()
+
+ logger.debug(
+ "Looking for existing account based on mapped %s", user_id,
)
- return registered_user_id
+ users = await self._store.get_users_by_id_case_insensitive(user_id)
+ if users:
+ registered_user_id = list(users.keys())[0]
+ logger.info("Grandfathering mapping to %s", registered_user_id)
+ return registered_user_id
+
+ return None
+
+ await self._sso_handler.complete_sso_login_request(
+ self.idp_id,
+ cas_response.username,
+ request,
+ client_redirect_url,
+ cas_response_to_user_attributes,
+ grandfather_existing_users,
+ )
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index e808142365..c4a3b26a84 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional
from synapse.api.errors import SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import UserID, create_requester
+from synapse.types import Requester, UserID, create_requester
from ._base import BaseHandler
@@ -38,6 +38,7 @@ class DeactivateAccountHandler(BaseHandler):
self._device_handler = hs.get_device_handler()
self._room_member_handler = hs.get_room_member_handler()
self._identity_handler = hs.get_identity_handler()
+ self._profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
self._server_name = hs.hostname
@@ -52,16 +53,23 @@ class DeactivateAccountHandler(BaseHandler):
self._account_validity_enabled = hs.config.account_validity.enabled
async def deactivate_account(
- self, user_id: str, erase_data: bool, id_server: Optional[str] = None
+ self,
+ user_id: str,
+ erase_data: bool,
+ requester: Requester,
+ id_server: Optional[str] = None,
+ by_admin: bool = False,
) -> bool:
"""Deactivate a user's account
Args:
user_id: ID of user to be deactivated
erase_data: whether to GDPR-erase the user's data
+ requester: The user attempting to make this change.
id_server: Use the given identity server when unbinding
any threepids. If None then will attempt to unbind using the
identity server specified when binding (if known).
+ by_admin: Whether this change was made by an administrator.
Returns:
True if identity server supports removing threepids, otherwise False.
@@ -121,6 +129,12 @@ class DeactivateAccountHandler(BaseHandler):
# Mark the user as erased, if they asked for that
if erase_data:
+ user = UserID.from_string(user_id)
+ # Remove avatar URL from this user
+ await self._profile_handler.set_avatar_url(user, requester, "", by_admin)
+ # Remove displayname from this user
+ await self._profile_handler.set_displayname(user, requester, "", by_admin)
+
logger.info("Marking %s as erased", user_id)
await self.store.mark_user_erased(user_id)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index debb1b4f29..0863154f7a 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api import errors
from synapse.api.constants import EventTypes
@@ -62,7 +62,7 @@ class DeviceWorkerHandler(BaseHandler):
self._auth_handler = hs.get_auth_handler()
@trace
- async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]:
+ async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
"""
Retrieve the given user's devices
@@ -85,7 +85,7 @@ class DeviceWorkerHandler(BaseHandler):
return devices
@trace
- async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
+ async def get_device(self, user_id: str, device_id: str) -> JsonDict:
""" Retrieve the given device
Args:
@@ -598,7 +598,7 @@ class DeviceHandler(DeviceWorkerHandler):
def _update_device_from_client_ips(
- device: Dict[str, Any], client_ips: Dict[Tuple[str, str], Dict[str, Any]]
+ device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict]
) -> None:
ip = client_ips.get((device["user_id"], device["device_id"]), {})
device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
@@ -946,8 +946,8 @@ class DeviceListUpdater:
async def process_cross_signing_key_update(
self,
user_id: str,
- master_key: Optional[Dict[str, Any]],
- self_signing_key: Optional[Dict[str, Any]],
+ master_key: Optional[JsonDict],
+ self_signing_key: Optional[JsonDict],
) -> List[str]:
"""Process the given new master and self-signing key for the given remote user.
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 9cac5a8463..0c7737e09d 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -24,6 +24,7 @@ from synapse.logging.opentracing import (
set_tag,
start_active_span,
)
+from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.stringutils import random_string
@@ -44,13 +45,37 @@ class DeviceMessageHandler:
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.is_mine = hs.is_mine
- self.federation = hs.get_federation_sender()
- hs.get_federation_registry().register_edu_handler(
- "m.direct_to_device", self.on_direct_to_device_edu
- )
+ # We only need to poke the federation sender explicitly if its on the
+ # same instance. Other federation sender instances will get notified by
+ # `synapse.app.generic_worker.FederationSenderHandler` when it sees it
+ # in the to-device replication stream.
+ self.federation_sender = None
+ if hs.should_send_federation():
+ self.federation_sender = hs.get_federation_sender()
+
+ # If we can handle the to device EDUs we do so, otherwise we route them
+ # to the appropriate worker.
+ if hs.get_instance_name() in hs.config.worker.writers.to_device:
+ hs.get_federation_registry().register_edu_handler(
+ "m.direct_to_device", self.on_direct_to_device_edu
+ )
+ else:
+ hs.get_federation_registry().register_instances_for_edu(
+ "m.direct_to_device", hs.config.worker.writers.to_device,
+ )
- self._device_list_updater = hs.get_device_handler().device_list_updater
+ # The handler to call when we think a user's device list might be out of
+ # sync. We do all device list resyncing on the master instance, so if
+ # we're on a worker we hit the device resync replication API.
+ if hs.config.worker.worker_app is None:
+ self._user_device_resync = (
+ hs.get_device_handler().device_list_updater.user_device_resync
+ )
+ else:
+ self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
+ hs
+ )
async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
local_messages = {}
@@ -138,9 +163,7 @@ class DeviceMessageHandler:
await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
# Immediately attempt a resync in the background
- run_in_background(
- self._device_list_updater.user_device_resync, sender_user_id
- )
+ run_in_background(self._user_device_resync, user_id=sender_user_id)
async def send_device_message(
self,
@@ -195,7 +218,8 @@ class DeviceMessageHandler:
)
log_kv({"remote_messages": remote_messages})
- for destination in remote_messages.keys():
- # Enqueue a new federation transaction to send the new
- # device messages to each remote destination.
- self.federation.send_device_messages(destination)
+ if self.federation_sender:
+ for destination in remote_messages.keys():
+ # Enqueue a new federation transaction to send the new
+ # device messages to each remote destination.
+ self.federation_sender.send_device_messages(destination)
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index ad5683d251..abcf86352d 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -133,7 +133,9 @@ class DirectoryHandler(BaseHandler):
403, "You must be in the room to create an alias for it"
)
- if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
+ if not await self.spam_checker.user_may_create_room_alias(
+ user_id, room_alias
+ ):
raise AuthError(403, "This user is not permitted to create this alias")
if not self.config.is_alias_creation_allowed(
@@ -409,7 +411,7 @@ class DirectoryHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- if not self.spam_checker.user_may_publish_room(user_id, room_id):
+ if not await self.spam_checker.user_may_publish_room(user_id, room_id):
raise AuthError(
403, "This user is not permitted to publish rooms to the room list"
)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 929752150d..8f3a6b35a4 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -16,7 +16,7 @@
# limitations under the License.
import logging
-from typing import Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json
@@ -31,6 +31,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import (
+ JsonDict,
UserID,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
@@ -40,11 +41,14 @@ from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
class E2eKeysHandler:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.federation = hs.get_federation_client()
self.device_handler = hs.get_device_handler()
@@ -78,7 +82,9 @@ class E2eKeysHandler:
)
@trace
- async def query_devices(self, query_body, timeout, from_user_id):
+ async def query_devices(
+ self, query_body: JsonDict, timeout: int, from_user_id: str
+ ) -> JsonDict:
""" Handle a device key query from a client
{
@@ -98,12 +104,14 @@ class E2eKeysHandler:
}
Args:
- from_user_id (str): the user making the query. This is used when
+ from_user_id: the user making the query. This is used when
adding cross-signing signatures to limit what signatures users
can see.
"""
- device_keys_query = query_body.get("device_keys", {})
+ device_keys_query = query_body.get(
+ "device_keys", {}
+ ) # type: Dict[str, Iterable[str]]
# separate users by domain.
# make a map from domain to user_id to device_ids
@@ -121,7 +129,8 @@ class E2eKeysHandler:
set_tag("remote_key_query", remote_queries)
# First get local devices.
- failures = {}
+ # A map of destination -> failure response.
+ failures = {} # type: Dict[str, JsonDict]
results = {}
if local_query:
local_result = await self.query_local_devices(local_query)
@@ -135,9 +144,10 @@ class E2eKeysHandler:
)
# Now attempt to get any remote devices from our local cache.
- remote_queries_not_in_cache = {}
+ # A map of destination -> user ID -> device IDs.
+ remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]]
if remote_queries:
- query_list = []
+ query_list = [] # type: List[Tuple[str, Optional[str]]]
for user_id, device_ids in remote_queries.items():
if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids)
@@ -284,15 +294,15 @@ class E2eKeysHandler:
return ret
async def get_cross_signing_keys_from_cache(
- self, query, from_user_id
+ self, query: Iterable[str], from_user_id: Optional[str]
) -> Dict[str, Dict[str, dict]]:
"""Get cross-signing keys for users from the database
Args:
- query (Iterable[string]) an iterable of user IDs. A dict whose keys
+ query: an iterable of user IDs. A dict whose keys
are user IDs satisfies this, so the query format used for
query_devices can be used here.
- from_user_id (str): the user making the query. This is used when
+ from_user_id: the user making the query. This is used when
adding cross-signing signatures to limit what signatures users
can see.
@@ -315,14 +325,12 @@ class E2eKeysHandler:
if "self_signing" in user_info:
self_signing_keys[user_id] = user_info["self_signing"]
- if (
- from_user_id in keys
- and keys[from_user_id] is not None
- and "user_signing" in keys[from_user_id]
- ):
- # users can see other users' master and self-signing keys, but can
- # only see their own user-signing keys
- user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"]
+ # users can see other users' master and self-signing keys, but can
+ # only see their own user-signing keys
+ if from_user_id:
+ from_user_key = keys.get(from_user_id)
+ if from_user_key and "user_signing" in from_user_key:
+ user_signing_keys[from_user_id] = from_user_key["user_signing"]
return {
"master_keys": master_keys,
@@ -344,9 +352,9 @@ class E2eKeysHandler:
A map from user_id -> device_id -> device details
"""
set_tag("local_query", query)
- local_query = []
+ local_query = [] # type: List[Tuple[str, Optional[str]]]
- result_dict = {}
+ result_dict = {} # type: Dict[str, Dict[str, dict]]
for user_id, device_ids in query.items():
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
@@ -380,10 +388,14 @@ class E2eKeysHandler:
log_kv(results)
return result_dict
- async def on_federation_query_client_keys(self, query_body):
+ async def on_federation_query_client_keys(
+ self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
+ ) -> JsonDict:
""" Handle a device key query from a federated server
"""
- device_keys_query = query_body.get("device_keys", {})
+ device_keys_query = query_body.get(
+ "device_keys", {}
+ ) # type: Dict[str, Optional[List[str]]]
res = await self.query_local_devices(device_keys_query)
ret = {"device_keys": res}
@@ -397,31 +409,34 @@ class E2eKeysHandler:
return ret
@trace
- async def claim_one_time_keys(self, query, timeout):
- local_query = []
- remote_queries = {}
+ async def claim_one_time_keys(
+ self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
+ ) -> JsonDict:
+ local_query = [] # type: List[Tuple[str, str, str]]
+ remote_queries = {} # type: Dict[str, Dict[str, Dict[str, str]]]
- for user_id, device_keys in query.get("one_time_keys", {}).items():
+ for user_id, one_time_keys in query.get("one_time_keys", {}).items():
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
- for device_id, algorithm in device_keys.items():
+ for device_id, algorithm in one_time_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
domain = get_domain_from_id(user_id)
- remote_queries.setdefault(domain, {})[user_id] = device_keys
+ remote_queries.setdefault(domain, {})[user_id] = one_time_keys
set_tag("local_key_query", local_query)
set_tag("remote_key_query", remote_queries)
results = await self.store.claim_e2e_one_time_keys(local_query)
- json_result = {}
- failures = {}
+ # A map of user ID -> device ID -> key ID -> key.
+ json_result = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
+ failures = {} # type: Dict[str, JsonDict]
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
- for key_id, json_bytes in keys.items():
+ for key_id, json_str in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
- key_id: json_decoder.decode(json_bytes)
+ key_id: json_decoder.decode(json_str)
}
@trace
@@ -468,7 +483,9 @@ class E2eKeysHandler:
return {"one_time_keys": json_result, "failures": failures}
@tag_args
- async def upload_keys_for_user(self, user_id, device_id, keys):
+ async def upload_keys_for_user(
+ self, user_id: str, device_id: str, keys: JsonDict
+ ) -> JsonDict:
time_now = self.clock.time_msec()
@@ -543,8 +560,8 @@ class E2eKeysHandler:
return {"one_time_key_counts": result}
async def _upload_one_time_keys_for_user(
- self, user_id, device_id, time_now, one_time_keys
- ):
+ self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
+ ) -> None:
logger.info(
"Adding one_time_keys %r for device %r for user %r at %d",
one_time_keys.keys(),
@@ -585,12 +602,14 @@ class E2eKeysHandler:
log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
- async def upload_signing_keys_for_user(self, user_id, keys):
+ async def upload_signing_keys_for_user(
+ self, user_id: str, keys: JsonDict
+ ) -> JsonDict:
"""Upload signing keys for cross-signing
Args:
- user_id (string): the user uploading the keys
- keys (dict[string, dict]): the signing keys
+ user_id: the user uploading the keys
+ keys: the signing keys
"""
# if a master key is uploaded, then check it. Otherwise, load the
@@ -667,16 +686,17 @@ class E2eKeysHandler:
return {}
- async def upload_signatures_for_device_keys(self, user_id, signatures):
+ async def upload_signatures_for_device_keys(
+ self, user_id: str, signatures: JsonDict
+ ) -> JsonDict:
"""Upload device signatures for cross-signing
Args:
- user_id (string): the user uploading the signatures
- signatures (dict[string, dict[string, dict]]): map of users to
- devices to signed keys. This is the submission from the user; an
- exception will be raised if it is malformed.
+ user_id: the user uploading the signatures
+ signatures: map of users to devices to signed keys. This is the submission
+ from the user; an exception will be raised if it is malformed.
Returns:
- dict: response to be sent back to the client. The response will have
+ The response to be sent back to the client. The response will have
a "failures" key, which will be a dict mapping users to devices
to errors for the signatures that failed.
Raises:
@@ -719,7 +739,9 @@ class E2eKeysHandler:
return {"failures": failures}
- async def _process_self_signatures(self, user_id, signatures):
+ async def _process_self_signatures(
+ self, user_id: str, signatures: JsonDict
+ ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
"""Process uploaded signatures of the user's own keys.
Signatures of the user's own keys from this API come in two forms:
@@ -731,15 +753,14 @@ class E2eKeysHandler:
signatures (dict[string, dict]): map of devices to signed keys
Returns:
- (list[SignatureListItem], dict[string, dict[string, dict]]):
- a list of signatures to store, and a map of users to devices to failure
- reasons
+ A tuple of a list of signatures to store, and a map of users to
+ devices to failure reasons
Raises:
SynapseError: if the input is malformed
"""
- signature_list = []
- failures = {}
+ signature_list = [] # type: List[SignatureListItem]
+ failures = {} # type: Dict[str, Dict[str, JsonDict]]
if not signatures:
return signature_list, failures
@@ -834,19 +855,24 @@ class E2eKeysHandler:
return signature_list, failures
def _check_master_key_signature(
- self, user_id, master_key_id, signed_master_key, stored_master_key, devices
- ):
+ self,
+ user_id: str,
+ master_key_id: str,
+ signed_master_key: JsonDict,
+ stored_master_key: JsonDict,
+ devices: Dict[str, Dict[str, JsonDict]],
+ ) -> List["SignatureListItem"]:
"""Check signatures of a user's master key made by their devices.
Args:
- user_id (string): the user whose master key is being checked
- master_key_id (string): the ID of the user's master key
- signed_master_key (dict): the user's signed master key that was uploaded
- stored_master_key (dict): our previously-stored copy of the user's master key
- devices (iterable(dict)): the user's devices
+ user_id: the user whose master key is being checked
+ master_key_id: the ID of the user's master key
+ signed_master_key: the user's signed master key that was uploaded
+ stored_master_key: our previously-stored copy of the user's master key
+ devices: the user's devices
Returns:
- list[SignatureListItem]: a list of signatures to store
+ A list of signatures to store
Raises:
SynapseError: if a signature is invalid
@@ -877,25 +903,26 @@ class E2eKeysHandler:
return master_key_signature_list
- async def _process_other_signatures(self, user_id, signatures):
+ async def _process_other_signatures(
+ self, user_id: str, signatures: Dict[str, dict]
+ ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
"""Process uploaded signatures of other users' keys. These will be the
target user's master keys, signed by the uploading user's user-signing
key.
Args:
- user_id (string): the user uploading the keys
- signatures (dict[string, dict]): map of users to devices to signed keys
+ user_id: the user uploading the keys
+ signatures: map of users to devices to signed keys
Returns:
- (list[SignatureListItem], dict[string, dict[string, dict]]):
- a list of signatures to store, and a map of users to devices to failure
+ A list of signatures to store, and a map of users to devices to failure
reasons
Raises:
SynapseError: if the input is malformed
"""
- signature_list = []
- failures = {}
+ signature_list = [] # type: List[SignatureListItem]
+ failures = {} # type: Dict[str, Dict[str, JsonDict]]
if not signatures:
return signature_list, failures
@@ -983,7 +1010,7 @@ class E2eKeysHandler:
async def _get_e2e_cross_signing_verify_key(
self, user_id: str, key_type: str, from_user_id: str = None
- ):
+ ) -> Tuple[JsonDict, str, VerifyKey]:
"""Fetch locally or remotely query for a cross-signing public key.
First, attempt to fetch the cross-signing public key from storage.
@@ -997,8 +1024,7 @@ class E2eKeysHandler:
This affects what signatures are fetched.
Returns:
- dict, str, VerifyKey: the raw key data, the key ID, and the
- signedjson verify key
+ The raw key data, the key ID, and the signedjson verify key
Raises:
NotFoundError: if the key is not found
@@ -1135,16 +1161,18 @@ class E2eKeysHandler:
return desired_key, desired_key_id, desired_verify_key
-def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
+def _check_cross_signing_key(
+ key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None
+) -> None:
"""Check a cross-signing key uploaded by a user. Performs some basic sanity
checking, and ensures that it is signed, if a signature is required.
Args:
- key (dict): the key data to verify
- user_id (str): the user whose key is being checked
- key_type (str): the type of key that the key should be
- signing_key (VerifyKey): (optional) the signing key that the key should
- be signed with. If omitted, signatures will not be checked.
+ key: the key data to verify
+ user_id: the user whose key is being checked
+ key_type: the type of key that the key should be
+ signing_key: the signing key that the key should be signed with. If
+ omitted, signatures will not be checked.
"""
if (
key.get("user_id") != user_id
@@ -1162,16 +1190,21 @@ def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
)
-def _check_device_signature(user_id, verify_key, signed_device, stored_device):
+def _check_device_signature(
+ user_id: str,
+ verify_key: VerifyKey,
+ signed_device: JsonDict,
+ stored_device: JsonDict,
+) -> None:
"""Check that a signature on a device or cross-signing key is correct and
matches the copy of the device/key that we have stored. Throws an
exception if an error is detected.
Args:
- user_id (str): the user ID whose signature is being checked
- verify_key (VerifyKey): the key to verify the device with
- signed_device (dict): the uploaded signed device data
- stored_device (dict): our previously stored copy of the device
+ user_id: the user ID whose signature is being checked
+ verify_key: the key to verify the device with
+ signed_device: the uploaded signed device data
+ stored_device: our previously stored copy of the device
Raises:
SynapseError: if the signature was invalid or the sent device is not the
@@ -1201,7 +1234,7 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device):
raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE)
-def _exception_to_failure(e):
+def _exception_to_failure(e: Exception) -> JsonDict:
if isinstance(e, SynapseError):
return {"status": e.code, "errcode": e.errcode, "message": str(e)}
@@ -1218,7 +1251,7 @@ def _exception_to_failure(e):
return {"status": 503, "message": str(e)}
-def _one_time_keys_match(old_key_json, new_key):
+def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool:
old_key = json_decoder.decode(old_key_json)
# if either is a string rather than an object, they must match exactly
@@ -1239,16 +1272,16 @@ class SignatureListItem:
"""An item in the signature list as used by upload_signatures_for_device_keys.
"""
- signing_key_id = attr.ib()
- target_user_id = attr.ib()
- target_device_id = attr.ib()
- signature = attr.ib()
+ signing_key_id = attr.ib(type=str)
+ target_user_id = attr.ib(type=str)
+ target_device_id = attr.ib(type=str)
+ signature = attr.ib(type=JsonDict)
class SigningKeyEduUpdater:
"""Handles incoming signing key updates from federation and updates the DB"""
- def __init__(self, hs, e2e_keys_handler):
+ def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
self.store = hs.get_datastore()
self.federation = hs.get_federation_client()
self.clock = hs.get_clock()
@@ -1257,7 +1290,7 @@ class SigningKeyEduUpdater:
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
# user_id -> list of updates waiting to be handled.
- self._pending_updates = {}
+ self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
# Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious
@@ -1270,13 +1303,15 @@ class SigningKeyEduUpdater:
iterable=True,
)
- async def incoming_signing_key_update(self, origin, edu_content):
+ async def incoming_signing_key_update(
+ self, origin: str, edu_content: JsonDict
+ ) -> None:
"""Called on incoming signing key update from federation. Responsible for
parsing the EDU and adding to pending updates list.
Args:
- origin (string): the server that sent the EDU
- edu_content (dict): the contents of the EDU
+ origin: the server that sent the EDU
+ edu_content: the contents of the EDU
"""
user_id = edu_content.pop("user_id")
@@ -1299,11 +1334,11 @@ class SigningKeyEduUpdater:
await self._handle_signing_key_updates(user_id)
- async def _handle_signing_key_updates(self, user_id):
+ async def _handle_signing_key_updates(self, user_id: str) -> None:
"""Actually handle pending updates.
Args:
- user_id (string): the user whose updates we are processing
+ user_id: the user whose updates we are processing
"""
device_handler = self.e2e_keys_handler.device_handler
@@ -1315,7 +1350,7 @@ class SigningKeyEduUpdater:
# This can happen since we batch updates
return
- device_ids = []
+ device_ids = [] # type: List[str]
logger.info("pending updates: %r", pending_updates)
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index f01b090772..622cae23be 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, List, Optional
from synapse.api.errors import (
Codes,
@@ -24,8 +25,12 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.logging.opentracing import log_kv, trace
+from synapse.types import JsonDict
from synapse.util.async_helpers import Linearizer
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
@@ -37,7 +42,7 @@ class E2eRoomKeysHandler:
The actual payload of the encrypted keys is completely opaque to the handler.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
# Used to lock whenever a client is uploading key data. This prevents collisions
@@ -48,21 +53,27 @@ class E2eRoomKeysHandler:
self._upload_linearizer = Linearizer("upload_room_keys_lock")
@trace
- async def get_room_keys(self, user_id, version, room_id=None, session_id=None):
+ async def get_room_keys(
+ self,
+ user_id: str,
+ version: str,
+ room_id: Optional[str] = None,
+ session_id: Optional[str] = None,
+ ) -> List[JsonDict]:
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session.
See EndToEndRoomKeyStore.get_e2e_room_keys for full details.
Args:
- user_id(str): the user whose keys we're getting
- version(str): the version ID of the backup we're getting keys from
- room_id(string): room ID to get keys for, for None to get keys for all rooms
- session_id(string): session ID to get keys for, for None to get keys for all
+ user_id: the user whose keys we're getting
+ version: the version ID of the backup we're getting keys from
+ room_id: room ID to get keys for, for None to get keys for all rooms
+ session_id: session ID to get keys for, for None to get keys for all
sessions
Raises:
NotFoundError: if the backup version does not exist
Returns:
- A deferred list of dicts giving the session_data and message metadata for
+ A list of dicts giving the session_data and message metadata for
these room keys.
"""
@@ -86,17 +97,23 @@ class E2eRoomKeysHandler:
return results
@trace
- async def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
+ async def delete_room_keys(
+ self,
+ user_id: str,
+ version: str,
+ room_id: Optional[str] = None,
+ session_id: Optional[str] = None,
+ ) -> JsonDict:
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
room or a given session.
See EndToEndRoomKeyStore.delete_e2e_room_keys for full details.
Args:
- user_id(str): the user whose backup we're deleting
- version(str): the version ID of the backup we're deleting
- room_id(string): room ID to delete keys for, for None to delete keys for all
+ user_id: the user whose backup we're deleting
+ version: the version ID of the backup we're deleting
+ room_id: room ID to delete keys for, for None to delete keys for all
rooms
- session_id(string): session ID to delete keys for, for None to delete keys
+ session_id: session ID to delete keys for, for None to delete keys
for all sessions
Raises:
NotFoundError: if the backup version does not exist
@@ -128,15 +145,17 @@ class E2eRoomKeysHandler:
return {"etag": str(version_etag), "count": count}
@trace
- async def upload_room_keys(self, user_id, version, room_keys):
+ async def upload_room_keys(
+ self, user_id: str, version: str, room_keys: JsonDict
+ ) -> JsonDict:
"""Bulk upload a list of room keys into a given backup version, asserting
that the given version is the current backup version. room_keys are merged
into the current backup as described in RoomKeysServlet.on_PUT().
Args:
- user_id(str): the user whose backup we're setting
- version(str): the version ID of the backup we're updating
- room_keys(dict): a nested dict describing the room_keys we're setting:
+ user_id: the user whose backup we're setting
+ version: the version ID of the backup we're updating
+ room_keys: a nested dict describing the room_keys we're setting:
{
"rooms": {
@@ -254,14 +273,16 @@ class E2eRoomKeysHandler:
return {"etag": str(version_etag), "count": count}
@staticmethod
- def _should_replace_room_key(current_room_key, room_key):
+ def _should_replace_room_key(
+ current_room_key: Optional[JsonDict], room_key: JsonDict
+ ) -> bool:
"""
Determine whether to replace a given current_room_key (if any)
with a newly uploaded room_key backup
Args:
- current_room_key (dict): Optional, the current room_key dict if any
- room_key (dict): The new room_key dict which may or may not be fit to
+ current_room_key: Optional, the current room_key dict if any
+ room_key : The new room_key dict which may or may not be fit to
replace the current_room_key
Returns:
@@ -286,14 +307,14 @@ class E2eRoomKeysHandler:
return True
@trace
- async def create_version(self, user_id, version_info):
+ async def create_version(self, user_id: str, version_info: JsonDict) -> str:
"""Create a new backup version. This automatically becomes the new
backup version for the user's keys; previous backups will no longer be
writeable to.
Args:
- user_id(str): the user whose backup version we're creating
- version_info(dict): metadata about the new version being created
+ user_id: the user whose backup version we're creating
+ version_info: metadata about the new version being created
{
"algorithm": "m.megolm_backup.v1",
@@ -301,7 +322,7 @@ class E2eRoomKeysHandler:
}
Returns:
- A deferred of a string that gives the new version number.
+ The new version number.
"""
# TODO: Validate the JSON to make sure it has the right keys.
@@ -313,17 +334,19 @@ class E2eRoomKeysHandler:
)
return new_version
- async def get_version_info(self, user_id, version=None):
+ async def get_version_info(
+ self, user_id: str, version: Optional[str] = None
+ ) -> JsonDict:
"""Get the info about a given version of the user's backup
Args:
- user_id(str): the user whose current backup version we're querying
- version(str): Optional; if None gives the most recent version
+ user_id: the user whose current backup version we're querying
+ version: Optional; if None gives the most recent version
otherwise a historical one.
Raises:
NotFoundError: if the requested backup version doesn't exist
Returns:
- A deferred of a info dict that gives the info about the new version.
+ A info dict that gives the info about the new version.
{
"version": "1234",
@@ -346,7 +369,7 @@ class E2eRoomKeysHandler:
return res
@trace
- async def delete_version(self, user_id, version=None):
+ async def delete_version(self, user_id: str, version: Optional[str] = None) -> None:
"""Deletes a given version of the user's e2e_room_keys backup
Args:
@@ -366,17 +389,19 @@ class E2eRoomKeysHandler:
raise
@trace
- async def update_version(self, user_id, version, version_info):
+ async def update_version(
+ self, user_id: str, version: str, version_info: JsonDict
+ ) -> JsonDict:
"""Update the info about a given version of the user's backup
Args:
- user_id(str): the user whose current backup version we're updating
- version(str): the backup version we're updating
- version_info(dict): the new information about the backup
+ user_id: the user whose current backup version we're updating
+ version: the backup version we're updating
+ version_info: the new information about the backup
Raises:
NotFoundError: if the requested backup version doesn't exist
Returns:
- A deferred of an empty dict.
+ An empty dict.
"""
if "version" not in version_info:
version_info["version"] = version
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index b9799090f7..eddc7582d0 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -140,7 +140,7 @@ class FederationHandler(BaseHandler):
self._message_handler = hs.get_message_handler()
self._server_notices_mxid = hs.config.server_notices_mxid
self.config = hs.config
- self.http_client = hs.get_simple_http_client()
+ self.http_client = hs.get_proxied_blacklisted_http_client()
self._instance_name = hs.get_instance_name()
self._replication = hs.get_replication_data_handler()
@@ -1593,7 +1593,7 @@ class FederationHandler(BaseHandler):
if self.hs.config.block_non_admin_invites:
raise SynapseError(403, "This server does not accept room invites")
- if not self.spam_checker.user_may_invite(
+ if not await self.spam_checker.user_may_invite(
event.sender, event.state_key, event.room_id
):
raise SynapseError(
@@ -1617,6 +1617,12 @@ class FederationHandler(BaseHandler):
if event.state_key == self._server_notices_mxid:
raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
+ # We retrieve the room member handler here as to not cause a cyclic dependency
+ member_handler = self.hs.get_room_member_handler()
+ # We don't rate limit based on room ID, as that should be done by
+ # sending server.
+ member_handler.ratelimit_invite(None, event.state_key)
+
# keep a record of the room version, if we don't yet know it.
# (this may get overwritten if we later get a different room version in a
# join dance).
@@ -2093,6 +2099,11 @@ class FederationHandler(BaseHandler):
if event.type == EventTypes.GuestAccess and not context.rejected:
await self.maybe_kick_guest_users(event)
+ # If we are going to send this event over federation we precaclculate
+ # the joined hosts.
+ if event.internal_metadata.get_send_on_behalf_of():
+ await self.event_creation_handler.cache_joined_hosts_for_event(event)
+
return context
async def _check_for_soft_fail(
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 7c06cc529e..754a22bd5f 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -15,9 +15,13 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Dict, Iterable, List, Set
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
-from synapse.types import GroupID, get_domain_from_id
+from synapse.types import GroupID, JsonDict, get_domain_from_id
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@@ -29,7 +33,7 @@ def _create_rerouter(func_name):
async def f(self, group_id, *args, **kwargs):
if not GroupID.is_valid(group_id):
- raise SynapseError(400, "%s was not legal group ID" % (group_id,))
+ raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
if self.is_mine_id(group_id):
return await getattr(self.groups_server_handler, func_name)(
@@ -56,7 +60,7 @@ def _create_rerouter(func_name):
class GroupsLocalWorkerHandler:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.room_list_handler = hs.get_room_list_handler()
@@ -84,7 +88,9 @@ class GroupsLocalWorkerHandler:
get_group_role = _create_rerouter("get_group_role")
get_group_roles = _create_rerouter("get_group_roles")
- async def get_group_summary(self, group_id, requester_user_id):
+ async def get_group_summary(
+ self, group_id: str, requester_user_id: str
+ ) -> JsonDict:
"""Get the group summary for a group.
If the group is remote we check that the users have valid attestations.
@@ -137,14 +143,15 @@ class GroupsLocalWorkerHandler:
return res
- async def get_users_in_group(self, group_id, requester_user_id):
+ async def get_users_in_group(
+ self, group_id: str, requester_user_id: str
+ ) -> JsonDict:
"""Get users in a group
"""
if self.is_mine_id(group_id):
- res = await self.groups_server_handler.get_users_in_group(
+ return await self.groups_server_handler.get_users_in_group(
group_id, requester_user_id
)
- return res
group_server_name = get_domain_from_id(group_id)
@@ -178,11 +185,11 @@ class GroupsLocalWorkerHandler:
return res
- async def get_joined_groups(self, user_id):
+ async def get_joined_groups(self, user_id: str) -> JsonDict:
group_ids = await self.store.get_joined_groups(user_id)
return {"groups": group_ids}
- async def get_publicised_groups_for_user(self, user_id):
+ async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict:
if self.hs.is_mine_id(user_id):
result = await self.store.get_publicised_groups_for_user(user_id)
@@ -206,8 +213,10 @@ class GroupsLocalWorkerHandler:
# TODO: Verify attestations
return {"groups": result}
- async def bulk_get_publicised_groups(self, user_ids, proxy=True):
- destinations = {}
+ async def bulk_get_publicised_groups(
+ self, user_ids: Iterable[str], proxy: bool = True
+ ) -> JsonDict:
+ destinations = {} # type: Dict[str, Set[str]]
local_users = set()
for user_id in user_ids:
@@ -220,7 +229,7 @@ class GroupsLocalWorkerHandler:
raise SynapseError(400, "Some user_ids are not local")
results = {}
- failed_results = []
+ failed_results = [] # type: List[str]
for destination, dest_user_ids in destinations.items():
try:
r = await self.transport_client.bulk_get_publicised_groups(
@@ -242,7 +251,7 @@ class GroupsLocalWorkerHandler:
class GroupsLocalHandler(GroupsLocalWorkerHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
# Ensure attestations get renewed
@@ -271,7 +280,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
set_group_join_policy = _create_rerouter("set_group_join_policy")
- async def create_group(self, group_id, user_id, content):
+ async def create_group(
+ self, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Create a group
"""
@@ -284,27 +295,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
local_attestation = None
remote_attestation = None
else:
- local_attestation = self.attestations.create_attestation(group_id, user_id)
- content["attestation"] = local_attestation
-
- content["user_profile"] = await self.profile_handler.get_profile(user_id)
-
- try:
- res = await self.transport_client.create_group(
- get_domain_from_id(group_id), group_id, user_id, content
- )
- except HttpResponseException as e:
- raise e.to_synapse_error()
- except RequestSendFailed:
- raise SynapseError(502, "Failed to contact group server")
-
- remote_attestation = res["attestation"]
- await self.attestations.verify_attestation(
- remote_attestation,
- group_id=group_id,
- user_id=user_id,
- server_name=get_domain_from_id(group_id),
- )
+ raise SynapseError(400, "Unable to create remote groups")
is_publicised = content.get("publicise", False)
token = await self.store.register_user_group_membership(
@@ -320,7 +311,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res
- async def join_group(self, group_id, user_id, content):
+ async def join_group(
+ self, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Request to join a group
"""
if self.is_mine_id(group_id):
@@ -365,7 +358,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {}
- async def accept_invite(self, group_id, user_id, content):
+ async def accept_invite(
+ self, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Accept an invite to a group
"""
if self.is_mine_id(group_id):
@@ -410,7 +405,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {}
- async def invite(self, group_id, user_id, requester_user_id, config):
+ async def invite(
+ self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict
+ ) -> JsonDict:
"""Invite a user to a group
"""
content = {"requester_user_id": requester_user_id, "config": config}
@@ -434,7 +431,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res
- async def on_invite(self, group_id, user_id, content):
+ async def on_invite(
+ self, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
"""One of our users were invited to a group
"""
# TODO: Support auto join and rejection
@@ -484,8 +483,8 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res
async def remove_user_from_group(
- self, group_id, user_id, requester_user_id, content
- ):
+ self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Remove a user from a group
"""
if user_id == requester_user_id:
@@ -518,7 +517,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res
- async def user_removed_from_group(self, group_id, user_id, content):
+ async def user_removed_from_group(
+ self, group_id: str, user_id: str, content: JsonDict
+ ) -> None:
"""One of our users was removed/kicked from a group
"""
# TODO: Check if user in group
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 9b3c6b4551..8fc1e8b91c 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -27,9 +27,11 @@ from synapse.api.errors import (
HttpResponseException,
SynapseError,
)
+from synapse.api.ratelimiting import Ratelimiter
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http import RequestTimedOutError
from synapse.http.client import SimpleHttpClient
+from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, Requester
from synapse.util import json_decoder
from synapse.util.hash import sha256_and_url_safe_base64
@@ -46,15 +48,43 @@ class IdentityHandler(BaseHandler):
def __init__(self, hs):
super().__init__(hs)
+ # An HTTP client for contacting trusted URLs.
self.http_client = SimpleHttpClient(hs)
- # We create a blacklisting instance of SimpleHttpClient for contacting identity
- # servers specified by clients
+ # An HTTP client for contacting identity servers specified by clients.
self.blacklisting_http_client = SimpleHttpClient(
hs, ip_blacklist=hs.config.federation_ip_range_blacklist
)
- self.federation_http_client = hs.get_http_client()
+ self.federation_http_client = hs.get_federation_http_client()
self.hs = hs
+ self._web_client_location = hs.config.invite_client_location
+
+ # Ratelimiters for `/requestToken` endpoints.
+ self._3pid_validation_ratelimiter_ip = Ratelimiter(
+ clock=hs.get_clock(),
+ rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
+ burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
+ )
+ self._3pid_validation_ratelimiter_address = Ratelimiter(
+ clock=hs.get_clock(),
+ rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
+ burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
+ )
+
+ def ratelimit_request_token_requests(
+ self, request: SynapseRequest, medium: str, address: str,
+ ):
+ """Used to ratelimit requests to `/requestToken` by IP and address.
+
+ Args:
+ request: The associated request
+ medium: The type of threepid, e.g. "msisdn" or "email"
+ address: The actual threepid ID, e.g. the phone number or email address
+ """
+
+ self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP()))
+ self._3pid_validation_ratelimiter_address.ratelimit((medium, address))
+
async def threepid_from_creds(
self, id_server: str, creds: Dict[str, str]
) -> Optional[JsonDict]:
@@ -474,6 +504,8 @@ class IdentityHandler(BaseHandler):
except RequestTimedOutError:
raise SynapseError(500, "Timed out contacting identity server")
+ # It is already checked that public_baseurl is configured since this code
+ # should only be used if account_threepid_delegate_msisdn is true.
assert self.hs.config.public_baseurl
# we need to tell the client to send the token back to us, since it doesn't
@@ -803,6 +835,9 @@ class IdentityHandler(BaseHandler):
"sender_display_name": inviter_display_name,
"sender_avatar_url": inviter_avatar_url,
}
+ # If a custom web client location is available, include it in the request.
+ if self._web_client_location:
+ invite_config["org.matrix.web_client_location"] = self._web_client_location
# Add the identity service access token to the JSON body and use the v2
# Identity Service endpoints if id_access_token is present
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index cb11754bf8..fbd8df9dcc 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -323,9 +323,7 @@ class InitialSyncHandler(BaseHandler):
member_event_id: str,
is_peeking: bool,
) -> JsonDict:
- room_state = await self.state_store.get_state_for_events([member_event_id])
-
- room_state = room_state[member_event_id]
+ room_state = await self.state_store.get_state_for_event(member_event_id)
limit = pagin_config.limit if pagin_config else None
if limit is None:
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 96843338ae..a15336bf00 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -174,7 +174,7 @@ class MessageHandler:
raise NotFoundError("Can't find event for token %s" % (at_token,))
visible_events = await filter_events_for_client(
- self.storage, user_id, last_events, filter_send_to_client=False
+ self.storage, user_id, last_events, filter_send_to_client=False,
)
event = last_events[0]
@@ -432,6 +432,8 @@ class EventCreationHandler:
self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
+ self._external_cache = hs.get_external_cache()
+
async def create_event(
self,
requester: Requester,
@@ -744,7 +746,7 @@ class EventCreationHandler:
event.sender,
)
- spam_error = self.spam_checker.check_event_for_spam(event)
+ spam_error = await self.spam_checker.check_event_for_spam(event)
if spam_error:
if not isinstance(spam_error, str):
spam_error = "Spam is not permitted here"
@@ -939,6 +941,8 @@ class EventCreationHandler:
await self.action_generator.handle_push_actions_for_event(event, context)
+ await self.cache_joined_hosts_for_event(event)
+
try:
# If we're a worker we need to hit out to the master.
writer_instance = self._events_shard_config.get_instance(event.room_id)
@@ -978,6 +982,44 @@ class EventCreationHandler:
await self.store.remove_push_actions_from_staging(event.event_id)
raise
+ async def cache_joined_hosts_for_event(self, event: EventBase) -> None:
+ """Precalculate the joined hosts at the event, when using Redis, so that
+ external federation senders don't have to recalculate it themselves.
+ """
+
+ if not self._external_cache.is_enabled():
+ return
+
+ # We actually store two mappings, event ID -> prev state group,
+ # state group -> joined hosts, which is much more space efficient
+ # than event ID -> joined hosts.
+ #
+ # Note: We have to cache event ID -> prev state group, as we don't
+ # store that in the DB.
+ #
+ # Note: We always set the state group -> joined hosts cache, even if
+ # we already set it, so that the expiry time is reset.
+
+ state_entry = await self.state.resolve_state_groups_for_events(
+ event.room_id, event_ids=event.prev_event_ids()
+ )
+
+ if state_entry.state_group:
+ joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
+
+ await self._external_cache.set(
+ "event_to_prev_state_group",
+ event.event_id,
+ state_entry.state_group,
+ expiry_ms=60 * 60 * 1000,
+ )
+ await self._external_cache.set(
+ "get_joined_hosts",
+ str(state_entry.state_group),
+ list(joined_hosts),
+ expiry_ms=60 * 60 * 1000,
+ )
+
async def _validate_canonical_alias(
self, directory_handler, room_alias_str: str, expected_room_id: str
) -> None:
@@ -1261,7 +1303,7 @@ class EventCreationHandler:
event, context = await self.create_event(
requester,
{
- "type": "org.matrix.dummy_event",
+ "type": EventTypes.Dummy,
"content": {},
"room_id": room_id,
"sender": user_id,
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index c605f7082a..71008ec50d 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -14,7 +14,7 @@
# limitations under the License.
import inspect
import logging
-from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
+from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar
from urllib.parse import urlencode
import attr
@@ -35,7 +35,7 @@ from typing_extensions import TypedDict
from twisted.web.client import readBody
from synapse.config import ConfigError
-from synapse.handlers._base import BaseHandler
+from synapse.config.oidc_config import OidcProviderConfig
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
@@ -71,6 +71,144 @@ JWK = Dict[str, str]
JWKS = TypedDict("JWKS", {"keys": List[JWK]})
+class OidcHandler:
+ """Handles requests related to the OpenID Connect login flow.
+ """
+
+ def __init__(self, hs: "HomeServer"):
+ self._sso_handler = hs.get_sso_handler()
+
+ provider_confs = hs.config.oidc.oidc_providers
+ # we should not have been instantiated if there is no configured provider.
+ assert provider_confs
+
+ self._token_generator = OidcSessionTokenGenerator(hs)
+ self._providers = {
+ p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
+ } # type: Dict[str, OidcProvider]
+
+ async def load_metadata(self) -> None:
+ """Validate the config and load the metadata from the remote endpoint.
+
+ Called at startup to ensure we have everything we need.
+ """
+ for idp_id, p in self._providers.items():
+ try:
+ await p.load_metadata()
+ await p.load_jwks()
+ except Exception as e:
+ raise Exception(
+ "Error while initialising OIDC provider %r" % (idp_id,)
+ ) from e
+
+ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
+ """Handle an incoming request to /_synapse/client/oidc/callback
+
+ Since we might want to display OIDC-related errors in a user-friendly
+ way, we don't raise SynapseError from here. Instead, we call
+ ``self._sso_handler.render_error`` which displays an HTML page for the error.
+
+ Most of the OpenID Connect logic happens here:
+
+ - first, we check if there was any error returned by the provider and
+ display it
+ - then we fetch the session cookie, decode and verify it
+ - the ``state`` query parameter should match with the one stored in the
+ session cookie
+
+ Once we know the session is legit, we then delegate to the OIDC Provider
+ implementation, which will exchange the code with the provider and complete the
+ login/authentication.
+
+ Args:
+ request: the incoming request from the browser.
+ """
+
+ # The provider might redirect with an error.
+ # In that case, just display it as-is.
+ if b"error" in request.args:
+ # error response from the auth server. see:
+ # https://tools.ietf.org/html/rfc6749#section-4.1.2.1
+ # https://openid.net/specs/openid-connect-core-1_0.html#AuthError
+ error = request.args[b"error"][0].decode()
+ description = request.args.get(b"error_description", [b""])[0].decode()
+
+ # Most of the errors returned by the provider could be due by
+ # either the provider misbehaving or Synapse being misconfigured.
+ # The only exception of that is "access_denied", where the user
+ # probably cancelled the login flow. In other cases, log those errors.
+ if error != "access_denied":
+ logger.error("Error from the OIDC provider: %s %s", error, description)
+
+ self._sso_handler.render_error(request, error, description)
+ return
+
+ # otherwise, it is presumably a successful response. see:
+ # https://tools.ietf.org/html/rfc6749#section-4.1.2
+
+ # Fetch the session cookie
+ session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
+ if session is None:
+ logger.info("No session cookie found")
+ self._sso_handler.render_error(
+ request, "missing_session", "No session cookie found"
+ )
+ return
+
+ # Remove the cookie. There is a good chance that if the callback failed
+ # once, it will fail next time and the code will already be exchanged.
+ # Removing it early avoids spamming the provider with token requests.
+ request.addCookie(
+ SESSION_COOKIE_NAME,
+ b"",
+ path="/_synapse/oidc",
+ expires="Thu, Jan 01 1970 00:00:00 UTC",
+ httpOnly=True,
+ sameSite="lax",
+ )
+
+ # Check for the state query parameter
+ if b"state" not in request.args:
+ logger.info("State parameter is missing")
+ self._sso_handler.render_error(
+ request, "invalid_request", "State parameter is missing"
+ )
+ return
+
+ state = request.args[b"state"][0].decode()
+
+ # Deserialize the session token and verify it.
+ try:
+ session_data = self._token_generator.verify_oidc_session_token(
+ session, state
+ )
+ except (MacaroonDeserializationException, ValueError) as e:
+ logger.exception("Invalid session")
+ self._sso_handler.render_error(request, "invalid_session", str(e))
+ return
+ except MacaroonInvalidSignatureException as e:
+ logger.exception("Could not verify session")
+ self._sso_handler.render_error(request, "mismatching_session", str(e))
+ return
+
+ oidc_provider = self._providers.get(session_data.idp_id)
+ if not oidc_provider:
+ logger.error("OIDC session uses unknown IdP %r", oidc_provider)
+ self._sso_handler.render_error(request, "unknown_idp", "Unknown IdP")
+ return
+
+ if b"code" not in request.args:
+ logger.info("Code parameter is missing")
+ self._sso_handler.render_error(
+ request, "invalid_request", "Code parameter is missing"
+ )
+ return
+
+ code = request.args[b"code"][0].decode()
+
+ await oidc_provider.handle_oidc_callback(request, session_data, code)
+
+
class OidcError(Exception):
"""Used to catch errors when calling the token_endpoint
"""
@@ -85,46 +223,64 @@ class OidcError(Exception):
return self.error
-class OidcHandler(BaseHandler):
- """Handles requests related to the OpenID Connect login flow.
+class OidcProvider:
+ """Wraps the config for a single OIDC IdentityProvider
+
+ Provides methods for handling redirect requests and callbacks via that particular
+ IdP.
"""
- def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ def __init__(
+ self,
+ hs: "HomeServer",
+ token_generator: "OidcSessionTokenGenerator",
+ provider: OidcProviderConfig,
+ ):
+ self._store = hs.get_datastore()
+
+ self._token_generator = token_generator
+
self._callback_url = hs.config.oidc_callback_url # type: str
- self._scopes = hs.config.oidc_scopes # type: List[str]
- self._user_profile_method = hs.config.oidc_user_profile_method # type: str
+
+ self._scopes = provider.scopes
+ self._user_profile_method = provider.user_profile_method
self._client_auth = ClientAuth(
- hs.config.oidc_client_id,
- hs.config.oidc_client_secret,
- hs.config.oidc_client_auth_method,
+ provider.client_id, provider.client_secret, provider.client_auth_method,
) # type: ClientAuth
- self._client_auth_method = hs.config.oidc_client_auth_method # type: str
+ self._client_auth_method = provider.client_auth_method
self._provider_metadata = OpenIDProviderMetadata(
- issuer=hs.config.oidc_issuer,
- authorization_endpoint=hs.config.oidc_authorization_endpoint,
- token_endpoint=hs.config.oidc_token_endpoint,
- userinfo_endpoint=hs.config.oidc_userinfo_endpoint,
- jwks_uri=hs.config.oidc_jwks_uri,
+ issuer=provider.issuer,
+ authorization_endpoint=provider.authorization_endpoint,
+ token_endpoint=provider.token_endpoint,
+ userinfo_endpoint=provider.userinfo_endpoint,
+ jwks_uri=provider.jwks_uri,
) # type: OpenIDProviderMetadata
- self._provider_needs_discovery = hs.config.oidc_discover # type: bool
- self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class(
- hs.config.oidc_user_mapping_provider_config
- ) # type: OidcMappingProvider
- self._skip_verification = hs.config.oidc_skip_verification # type: bool
- self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool
+ self._provider_needs_discovery = provider.discover
+ self._user_mapping_provider = provider.user_mapping_provider_class(
+ provider.user_mapping_provider_config
+ )
+ self._skip_verification = provider.skip_verification
+ self._allow_existing_users = provider.allow_existing_users
self._http_client = hs.get_proxied_http_client()
- self._auth_handler = hs.get_auth_handler()
- self._registration_handler = hs.get_registration_handler()
self._server_name = hs.config.server_name # type: str
- self._macaroon_secret_key = hs.config.macaroon_secret_key
# identifier for the external_ids table
- self._auth_provider_id = "oidc"
+ self.idp_id = provider.idp_id
+
+ # user-facing name of this auth provider
+ self.idp_name = provider.idp_name
+
+ # MXC URI for icon for this auth provider
+ self.idp_icon = provider.idp_icon
+
+ # optional brand identifier for this auth provider
+ self.idp_brand = provider.idp_brand
self._sso_handler = hs.get_sso_handler()
+ self._sso_handler.register_identity_provider(self)
+
def _validate_metadata(self):
"""Verifies the provider metadata.
@@ -477,7 +633,7 @@ class OidcHandler(BaseHandler):
async def handle_redirect_request(
self,
request: SynapseRequest,
- client_redirect_url: bytes,
+ client_redirect_url: Optional[bytes],
ui_auth_session_id: Optional[str] = None,
) -> str:
"""Handle an incoming request to /login/sso/redirect
@@ -487,7 +643,7 @@ class OidcHandler(BaseHandler):
- ``client_id``: the client ID set in ``oidc_config.client_id``
- ``response_type``: ``code``
- - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/oidc/callback``
+ - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/client/oidc/callback``
- ``scope``: the list of scopes set in ``oidc_config.scopes``
- ``state``: a random string
- ``nonce``: a random string
@@ -501,7 +657,7 @@ class OidcHandler(BaseHandler):
request: the incoming request from the browser.
We'll respond to it with a redirect and a cookie.
client_redirect_url: the URL that we should redirect the client to
- when everything is done
+ when everything is done (or None for UI Auth)
ui_auth_session_id: The session ID of the ongoing UI Auth (or
None if this is a login).
@@ -513,16 +669,22 @@ class OidcHandler(BaseHandler):
state = generate_token()
nonce = generate_token()
- cookie = self._generate_oidc_session_token(
+ if not client_redirect_url:
+ client_redirect_url = b""
+
+ cookie = self._token_generator.generate_oidc_session_token(
state=state,
- nonce=nonce,
- client_redirect_url=client_redirect_url.decode(),
- ui_auth_session_id=ui_auth_session_id,
+ session_data=OidcSessionData(
+ idp_id=self.idp_id,
+ nonce=nonce,
+ client_redirect_url=client_redirect_url.decode(),
+ ui_auth_session_id=ui_auth_session_id,
+ ),
)
request.addCookie(
SESSION_COOKIE_NAME,
cookie,
- path="/_synapse/oidc",
+ path="/_synapse/client/oidc",
max_age="3600",
httpOnly=True,
sameSite="lax",
@@ -540,22 +702,16 @@ class OidcHandler(BaseHandler):
nonce=nonce,
)
- async def handle_oidc_callback(self, request: SynapseRequest) -> None:
- """Handle an incoming request to /_synapse/oidc/callback
+ async def handle_oidc_callback(
+ self, request: SynapseRequest, session_data: "OidcSessionData", code: str
+ ) -> None:
+ """Handle an incoming request to /_synapse/client/oidc/callback
- Since we might want to display OIDC-related errors in a user-friendly
- way, we don't raise SynapseError from here. Instead, we call
- ``self._sso_handler.render_error`` which displays an HTML page for the error.
-
- Most of the OpenID Connect logic happens here:
+ By this time we have already validated the session on the synapse side, and
+ now need to do the provider-specific operations. This includes:
- - first, we check if there was any error returned by the provider and
- display it
- - then we fetch the session cookie, decode and verify it
- - the ``state`` query parameter should match with the one stored in the
- session cookie
- - once we known this session is legit, exchange the code with the
- provider using the ``token_endpoint`` (see ``_exchange_code``)
+ - exchange the code with the provider using the ``token_endpoint`` (see
+ ``_exchange_code``)
- once we have the token, use it to either extract the UserInfo from
the ``id_token`` (``_parse_id_token``), or use the ``access_token``
to fetch UserInfo from the ``userinfo_endpoint``
@@ -565,88 +721,12 @@ class OidcHandler(BaseHandler):
Args:
request: the incoming request from the browser.
+ session_data: the session data, extracted from our cookie
+ code: The authorization code we got from the callback.
"""
-
- # The provider might redirect with an error.
- # In that case, just display it as-is.
- if b"error" in request.args:
- # error response from the auth server. see:
- # https://tools.ietf.org/html/rfc6749#section-4.1.2.1
- # https://openid.net/specs/openid-connect-core-1_0.html#AuthError
- error = request.args[b"error"][0].decode()
- description = request.args.get(b"error_description", [b""])[0].decode()
-
- # Most of the errors returned by the provider could be due by
- # either the provider misbehaving or Synapse being misconfigured.
- # The only exception of that is "access_denied", where the user
- # probably cancelled the login flow. In other cases, log those errors.
- if error != "access_denied":
- logger.error("Error from the OIDC provider: %s %s", error, description)
-
- self._sso_handler.render_error(request, error, description)
- return
-
- # otherwise, it is presumably a successful response. see:
- # https://tools.ietf.org/html/rfc6749#section-4.1.2
-
- # Fetch the session cookie
- session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
- if session is None:
- logger.info("No session cookie found")
- self._sso_handler.render_error(
- request, "missing_session", "No session cookie found"
- )
- return
-
- # Remove the cookie. There is a good chance that if the callback failed
- # once, it will fail next time and the code will already be exchanged.
- # Removing it early avoids spamming the provider with token requests.
- request.addCookie(
- SESSION_COOKIE_NAME,
- b"",
- path="/_synapse/oidc",
- expires="Thu, Jan 01 1970 00:00:00 UTC",
- httpOnly=True,
- sameSite="lax",
- )
-
- # Check for the state query parameter
- if b"state" not in request.args:
- logger.info("State parameter is missing")
- self._sso_handler.render_error(
- request, "invalid_request", "State parameter is missing"
- )
- return
-
- state = request.args[b"state"][0].decode()
-
- # Deserialize the session token and verify it.
- try:
- (
- nonce,
- client_redirect_url,
- ui_auth_session_id,
- ) = self._verify_oidc_session_token(session, state)
- except MacaroonDeserializationException as e:
- logger.exception("Invalid session")
- self._sso_handler.render_error(request, "invalid_session", str(e))
- return
- except MacaroonInvalidSignatureException as e:
- logger.exception("Could not verify session")
- self._sso_handler.render_error(request, "mismatching_session", str(e))
- return
-
# Exchange the code with the provider
- if b"code" not in request.args:
- logger.info("Code parameter is missing")
- self._sso_handler.render_error(
- request, "invalid_request", "Code parameter is missing"
- )
- return
-
- logger.debug("Exchanging code")
- code = request.args[b"code"][0].decode()
try:
+ logger.debug("Exchanging code")
token = await self._exchange_code(code)
except OidcError as e:
logger.exception("Could not exchange code")
@@ -668,25 +748,131 @@ class OidcHandler(BaseHandler):
else:
logger.debug("Extracting userinfo from id_token")
try:
- userinfo = await self._parse_id_token(token, nonce=nonce)
+ userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
except Exception as e:
logger.exception("Invalid id_token")
self._sso_handler.render_error(request, "invalid_token", str(e))
return
- # Pull out the user-agent and IP from the request.
- user_agent = request.get_user_agent("")
- ip_address = self.hs.get_ip_from_request(request)
+ # first check if we're doing a UIA
+ if session_data.ui_auth_session_id:
+ try:
+ remote_user_id = self._remote_id_from_userinfo(userinfo)
+ except Exception as e:
+ logger.exception("Could not extract remote user id")
+ self._sso_handler.render_error(request, "mapping_error", str(e))
+ return
+
+ return await self._sso_handler.complete_sso_ui_auth_request(
+ self.idp_id, remote_user_id, session_data.ui_auth_session_id, request
+ )
+
+ # otherwise, it's a login
# Call the mapper to register/login the user
try:
- user_id = await self._map_userinfo_to_user(
- userinfo, token, user_agent, ip_address
+ await self._complete_oidc_login(
+ userinfo, token, request, session_data.client_redirect_url
)
except MappingException as e:
logger.exception("Could not map user")
self._sso_handler.render_error(request, "mapping_error", str(e))
- return
+
+ async def _complete_oidc_login(
+ self,
+ userinfo: UserInfo,
+ token: Token,
+ request: SynapseRequest,
+ client_redirect_url: str,
+ ) -> None:
+ """Given a UserInfo response, complete the login flow
+
+ UserInfo should have a claim that uniquely identifies users. This claim
+ is usually `sub`, but can be configured with `oidc_config.subject_claim`.
+ It is then used as an `external_id`.
+
+ If we don't find the user that way, we should register the user,
+ mapping the localpart and the display name from the UserInfo.
+
+ If a user already exists with the mxid we've mapped and allow_existing_users
+ is disabled, raise an exception.
+
+ Otherwise, render a redirect back to the client_redirect_url with a loginToken.
+
+ Args:
+ userinfo: an object representing the user
+ token: a dict with the tokens obtained from the provider
+ request: The request to respond to
+ client_redirect_url: The redirect URL passed in by the client.
+
+ Raises:
+ MappingException: if there was an error while mapping some properties
+ """
+ try:
+ remote_user_id = self._remote_id_from_userinfo(userinfo)
+ except Exception as e:
+ raise MappingException(
+ "Failed to extract subject from OIDC response: %s" % (e,)
+ )
+
+ # Older mapping providers don't accept the `failures` argument, so we
+ # try and detect support.
+ mapper_signature = inspect.signature(
+ self._user_mapping_provider.map_user_attributes
+ )
+ supports_failures = "failures" in mapper_signature.parameters
+
+ async def oidc_response_to_user_attributes(failures: int) -> UserAttributes:
+ """
+ Call the mapping provider to map the OIDC userinfo and token to user attributes.
+
+ This is backwards compatibility for abstraction for the SSO handler.
+ """
+ if supports_failures:
+ attributes = await self._user_mapping_provider.map_user_attributes(
+ userinfo, token, failures
+ )
+ else:
+ # If the mapping provider does not support processing failures,
+ # do not continually generate the same Matrix ID since it will
+ # continue to already be in use. Note that the error raised is
+ # arbitrary and will get turned into a MappingException.
+ if failures:
+ raise MappingException(
+ "Mapping provider does not support de-duplicating Matrix IDs"
+ )
+
+ attributes = await self._user_mapping_provider.map_user_attributes( # type: ignore
+ userinfo, token
+ )
+
+ return UserAttributes(**attributes)
+
+ async def grandfather_existing_users() -> Optional[str]:
+ if self._allow_existing_users:
+ # If allowing existing users we want to generate a single localpart
+ # and attempt to match it.
+ attributes = await oidc_response_to_user_attributes(failures=0)
+
+ user_id = UserID(attributes.localpart, self._server_name).to_string()
+ users = await self._store.get_users_by_id_case_insensitive(user_id)
+ if users:
+ # If an existing matrix ID is returned, then use it.
+ if len(users) == 1:
+ previously_registered_user_id = next(iter(users))
+ elif user_id in users:
+ previously_registered_user_id = user_id
+ else:
+ # Do not attempt to continue generating Matrix IDs.
+ raise MappingException(
+ "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
+ user_id, users
+ )
+ )
+
+ return previously_registered_user_id
+
+ return None
# Mapping providers might not have get_extra_attributes: only call this
# method if it exists.
@@ -697,22 +883,42 @@ class OidcHandler(BaseHandler):
if get_extra_attributes:
extra_attributes = await get_extra_attributes(userinfo, token)
- # and finally complete the login
- if ui_auth_session_id:
- await self._auth_handler.complete_sso_ui_auth(
- user_id, ui_auth_session_id, request
- )
- else:
- await self._auth_handler.complete_sso_login(
- user_id, request, client_redirect_url, extra_attributes
- )
+ await self._sso_handler.complete_sso_login_request(
+ self.idp_id,
+ remote_user_id,
+ request,
+ client_redirect_url,
+ oidc_response_to_user_attributes,
+ grandfather_existing_users,
+ extra_attributes,
+ )
+
+ def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
+ """Extract the unique remote id from an OIDC UserInfo block
+
+ Args:
+ userinfo: An object representing the user given by the OIDC provider
+ Returns:
+ remote user id
+ """
+ remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
+ # Some OIDC providers use integer IDs, but Synapse expects external IDs
+ # to be strings.
+ return str(remote_user_id)
+
+
+class OidcSessionTokenGenerator:
+ """Methods for generating and checking OIDC Session cookies."""
+
+ def __init__(self, hs: "HomeServer"):
+ self._clock = hs.get_clock()
+ self._server_name = hs.hostname
+ self._macaroon_secret_key = hs.config.key.macaroon_secret_key
- def _generate_oidc_session_token(
+ def generate_oidc_session_token(
self,
state: str,
- nonce: str,
- client_redirect_url: str,
- ui_auth_session_id: Optional[str],
+ session_data: "OidcSessionData",
duration_in_ms: int = (60 * 60 * 1000),
) -> str:
"""Generates a signed token storing data about an OIDC session.
@@ -725,11 +931,7 @@ class OidcHandler(BaseHandler):
Args:
state: The ``state`` parameter passed to the OIDC provider.
- nonce: The ``nonce`` parameter passed to the OIDC provider.
- client_redirect_url: The URL the client gave when it initiated the
- flow.
- ui_auth_session_id: The session ID of the ongoing UI Auth (or
- None if this is a login).
+ session_data: data to include in the session token.
duration_in_ms: An optional duration for the token in milliseconds.
Defaults to an hour.
@@ -742,23 +944,24 @@ class OidcHandler(BaseHandler):
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = session")
macaroon.add_first_party_caveat("state = %s" % (state,))
- macaroon.add_first_party_caveat("nonce = %s" % (nonce,))
+ macaroon.add_first_party_caveat("idp_id = %s" % (session_data.idp_id,))
+ macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,))
macaroon.add_first_party_caveat(
- "client_redirect_url = %s" % (client_redirect_url,)
+ "client_redirect_url = %s" % (session_data.client_redirect_url,)
)
- if ui_auth_session_id:
+ if session_data.ui_auth_session_id:
macaroon.add_first_party_caveat(
- "ui_auth_session_id = %s" % (ui_auth_session_id,)
+ "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
)
- now = self.clock.time_msec()
+ now = self._clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
- def _verify_oidc_session_token(
+ def verify_oidc_session_token(
self, session: bytes, state: str
- ) -> Tuple[str, str, Optional[str]]:
+ ) -> "OidcSessionData":
"""Verifies and extract an OIDC session token.
This verifies that a given session token was issued by this homeserver
@@ -769,7 +972,10 @@ class OidcHandler(BaseHandler):
state: The state the OIDC provider gave back
Returns:
- The nonce, client_redirect_url, and ui_auth_session_id for this session
+ The data extracted from the session cookie
+
+ Raises:
+ ValueError if an expected caveat is missing from the macaroon.
"""
macaroon = pymacaroons.Macaroon.deserialize(session)
@@ -778,6 +984,7 @@ class OidcHandler(BaseHandler):
v.satisfy_exact("type = session")
v.satisfy_exact("state = %s" % (state,))
v.satisfy_general(lambda c: c.startswith("nonce = "))
+ v.satisfy_general(lambda c: c.startswith("idp_id = "))
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
# to always satisfy this.
@@ -786,9 +993,9 @@ class OidcHandler(BaseHandler):
v.verify(macaroon, self._macaroon_secret_key)
- # Extract the `nonce`, `client_redirect_url`, and maybe the
- # `ui_auth_session_id` from the token.
+ # Extract the session data from the token.
nonce = self._get_value_from_macaroon(macaroon, "nonce")
+ idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
client_redirect_url = self._get_value_from_macaroon(
macaroon, "client_redirect_url"
)
@@ -799,7 +1006,12 @@ class OidcHandler(BaseHandler):
except ValueError:
ui_auth_session_id = None
- return nonce, client_redirect_url, ui_auth_session_id
+ return OidcSessionData(
+ nonce=nonce,
+ idp_id=idp_id,
+ client_redirect_url=client_redirect_url,
+ ui_auth_session_id=ui_auth_session_id,
+ )
def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
"""Extracts a caveat value from a macaroon token.
@@ -812,7 +1024,7 @@ class OidcHandler(BaseHandler):
The extracted value
Raises:
- Exception: if the caveat was not in the macaroon
+ ValueError: if the caveat was not in the macaroon
"""
prefix = key + " = "
for caveat in macaroon.caveats:
@@ -825,117 +1037,30 @@ class OidcHandler(BaseHandler):
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
- now = self.clock.time_msec()
+ now = self._clock.time_msec()
return now < expiry
- async def _map_userinfo_to_user(
- self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str
- ) -> str:
- """Maps a UserInfo object to a mxid.
-
- UserInfo should have a claim that uniquely identifies users. This claim
- is usually `sub`, but can be configured with `oidc_config.subject_claim`.
- It is then used as an `external_id`.
-
- If we don't find the user that way, we should register the user,
- mapping the localpart and the display name from the UserInfo.
-
- If a user already exists with the mxid we've mapped and allow_existing_users
- is disabled, raise an exception.
-
- Args:
- userinfo: an object representing the user
- token: a dict with the tokens obtained from the provider
- user_agent: The user agent of the client making the request.
- ip_address: The IP address of the client making the request.
-
- Raises:
- MappingException: if there was an error while mapping some properties
-
- Returns:
- The mxid of the user
- """
- try:
- remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
- except Exception as e:
- raise MappingException(
- "Failed to extract subject from OIDC response: %s" % (e,)
- )
- # Some OIDC providers use integer IDs, but Synapse expects external IDs
- # to be strings.
- remote_user_id = str(remote_user_id)
-
- # Older mapping providers don't accept the `failures` argument, so we
- # try and detect support.
- mapper_signature = inspect.signature(
- self._user_mapping_provider.map_user_attributes
- )
- supports_failures = "failures" in mapper_signature.parameters
-
- async def oidc_response_to_user_attributes(failures: int) -> UserAttributes:
- """
- Call the mapping provider to map the OIDC userinfo and token to user attributes.
-
- This is backwards compatibility for abstraction for the SSO handler.
- """
- if supports_failures:
- attributes = await self._user_mapping_provider.map_user_attributes(
- userinfo, token, failures
- )
- else:
- # If the mapping provider does not support processing failures,
- # do not continually generate the same Matrix ID since it will
- # continue to already be in use. Note that the error raised is
- # arbitrary and will get turned into a MappingException.
- if failures:
- raise MappingException(
- "Mapping provider does not support de-duplicating Matrix IDs"
- )
-
- attributes = await self._user_mapping_provider.map_user_attributes( # type: ignore
- userinfo, token
- )
- return UserAttributes(**attributes)
+@attr.s(frozen=True, slots=True)
+class OidcSessionData:
+ """The attributes which are stored in a OIDC session cookie"""
- async def grandfather_existing_users() -> Optional[str]:
- if self._allow_existing_users:
- # If allowing existing users we want to generate a single localpart
- # and attempt to match it.
- attributes = await oidc_response_to_user_attributes(failures=0)
+ # the Identity Provider being used
+ idp_id = attr.ib(type=str)
- user_id = UserID(attributes.localpart, self.server_name).to_string()
- users = await self.store.get_users_by_id_case_insensitive(user_id)
- if users:
- # If an existing matrix ID is returned, then use it.
- if len(users) == 1:
- previously_registered_user_id = next(iter(users))
- elif user_id in users:
- previously_registered_user_id = user_id
- else:
- # Do not attempt to continue generating Matrix IDs.
- raise MappingException(
- "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
- user_id, users
- )
- )
+ # The `nonce` parameter passed to the OIDC provider.
+ nonce = attr.ib(type=str)
- return previously_registered_user_id
+ # The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
+ client_redirect_url = attr.ib(type=str)
- return None
-
- return await self._sso_handler.get_mxid_from_sso(
- self._auth_provider_id,
- remote_user_id,
- user_agent,
- ip_address,
- oidc_response_to_user_attributes,
- grandfather_existing_users,
- )
+ # The session ID of the ongoing UI Auth (None if this is a login)
+ ui_auth_session_id = attr.ib(type=Optional[str], default=None)
UserAttributeDict = TypedDict(
- "UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
+ "UserAttributeDict",
+ {"localpart": Optional[str], "display_name": Optional[str], "emails": List[str]},
)
C = TypeVar("C")
@@ -1014,12 +1139,13 @@ def jinja_finalize(thing):
env = Environment(finalize=jinja_finalize)
-@attr.s
+@attr.s(slots=True, frozen=True)
class JinjaOidcMappingConfig:
- subject_claim = attr.ib() # type: str
- localpart_template = attr.ib() # type: Template
- display_name_template = attr.ib() # type: Optional[Template]
- extra_attributes = attr.ib() # type: Dict[str, Template]
+ subject_claim = attr.ib(type=str)
+ localpart_template = attr.ib(type=Optional[Template])
+ display_name_template = attr.ib(type=Optional[Template])
+ email_template = attr.ib(type=Optional[Template])
+ extra_attributes = attr.ib(type=Dict[str, Template])
class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
@@ -1035,50 +1161,37 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
def parse_config(config: dict) -> JinjaOidcMappingConfig:
subject_claim = config.get("subject_claim", "sub")
- if "localpart_template" not in config:
- raise ConfigError(
- "missing key: oidc_config.user_mapping_provider.config.localpart_template"
- )
-
- try:
- localpart_template = env.from_string(config["localpart_template"])
- except Exception as e:
- raise ConfigError(
- "invalid jinja template for oidc_config.user_mapping_provider.config.localpart_template: %r"
- % (e,)
- )
-
- display_name_template = None # type: Optional[Template]
- if "display_name_template" in config:
+ def parse_template_config(option_name: str) -> Optional[Template]:
+ if option_name not in config:
+ return None
try:
- display_name_template = env.from_string(config["display_name_template"])
+ return env.from_string(config[option_name])
except Exception as e:
- raise ConfigError(
- "invalid jinja template for oidc_config.user_mapping_provider.config.display_name_template: %r"
- % (e,)
- )
+ raise ConfigError("invalid jinja template", path=[option_name]) from e
+
+ localpart_template = parse_template_config("localpart_template")
+ display_name_template = parse_template_config("display_name_template")
+ email_template = parse_template_config("email_template")
extra_attributes = {} # type Dict[str, Template]
if "extra_attributes" in config:
extra_attributes_config = config.get("extra_attributes") or {}
if not isinstance(extra_attributes_config, dict):
- raise ConfigError(
- "oidc_config.user_mapping_provider.config.extra_attributes must be a dict"
- )
+ raise ConfigError("must be a dict", path=["extra_attributes"])
for key, value in extra_attributes_config.items():
try:
extra_attributes[key] = env.from_string(value)
except Exception as e:
raise ConfigError(
- "invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r"
- % (key, e)
- )
+ "invalid jinja template", path=["extra_attributes", key]
+ ) from e
return JinjaOidcMappingConfig(
subject_claim=subject_claim,
localpart_template=localpart_template,
display_name_template=display_name_template,
+ email_template=email_template,
extra_attributes=extra_attributes,
)
@@ -1088,25 +1201,35 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
async def map_user_attributes(
self, userinfo: UserInfo, token: Token, failures: int
) -> UserAttributeDict:
- localpart = self._config.localpart_template.render(user=userinfo).strip()
+ localpart = None
- # Ensure only valid characters are included in the MXID.
- localpart = map_username_to_mxid_localpart(localpart)
+ if self._config.localpart_template:
+ localpart = self._config.localpart_template.render(user=userinfo).strip()
- # Append suffix integer if last call to this function failed to produce
- # a usable mxid.
- localpart += str(failures) if failures else ""
+ # Ensure only valid characters are included in the MXID.
+ localpart = map_username_to_mxid_localpart(localpart)
- display_name = None # type: Optional[str]
- if self._config.display_name_template is not None:
- display_name = self._config.display_name_template.render(
- user=userinfo
- ).strip()
+ # Append suffix integer if last call to this function failed to produce
+ # a usable mxid.
+ localpart += str(failures) if failures else ""
- if display_name == "":
- display_name = None
+ def render_template_field(template: Optional[Template]) -> Optional[str]:
+ if template is None:
+ return None
+ return template.render(user=userinfo).strip()
- return UserAttributeDict(localpart=localpart, display_name=display_name)
+ display_name = render_template_field(self._config.display_name_template)
+ if display_name == "":
+ display_name = None
+
+ emails = [] # type: List[str]
+ email = render_template_field(self._config.email_template)
+ if email:
+ emails.append(email)
+
+ return UserAttributeDict(
+ localpart=localpart, display_name=display_name, emails=emails
+ )
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
extras = {} # type: Dict[str, str]
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index dee0ef45e7..c02b951031 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -156,7 +156,7 @@ class ProfileHandler(BaseHandler):
except HttpResponseException as e:
raise e.to_synapse_error()
- return result["displayname"]
+ return result.get("displayname")
async def set_displayname(
self,
@@ -246,7 +246,7 @@ class ProfileHandler(BaseHandler):
except HttpResponseException as e:
raise e.to_synapse_error()
- return result["avatar_url"]
+ return result.get("avatar_url")
async def set_avatar_url(
self,
@@ -286,13 +286,19 @@ class ProfileHandler(BaseHandler):
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
)
+ avatar_url_to_set = new_avatar_url # type: Optional[str]
+ if new_avatar_url == "":
+ avatar_url_to_set = None
+
# Same like set_displayname
if by_admin:
requester = create_requester(
target_user, authenticated_entity=requester.authenticated_entity
)
- await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
+ await self.store.set_profile_avatar_url(
+ target_user.localpart, avatar_url_to_set
+ )
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(target_user.localpart)
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index a7550806e6..6bb2fd936b 100644
--- a/synapse/handlers/read_marker.py
+++ b/synapse/handlers/read_marker.py
@@ -31,8 +31,8 @@ class ReadMarkerHandler(BaseHandler):
super().__init__(hs)
self.server_name = hs.config.server_name
self.store = hs.get_datastore()
+ self.account_data_handler = hs.get_account_data_handler()
self.read_marker_linearizer = Linearizer(name="read_marker")
- self.notifier = hs.get_notifier()
async def received_client_read_marker(
self, room_id: str, user_id: str, event_id: str
@@ -59,7 +59,6 @@ class ReadMarkerHandler(BaseHandler):
if should_update:
content = {"event_id": event_id}
- max_id = await self.store.add_account_data_to_room(
+ await self.account_data_handler.add_account_data_to_room(
user_id, room_id, "m.fully_read", content
)
- self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 153cbae7b9..cc21fc2284 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -13,31 +13,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import List, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.appservice import ApplicationService
from synapse.handlers._base import BaseHandler
from synapse.types import JsonDict, ReadReceipt, get_domain_from_id
-from synapse.util.async_helpers import maybe_awaitable
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
class ReceiptsHandler(BaseHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.server_name = hs.config.server_name
self.store = hs.get_datastore()
self.hs = hs
- self.federation = hs.get_federation_sender()
- hs.get_federation_registry().register_edu_handler(
- "m.receipt", self._received_remote_receipt
- )
+
+ # We only need to poke the federation sender explicitly if its on the
+ # same instance. Other federation sender instances will get notified by
+ # `synapse.app.generic_worker.FederationSenderHandler` when it sees it
+ # in the receipts stream.
+ self.federation_sender = None
+ if hs.should_send_federation():
+ self.federation_sender = hs.get_federation_sender()
+
+ # If we can handle the receipt EDUs we do so, otherwise we route them
+ # to the appropriate worker.
+ if hs.get_instance_name() in hs.config.worker.writers.receipts:
+ hs.get_federation_registry().register_edu_handler(
+ "m.receipt", self._received_remote_receipt
+ )
+ else:
+ hs.get_federation_registry().register_instances_for_edu(
+ "m.receipt", hs.config.worker.writers.receipts,
+ )
+
self.clock = self.hs.get_clock()
self.state = hs.get_state_handler()
- async def _received_remote_receipt(self, origin, content):
+ async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None:
"""Called when we receive an EDU of type m.receipt from a remote HS.
"""
receipts = []
@@ -64,11 +82,11 @@ class ReceiptsHandler(BaseHandler):
await self._handle_new_receipts(receipts)
- async def _handle_new_receipts(self, receipts):
+ async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
"""Takes a list of receipts, stores them and informs the notifier.
"""
- min_batch_id = None
- max_batch_id = None
+ min_batch_id = None # type: Optional[int]
+ max_batch_id = None # type: Optional[int]
for receipt in receipts:
res = await self.store.insert_receipt(
@@ -90,7 +108,8 @@ class ReceiptsHandler(BaseHandler):
if max_batch_id is None or max_persisted_id > max_batch_id:
max_batch_id = max_persisted_id
- if min_batch_id is None:
+ # Either both of these should be None or neither.
+ if min_batch_id is None or max_batch_id is None:
# no new receipts
return False
@@ -98,15 +117,15 @@ class ReceiptsHandler(BaseHandler):
self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
# Note that the min here shouldn't be relied upon to be accurate.
- await maybe_awaitable(
- self.hs.get_pusherpool().on_new_receipts(
- min_batch_id, max_batch_id, affected_room_ids
- )
+ await self.hs.get_pusherpool().on_new_receipts(
+ min_batch_id, max_batch_id, affected_room_ids
)
return True
- async def received_client_receipt(self, room_id, receipt_type, user_id, event_id):
+ async def received_client_receipt(
+ self, room_id: str, receipt_type: str, user_id: str, event_id: str
+ ) -> None:
"""Called when a client tells us a local user has read up to the given
event_id in the room.
"""
@@ -122,14 +141,17 @@ class ReceiptsHandler(BaseHandler):
if not is_new:
return
- await self.federation.send_read_receipt(receipt)
+ if self.federation_sender:
+ await self.federation_sender.send_read_receipt(receipt)
class ReceiptEventSource:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
- async def get_new_events(self, from_key, room_ids, **kwargs):
+ async def get_new_events(
+ self, from_key: int, room_ids: List[str], **kwargs
+ ) -> Tuple[List[JsonDict], int]:
from_key = int(from_key)
to_key = self.get_current_key()
@@ -174,5 +196,5 @@ class ReceiptEventSource:
return (events, to_key)
- def get_current_key(self, direction="f"):
+ def get_current_key(self, direction: str = "f") -> int:
return self.store.get_max_receipt_stream_id()
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 0d85fd0868..49b085269b 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -14,8 +14,9 @@
# limitations under the License.
"""Contains functions for registering clients."""
+
import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from synapse import types
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
@@ -152,7 +153,7 @@ class RegistrationHandler(BaseHandler):
user_type: Optional[str] = None,
default_display_name: Optional[str] = None,
address: Optional[str] = None,
- bind_emails: List[str] = [],
+ bind_emails: Iterable[str] = [],
by_admin: bool = False,
user_agent_ips: Optional[List[Tuple[str, str]]] = None,
) -> str:
@@ -187,7 +188,7 @@ class RegistrationHandler(BaseHandler):
"""
self.check_registration_ratelimit(address)
- result = self.spam_checker.check_registration_for_spam(
+ result = await self.spam_checker.check_registration_for_spam(
threepid, localpart, user_agent_ips or [],
)
@@ -630,6 +631,7 @@ class RegistrationHandler(BaseHandler):
device_id: Optional[str],
initial_display_name: Optional[str],
is_guest: bool = False,
+ is_appservice_ghost: bool = False,
) -> Tuple[str, str]:
"""Register a device for a user and generate an access token.
@@ -651,6 +653,7 @@ class RegistrationHandler(BaseHandler):
device_id=device_id,
initial_display_name=initial_display_name,
is_guest=is_guest,
+ is_appservice_ghost=is_appservice_ghost,
)
return r["device_id"], r["access_token"]
@@ -672,7 +675,10 @@ class RegistrationHandler(BaseHandler):
)
else:
access_token = await self._auth_handler.get_access_token_for_user_id(
- user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms
+ user_id,
+ device_id=registered_device_id,
+ valid_until_ms=valid_until_ms,
+ is_appservice_ghost=is_appservice_ghost,
)
return (registered_device_id, access_token)
@@ -688,6 +694,8 @@ class RegistrationHandler(BaseHandler):
access_token: The access token of the newly logged in device, or
None if `inhibit_login` enabled.
"""
+ # TODO: 3pid registration can actually happen on the workers. Consider
+ # refactoring it.
if self.hs.config.worker_app:
await self._post_registration_client(
user_id=user_id, auth_result=auth_result, access_token=access_token
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 930047e730..07b2187eb1 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -27,6 +27,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
from synapse.api.constants import (
EventTypes,
+ HistoryVisibility,
JoinRules,
Membership,
RoomCreationPreset,
@@ -37,7 +38,6 @@ from synapse.api.filtering import Filter
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents
-from synapse.http.endpoint import parse_and_validate_server_name
from synapse.storage.state import StateFilter
from synapse.types import (
JsonDict,
@@ -54,6 +54,7 @@ from synapse.types import (
from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.response_cache import ResponseCache
+from synapse.util.stringutils import parse_and_validate_server_name
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -81,21 +82,21 @@ class RoomCreationHandler(BaseHandler):
self._presets_dict = {
RoomCreationPreset.PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE,
- "history_visibility": "shared",
+ "history_visibility": HistoryVisibility.SHARED,
"original_invitees_have_ops": False,
"guest_can_join": True,
"power_level_content_override": {"invite": 0},
},
RoomCreationPreset.TRUSTED_PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE,
- "history_visibility": "shared",
+ "history_visibility": HistoryVisibility.SHARED,
"original_invitees_have_ops": True,
"guest_can_join": True,
"power_level_content_override": {"invite": 0},
},
RoomCreationPreset.PUBLIC_CHAT: {
"join_rules": JoinRules.PUBLIC,
- "history_visibility": "shared",
+ "history_visibility": HistoryVisibility.SHARED,
"original_invitees_have_ops": False,
"guest_can_join": False,
"power_level_content_override": {},
@@ -125,6 +126,10 @@ class RoomCreationHandler(BaseHandler):
self.third_party_event_rules = hs.get_third_party_event_rules()
+ self._invite_burst_count = (
+ hs.config.ratelimiting.rc_invites_per_room.burst_count
+ )
+
async def upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
) -> str:
@@ -358,13 +363,13 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- if not self.spam_checker.user_may_create_room(user_id):
+ if not await self.spam_checker.user_may_create_room(user_id):
raise SynapseError(403, "You are not permitted to create rooms")
creation_content = {
"room_version": new_room_version.identifier,
"predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id},
- }
+ } # type: JsonDict
# Check if old room was non-federatable
@@ -440,6 +445,7 @@ class RoomCreationHandler(BaseHandler):
invite_list=[],
initial_state=initial_state,
creation_content=creation_content,
+ ratelimit=False,
)
# Transfer membership events
@@ -608,7 +614,7 @@ class RoomCreationHandler(BaseHandler):
403, "You are not permitted to create rooms", Codes.FORBIDDEN
)
- if not is_requester_admin and not self.spam_checker.user_may_create_room(
+ if not is_requester_admin and not await self.spam_checker.user_may_create_room(
user_id
):
raise SynapseError(403, "You are not permitted to create rooms")
@@ -660,6 +666,9 @@ class RoomCreationHandler(BaseHandler):
invite_3pid_list = []
invite_list = []
+ if len(invite_list) + len(invite_3pid_list) > self._invite_burst_count:
+ raise SynapseError(400, "Cannot invite so many users at once")
+
await self.event_creation_handler.assert_accepted_privacy_policy(requester)
power_level_content_override = config.get("power_level_content_override")
@@ -735,6 +744,7 @@ class RoomCreationHandler(BaseHandler):
room_alias=room_alias,
power_level_content_override=power_level_content_override,
creator_join_profile=creator_join_profile,
+ ratelimit=ratelimit,
)
if "name" in config:
@@ -838,6 +848,7 @@ class RoomCreationHandler(BaseHandler):
room_alias: Optional[RoomAlias] = None,
power_level_content_override: Optional[JsonDict] = None,
creator_join_profile: Optional[JsonDict] = None,
+ ratelimit: bool = True,
) -> int:
"""Sends the initial events into a new room.
@@ -884,7 +895,7 @@ class RoomCreationHandler(BaseHandler):
creator.user,
room_id,
"join",
- ratelimit=False,
+ ratelimit=ratelimit,
content=creator_join_profile,
)
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 4a13c8e912..14f14db449 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -15,19 +15,22 @@
import logging
from collections import namedtuple
-from typing import Any, Dict, Optional
+from typing import TYPE_CHECKING, Optional, Tuple
import msgpack
from unpaddedbase64 import decode_base64, encode_base64
-from synapse.api.constants import EventTypes, JoinRules
+from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
from synapse.api.errors import Codes, HttpResponseException
-from synapse.types import ThirdPartyInstanceID
+from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util.caches.descriptors import cached
from synapse.util.caches.response_cache import ResponseCache
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
@@ -37,37 +40,38 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
class RoomListHandler(BaseHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.enable_room_list_search = hs.config.enable_room_list_search
- self.response_cache = ResponseCache(hs, "room_list")
+ self.response_cache = ResponseCache(
+ hs, "room_list"
+ ) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]]
self.remote_response_cache = ResponseCache(
hs, "remote_room_list", timeout_ms=30 * 1000
- )
+ ) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]]
async def get_local_public_room_list(
self,
- limit=None,
- since_token=None,
- search_filter=None,
- network_tuple=EMPTY_THIRD_PARTY_ID,
- from_federation=False,
- ):
+ limit: Optional[int] = None,
+ since_token: Optional[str] = None,
+ search_filter: Optional[dict] = None,
+ network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID,
+ from_federation: bool = False,
+ ) -> JsonDict:
"""Generate a local public room list.
There are multiple different lists: the main one plus one per third
party network. A client can ask for a specific list or to return all.
Args:
- limit (int|None)
- since_token (str|None)
- search_filter (dict|None)
- network_tuple (ThirdPartyInstanceID): Which public list to use.
+ limit
+ since_token
+ search_filter
+ network_tuple: Which public list to use.
This can be (None, None) to indicate the main list, or a particular
appservice and network id to use an appservice specific one.
Setting to None returns all public rooms across all lists.
- from_federation (bool): true iff the request comes from the federation
- API
+ from_federation: true iff the request comes from the federation API
"""
if not self.enable_room_list_search:
return {"chunk": [], "total_room_count_estimate": 0}
@@ -107,10 +111,10 @@ class RoomListHandler(BaseHandler):
self,
limit: Optional[int] = None,
since_token: Optional[str] = None,
- search_filter: Optional[Dict] = None,
+ search_filter: Optional[dict] = None,
network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID,
from_federation: bool = False,
- ) -> Dict[str, Any]:
+ ) -> JsonDict:
"""Generate a public room list.
Args:
limit: Maximum amount of rooms to return.
@@ -131,13 +135,17 @@ class RoomListHandler(BaseHandler):
if since_token:
batch_token = RoomListNextBatch.from_token(since_token)
- bounds = (batch_token.last_joined_members, batch_token.last_room_id)
+ bounds = (
+ batch_token.last_joined_members,
+ batch_token.last_room_id,
+ ) # type: Optional[Tuple[int, str]]
forwards = batch_token.direction_is_forward
+ has_batch_token = True
else:
- batch_token = None
bounds = None
forwards = True
+ has_batch_token = False
# we request one more than wanted to see if there are more pages to come
probing_limit = limit + 1 if limit is not None else None
@@ -159,7 +167,8 @@ class RoomListHandler(BaseHandler):
"canonical_alias": room["canonical_alias"],
"num_joined_members": room["joined_members"],
"avatar_url": room["avatar"],
- "world_readable": room["history_visibility"] == "world_readable",
+ "world_readable": room["history_visibility"]
+ == HistoryVisibility.WORLD_READABLE,
"guest_can_join": room["guest_access"] == "can_join",
}
@@ -168,7 +177,7 @@ class RoomListHandler(BaseHandler):
results = [build_room_entry(r) for r in results]
- response = {}
+ response = {} # type: JsonDict
num_results = len(results)
if limit is not None:
more_to_come = num_results == probing_limit
@@ -186,7 +195,7 @@ class RoomListHandler(BaseHandler):
initial_entry = results[0]
if forwards:
- if batch_token:
+ if has_batch_token:
# If there was a token given then we assume that there
# must be previous results.
response["prev_batch"] = RoomListNextBatch(
@@ -202,7 +211,7 @@ class RoomListHandler(BaseHandler):
direction_is_forward=True,
).to_token()
else:
- if batch_token:
+ if has_batch_token:
response["next_batch"] = RoomListNextBatch(
last_joined_members=final_entry["num_joined_members"],
last_room_id=final_entry["room_id"],
@@ -292,7 +301,7 @@ class RoomListHandler(BaseHandler):
return None
# Return whether this room is open to federation users or not
- create_event = current_state.get((EventTypes.Create, ""))
+ create_event = current_state[EventTypes.Create, ""]
result["m.federate"] = create_event.content.get("m.federate", True)
name_event = current_state.get((EventTypes.Name, ""))
@@ -317,7 +326,7 @@ class RoomListHandler(BaseHandler):
visibility = None
if visibility_event:
visibility = visibility_event.content.get("history_visibility", None)
- result["world_readable"] = visibility == "world_readable"
+ result["world_readable"] = visibility == HistoryVisibility.WORLD_READABLE
guest_event = current_state.get((EventTypes.GuestAccess, ""))
guest = None
@@ -335,13 +344,13 @@ class RoomListHandler(BaseHandler):
async def get_remote_public_room_list(
self,
- server_name,
- limit=None,
- since_token=None,
- search_filter=None,
- include_all_networks=False,
- third_party_instance_id=None,
- ):
+ server_name: str,
+ limit: Optional[int] = None,
+ since_token: Optional[str] = None,
+ search_filter: Optional[dict] = None,
+ include_all_networks: bool = False,
+ third_party_instance_id: Optional[str] = None,
+ ) -> JsonDict:
if not self.enable_room_list_search:
return {"chunk": [], "total_room_count_estimate": 0}
@@ -398,13 +407,13 @@ class RoomListHandler(BaseHandler):
async def _get_remote_list_cached(
self,
- server_name,
- limit=None,
- since_token=None,
- search_filter=None,
- include_all_networks=False,
- third_party_instance_id=None,
- ):
+ server_name: str,
+ limit: Optional[int] = None,
+ since_token: Optional[str] = None,
+ search_filter: Optional[dict] = None,
+ include_all_networks: bool = False,
+ third_party_instance_id: Optional[str] = None,
+ ) -> JsonDict:
repl_layer = self.hs.get_federation_client()
if search_filter:
# We can't cache when asking for search
@@ -455,24 +464,24 @@ class RoomListNextBatch(
REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()}
@classmethod
- def from_token(cls, token):
+ def from_token(cls, token: str) -> "RoomListNextBatch":
decoded = msgpack.loads(decode_base64(token), raw=False)
return RoomListNextBatch(
**{cls.REVERSE_KEY_DICT[key]: val for key, val in decoded.items()}
)
- def to_token(self):
+ def to_token(self) -> str:
return encode_base64(
msgpack.dumps(
{self.KEY_DICT[key]: val for key, val in self._asdict().items()}
)
)
- def copy_and_replace(self, **kwds):
+ def copy_and_replace(self, **kwds) -> "RoomListNextBatch":
return self._replace(**kwds)
-def _matches_room_entry(room_entry, search_filter):
+def _matches_room_entry(room_entry: JsonDict, search_filter: dict) -> bool:
if search_filter and search_filter.get("generic_search_term", None):
generic_search_term = search_filter["generic_search_term"].upper()
if generic_search_term in room_entry.get("name", "").upper():
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index c002886324..a5da97cfe0 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -63,6 +63,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.registration_handler = hs.get_registration_handler()
self.profile_handler = hs.get_profile_handler()
self.event_creation_handler = hs.get_event_creation_handler()
+ self.account_data_handler = hs.get_account_data_handler()
self.member_linearizer = Linearizer(name="member")
@@ -84,6 +85,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
)
+ self._invites_per_room_limiter = Ratelimiter(
+ clock=self.clock,
+ rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second,
+ burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count,
+ )
+ self._invites_per_user_limiter = Ratelimiter(
+ clock=self.clock,
+ rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second,
+ burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
+ )
+
# This is only used to get at ratelimit function, and
# maybe_kick_guest_users. It's fine there are multiple of these as
# it doesn't store state.
@@ -143,6 +155,16 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
"""
raise NotImplementedError()
+ def ratelimit_invite(self, room_id: Optional[str], invitee_user_id: str):
+ """Ratelimit invites by room and by target user.
+
+ If room ID is missing then we just rate limit by target user.
+ """
+ if room_id:
+ self._invites_per_room_limiter.ratelimit(room_id)
+
+ self._invites_per_user_limiter.ratelimit(invitee_user_id)
+
async def _local_membership_update(
self,
requester: Requester,
@@ -203,7 +225,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Only rate-limit if the user actually joined the room, otherwise we'll end
# up blocking profile updates.
- if newly_joined:
+ if newly_joined and ratelimit:
time_now_s = self.clock.time()
(
allowed,
@@ -253,7 +275,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
direct_rooms[key].append(new_room_id)
# Save back to user's m.direct account data
- await self.store.add_account_data_for_user(
+ await self.account_data_handler.add_account_data_for_user(
user_id, AccountDataTypes.DIRECT, direct_rooms
)
break
@@ -263,7 +285,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Copy each room tag to the new room
for tag, tag_content in room_tags.items():
- await self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
+ await self.account_data_handler.add_tag_to_room(
+ user_id, new_room_id, tag, tag_content
+ )
async def update_membership(
self,
@@ -384,8 +408,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise SynapseError(403, "This room has been blocked on this server")
if effective_membership_state == Membership.INVITE:
+ target_id = target.to_string()
+ if ratelimit:
+ # Don't ratelimit application services.
+ if not requester.app_service or requester.app_service.is_rate_limited():
+ self.ratelimit_invite(room_id, target_id)
+
# block any attempts to invite the server notices mxid
- if target.to_string() == self._server_notices_mxid:
+ if target_id == self._server_notices_mxid:
raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
block_invite = False
@@ -408,8 +438,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
)
block_invite = True
- if not self.spam_checker.user_may_invite(
- requester.user.to_string(), target.to_string(), room_id
+ if not await self.spam_checker.user_may_invite(
+ requester.user.to_string(), target_id, room_id
):
logger.info("Blocking invite due to spam checker")
block_invite = True
@@ -488,17 +518,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise AuthError(403, "Guest access not allowed")
if not is_host_in_room:
- time_now_s = self.clock.time()
- (
- allowed,
- time_allowed,
- ) = self._join_rate_limiter_remote.can_requester_do_action(requester,)
-
- if not allowed:
- raise LimitExceededError(
- retry_after_ms=int(1000 * (time_allowed - time_now_s))
+ if ratelimit:
+ time_now_s = self.clock.time()
+ (
+ allowed,
+ time_allowed,
+ ) = self._join_rate_limiter_remote.can_requester_do_action(
+ requester,
)
+ if not allowed:
+ raise LimitExceededError(
+ retry_after_ms=int(1000 * (time_allowed - time_now_s))
+ )
+
inviter = await self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 76d4169fe2..e88fd59749 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -34,7 +34,6 @@ from synapse.types import (
map_username_to_mxid_localpart,
mxid_localpart_allowed_characters,
)
-from synapse.util.async_helpers import Linearizer
from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING:
@@ -59,8 +58,6 @@ class SamlHandler(BaseHandler):
super().__init__(hs)
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._saml_idp_entityid = hs.config.saml2_idp_entityid
- self._auth_handler = hs.get_auth_handler()
- self._registration_handler = hs.get_registration_handler()
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
self._grandfathered_mxid_source_attribute = (
@@ -76,30 +73,46 @@ class SamlHandler(BaseHandler):
)
# identifier for the external_ids table
- self._auth_provider_id = "saml"
+ self.idp_id = "saml"
+
+ # user-facing name of this auth provider
+ self.idp_name = "SAML"
+
+ # we do not currently support icons/brands for SAML auth, but this is required by
+ # the SsoIdentityProvider protocol type.
+ self.idp_icon = None
+ self.idp_brand = None
# a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
- # a lock on the mappings
- self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
-
self._sso_handler = hs.get_sso_handler()
+ self._sso_handler.register_identity_provider(self)
- def handle_redirect_request(
- self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
- ) -> bytes:
+ async def handle_redirect_request(
+ self,
+ request: SynapseRequest,
+ client_redirect_url: Optional[bytes],
+ ui_auth_session_id: Optional[str] = None,
+ ) -> str:
"""Handle an incoming request to /login/sso/redirect
Args:
+ request: the incoming HTTP request
client_redirect_url: the URL that we should redirect the
- client to when everything is done
+ client to after login (or None for UI Auth).
ui_auth_session_id: The session ID of the ongoing UI Auth (or
None if this is a login).
Returns:
URL to redirect to
"""
+ if not client_redirect_url:
+ # Some SAML identity providers (e.g. Google) require a
+ # RelayState parameter on requests, so pass in a dummy redirect URL
+ # (which will never get used).
+ client_redirect_url = b"unused"
+
reqid, info = self._saml_client.prepare_for_authenticate(
entityid=self._saml_idp_entityid, relay_state=client_redirect_url
)
@@ -120,7 +133,7 @@ class SamlHandler(BaseHandler):
raise Exception("prepare_for_authenticate didn't return a Location header")
async def handle_saml_response(self, request: SynapseRequest) -> None:
- """Handle an incoming request to /_matrix/saml2/authn_response
+ """Handle an incoming request to /_synapse/client/saml2/authn_response
Args:
request: the incoming request from the browser. We'll
@@ -167,6 +180,29 @@ class SamlHandler(BaseHandler):
return
logger.debug("SAML2 response: %s", saml2_auth.origxml)
+
+ await self._handle_authn_response(request, saml2_auth, relay_state)
+
+ async def _handle_authn_response(
+ self,
+ request: SynapseRequest,
+ saml2_auth: saml2.response.AuthnResponse,
+ relay_state: str,
+ ) -> None:
+ """Handle an AuthnResponse, having parsed it from the request params
+
+ Assumes that the signature on the response object has been checked. Maps
+ the user onto an MXID, registering them if necessary, and returns a response
+ to the browser.
+
+ Args:
+ request: the incoming request from the browser. We'll respond to it with an
+ HTML page or a redirect
+ saml2_auth: the parsed AuthnResponse object
+ relay_state: the RelayState query param, which encodes the URI to rediret
+ back to
+ """
+
for assertion in saml2_auth.assertions:
# kibana limits the length of a log field, whereas this is all rather
# useful, so split it up.
@@ -183,6 +219,24 @@ class SamlHandler(BaseHandler):
saml2_auth.in_response_to, None
)
+ # first check if we're doing a UIA
+ if current_session and current_session.ui_auth_session_id:
+ try:
+ remote_user_id = self._remote_id_from_saml_response(saml2_auth, None)
+ except MappingException as e:
+ logger.exception("Failed to extract remote user id from SAML response")
+ self._sso_handler.render_error(request, "mapping_error", str(e))
+ return
+
+ return await self._sso_handler.complete_sso_ui_auth_request(
+ self.idp_id,
+ remote_user_id,
+ current_session.ui_auth_session_id,
+ request,
+ )
+
+ # otherwise, we're handling a login request.
+
# Ensure that the attributes of the logged in user meet the required
# attributes.
for requirement in self._saml2_attribute_requirements:
@@ -192,63 +246,39 @@ class SamlHandler(BaseHandler):
)
return
- # Pull out the user-agent and IP from the request.
- user_agent = request.get_user_agent("")
- ip_address = self.hs.get_ip_from_request(request)
-
# Call the mapper to register/login the user
try:
- user_id = await self._map_saml_response_to_user(
- saml2_auth, relay_state, user_agent, ip_address
- )
+ await self._complete_saml_login(saml2_auth, request, relay_state)
except MappingException as e:
logger.exception("Could not map user")
self._sso_handler.render_error(request, "mapping_error", str(e))
- return
-
- # Complete the interactive auth session or the login.
- if current_session and current_session.ui_auth_session_id:
- await self._auth_handler.complete_sso_ui_auth(
- user_id, current_session.ui_auth_session_id, request
- )
-
- else:
- await self._auth_handler.complete_sso_login(user_id, request, relay_state)
- async def _map_saml_response_to_user(
+ async def _complete_saml_login(
self,
saml2_auth: saml2.response.AuthnResponse,
+ request: SynapseRequest,
client_redirect_url: str,
- user_agent: str,
- ip_address: str,
- ) -> str:
+ ) -> None:
"""
- Given a SAML response, retrieve the user ID for it and possibly register the user.
+ Given a SAML response, complete the login flow
+
+ Retrieves the remote user ID, registers the user if necessary, and serves
+ a redirect back to the client with a login-token.
Args:
saml2_auth: The parsed SAML2 response.
+ request: The request to respond to
client_redirect_url: The redirect URL passed in by the client.
- user_agent: The user agent of the client making the request.
- ip_address: The IP address of the client making the request.
-
- Returns:
- The user ID associated with this response.
Raises:
MappingException if there was a problem mapping the response to a user.
RedirectException: some mapping providers may raise this if they need
to redirect to an interstitial page.
"""
-
- remote_user_id = self._user_mapping_provider.get_remote_user_id(
+ remote_user_id = self._remote_id_from_saml_response(
saml2_auth, client_redirect_url
)
- if not remote_user_id:
- raise MappingException(
- "Failed to extract remote user id from SAML response"
- )
-
async def saml_response_to_remapped_user_attributes(
failures: int,
) -> UserAttributes:
@@ -294,16 +324,44 @@ class SamlHandler(BaseHandler):
return None
- with (await self._mapping_lock.queue(self._auth_provider_id)):
- return await self._sso_handler.get_mxid_from_sso(
- self._auth_provider_id,
- remote_user_id,
- user_agent,
- ip_address,
- saml_response_to_remapped_user_attributes,
- grandfather_existing_users,
+ await self._sso_handler.complete_sso_login_request(
+ self.idp_id,
+ remote_user_id,
+ request,
+ client_redirect_url,
+ saml_response_to_remapped_user_attributes,
+ grandfather_existing_users,
+ )
+
+ def _remote_id_from_saml_response(
+ self,
+ saml2_auth: saml2.response.AuthnResponse,
+ client_redirect_url: Optional[str],
+ ) -> str:
+ """Extract the unique remote id from a SAML2 AuthnResponse
+
+ Args:
+ saml2_auth: The parsed SAML2 response.
+ client_redirect_url: The redirect URL passed in by the client.
+ Returns:
+ remote user id
+
+ Raises:
+ MappingException if there was an error extracting the user id
+ """
+ # It's not obvious why we need to pass in the redirect URI to the mapping
+ # provider, but we do :/
+ remote_user_id = self._user_mapping_provider.get_remote_user_id(
+ saml2_auth, client_redirect_url
+ )
+
+ if not remote_user_id:
+ raise MappingException(
+ "Failed to extract remote user id from SAML response"
)
+ return remote_user_id
+
def expire_sessions(self):
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set()
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 66f1bbcfc4..94062e79cb 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -15,23 +15,28 @@
import itertools
import logging
-from typing import Iterable
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional
from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
+from synapse.events import EventBase
from synapse.storage.state import StateFilter
+from synapse.types import JsonDict, UserID
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
class SearchHandler(BaseHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
@@ -87,13 +92,15 @@ class SearchHandler(BaseHandler):
return historical_room_ids
- async def search(self, user, content, batch=None):
+ async def search(
+ self, user: UserID, content: JsonDict, batch: Optional[str] = None
+ ) -> JsonDict:
"""Performs a full text search for a user.
Args:
- user (UserID)
- content (dict): Search parameters
- batch (str): The next_batch parameter. Used for pagination.
+ user
+ content: Search parameters
+ batch: The next_batch parameter. Used for pagination.
Returns:
dict to be returned to the client with results of search
@@ -186,7 +193,7 @@ class SearchHandler(BaseHandler):
# If doing a subset of all rooms seearch, check if any of the rooms
# are from an upgraded room, and search their contents as well
if search_filter.rooms:
- historical_room_ids = []
+ historical_room_ids = [] # type: List[str]
for room_id in search_filter.rooms:
# Add any previous rooms to the search if they exist
ids = await self.get_old_rooms_from_upgraded_room(room_id)
@@ -209,8 +216,10 @@ class SearchHandler(BaseHandler):
rank_map = {} # event_id -> rank of event
allowed_events = []
- room_groups = {} # Holds result of grouping by room, if applicable
- sender_group = {} # Holds result of grouping by sender, if applicable
+ # Holds result of grouping by room, if applicable
+ room_groups = {} # type: Dict[str, JsonDict]
+ # Holds result of grouping by sender, if applicable
+ sender_group = {} # type: Dict[str, JsonDict]
# Holds the next_batch for the entire result set if one of those exists
global_next_batch = None
@@ -254,7 +263,7 @@ class SearchHandler(BaseHandler):
s["results"].append(e.event_id)
elif order_by == "recent":
- room_events = []
+ room_events = [] # type: List[EventBase]
i = 0
pagination_token = batch_token
@@ -418,13 +427,10 @@ class SearchHandler(BaseHandler):
state_results = {}
if include_state:
- rooms = {e.room_id for e in allowed_events}
- for room_id in rooms:
+ for room_id in {e.room_id for e in allowed_events}:
state = await self.state_handler.get_current_state(room_id)
state_results[room_id] = list(state.values())
- state_results.values()
-
# We're now about to serialize the events. We should not make any
# blocking calls after this. Otherwise the 'age' will be wrong
@@ -448,9 +454,9 @@ class SearchHandler(BaseHandler):
if state_results:
s = {}
- for room_id, state in state_results.items():
+ for room_id, state_events in state_results.items():
s[room_id] = await self._event_serializer.serialize_events(
- state, time_now
+ state_events, time_now
)
rooms_cat_res["state"] = s
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index a5d67f828f..84af2dde7e 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -13,24 +13,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.types import Requester
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
class SetPasswordHandler(BaseHandler):
"""Handler which deals with changing user account passwords"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
- self._password_policy_handler = hs.get_password_policy_handler()
async def set_password(
self,
@@ -38,7 +40,7 @@ class SetPasswordHandler(BaseHandler):
password_hash: str,
logout_devices: bool,
requester: Optional[Requester] = None,
- ):
+ ) -> None:
if not self.hs.config.password_localdb_enabled:
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 47ad96f97e..b450668f1c 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -12,15 +12,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
+from typing import (
+ TYPE_CHECKING,
+ Awaitable,
+ Callable,
+ Dict,
+ Iterable,
+ Mapping,
+ Optional,
+ Set,
+)
+from urllib.parse import urlencode
import attr
+from typing_extensions import NoReturn, Protocol
-from synapse.api.errors import RedirectException
-from synapse.handlers._base import BaseHandler
-from synapse.http.server import respond_with_html
-from synapse.types import UserID, contains_invalid_mxid_characters
+from twisted.web.http import Request
+from twisted.web.iweb import IRequest
+
+from synapse.api.constants import LoginType
+from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
+from synapse.handlers.ui_auth import UIAuthSessionDataConstants
+from synapse.http import get_request_user_agent
+from synapse.http.server import respond_with_html, respond_with_redirect
+from synapse.http.site import SynapseRequest
+from synapse.types import Collection, JsonDict, UserID, contains_invalid_mxid_characters
+from synapse.util.async_helpers import Linearizer
+from synapse.util.stringutils import random_string
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -35,24 +55,190 @@ class MappingException(Exception):
"""
+class SsoIdentityProvider(Protocol):
+ """Abstract base class to be implemented by SSO Identity Providers
+
+ An Identity Provider, or IdP, is an external HTTP service which authenticates a user
+ to say whether they should be allowed to log in, or perform a given action.
+
+ Synapse supports various implementations of IdPs, including OpenID Connect, SAML,
+ and CAS.
+
+ The main entry point is `handle_redirect_request`, which should return a URI to
+ redirect the user's browser to the IdP's authentication page.
+
+ Each IdP should be registered with the SsoHandler via
+ `hs.get_sso_handler().register_identity_provider()`, so that requests to
+ `/_matrix/client/r0/login/sso/redirect` can be correctly dispatched.
+ """
+
+ @property
+ @abc.abstractmethod
+ def idp_id(self) -> str:
+ """A unique identifier for this SSO provider
+
+ Eg, "saml", "cas", "github"
+ """
+
+ @property
+ @abc.abstractmethod
+ def idp_name(self) -> str:
+ """User-facing name for this provider"""
+
+ @property
+ def idp_icon(self) -> Optional[str]:
+ """Optional MXC URI for user-facing icon"""
+ return None
+
+ @property
+ def idp_brand(self) -> Optional[str]:
+ """Optional branding identifier"""
+ return None
+
+ @abc.abstractmethod
+ async def handle_redirect_request(
+ self,
+ request: SynapseRequest,
+ client_redirect_url: Optional[bytes],
+ ui_auth_session_id: Optional[str] = None,
+ ) -> str:
+ """Handle an incoming request to /login/sso/redirect
+
+ Args:
+ request: the incoming HTTP request
+ client_redirect_url: the URL that we should redirect the
+ client to after login (or None for UI Auth).
+ ui_auth_session_id: The session ID of the ongoing UI Auth (or
+ None if this is a login).
+
+ Returns:
+ URL to redirect to
+ """
+ raise NotImplementedError()
+
+
@attr.s
class UserAttributes:
- localpart = attr.ib(type=str)
+ # the localpart of the mxid that the mapper has assigned to the user.
+ # if `None`, the mapper has not picked a userid, and the user should be prompted to
+ # enter one.
+ localpart = attr.ib(type=Optional[str])
display_name = attr.ib(type=Optional[str], default=None)
- emails = attr.ib(type=List[str], default=attr.Factory(list))
+ emails = attr.ib(type=Collection[str], default=attr.Factory(list))
+
+
+@attr.s(slots=True)
+class UsernameMappingSession:
+ """Data we track about SSO sessions"""
+
+ # A unique identifier for this SSO provider, e.g. "oidc" or "saml".
+ auth_provider_id = attr.ib(type=str)
+
+ # user ID on the IdP server
+ remote_user_id = attr.ib(type=str)
+ # attributes returned by the ID mapper
+ display_name = attr.ib(type=Optional[str])
+ emails = attr.ib(type=Collection[str])
-class SsoHandler(BaseHandler):
+ # An optional dictionary of extra attributes to be provided to the client in the
+ # login response.
+ extra_login_attributes = attr.ib(type=Optional[JsonDict])
+
+ # where to redirect the client back to
+ client_redirect_url = attr.ib(type=str)
+
+ # expiry time for the session, in milliseconds
+ expiry_time_ms = attr.ib(type=int)
+
+ # choices made by the user
+ chosen_localpart = attr.ib(type=Optional[str], default=None)
+ use_display_name = attr.ib(type=bool, default=True)
+ emails_to_use = attr.ib(type=Collection[str], default=())
+ terms_accepted_version = attr.ib(type=Optional[str], default=None)
+
+
+# the HTTP cookie used to track the mapping session id
+USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session"
+
+
+class SsoHandler:
# The number of attempts to ask the mapping provider for when generating an MXID.
_MAP_USERNAME_RETRIES = 1000
+ # the time a UsernameMappingSession remains valid for
+ _MAPPING_SESSION_VALIDITY_PERIOD_MS = 15 * 60 * 1000
+
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ self._clock = hs.get_clock()
+ self._store = hs.get_datastore()
+ self._server_name = hs.hostname
self._registration_handler = hs.get_registration_handler()
+ self._auth_handler = hs.get_auth_handler()
self._error_template = hs.config.sso_error_template
+ self._bad_user_template = hs.config.sso_auth_bad_user_template
+
+ # The following template is shown after a successful user interactive
+ # authentication session. It tells the user they can close the window.
+ self._sso_auth_success_template = hs.config.sso_auth_success_template
+
+ # a lock on the mappings
+ self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
+
+ # a map from session id to session data
+ self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession]
+
+ # map from idp_id to SsoIdentityProvider
+ self._identity_providers = {} # type: Dict[str, SsoIdentityProvider]
+
+ self._consent_at_registration = hs.config.consent.user_consent_at_registration
+
+ def register_identity_provider(self, p: SsoIdentityProvider):
+ p_id = p.idp_id
+ assert p_id not in self._identity_providers
+ self._identity_providers[p_id] = p
+
+ def get_identity_providers(self) -> Mapping[str, SsoIdentityProvider]:
+ """Get the configured identity providers"""
+ return self._identity_providers
+
+ async def get_identity_providers_for_user(
+ self, user_id: str
+ ) -> Mapping[str, SsoIdentityProvider]:
+ """Get the SsoIdentityProviders which a user has used
+
+ Given a user id, get the identity providers that that user has used to log in
+ with in the past (and thus could use to re-identify themselves for UI Auth).
+
+ Args:
+ user_id: MXID of user to look up
+
+ Raises:
+ a map of idp_id to SsoIdentityProvider
+ """
+ external_ids = await self._store.get_external_ids_by_user(user_id)
+
+ valid_idps = {}
+ for idp_id, _ in external_ids:
+ idp = self._identity_providers.get(idp_id)
+ if not idp:
+ logger.warning(
+ "User %r has an SSO mapping for IdP %r, but this is no longer "
+ "configured.",
+ user_id,
+ idp_id,
+ )
+ else:
+ valid_idps[idp_id] = idp
+
+ return valid_idps
def render_error(
- self, request, error: str, error_description: Optional[str] = None
+ self,
+ request: Request,
+ error: str,
+ error_description: Optional[str] = None,
+ code: int = 400,
) -> None:
"""Renders the error template and responds with it.
@@ -64,11 +250,53 @@ class SsoHandler(BaseHandler):
We'll respond with an HTML page describing the error.
error: A technical identifier for this error.
error_description: A human-readable description of the error.
+ code: The integer error code (an HTTP response code)
"""
html = self._error_template.render(
error=error, error_description=error_description
)
- respond_with_html(request, 400, html)
+ respond_with_html(request, code, html)
+
+ async def handle_redirect_request(
+ self,
+ request: SynapseRequest,
+ client_redirect_url: bytes,
+ idp_id: Optional[str],
+ ) -> str:
+ """Handle a request to /login/sso/redirect
+
+ Args:
+ request: incoming HTTP request
+ client_redirect_url: the URL that we should redirect the
+ client to after login.
+ idp_id: optional identity provider chosen by the client
+
+ Returns:
+ the URI to redirect to
+ """
+ if not self._identity_providers:
+ raise SynapseError(
+ 400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED
+ )
+
+ # if the client chose an IdP, use that
+ idp = None # type: Optional[SsoIdentityProvider]
+ if idp_id:
+ idp = self._identity_providers.get(idp_id)
+ if not idp:
+ raise NotFoundError("Unknown identity provider")
+
+ # if we only have one auth provider, redirect to it directly
+ elif len(self._identity_providers) == 1:
+ idp = next(iter(self._identity_providers.values()))
+
+ if idp:
+ return await idp.handle_redirect_request(request, client_redirect_url)
+
+ # otherwise, redirect to the IDP picker
+ return "/_synapse/client/pick_idp?" + urlencode(
+ (("redirectUrl", client_redirect_url),)
+ )
async def get_sso_user_by_remote_user_id(
self, auth_provider_id: str, remote_user_id: str
@@ -95,7 +323,7 @@ class SsoHandler(BaseHandler):
)
# Check if we already have a mapping for this user.
- previously_registered_user_id = await self.store.get_user_by_external_id(
+ previously_registered_user_id = await self._store.get_user_by_external_id(
auth_provider_id, remote_user_id,
)
@@ -112,15 +340,16 @@ class SsoHandler(BaseHandler):
# No match.
return None
- async def get_mxid_from_sso(
+ async def complete_sso_login_request(
self,
auth_provider_id: str,
remote_user_id: str,
- user_agent: str,
- ip_address: str,
+ request: SynapseRequest,
+ client_redirect_url: str,
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
- grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
- ) -> str:
+ grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
+ extra_login_attributes: Optional[JsonDict] = None,
+ ) -> None:
"""
Given an SSO ID, retrieve the user ID for it and possibly register the user.
@@ -139,12 +368,18 @@ class SsoHandler(BaseHandler):
given user-agent and IP address and the SSO ID is linked to this matrix
ID for subsequent calls.
+ Finally, we generate a redirect to the supplied redirect uri, with a login token
+
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml".
+
remote_user_id: The unique identifier from the SSO provider.
- user_agent: The user agent of the client making the request.
- ip_address: The IP address of the client making the request.
+
+ request: The request to respond to
+
+ client_redirect_url: The redirect URL passed in by the client.
+
sso_to_matrix_id_mapper: A callable to generate the user attributes.
The only parameter is an integer which represents the amount of
times the returned mxid localpart mapping has failed.
@@ -156,12 +391,13 @@ class SsoHandler(BaseHandler):
to the user.
RedirectException to redirect to an additional page (e.g.
to prompt the user for more information).
+
grandfather_existing_users: A callable which can return an previously
existing matrix ID. The SSO ID is then linked to the returned
matrix ID.
- Returns:
- The user ID associated with the SSO response.
+ extra_login_attributes: An optional dictionary of extra
+ attributes to be provided to the client in the login response.
Raises:
MappingException if there was a problem mapping the response to a user.
@@ -169,24 +405,62 @@ class SsoHandler(BaseHandler):
to an additional page. (e.g. to prompt for more information)
"""
- # first of all, check if we already have a mapping for this user
- previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
- auth_provider_id, remote_user_id,
- )
- if previously_registered_user_id:
- return previously_registered_user_id
+ new_user = False
+
+ # grab a lock while we try to find a mapping for this user. This seems...
+ # optimistic, especially for implementations that end up redirecting to
+ # interstitial pages.
+ with await self._mapping_lock.queue(auth_provider_id):
+ # first of all, check if we already have a mapping for this user
+ user_id = await self.get_sso_user_by_remote_user_id(
+ auth_provider_id, remote_user_id,
+ )
- # Check for grandfathering of users.
- if grandfather_existing_users:
- previously_registered_user_id = await grandfather_existing_users()
- if previously_registered_user_id:
- # Future logins should also match this user ID.
- await self.store.record_user_external_id(
- auth_provider_id, remote_user_id, previously_registered_user_id
+ # Check for grandfathering of users.
+ if not user_id:
+ user_id = await grandfather_existing_users()
+ if user_id:
+ # Future logins should also match this user ID.
+ await self._store.record_user_external_id(
+ auth_provider_id, remote_user_id, user_id
+ )
+
+ # Otherwise, generate a new user.
+ if not user_id:
+ attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
+
+ if attributes.localpart is None:
+ # the mapper doesn't return a username. bail out with a redirect to
+ # the username picker.
+ await self._redirect_to_username_picker(
+ auth_provider_id,
+ remote_user_id,
+ attributes,
+ client_redirect_url,
+ extra_login_attributes,
+ )
+
+ user_id = await self._register_mapped_user(
+ attributes,
+ auth_provider_id,
+ remote_user_id,
+ get_request_user_agent(request),
+ request.getClientIP(),
)
- return previously_registered_user_id
+ new_user = True
+
+ await self._auth_handler.complete_sso_login(
+ user_id,
+ request,
+ client_redirect_url,
+ extra_login_attributes,
+ new_user=new_user,
+ )
- # Otherwise, generate a new user.
+ async def _call_attribute_mapper(
+ self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
+ ) -> UserAttributes:
+ """Call the attribute mapper function in a loop, until we get a unique userid"""
for i in range(self._MAP_USERNAME_RETRIES):
try:
attributes = await sso_to_matrix_id_mapper(i)
@@ -208,14 +482,12 @@ class SsoHandler(BaseHandler):
)
if not attributes.localpart:
- raise MappingException(
- "Error parsing SSO response: SSO mapping provider plugin "
- "did not return a localpart value"
- )
+ # the mapper has not picked a localpart
+ return attributes
# Check if this mxid already exists
- user_id = UserID(attributes.localpart, self.server_name).to_string()
- if not await self.store.get_users_by_id_case_insensitive(user_id):
+ user_id = UserID(attributes.localpart, self._server_name).to_string()
+ if not await self._store.get_users_by_id_case_insensitive(user_id):
# This mxid is free
break
else:
@@ -224,10 +496,101 @@ class SsoHandler(BaseHandler):
raise MappingException(
"Unable to generate a Matrix ID from the SSO response"
)
+ return attributes
+
+ async def _redirect_to_username_picker(
+ self,
+ auth_provider_id: str,
+ remote_user_id: str,
+ attributes: UserAttributes,
+ client_redirect_url: str,
+ extra_login_attributes: Optional[JsonDict],
+ ) -> NoReturn:
+ """Creates a UsernameMappingSession and redirects the browser
+
+ Called if the user mapping provider doesn't return a localpart for a new user.
+ Raises a RedirectException which redirects the browser to the username picker.
+
+ Args:
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+
+ remote_user_id: The unique identifier from the SSO provider.
+
+ attributes: the user attributes returned by the user mapping provider.
+
+ client_redirect_url: The redirect URL passed in by the client, which we
+ will eventually redirect back to.
+
+ extra_login_attributes: An optional dictionary of extra
+ attributes to be provided to the client in the login response.
+
+ Raises:
+ RedirectException
+ """
+ session_id = random_string(16)
+ now = self._clock.time_msec()
+ session = UsernameMappingSession(
+ auth_provider_id=auth_provider_id,
+ remote_user_id=remote_user_id,
+ display_name=attributes.display_name,
+ emails=attributes.emails,
+ client_redirect_url=client_redirect_url,
+ expiry_time_ms=now + self._MAPPING_SESSION_VALIDITY_PERIOD_MS,
+ extra_login_attributes=extra_login_attributes,
+ )
+
+ self._username_mapping_sessions[session_id] = session
+ logger.info("Recorded registration session id %s", session_id)
+
+ # Set the cookie and redirect to the username picker
+ e = RedirectException(b"/_synapse/client/pick_username/account_details")
+ e.cookies.append(
+ b"%s=%s; path=/"
+ % (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii"))
+ )
+ raise e
+
+ async def _register_mapped_user(
+ self,
+ attributes: UserAttributes,
+ auth_provider_id: str,
+ remote_user_id: str,
+ user_agent: str,
+ ip_address: str,
+ ) -> str:
+ """Register a new SSO user.
+
+ This is called once we have successfully mapped the remote user id onto a local
+ user id, one way or another.
+
+ Args:
+ attributes: user attributes returned by the user mapping provider,
+ including a non-empty localpart.
+
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+
+ remote_user_id: The unique identifier from the SSO provider.
+
+ user_agent: The user-agent in the HTTP request (used for potential
+ shadow-banning.)
+
+ ip_address: The IP address of the requester (used for potential
+ shadow-banning.)
+
+ Raises:
+ a MappingException if the localpart is invalid.
+
+ a SynapseError with code 400 and errcode Codes.USER_IN_USE if the localpart
+ is already taken.
+ """
# Since the localpart is provided via a potentially untrusted module,
# ensure the MXID is valid before registering.
- if contains_invalid_mxid_characters(attributes.localpart):
+ if not attributes.localpart or contains_invalid_mxid_characters(
+ attributes.localpart
+ ):
raise MappingException("localpart is invalid: %s" % (attributes.localpart,))
logger.debug("Mapped SSO user to local part %s", attributes.localpart)
@@ -238,7 +601,292 @@ class SsoHandler(BaseHandler):
user_agent_ips=[(user_agent, ip_address)],
)
- await self.store.record_user_external_id(
+ await self._store.record_user_external_id(
auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id
+
+ async def complete_sso_ui_auth_request(
+ self,
+ auth_provider_id: str,
+ remote_user_id: str,
+ ui_auth_session_id: str,
+ request: Request,
+ ) -> None:
+ """
+ Given an SSO ID, retrieve the user ID for it and complete UIA.
+
+ Note that this requires that the user is mapped in the "user_external_ids"
+ table. This will be the case if they have ever logged in via SAML or OIDC in
+ recentish synapse versions, but may not be for older users.
+
+ Args:
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+ remote_user_id: The unique identifier from the SSO provider.
+ ui_auth_session_id: The ID of the user-interactive auth session.
+ request: The request to complete.
+ """
+
+ user_id = await self.get_sso_user_by_remote_user_id(
+ auth_provider_id, remote_user_id,
+ )
+
+ user_id_to_verify = await self._auth_handler.get_session_data(
+ ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
+ ) # type: str
+
+ if not user_id:
+ logger.warning(
+ "Remote user %s/%s has not previously logged in here: UIA will fail",
+ auth_provider_id,
+ remote_user_id,
+ )
+ elif user_id != user_id_to_verify:
+ logger.warning(
+ "Remote user %s/%s mapped onto incorrect user %s: UIA will fail",
+ auth_provider_id,
+ remote_user_id,
+ user_id,
+ )
+ else:
+ # success!
+ # Mark the stage of the authentication as successful.
+ await self._store.mark_ui_auth_stage_complete(
+ ui_auth_session_id, LoginType.SSO, user_id
+ )
+
+ # Render the HTML confirmation page and return.
+ html = self._sso_auth_success_template
+ respond_with_html(request, 200, html)
+ return
+
+ # the user_id didn't match: mark the stage of the authentication as unsuccessful
+ await self._store.mark_ui_auth_stage_complete(
+ ui_auth_session_id, LoginType.SSO, ""
+ )
+
+ # render an error page.
+ html = self._bad_user_template.render(
+ server_name=self._server_name, user_id_to_verify=user_id_to_verify,
+ )
+ respond_with_html(request, 200, html)
+
+ def get_mapping_session(self, session_id: str) -> UsernameMappingSession:
+ """Look up the given username mapping session
+
+ If it is not found, raises a SynapseError with an http code of 400
+
+ Args:
+ session_id: session to look up
+ Returns:
+ active mapping session
+ Raises:
+ SynapseError if the session is not found/has expired
+ """
+ self._expire_old_sessions()
+ session = self._username_mapping_sessions.get(session_id)
+ if session:
+ return session
+ logger.info("Couldn't find session id %s", session_id)
+ raise SynapseError(400, "unknown session")
+
+ async def check_username_availability(
+ self, localpart: str, session_id: str,
+ ) -> bool:
+ """Handle an "is username available" callback check
+
+ Args:
+ localpart: desired localpart
+ session_id: the session id for the username picker
+ Returns:
+ True if the username is available
+ Raises:
+ SynapseError if the localpart is invalid or the session is unknown
+ """
+
+ # make sure that there is a valid mapping session, to stop people dictionary-
+ # scanning for accounts
+ self.get_mapping_session(session_id)
+
+ logger.info(
+ "[session %s] Checking for availability of username %s",
+ session_id,
+ localpart,
+ )
+
+ if contains_invalid_mxid_characters(localpart):
+ raise SynapseError(400, "localpart is invalid: %s" % (localpart,))
+ user_id = UserID(localpart, self._server_name).to_string()
+ user_infos = await self._store.get_users_by_id_case_insensitive(user_id)
+
+ logger.info("[session %s] users: %s", session_id, user_infos)
+ return not user_infos
+
+ async def handle_submit_username_request(
+ self,
+ request: SynapseRequest,
+ session_id: str,
+ localpart: str,
+ use_display_name: bool,
+ emails_to_use: Iterable[str],
+ ) -> None:
+ """Handle a request to the username-picker 'submit' endpoint
+
+ Will serve an HTTP response to the request.
+
+ Args:
+ request: HTTP request
+ localpart: localpart requested by the user
+ session_id: ID of the username mapping session, extracted from a cookie
+ use_display_name: whether the user wants to use the suggested display name
+ emails_to_use: emails that the user would like to use
+ """
+ session = self.get_mapping_session(session_id)
+
+ # update the session with the user's choices
+ session.chosen_localpart = localpart
+ session.use_display_name = use_display_name
+
+ emails_from_idp = set(session.emails)
+ filtered_emails = set() # type: Set[str]
+
+ # we iterate through the list rather than just building a set conjunction, so
+ # that we can log attempts to use unknown addresses
+ for email in emails_to_use:
+ if email in emails_from_idp:
+ filtered_emails.add(email)
+ else:
+ logger.warning(
+ "[session %s] ignoring user request to use unknown email address %r",
+ session_id,
+ email,
+ )
+ session.emails_to_use = filtered_emails
+
+ # we may now need to collect consent from the user, in which case, redirect
+ # to the consent-extraction-unit
+ if self._consent_at_registration:
+ redirect_url = b"/_synapse/client/new_user_consent"
+
+ # otherwise, redirect to the completion page
+ else:
+ redirect_url = b"/_synapse/client/sso_register"
+
+ respond_with_redirect(request, redirect_url)
+
+ async def handle_terms_accepted(
+ self, request: Request, session_id: str, terms_version: str
+ ):
+ """Handle a request to the new-user 'consent' endpoint
+
+ Will serve an HTTP response to the request.
+
+ Args:
+ request: HTTP request
+ session_id: ID of the username mapping session, extracted from a cookie
+ terms_version: the version of the terms which the user viewed and consented
+ to
+ """
+ logger.info(
+ "[session %s] User consented to terms version %s",
+ session_id,
+ terms_version,
+ )
+ session = self.get_mapping_session(session_id)
+ session.terms_accepted_version = terms_version
+
+ # we're done; now we can register the user
+ respond_with_redirect(request, b"/_synapse/client/sso_register")
+
+ async def register_sso_user(self, request: Request, session_id: str) -> None:
+ """Called once we have all the info we need to register a new user.
+
+ Does so and serves an HTTP response
+
+ Args:
+ request: HTTP request
+ session_id: ID of the username mapping session, extracted from a cookie
+ """
+ session = self.get_mapping_session(session_id)
+
+ logger.info(
+ "[session %s] Registering localpart %s",
+ session_id,
+ session.chosen_localpart,
+ )
+
+ attributes = UserAttributes(
+ localpart=session.chosen_localpart, emails=session.emails_to_use,
+ )
+
+ if session.use_display_name:
+ attributes.display_name = session.display_name
+
+ # the following will raise a 400 error if the username has been taken in the
+ # meantime.
+ user_id = await self._register_mapped_user(
+ attributes,
+ session.auth_provider_id,
+ session.remote_user_id,
+ get_request_user_agent(request),
+ request.getClientIP(),
+ )
+
+ logger.info(
+ "[session %s] Registered userid %s with attributes %s",
+ session_id,
+ user_id,
+ attributes,
+ )
+
+ # delete the mapping session and the cookie
+ del self._username_mapping_sessions[session_id]
+
+ # delete the cookie
+ request.addCookie(
+ USERNAME_MAPPING_SESSION_COOKIE_NAME,
+ b"",
+ expires=b"Thu, 01 Jan 1970 00:00:00 GMT",
+ path=b"/",
+ )
+
+ auth_result = {}
+ if session.terms_accepted_version:
+ # TODO: make this less awful.
+ auth_result[LoginType.TERMS] = True
+
+ await self._registration_handler.post_registration_actions(
+ user_id, auth_result, access_token=None
+ )
+
+ await self._auth_handler.complete_sso_login(
+ user_id,
+ request,
+ session.client_redirect_url,
+ session.extra_login_attributes,
+ new_user=True,
+ )
+
+ def _expire_old_sessions(self):
+ to_expire = []
+ now = int(self._clock.time_msec())
+
+ for session_id, session in self._username_mapping_sessions.items():
+ if session.expiry_time_ms <= now:
+ to_expire.append(session_id)
+
+ for session_id in to_expire:
+ logger.info("Expiring mapping session %s", session_id)
+ del self._username_mapping_sessions[session_id]
+
+
+def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
+ """Extract the session ID from the cookie
+
+ Raises a SynapseError if the cookie isn't found
+ """
+ session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
+ if not session_id:
+ raise SynapseError(code=400, msg="missing session_id")
+ return session_id.decode("ascii", errors="replace")
diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py
index fb4f70e8e2..b3f9875358 100644
--- a/synapse/handlers/state_deltas.py
+++ b/synapse/handlers/state_deltas.py
@@ -14,15 +14,25 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Optional
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
class StateDeltasHandler:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
- async def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
+ async def _get_key_change(
+ self,
+ prev_event_id: Optional[str],
+ event_id: Optional[str],
+ key_name: str,
+ public_value: str,
+ ) -> Optional[bool]:
"""Given two events check if the `key_name` field in content changed
from not matching `public_value` to doing so.
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index dc62b21c06..d261d7cd4e 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -12,13 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
from collections import Counter
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple
+
+from typing_extensions import Counter as CounterType
from synapse.api.constants import EventTypes, Membership
from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@@ -31,7 +37,7 @@ class StatsHandler:
Heavily derived from UserDirectoryHandler
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
@@ -44,7 +50,7 @@ class StatsHandler:
self.stats_enabled = hs.config.stats_enabled
# The current position in the current_state_delta stream
- self.pos = None
+ self.pos = None # type: Optional[int]
# Guard to ensure we only process deltas one at a time
self._is_processing = False
@@ -56,7 +62,7 @@ class StatsHandler:
# we start populating stats
self.clock.call_later(0, self.notify_new_event)
- def notify_new_event(self):
+ def notify_new_event(self) -> None:
"""Called when there may be more deltas to process
"""
if not self.stats_enabled or self._is_processing:
@@ -72,7 +78,7 @@ class StatsHandler:
run_as_background_process("stats.notify_new_event", process)
- async def _unsafe_process(self):
+ async def _unsafe_process(self) -> None:
# If self.pos is None then means we haven't fetched it from DB
if self.pos is None:
self.pos = await self.store.get_stats_positions()
@@ -110,10 +116,10 @@ class StatsHandler:
)
for room_id, fields in room_count.items():
- room_deltas.setdefault(room_id, {}).update(fields)
+ room_deltas.setdefault(room_id, Counter()).update(fields)
for user_id, fields in user_count.items():
- user_deltas.setdefault(user_id, {}).update(fields)
+ user_deltas.setdefault(user_id, Counter()).update(fields)
logger.debug("room_deltas: %s", room_deltas)
logger.debug("user_deltas: %s", user_deltas)
@@ -131,19 +137,20 @@ class StatsHandler:
self.pos = max_pos
- async def _handle_deltas(self, deltas):
+ async def _handle_deltas(
+ self, deltas: Iterable[JsonDict]
+ ) -> Tuple[Dict[str, CounterType[str]], Dict[str, CounterType[str]]]:
"""Called with the state deltas to process
Returns:
- tuple[dict[str, Counter], dict[str, counter]]
Two dicts: the room deltas and the user deltas,
mapping from room/user ID to changes in the various fields.
"""
- room_to_stats_deltas = {}
- user_to_stats_deltas = {}
+ room_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
+ user_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
- room_to_state_updates = {}
+ room_to_state_updates = {} # type: Dict[str, Dict[str, Any]]
for delta in deltas:
typ = delta["type"]
@@ -173,7 +180,7 @@ class StatsHandler:
)
continue
- event_content = {}
+ event_content = {} # type: JsonDict
sender = None
if event_id is not None:
@@ -257,13 +264,13 @@ class StatsHandler:
)
if has_changed_joinedness:
- delta = +1 if membership == Membership.JOIN else -1
+ membership_delta = +1 if membership == Membership.JOIN else -1
user_to_stats_deltas.setdefault(user_id, Counter())[
"joined_rooms"
- ] += delta
+ ] += membership_delta
- room_stats_delta["local_users_in_room"] += delta
+ room_stats_delta["local_users_in_room"] += membership_delta
elif typ == EventTypes.Create:
room_state["is_federatable"] = (
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 9827c7eb8d..5c7590f38e 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -554,7 +554,7 @@ class SyncHandler:
event.event_id, state_filter=state_filter
)
if event.is_state():
- state_ids = state_ids.copy()
+ state_ids = dict(state_ids)
state_ids[(event.type, event.state_key)] = event.event_id
return state_ids
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index e919a8f9ed..3f0dfc7a74 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -15,13 +15,13 @@
import logging
import random
from collections import namedtuple
-from typing import TYPE_CHECKING, List, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.errors import AuthError, ShadowBanError, SynapseError
from synapse.appservice import ApplicationService
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.streams import TypingStream
-from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
@@ -65,17 +65,17 @@ class FollowerTypingHandler:
)
# map room IDs to serial numbers
- self._room_serials = {}
+ self._room_serials = {} # type: Dict[str, int]
# map room IDs to sets of users currently typing
- self._room_typing = {}
+ self._room_typing = {} # type: Dict[str, Set[str]]
- self._member_last_federation_poke = {}
+ self._member_last_federation_poke = {} # type: Dict[RoomMember, int]
self.wheel_timer = WheelTimer(bucket_size=5000)
self._latest_room_serial = 0
self.clock.looping_call(self._handle_timeouts, 5000)
- def _reset(self):
+ def _reset(self) -> None:
"""Reset the typing handler's data caches.
"""
# map room IDs to serial numbers
@@ -86,7 +86,7 @@ class FollowerTypingHandler:
self._member_last_federation_poke = {}
self.wheel_timer = WheelTimer(bucket_size=5000)
- def _handle_timeouts(self):
+ def _handle_timeouts(self) -> None:
logger.debug("Checking for typing timeouts")
now = self.clock.time_msec()
@@ -96,7 +96,7 @@ class FollowerTypingHandler:
for member in members:
self._handle_timeout_for_member(now, member)
- def _handle_timeout_for_member(self, now: int, member: RoomMember):
+ def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
if not self.is_typing(member):
# Nothing to do if they're no longer typing
return
@@ -114,10 +114,10 @@ class FollowerTypingHandler:
# each person typing.
self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
- def is_typing(self, member):
+ def is_typing(self, member: RoomMember) -> bool:
return member.user_id in self._room_typing.get(member.room_id, [])
- async def _push_remote(self, member, typing):
+ async def _push_remote(self, member: RoomMember, typing: bool) -> None:
if not self.federation:
return
@@ -148,7 +148,7 @@ class FollowerTypingHandler:
def process_replication_rows(
self, token: int, rows: List[TypingStream.TypingStreamRow]
- ):
+ ) -> None:
"""Should be called whenever we receive updates for typing stream.
"""
@@ -178,7 +178,7 @@ class FollowerTypingHandler:
async def _send_changes_in_typing_to_remotes(
self, room_id: str, prev_typing: Set[str], now_typing: Set[str]
- ):
+ ) -> None:
"""Process a change in typing of a room from replication, sending EDUs
for any local users.
"""
@@ -194,12 +194,12 @@ class FollowerTypingHandler:
if self.is_mine_id(user_id):
await self._push_remote(RoomMember(room_id, user_id), False)
- def get_current_token(self):
+ def get_current_token(self) -> int:
return self._latest_room_serial
class TypingWriterHandler(FollowerTypingHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
assert hs.config.worker.writers.typing == hs.get_instance_name()
@@ -213,14 +213,15 @@ class TypingWriterHandler(FollowerTypingHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room)
- self._member_typing_until = {} # clock time we expect to stop
+ # clock time we expect to stop
+ self._member_typing_until = {} # type: Dict[RoomMember, int]
# caches which room_ids changed at which serials
self._typing_stream_change_cache = StreamChangeCache(
"TypingStreamChangeCache", self._latest_room_serial
)
- def _handle_timeout_for_member(self, now: int, member: RoomMember):
+ def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
super()._handle_timeout_for_member(now, member)
if not self.is_typing(member):
@@ -233,7 +234,9 @@ class TypingWriterHandler(FollowerTypingHandler):
self._stopped_typing(member)
return
- async def started_typing(self, target_user, requester, room_id, timeout):
+ async def started_typing(
+ self, target_user: UserID, requester: Requester, room_id: str, timeout: int
+ ) -> None:
target_user_id = target_user.to_string()
auth_user_id = requester.user.to_string()
@@ -263,11 +266,13 @@ class TypingWriterHandler(FollowerTypingHandler):
if was_present:
# No point sending another notification
- return None
+ return
self._push_update(member=member, typing=True)
- async def stopped_typing(self, target_user, requester, room_id):
+ async def stopped_typing(
+ self, target_user: UserID, requester: Requester, room_id: str
+ ) -> None:
target_user_id = target_user.to_string()
auth_user_id = requester.user.to_string()
@@ -290,23 +295,23 @@ class TypingWriterHandler(FollowerTypingHandler):
self._stopped_typing(member)
- def user_left_room(self, user, room_id):
+ def user_left_room(self, user: UserID, room_id: str) -> None:
user_id = user.to_string()
if self.is_mine_id(user_id):
member = RoomMember(room_id=room_id, user_id=user_id)
self._stopped_typing(member)
- def _stopped_typing(self, member):
+ def _stopped_typing(self, member: RoomMember) -> None:
if member.user_id not in self._room_typing.get(member.room_id, set()):
# No point
- return None
+ return
self._member_typing_until.pop(member, None)
self._member_last_federation_poke.pop(member, None)
self._push_update(member=member, typing=False)
- def _push_update(self, member, typing):
+ def _push_update(self, member: RoomMember, typing: bool) -> None:
if self.hs.is_mine_id(member.user_id):
# Only send updates for changes to our own users.
run_as_background_process(
@@ -315,7 +320,7 @@ class TypingWriterHandler(FollowerTypingHandler):
self._push_update_local(member=member, typing=typing)
- async def _recv_edu(self, origin, content):
+ async def _recv_edu(self, origin: str, content: JsonDict) -> None:
room_id = content["room_id"]
user_id = content["user_id"]
@@ -340,7 +345,7 @@ class TypingWriterHandler(FollowerTypingHandler):
self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT)
self._push_update_local(member=member, typing=content["typing"])
- def _push_update_local(self, member, typing):
+ def _push_update_local(self, member: RoomMember, typing: bool) -> None:
room_set = self._room_typing.setdefault(member.room_id, set())
if typing:
room_set.add(member.user_id)
@@ -386,7 +391,7 @@ class TypingWriterHandler(FollowerTypingHandler):
changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
last_id
- )
+ ) # type: Optional[Iterable[str]]
if changed_rooms is None:
changed_rooms = self._room_serials
@@ -412,13 +417,13 @@ class TypingWriterHandler(FollowerTypingHandler):
def process_replication_rows(
self, token: int, rows: List[TypingStream.TypingStreamRow]
- ):
+ ) -> None:
# The writing process should never get updates from replication.
raise Exception("Typing writer instance got typing info over replication")
class TypingNotificationEventSource:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.clock = hs.get_clock()
# We can't call get_typing_handler here because there's a cycle:
@@ -427,7 +432,7 @@ class TypingNotificationEventSource:
#
self.get_typing_handler = hs.get_typing_handler
- def _make_event_for(self, room_id):
+ def _make_event_for(self, room_id: str) -> JsonDict:
typing = self.get_typing_handler()._room_typing[room_id]
return {
"type": "m.typing",
@@ -462,7 +467,9 @@ class TypingNotificationEventSource:
return (events, handler._latest_room_serial)
- async def get_new_events(self, from_key, room_ids, **kwargs):
+ async def get_new_events(
+ self, from_key: int, room_ids: Iterable[str], **kwargs
+ ) -> Tuple[List[JsonDict], int]:
with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key)
handler = self.get_typing_handler()
@@ -478,5 +485,5 @@ class TypingNotificationEventSource:
return (events, handler._latest_room_serial)
- def get_current_key(self):
+ def get_current_key(self) -> int:
return self.get_typing_handler()._latest_room_serial
diff --git a/synapse/handlers/ui_auth/__init__.py b/synapse/handlers/ui_auth/__init__.py
index 824f37f8f8..a68d5e790e 100644
--- a/synapse/handlers/ui_auth/__init__.py
+++ b/synapse/handlers/ui_auth/__init__.py
@@ -20,3 +20,18 @@ TODO: move more stuff out of AuthHandler in here.
"""
from synapse.handlers.ui_auth.checkers import INTERACTIVE_AUTH_CHECKERS # noqa: F401
+
+
+class UIAuthSessionDataConstants:
+ """Constants for use with AuthHandler.set_session_data"""
+
+ # used during registration and password reset to store a hashed copy of the
+ # password, so that the client does not need to submit it each time.
+ PASSWORD_HASH = "password_hash"
+
+ # used during registration to store the mxid of the registered user
+ REGISTERED_USER_ID = "registered_user_id"
+
+ # used by validate_user_via_ui_auth to store the mxid of the user we are validating
+ # for.
+ REQUEST_USER_ID = "request_user_id"
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index afbebfc200..8aedf5072e 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -14,14 +14,19 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
import synapse.metrics
-from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership
from synapse.handlers.state_deltas import StateDeltasHandler
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.roommember import ProfileInfo
+from synapse.types import JsonDict
from synapse.util.metrics import Measure
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
@@ -36,7 +41,7 @@ class UserDirectoryHandler(StateDeltasHandler):
be in the directory or not when necessary.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.store = hs.get_datastore()
@@ -49,7 +54,7 @@ class UserDirectoryHandler(StateDeltasHandler):
self.search_all_users = hs.config.user_directory_search_all_users
self.spam_checker = hs.get_spam_checker()
# The current position in the current_state_delta stream
- self.pos = None
+ self.pos = None # type: Optional[int]
# Guard to ensure we only process deltas one at a time
self._is_processing = False
@@ -61,7 +66,9 @@ class UserDirectoryHandler(StateDeltasHandler):
# we start populating the user directory
self.clock.call_later(0, self.notify_new_event)
- async def search_users(self, user_id, search_term, limit):
+ async def search_users(
+ self, user_id: str, search_term: str, limit: int
+ ) -> JsonDict:
"""Searches for users in directory
Returns:
@@ -81,15 +88,15 @@ class UserDirectoryHandler(StateDeltasHandler):
results = await self.store.search_user_dir(user_id, search_term, limit)
# Remove any spammy users from the results.
- results["results"] = [
- user
- for user in results["results"]
- if not self.spam_checker.check_username_for_spam(user)
- ]
+ non_spammy_users = []
+ for user in results["results"]:
+ if not await self.spam_checker.check_username_for_spam(user):
+ non_spammy_users.append(user)
+ results["results"] = non_spammy_users
return results
- def notify_new_event(self):
+ def notify_new_event(self) -> None:
"""Called when there may be more deltas to process
"""
if not self.update_user_directory:
@@ -107,35 +114,37 @@ class UserDirectoryHandler(StateDeltasHandler):
self._is_processing = True
run_as_background_process("user_directory.notify_new_event", process)
- async def handle_local_profile_change(self, user_id, profile):
+ async def handle_local_profile_change(
+ self, user_id: str, profile: ProfileInfo
+ ) -> None:
"""Called to update index of our local user profiles when they change
irrespective of any rooms the user may be in.
"""
# FIXME(#3714): We should probably do this in the same worker as all
# the other changes.
- is_support = await self.store.is_support_user(user_id)
+
# Support users are for diagnostics and should not appear in the user directory.
- if not is_support:
+ is_support = await self.store.is_support_user(user_id)
+ # When change profile information of deactivated user it should not appear in the user directory.
+ is_deactivated = await self.store.get_user_deactivated_status(user_id)
+
+ if not (is_support or is_deactivated):
await self.store.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url
)
- async def handle_user_deactivated(self, user_id):
+ async def handle_user_deactivated(self, user_id: str) -> None:
"""Called when a user ID is deactivated
"""
# FIXME(#3714): We should probably do this in the same worker as all
# the other changes.
await self.store.remove_from_user_dir(user_id)
- async def _unsafe_process(self):
+ async def _unsafe_process(self) -> None:
# If self.pos is None then means we haven't fetched it from DB
if self.pos is None:
self.pos = await self.store.get_user_directory_stream_pos()
- # If still None then the initial background update hasn't happened yet
- if self.pos is None:
- return None
-
# Loop round handling deltas until we're up to date
while True:
with Measure(self.clock, "user_dir_delta"):
@@ -162,7 +171,7 @@ class UserDirectoryHandler(StateDeltasHandler):
await self.store.update_user_directory_stream_pos(max_pos)
- async def _handle_deltas(self, deltas):
+ async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None:
"""Called with the state deltas to process
"""
for delta in deltas:
@@ -220,6 +229,11 @@ class UserDirectoryHandler(StateDeltasHandler):
if change: # The user joined
event = await self.store.get_event(event_id, allow_none=True)
+ # It isn't expected for this event to not exist, but we
+ # don't want the entire background process to break.
+ if event is None:
+ continue
+
profile = ProfileInfo(
avatar_url=event.content.get("avatar_url"),
display_name=event.content.get("displayname"),
@@ -232,16 +246,20 @@ class UserDirectoryHandler(StateDeltasHandler):
logger.debug("Ignoring irrelevant type: %r", typ)
async def _handle_room_publicity_change(
- self, room_id, prev_event_id, event_id, typ
- ):
+ self,
+ room_id: str,
+ prev_event_id: Optional[str],
+ event_id: Optional[str],
+ typ: str,
+ ) -> None:
"""Handle a room having potentially changed from/to world_readable/publicly
joinable.
Args:
- room_id (str)
- prev_event_id (str|None): The previous event before the state change
- event_id (str|None): The new event after the state change
- typ (str): Type of the event
+ room_id: The ID of the room which changed.
+ prev_event_id: The previous event before the state change
+ event_id: The new event after the state change
+ typ: Type of the event
"""
logger.debug("Handling change for %s: %s", typ, room_id)
@@ -250,7 +268,7 @@ class UserDirectoryHandler(StateDeltasHandler):
prev_event_id,
event_id,
key_name="history_visibility",
- public_value="world_readable",
+ public_value=HistoryVisibility.WORLD_READABLE,
)
elif typ == EventTypes.JoinRules:
change = await self._get_key_change(
@@ -299,12 +317,14 @@ class UserDirectoryHandler(StateDeltasHandler):
for user_id, profile in users_with_profile.items():
await self._handle_new_user(room_id, user_id, profile)
- async def _handle_new_user(self, room_id, user_id, profile):
+ async def _handle_new_user(
+ self, room_id: str, user_id: str, profile: ProfileInfo
+ ) -> None:
"""Called when we might need to add user to directory
Args:
- room_id (str): room_id that user joined or started being public
- user_id (str)
+ room_id: The room ID that user joined or started being public
+ user_id
"""
logger.debug("Adding new user to dir, %r", user_id)
@@ -352,12 +372,12 @@ class UserDirectoryHandler(StateDeltasHandler):
if to_insert:
await self.store.add_users_who_share_private_room(room_id, to_insert)
- async def _handle_remove_user(self, room_id, user_id):
+ async def _handle_remove_user(self, room_id: str, user_id: str) -> None:
"""Called when we might need to remove user from directory
Args:
- room_id (str): room_id that user left or stopped being public that
- user_id (str)
+ room_id: The room ID that user left or stopped being public that
+ user_id
"""
logger.debug("Removing user %r", user_id)
@@ -370,7 +390,13 @@ class UserDirectoryHandler(StateDeltasHandler):
if len(rooms_user_is_in) == 0:
await self.store.remove_from_user_dir(user_id)
- async def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id):
+ async def _handle_profile_change(
+ self,
+ user_id: str,
+ room_id: str,
+ prev_event_id: Optional[str],
+ event_id: Optional[str],
+ ) -> None:
"""Check member event changes for any profile changes and update the
database if there are.
"""
|