diff --git a/changelog.d/8617.feature b/changelog.d/8617.feature
new file mode 100644
index 0000000000..4f1e788506
--- /dev/null
+++ b/changelog.d/8617.feature
@@ -0,0 +1 @@
+Add admin API for logging in as a user.
diff --git a/changelog.d/8765.misc b/changelog.d/8765.misc
new file mode 100644
index 0000000000..053f9acc9c
--- /dev/null
+++ b/changelog.d/8765.misc
@@ -0,0 +1 @@
+Consolidate logic between the OpenID Connect and SAML code.
diff --git a/changelog.d/8770.misc b/changelog.d/8770.misc
new file mode 100644
index 0000000000..b5876a82f9
--- /dev/null
+++ b/changelog.d/8770.misc
@@ -0,0 +1 @@
+Use `TYPE_CHECKING` instead of magic `MYPY` variable.
diff --git a/changelog.d/8771.doc b/changelog.d/8771.doc
new file mode 100644
index 0000000000..297cf61e98
--- /dev/null
+++ b/changelog.d/8771.doc
@@ -0,0 +1 @@
+Remove extraneous comma from JSON example in User Admin API docs.
\ No newline at end of file
diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst
index d4051d0257..84863296e3 100644
--- a/docs/admin_api/user_admin_api.rst
+++ b/docs/admin_api/user_admin_api.rst
@@ -254,7 +254,7 @@ with a body of:
{
"new_password": "<secret>",
- "logout_devices": true,
+ "logout_devices": true
}
To use it, you will need to authenticate by providing an ``access_token`` for a
@@ -424,6 +424,41 @@ The following fields are returned in the JSON response body:
- ``next_token``: integer - Indication for pagination. See above.
- ``total`` - integer - Total number of media.
+Login as a user
+===============
+
+Get an access token that can be used to authenticate as that user. Useful for
+when admins wish to do actions on behalf of a user.
+
+The API is::
+
+ POST /_synapse/admin/v1/users/<user_id>/login
+ {}
+
+An optional ``valid_until_ms`` field can be specified in the request body as an
+integer timestamp that specifies when the token should expire. By default tokens
+do not expire.
+
+A response body like the following is returned:
+
+.. code:: json
+
+ {
+ "access_token": "<opaque_access_token_string>"
+ }
+
+
+This API does *not* generate a new device for the user, and so will not appear
+their ``/devices`` list, and in general the target user should not be able to
+tell they have been logged in as.
+
+To expire the token call the standard ``/logout`` API with the token.
+
+Note: The token will expire if the *admin* user calls ``/logout/all`` from any
+of their devices, but the token will *not* expire if the target user does the
+same.
+
+
User devices
============
diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py
index d8fafd7cb8..9c227218e0 100644
--- a/synapse/api/auth_blocking.py
+++ b/synapse/api/auth_blocking.py
@@ -14,10 +14,12 @@
# limitations under the License.
import logging
+from typing import Optional
from synapse.api.constants import LimitBlockingTypes, UserTypes
from synapse.api.errors import Codes, ResourceLimitError
from synapse.config.server import is_threepid_reserved
+from synapse.types import Requester
logger = logging.getLogger(__name__)
@@ -33,24 +35,47 @@ class AuthBlocking:
self._max_mau_value = hs.config.max_mau_value
self._limit_usage_by_mau = hs.config.limit_usage_by_mau
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
+ self._server_name = hs.hostname
- async def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
+ async def check_auth_blocking(
+ self,
+ user_id: Optional[str] = None,
+ threepid: Optional[dict] = None,
+ user_type: Optional[str] = None,
+ requester: Optional[Requester] = None,
+ ):
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag
Args:
- user_id(str|None): If present, checks for presence against existing
+ user_id: If present, checks for presence against existing
MAU cohort
- threepid(dict|None): If present, checks for presence against configured
+ threepid: If present, checks for presence against configured
reserved threepid. Used in cases where the user is trying register
with a MAU blocked server, normally they would be rejected but their
threepid is on the reserved list. user_id and
threepid should never be set at the same time.
- user_type(str|None): If present, is used to decide whether to check against
+ user_type: If present, is used to decide whether to check against
certain blocking reasons like MAU.
+
+ requester: If present, and the authenticated entity is a user, checks for
+ presence against existing MAU cohort. Passing in both a `user_id` and
+ `requester` is an error.
"""
+ if requester and user_id:
+ raise Exception(
+ "Passed in both 'user_id' and 'requester' to 'check_auth_blocking'"
+ )
+
+ if requester:
+ if requester.authenticated_entity.startswith("@"):
+ user_id = requester.authenticated_entity
+ elif requester.authenticated_entity == self._server_name:
+ # We never block the server from doing actions on behalf of
+ # users.
+ return
# Never fail an auth check for the server notices users or support user
# This can be a problem where event creation is prohibited due to blocking
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index bad18f7fdf..936896656a 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -15,13 +15,12 @@
# limitations under the License.
import inspect
-from typing import Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import Collection
-MYPY = False
-if MYPY:
+if TYPE_CHECKING:
import synapse.events
import synapse.server
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index bd8e71ae56..bb81c0e81d 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -169,7 +169,9 @@ class BaseHandler:
# and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having
# homeserver.
- requester = synapse.types.create_requester(target_user, is_guest=True)
+ requester = synapse.types.create_requester(
+ target_user, is_guest=True, authenticated_entity=self.server_name
+ )
handler = self.hs.get_room_member_handler()
await handler.update_membership(
requester,
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 213baea2e3..5163afd86c 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -698,8 +698,12 @@ class AuthHandler(BaseHandler):
}
async def get_access_token_for_user_id(
- self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
- ):
+ self,
+ user_id: str,
+ device_id: Optional[str],
+ valid_until_ms: Optional[int],
+ puppets_user_id: Optional[str] = None,
+ ) -> str:
"""
Creates a new access token for the user with the given user ID.
@@ -725,13 +729,25 @@ class AuthHandler(BaseHandler):
fmt_expiry = time.strftime(
" until %Y-%m-%d %H:%M:%S", time.localtime(valid_until_ms / 1000.0)
)
- logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry)
+
+ if puppets_user_id:
+ logger.info(
+ "Logging in user %s as %s%s", user_id, puppets_user_id, fmt_expiry
+ )
+ else:
+ logger.info(
+ "Logging in user %s on device %s%s", user_id, device_id, fmt_expiry
+ )
await self.auth.check_auth_blocking(user_id)
access_token = self.macaroon_gen.generate_access_token(user_id)
await self.store.add_access_token_to_user(
- user_id, access_token, device_id, valid_until_ms
+ user_id=user_id,
+ token=access_token,
+ device_id=device_id,
+ valid_until_ms=valid_until_ms,
+ puppets_user_id=puppets_user_id,
)
# the device *should* have been registered before we got here; however,
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 4efe6c530a..e808142365 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -39,6 +39,7 @@ class DeactivateAccountHandler(BaseHandler):
self._room_member_handler = hs.get_room_member_handler()
self._identity_handler = hs.get_identity_handler()
self.user_directory_handler = hs.get_user_directory_handler()
+ self._server_name = hs.hostname
# Flag that indicates whether the process to part users from rooms is running
self._user_parter_running = False
@@ -152,7 +153,7 @@ class DeactivateAccountHandler(BaseHandler):
for room in pending_invites:
try:
await self._room_member_handler.update_membership(
- create_requester(user),
+ create_requester(user, authenticated_entity=self._server_name),
user,
room.room_id,
"leave",
@@ -208,7 +209,7 @@ class DeactivateAccountHandler(BaseHandler):
logger.info("User parter parting %r from %r", user_id, room_id)
try:
await self._room_member_handler.update_membership(
- create_requester(user),
+ create_requester(user, authenticated_entity=self._server_name),
user,
room_id,
"leave",
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index c6791fb912..96843338ae 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -472,7 +472,7 @@ class EventCreationHandler:
Returns:
Tuple of created event, Context
"""
- await self.auth.check_auth_blocking(requester.user.to_string())
+ await self.auth.check_auth_blocking(requester=requester)
if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
room_version = event_dict["content"]["room_version"]
@@ -619,7 +619,13 @@ class EventCreationHandler:
if requester.app_service is not None:
return
- user_id = requester.user.to_string()
+ user_id = requester.authenticated_entity
+ if not user_id.startswith("@"):
+ # The authenticated entity might not be a user, e.g. if it's the
+ # server puppetting the user.
+ return
+
+ user = UserID.from_string(user_id)
# exempt the system notices user
if (
@@ -639,9 +645,7 @@ class EventCreationHandler:
if u["consent_version"] == self.config.user_consent_version:
return
- consent_uri = self._consent_uri_builder.build_user_consent_uri(
- requester.user.localpart
- )
+ consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart)
msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
@@ -1252,7 +1256,7 @@ class EventCreationHandler:
for user_id in members:
if not self.hs.is_mine_id(user_id):
continue
- requester = create_requester(user_id)
+ requester = create_requester(user_id, authenticated_entity=self.server_name)
try:
event, context = await self.create_event(
requester,
@@ -1273,11 +1277,6 @@ class EventCreationHandler:
requester, event, context, ratelimit=False, ignore_shadow_ban=True,
)
return True
- except ConsentNotGivenError:
- logger.info(
- "Failed to send dummy event into room %s for user %s due to "
- "lack of consent. Will try another user" % (room_id, user_id)
- )
except AuthError:
logger.info(
"Failed to send dummy event into room %s for user %s due to "
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 331d4e7e96..be8562d47b 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -34,7 +34,8 @@ from typing_extensions import TypedDict
from twisted.web.client import readBody
from synapse.config import ConfigError
-from synapse.http.server import respond_with_html
+from synapse.handlers._base import BaseHandler
+from synapse.handlers.sso import MappingException
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
@@ -83,17 +84,12 @@ class OidcError(Exception):
return self.error
-class MappingException(Exception):
- """Used to catch errors when mapping the UserInfo object
- """
-
-
-class OidcHandler:
+class OidcHandler(BaseHandler):
"""Handles requests related to the OpenID Connect login flow.
"""
def __init__(self, hs: "HomeServer"):
- self.hs = hs
+ super().__init__(hs)
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
self._user_profile_method = hs.config.oidc_user_profile_method # type: str
@@ -120,36 +116,13 @@ class OidcHandler:
self._http_client = hs.get_proxied_http_client()
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
- self._datastore = hs.get_datastore()
- self._clock = hs.get_clock()
- self._hostname = hs.hostname # type: str
self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key
- self._error_template = hs.config.sso_error_template
# identifier for the external_ids table
self._auth_provider_id = "oidc"
- def _render_error(
- self, request, error: str, error_description: Optional[str] = None
- ) -> None:
- """Render the error template and respond to the request with it.
-
- This is used to show errors to the user. The template of this page can
- be found under `synapse/res/templates/sso_error.html`.
-
- Args:
- request: The incoming request from the browser.
- We'll respond with an HTML page describing the error.
- error: A technical identifier for this error. Those include
- well-known OAuth2/OIDC error types like invalid_request or
- access_denied.
- error_description: A human-readable description of the error.
- """
- html = self._error_template.render(
- error=error, error_description=error_description
- )
- respond_with_html(request, 400, html)
+ self._sso_handler = hs.get_sso_handler()
def _validate_metadata(self):
"""Verifies the provider metadata.
@@ -571,7 +544,7 @@ class OidcHandler:
Since we might want to display OIDC-related errors in a user-friendly
way, we don't raise SynapseError from here. Instead, we call
- ``self._render_error`` which displays an HTML page for the error.
+ ``self._sso_handler.render_error`` which displays an HTML page for the error.
Most of the OpenID Connect logic happens here:
@@ -609,7 +582,7 @@ class OidcHandler:
if error != "access_denied":
logger.error("Error from the OIDC provider: %s %s", error, description)
- self._render_error(request, error, description)
+ self._sso_handler.render_error(request, error, description)
return
# otherwise, it is presumably a successful response. see:
@@ -619,7 +592,9 @@ class OidcHandler:
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
if session is None:
logger.info("No session cookie found")
- self._render_error(request, "missing_session", "No session cookie found")
+ self._sso_handler.render_error(
+ request, "missing_session", "No session cookie found"
+ )
return
# Remove the cookie. There is a good chance that if the callback failed
@@ -637,7 +612,9 @@ class OidcHandler:
# Check for the state query parameter
if b"state" not in request.args:
logger.info("State parameter is missing")
- self._render_error(request, "invalid_request", "State parameter is missing")
+ self._sso_handler.render_error(
+ request, "invalid_request", "State parameter is missing"
+ )
return
state = request.args[b"state"][0].decode()
@@ -651,17 +628,19 @@ class OidcHandler:
) = self._verify_oidc_session_token(session, state)
except MacaroonDeserializationException as e:
logger.exception("Invalid session")
- self._render_error(request, "invalid_session", str(e))
+ self._sso_handler.render_error(request, "invalid_session", str(e))
return
except MacaroonInvalidSignatureException as e:
logger.exception("Could not verify session")
- self._render_error(request, "mismatching_session", str(e))
+ self._sso_handler.render_error(request, "mismatching_session", str(e))
return
# Exchange the code with the provider
if b"code" not in request.args:
logger.info("Code parameter is missing")
- self._render_error(request, "invalid_request", "Code parameter is missing")
+ self._sso_handler.render_error(
+ request, "invalid_request", "Code parameter is missing"
+ )
return
logger.debug("Exchanging code")
@@ -670,7 +649,7 @@ class OidcHandler:
token = await self._exchange_code(code)
except OidcError as e:
logger.exception("Could not exchange code")
- self._render_error(request, e.error, e.error_description)
+ self._sso_handler.render_error(request, e.error, e.error_description)
return
logger.debug("Successfully obtained OAuth2 access token")
@@ -683,7 +662,7 @@ class OidcHandler:
userinfo = await self._fetch_userinfo(token)
except Exception as e:
logger.exception("Could not fetch userinfo")
- self._render_error(request, "fetch_error", str(e))
+ self._sso_handler.render_error(request, "fetch_error", str(e))
return
else:
logger.debug("Extracting userinfo from id_token")
@@ -691,7 +670,7 @@ class OidcHandler:
userinfo = await self._parse_id_token(token, nonce=nonce)
except Exception as e:
logger.exception("Invalid id_token")
- self._render_error(request, "invalid_token", str(e))
+ self._sso_handler.render_error(request, "invalid_token", str(e))
return
# Pull out the user-agent and IP from the request.
@@ -705,7 +684,7 @@ class OidcHandler:
)
except MappingException as e:
logger.exception("Could not map user")
- self._render_error(request, "mapping_error", str(e))
+ self._sso_handler.render_error(request, "mapping_error", str(e))
return
# Mapping providers might not have get_extra_attributes: only call this
@@ -770,7 +749,7 @@ class OidcHandler:
macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (ui_auth_session_id,)
)
- now = self._clock.time_msec()
+ now = self.clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
@@ -845,7 +824,7 @@ class OidcHandler:
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
- now = self._clock.time_msec()
+ now = self.clock.time_msec()
return now < expiry
async def _map_userinfo_to_user(
@@ -885,20 +864,14 @@ class OidcHandler:
# to be strings.
remote_user_id = str(remote_user_id)
- logger.info(
- "Looking for existing mapping for user %s:%s",
- self._auth_provider_id,
- remote_user_id,
- )
-
- registered_user_id = await self._datastore.get_user_by_external_id(
+ # first of all, check if we already have a mapping for this user
+ previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
self._auth_provider_id, remote_user_id,
)
+ if previously_registered_user_id:
+ return previously_registered_user_id
- if registered_user_id is not None:
- logger.info("Found existing mapping %s", registered_user_id)
- return registered_user_id
-
+ # Otherwise, generate a new user.
try:
attributes = await self._user_mapping_provider.map_user_attributes(
userinfo, token
@@ -917,8 +890,8 @@ class OidcHandler:
localpart = map_username_to_mxid_localpart(attributes["localpart"])
- user_id = UserID(localpart, self._hostname).to_string()
- users = await self._datastore.get_users_by_id_case_insensitive(user_id)
+ user_id = UserID(localpart, self.server_name).to_string()
+ users = await self.store.get_users_by_id_case_insensitive(user_id)
if users:
if self._allow_existing_users:
if len(users) == 1:
@@ -942,7 +915,8 @@ class OidcHandler:
default_display_name=attributes["display_name"],
user_agent_ips=(user_agent, ip_address),
)
- await self._datastore.record_user_external_id(
+
+ await self.store.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id,
)
return registered_user_id
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 8e014c9bb5..22d1e9d35c 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -25,7 +25,7 @@ The methods that define policy are:
import abc
import logging
from contextlib import contextmanager
-from typing import Dict, Iterable, List, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
from prometheus_client import Counter
from typing_extensions import ContextManager
@@ -46,8 +46,7 @@ from synapse.util.caches.descriptors import cached
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
-MYPY = False
-if MYPY:
+if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 74a1ddd780..dee0ef45e7 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -206,7 +206,9 @@ class ProfileHandler(BaseHandler):
# the join event to update the displayname in the rooms.
# This must be done by the target user himself.
if by_admin:
- requester = create_requester(target_user)
+ requester = create_requester(
+ target_user, authenticated_entity=requester.authenticated_entity,
+ )
await self.store.set_profile_displayname(
target_user.localpart, displayname_to_set
@@ -286,7 +288,9 @@ class ProfileHandler(BaseHandler):
# Same like set_displayname
if by_admin:
- requester = create_requester(target_user)
+ requester = create_requester(
+ target_user, authenticated_entity=requester.authenticated_entity
+ )
await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index ed1ff62599..252f700786 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -52,6 +52,7 @@ class RegistrationHandler(BaseHandler):
self.ratelimiter = hs.get_registration_ratelimiter()
self.macaroon_gen = hs.get_macaroon_generator()
self._server_notices_mxid = hs.config.server_notices_mxid
+ self._server_name = hs.hostname
self.spam_checker = hs.get_spam_checker()
@@ -317,7 +318,8 @@ class RegistrationHandler(BaseHandler):
requires_join = False
if self.hs.config.registration.auto_join_user_id:
fake_requester = create_requester(
- self.hs.config.registration.auto_join_user_id
+ self.hs.config.registration.auto_join_user_id,
+ authenticated_entity=self._server_name,
)
# If the room requires an invite, add the user to the list of invites.
@@ -329,7 +331,9 @@ class RegistrationHandler(BaseHandler):
# being necessary this will occur after the invite was sent.
requires_join = True
else:
- fake_requester = create_requester(user_id)
+ fake_requester = create_requester(
+ user_id, authenticated_entity=self._server_name
+ )
# Choose whether to federate the new room.
if not self.hs.config.registration.autocreate_auto_join_rooms_federated:
@@ -362,7 +366,9 @@ class RegistrationHandler(BaseHandler):
# created it, then ensure the first user joins it.
if requires_join:
await room_member_handler.update_membership(
- requester=create_requester(user_id),
+ requester=create_requester(
+ user_id, authenticated_entity=self._server_name
+ ),
target=UserID.from_string(user_id),
room_id=info["room_id"],
# Since it was just created, there are no remote hosts.
@@ -370,11 +376,6 @@ class RegistrationHandler(BaseHandler):
action="join",
ratelimit=False,
)
-
- except ConsentNotGivenError as e:
- # Technically not necessary to pull out this error though
- # moving away from bare excepts is a good thing to do.
- logger.error("Failed to join new user to %r: %r", r, e)
except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e)
@@ -426,7 +427,8 @@ class RegistrationHandler(BaseHandler):
if requires_invite:
await room_member_handler.update_membership(
requester=create_requester(
- self.hs.config.registration.auto_join_user_id
+ self.hs.config.registration.auto_join_user_id,
+ authenticated_entity=self._server_name,
),
target=UserID.from_string(user_id),
room_id=room_id,
@@ -437,7 +439,9 @@ class RegistrationHandler(BaseHandler):
# Send the join.
await room_member_handler.update_membership(
- requester=create_requester(user_id),
+ requester=create_requester(
+ user_id, authenticated_entity=self._server_name
+ ),
target=UserID.from_string(user_id),
room_id=room_id,
remote_room_hosts=remote_room_hosts,
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index e73031475f..930047e730 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -587,7 +587,7 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- await self.auth.check_auth_blocking(user_id)
+ await self.auth.check_auth_blocking(requester=requester)
if (
self._server_notices_mxid is not None
@@ -1257,7 +1257,9 @@ class RoomShutdownHandler:
400, "User must be our own: %s" % (new_room_user_id,)
)
- room_creator_requester = create_requester(new_room_user_id)
+ room_creator_requester = create_requester(
+ new_room_user_id, authenticated_entity=requester_user_id
+ )
info, stream_id = await self._room_creation_handler.create_room(
room_creator_requester,
@@ -1297,7 +1299,9 @@ class RoomShutdownHandler:
try:
# Kick users from room
- target_requester = create_requester(user_id)
+ target_requester = create_requester(
+ user_id, authenticated_entity=requester_user_id
+ )
_, stream_id = await self.room_member_handler.update_membership(
requester=target_requester,
target=target_requester.user,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index fd85e08973..70f8966267 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -965,6 +965,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
self.distributor = hs.get_distributor()
self.distributor.declare("user_left_room")
+ self._server_name = hs.hostname
async def _is_remote_room_too_complex(
self, room_id: str, remote_room_hosts: List[str]
@@ -1059,7 +1060,9 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return event_id, stream_id
# The room is too large. Leave.
- requester = types.create_requester(user, None, False, False, None)
+ requester = types.create_requester(
+ user, authenticated_entity=self._server_name
+ )
await self.update_membership(
requester=requester, target=user, room_id=room_id, action="leave"
)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index fd6c5e9ea8..aee772239a 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -24,7 +24,8 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.config.saml2_config import SamlAttributeRequirement
-from synapse.http.server import respond_with_html
+from synapse.handlers._base import BaseHandler
+from synapse.handlers.sso import MappingException
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi
@@ -42,10 +43,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class MappingException(Exception):
- """Used to catch errors when mapping the SAML2 response to a user."""
-
-
@attr.s(slots=True)
class Saml2SessionData:
"""Data we track about SAML2 sessions"""
@@ -57,17 +54,13 @@ class Saml2SessionData:
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
-class SamlHandler:
+class SamlHandler(BaseHandler):
def __init__(self, hs: "synapse.server.HomeServer"):
- self.hs = hs
+ super().__init__(hs)
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
- self._auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
- self._clock = hs.get_clock()
- self._datastore = hs.get_datastore()
- self._hostname = hs.hostname
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
self._grandfathered_mxid_source_attribute = (
hs.config.saml2_grandfathered_mxid_source_attribute
@@ -88,26 +81,9 @@ class SamlHandler:
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
# a lock on the mappings
- self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
-
- def _render_error(
- self, request, error: str, error_description: Optional[str] = None
- ) -> None:
- """Render the error template and respond to the request with it.
+ self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
- This is used to show errors to the user. The template of this page can
- be found under `synapse/res/templates/sso_error.html`.
-
- Args:
- request: The incoming request from the browser.
- We'll respond with an HTML page describing the error.
- error: A technical identifier for this error.
- error_description: A human-readable description of the error.
- """
- html = self._error_template.render(
- error=error, error_description=error_description
- )
- respond_with_html(request, 400, html)
+ self._sso_handler = hs.get_sso_handler()
def handle_redirect_request(
self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
@@ -130,7 +106,7 @@ class SamlHandler:
# Since SAML sessions timeout it is useful to log when they were created.
logger.info("Initiating a new SAML session: %s" % (reqid,))
- now = self._clock.time_msec()
+ now = self.clock.time_msec()
self._outstanding_requests_dict[reqid] = Saml2SessionData(
creation_time=now, ui_auth_session_id=ui_auth_session_id,
)
@@ -171,12 +147,12 @@ class SamlHandler:
# in the (user-visible) exception message, so let's log the exception here
# so we can track down the session IDs later.
logger.warning(str(e))
- self._render_error(
+ self._sso_handler.render_error(
request, "unsolicited_response", "Unexpected SAML2 login."
)
return
except Exception as e:
- self._render_error(
+ self._sso_handler.render_error(
request,
"invalid_response",
"Unable to parse SAML2 response: %s." % (e,),
@@ -184,7 +160,7 @@ class SamlHandler:
return
if saml2_auth.not_signed:
- self._render_error(
+ self._sso_handler.render_error(
request, "unsigned_respond", "SAML2 response was not signed."
)
return
@@ -210,7 +186,7 @@ class SamlHandler:
# attributes.
for requirement in self._saml2_attribute_requirements:
if not _check_attribute_requirement(saml2_auth.ava, requirement):
- self._render_error(
+ self._sso_handler.render_error(
request, "unauthorised", "You are not authorised to log in here."
)
return
@@ -226,7 +202,7 @@ class SamlHandler:
)
except MappingException as e:
logger.exception("Could not map user")
- self._render_error(request, "mapping_error", str(e))
+ self._sso_handler.render_error(request, "mapping_error", str(e))
return
# Complete the interactive auth session or the login.
@@ -274,17 +250,11 @@ class SamlHandler:
with (await self._mapping_lock.queue(self._auth_provider_id)):
# first of all, check if we already have a mapping for this user
- logger.info(
- "Looking for existing mapping for user %s:%s",
- self._auth_provider_id,
- remote_user_id,
+ previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
+ self._auth_provider_id, remote_user_id,
)
- registered_user_id = await self._datastore.get_user_by_external_id(
- self._auth_provider_id, remote_user_id
- )
- if registered_user_id is not None:
- logger.info("Found existing mapping %s", registered_user_id)
- return registered_user_id
+ if previously_registered_user_id:
+ return previously_registered_user_id
# backwards-compatibility hack: see if there is an existing user with a
# suitable mapping from the uid
@@ -294,7 +264,7 @@ class SamlHandler:
):
attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
user_id = UserID(
- map_username_to_mxid_localpart(attrval), self._hostname
+ map_username_to_mxid_localpart(attrval), self.server_name
).to_string()
logger.info(
"Looking for existing account based on mapped %s %s",
@@ -302,11 +272,11 @@ class SamlHandler:
user_id,
)
- users = await self._datastore.get_users_by_id_case_insensitive(user_id)
+ users = await self.store.get_users_by_id_case_insensitive(user_id)
if users:
registered_user_id = list(users.keys())[0]
logger.info("Grandfathering mapping to %s", registered_user_id)
- await self._datastore.record_user_external_id(
+ await self.store.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id
@@ -335,8 +305,8 @@ class SamlHandler:
emails = attribute_dict.get("emails", [])
# Check if this mxid already exists
- if not await self._datastore.get_users_by_id_case_insensitive(
- UserID(localpart, self._hostname).to_string()
+ if not await self.store.get_users_by_id_case_insensitive(
+ UserID(localpart, self.server_name).to_string()
):
# This mxid is free
break
@@ -348,7 +318,6 @@ class SamlHandler:
)
logger.info("Mapped SAML user to local part %s", localpart)
-
registered_user_id = await self._registration_handler.register_user(
localpart=localpart,
default_display_name=displayname,
@@ -356,13 +325,13 @@ class SamlHandler:
user_agent_ips=(user_agent, ip_address),
)
- await self._datastore.record_user_external_id(
+ await self.store.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id
def expire_sessions(self):
- expire_before = self._clock.time_msec() - self._saml2_session_lifetime
+ expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set()
for reqid, data in self._outstanding_requests_dict.items():
if data.creation_time < expire_before:
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
new file mode 100644
index 0000000000..9cb1866a71
--- /dev/null
+++ b/synapse/handlers/sso.py
@@ -0,0 +1,90 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Optional
+
+from synapse.handlers._base import BaseHandler
+from synapse.http.server import respond_with_html
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class MappingException(Exception):
+ """Used to catch errors when mapping the UserInfo object
+ """
+
+
+class SsoHandler(BaseHandler):
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+ self._error_template = hs.config.sso_error_template
+
+ def render_error(
+ self, request, error: str, error_description: Optional[str] = None
+ ) -> None:
+ """Renders the error template and responds with it.
+
+ This is used to show errors to the user. The template of this page can
+ be found under `synapse/res/templates/sso_error.html`.
+
+ Args:
+ request: The incoming request from the browser.
+ We'll respond with an HTML page describing the error.
+ error: A technical identifier for this error.
+ error_description: A human-readable description of the error.
+ """
+ html = self._error_template.render(
+ error=error, error_description=error_description
+ )
+ respond_with_html(request, 400, html)
+
+ async def get_sso_user_by_remote_user_id(
+ self, auth_provider_id: str, remote_user_id: str
+ ) -> Optional[str]:
+ """
+ Maps the user ID of a remote IdP to a mxid for a previously seen user.
+
+ If the user has not been seen yet, this will return None.
+
+ Args:
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+ remote_user_id: The user ID according to the remote IdP. This might
+ be an e-mail address, a GUID, or some other form. It must be
+ unique and immutable.
+
+ Returns:
+ The mxid of a previously seen user.
+ """
+ # Check if we already have a mapping for this user.
+ logger.info(
+ "Looking for existing mapping for user %s:%s",
+ auth_provider_id,
+ remote_user_id,
+ )
+ previously_registered_user_id = await self.store.get_user_by_external_id(
+ auth_provider_id, remote_user_id,
+ )
+
+ # A match was found, return the user ID.
+ if previously_registered_user_id is not None:
+ logger.info("Found existing mapping %s", previously_registered_user_id)
+ return previously_registered_user_id
+
+ # No match.
+ return None
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 32e53c2d25..9827c7eb8d 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -31,6 +31,7 @@ from synapse.types import (
Collection,
JsonDict,
MutableStateMap,
+ Requester,
RoomStreamToken,
StateMap,
StreamToken,
@@ -260,6 +261,7 @@ class SyncHandler:
async def wait_for_sync_for_user(
self,
+ requester: Requester,
sync_config: SyncConfig,
since_token: Optional[StreamToken] = None,
timeout: int = 0,
@@ -273,7 +275,7 @@ class SyncHandler:
# not been exceeded (if not part of the group by this point, almost certain
# auth_blocking will occur)
user_id = sync_config.user.to_string()
- await self.auth.check_auth_blocking(user_id)
+ await self.auth.check_auth_blocking(requester=requester)
res = await self.response_cache.wrap(
sync_config.request_key,
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 0142542852..72ab5750cc 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -49,6 +49,7 @@ class ModuleApi:
self._store = hs.get_datastore()
self._auth = hs.get_auth()
self._auth_handler = auth_handler
+ self._server_name = hs.hostname
# We expose these as properties below in order to attach a helpful docstring.
self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient
@@ -336,7 +337,9 @@ class ModuleApi:
SynapseError if the event was not allowed.
"""
# Create a requester object
- requester = create_requester(event_dict["sender"])
+ requester = create_requester(
+ event_dict["sender"], authenticated_entity=self._server_name
+ )
# Create and send the event
(
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 2a4f7a1740..7a3a5c46ca 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -61,6 +61,7 @@ from synapse.rest.admin.users import (
UserRestServletV2,
UsersRestServlet,
UsersRestServletV2,
+ UserTokenRestServlet,
WhoisRestServlet,
)
from synapse.types import RoomStreamToken
@@ -223,6 +224,7 @@ def register_servlets(hs, http_server):
UserAdminServlet(hs).register(http_server)
UserMediaRestServlet(hs).register(http_server)
UserMembershipRestServlet(hs).register(http_server)
+ UserTokenRestServlet(hs).register(http_server)
UserRestServletV2(hs).register(http_server)
UsersRestServletV2(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index f5304ff43d..ee345e12ce 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -309,7 +309,9 @@ class JoinRoomAliasServlet(RestServlet):
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
- fake_requester = create_requester(target_user)
+ fake_requester = create_requester(
+ target_user, authenticated_entity=requester.authenticated_entity
+ )
# send invite if room has "JoinRules.INVITE"
room_state = await self.state_handler.get_current_state(room_id)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 3638e219f2..fa8d8e6d91 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -16,7 +16,7 @@ import hashlib
import hmac
import logging
from http import HTTPStatus
-from typing import Tuple
+from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -37,6 +37,9 @@ from synapse.rest.admin._base import (
)
from synapse.types import JsonDict, UserID
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
_GET_PUSHERS_ALLOWED_KEYS = {
@@ -828,3 +831,52 @@ class UserMediaRestServlet(RestServlet):
ret["next_token"] = start + len(media)
return 200, ret
+
+
+class UserTokenRestServlet(RestServlet):
+ """An admin API for logging in as a user.
+
+ Example:
+
+ POST /_synapse/admin/v1/users/@test:example.com/login
+ {}
+
+ 200 OK
+ {
+ "access_token": "<some_token>"
+ }
+ """
+
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/login$")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.auth_handler = hs.get_auth_handler()
+
+ async def on_POST(self, request, user_id):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+ auth_user = requester.user
+
+ if not self.hs.is_mine_id(user_id):
+ raise SynapseError(400, "Only local users can be logged in as")
+
+ body = parse_json_object_from_request(request, allow_empty_body=True)
+
+ valid_until_ms = body.get("valid_until_ms")
+ if valid_until_ms and not isinstance(valid_until_ms, int):
+ raise SynapseError(400, "'valid_until_ms' parameter must be an int")
+
+ if auth_user.to_string() == user_id:
+ raise SynapseError(400, "Cannot use admin API to login as self")
+
+ token = await self.auth_handler.get_access_token_for_user_id(
+ user_id=auth_user.to_string(),
+ device_id=None,
+ valid_until_ms=valid_until_ms,
+ puppets_user_id=user_id,
+ )
+
+ return 200, {"access_token": token}
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 25d3cc6148..93c06afe27 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -18,7 +18,7 @@
import logging
import re
-from typing import List, Optional
+from typing import TYPE_CHECKING, List, Optional
from urllib import parse as urlparse
from synapse.api.constants import EventTypes, Membership
@@ -48,8 +48,7 @@ from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID,
from synapse.util import json_decoder
from synapse.util.stringutils import random_string
-MYPY = False
-if MYPY:
+if TYPE_CHECKING:
import synapse.server
logger = logging.getLogger(__name__)
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 2b84eb89c0..8e52e4cca4 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -171,6 +171,7 @@ class SyncRestServlet(RestServlet):
)
with context:
sync_result = await self.sync_handler.wait_for_sync_for_user(
+ requester,
sync_config,
since_token=since_token,
timeout=timeout,
diff --git a/synapse/server.py b/synapse/server.py
index 21a232bbd9..12a783de17 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -89,6 +89,7 @@ from synapse.handlers.room_member import RoomMemberMasterHandler
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
from synapse.handlers.search import SearchHandler
from synapse.handlers.set_password import SetPasswordHandler
+from synapse.handlers.sso import SsoHandler
from synapse.handlers.stats import StatsHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
@@ -391,6 +392,10 @@ class HomeServer(metaclass=abc.ABCMeta):
return FollowerTypingHandler(self)
@cache_in_self
+ def get_sso_handler(self) -> SsoHandler:
+ return SsoHandler(self)
+
+ @cache_in_self
def get_sync_handler(self) -> SyncHandler:
return SyncHandler(self)
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index d464c75c03..100dbd5e2c 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -39,6 +39,7 @@ class ServerNoticesManager:
self._room_member_handler = hs.get_room_member_handler()
self._event_creation_handler = hs.get_event_creation_handler()
self._is_mine_id = hs.is_mine_id
+ self._server_name = hs.hostname
self._notifier = hs.get_notifier()
self.server_notices_mxid = self._config.server_notices_mxid
@@ -72,7 +73,9 @@ class ServerNoticesManager:
await self.maybe_invite_user_to_room(user_id, room_id)
system_mxid = self._config.server_notices_mxid
- requester = create_requester(system_mxid)
+ requester = create_requester(
+ system_mxid, authenticated_entity=self._server_name
+ )
logger.info("Sending server notice to %s", user_id)
@@ -145,7 +148,9 @@ class ServerNoticesManager:
"avatar_url": self._config.server_notices_mxid_avatar_url,
}
- requester = create_requester(self.server_notices_mxid)
+ requester = create_requester(
+ self.server_notices_mxid, authenticated_entity=self._server_name
+ )
info, _ = await self._room_creation_handler.create_room(
requester,
config={
@@ -174,7 +179,9 @@ class ServerNoticesManager:
user_id: The ID of the user to invite.
room_id: The ID of the room to invite the user to.
"""
- requester = create_requester(self.server_notices_mxid)
+ requester = create_requester(
+ self.server_notices_mxid, authenticated_entity=self._server_name
+ )
# Check whether the user has already joined or been invited to this room. If
# that's the case, there is no need to re-invite them.
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e5d07ce72a..fedb8a6c26 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1110,6 +1110,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
token: str,
device_id: Optional[str],
valid_until_ms: Optional[int],
+ puppets_user_id: Optional[str] = None,
) -> int:
"""Adds an access token for the given user.
@@ -1133,6 +1134,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"token": token,
"device_id": device_id,
"valid_until_ms": valid_until_ms,
+ "puppets_user_id": puppets_user_id,
},
desc="add_access_token_to_user",
)
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 0fd55f428a..ee5217b074 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -282,7 +282,11 @@ class AuthTestCase(unittest.TestCase):
)
)
self.store.add_access_token_to_user.assert_called_with(
- USER_ID, token, "DEVICE", None
+ user_id=USER_ID,
+ token=token,
+ device_id="DEVICE",
+ valid_until_ms=None,
+ puppets_user_id=None,
)
def get_user(tok):
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 0d51705849..630e6da808 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -154,6 +154,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
self.handler = OidcHandler(hs)
+ # Mock the render error method.
+ self.render_error = Mock(return_value=None)
+ self.handler._sso_handler.render_error = self.render_error
return hs
@@ -161,12 +164,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
return patch.dict(self.handler._provider_metadata, values)
def assertRenderedError(self, error, error_description=None):
- args = self.handler._render_error.call_args[0]
+ args = self.render_error.call_args[0]
self.assertEqual(args[1], error)
if error_description is not None:
self.assertEqual(args[2], error_description)
# Reset the render_error mock
- self.handler._render_error.reset_mock()
+ self.render_error.reset_mock()
def test_config(self):
"""Basic config correctly sets up the callback URL and client auth correctly."""
@@ -356,7 +359,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
def test_callback_error(self):
"""Errors from the provider returned in the callback are displayed."""
- self.handler._render_error = Mock()
request = Mock(args={})
request.args[b"error"] = [b"invalid_client"]
self.get_success(self.handler.handle_oidc_callback(request))
@@ -387,7 +389,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
"preferred_username": "bar",
}
user_id = "@foo:domain.org"
- self.handler._render_error = Mock(return_value=None)
self.handler._exchange_code = simple_async_mock(return_value=token)
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
@@ -435,7 +436,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
userinfo, token, user_agent, ip_address
)
self.handler._fetch_userinfo.assert_not_called()
- self.handler._render_error.assert_not_called()
+ self.render_error.assert_not_called()
# Handle mapping errors
self.handler._map_userinfo_to_user = simple_async_mock(
@@ -469,7 +470,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
userinfo, token, user_agent, ip_address
)
self.handler._fetch_userinfo.assert_called_once_with(token)
- self.handler._render_error.assert_not_called()
+ self.render_error.assert_not_called()
# Handle userinfo fetching error
self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
@@ -485,7 +486,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
def test_callback_session(self):
"""The callback verifies the session presence and validity"""
- self.handler._render_error = Mock(return_value=None)
request = Mock(spec=["args", "getCookie", "addCookie"])
# Missing cookie
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index e178d7765b..e62586142e 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -16,7 +16,7 @@
from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
from synapse.handlers.sync import SyncConfig
-from synapse.types import UserID
+from synapse.types import UserID, create_requester
import tests.unittest
import tests.utils
@@ -38,6 +38,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
user_id1 = "@user1:test"
user_id2 = "@user2:test"
sync_config = self._generate_sync_config(user_id1)
+ requester = create_requester(user_id1)
self.reactor.advance(100) # So we get not 0 time
self.auth_blocking._limit_usage_by_mau = True
@@ -45,21 +46,26 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Check that the happy case does not throw errors
self.get_success(self.store.upsert_monthly_active_user(user_id1))
- self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config))
+ self.get_success(
+ self.sync_handler.wait_for_sync_for_user(requester, sync_config)
+ )
# Test that global lock works
self.auth_blocking._hs_disabled = True
e = self.get_failure(
- self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
+ self.sync_handler.wait_for_sync_for_user(requester, sync_config),
+ ResourceLimitError,
)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.auth_blocking._hs_disabled = False
sync_config = self._generate_sync_config(user_id2)
+ requester = create_requester(user_id2)
e = self.get_failure(
- self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
+ self.sync_handler.wait_for_sync_for_user(requester, sync_config),
+ ResourceLimitError,
)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 9b573ac24d..27206ca3db 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -94,12 +94,13 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertFalse(hasattr(event, "state_key"))
self.assertDictEqual(event.content, content)
+ expected_requester = create_requester(
+ user_id, authenticated_entity=self.hs.hostname
+ )
+
# Check that the event was sent
self.event_creation_handler.create_and_send_nonmember_event.assert_called_with(
- create_requester(user_id),
- event_dict,
- ratelimit=False,
- ignore_shadow_ban=True,
+ expected_requester, event_dict, ratelimit=False, ignore_shadow_ban=True,
)
# Create and send a state event
@@ -128,7 +129,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that the event was sent
self.event_creation_handler.create_and_send_nonmember_event.assert_called_with(
- create_requester(user_id),
+ expected_requester,
{
"type": "m.room.power_levels",
"content": content,
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 5699c41eb4..cc8a70be04 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -24,8 +24,8 @@ from mock import Mock
import synapse.rest.admin
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
-from synapse.rest.client.v1 import login, profile, room
-from synapse.rest.client.v2_alpha import sync
+from synapse.rest.client.v1 import login, logout, profile, room
+from synapse.rest.client.v2_alpha import devices, sync
from tests import unittest
from tests.test_utils import make_awaitable
@@ -1543,3 +1543,244 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertIn("last_access_ts", m)
self.assertIn("quarantined_by", m)
self.assertIn("safe_from_quarantine", m)
+
+
+class UserTokenRestTestCase(unittest.HomeserverTestCase):
+ """Test for /_synapse/admin/v1/users/<user>/login
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ room.register_servlets,
+ devices.register_servlets,
+ logout.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+ self.url = "/_synapse/admin/v1/users/%s/login" % urllib.parse.quote(
+ self.other_user
+ )
+
+ def _get_token(self) -> str:
+ request, channel = self.make_request(
+ "POST", self.url, b"{}", access_token=self.admin_user_tok
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ return channel.json_body["access_token"]
+
+ def test_no_auth(self):
+ """Try to login as a user without authentication.
+ """
+ request, channel = self.make_request("POST", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_not_admin(self):
+ """Try to login as a user as a non-admin user.
+ """
+ request, channel = self.make_request(
+ "POST", self.url, b"{}", access_token=self.other_user_tok
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+
+ def test_send_event(self):
+ """Test that sending event as a user works.
+ """
+ # Create a room.
+ room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_tok)
+
+ # Login in as the user
+ puppet_token = self._get_token()
+
+ # Test that sending works, and generates the event as the right user.
+ resp = self.helper.send_event(room_id, "com.example.test", tok=puppet_token)
+ event_id = resp["event_id"]
+ event = self.get_success(self.store.get_event(event_id))
+ self.assertEqual(event.sender, self.other_user)
+
+ def test_devices(self):
+ """Tests that logging in as a user doesn't create a new device for them.
+ """
+ # Login in as the user
+ self._get_token()
+
+ # Check that we don't see a new device in our devices list
+ request, channel = self.make_request(
+ "GET", "devices", b"{}", access_token=self.other_user_tok
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # We should only see the one device (from the login in `prepare`)
+ self.assertEqual(len(channel.json_body["devices"]), 1)
+
+ def test_logout(self):
+ """Test that calling `/logout` with the token works.
+ """
+ # Login in as the user
+ puppet_token = self._get_token()
+
+ # Test that we can successfully make a request
+ request, channel = self.make_request(
+ "GET", "devices", b"{}", access_token=puppet_token
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Logout with the puppet token
+ request, channel = self.make_request(
+ "POST", "logout", b"{}", access_token=puppet_token
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # The puppet token should no longer work
+ request, channel = self.make_request(
+ "GET", "devices", b"{}", access_token=puppet_token
+ )
+ self.render(request)
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+
+ # .. but the real user's tokens should still work
+ request, channel = self.make_request(
+ "GET", "devices", b"{}", access_token=self.other_user_tok
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ def test_user_logout_all(self):
+ """Tests that the target user calling `/logout/all` does *not* expire
+ the token.
+ """
+ # Login in as the user
+ puppet_token = self._get_token()
+
+ # Test that we can successfully make a request
+ request, channel = self.make_request(
+ "GET", "devices", b"{}", access_token=puppet_token
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Logout all with the real user token
+ request, channel = self.make_request(
+ "POST", "logout/all", b"{}", access_token=self.other_user_tok
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # The puppet token should still work
+ request, channel = self.make_request(
+ "GET", "devices", b"{}", access_token=puppet_token
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # .. but the real user's tokens shouldn't
+ request, channel = self.make_request(
+ "GET", "devices", b"{}", access_token=self.other_user_tok
+ )
+ self.render(request)
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+
+ def test_admin_logout_all(self):
+ """Tests that the admin user calling `/logout/all` does expire the
+ token.
+ """
+ # Login in as the user
+ puppet_token = self._get_token()
+
+ # Test that we can successfully make a request
+ request, channel = self.make_request(
+ "GET", "devices", b"{}", access_token=puppet_token
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Logout all with the admin user token
+ request, channel = self.make_request(
+ "POST", "logout/all", b"{}", access_token=self.admin_user_tok
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # The puppet token should no longer work
+ request, channel = self.make_request(
+ "GET", "devices", b"{}", access_token=puppet_token
+ )
+ self.render(request)
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+
+ # .. but the real user's tokens should still work
+ request, channel = self.make_request(
+ "GET", "devices", b"{}", access_token=self.other_user_tok
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ @unittest.override_config(
+ {
+ "public_baseurl": "https://example.org/",
+ "user_consent": {
+ "version": "1.0",
+ "policy_name": "My Cool Privacy Policy",
+ "template_dir": "/",
+ "require_at_registration": True,
+ "block_events_error": "You should accept the policy",
+ },
+ "form_secret": "123secret",
+ }
+ )
+ def test_consent(self):
+ """Test that sending a message is not subject to the privacy policies.
+ """
+ # Have the admin user accept the terms.
+ self.get_success(self.store.user_set_consent_version(self.admin_user, "1.0"))
+
+ # First, cheekily accept the terms and create a room
+ self.get_success(self.store.user_set_consent_version(self.other_user, "1.0"))
+ room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_tok)
+ self.helper.send_event(room_id, "com.example.test", tok=self.other_user_tok)
+
+ # Now unaccept it and check that we can't send an event
+ self.get_success(self.store.user_set_consent_version(self.other_user, "0.0"))
+ self.helper.send_event(
+ room_id, "com.example.test", tok=self.other_user_tok, expect_code=403
+ )
+
+ # Login in as the user
+ puppet_token = self._get_token()
+
+ # Sending an event on their behalf should work fine
+ self.helper.send_event(room_id, "com.example.test", tok=puppet_token)
+
+ @override_config(
+ {"limit_usage_by_mau": True, "max_mau_value": 1, "mau_trial_days": 0}
+ )
+ def test_mau_limit(self):
+ # Create a room as the admin user. This will bump the monthly active users to 1.
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ # Trying to join as the other user should fail due to reaching MAU limit.
+ self.helper.join(
+ room_id, user=self.other_user, tok=self.other_user_tok, expect_code=403
+ )
+
+ # Logging in as the other user and joining a room should work, even
+ # though the MAU limit would stop the user doing so.
+ puppet_token = self._get_token()
+ self.helper.join(room_id, user=self.other_user, tok=puppet_token)
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 5a1e5c4e66..c13a57dad1 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -309,36 +309,6 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
)
self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
- @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0)
- def test_send_dummy_event_without_consent(self):
- self._create_extremity_rich_graph()
- self._enable_consent_checking()
-
- # Pump the reactor repeatedly so that the background updates have a
- # chance to run. Attempt to add dummy event with user that has not consented
- # Check that dummy event send fails.
- self.pump(10 * 60)
- latest_event_ids = self.get_success(
- self.store.get_latest_event_ids_in_room(self.room_id)
- )
- self.assertTrue(len(latest_event_ids) == self.EXTREMITIES_COUNT)
-
- # Create new user, and add consent
- user2 = self.register_user("user2", "password")
- token2 = self.login("user2", "password")
- self.get_success(
- self.store.user_set_consent_version(user2, self.CONSENT_VERSION)
- )
- self.helper.join(self.room_id, user2, tok=token2)
-
- # Background updates should now cause a dummy event to be added to the graph
- self.pump(10 * 60)
-
- latest_event_ids = self.get_success(
- self.store.get_latest_event_ids_in_room(self.room_id)
- )
- self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
-
@patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=250)
def test_expiry_logic(self):
"""Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion()
diff --git a/tests/test_state.py b/tests/test_state.py
index 80b0ccbc40..6227a3ba95 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -169,6 +169,7 @@ class StateTestCase(unittest.TestCase):
"get_state_handler",
"get_clock",
"get_state_resolution_handler",
+ "hostname",
]
)
hs.config = default_config("tesths", True)
|