From 36224e056a0ba91b4541607c5ad5cd5152d0e672 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 12 Oct 2021 13:50:34 +0100 Subject: Add type hints to `synapse.storage.databases.main.client_ips` (#10972) --- synapse/module_api/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'synapse/module_api/__init__.py') diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 8ae21bc43c..b2a228c231 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -773,9 +773,9 @@ class ModuleApi: # Sanitize some of the data. We don't want to return tokens. return [ UserIpAndAgent( - ip=str(data["ip"]), - user_agent=str(data["user_agent"]), - last_seen=int(data["last_seen"]), + ip=data["ip"], + user_agent=data["user_agent"], + last_seen=data["last_seen"], ) for data in raw_data ] -- cgit 1.5.1 From cdd308845ba22fef22a39ed5bf904b438e48b491 Mon Sep 17 00:00:00 2001 From: Azrenbeth <77782548+Azrenbeth@users.noreply.github.com> Date: Wed, 13 Oct 2021 12:21:52 +0100 Subject: Port the Password Auth Providers module interface to the new generic interface (#10548) Co-authored-by: Azrenbeth <7782548+Azrenbeth@users.noreply.github.com> Co-authored-by: Brendan Abolivier --- changelog.d/10548.feature | 1 + docs/SUMMARY.md | 1 + docs/modules/password_auth_provider_callbacks.md | 153 +++++++ docs/modules/porting_legacy_module.md | 3 + docs/password_auth_providers.md | 6 + docs/sample_config.yaml | 28 -- synapse/app/_base.py | 2 + synapse/config/password_auth_providers.py | 53 +-- synapse/handlers/auth.py | 528 +++++++++++++++++------ synapse/module_api/__init__.py | 9 + synapse/server.py | 6 +- synapse/storage/prepare_database.py | 2 + tests/handlers/test_password_providers.py | 223 ++++++++-- 13 files changed, 790 insertions(+), 225 deletions(-) create mode 100644 changelog.d/10548.feature create mode 100644 docs/modules/password_auth_provider_callbacks.md (limited to 'synapse/module_api/__init__.py') diff --git a/changelog.d/10548.feature b/changelog.d/10548.feature new file mode 100644 index 0000000000..263a811faf --- /dev/null +++ b/changelog.d/10548.feature @@ -0,0 +1 @@ +Port the Password Auth Providers module interface to the new generic interface. \ No newline at end of file diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index bdb44543b8..35412ea92c 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -43,6 +43,7 @@ - [Third-party rules callbacks](modules/third_party_rules_callbacks.md) - [Presence router callbacks](modules/presence_router_callbacks.md) - [Account validity callbacks](modules/account_validity_callbacks.md) + - [Password auth provider callbacks](modules/password_auth_provider_callbacks.md) - [Porting a legacy module to the new interface](modules/porting_legacy_module.md) - [Workers](workers.md) - [Using `synctl` with Workers](synctl_workers.md) diff --git a/docs/modules/password_auth_provider_callbacks.md b/docs/modules/password_auth_provider_callbacks.md new file mode 100644 index 0000000000..36417dd39e --- /dev/null +++ b/docs/modules/password_auth_provider_callbacks.md @@ -0,0 +1,153 @@ +# Password auth provider callbacks + +Password auth providers offer a way for server administrators to integrate +their Synapse installation with an external authentication system. The callbacks can be +registered by using the Module API's `register_password_auth_provider_callbacks` method. + +## Callbacks + +### `auth_checkers` + +``` + auth_checkers: Dict[Tuple[str,Tuple], Callable] +``` + +A dict mapping from tuples of a login type identifier (such as `m.login.password`) and a +tuple of field names (such as `("password", "secret_thing")`) to authentication checking +callbacks, which should be of the following form: + +```python +async def check_auth( + user: str, + login_type: str, + login_dict: "synapse.module_api.JsonDict", +) -> Optional[ + Tuple[ + str, + Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]] + ] +] +``` + +The login type and field names should be provided by the user in the +request to the `/login` API. [The Matrix specification](https://matrix.org/docs/spec/client_server/latest#authentication-types) +defines some types, however user defined ones are also allowed. + +The callback is passed the `user` field provided by the client (which might not be in +`@username:server` form), the login type, and a dictionary of login secrets passed by +the client. + +If the authentication is successful, the module must return the user's Matrix ID (e.g. +`@alice:example.com`) and optionally a callback to be called with the response to the +`/login` request. If the module doesn't wish to return a callback, it must return `None` +instead. + +If the authentication is unsuccessful, the module must return `None`. + +### `check_3pid_auth` + +```python +async def check_3pid_auth( + medium: str, + address: str, + password: str, +) -> Optional[ + Tuple[ + str, + Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]] + ] +] +``` + +Called when a user attempts to register or log in with a third party identifier, +such as email. It is passed the medium (eg. `email`), an address (eg. `jdoe@example.com`) +and the user's password. + +If the authentication is successful, the module must return the user's Matrix ID (e.g. +`@alice:example.com`) and optionally a callback to be called with the response to the `/login` request. +If the module doesn't wish to return a callback, it must return None instead. + +If the authentication is unsuccessful, the module must return None. + +### `on_logged_out` + +```python +async def on_logged_out( + user_id: str, + device_id: Optional[str], + access_token: str +) -> None +``` +Called during a logout request for a user. It is passed the qualified user ID, the ID of the +deactivated device (if any: access tokens are occasionally created without an associated +device ID), and the (now deactivated) access token. + +## Example + +The example module below implements authentication checkers for two different login types: +- `my.login.type` + - Expects a `my_field` field to be sent to `/login` + - Is checked by the method: `self.check_my_login` +- `m.login.password` (defined in [the spec](https://matrix.org/docs/spec/client_server/latest#password-based)) + - Expects a `password` field to be sent to `/login` + - Is checked by the method: `self.check_pass` + + +```python +from typing import Awaitable, Callable, Optional, Tuple + +import synapse +from synapse import module_api + + +class MyAuthProvider: + def __init__(self, config: dict, api: module_api): + + self.api = api + + self.credentials = { + "bob": "building", + "@scoop:matrix.org": "digging", + } + + api.register_password_auth_provider_callbacks( + auth_checkers={ + ("my.login_type", ("my_field",)): self.check_my_login, + ("m.login.password", ("password",)): self.check_pass, + }, + ) + + async def check_my_login( + self, + username: str, + login_type: str, + login_dict: "synapse.module_api.JsonDict", + ) -> Optional[ + Tuple[ + str, + Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]], + ] + ]: + if login_type != "my.login_type": + return None + + if self.credentials.get(username) == login_dict.get("my_field"): + return self.api.get_qualified_user_id(username) + + async def check_pass( + self, + username: str, + login_type: str, + login_dict: "synapse.module_api.JsonDict", + ) -> Optional[ + Tuple[ + str, + Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]], + ] + ]: + if login_type != "m.login.password": + return None + + if self.credentials.get(username) == login_dict.get("password"): + return self.api.get_qualified_user_id(username) +``` diff --git a/docs/modules/porting_legacy_module.md b/docs/modules/porting_legacy_module.md index a7a251e535..89084eb7b3 100644 --- a/docs/modules/porting_legacy_module.md +++ b/docs/modules/porting_legacy_module.md @@ -12,6 +12,9 @@ should register this resource in its `__init__` method using the `register_web_r method from the `ModuleApi` class (see [this section](writing_a_module.html#registering-a-web-resource) for more info). +There is no longer a `get_db_schema_files` callback provided for password auth provider modules. Any +changes to the database should now be made by the module using the module API class. + The module's author should also update any example in the module's configuration to only use the new `modules` section in Synapse's configuration file (see [this section](index.html#using-modules) for more info). diff --git a/docs/password_auth_providers.md b/docs/password_auth_providers.md index d2cdb9b2f4..d7beacfff3 100644 --- a/docs/password_auth_providers.md +++ b/docs/password_auth_providers.md @@ -1,3 +1,9 @@ +

+This page of the Synapse documentation is now deprecated. For up to date +documentation on setting up or writing a password auth provider module, please see +this page. +

+ # Password auth provider modules Password auth providers offer a way for server administrators to diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 166cec38d3..7bfaed483b 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2260,34 +2260,6 @@ email: #email_validation: "[%(server_name)s] Validate your email" -# Password providers allow homeserver administrators to integrate -# their Synapse installation with existing authentication methods -# ex. LDAP, external tokens, etc. -# -# For more information and known implementations, please see -# https://matrix-org.github.io/synapse/latest/password_auth_providers.html -# -# Note: instances wishing to use SAML or CAS authentication should -# instead use the `saml2_config` or `cas_config` options, -# respectively. -# -password_providers: -# # Example config for an LDAP auth provider -# - module: "ldap_auth_provider.LdapAuthProvider" -# config: -# enabled: true -# uri: "ldap://ldap.example.com:389" -# start_tls: true -# base: "ou=users,dc=example,dc=com" -# attributes: -# uid: "cn" -# mail: "email" -# name: "givenName" -# #bind_dn: -# #bind_password: -# #filter: "(objectClass=posixAccount)" - - ## Push ## diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 4a204a5823..bb4d53d778 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -42,6 +42,7 @@ from synapse.crypto import context_factory from synapse.events.presence_router import load_legacy_presence_router from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.events.third_party_rules import load_legacy_third_party_event_rules +from synapse.handlers.auth import load_legacy_password_auth_providers from synapse.logging.context import PreserveLoggingContext from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.jemalloc import setup_jemalloc_stats @@ -379,6 +380,7 @@ async def start(hs: "HomeServer"): load_legacy_spam_checkers(hs) load_legacy_third_party_event_rules(hs) load_legacy_presence_router(hs) + load_legacy_password_auth_providers(hs) # If we've configured an expiry time for caches, start the background job now. setup_expire_lru_cache_entries(hs) diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py index 83994df798..f980102b45 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py @@ -25,6 +25,29 @@ class PasswordAuthProviderConfig(Config): section = "authproviders" def read_config(self, config, **kwargs): + """Parses the old password auth providers config. The config format looks like this: + + password_providers: + # Example config for an LDAP auth provider + - module: "ldap_auth_provider.LdapAuthProvider" + config: + enabled: true + uri: "ldap://ldap.example.com:389" + start_tls: true + base: "ou=users,dc=example,dc=com" + attributes: + uid: "cn" + mail: "email" + name: "givenName" + #bind_dn: + #bind_password: + #filter: "(objectClass=posixAccount)" + + We expect admins to use modules for this feature (which is why it doesn't appear + in the sample config file), but we want to keep support for it around for a bit + for backwards compatibility. + """ + self.password_providers: List[Tuple[Type, Any]] = [] providers = [] @@ -49,33 +72,3 @@ class PasswordAuthProviderConfig(Config): ) self.password_providers.append((provider_class, provider_config)) - - def generate_config_section(self, **kwargs): - return """\ - # Password providers allow homeserver administrators to integrate - # their Synapse installation with existing authentication methods - # ex. LDAP, external tokens, etc. - # - # For more information and known implementations, please see - # https://matrix-org.github.io/synapse/latest/password_auth_providers.html - # - # Note: instances wishing to use SAML or CAS authentication should - # instead use the `saml2_config` or `cas_config` options, - # respectively. - # - password_providers: - # # Example config for an LDAP auth provider - # - module: "ldap_auth_provider.LdapAuthProvider" - # config: - # enabled: true - # uri: "ldap://ldap.example.com:389" - # start_tls: true - # base: "ou=users,dc=example,dc=com" - # attributes: - # uid: "cn" - # mail: "email" - # name: "givenName" - # #bind_dn: - # #bind_password: - # #filter: "(objectClass=posixAccount)" - """ diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index f4612a5b92..ebe75a9e9b 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -200,46 +200,13 @@ class AuthHandler: self.bcrypt_rounds = hs.config.registration.bcrypt_rounds - # we can't use hs.get_module_api() here, because to do so will create an - # import loop. - # - # TODO: refactor this class to separate the lower-level stuff that - # ModuleApi can use from the higher-level stuff that uses ModuleApi, as - # better way to break the loop - account_handler = ModuleApi(hs, self) - - self.password_providers = [ - PasswordProvider.load(module, config, account_handler) - for module, config in hs.config.authproviders.password_providers - ] - - logger.info("Extra password_providers: %s", self.password_providers) + self.password_auth_provider = hs.get_password_auth_provider() self.hs = hs # FIXME better possibility to access registrationHandler later? self.macaroon_gen = hs.get_macaroon_generator() self._password_enabled = hs.config.auth.password_enabled self._password_localdb_enabled = hs.config.auth.password_localdb_enabled - # start out by assuming PASSWORD is enabled; we will remove it later if not. - login_types = set() - if self._password_localdb_enabled: - login_types.add(LoginType.PASSWORD) - - for provider in self.password_providers: - login_types.update(provider.get_supported_login_types().keys()) - - if not self._password_enabled: - login_types.discard(LoginType.PASSWORD) - - # Some clients just pick the first type in the list. In this case, we want - # them to use PASSWORD (rather than token or whatever), so we want to make sure - # that comes first, where it's present. - self._supported_login_types = [] - if LoginType.PASSWORD in login_types: - self._supported_login_types.append(LoginType.PASSWORD) - login_types.remove(LoginType.PASSWORD) - self._supported_login_types.extend(login_types) - # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. self._failed_uia_attempts_ratelimiter = Ratelimiter( @@ -427,11 +394,10 @@ class AuthHandler: ui_auth_types.add(LoginType.PASSWORD) # also allow auth from password providers - for provider in self.password_providers: - for t in provider.get_supported_login_types().keys(): - if t == LoginType.PASSWORD and not self._password_enabled: - continue - ui_auth_types.add(t) + for t in self.password_auth_provider.get_supported_login_types().keys(): + if t == LoginType.PASSWORD and not self._password_enabled: + continue + ui_auth_types.add(t) # if sso is enabled, allow the user to log in via SSO iff they have a mapping # from sso to mxid. @@ -1038,7 +1004,25 @@ class AuthHandler: Returns: login types """ - return self._supported_login_types + # Load any login types registered by modules + # This is stored in the password_auth_provider so this doesn't trigger + # any callbacks + types = list(self.password_auth_provider.get_supported_login_types().keys()) + + # This list should include PASSWORD if (either _password_localdb_enabled is + # true or if one of the modules registered it) AND _password_enabled is true + # Also: + # Some clients just pick the first type in the list. In this case, we want + # them to use PASSWORD (rather than token or whatever), so we want to make sure + # that comes first, where it's present. + if LoginType.PASSWORD in types: + types.remove(LoginType.PASSWORD) + if self._password_enabled: + types.insert(0, LoginType.PASSWORD) + elif self._password_localdb_enabled and self._password_enabled: + types.insert(0, LoginType.PASSWORD) + + return types async def validate_login( self, @@ -1217,15 +1201,20 @@ class AuthHandler: known_login_type = False - for provider in self.password_providers: - supported_login_types = provider.get_supported_login_types() - if login_type not in supported_login_types: - # this password provider doesn't understand this login type - continue - + # Check if login_type matches a type registered by one of the modules + # We don't need to remove LoginType.PASSWORD from the list if password login is + # disabled, since if that were the case then by this point we know that the + # login_type is not LoginType.PASSWORD + supported_login_types = self.password_auth_provider.get_supported_login_types() + # check if the login type being used is supported by a module + if login_type in supported_login_types: + # Make a note that this login type is supported by the server known_login_type = True + # Get all the fields expected for this login types login_fields = supported_login_types[login_type] + # go through the login submission and keep track of which required fields are + # provided/not provided missing_fields = [] login_dict = {} for f in login_fields: @@ -1233,6 +1222,7 @@ class AuthHandler: missing_fields.append(f) else: login_dict[f] = login_submission[f] + # raise an error if any of the expected fields for that login type weren't provided if missing_fields: raise SynapseError( 400, @@ -1240,10 +1230,15 @@ class AuthHandler: % (login_type, missing_fields), ) - result = await provider.check_auth(username, login_type, login_dict) + # call all of the check_auth hooks for that login_type + # it will return a result once the first success is found (or None otherwise) + result = await self.password_auth_provider.check_auth( + username, login_type, login_dict + ) if result: return result + # if no module managed to authenticate the user, then fallback to built in password based auth if login_type == LoginType.PASSWORD and self._password_localdb_enabled: known_login_type = True @@ -1282,11 +1277,16 @@ class AuthHandler: completed login/registration, or `None`. If authentication was unsuccessful, `user_id` and `callback` are both `None`. """ - for provider in self.password_providers: - result = await provider.check_3pid_auth(medium, address, password) - if result: - return result + # call all of the check_3pid_auth callbacks + # Result will be from the first callback that returns something other than None + # If all the callbacks return None, then result is also set to None + result = await self.password_auth_provider.check_3pid_auth( + medium, address, password + ) + if result: + return result + # if result is None then return (None, None) return None, None async def _check_local_password(self, user_id: str, password: str) -> Optional[str]: @@ -1365,13 +1365,12 @@ class AuthHandler: user_info = await self.auth.get_user_by_access_token(access_token) await self.store.delete_access_token(access_token) - # see if any of our auth providers want to know about this - for provider in self.password_providers: - await provider.on_logged_out( - user_id=user_info.user_id, - device_id=user_info.device_id, - access_token=access_token, - ) + # see if any modules want to know about this + await self.password_auth_provider.on_logged_out( + user_id=user_info.user_id, + device_id=user_info.device_id, + access_token=access_token, + ) # delete pushers associated with this access token if user_info.token_id is not None: @@ -1398,12 +1397,11 @@ class AuthHandler: user_id, except_token_id=except_token_id, device_id=device_id ) - # see if any of our auth providers want to know about this - for provider in self.password_providers: - for token, _, device_id in tokens_and_devices: - await provider.on_logged_out( - user_id=user_id, device_id=device_id, access_token=token - ) + # see if any modules want to know about this + for token, _, device_id in tokens_and_devices: + await self.password_auth_provider.on_logged_out( + user_id=user_id, device_id=device_id, access_token=token + ) # delete pushers associated with the access tokens await self.hs.get_pusherpool().remove_pushers_by_access_token( @@ -1811,40 +1809,228 @@ class MacaroonGenerator: return macaroon -class PasswordProvider: - """Wrapper for a password auth provider module +def load_legacy_password_auth_providers(hs: "HomeServer") -> None: + module_api = hs.get_module_api() + for module, config in hs.config.authproviders.password_providers: + load_single_legacy_password_auth_provider( + module=module, config=config, api=module_api + ) - This class abstracts out all of the backwards-compatibility hacks for - password providers, to provide a consistent interface. - """ - @classmethod - def load( - cls, module: Type, config: JsonDict, module_api: ModuleApi - ) -> "PasswordProvider": - try: - pp = module(config=config, account_handler=module_api) - except Exception as e: - logger.error("Error while initializing %r: %s", module, e) - raise - return cls(pp, module_api) +def load_single_legacy_password_auth_provider( + module: Type, config: JsonDict, api: ModuleApi +) -> None: + try: + provider = module(config=config, account_handler=api) + except Exception as e: + logger.error("Error while initializing %r: %s", module, e) + raise + + # The known hooks. If a module implements a method who's name appears in this set + # we'll want to register it + password_auth_provider_methods = { + "check_3pid_auth", + "on_logged_out", + } + + # All methods that the module provides should be async, but this wasn't enforced + # in the old module system, so we wrap them if needed + def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]: + # f might be None if the callback isn't implemented by the module. In this + # case we don't want to register a callback at all so we return None. + if f is None: + return None + + # We need to wrap check_password because its old form would return a boolean + # but we now want it to behave just like check_auth() and return the matrix id of + # the user if authentication succeeded or None otherwise + if f.__name__ == "check_password": + + async def wrapped_check_password( + username: str, login_type: str, login_dict: JsonDict + ) -> Optional[Tuple[str, Optional[Callable]]]: + # We've already made sure f is not None above, but mypy doesn't do well + # across function boundaries so we need to tell it f is definitely not + # None. + assert f is not None + + matrix_user_id = api.get_qualified_user_id(username) + password = login_dict["password"] + + is_valid = await f(matrix_user_id, password) + + if is_valid: + return matrix_user_id, None + + return None - def __init__(self, pp: "PasswordProvider", module_api: ModuleApi): - self._pp = pp - self._module_api = module_api + return wrapped_check_password + + # We need to wrap check_auth as in the old form it could return + # just a str, but now it must return Optional[Tuple[str, Optional[Callable]] + if f.__name__ == "check_auth": + + async def wrapped_check_auth( + username: str, login_type: str, login_dict: JsonDict + ) -> Optional[Tuple[str, Optional[Callable]]]: + # We've already made sure f is not None above, but mypy doesn't do well + # across function boundaries so we need to tell it f is definitely not + # None. + assert f is not None + + result = await f(username, login_type, login_dict) + + if isinstance(result, str): + return result, None + + return result + + return wrapped_check_auth + + # We need to wrap check_3pid_auth as in the old form it could return + # just a str, but now it must return Optional[Tuple[str, Optional[Callable]] + if f.__name__ == "check_3pid_auth": + + async def wrapped_check_3pid_auth( + medium: str, address: str, password: str + ) -> Optional[Tuple[str, Optional[Callable]]]: + # We've already made sure f is not None above, but mypy doesn't do well + # across function boundaries so we need to tell it f is definitely not + # None. + assert f is not None + + result = await f(medium, address, password) + + if isinstance(result, str): + return result, None + + return result - self._supported_login_types = {} + return wrapped_check_3pid_auth - # grandfather in check_password support - if hasattr(self._pp, "check_password"): - self._supported_login_types[LoginType.PASSWORD] = ("password",) + def run(*args: Tuple, **kwargs: Dict) -> Awaitable: + # mypy doesn't do well across function boundaries so we need to tell it + # f is definitely not None. + assert f is not None - g = getattr(self._pp, "get_supported_login_types", None) - if g: - self._supported_login_types.update(g()) + return maybe_awaitable(f(*args, **kwargs)) - def __str__(self) -> str: - return str(self._pp) + return run + + # populate hooks with the implemented methods, wrapped with async_wrapper + hooks = { + hook: async_wrapper(getattr(provider, hook, None)) + for hook in password_auth_provider_methods + } + + supported_login_types = {} + # call get_supported_login_types and add that to the dict + g = getattr(provider, "get_supported_login_types", None) + if g is not None: + # Note the old module style also called get_supported_login_types at loading time + # and it is synchronous + supported_login_types.update(g()) + + auth_checkers = {} + # Legacy modules have a check_auth method which expects to be called with one of + # the keys returned by get_supported_login_types. New style modules register a + # dictionary of login_type->check_auth_method mappings + check_auth = async_wrapper(getattr(provider, "check_auth", None)) + if check_auth is not None: + for login_type, fields in supported_login_types.items(): + # need tuple(fields) since fields can be any Iterable type (so may not be hashable) + auth_checkers[(login_type, tuple(fields))] = check_auth + + # if it has a "check_password" method then it should handle all auth checks + # with login type of LoginType.PASSWORD + check_password = async_wrapper(getattr(provider, "check_password", None)) + if check_password is not None: + # need to use a tuple here for ("password",) not a list since lists aren't hashable + auth_checkers[(LoginType.PASSWORD, ("password",))] = check_password + + api.register_password_auth_provider_callbacks(hooks, auth_checkers=auth_checkers) + + +CHECK_3PID_AUTH_CALLBACK = Callable[ + [str, str, str], + Awaitable[ + Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]] + ], +] +ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable] +CHECK_AUTH_CALLBACK = Callable[ + [str, str, JsonDict], + Awaitable[ + Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]] + ], +] + + +class PasswordAuthProvider: + """ + A class that the AuthHandler calls when authenticating users + It allows modules to provide alternative methods for authentication + """ + + def __init__(self) -> None: + # lists of callbacks + self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = [] + self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = [] + + # Mapping from login type to login parameters + self._supported_login_types: Dict[str, Iterable[str]] = {} + + # Mapping from login type to auth checker callbacks + self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {} + + def register_password_auth_provider_callbacks( + self, + check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None, + on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None, + auth_checkers: Optional[Dict[Tuple[str, Tuple], CHECK_AUTH_CALLBACK]] = None, + ) -> None: + # Register check_3pid_auth callback + if check_3pid_auth is not None: + self.check_3pid_auth_callbacks.append(check_3pid_auth) + + # register on_logged_out callback + if on_logged_out is not None: + self.on_logged_out_callbacks.append(on_logged_out) + + if auth_checkers is not None: + # register a new supported login_type + # Iterate through all of the types being registered + for (login_type, fields), callback in auth_checkers.items(): + # Note: fields may be empty here. This would allow a modules auth checker to + # be called with just 'login_type' and no password or other secrets + + # Need to check that all the field names are strings or may get nasty errors later + for f in fields: + if not isinstance(f, str): + raise RuntimeError( + "A module tried to register support for login type: %s with parameters %s" + " but all parameter names must be strings" + % (login_type, fields) + ) + + # 2 modules supporting the same login type must expect the same fields + # e.g. 1 can't expect "pass" if the other expects "password" + # so throw an exception if that happens + if login_type not in self._supported_login_types.get(login_type, []): + self._supported_login_types[login_type] = fields + else: + fields_currently_supported = self._supported_login_types.get( + login_type + ) + if fields_currently_supported != fields: + raise RuntimeError( + "A module tried to register support for login type: %s with parameters %s" + " but another module had already registered support for that type with parameters %s" + % (login_type, fields, fields_currently_supported) + ) + + # Add the new method to the list of auth_checker_callbacks for this login type + self.auth_checker_callbacks.setdefault(login_type, []).append(callback) def get_supported_login_types(self) -> Mapping[str, Iterable[str]]: """Get the login types supported by this password provider @@ -1852,20 +2038,15 @@ class PasswordProvider: Returns a map from a login type identifier (such as m.login.password) to an iterable giving the fields which must be provided by the user in the submission to the /login API. - - This wrapper adds m.login.password to the list if the underlying password - provider supports the check_password() api. """ + return self._supported_login_types async def check_auth( self, username: str, login_type: str, login_dict: JsonDict - ) -> Optional[Tuple[str, Optional[Callable]]]: + ) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]: """Check if the user has presented valid login credentials - This wrapper also calls check_password() if the underlying password provider - supports the check_password() api and the login type is m.login.password. - Args: username: user id presented by the client. Either an MXID or an unqualified username. @@ -1879,63 +2060,130 @@ class PasswordProvider: user, and `callback` is an optional callback which will be called with the result from the /login call (including access_token, device_id, etc.) """ - # first grandfather in a call to check_password - if login_type == LoginType.PASSWORD: - check_password = getattr(self._pp, "check_password", None) - if check_password: - qualified_user_id = self._module_api.get_qualified_user_id(username) - is_valid = await check_password( - qualified_user_id, login_dict["password"] - ) - if is_valid: - return qualified_user_id, None - check_auth = getattr(self._pp, "check_auth", None) - if not check_auth: - return None - result = await check_auth(username, login_type, login_dict) + # Go through all callbacks for the login type until one returns with a value + # other than None (i.e. until a callback returns a success) + for callback in self.auth_checker_callbacks[login_type]: + try: + result = await callback(username, login_type, login_dict) + except Exception as e: + logger.warning("Failed to run module API callback %s: %s", callback, e) + continue - # Check if the return value is a str or a tuple - if isinstance(result, str): - # If it's a str, set callback function to None - return result, None + if result is not None: + # Check that the callback returned a Tuple[str, Optional[Callable]] + # "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks + # result is always the right type, but as it is 3rd party code it might not be + + if not isinstance(result, tuple) or len(result) != 2: + logger.warning( + "Wrong type returned by module API callback %s: %s, expected" + " Optional[Tuple[str, Optional[Callable]]]", + callback, + result, + ) + continue - return result + # pull out the two parts of the tuple so we can do type checking + str_result, callback_result = result + + # the 1st item in the tuple should be a str + if not isinstance(str_result, str): + logger.warning( # type: ignore[unreachable] + "Wrong type returned by module API callback %s: %s, expected" + " Optional[Tuple[str, Optional[Callable]]]", + callback, + result, + ) + continue + + # the second should be Optional[Callable] + if callback_result is not None: + if not callable(callback_result): + logger.warning( # type: ignore[unreachable] + "Wrong type returned by module API callback %s: %s, expected" + " Optional[Tuple[str, Optional[Callable]]]", + callback, + result, + ) + continue + + # The result is a (str, Optional[callback]) tuple so return the successful result + return result + + # If this point has been reached then none of the callbacks successfully authenticated + # the user so return None + return None async def check_3pid_auth( self, medium: str, address: str, password: str - ) -> Optional[Tuple[str, Optional[Callable]]]: - g = getattr(self._pp, "check_3pid_auth", None) - if not g: - return None - + ) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]: # This function is able to return a deferred that either # resolves None, meaning authentication failure, or upon # success, to a str (which is the user_id) or a tuple of # (user_id, callback_func), where callback_func should be run # after we've finished everything else - result = await g(medium, address, password) - # Check if the return value is a str or a tuple - if isinstance(result, str): - # If it's a str, set callback function to None - return result, None + for callback in self.check_3pid_auth_callbacks: + try: + result = await callback(medium, address, password) + except Exception as e: + logger.warning("Failed to run module API callback %s: %s", callback, e) + continue - return result + if result is not None: + # Check that the callback returned a Tuple[str, Optional[Callable]] + # "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks + # result is always the right type, but as it is 3rd party code it might not be + + if not isinstance(result, tuple) or len(result) != 2: + logger.warning( + "Wrong type returned by module API callback %s: %s, expected" + " Optional[Tuple[str, Optional[Callable]]]", + callback, + result, + ) + continue + + # pull out the two parts of the tuple so we can do type checking + str_result, callback_result = result + + # the 1st item in the tuple should be a str + if not isinstance(str_result, str): + logger.warning( # type: ignore[unreachable] + "Wrong type returned by module API callback %s: %s, expected" + " Optional[Tuple[str, Optional[Callable]]]", + callback, + result, + ) + continue + + # the second should be Optional[Callable] + if callback_result is not None: + if not callable(callback_result): + logger.warning( # type: ignore[unreachable] + "Wrong type returned by module API callback %s: %s, expected" + " Optional[Tuple[str, Optional[Callable]]]", + callback, + result, + ) + continue + + # The result is a (str, Optional[callback]) tuple so return the successful result + return result + + # If this point has been reached then none of the callbacks successfully authenticated + # the user so return None + return None async def on_logged_out( self, user_id: str, device_id: Optional[str], access_token: str ) -> None: - g = getattr(self._pp, "on_logged_out", None) - if not g: - return - # This might return an awaitable, if it does block the log out - # until it completes. - await maybe_awaitable( - g( - user_id=user_id, - device_id=device_id, - access_token=access_token, - ) - ) + # call all of the on_logged_out callbacks + for callback in self.on_logged_out_callbacks: + try: + callback(user_id, device_id, access_token) + except Exception as e: + logger.warning("Failed to run module API callback %s: %s", callback, e) + continue diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index b2a228c231..ab7ef8f950 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -45,6 +45,7 @@ from synapse.http.servlet import parse_json_object_from_request from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.rest.client.login import LoginResponse from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.state import StateFilter @@ -83,6 +84,8 @@ __all__ = [ "DirectServeJsonResource", "ModuleApi", "PRESENCE_ALL_USERS", + "LoginResponse", + "JsonDict", ] logger = logging.getLogger(__name__) @@ -139,6 +142,7 @@ class ModuleApi: self._spam_checker = hs.get_spam_checker() self._account_validity_handler = hs.get_account_validity_handler() self._third_party_event_rules = hs.get_third_party_event_rules() + self._password_auth_provider = hs.get_password_auth_provider() self._presence_router = hs.get_presence_router() ################################################################################# @@ -164,6 +168,11 @@ class ModuleApi: """Registers callbacks for presence router capabilities.""" return self._presence_router.register_presence_router_callbacks + @property + def register_password_auth_provider_callbacks(self): + """Registers callbacks for password auth provider capabilities.""" + return self._password_auth_provider.register_password_auth_provider_callbacks + def register_web_resource(self, path: str, resource: IResource): """Registers a web resource to be served at the given path. diff --git a/synapse/server.py b/synapse/server.py index 5bc045d615..a64c846d1c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -65,7 +65,7 @@ from synapse.handlers.account_data import AccountDataHandler from synapse.handlers.account_validity import AccountValidityHandler from synapse.handlers.admin import AdminHandler from synapse.handlers.appservice import ApplicationServicesHandler -from synapse.handlers.auth import AuthHandler, MacaroonGenerator +from synapse.handlers.auth import AuthHandler, MacaroonGenerator, PasswordAuthProvider from synapse.handlers.cas import CasHandler from synapse.handlers.deactivate_account import DeactivateAccountHandler from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler @@ -687,6 +687,10 @@ class HomeServer(metaclass=abc.ABCMeta): def get_third_party_event_rules(self) -> ThirdPartyEventRules: return ThirdPartyEventRules(self) + @cache_in_self + def get_password_auth_provider(self) -> PasswordAuthProvider: + return PasswordAuthProvider() + @cache_in_self def get_room_member_handler(self) -> RoomMemberHandler: if self.config.worker.worker_app: diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 11ca47ea28..1629d2a53c 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -549,6 +549,8 @@ def _apply_module_schemas( database_engine: config: application config """ + # This is the old way for password_auth_provider modules to make changes + # to the database. This should instead be done using the module API for (mod, _config) in config.authproviders.password_providers: if not hasattr(mod, "get_db_schema_files"): continue diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 38e6d9f536..7dd4a5a367 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -20,6 +20,8 @@ from unittest.mock import Mock from twisted.internet import defer import synapse +from synapse.handlers.auth import load_legacy_password_auth_providers +from synapse.module_api import ModuleApi from synapse.rest.client import devices, login from synapse.types import JsonDict @@ -36,8 +38,8 @@ ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_servi mock_password_provider = Mock() -class PasswordOnlyAuthProvider: - """A password_provider which only implements `check_password`.""" +class LegacyPasswordOnlyAuthProvider: + """A legacy password_provider which only implements `check_password`.""" @staticmethod def parse_config(self): @@ -50,8 +52,8 @@ class PasswordOnlyAuthProvider: return mock_password_provider.check_password(*args) -class CustomAuthProvider: - """A password_provider which implements a custom login type.""" +class LegacyCustomAuthProvider: + """A legacy password_provider which implements a custom login type.""" @staticmethod def parse_config(self): @@ -67,7 +69,23 @@ class CustomAuthProvider: return mock_password_provider.check_auth(*args) -class PasswordCustomAuthProvider: +class CustomAuthProvider: + """A module which registers password_auth_provider callbacks for a custom login type.""" + + @staticmethod + def parse_config(self): + pass + + def __init__(self, config, api: ModuleApi): + api.register_password_auth_provider_callbacks( + auth_checkers={("test.login_type", ("test_field",)): self.check_auth}, + ) + + def check_auth(self, *args): + return mock_password_provider.check_auth(*args) + + +class LegacyPasswordCustomAuthProvider: """A password_provider which implements password login via `check_auth`, as well as a custom type.""" @@ -85,8 +103,32 @@ class PasswordCustomAuthProvider: return mock_password_provider.check_auth(*args) -def providers_config(*providers: Type[Any]) -> dict: - """Returns a config dict that will enable the given password auth providers""" +class PasswordCustomAuthProvider: + """A module which registers password_auth_provider callbacks for a custom login type. + as well as a password login""" + + @staticmethod + def parse_config(self): + pass + + def __init__(self, config, api: ModuleApi): + api.register_password_auth_provider_callbacks( + auth_checkers={ + ("test.login_type", ("test_field",)): self.check_auth, + ("m.login.password", ("password",)): self.check_auth, + }, + ) + pass + + def check_auth(self, *args): + return mock_password_provider.check_auth(*args) + + def check_pass(self, *args): + return mock_password_provider.check_password(*args) + + +def legacy_providers_config(*providers: Type[Any]) -> dict: + """Returns a config dict that will enable the given legacy password auth providers""" return { "password_providers": [ {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}} @@ -95,6 +137,16 @@ def providers_config(*providers: Type[Any]) -> dict: } +def providers_config(*providers: Type[Any]) -> dict: + """Returns a config dict that will enable the given modules""" + return { + "modules": [ + {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}} + for provider in providers + ] + } + + class PasswordAuthProviderTests(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets, @@ -107,8 +159,21 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.reset_mock() super().setUp() - @override_config(providers_config(PasswordOnlyAuthProvider)) - def test_password_only_auth_provider_login(self): + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + # Load the modules into the homeserver + module_api = hs.get_module_api() + for module, config in hs.config.modules.loaded_modules: + module(config=config, api=module_api) + load_legacy_password_auth_providers(hs) + + return hs + + @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) + def test_password_only_auth_progiver_login_legacy(self): + self.password_only_auth_provider_login_test_body() + + def password_only_auth_provider_login_test_body(self): # login flows should only have m.login.password flows = self._get_login_flows() self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) @@ -138,8 +203,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "@ USER🙂NAME :test", " pASS😢word " ) - @override_config(providers_config(PasswordOnlyAuthProvider)) - def test_password_only_auth_provider_ui_auth(self): + @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) + def test_password_only_auth_provider_ui_auth_legacy(self): + self.password_only_auth_provider_ui_auth_test_body() + + def password_only_auth_provider_ui_auth_test_body(self): """UI Auth should delegate correctly to the password provider""" # create the user, otherwise access doesn't work @@ -172,8 +240,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) mock_password_provider.check_password.assert_called_once_with("@u:test", "p") - @override_config(providers_config(PasswordOnlyAuthProvider)) - def test_local_user_fallback_login(self): + @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) + def test_local_user_fallback_login_legacy(self): + self.local_user_fallback_login_test_body() + + def local_user_fallback_login_test_body(self): """rejected login should fall back to local db""" self.register_user("localuser", "localpass") @@ -186,8 +257,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) self.assertEqual("@localuser:test", channel.json_body["user_id"]) - @override_config(providers_config(PasswordOnlyAuthProvider)) - def test_local_user_fallback_ui_auth(self): + @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) + def test_local_user_fallback_ui_auth_legacy(self): + self.local_user_fallback_ui_auth_test_body() + + def local_user_fallback_ui_auth_test_body(self): """rejected login should fall back to local db""" self.register_user("localuser", "localpass") @@ -223,11 +297,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): @override_config( { - **providers_config(PasswordOnlyAuthProvider), + **legacy_providers_config(LegacyPasswordOnlyAuthProvider), "password_config": {"localdb_enabled": False}, } ) - def test_no_local_user_fallback_login(self): + def test_no_local_user_fallback_login_legacy(self): + self.no_local_user_fallback_login_test_body() + + def no_local_user_fallback_login_test_body(self): """localdb_enabled can block login with the local password""" self.register_user("localuser", "localpass") @@ -242,11 +319,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): @override_config( { - **providers_config(PasswordOnlyAuthProvider), + **legacy_providers_config(LegacyPasswordOnlyAuthProvider), "password_config": {"localdb_enabled": False}, } ) - def test_no_local_user_fallback_ui_auth(self): + def test_no_local_user_fallback_ui_auth_legacy(self): + self.no_local_user_fallback_ui_auth_test_body() + + def no_local_user_fallback_ui_auth_test_body(self): """localdb_enabled can block ui auth with the local password""" self.register_user("localuser", "localpass") @@ -280,11 +360,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): @override_config( { - **providers_config(PasswordOnlyAuthProvider), + **legacy_providers_config(LegacyPasswordOnlyAuthProvider), "password_config": {"enabled": False}, } ) - def test_password_auth_disabled(self): + def test_password_auth_disabled_legacy(self): + self.password_auth_disabled_test_body() + + def password_auth_disabled_test_body(self): """password auth doesn't work if it's disabled across the board""" # login flows should be empty flows = self._get_login_flows() @@ -295,8 +378,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_password.assert_not_called() + @override_config(legacy_providers_config(LegacyCustomAuthProvider)) + def test_custom_auth_provider_login_legacy(self): + self.custom_auth_provider_login_test_body() + @override_config(providers_config(CustomAuthProvider)) def test_custom_auth_provider_login(self): + self.custom_auth_provider_login_test_body() + + def custom_auth_provider_login_test_body(self): # login flows should have the custom flow and m.login.password, since we # haven't disabled local password lookup. # (password must come first, because reasons) @@ -312,7 +402,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_auth.assert_not_called() - mock_password_provider.check_auth.return_value = defer.succeed("@user:bz") + mock_password_provider.check_auth.return_value = defer.succeed( + ("@user:bz", None) + ) channel = self._send_login("test.login_type", "u", test_field="y") self.assertEqual(channel.code, 200, channel.result) self.assertEqual("@user:bz", channel.json_body["user_id"]) @@ -325,7 +417,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # in these cases, but at least we can guard against the API changing # unexpectedly mock_password_provider.check_auth.return_value = defer.succeed( - "@ MALFORMED! :bz" + ("@ MALFORMED! :bz", None) ) channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ") self.assertEqual(channel.code, 200, channel.result) @@ -334,8 +426,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): " USER🙂NAME ", "test.login_type", {"test_field": " abc "} ) + @override_config(legacy_providers_config(LegacyCustomAuthProvider)) + def test_custom_auth_provider_ui_auth_legacy(self): + self.custom_auth_provider_ui_auth_test_body() + @override_config(providers_config(CustomAuthProvider)) def test_custom_auth_provider_ui_auth(self): + self.custom_auth_provider_ui_auth_test_body() + + def custom_auth_provider_ui_auth_test_body(self): # register the user and log in twice, to get two devices self.register_user("localuser", "localpass") tok1 = self.login("localuser", "localpass") @@ -367,7 +466,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.reset_mock() # right params, but authing as the wrong user - mock_password_provider.check_auth.return_value = defer.succeed("@user:bz") + mock_password_provider.check_auth.return_value = defer.succeed( + ("@user:bz", None) + ) body["auth"]["test_field"] = "foo" channel = self._delete_device(tok1, "dev2", body) self.assertEqual(channel.code, 403) @@ -379,7 +480,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # and finally, succeed mock_password_provider.check_auth.return_value = defer.succeed( - "@localuser:test" + ("@localuser:test", None) ) channel = self._delete_device(tok1, "dev2", body) self.assertEqual(channel.code, 200) @@ -387,8 +488,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "localuser", "test.login_type", {"test_field": "foo"} ) + @override_config(legacy_providers_config(LegacyCustomAuthProvider)) + def test_custom_auth_provider_callback_legacy(self): + self.custom_auth_provider_callback_test_body() + @override_config(providers_config(CustomAuthProvider)) def test_custom_auth_provider_callback(self): + self.custom_auth_provider_callback_test_body() + + def custom_auth_provider_callback_test_body(self): callback = Mock(return_value=defer.succeed(None)) mock_password_provider.check_auth.return_value = defer.succeed( @@ -410,10 +518,22 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): for p in ["user_id", "access_token", "device_id", "home_server"]: self.assertIn(p, call_args[0]) + @override_config( + { + **legacy_providers_config(LegacyCustomAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_custom_auth_password_disabled_legacy(self): + self.custom_auth_password_disabled_test_body() + @override_config( {**providers_config(CustomAuthProvider), "password_config": {"enabled": False}} ) def test_custom_auth_password_disabled(self): + self.custom_auth_password_disabled_test_body() + + def custom_auth_password_disabled_test_body(self): """Test login with a custom auth provider where password login is disabled""" self.register_user("localuser", "localpass") @@ -425,6 +545,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_auth.assert_not_called() + @override_config( + { + **legacy_providers_config(LegacyCustomAuthProvider), + "password_config": {"enabled": False, "localdb_enabled": False}, + } + ) + def test_custom_auth_password_disabled_localdb_enabled_legacy(self): + self.custom_auth_password_disabled_localdb_enabled_test_body() + @override_config( { **providers_config(CustomAuthProvider), @@ -432,6 +561,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): } ) def test_custom_auth_password_disabled_localdb_enabled(self): + self.custom_auth_password_disabled_localdb_enabled_test_body() + + def custom_auth_password_disabled_localdb_enabled_test_body(self): """Check the localdb_enabled == enabled == False Regression test for https://github.com/matrix-org/synapse/issues/8914: check @@ -448,6 +580,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_auth.assert_not_called() + @override_config( + { + **legacy_providers_config(LegacyPasswordCustomAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_password_custom_auth_password_disabled_login_legacy(self): + self.password_custom_auth_password_disabled_login_test_body() + @override_config( { **providers_config(PasswordCustomAuthProvider), @@ -455,6 +596,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): } ) def test_password_custom_auth_password_disabled_login(self): + self.password_custom_auth_password_disabled_login_test_body() + + def password_custom_auth_password_disabled_login_test_body(self): """log in with a custom auth provider which implements password, but password login is disabled""" self.register_user("localuser", "localpass") @@ -466,6 +610,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): channel = self._send_password_login("localuser", "localpass") self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_auth.assert_not_called() + mock_password_provider.check_password.assert_not_called() + + @override_config( + { + **legacy_providers_config(LegacyPasswordCustomAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_password_custom_auth_password_disabled_ui_auth_legacy(self): + self.password_custom_auth_password_disabled_ui_auth_test_body() @override_config( { @@ -474,12 +628,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): } ) def test_password_custom_auth_password_disabled_ui_auth(self): + self.password_custom_auth_password_disabled_ui_auth_test_body() + + def password_custom_auth_password_disabled_ui_auth_test_body(self): """UI Auth with a custom auth provider which implements password, but password login is disabled""" # register the user and log in twice via the test login type to get two devices, self.register_user("localuser", "localpass") mock_password_provider.check_auth.return_value = defer.succeed( - "@localuser:test" + ("@localuser:test", None) ) channel = self._send_login("test.login_type", "localuser", test_field="") self.assertEqual(channel.code, 200, channel.result) @@ -516,6 +673,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "Password login has been disabled.", channel.json_body["error"] ) mock_password_provider.check_auth.assert_not_called() + mock_password_provider.check_password.assert_not_called() mock_password_provider.reset_mock() # successful auth @@ -526,6 +684,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.check_auth.assert_called_once_with( "localuser", "test.login_type", {"test_field": "x"} ) + mock_password_provider.check_password.assert_not_called() + + @override_config( + { + **legacy_providers_config(LegacyCustomAuthProvider), + "password_config": {"localdb_enabled": False}, + } + ) + def test_custom_auth_no_local_user_fallback_legacy(self): + self.custom_auth_no_local_user_fallback_test_body() @override_config( { @@ -534,6 +702,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): } ) def test_custom_auth_no_local_user_fallback(self): + self.custom_auth_no_local_user_fallback_test_body() + + def custom_auth_no_local_user_fallback_test_body(self): """Test login with a custom auth provider where the local db is disabled""" self.register_user("localuser", "localpass") -- cgit 1.5.1 From 85a09f8b8ba7c8023c0d28a526d32111fc704197 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 25 Oct 2021 13:01:04 +0100 Subject: Fix module API's `get_user_ip_and_agents` function when run on workers (#11112) --- changelog.d/11112.bugfix | 1 + synapse/module_api/__init__.py | 6 +- synapse/storage/databases/main/client_ips.py | 124 ++++++++++++++++++--------- 3 files changed, 91 insertions(+), 40 deletions(-) create mode 100644 changelog.d/11112.bugfix (limited to 'synapse/module_api/__init__.py') diff --git a/changelog.d/11112.bugfix b/changelog.d/11112.bugfix new file mode 100644 index 0000000000..c8e22da8cf --- /dev/null +++ b/changelog.d/11112.bugfix @@ -0,0 +1 @@ +Fix a bug which caused the module API's `get_user_ip_and_agents` function to always fail on workers. `get_user_ip_and_agents` was introduced in 1.44.0 and did not function correctly on worker processes at the time. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index ab7ef8f950..d37252b6b3 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -46,6 +46,7 @@ from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.client.login import LoginResponse +from synapse.storage import DataStore from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.state import StateFilter @@ -61,6 +62,7 @@ from synapse.util import Clock from synapse.util.caches.descriptors import cached if TYPE_CHECKING: + from synapse.app.generic_worker import GenericWorkerSlavedStore from synapse.server import HomeServer """ @@ -111,7 +113,9 @@ class ModuleApi: def __init__(self, hs: "HomeServer", auth_handler): self._hs = hs - self._store = hs.get_datastore() + # TODO: Fix this type hint once the types for the data stores have been ironed + # out. + self._store: Union[DataStore, "GenericWorkerSlavedStore"] = hs.get_datastore() self._auth = hs.get_auth() self._auth_handler = auth_handler self._server_name = hs.hostname diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index b81d9218ce..1dc7f0ebe3 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -478,6 +478,58 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): return {(d["user_id"], d["device_id"]): d for d in res} + async def get_user_ip_and_agents( + self, user: UserID, since_ts: int = 0 + ) -> List[LastConnectionInfo]: + """Fetch the IPs and user agents for a user since the given timestamp. + + The result might be slightly out of date as client IPs are inserted in batches. + + Args: + user: The user for which to fetch IP addresses and user agents. + since_ts: The timestamp after which to fetch IP addresses and user agents, + in milliseconds. + + Returns: + A list of dictionaries, each containing: + * `access_token`: The access token used. + * `ip`: The IP address used. + * `user_agent`: The last user agent seen for this access token and IP + address combination. + * `last_seen`: The timestamp at which this access token and IP address + combination was last seen, in milliseconds. + + Only the latest user agent for each access token and IP address combination + is available. + """ + user_id = user.to_string() + + def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]: + txn.execute( + """ + SELECT access_token, ip, user_agent, last_seen FROM user_ips + WHERE last_seen >= ? AND user_id = ? + ORDER BY last_seen + DESC + """, + (since_ts, user_id), + ) + return cast(List[Tuple[str, str, str, int]], txn.fetchall()) + + rows = await self.db_pool.runInteraction( + desc="get_user_ip_and_agents", func=get_recent + ) + + return [ + { + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "last_seen": last_seen, + } + for access_token, ip, user_agent, last_seen in rows + ] + class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): @@ -622,49 +674,43 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): async def get_user_ip_and_agents( self, user: UserID, since_ts: int = 0 ) -> List[LastConnectionInfo]: + """Fetch the IPs and user agents for a user since the given timestamp. + + Args: + user: The user for which to fetch IP addresses and user agents. + since_ts: The timestamp after which to fetch IP addresses and user agents, + in milliseconds. + + Returns: + A list of dictionaries, each containing: + * `access_token`: The access token used. + * `ip`: The IP address used. + * `user_agent`: The last user agent seen for this access token and IP + address combination. + * `last_seen`: The timestamp at which this access token and IP address + combination was last seen, in milliseconds. + + Only the latest user agent for each access token and IP address combination + is available. """ - Fetch IP/User Agent connection since a given timestamp. - """ - user_id = user.to_string() - results: Dict[Tuple[str, str], Tuple[str, int]] = {} + results: Dict[Tuple[str, str], LastConnectionInfo] = { + (connection["access_token"], connection["ip"]): connection + for connection in await super().get_user_ip_and_agents(user, since_ts) + } + # Overlay data that is pending insertion on top of the results from the + # database. + user_id = user.to_string() for key in self._batch_row_update: - ( - uid, - access_token, - ip, - ) = key + uid, access_token, ip = key if uid == user_id: user_agent, _, last_seen = self._batch_row_update[key] if last_seen >= since_ts: - results[(access_token, ip)] = (user_agent, last_seen) - - def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]: - txn.execute( - """ - SELECT access_token, ip, user_agent, last_seen FROM user_ips - WHERE last_seen >= ? AND user_id = ? - ORDER BY last_seen - DESC - """, - (since_ts, user_id), - ) - return cast(List[Tuple[str, str, str, int]], txn.fetchall()) - - rows = await self.db_pool.runInteraction( - desc="get_user_ip_and_agents", func=get_recent - ) + results[(access_token, ip)] = { + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "last_seen": last_seen, + } - results.update( - ((access_token, ip), (user_agent, last_seen)) - for access_token, ip, user_agent, last_seen in rows - ) - return [ - { - "access_token": access_token, - "ip": ip, - "user_agent": user_agent, - "last_seen": last_seen, - } - for (access_token, ip), (user_agent, last_seen) in results.items() - ] + return list(results.values()) -- cgit 1.5.1 From 8c8e36af0d6c3855de7bd786be14b85f5dae4ea7 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 26 Oct 2021 11:09:10 +0200 Subject: Document the version each module API method was added to Synapse (#11183) --- changelog.d/11183.doc | 1 + synapse/module_api/__init__.py | 99 +++++++++++++++++++++++++++++++++++++----- 2 files changed, 89 insertions(+), 11 deletions(-) create mode 100644 changelog.d/11183.doc (limited to 'synapse/module_api/__init__.py') diff --git a/changelog.d/11183.doc b/changelog.d/11183.doc new file mode 100644 index 0000000000..a171a107af --- /dev/null +++ b/changelog.d/11183.doc @@ -0,0 +1 @@ +Document the version of Synapse that introduced each module API method. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index d37252b6b3..d707a9325d 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -154,27 +154,42 @@ class ModuleApi: @property def register_spam_checker_callbacks(self): - """Registers callbacks for spam checking capabilities.""" + """Registers callbacks for spam checking capabilities. + + Added in Synapse v1.37.0. + """ return self._spam_checker.register_callbacks @property def register_account_validity_callbacks(self): - """Registers callbacks for account validity capabilities.""" + """Registers callbacks for account validity capabilities. + + Added in Synapse v1.39.0. + """ return self._account_validity_handler.register_account_validity_callbacks @property def register_third_party_rules_callbacks(self): - """Registers callbacks for third party event rules capabilities.""" + """Registers callbacks for third party event rules capabilities. + + Added in Synapse v1.39.0. + """ return self._third_party_event_rules.register_third_party_rules_callbacks @property def register_presence_router_callbacks(self): - """Registers callbacks for presence router capabilities.""" + """Registers callbacks for presence router capabilities. + + Added in Synapse v1.42.0. + """ return self._presence_router.register_presence_router_callbacks @property def register_password_auth_provider_callbacks(self): - """Registers callbacks for password auth provider capabilities.""" + """Registers callbacks for password auth provider capabilities. + + Added in Synapse v1.46.0. + """ return self._password_auth_provider.register_password_auth_provider_callbacks def register_web_resource(self, path: str, resource: IResource): @@ -185,6 +200,8 @@ class ModuleApi: If multiple modules register a resource for the same path, the module that appears the highest in the configuration file takes priority. + Added in Synapse v1.37.0. + Args: path: The path to register the resource for. resource: The resource to attach to this path. @@ -199,6 +216,8 @@ class ModuleApi: """Allows making outbound HTTP requests to remote resources. An instance of synapse.http.client.SimpleHttpClient + + Added in Synapse v1.22.0. """ return self._http_client @@ -208,22 +227,32 @@ class ModuleApi: public room list. An instance of synapse.module_api.PublicRoomListManager + + Added in Synapse v1.22.0. """ return self._public_room_list_manager @property def public_baseurl(self) -> str: - """The configured public base URL for this homeserver.""" + """The configured public base URL for this homeserver. + + Added in Synapse v1.39.0. + """ return self._hs.config.server.public_baseurl @property def email_app_name(self) -> str: - """The application name configured in the homeserver's configuration.""" + """The application name configured in the homeserver's configuration. + + Added in Synapse v1.39.0. + """ return self._hs.config.email.email_app_name async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: """Get user info by user_id + Added in Synapse v1.41.0. + Args: user_id: Fully qualified user id. Returns: @@ -239,6 +268,8 @@ class ModuleApi: ) -> Requester: """Check the access_token provided for a request + Added in Synapse v1.39.0. + Args: req: Incoming HTTP request allow_guest: True if guest users should be allowed. If this @@ -264,6 +295,8 @@ class ModuleApi: async def is_user_admin(self, user_id: str) -> bool: """Checks if a user is a server admin. + Added in Synapse v1.39.0. + Args: user_id: The Matrix ID of the user to check. @@ -278,6 +311,8 @@ class ModuleApi: Takes a user id provided by the user and adds the @ and :domain to qualify it, if necessary + Added in Synapse v0.25.0. + Args: username (str): provided user id @@ -291,6 +326,8 @@ class ModuleApi: async def get_profile_for_user(self, localpart: str) -> ProfileInfo: """Look up the profile info for the user with the given localpart. + Added in Synapse v1.39.0. + Args: localpart: The localpart to look up profile information for. @@ -303,6 +340,8 @@ class ModuleApi: """Look up the threepids (email addresses and phone numbers) associated with the given Matrix user ID. + Added in Synapse v1.39.0. + Args: user_id: The Matrix user ID to look up threepids for. @@ -317,6 +356,8 @@ class ModuleApi: def check_user_exists(self, user_id): """Check if user exists. + Added in Synapse v0.25.0. + Args: user_id (str): Complete @user:id @@ -336,6 +377,8 @@ class ModuleApi: return that device to the user. Prefer separate calls to register_user and register_device. + Added in Synapse v0.25.0. + Args: localpart (str): The localpart of the new user. displayname (str|None): The displayname of the new user. @@ -356,6 +399,8 @@ class ModuleApi: ): """Registers a new user with given localpart and optional displayname, emails. + Added in Synapse v1.2.0. + Args: localpart (str): The localpart of the new user. displayname (str|None): The displayname of the new user. @@ -379,6 +424,8 @@ class ModuleApi: def register_device(self, user_id, device_id=None, initial_display_name=None): """Register a device for a user and generate an access token. + Added in Synapse v1.2.0. + Args: user_id (str): full canonical @user:id device_id (str|None): The device ID to check, or None to generate @@ -402,6 +449,8 @@ class ModuleApi: ) -> defer.Deferred: """Record a mapping from an external user id to a mxid + Added in Synapse v1.9.0. + Args: auth_provider: identifier for the remote auth provider external_id: id on that system @@ -421,6 +470,8 @@ class ModuleApi: ) -> str: """Generate a login token suitable for m.login.token authentication + Added in Synapse v1.9.0. + Args: user_id: gives the ID of the user that the token is for @@ -440,6 +491,8 @@ class ModuleApi: def invalidate_access_token(self, access_token): """Invalidate an access token for a user + Added in Synapse v0.25.0. + Args: access_token(str): access token @@ -470,6 +523,8 @@ class ModuleApi: def run_db_interaction(self, desc, func, *args, **kwargs): """Run a function with a database connection + Added in Synapse v0.25.0. + Args: desc (str): description for the transaction, for metrics etc func (func): function to be run. Passed a database cursor object @@ -493,6 +548,8 @@ class ModuleApi: This is deprecated in favor of complete_sso_login_async. + Added in Synapse v1.11.1. + Args: registered_user_id: The MXID that has been registered as a previous step of of this SSO login. @@ -519,6 +576,8 @@ class ModuleApi: want their access token sent to `client_redirect_url`, or redirect them to that URL with a token directly if the URL matches with one of the whitelisted clients. + Added in Synapse v1.13.0. + Args: registered_user_id: The MXID that has been registered as a previous step of of this SSO login. @@ -547,6 +606,8 @@ class ModuleApi: (This is exposed for compatibility with the old SpamCheckerApi. We should probably deprecate it and replace it with an async method in a subclass.) + Added in Synapse v1.22.0. + Args: room_id: The room ID to get state events in. types: The event type and state key (using None @@ -567,6 +628,8 @@ class ModuleApi: async def create_and_send_event_into_room(self, event_dict: JsonDict) -> EventBase: """Create and send an event into a room. Membership events are currently not supported. + Added in Synapse v1.22.0. + Args: event_dict: A dictionary representing the event to send. Required keys are `type`, `room_id`, `sender` and `content`. @@ -607,6 +670,8 @@ class ModuleApi: Note that this method can only be run on the process that is configured to write to the presence stream. By default this is the main process. + + Added in Synapse v1.32.0. """ if self._hs._instance_name not in self._hs.config.worker.writers.presence: raise Exception( @@ -661,6 +726,8 @@ class ModuleApi: Waits `msec` initially before calling `f` for the first time. + Added in Synapse v1.39.0. + Args: f: The function to call repeatedly. f can be either synchronous or asynchronous, and must follow Synapse's logcontext rules. @@ -700,6 +767,8 @@ class ModuleApi: ): """Send an email on behalf of the homeserver. + Added in Synapse v1.39.0. + Args: recipient: The email address for the recipient. subject: The email's subject. @@ -723,6 +792,8 @@ class ModuleApi: By default, Synapse will look for these templates in its configured template directory, but another directory to search in can be provided. + Added in Synapse v1.39.0. + Args: filenames: The name of the template files to look for. custom_template_directory: An additional directory to look for the files in. @@ -740,13 +811,13 @@ class ModuleApi: """ Checks whether an ID (user id, room, ...) comes from this homeserver. + Added in Synapse v1.44.0. + Args: id: any Matrix id (e.g. user id, room id, ...), either as a raw id, e.g. string "@user:example.com" or as a parsed UserID, RoomID, ... Returns: True if id comes from this homeserver, False otherwise. - - Added in Synapse v1.44.0. """ if isinstance(id, DomainSpecificString): return self._hs.is_mine(id) @@ -759,6 +830,8 @@ class ModuleApi: """ Return the list of user IPs and agents for a user. + Added in Synapse v1.44.0. + Args: user_id: the id of a user, local or remote since_ts: a timestamp in seconds since the epoch, @@ -767,8 +840,6 @@ class ModuleApi: The list of all UserIpAndAgent that the user has used to connect to this homeserver since `since_ts`. If the user is remote, this list is empty. - - Added in Synapse v1.44.0. """ # Don't hit the db if this is not a local user. is_mine = False @@ -807,6 +878,8 @@ class PublicRoomListManager: async def room_is_in_public_room_list(self, room_id: str) -> bool: """Checks whether a room is in the public room list. + Added in Synapse v1.22.0. + Args: room_id: The ID of the room. @@ -823,6 +896,8 @@ class PublicRoomListManager: async def add_room_to_public_room_list(self, room_id: str) -> None: """Publishes a room to the public room list. + Added in Synapse v1.22.0. + Args: room_id: The ID of the room. """ @@ -831,6 +906,8 @@ class PublicRoomListManager: async def remove_room_from_public_room_list(self, room_id: str) -> None: """Removes a room from the public room list. + Added in Synapse v1.22.0. + Args: room_id: The ID of the room. """ -- cgit 1.5.1