diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index bd8e71ae56..bb81c0e81d 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -169,7 +169,9 @@ class BaseHandler:
# and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having
# homeserver.
- requester = synapse.types.create_requester(target_user, is_guest=True)
+ requester = synapse.types.create_requester(
+ target_user, is_guest=True, authenticated_entity=self.server_name
+ )
handler = self.hs.get_room_member_handler()
await handler.update_membership(
requester,
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 9fc8444228..5c6458eb52 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -226,7 +226,7 @@ class ApplicationServicesHandler:
new_token: Optional[int],
users: Collection[Union[str, UserID]],
):
- logger.info("Checking interested services for %s" % (stream_key))
+ logger.debug("Checking interested services for %s" % (stream_key))
with Measure(self.clock, "notify_interested_services_ephemeral"):
for service in services:
# Only handle typing if we have the latest token
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 213baea2e3..5163afd86c 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -698,8 +698,12 @@ class AuthHandler(BaseHandler):
}
async def get_access_token_for_user_id(
- self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
- ):
+ self,
+ user_id: str,
+ device_id: Optional[str],
+ valid_until_ms: Optional[int],
+ puppets_user_id: Optional[str] = None,
+ ) -> str:
"""
Creates a new access token for the user with the given user ID.
@@ -725,13 +729,25 @@ class AuthHandler(BaseHandler):
fmt_expiry = time.strftime(
" until %Y-%m-%d %H:%M:%S", time.localtime(valid_until_ms / 1000.0)
)
- logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry)
+
+ if puppets_user_id:
+ logger.info(
+ "Logging in user %s as %s%s", user_id, puppets_user_id, fmt_expiry
+ )
+ else:
+ logger.info(
+ "Logging in user %s on device %s%s", user_id, device_id, fmt_expiry
+ )
await self.auth.check_auth_blocking(user_id)
access_token = self.macaroon_gen.generate_access_token(user_id)
await self.store.add_access_token_to_user(
- user_id, access_token, device_id, valid_until_ms
+ user_id=user_id,
+ token=access_token,
+ device_id=device_id,
+ valid_until_ms=valid_until_ms,
+ puppets_user_id=puppets_user_id,
)
# the device *should* have been registered before we got here; however,
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 048a3b3c0b..f4ea0a9767 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
import urllib
-from typing import Dict, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Optional, Tuple
from xml.etree import ElementTree as ET
from twisted.web.client import PartialDownloadError
@@ -23,6 +23,9 @@ from synapse.api.errors import Codes, LoginError
from synapse.http.site import SynapseRequest
from synapse.types import UserID, map_username_to_mxid_localpart
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
@@ -31,10 +34,10 @@ class CasHandler:
Utility class for to handle the response from a CAS SSO service.
Args:
- hs (synapse.server.HomeServer)
+ hs
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler()
@@ -200,27 +203,57 @@ class CasHandler:
args["session"] = session
username, user_display_name = await self._validate_ticket(ticket, args)
- localpart = map_username_to_mxid_localpart(username)
- user_id = UserID(localpart, self._hostname).to_string()
- registered_user_id = await self._auth_handler.check_user_exists(user_id)
+ # Pull out the user-agent and IP from the request.
+ user_agent = request.get_user_agent("")
+ ip_address = self.hs.get_ip_from_request(request)
+
+ # Get the matrix ID from the CAS username.
+ user_id = await self._map_cas_user_to_matrix_user(
+ username, user_display_name, user_agent, ip_address
+ )
if session:
await self._auth_handler.complete_sso_ui_auth(
- registered_user_id, session, request,
+ user_id, session, request,
)
-
else:
- if not registered_user_id:
- # Pull out the user-agent and IP from the request.
- user_agent = request.get_user_agent("")
- ip_address = self.hs.get_ip_from_request(request)
-
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart,
- default_display_name=user_display_name,
- user_agent_ips=(user_agent, ip_address),
- )
+ # If this not a UI auth request than there must be a redirect URL.
+ assert client_redirect_url
await self._auth_handler.complete_sso_login(
- registered_user_id, request, client_redirect_url
+ user_id, request, client_redirect_url
)
+
+ async def _map_cas_user_to_matrix_user(
+ self,
+ remote_user_id: str,
+ display_name: Optional[str],
+ user_agent: str,
+ ip_address: str,
+ ) -> str:
+ """
+ Given a CAS username, retrieve the user ID for it and possibly register the user.
+
+ Args:
+ remote_user_id: The username from the CAS response.
+ display_name: The display name from the CAS response.
+ user_agent: The user agent of the client making the request.
+ ip_address: The IP address of the client making the request.
+
+ Returns:
+ The user ID associated with this response.
+ """
+
+ localpart = map_username_to_mxid_localpart(remote_user_id)
+ user_id = UserID(localpart, self._hostname).to_string()
+ registered_user_id = await self._auth_handler.check_user_exists(user_id)
+
+ # If the user does not exist, register it.
+ if not registered_user_id:
+ registered_user_id = await self._registration_handler.register_user(
+ localpart=localpart,
+ default_display_name=display_name,
+ user_agent_ips=[(user_agent, ip_address)],
+ )
+
+ return registered_user_id
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 4efe6c530a..e808142365 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -39,6 +39,7 @@ class DeactivateAccountHandler(BaseHandler):
self._room_member_handler = hs.get_room_member_handler()
self._identity_handler = hs.get_identity_handler()
self.user_directory_handler = hs.get_user_directory_handler()
+ self._server_name = hs.hostname
# Flag that indicates whether the process to part users from rooms is running
self._user_parter_running = False
@@ -152,7 +153,7 @@ class DeactivateAccountHandler(BaseHandler):
for room in pending_invites:
try:
await self._room_member_handler.update_membership(
- create_requester(user),
+ create_requester(user, authenticated_entity=self._server_name),
user,
room.room_id,
"leave",
@@ -208,7 +209,7 @@ class DeactivateAccountHandler(BaseHandler):
logger.info("User parter parting %r from %r", user_id, room_id)
try:
await self._room_member_handler.update_membership(
- create_requester(user),
+ create_requester(user, authenticated_entity=self._server_name),
user,
room_id,
"leave",
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index c386957706..b9799090f7 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -55,6 +55,7 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.handlers._base import BaseHandler
+from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import (
make_deferred_yieldable,
nested_logging_context,
@@ -67,7 +68,7 @@ from synapse.replication.http.devices import ReplicationUserDevicesResyncRestSer
from synapse.replication.http.federation import (
ReplicationCleanRoomRestServlet,
ReplicationFederationSendEventsRestServlet,
- ReplicationStoreRoomOnInviteRestServlet,
+ ReplicationStoreRoomOnOutlierMembershipRestServlet,
)
from synapse.state import StateResolutionStore
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
@@ -152,12 +153,14 @@ class FederationHandler(BaseHandler):
self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
hs
)
- self._maybe_store_room_on_invite = ReplicationStoreRoomOnInviteRestServlet.make_client(
+ self._maybe_store_room_on_outlier_membership = ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(
hs
)
else:
self._device_list_updater = hs.get_device_handler().device_list_updater
- self._maybe_store_room_on_invite = self.store.maybe_store_room_on_invite
+ self._maybe_store_room_on_outlier_membership = (
+ self.store.maybe_store_room_on_outlier_membership
+ )
# When joining a room we need to queue any events for that room up.
# For each room, a list of (pdu, origin) tuples.
@@ -1617,7 +1620,7 @@ class FederationHandler(BaseHandler):
# keep a record of the room version, if we don't yet know it.
# (this may get overwritten if we later get a different room version in a
# join dance).
- await self._maybe_store_room_on_invite(
+ await self._maybe_store_room_on_outlier_membership(
room_id=event.room_id, room_version=room_version
)
@@ -2686,7 +2689,7 @@ class FederationHandler(BaseHandler):
)
async def on_exchange_third_party_invite_request(
- self, room_id: str, event_dict: JsonDict
+ self, event_dict: JsonDict
) -> None:
"""Handle an exchange_third_party_invite request from a remote server
@@ -2694,12 +2697,11 @@ class FederationHandler(BaseHandler):
into a normal m.room.member invite.
Args:
- room_id: The ID of the room.
-
- event_dict (dict[str, Any]): Dictionary containing the event body.
+ event_dict: Dictionary containing the event body.
"""
- room_version = await self.store.get_room_version_id(room_id)
+ assert_params_in_dict(event_dict, ["room_id"])
+ room_version = await self.store.get_room_version_id(event_dict["room_id"])
# NB: event_dict has a particular specced format we might need to fudge
# if we change event formats too much.
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 2f3f3a7ef5..11420ea996 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -472,7 +472,7 @@ class EventCreationHandler:
Returns:
Tuple of created event, Context
"""
- await self.auth.check_auth_blocking(requester.user.to_string())
+ await self.auth.check_auth_blocking(requester=requester)
if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
room_version = event_dict["content"]["room_version"]
@@ -619,7 +619,13 @@ class EventCreationHandler:
if requester.app_service is not None:
return
- user_id = requester.user.to_string()
+ user_id = requester.authenticated_entity
+ if not user_id.startswith("@"):
+ # The authenticated entity might not be a user, e.g. if it's the
+ # server puppetting the user.
+ return
+
+ user = UserID.from_string(user_id)
# exempt the system notices user
if (
@@ -639,9 +645,7 @@ class EventCreationHandler:
if u["consent_version"] == self.config.user_consent_version:
return
- consent_uri = self._consent_uri_builder.build_user_consent_uri(
- requester.user.localpart
- )
+ consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart)
msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
@@ -1252,7 +1256,7 @@ class EventCreationHandler:
for user_id in members:
if not self.hs.is_mine_id(user_id):
continue
- requester = create_requester(user_id)
+ requester = create_requester(user_id, authenticated_entity=self.server_name)
try:
event, context = await self.create_event(
requester,
@@ -1273,11 +1277,6 @@ class EventCreationHandler:
requester, event, context, ratelimit=False, ignore_shadow_ban=True,
)
return True
- except ConsentNotGivenError:
- logger.info(
- "Failed to send dummy event into room %s for user %s due to "
- "lack of consent. Will try another user" % (room_id, user_id)
- )
except AuthError:
logger.info(
"Failed to send dummy event into room %s for user %s due to "
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 331d4e7e96..78c4e94a9d 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
from urllib.parse import urlencode
@@ -34,10 +35,11 @@ from typing_extensions import TypedDict
from twisted.web.client import readBody
from synapse.config import ConfigError
-from synapse.http.server import respond_with_html
+from synapse.handlers._base import BaseHandler
+from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
-from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
+from synapse.types import JsonDict, map_username_to_mxid_localpart
from synapse.util import json_decoder
if TYPE_CHECKING:
@@ -83,17 +85,12 @@ class OidcError(Exception):
return self.error
-class MappingException(Exception):
- """Used to catch errors when mapping the UserInfo object
- """
-
-
-class OidcHandler:
+class OidcHandler(BaseHandler):
"""Handles requests related to the OpenID Connect login flow.
"""
def __init__(self, hs: "HomeServer"):
- self.hs = hs
+ super().__init__(hs)
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
self._user_profile_method = hs.config.oidc_user_profile_method # type: str
@@ -120,36 +117,13 @@ class OidcHandler:
self._http_client = hs.get_proxied_http_client()
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
- self._datastore = hs.get_datastore()
- self._clock = hs.get_clock()
- self._hostname = hs.hostname # type: str
self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key
- self._error_template = hs.config.sso_error_template
# identifier for the external_ids table
self._auth_provider_id = "oidc"
- def _render_error(
- self, request, error: str, error_description: Optional[str] = None
- ) -> None:
- """Render the error template and respond to the request with it.
-
- This is used to show errors to the user. The template of this page can
- be found under `synapse/res/templates/sso_error.html`.
-
- Args:
- request: The incoming request from the browser.
- We'll respond with an HTML page describing the error.
- error: A technical identifier for this error. Those include
- well-known OAuth2/OIDC error types like invalid_request or
- access_denied.
- error_description: A human-readable description of the error.
- """
- html = self._error_template.render(
- error=error, error_description=error_description
- )
- respond_with_html(request, 400, html)
+ self._sso_handler = hs.get_sso_handler()
def _validate_metadata(self):
"""Verifies the provider metadata.
@@ -571,7 +545,7 @@ class OidcHandler:
Since we might want to display OIDC-related errors in a user-friendly
way, we don't raise SynapseError from here. Instead, we call
- ``self._render_error`` which displays an HTML page for the error.
+ ``self._sso_handler.render_error`` which displays an HTML page for the error.
Most of the OpenID Connect logic happens here:
@@ -609,7 +583,7 @@ class OidcHandler:
if error != "access_denied":
logger.error("Error from the OIDC provider: %s %s", error, description)
- self._render_error(request, error, description)
+ self._sso_handler.render_error(request, error, description)
return
# otherwise, it is presumably a successful response. see:
@@ -619,7 +593,9 @@ class OidcHandler:
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
if session is None:
logger.info("No session cookie found")
- self._render_error(request, "missing_session", "No session cookie found")
+ self._sso_handler.render_error(
+ request, "missing_session", "No session cookie found"
+ )
return
# Remove the cookie. There is a good chance that if the callback failed
@@ -637,7 +613,9 @@ class OidcHandler:
# Check for the state query parameter
if b"state" not in request.args:
logger.info("State parameter is missing")
- self._render_error(request, "invalid_request", "State parameter is missing")
+ self._sso_handler.render_error(
+ request, "invalid_request", "State parameter is missing"
+ )
return
state = request.args[b"state"][0].decode()
@@ -651,17 +629,19 @@ class OidcHandler:
) = self._verify_oidc_session_token(session, state)
except MacaroonDeserializationException as e:
logger.exception("Invalid session")
- self._render_error(request, "invalid_session", str(e))
+ self._sso_handler.render_error(request, "invalid_session", str(e))
return
except MacaroonInvalidSignatureException as e:
logger.exception("Could not verify session")
- self._render_error(request, "mismatching_session", str(e))
+ self._sso_handler.render_error(request, "mismatching_session", str(e))
return
# Exchange the code with the provider
if b"code" not in request.args:
logger.info("Code parameter is missing")
- self._render_error(request, "invalid_request", "Code parameter is missing")
+ self._sso_handler.render_error(
+ request, "invalid_request", "Code parameter is missing"
+ )
return
logger.debug("Exchanging code")
@@ -670,7 +650,7 @@ class OidcHandler:
token = await self._exchange_code(code)
except OidcError as e:
logger.exception("Could not exchange code")
- self._render_error(request, e.error, e.error_description)
+ self._sso_handler.render_error(request, e.error, e.error_description)
return
logger.debug("Successfully obtained OAuth2 access token")
@@ -683,7 +663,7 @@ class OidcHandler:
userinfo = await self._fetch_userinfo(token)
except Exception as e:
logger.exception("Could not fetch userinfo")
- self._render_error(request, "fetch_error", str(e))
+ self._sso_handler.render_error(request, "fetch_error", str(e))
return
else:
logger.debug("Extracting userinfo from id_token")
@@ -691,7 +671,7 @@ class OidcHandler:
userinfo = await self._parse_id_token(token, nonce=nonce)
except Exception as e:
logger.exception("Invalid id_token")
- self._render_error(request, "invalid_token", str(e))
+ self._sso_handler.render_error(request, "invalid_token", str(e))
return
# Pull out the user-agent and IP from the request.
@@ -705,7 +685,7 @@ class OidcHandler:
)
except MappingException as e:
logger.exception("Could not map user")
- self._render_error(request, "mapping_error", str(e))
+ self._sso_handler.render_error(request, "mapping_error", str(e))
return
# Mapping providers might not have get_extra_attributes: only call this
@@ -770,7 +750,7 @@ class OidcHandler:
macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (ui_auth_session_id,)
)
- now = self._clock.time_msec()
+ now = self.clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
@@ -845,7 +825,7 @@ class OidcHandler:
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
- now = self._clock.time_msec()
+ now = self.clock.time_msec()
return now < expiry
async def _map_userinfo_to_user(
@@ -885,71 +865,51 @@ class OidcHandler:
# to be strings.
remote_user_id = str(remote_user_id)
- logger.info(
- "Looking for existing mapping for user %s:%s",
- self._auth_provider_id,
- remote_user_id,
+ # Older mapping providers don't accept the `failures` argument, so we
+ # try and detect support.
+ mapper_signature = inspect.signature(
+ self._user_mapping_provider.map_user_attributes
)
+ supports_failures = "failures" in mapper_signature.parameters
- registered_user_id = await self._datastore.get_user_by_external_id(
- self._auth_provider_id, remote_user_id,
- )
+ async def oidc_response_to_user_attributes(failures: int) -> UserAttributes:
+ """
+ Call the mapping provider to map the OIDC userinfo and token to user attributes.
- if registered_user_id is not None:
- logger.info("Found existing mapping %s", registered_user_id)
- return registered_user_id
-
- try:
- attributes = await self._user_mapping_provider.map_user_attributes(
- userinfo, token
- )
- except Exception as e:
- raise MappingException(
- "Could not extract user attributes from OIDC response: " + str(e)
- )
-
- logger.debug(
- "Retrieved user attributes from user mapping provider: %r", attributes
- )
+ This is backwards compatibility for abstraction for the SSO handler.
+ """
+ if supports_failures:
+ attributes = await self._user_mapping_provider.map_user_attributes(
+ userinfo, token, failures
+ )
+ else:
+ # If the mapping provider does not support processing failures,
+ # do not continually generate the same Matrix ID since it will
+ # continue to already be in use. Note that the error raised is
+ # arbitrary and will get turned into a MappingException.
+ if failures:
+ raise RuntimeError(
+ "Mapping provider does not support de-duplicating Matrix IDs"
+ )
- if not attributes["localpart"]:
- raise MappingException("localpart is empty")
+ attributes = await self._user_mapping_provider.map_user_attributes( # type: ignore
+ userinfo, token
+ )
- localpart = map_username_to_mxid_localpart(attributes["localpart"])
+ return UserAttributes(**attributes)
- user_id = UserID(localpart, self._hostname).to_string()
- users = await self._datastore.get_users_by_id_case_insensitive(user_id)
- if users:
- if self._allow_existing_users:
- if len(users) == 1:
- registered_user_id = next(iter(users))
- elif user_id in users:
- registered_user_id = user_id
- else:
- raise MappingException(
- "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
- user_id, list(users.keys())
- )
- )
- else:
- # This mxid is taken
- raise MappingException("mxid '{}' is already taken".format(user_id))
- else:
- # It's the first time this user is logging in and the mapped mxid was
- # not taken, register the user
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart,
- default_display_name=attributes["display_name"],
- user_agent_ips=(user_agent, ip_address),
- )
- await self._datastore.record_user_external_id(
- self._auth_provider_id, remote_user_id, registered_user_id,
+ return await self._sso_handler.get_mxid_from_sso(
+ self._auth_provider_id,
+ remote_user_id,
+ user_agent,
+ ip_address,
+ oidc_response_to_user_attributes,
+ self._allow_existing_users,
)
- return registered_user_id
-UserAttribute = TypedDict(
- "UserAttribute", {"localpart": str, "display_name": Optional[str]}
+UserAttributeDict = TypedDict(
+ "UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
)
C = TypeVar("C")
@@ -992,13 +952,15 @@ class OidcMappingProvider(Generic[C]):
raise NotImplementedError()
async def map_user_attributes(
- self, userinfo: UserInfo, token: Token
- ) -> UserAttribute:
+ self, userinfo: UserInfo, token: Token, failures: int
+ ) -> UserAttributeDict:
"""Map a `UserInfo` object into user attributes.
Args:
userinfo: An object representing the user given by the OIDC provider
token: A dict with the tokens returned by the provider
+ failures: How many times a call to this function with this
+ UserInfo has resulted in a failure.
Returns:
A dict containing the ``localpart`` and (optionally) the ``display_name``
@@ -1098,10 +1060,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
return userinfo[self._config.subject_claim]
async def map_user_attributes(
- self, userinfo: UserInfo, token: Token
- ) -> UserAttribute:
+ self, userinfo: UserInfo, token: Token, failures: int
+ ) -> UserAttributeDict:
localpart = self._config.localpart_template.render(user=userinfo).strip()
+ # Ensure only valid characters are included in the MXID.
+ localpart = map_username_to_mxid_localpart(localpart)
+
+ # Append suffix integer if last call to this function failed to produce
+ # a usable mxid.
+ localpart += str(failures) if failures else ""
+
display_name = None # type: Optional[str]
if self._config.display_name_template is not None:
display_name = self._config.display_name_template.render(
@@ -1111,7 +1080,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
if display_name == "":
display_name = None
- return UserAttribute(localpart=localpart, display_name=display_name)
+ return UserAttributeDict(localpart=localpart, display_name=display_name)
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
extras = {} # type: Dict[str, str]
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 8e014c9bb5..22d1e9d35c 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -25,7 +25,7 @@ The methods that define policy are:
import abc
import logging
from contextlib import contextmanager
-from typing import Dict, Iterable, List, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
from prometheus_client import Counter
from typing_extensions import ContextManager
@@ -46,8 +46,7 @@ from synapse.util.caches.descriptors import cached
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
-MYPY = False
-if MYPY:
+if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 74a1ddd780..dee0ef45e7 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -206,7 +206,9 @@ class ProfileHandler(BaseHandler):
# the join event to update the displayname in the rooms.
# This must be done by the target user himself.
if by_admin:
- requester = create_requester(target_user)
+ requester = create_requester(
+ target_user, authenticated_entity=requester.authenticated_entity,
+ )
await self.store.set_profile_displayname(
target_user.localpart, displayname_to_set
@@ -286,7 +288,9 @@ class ProfileHandler(BaseHandler):
# Same like set_displayname
if by_admin:
- requester = create_requester(target_user)
+ requester = create_requester(
+ target_user, authenticated_entity=requester.authenticated_entity
+ )
await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index c242c409cf..153cbae7b9 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -158,7 +158,8 @@ class ReceiptEventSource:
if from_key == to_key:
return [], to_key
- # We first need to fetch all new receipts
+ # Fetch all read receipts for all rooms, up to a limit of 100. This is ordered
+ # by most recent.
rooms_to_events = await self.store.get_linearized_receipts_for_all_rooms(
from_key=from_key, to_key=to_key
)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index ed1ff62599..0d85fd0868 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -15,10 +15,12 @@
"""Contains functions for registering clients."""
import logging
+from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse import types
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
+from synapse.appservice import ApplicationService
from synapse.config.server import is_threepid_reserved
from synapse.http.servlet import assert_params_in_dict
from synapse.replication.http.login import RegisterDeviceReplicationServlet
@@ -32,16 +34,14 @@ from synapse.types import RoomAlias, UserID, create_requester
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
class RegistrationHandler(BaseHandler):
- def __init__(self, hs):
- """
-
- Args:
- hs (synapse.server.HomeServer):
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
@@ -52,6 +52,7 @@ class RegistrationHandler(BaseHandler):
self.ratelimiter = hs.get_registration_ratelimiter()
self.macaroon_gen = hs.get_macaroon_generator()
self._server_notices_mxid = hs.config.server_notices_mxid
+ self._server_name = hs.hostname
self.spam_checker = hs.get_spam_checker()
@@ -70,7 +71,10 @@ class RegistrationHandler(BaseHandler):
self.session_lifetime = hs.config.session_lifetime
async def check_username(
- self, localpart, guest_access_token=None, assigned_user_id=None
+ self,
+ localpart: str,
+ guest_access_token: Optional[str] = None,
+ assigned_user_id: Optional[str] = None,
):
if types.contains_invalid_mxid_characters(localpart):
raise SynapseError(
@@ -139,39 +143,45 @@ class RegistrationHandler(BaseHandler):
async def register_user(
self,
- localpart=None,
- password_hash=None,
- guest_access_token=None,
- make_guest=False,
- admin=False,
- threepid=None,
- user_type=None,
- default_display_name=None,
- address=None,
- bind_emails=[],
- by_admin=False,
- user_agent_ips=None,
- ):
+ localpart: Optional[str] = None,
+ password_hash: Optional[str] = None,
+ guest_access_token: Optional[str] = None,
+ make_guest: bool = False,
+ admin: bool = False,
+ threepid: Optional[dict] = None,
+ user_type: Optional[str] = None,
+ default_display_name: Optional[str] = None,
+ address: Optional[str] = None,
+ bind_emails: List[str] = [],
+ by_admin: bool = False,
+ user_agent_ips: Optional[List[Tuple[str, str]]] = None,
+ ) -> str:
"""Registers a new client on the server.
Args:
localpart: The local part of the user ID to register. If None,
one will be generated.
- password_hash (str|None): The hashed password to assign to this user so they can
+ password_hash: The hashed password to assign to this user so they can
login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user).
- user_type (str|None): type of user. One of the values from
+ guest_access_token: The access token used when this was a guest
+ account.
+ make_guest: True if the the new user should be guest,
+ false to add a regular user account.
+ admin: True if the user should be registered as a server admin.
+ threepid: The threepid used for registering, if any.
+ user_type: type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
- default_display_name (unicode|None): if set, the new user's displayname
+ default_display_name: if set, the new user's displayname
will be set to this. Defaults to 'localpart'.
- address (str|None): the IP address used to perform the registration.
- bind_emails (List[str]): list of emails to bind to this account.
- by_admin (bool): True if this registration is being made via the
+ address: the IP address used to perform the registration.
+ bind_emails: list of emails to bind to this account.
+ by_admin: True if this registration is being made via the
admin api, otherwise False.
- user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used
+ user_agent_ips: Tuples of IP addresses and user-agents used
during the registration process.
Returns:
- str: user_id
+ The registere user_id.
Raises:
SynapseError if there was a problem registering.
"""
@@ -235,8 +245,10 @@ class RegistrationHandler(BaseHandler):
else:
# autogen a sequential user ID
fail_count = 0
- user = None
- while not user:
+ # If a default display name is not given, generate one.
+ generate_display_name = default_display_name is None
+ # This breaks on successful registration *or* errors after 10 failures.
+ while True:
# Fail after being unable to find a suitable ID a few times
if fail_count > 10:
raise SynapseError(500, "Unable to find a suitable guest user ID")
@@ -245,7 +257,7 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
self.check_user_id_not_appservice_exclusive(user_id)
- if default_display_name is None:
+ if generate_display_name:
default_display_name = localpart
try:
await self.register_with_store(
@@ -261,8 +273,6 @@ class RegistrationHandler(BaseHandler):
break
except SynapseError:
# if user id is taken, just generate another
- user = None
- user_id = None
fail_count += 1
if not self.hs.config.user_consent_at_registration:
@@ -294,7 +304,7 @@ class RegistrationHandler(BaseHandler):
return user_id
- async def _create_and_join_rooms(self, user_id: str):
+ async def _create_and_join_rooms(self, user_id: str) -> None:
"""
Create the auto-join rooms and join or invite the user to them.
@@ -317,7 +327,8 @@ class RegistrationHandler(BaseHandler):
requires_join = False
if self.hs.config.registration.auto_join_user_id:
fake_requester = create_requester(
- self.hs.config.registration.auto_join_user_id
+ self.hs.config.registration.auto_join_user_id,
+ authenticated_entity=self._server_name,
)
# If the room requires an invite, add the user to the list of invites.
@@ -329,7 +340,9 @@ class RegistrationHandler(BaseHandler):
# being necessary this will occur after the invite was sent.
requires_join = True
else:
- fake_requester = create_requester(user_id)
+ fake_requester = create_requester(
+ user_id, authenticated_entity=self._server_name
+ )
# Choose whether to federate the new room.
if not self.hs.config.registration.autocreate_auto_join_rooms_federated:
@@ -362,7 +375,9 @@ class RegistrationHandler(BaseHandler):
# created it, then ensure the first user joins it.
if requires_join:
await room_member_handler.update_membership(
- requester=create_requester(user_id),
+ requester=create_requester(
+ user_id, authenticated_entity=self._server_name
+ ),
target=UserID.from_string(user_id),
room_id=info["room_id"],
# Since it was just created, there are no remote hosts.
@@ -370,15 +385,10 @@ class RegistrationHandler(BaseHandler):
action="join",
ratelimit=False,
)
-
- except ConsentNotGivenError as e:
- # Technically not necessary to pull out this error though
- # moving away from bare excepts is a good thing to do.
- logger.error("Failed to join new user to %r: %r", r, e)
except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e)
- async def _join_rooms(self, user_id: str):
+ async def _join_rooms(self, user_id: str) -> None:
"""
Join or invite the user to the auto-join rooms.
@@ -424,9 +434,13 @@ class RegistrationHandler(BaseHandler):
# Send the invite, if necessary.
if requires_invite:
+ # If an invite is required, there must be a auto-join user ID.
+ assert self.hs.config.registration.auto_join_user_id
+
await room_member_handler.update_membership(
requester=create_requester(
- self.hs.config.registration.auto_join_user_id
+ self.hs.config.registration.auto_join_user_id,
+ authenticated_entity=self._server_name,
),
target=UserID.from_string(user_id),
room_id=room_id,
@@ -437,7 +451,9 @@ class RegistrationHandler(BaseHandler):
# Send the join.
await room_member_handler.update_membership(
- requester=create_requester(user_id),
+ requester=create_requester(
+ user_id, authenticated_entity=self._server_name
+ ),
target=UserID.from_string(user_id),
room_id=room_id,
remote_room_hosts=remote_room_hosts,
@@ -452,7 +468,7 @@ class RegistrationHandler(BaseHandler):
except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e)
- async def _auto_join_rooms(self, user_id: str):
+ async def _auto_join_rooms(self, user_id: str) -> None:
"""Automatically joins users to auto join rooms - creating the room in the first place
if the user is the first to be created.
@@ -475,16 +491,16 @@ class RegistrationHandler(BaseHandler):
else:
await self._join_rooms(user_id)
- async def post_consent_actions(self, user_id):
+ async def post_consent_actions(self, user_id: str) -> None:
"""A series of registration actions that can only be carried out once consent
has been granted
Args:
- user_id (str): The user to join
+ user_id: The user to join
"""
await self._auto_join_rooms(user_id)
- async def appservice_register(self, user_localpart, as_token):
+ async def appservice_register(self, user_localpart: str, as_token: str) -> str:
user = UserID(user_localpart, self.hs.hostname)
user_id = user.to_string()
service = self.store.get_app_service_by_token(as_token)
@@ -509,7 +525,9 @@ class RegistrationHandler(BaseHandler):
)
return user_id
- def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
+ def check_user_id_not_appservice_exclusive(
+ self, user_id: str, allowed_appservice: Optional[ApplicationService] = None
+ ) -> None:
# don't allow people to register the server notices mxid
if self._server_notices_mxid is not None:
if user_id == self._server_notices_mxid:
@@ -533,12 +551,12 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE,
)
- def check_registration_ratelimit(self, address):
+ def check_registration_ratelimit(self, address: Optional[str]) -> None:
"""A simple helper method to check whether the registration rate limit has been hit
for a given IP address
Args:
- address (str|None): the IP address used to perform the registration. If this is
+ address: the IP address used to perform the registration. If this is
None, no ratelimiting will be performed.
Raises:
@@ -549,42 +567,39 @@ class RegistrationHandler(BaseHandler):
self.ratelimiter.ratelimit(address)
- def register_with_store(
+ async def register_with_store(
self,
- user_id,
- password_hash=None,
- was_guest=False,
- make_guest=False,
- appservice_id=None,
- create_profile_with_displayname=None,
- admin=False,
- user_type=None,
- address=None,
- shadow_banned=False,
- ):
+ user_id: str,
+ password_hash: Optional[str] = None,
+ was_guest: bool = False,
+ make_guest: bool = False,
+ appservice_id: Optional[str] = None,
+ create_profile_with_displayname: Optional[str] = None,
+ admin: bool = False,
+ user_type: Optional[str] = None,
+ address: Optional[str] = None,
+ shadow_banned: bool = False,
+ ) -> None:
"""Register user in the datastore.
Args:
- user_id (str): The desired user ID to register.
- password_hash (str|None): Optional. The password hash for this user.
- was_guest (bool): Optional. Whether this is a guest account being
+ user_id: The desired user ID to register.
+ password_hash: Optional. The password hash for this user.
+ was_guest: Optional. Whether this is a guest account being
upgraded to a non-guest account.
- make_guest (boolean): True if the the new user should be guest,
+ make_guest: True if the the new user should be guest,
false to add a regular user account.
- appservice_id (str|None): The ID of the appservice registering the user.
- create_profile_with_displayname (unicode|None): Optionally create a
+ appservice_id: The ID of the appservice registering the user.
+ create_profile_with_displayname: Optionally create a
profile for the user, setting their displayname to the given value
- admin (boolean): is an admin user?
- user_type (str|None): type of user. One of the values from
+ admin: is an admin user?
+ user_type: type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
- address (str|None): the IP address used to perform the registration.
- shadow_banned (bool): Whether to shadow-ban the user
-
- Returns:
- Awaitable
+ address: the IP address used to perform the registration.
+ shadow_banned: Whether to shadow-ban the user
"""
if self.hs.config.worker_app:
- return self._register_client(
+ await self._register_client(
user_id=user_id,
password_hash=password_hash,
was_guest=was_guest,
@@ -597,7 +612,7 @@ class RegistrationHandler(BaseHandler):
shadow_banned=shadow_banned,
)
else:
- return self.store.register_user(
+ await self.store.register_user(
user_id=user_id,
password_hash=password_hash,
was_guest=was_guest,
@@ -610,22 +625,24 @@ class RegistrationHandler(BaseHandler):
)
async def register_device(
- self, user_id, device_id, initial_display_name, is_guest=False
- ):
+ self,
+ user_id: str,
+ device_id: Optional[str],
+ initial_display_name: Optional[str],
+ is_guest: bool = False,
+ ) -> Tuple[str, str]:
"""Register a device for a user and generate an access token.
The access token will be limited by the homeserver's session_lifetime config.
Args:
- user_id (str): full canonical @user:id
- device_id (str|None): The device ID to check, or None to generate
- a new one.
- initial_display_name (str|None): An optional display name for the
- device.
- is_guest (bool): Whether this is a guest account
+ user_id: full canonical @user:id
+ device_id: The device ID to check, or None to generate a new one.
+ initial_display_name: An optional display name for the device.
+ is_guest: Whether this is a guest account
Returns:
- tuple[str, str]: Tuple of device ID and access token
+ Tuple of device ID and access token
"""
if self.hs.config.worker_app:
@@ -645,7 +662,7 @@ class RegistrationHandler(BaseHandler):
)
valid_until_ms = self.clock.time_msec() + self.session_lifetime
- device_id = await self.device_handler.check_device_registered(
+ registered_device_id = await self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
if is_guest:
@@ -655,20 +672,21 @@ class RegistrationHandler(BaseHandler):
)
else:
access_token = await self._auth_handler.get_access_token_for_user_id(
- user_id, device_id=device_id, valid_until_ms=valid_until_ms
+ user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms
)
- return (device_id, access_token)
+ return (registered_device_id, access_token)
- async def post_registration_actions(self, user_id, auth_result, access_token):
+ async def post_registration_actions(
+ self, user_id: str, auth_result: dict, access_token: Optional[str]
+ ) -> None:
"""A user has completed registration
Args:
- user_id (str): The user ID that consented
- auth_result (dict): The authenticated credentials of the newly
- registered user.
- access_token (str|None): The access token of the newly logged in
- device, or None if `inhibit_login` enabled.
+ user_id: The user ID that consented
+ auth_result: The authenticated credentials of the newly registered user.
+ access_token: The access token of the newly logged in device, or
+ None if `inhibit_login` enabled.
"""
if self.hs.config.worker_app:
await self._post_registration_client(
@@ -694,19 +712,20 @@ class RegistrationHandler(BaseHandler):
if auth_result and LoginType.TERMS in auth_result:
await self._on_user_consented(user_id, self.hs.config.user_consent_version)
- async def _on_user_consented(self, user_id, consent_version):
+ async def _on_user_consented(self, user_id: str, consent_version: str) -> None:
"""A user consented to the terms on registration
Args:
- user_id (str): The user ID that consented.
- consent_version (str): version of the policy the user has
- consented to.
+ user_id: The user ID that consented.
+ consent_version: version of the policy the user has consented to.
"""
logger.info("%s has consented to the privacy policy", user_id)
await self.store.user_set_consent_version(user_id, consent_version)
await self.post_consent_actions(user_id)
- async def _register_email_threepid(self, user_id, threepid, token):
+ async def _register_email_threepid(
+ self, user_id: str, threepid: dict, token: Optional[str]
+ ) -> None:
"""Add an email address as a 3pid identifier
Also adds an email pusher for the email address, if configured in the
@@ -715,10 +734,9 @@ class RegistrationHandler(BaseHandler):
Must be called on master.
Args:
- user_id (str): id of user
- threepid (object): m.login.email.identity auth response
- token (str|None): access_token for the user, or None if not logged
- in.
+ user_id: id of user
+ threepid: m.login.email.identity auth response
+ token: access_token for the user, or None if not logged in.
"""
reqd = ("medium", "address", "validated_at")
if any(x not in threepid for x in reqd):
@@ -744,6 +762,8 @@ class RegistrationHandler(BaseHandler):
# up when the access token is saved, but that's quite an
# invasive change I'd rather do separately.
user_tuple = await self.store.get_user_by_access_token(token)
+ # The token better still exist.
+ assert user_tuple
token_id = user_tuple.token_id
await self.pusher_pool.add_pusher(
@@ -758,14 +778,14 @@ class RegistrationHandler(BaseHandler):
data={},
)
- async def _register_msisdn_threepid(self, user_id, threepid):
+ async def _register_msisdn_threepid(self, user_id: str, threepid: dict) -> None:
"""Add a phone number as a 3pid identifier
Must be called on master.
Args:
- user_id (str): id of user
- threepid (object): m.login.msisdn auth response
+ user_id: id of user
+ threepid: m.login.msisdn auth response
"""
try:
assert_params_in_dict(threepid, ["medium", "address", "validated_at"])
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index e73031475f..930047e730 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -587,7 +587,7 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- await self.auth.check_auth_blocking(user_id)
+ await self.auth.check_auth_blocking(requester=requester)
if (
self._server_notices_mxid is not None
@@ -1257,7 +1257,9 @@ class RoomShutdownHandler:
400, "User must be our own: %s" % (new_room_user_id,)
)
- room_creator_requester = create_requester(new_room_user_id)
+ room_creator_requester = create_requester(
+ new_room_user_id, authenticated_entity=requester_user_id
+ )
info, stream_id = await self._room_creation_handler.create_room(
room_creator_requester,
@@ -1297,7 +1299,9 @@ class RoomShutdownHandler:
try:
# Kick users from room
- target_requester = create_requester(user_id)
+ target_requester = create_requester(
+ user_id, authenticated_entity=requester_user_id
+ )
_, stream_id = await self.room_member_handler.update_membership(
requester=target_requester,
target=target_requester.user,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 7e5e53a56f..4e693a419e 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -31,7 +31,6 @@ from synapse.api.errors import (
from synapse.api.ratelimiting import Ratelimiter
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.storage.roommember import RoomsForUser
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_left_room
@@ -535,10 +534,16 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
# perhaps we've been invited
- invite = await self.store.get_invite_for_local_user_in_room(
- user_id=target.to_string(), room_id=room_id
- ) # type: Optional[RoomsForUser]
- if not invite:
+ (
+ current_membership_type,
+ current_membership_event_id,
+ ) = await self.store.get_local_current_membership_for_user_in_room(
+ target.to_string(), room_id
+ )
+ if (
+ current_membership_type != Membership.INVITE
+ or not current_membership_event_id
+ ):
logger.info(
"%s sent a leave request to %s, but that is not an active room "
"on this server, and there is no pending invite",
@@ -548,6 +553,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise SynapseError(404, "Not a known room")
+ invite = await self.store.get_event(current_membership_event_id)
logger.info(
"%s rejects invite to %s from %s", target, room_id, invite.sender
)
@@ -985,6 +991,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
self.distributor = hs.get_distributor()
self.distributor.declare("user_left_room")
+ self._server_name = hs.hostname
async def _is_remote_room_too_complex(
self, room_id: str, remote_room_hosts: List[str]
@@ -1079,7 +1086,9 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return event_id, stream_id
# The room is too large. Leave.
- requester = types.create_requester(user, None, False, False, None)
+ requester = types.create_requester(
+ user, authenticated_entity=self._server_name
+ )
await self.update_membership(
requester=requester, target=user, room_id=room_id, action="leave"
)
@@ -1124,32 +1133,34 @@ class RoomMemberMasterHandler(RoomMemberHandler):
#
logger.warning("Failed to reject invite: %s", e)
- return await self._locally_reject_invite(
+ return await self._generate_local_out_of_band_leave(
invite_event, txn_id, requester, content
)
- async def _locally_reject_invite(
+ async def _generate_local_out_of_band_leave(
self,
- invite_event: EventBase,
+ previous_membership_event: EventBase,
txn_id: Optional[str],
requester: Requester,
content: JsonDict,
) -> Tuple[str, int]:
- """Generate a local invite rejection
+ """Generate a local leave event for a room
- This is called after we fail to reject an invite via a remote server. It
- generates an out-of-band membership event locally.
+ This can be called after we e.g fail to reject an invite via a remote server.
+ It generates an out-of-band membership event locally.
Args:
- invite_event: the invite to be rejected
+ previous_membership_event: the previous membership event for this user
txn_id: optional transaction ID supplied by the client
- requester: user making the rejection request, according to the access token
- content: additional content to include in the rejection event.
+ requester: user making the request, according to the access token
+ content: additional content to include in the leave event.
Normally an empty dict.
- """
- room_id = invite_event.room_id
- target_user = invite_event.state_key
+ Returns:
+ A tuple containing (event_id, stream_id of the leave event)
+ """
+ room_id = previous_membership_event.room_id
+ target_user = previous_membership_event.state_key
content["membership"] = Membership.LEAVE
@@ -1161,12 +1172,12 @@ class RoomMemberMasterHandler(RoomMemberHandler):
"state_key": target_user,
}
- # the auth events for the new event are the same as that of the invite, plus
- # the invite itself.
+ # the auth events for the new event are the same as that of the previous event, plus
+ # the event itself.
#
- # the prev_events are just the invite.
- prev_event_ids = [invite_event.event_id]
- auth_event_ids = invite_event.auth_event_ids() + prev_event_ids
+ # the prev_events consist solely of the previous membership event.
+ prev_event_ids = [previous_membership_event.event_id]
+ auth_event_ids = previous_membership_event.auth_event_ids() + prev_event_ids
event, context = await self.event_creation_handler.create_event(
requester,
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index fd6c5e9ea8..34db10ffe4 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -24,7 +24,8 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.config.saml2_config import SamlAttributeRequirement
-from synapse.http.server import respond_with_html
+from synapse.handlers._base import BaseHandler
+from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi
@@ -37,15 +38,11 @@ from synapse.util.async_helpers import Linearizer
from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING:
- import synapse.server
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
-class MappingException(Exception):
- """Used to catch errors when mapping the SAML2 response to a user."""
-
-
@attr.s(slots=True)
class Saml2SessionData:
"""Data we track about SAML2 sessions"""
@@ -57,17 +54,14 @@ class Saml2SessionData:
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
-class SamlHandler:
- def __init__(self, hs: "synapse.server.HomeServer"):
- self.hs = hs
+class SamlHandler(BaseHandler):
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
- self._auth = hs.get_auth()
+ self._saml_idp_entityid = hs.config.saml2_idp_entityid
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
- self._clock = hs.get_clock()
- self._datastore = hs.get_datastore()
- self._hostname = hs.hostname
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
self._grandfathered_mxid_source_attribute = (
hs.config.saml2_grandfathered_mxid_source_attribute
@@ -88,26 +82,9 @@ class SamlHandler:
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
# a lock on the mappings
- self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
-
- def _render_error(
- self, request, error: str, error_description: Optional[str] = None
- ) -> None:
- """Render the error template and respond to the request with it.
-
- This is used to show errors to the user. The template of this page can
- be found under `synapse/res/templates/sso_error.html`.
+ self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
- Args:
- request: The incoming request from the browser.
- We'll respond with an HTML page describing the error.
- error: A technical identifier for this error.
- error_description: A human-readable description of the error.
- """
- html = self._error_template.render(
- error=error, error_description=error_description
- )
- respond_with_html(request, 400, html)
+ self._sso_handler = hs.get_sso_handler()
def handle_redirect_request(
self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
@@ -124,13 +101,13 @@ class SamlHandler:
URL to redirect to
"""
reqid, info = self._saml_client.prepare_for_authenticate(
- relay_state=client_redirect_url
+ entityid=self._saml_idp_entityid, relay_state=client_redirect_url
)
# Since SAML sessions timeout it is useful to log when they were created.
logger.info("Initiating a new SAML session: %s" % (reqid,))
- now = self._clock.time_msec()
+ now = self.clock.time_msec()
self._outstanding_requests_dict[reqid] = Saml2SessionData(
creation_time=now, ui_auth_session_id=ui_auth_session_id,
)
@@ -171,12 +148,12 @@ class SamlHandler:
# in the (user-visible) exception message, so let's log the exception here
# so we can track down the session IDs later.
logger.warning(str(e))
- self._render_error(
+ self._sso_handler.render_error(
request, "unsolicited_response", "Unexpected SAML2 login."
)
return
except Exception as e:
- self._render_error(
+ self._sso_handler.render_error(
request,
"invalid_response",
"Unable to parse SAML2 response: %s." % (e,),
@@ -184,7 +161,7 @@ class SamlHandler:
return
if saml2_auth.not_signed:
- self._render_error(
+ self._sso_handler.render_error(
request, "unsigned_respond", "SAML2 response was not signed."
)
return
@@ -210,7 +187,7 @@ class SamlHandler:
# attributes.
for requirement in self._saml2_attribute_requirements:
if not _check_attribute_requirement(saml2_auth.ava, requirement):
- self._render_error(
+ self._sso_handler.render_error(
request, "unauthorised", "You are not authorised to log in here."
)
return
@@ -226,7 +203,7 @@ class SamlHandler:
)
except MappingException as e:
logger.exception("Could not map user")
- self._render_error(request, "mapping_error", str(e))
+ self._sso_handler.render_error(request, "mapping_error", str(e))
return
# Complete the interactive auth session or the login.
@@ -272,20 +249,26 @@ class SamlHandler:
"Failed to extract remote user id from SAML response"
)
- with (await self._mapping_lock.queue(self._auth_provider_id)):
- # first of all, check if we already have a mapping for this user
- logger.info(
- "Looking for existing mapping for user %s:%s",
- self._auth_provider_id,
- remote_user_id,
+ async def saml_response_to_remapped_user_attributes(
+ failures: int,
+ ) -> UserAttributes:
+ """
+ Call the mapping provider to map a SAML response to user attributes and coerce the result into the standard form.
+
+ This is backwards compatibility for abstraction for the SSO handler.
+ """
+ # Call the mapping provider.
+ result = self._user_mapping_provider.saml_response_to_user_attributes(
+ saml2_auth, failures, client_redirect_url
)
- registered_user_id = await self._datastore.get_user_by_external_id(
- self._auth_provider_id, remote_user_id
+ # Remap some of the results.
+ return UserAttributes(
+ localpart=result.get("mxid_localpart"),
+ display_name=result.get("displayname"),
+ emails=result.get("emails"),
)
- if registered_user_id is not None:
- logger.info("Found existing mapping %s", registered_user_id)
- return registered_user_id
+ with (await self._mapping_lock.queue(self._auth_provider_id)):
# backwards-compatibility hack: see if there is an existing user with a
# suitable mapping from the uid
if (
@@ -294,75 +277,34 @@ class SamlHandler:
):
attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
user_id = UserID(
- map_username_to_mxid_localpart(attrval), self._hostname
+ map_username_to_mxid_localpart(attrval), self.server_name
).to_string()
- logger.info(
+
+ logger.debug(
"Looking for existing account based on mapped %s %s",
self._grandfathered_mxid_source_attribute,
user_id,
)
- users = await self._datastore.get_users_by_id_case_insensitive(user_id)
+ users = await self.store.get_users_by_id_case_insensitive(user_id)
if users:
registered_user_id = list(users.keys())[0]
logger.info("Grandfathering mapping to %s", registered_user_id)
- await self._datastore.record_user_external_id(
+ await self.store.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id
- # Map saml response to user attributes using the configured mapping provider
- for i in range(1000):
- attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
- saml2_auth, i, client_redirect_url=client_redirect_url,
- )
-
- logger.debug(
- "Retrieved SAML attributes from user mapping provider: %s "
- "(attempt %d)",
- attribute_dict,
- i,
- )
-
- localpart = attribute_dict.get("mxid_localpart")
- if not localpart:
- raise MappingException(
- "Error parsing SAML2 response: SAML mapping provider plugin "
- "did not return a mxid_localpart value"
- )
-
- displayname = attribute_dict.get("displayname")
- emails = attribute_dict.get("emails", [])
-
- # Check if this mxid already exists
- if not await self._datastore.get_users_by_id_case_insensitive(
- UserID(localpart, self._hostname).to_string()
- ):
- # This mxid is free
- break
- else:
- # Unable to generate a username in 1000 iterations
- # Break and return error to the user
- raise MappingException(
- "Unable to generate a Matrix ID from the SAML response"
- )
-
- logger.info("Mapped SAML user to local part %s", localpart)
-
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart,
- default_display_name=displayname,
- bind_emails=emails,
- user_agent_ips=(user_agent, ip_address),
- )
-
- await self._datastore.record_user_external_id(
- self._auth_provider_id, remote_user_id, registered_user_id
+ return await self._sso_handler.get_mxid_from_sso(
+ self._auth_provider_id,
+ remote_user_id,
+ user_agent,
+ ip_address,
+ saml_response_to_remapped_user_attributes,
)
- return registered_user_id
def expire_sessions(self):
- expire_before = self._clock.time_msec() - self._saml2_session_lifetime
+ expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set()
for reqid, data in self._outstanding_requests_dict.items():
if data.creation_time < expire_before:
@@ -474,11 +416,11 @@ class DefaultSamlMappingProvider:
)
# Use the configured mapper for this mxid_source
- base_mxid_localpart = self._mxid_mapper(mxid_source)
+ localpart = self._mxid_mapper(mxid_source)
# Append suffix integer if last call to this function failed to produce
- # a usable mxid
- localpart = base_mxid_localpart + (str(failures) if failures else "")
+ # a usable mxid.
+ localpart += str(failures) if failures else ""
# Retrieve the display name from the saml response
# If displayname is None, the mxid_localpart will be used instead
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
new file mode 100644
index 0000000000..d963082210
--- /dev/null
+++ b/synapse/handlers/sso.py
@@ -0,0 +1,249 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
+
+import attr
+
+from synapse.handlers._base import BaseHandler
+from synapse.http.server import respond_with_html
+from synapse.types import UserID, contains_invalid_mxid_characters
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class MappingException(Exception):
+ """Used to catch errors when mapping the UserInfo object
+ """
+
+
+@attr.s
+class UserAttributes:
+ localpart = attr.ib(type=str)
+ display_name = attr.ib(type=Optional[str], default=None)
+ emails = attr.ib(type=List[str], default=attr.Factory(list))
+
+
+class SsoHandler(BaseHandler):
+ # The number of attempts to ask the mapping provider for when generating an MXID.
+ _MAP_USERNAME_RETRIES = 1000
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+ self._registration_handler = hs.get_registration_handler()
+ self._error_template = hs.config.sso_error_template
+
+ def render_error(
+ self, request, error: str, error_description: Optional[str] = None
+ ) -> None:
+ """Renders the error template and responds with it.
+
+ This is used to show errors to the user. The template of this page can
+ be found under `synapse/res/templates/sso_error.html`.
+
+ Args:
+ request: The incoming request from the browser.
+ We'll respond with an HTML page describing the error.
+ error: A technical identifier for this error.
+ error_description: A human-readable description of the error.
+ """
+ html = self._error_template.render(
+ error=error, error_description=error_description
+ )
+ respond_with_html(request, 400, html)
+
+ async def get_sso_user_by_remote_user_id(
+ self, auth_provider_id: str, remote_user_id: str
+ ) -> Optional[str]:
+ """
+ Maps the user ID of a remote IdP to a mxid for a previously seen user.
+
+ If the user has not been seen yet, this will return None.
+
+ Args:
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+ remote_user_id: The user ID according to the remote IdP. This might
+ be an e-mail address, a GUID, or some other form. It must be
+ unique and immutable.
+
+ Returns:
+ The mxid of a previously seen user.
+ """
+ logger.debug(
+ "Looking for existing mapping for user %s:%s",
+ auth_provider_id,
+ remote_user_id,
+ )
+
+ # Check if we already have a mapping for this user.
+ previously_registered_user_id = await self.store.get_user_by_external_id(
+ auth_provider_id, remote_user_id,
+ )
+
+ # A match was found, return the user ID.
+ if previously_registered_user_id is not None:
+ logger.info(
+ "Found existing mapping for IdP '%s' and remote_user_id '%s': %s",
+ auth_provider_id,
+ remote_user_id,
+ previously_registered_user_id,
+ )
+ return previously_registered_user_id
+
+ # No match.
+ return None
+
+ async def get_mxid_from_sso(
+ self,
+ auth_provider_id: str,
+ remote_user_id: str,
+ user_agent: str,
+ ip_address: str,
+ sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
+ allow_existing_users: bool = False,
+ ) -> str:
+ """
+ Given an SSO ID, retrieve the user ID for it and possibly register the user.
+
+ This first checks if the SSO ID has previously been linked to a matrix ID,
+ if it has that matrix ID is returned regardless of the current mapping
+ logic.
+
+ The mapping function is called (potentially multiple times) to generate
+ a localpart for the user.
+
+ If an unused localpart is generated, the user is registered from the
+ given user-agent and IP address and the SSO ID is linked to this matrix
+ ID for subsequent calls.
+
+ If allow_existing_users is true the mapping function is only called once
+ and results in:
+
+ 1. The use of a previously registered matrix ID. In this case, the
+ SSO ID is linked to the matrix ID. (Note it is possible that
+ other SSO IDs are linked to the same matrix ID.)
+ 2. An unused localpart, in which case the user is registered (as
+ discussed above).
+ 3. An error if the generated localpart matches multiple pre-existing
+ matrix IDs. Generally this should not happen.
+
+ Args:
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+ remote_user_id: The unique identifier from the SSO provider.
+ user_agent: The user agent of the client making the request.
+ ip_address: The IP address of the client making the request.
+ sso_to_matrix_id_mapper: A callable to generate the user attributes.
+ The only parameter is an integer which represents the amount of
+ times the returned mxid localpart mapping has failed.
+ allow_existing_users: True if the localpart returned from the
+ mapping provider can be linked to an existing matrix ID.
+
+ Returns:
+ The user ID associated with the SSO response.
+
+ Raises:
+ MappingException if there was a problem mapping the response to a user.
+ RedirectException: some mapping providers may raise this if they need
+ to redirect to an interstitial page.
+
+ """
+ # first of all, check if we already have a mapping for this user
+ previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
+ auth_provider_id, remote_user_id,
+ )
+ if previously_registered_user_id:
+ return previously_registered_user_id
+
+ # Otherwise, generate a new user.
+ for i in range(self._MAP_USERNAME_RETRIES):
+ try:
+ attributes = await sso_to_matrix_id_mapper(i)
+ except Exception as e:
+ raise MappingException(
+ "Could not extract user attributes from SSO response: " + str(e)
+ )
+
+ logger.debug(
+ "Retrieved user attributes from user mapping provider: %r (attempt %d)",
+ attributes,
+ i,
+ )
+
+ if not attributes.localpart:
+ raise MappingException(
+ "Error parsing SSO response: SSO mapping provider plugin "
+ "did not return a localpart value"
+ )
+
+ # Check if this mxid already exists
+ user_id = UserID(attributes.localpart, self.server_name).to_string()
+ users = await self.store.get_users_by_id_case_insensitive(user_id)
+ # Note, if allow_existing_users is true then the loop is guaranteed
+ # to end on the first iteration: either by matching an existing user,
+ # raising an error, or registering a new user. See the docstring for
+ # more in-depth an explanation.
+ if users and allow_existing_users:
+ # If an existing matrix ID is returned, then use it.
+ if len(users) == 1:
+ previously_registered_user_id = next(iter(users))
+ elif user_id in users:
+ previously_registered_user_id = user_id
+ else:
+ # Do not attempt to continue generating Matrix IDs.
+ raise MappingException(
+ "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
+ user_id, users
+ )
+ )
+
+ # Future logins should also match this user ID.
+ await self.store.record_user_external_id(
+ auth_provider_id, remote_user_id, previously_registered_user_id
+ )
+
+ return previously_registered_user_id
+
+ elif not users:
+ # This mxid is free
+ break
+ else:
+ # Unable to generate a username in 1000 iterations
+ # Break and return error to the user
+ raise MappingException(
+ "Unable to generate a Matrix ID from the SSO response"
+ )
+
+ # Since the localpart is provided via a potentially untrusted module,
+ # ensure the MXID is valid before registering.
+ if contains_invalid_mxid_characters(attributes.localpart):
+ raise MappingException("localpart is invalid: %s" % (attributes.localpart,))
+
+ logger.debug("Mapped SSO user to local part %s", attributes.localpart)
+ registered_user_id = await self._registration_handler.register_user(
+ localpart=attributes.localpart,
+ default_display_name=attributes.display_name,
+ bind_emails=attributes.emails,
+ user_agent_ips=[(user_agent, ip_address)],
+ )
+
+ await self.store.record_user_external_id(
+ auth_provider_id, remote_user_id, registered_user_id
+ )
+ return registered_user_id
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index ed67bcc4f5..b9ae70adbe 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -31,6 +31,7 @@ from synapse.types import (
Collection,
JsonDict,
MutableStateMap,
+ Requester,
RoomStreamToken,
StateMap,
StreamToken,
@@ -261,6 +262,7 @@ class SyncHandler:
async def wait_for_sync_for_user(
self,
+ requester: Requester,
sync_config: SyncConfig,
since_token: Optional[StreamToken] = None,
timeout: int = 0,
@@ -274,7 +276,7 @@ class SyncHandler:
# not been exceeded (if not part of the group by this point, almost certain
# auth_blocking will occur)
user_id = sync_config.user.to_string()
- await self.auth.check_auth_blocking(user_id)
+ await self.auth.check_auth_blocking(requester=requester)
res = await self.response_cache.wrap(
sync_config.request_key,
|