diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index f4612a5b92..ebe75a9e9b 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -200,46 +200,13 @@ class AuthHandler:
self.bcrypt_rounds = hs.config.registration.bcrypt_rounds
- # we can't use hs.get_module_api() here, because to do so will create an
- # import loop.
- #
- # TODO: refactor this class to separate the lower-level stuff that
- # ModuleApi can use from the higher-level stuff that uses ModuleApi, as
- # better way to break the loop
- account_handler = ModuleApi(hs, self)
-
- self.password_providers = [
- PasswordProvider.load(module, config, account_handler)
- for module, config in hs.config.authproviders.password_providers
- ]
-
- logger.info("Extra password_providers: %s", self.password_providers)
+ self.password_auth_provider = hs.get_password_auth_provider()
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.auth.password_enabled
self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
- # start out by assuming PASSWORD is enabled; we will remove it later if not.
- login_types = set()
- if self._password_localdb_enabled:
- login_types.add(LoginType.PASSWORD)
-
- for provider in self.password_providers:
- login_types.update(provider.get_supported_login_types().keys())
-
- if not self._password_enabled:
- login_types.discard(LoginType.PASSWORD)
-
- # Some clients just pick the first type in the list. In this case, we want
- # them to use PASSWORD (rather than token or whatever), so we want to make sure
- # that comes first, where it's present.
- self._supported_login_types = []
- if LoginType.PASSWORD in login_types:
- self._supported_login_types.append(LoginType.PASSWORD)
- login_types.remove(LoginType.PASSWORD)
- self._supported_login_types.extend(login_types)
-
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
self._failed_uia_attempts_ratelimiter = Ratelimiter(
@@ -427,11 +394,10 @@ class AuthHandler:
ui_auth_types.add(LoginType.PASSWORD)
# also allow auth from password providers
- for provider in self.password_providers:
- for t in provider.get_supported_login_types().keys():
- if t == LoginType.PASSWORD and not self._password_enabled:
- continue
- ui_auth_types.add(t)
+ for t in self.password_auth_provider.get_supported_login_types().keys():
+ if t == LoginType.PASSWORD and not self._password_enabled:
+ continue
+ ui_auth_types.add(t)
# if sso is enabled, allow the user to log in via SSO iff they have a mapping
# from sso to mxid.
@@ -1038,7 +1004,25 @@ class AuthHandler:
Returns:
login types
"""
- return self._supported_login_types
+ # Load any login types registered by modules
+ # This is stored in the password_auth_provider so this doesn't trigger
+ # any callbacks
+ types = list(self.password_auth_provider.get_supported_login_types().keys())
+
+ # This list should include PASSWORD if (either _password_localdb_enabled is
+ # true or if one of the modules registered it) AND _password_enabled is true
+ # Also:
+ # Some clients just pick the first type in the list. In this case, we want
+ # them to use PASSWORD (rather than token or whatever), so we want to make sure
+ # that comes first, where it's present.
+ if LoginType.PASSWORD in types:
+ types.remove(LoginType.PASSWORD)
+ if self._password_enabled:
+ types.insert(0, LoginType.PASSWORD)
+ elif self._password_localdb_enabled and self._password_enabled:
+ types.insert(0, LoginType.PASSWORD)
+
+ return types
async def validate_login(
self,
@@ -1217,15 +1201,20 @@ class AuthHandler:
known_login_type = False
- for provider in self.password_providers:
- supported_login_types = provider.get_supported_login_types()
- if login_type not in supported_login_types:
- # this password provider doesn't understand this login type
- continue
-
+ # Check if login_type matches a type registered by one of the modules
+ # We don't need to remove LoginType.PASSWORD from the list if password login is
+ # disabled, since if that were the case then by this point we know that the
+ # login_type is not LoginType.PASSWORD
+ supported_login_types = self.password_auth_provider.get_supported_login_types()
+ # check if the login type being used is supported by a module
+ if login_type in supported_login_types:
+ # Make a note that this login type is supported by the server
known_login_type = True
+ # Get all the fields expected for this login types
login_fields = supported_login_types[login_type]
+ # go through the login submission and keep track of which required fields are
+ # provided/not provided
missing_fields = []
login_dict = {}
for f in login_fields:
@@ -1233,6 +1222,7 @@ class AuthHandler:
missing_fields.append(f)
else:
login_dict[f] = login_submission[f]
+ # raise an error if any of the expected fields for that login type weren't provided
if missing_fields:
raise SynapseError(
400,
@@ -1240,10 +1230,15 @@ class AuthHandler:
% (login_type, missing_fields),
)
- result = await provider.check_auth(username, login_type, login_dict)
+ # call all of the check_auth hooks for that login_type
+ # it will return a result once the first success is found (or None otherwise)
+ result = await self.password_auth_provider.check_auth(
+ username, login_type, login_dict
+ )
if result:
return result
+ # if no module managed to authenticate the user, then fallback to built in password based auth
if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
known_login_type = True
@@ -1282,11 +1277,16 @@ class AuthHandler:
completed login/registration, or `None`. If authentication was
unsuccessful, `user_id` and `callback` are both `None`.
"""
- for provider in self.password_providers:
- result = await provider.check_3pid_auth(medium, address, password)
- if result:
- return result
+ # call all of the check_3pid_auth callbacks
+ # Result will be from the first callback that returns something other than None
+ # If all the callbacks return None, then result is also set to None
+ result = await self.password_auth_provider.check_3pid_auth(
+ medium, address, password
+ )
+ if result:
+ return result
+ # if result is None then return (None, None)
return None, None
async def _check_local_password(self, user_id: str, password: str) -> Optional[str]:
@@ -1365,13 +1365,12 @@ class AuthHandler:
user_info = await self.auth.get_user_by_access_token(access_token)
await self.store.delete_access_token(access_token)
- # see if any of our auth providers want to know about this
- for provider in self.password_providers:
- await provider.on_logged_out(
- user_id=user_info.user_id,
- device_id=user_info.device_id,
- access_token=access_token,
- )
+ # see if any modules want to know about this
+ await self.password_auth_provider.on_logged_out(
+ user_id=user_info.user_id,
+ device_id=user_info.device_id,
+ access_token=access_token,
+ )
# delete pushers associated with this access token
if user_info.token_id is not None:
@@ -1398,12 +1397,11 @@ class AuthHandler:
user_id, except_token_id=except_token_id, device_id=device_id
)
- # see if any of our auth providers want to know about this
- for provider in self.password_providers:
- for token, _, device_id in tokens_and_devices:
- await provider.on_logged_out(
- user_id=user_id, device_id=device_id, access_token=token
- )
+ # see if any modules want to know about this
+ for token, _, device_id in tokens_and_devices:
+ await self.password_auth_provider.on_logged_out(
+ user_id=user_id, device_id=device_id, access_token=token
+ )
# delete pushers associated with the access tokens
await self.hs.get_pusherpool().remove_pushers_by_access_token(
@@ -1811,40 +1809,228 @@ class MacaroonGenerator:
return macaroon
-class PasswordProvider:
- """Wrapper for a password auth provider module
+def load_legacy_password_auth_providers(hs: "HomeServer") -> None:
+ module_api = hs.get_module_api()
+ for module, config in hs.config.authproviders.password_providers:
+ load_single_legacy_password_auth_provider(
+ module=module, config=config, api=module_api
+ )
- This class abstracts out all of the backwards-compatibility hacks for
- password providers, to provide a consistent interface.
- """
- @classmethod
- def load(
- cls, module: Type, config: JsonDict, module_api: ModuleApi
- ) -> "PasswordProvider":
- try:
- pp = module(config=config, account_handler=module_api)
- except Exception as e:
- logger.error("Error while initializing %r: %s", module, e)
- raise
- return cls(pp, module_api)
+def load_single_legacy_password_auth_provider(
+ module: Type, config: JsonDict, api: ModuleApi
+) -> None:
+ try:
+ provider = module(config=config, account_handler=api)
+ except Exception as e:
+ logger.error("Error while initializing %r: %s", module, e)
+ raise
+
+ # The known hooks. If a module implements a method who's name appears in this set
+ # we'll want to register it
+ password_auth_provider_methods = {
+ "check_3pid_auth",
+ "on_logged_out",
+ }
+
+ # All methods that the module provides should be async, but this wasn't enforced
+ # in the old module system, so we wrap them if needed
+ def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
+ # f might be None if the callback isn't implemented by the module. In this
+ # case we don't want to register a callback at all so we return None.
+ if f is None:
+ return None
+
+ # We need to wrap check_password because its old form would return a boolean
+ # but we now want it to behave just like check_auth() and return the matrix id of
+ # the user if authentication succeeded or None otherwise
+ if f.__name__ == "check_password":
+
+ async def wrapped_check_password(
+ username: str, login_type: str, login_dict: JsonDict
+ ) -> Optional[Tuple[str, Optional[Callable]]]:
+ # We've already made sure f is not None above, but mypy doesn't do well
+ # across function boundaries so we need to tell it f is definitely not
+ # None.
+ assert f is not None
+
+ matrix_user_id = api.get_qualified_user_id(username)
+ password = login_dict["password"]
+
+ is_valid = await f(matrix_user_id, password)
+
+ if is_valid:
+ return matrix_user_id, None
+
+ return None
- def __init__(self, pp: "PasswordProvider", module_api: ModuleApi):
- self._pp = pp
- self._module_api = module_api
+ return wrapped_check_password
+
+ # We need to wrap check_auth as in the old form it could return
+ # just a str, but now it must return Optional[Tuple[str, Optional[Callable]]
+ if f.__name__ == "check_auth":
+
+ async def wrapped_check_auth(
+ username: str, login_type: str, login_dict: JsonDict
+ ) -> Optional[Tuple[str, Optional[Callable]]]:
+ # We've already made sure f is not None above, but mypy doesn't do well
+ # across function boundaries so we need to tell it f is definitely not
+ # None.
+ assert f is not None
+
+ result = await f(username, login_type, login_dict)
+
+ if isinstance(result, str):
+ return result, None
+
+ return result
+
+ return wrapped_check_auth
+
+ # We need to wrap check_3pid_auth as in the old form it could return
+ # just a str, but now it must return Optional[Tuple[str, Optional[Callable]]
+ if f.__name__ == "check_3pid_auth":
+
+ async def wrapped_check_3pid_auth(
+ medium: str, address: str, password: str
+ ) -> Optional[Tuple[str, Optional[Callable]]]:
+ # We've already made sure f is not None above, but mypy doesn't do well
+ # across function boundaries so we need to tell it f is definitely not
+ # None.
+ assert f is not None
+
+ result = await f(medium, address, password)
+
+ if isinstance(result, str):
+ return result, None
+
+ return result
- self._supported_login_types = {}
+ return wrapped_check_3pid_auth
- # grandfather in check_password support
- if hasattr(self._pp, "check_password"):
- self._supported_login_types[LoginType.PASSWORD] = ("password",)
+ def run(*args: Tuple, **kwargs: Dict) -> Awaitable:
+ # mypy doesn't do well across function boundaries so we need to tell it
+ # f is definitely not None.
+ assert f is not None
- g = getattr(self._pp, "get_supported_login_types", None)
- if g:
- self._supported_login_types.update(g())
+ return maybe_awaitable(f(*args, **kwargs))
- def __str__(self) -> str:
- return str(self._pp)
+ return run
+
+ # populate hooks with the implemented methods, wrapped with async_wrapper
+ hooks = {
+ hook: async_wrapper(getattr(provider, hook, None))
+ for hook in password_auth_provider_methods
+ }
+
+ supported_login_types = {}
+ # call get_supported_login_types and add that to the dict
+ g = getattr(provider, "get_supported_login_types", None)
+ if g is not None:
+ # Note the old module style also called get_supported_login_types at loading time
+ # and it is synchronous
+ supported_login_types.update(g())
+
+ auth_checkers = {}
+ # Legacy modules have a check_auth method which expects to be called with one of
+ # the keys returned by get_supported_login_types. New style modules register a
+ # dictionary of login_type->check_auth_method mappings
+ check_auth = async_wrapper(getattr(provider, "check_auth", None))
+ if check_auth is not None:
+ for login_type, fields in supported_login_types.items():
+ # need tuple(fields) since fields can be any Iterable type (so may not be hashable)
+ auth_checkers[(login_type, tuple(fields))] = check_auth
+
+ # if it has a "check_password" method then it should handle all auth checks
+ # with login type of LoginType.PASSWORD
+ check_password = async_wrapper(getattr(provider, "check_password", None))
+ if check_password is not None:
+ # need to use a tuple here for ("password",) not a list since lists aren't hashable
+ auth_checkers[(LoginType.PASSWORD, ("password",))] = check_password
+
+ api.register_password_auth_provider_callbacks(hooks, auth_checkers=auth_checkers)
+
+
+CHECK_3PID_AUTH_CALLBACK = Callable[
+ [str, str, str],
+ Awaitable[
+ Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
+ ],
+]
+ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable]
+CHECK_AUTH_CALLBACK = Callable[
+ [str, str, JsonDict],
+ Awaitable[
+ Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
+ ],
+]
+
+
+class PasswordAuthProvider:
+ """
+ A class that the AuthHandler calls when authenticating users
+ It allows modules to provide alternative methods for authentication
+ """
+
+ def __init__(self) -> None:
+ # lists of callbacks
+ self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = []
+ self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = []
+
+ # Mapping from login type to login parameters
+ self._supported_login_types: Dict[str, Iterable[str]] = {}
+
+ # Mapping from login type to auth checker callbacks
+ self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {}
+
+ def register_password_auth_provider_callbacks(
+ self,
+ check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
+ on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
+ auth_checkers: Optional[Dict[Tuple[str, Tuple], CHECK_AUTH_CALLBACK]] = None,
+ ) -> None:
+ # Register check_3pid_auth callback
+ if check_3pid_auth is not None:
+ self.check_3pid_auth_callbacks.append(check_3pid_auth)
+
+ # register on_logged_out callback
+ if on_logged_out is not None:
+ self.on_logged_out_callbacks.append(on_logged_out)
+
+ if auth_checkers is not None:
+ # register a new supported login_type
+ # Iterate through all of the types being registered
+ for (login_type, fields), callback in auth_checkers.items():
+ # Note: fields may be empty here. This would allow a modules auth checker to
+ # be called with just 'login_type' and no password or other secrets
+
+ # Need to check that all the field names are strings or may get nasty errors later
+ for f in fields:
+ if not isinstance(f, str):
+ raise RuntimeError(
+ "A module tried to register support for login type: %s with parameters %s"
+ " but all parameter names must be strings"
+ % (login_type, fields)
+ )
+
+ # 2 modules supporting the same login type must expect the same fields
+ # e.g. 1 can't expect "pass" if the other expects "password"
+ # so throw an exception if that happens
+ if login_type not in self._supported_login_types.get(login_type, []):
+ self._supported_login_types[login_type] = fields
+ else:
+ fields_currently_supported = self._supported_login_types.get(
+ login_type
+ )
+ if fields_currently_supported != fields:
+ raise RuntimeError(
+ "A module tried to register support for login type: %s with parameters %s"
+ " but another module had already registered support for that type with parameters %s"
+ % (login_type, fields, fields_currently_supported)
+ )
+
+ # Add the new method to the list of auth_checker_callbacks for this login type
+ self.auth_checker_callbacks.setdefault(login_type, []).append(callback)
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
"""Get the login types supported by this password provider
@@ -1852,20 +2038,15 @@ class PasswordProvider:
Returns a map from a login type identifier (such as m.login.password) to an
iterable giving the fields which must be provided by the user in the submission
to the /login API.
-
- This wrapper adds m.login.password to the list if the underlying password
- provider supports the check_password() api.
"""
+
return self._supported_login_types
async def check_auth(
self, username: str, login_type: str, login_dict: JsonDict
- ) -> Optional[Tuple[str, Optional[Callable]]]:
+ ) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
"""Check if the user has presented valid login credentials
- This wrapper also calls check_password() if the underlying password provider
- supports the check_password() api and the login type is m.login.password.
-
Args:
username: user id presented by the client. Either an MXID or an unqualified
username.
@@ -1879,63 +2060,130 @@ class PasswordProvider:
user, and `callback` is an optional callback which will be called with the
result from the /login call (including access_token, device_id, etc.)
"""
- # first grandfather in a call to check_password
- if login_type == LoginType.PASSWORD:
- check_password = getattr(self._pp, "check_password", None)
- if check_password:
- qualified_user_id = self._module_api.get_qualified_user_id(username)
- is_valid = await check_password(
- qualified_user_id, login_dict["password"]
- )
- if is_valid:
- return qualified_user_id, None
- check_auth = getattr(self._pp, "check_auth", None)
- if not check_auth:
- return None
- result = await check_auth(username, login_type, login_dict)
+ # Go through all callbacks for the login type until one returns with a value
+ # other than None (i.e. until a callback returns a success)
+ for callback in self.auth_checker_callbacks[login_type]:
+ try:
+ result = await callback(username, login_type, login_dict)
+ except Exception as e:
+ logger.warning("Failed to run module API callback %s: %s", callback, e)
+ continue
- # Check if the return value is a str or a tuple
- if isinstance(result, str):
- # If it's a str, set callback function to None
- return result, None
+ if result is not None:
+ # Check that the callback returned a Tuple[str, Optional[Callable]]
+ # "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
+ # result is always the right type, but as it is 3rd party code it might not be
+
+ if not isinstance(result, tuple) or len(result) != 2:
+ logger.warning(
+ "Wrong type returned by module API callback %s: %s, expected"
+ " Optional[Tuple[str, Optional[Callable]]]",
+ callback,
+ result,
+ )
+ continue
- return result
+ # pull out the two parts of the tuple so we can do type checking
+ str_result, callback_result = result
+
+ # the 1st item in the tuple should be a str
+ if not isinstance(str_result, str):
+ logger.warning( # type: ignore[unreachable]
+ "Wrong type returned by module API callback %s: %s, expected"
+ " Optional[Tuple[str, Optional[Callable]]]",
+ callback,
+ result,
+ )
+ continue
+
+ # the second should be Optional[Callable]
+ if callback_result is not None:
+ if not callable(callback_result):
+ logger.warning( # type: ignore[unreachable]
+ "Wrong type returned by module API callback %s: %s, expected"
+ " Optional[Tuple[str, Optional[Callable]]]",
+ callback,
+ result,
+ )
+ continue
+
+ # The result is a (str, Optional[callback]) tuple so return the successful result
+ return result
+
+ # If this point has been reached then none of the callbacks successfully authenticated
+ # the user so return None
+ return None
async def check_3pid_auth(
self, medium: str, address: str, password: str
- ) -> Optional[Tuple[str, Optional[Callable]]]:
- g = getattr(self._pp, "check_3pid_auth", None)
- if not g:
- return None
-
+ ) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
# This function is able to return a deferred that either
# resolves None, meaning authentication failure, or upon
# success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run
# after we've finished everything else
- result = await g(medium, address, password)
- # Check if the return value is a str or a tuple
- if isinstance(result, str):
- # If it's a str, set callback function to None
- return result, None
+ for callback in self.check_3pid_auth_callbacks:
+ try:
+ result = await callback(medium, address, password)
+ except Exception as e:
+ logger.warning("Failed to run module API callback %s: %s", callback, e)
+ continue
- return result
+ if result is not None:
+ # Check that the callback returned a Tuple[str, Optional[Callable]]
+ # "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
+ # result is always the right type, but as it is 3rd party code it might not be
+
+ if not isinstance(result, tuple) or len(result) != 2:
+ logger.warning(
+ "Wrong type returned by module API callback %s: %s, expected"
+ " Optional[Tuple[str, Optional[Callable]]]",
+ callback,
+ result,
+ )
+ continue
+
+ # pull out the two parts of the tuple so we can do type checking
+ str_result, callback_result = result
+
+ # the 1st item in the tuple should be a str
+ if not isinstance(str_result, str):
+ logger.warning( # type: ignore[unreachable]
+ "Wrong type returned by module API callback %s: %s, expected"
+ " Optional[Tuple[str, Optional[Callable]]]",
+ callback,
+ result,
+ )
+ continue
+
+ # the second should be Optional[Callable]
+ if callback_result is not None:
+ if not callable(callback_result):
+ logger.warning( # type: ignore[unreachable]
+ "Wrong type returned by module API callback %s: %s, expected"
+ " Optional[Tuple[str, Optional[Callable]]]",
+ callback,
+ result,
+ )
+ continue
+
+ # The result is a (str, Optional[callback]) tuple so return the successful result
+ return result
+
+ # If this point has been reached then none of the callbacks successfully authenticated
+ # the user so return None
+ return None
async def on_logged_out(
self, user_id: str, device_id: Optional[str], access_token: str
) -> None:
- g = getattr(self._pp, "on_logged_out", None)
- if not g:
- return
- # This might return an awaitable, if it does block the log out
- # until it completes.
- await maybe_awaitable(
- g(
- user_id=user_id,
- device_id=device_id,
- access_token=access_token,
- )
- )
+ # call all of the on_logged_out callbacks
+ for callback in self.on_logged_out_callbacks:
+ try:
+ callback(user_id, device_id, access_token)
+ except Exception as e:
+ logger.warning("Failed to run module API callback %s: %s", callback, e)
+ continue
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 75e6019760..6eafbea25d 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -14,7 +14,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Set,
+ Tuple,
+)
from synapse.api import errors
from synapse.api.constants import EventTypes
@@ -595,7 +606,7 @@ class DeviceHandler(DeviceWorkerHandler):
def _update_device_from_client_ips(
- device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict]
+ device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
) -> None:
ip = client_ips.get((device["user_id"], device["device_id"]), {})
device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index d089c56286..365063ebdf 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -55,8 +55,7 @@ class EventAuthHandler:
"""Check an event passes the auth rules at its own auth events"""
auth_event_ids = event.auth_event_ids()
auth_events_by_id = await self._store.get_events(auth_event_ids)
- auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()}
- check_auth_rules_for_event(room_version_obj, event, auth_events)
+ check_auth_rules_for_event(room_version_obj, event, auth_events_by_id.values())
def compute_auth_events(
self,
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 3e341bd287..3112cc88b1 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -15,7 +15,6 @@
"""Contains handlers for federation events."""
-import itertools
import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
@@ -27,12 +26,7 @@ from unpaddedbase64 import decode_base64
from twisted.internet import defer
from synapse import event_auth
-from synapse.api.constants import (
- EventContentFields,
- EventTypes,
- Membership,
- RejectedReason,
-)
+from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.api.errors import (
AuthError,
CodeMessageException,
@@ -43,12 +37,9 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.crypto.event_signing import compute_event_signature
-from synapse.event_auth import (
- check_auth_rules_for_event,
- validate_event_for_room_version,
-)
+from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
@@ -238,18 +229,10 @@ class FederationHandler:
)
return False
- logger.debug(
- "room_id: %s, backfill: current_depth: %s, max_depth: %s, extrems: %s",
- room_id,
- current_depth,
- max_depth,
- sorted_extremeties_tuple,
- )
-
# We ignore extremities that have a greater depth than our current depth
# as:
# 1. we don't really care about getting events that have happened
- # before our current position; and
+ # after our current position; and
# 2. we have likely previously tried and failed to backfill from that
# extremity, so to avoid getting "stuck" requesting the same
# backfill repeatedly we drop those extremities.
@@ -257,9 +240,19 @@ class FederationHandler:
t for t in sorted_extremeties_tuple if int(t[1]) <= current_depth
]
+ logger.debug(
+ "room_id: %s, backfill: current_depth: %s, limit: %s, max_depth: %s, extrems: %s filtered_sorted_extremeties_tuple: %s",
+ room_id,
+ current_depth,
+ limit,
+ max_depth,
+ sorted_extremeties_tuple,
+ filtered_sorted_extremeties_tuple,
+ )
+
# However, we need to check that the filtered extremities are non-empty.
# If they are empty then either we can a) bail or b) still attempt to
- # backill. We opt to try backfilling anyway just in case we do get
+ # backfill. We opt to try backfilling anyway just in case we do get
# relevant events.
if filtered_sorted_extremeties_tuple:
sorted_extremeties_tuple = filtered_sorted_extremeties_tuple
@@ -389,7 +382,7 @@ class FederationHandler:
for key, state_dict in states.items()
}
- for e_id, _ in sorted_extremeties_tuple:
+ for e_id in event_ids:
likely_extremeties_domains = get_domains_from_state(states[e_id])
success = await try_backfill(
@@ -517,7 +510,7 @@ class FederationHandler:
auth_events=auth_chain,
)
- max_stream_id = await self._persist_auth_tree(
+ max_stream_id = await self._federation_event_handler.process_remote_join(
origin, room_id, auth_chain, state, event, room_version_obj
)
@@ -1093,119 +1086,6 @@ class FederationHandler:
else:
return None
- async def _persist_auth_tree(
- self,
- origin: str,
- room_id: str,
- auth_events: List[EventBase],
- state: List[EventBase],
- event: EventBase,
- room_version: RoomVersion,
- ) -> int:
- """Checks the auth chain is valid (and passes auth checks) for the
- state and event. Then persists the auth chain and state atomically.
- Persists the event separately. Notifies about the persisted events
- where appropriate.
-
- Will attempt to fetch missing auth events.
-
- Args:
- origin: Where the events came from
- room_id,
- auth_events
- state
- event
- room_version: The room version we expect this room to have, and
- will raise if it doesn't match the version in the create event.
- """
- events_to_context = {}
- for e in itertools.chain(auth_events, state):
- e.internal_metadata.outlier = True
- events_to_context[e.event_id] = EventContext.for_outlier()
-
- event_map = {
- e.event_id: e for e in itertools.chain(auth_events, state, [event])
- }
-
- create_event = None
- for e in auth_events:
- if (e.type, e.state_key) == (EventTypes.Create, ""):
- create_event = e
- break
-
- if create_event is None:
- # If the state doesn't have a create event then the room is
- # invalid, and it would fail auth checks anyway.
- raise SynapseError(400, "No create event in state")
-
- room_version_id = create_event.content.get(
- "room_version", RoomVersions.V1.identifier
- )
-
- if room_version.identifier != room_version_id:
- raise SynapseError(400, "Room version mismatch")
-
- missing_auth_events = set()
- for e in itertools.chain(auth_events, state, [event]):
- for e_id in e.auth_event_ids():
- if e_id not in event_map:
- missing_auth_events.add(e_id)
-
- for e_id in missing_auth_events:
- m_ev = await self.federation_client.get_pdu(
- [origin],
- e_id,
- room_version=room_version,
- outlier=True,
- timeout=10000,
- )
- if m_ev and m_ev.event_id == e_id:
- event_map[e_id] = m_ev
- else:
- logger.info("Failed to find auth event %r", e_id)
-
- for e in itertools.chain(auth_events, state, [event]):
- auth_for_e = {
- (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
- for e_id in e.auth_event_ids()
- if e_id in event_map
- }
- if create_event:
- auth_for_e[(EventTypes.Create, "")] = create_event
-
- try:
- validate_event_for_room_version(room_version, e)
- check_auth_rules_for_event(room_version, e, auth_for_e)
- except SynapseError as err:
- # we may get SynapseErrors here as well as AuthErrors. For
- # instance, there are a couple of (ancient) events in some
- # rooms whose senders do not have the correct sigil; these
- # cause SynapseErrors in auth.check. We don't want to give up
- # the attempt to federate altogether in such cases.
-
- logger.warning("Rejecting %s because %s", e.event_id, err.msg)
-
- if e == event:
- raise
- events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
-
- if auth_events or state:
- await self._federation_event_handler.persist_events_and_notify(
- room_id,
- [
- (e, events_to_context[e.event_id])
- for e in itertools.chain(auth_events, state)
- ],
- )
-
- new_event_context = await self.state_handler.compute_event_context(
- event, old_state=state
- )
-
- return await self._federation_event_handler.persist_events_and_notify(
- room_id, [(event, new_event_context)]
- )
-
async def on_get_missing_events(
self,
origin: str,
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index f640b417b3..5a2f2e5ebb 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import itertools
import logging
from http import HTTPStatus
from typing import (
@@ -45,7 +46,7 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions
from synapse.event_auth import (
auth_types_for_event,
check_auth_rules_for_event,
@@ -214,7 +215,7 @@ class FederationEventHandler:
if missing_prevs:
# We only backfill backwards to the min depth.
- min_depth = await self.get_min_depth_for_context(pdu.room_id)
+ min_depth = await self._store.get_min_depth(pdu.room_id)
logger.debug("min_depth: %d", min_depth)
if min_depth is not None and pdu.depth > min_depth:
@@ -390,9 +391,122 @@ class FederationEventHandler:
prev_member_event,
)
+ async def process_remote_join(
+ self,
+ origin: str,
+ room_id: str,
+ auth_events: List[EventBase],
+ state: List[EventBase],
+ event: EventBase,
+ room_version: RoomVersion,
+ ) -> int:
+ """Persists the events returned by a send_join
+
+ Checks the auth chain is valid (and passes auth checks) for the
+ state and event. Then persists the auth chain and state atomically.
+ Persists the event separately. Notifies about the persisted events
+ where appropriate.
+
+ Will attempt to fetch missing auth events.
+
+ Args:
+ origin: Where the events came from
+ room_id,
+ auth_events
+ state
+ event
+ room_version: The room version we expect this room to have, and
+ will raise if it doesn't match the version in the create event.
+ """
+ events_to_context = {}
+ for e in itertools.chain(auth_events, state):
+ e.internal_metadata.outlier = True
+ events_to_context[e.event_id] = EventContext.for_outlier()
+
+ event_map = {
+ e.event_id: e for e in itertools.chain(auth_events, state, [event])
+ }
+
+ create_event = None
+ for e in auth_events:
+ if (e.type, e.state_key) == (EventTypes.Create, ""):
+ create_event = e
+ break
+
+ if create_event is None:
+ # If the state doesn't have a create event then the room is
+ # invalid, and it would fail auth checks anyway.
+ raise SynapseError(400, "No create event in state")
+
+ room_version_id = create_event.content.get(
+ "room_version", RoomVersions.V1.identifier
+ )
+
+ if room_version.identifier != room_version_id:
+ raise SynapseError(400, "Room version mismatch")
+
+ missing_auth_events = set()
+ for e in itertools.chain(auth_events, state, [event]):
+ for e_id in e.auth_event_ids():
+ if e_id not in event_map:
+ missing_auth_events.add(e_id)
+
+ for e_id in missing_auth_events:
+ m_ev = await self._federation_client.get_pdu(
+ [origin],
+ e_id,
+ room_version=room_version,
+ outlier=True,
+ timeout=10000,
+ )
+ if m_ev and m_ev.event_id == e_id:
+ event_map[e_id] = m_ev
+ else:
+ logger.info("Failed to find auth event %r", e_id)
+
+ for e in itertools.chain(auth_events, state, [event]):
+ auth_for_e = [
+ event_map[e_id] for e_id in e.auth_event_ids() if e_id in event_map
+ ]
+ if create_event:
+ auth_for_e.append(create_event)
+
+ try:
+ validate_event_for_room_version(room_version, e)
+ check_auth_rules_for_event(room_version, e, auth_for_e)
+ except SynapseError as err:
+ # we may get SynapseErrors here as well as AuthErrors. For
+ # instance, there are a couple of (ancient) events in some
+ # rooms whose senders do not have the correct sigil; these
+ # cause SynapseErrors in auth.check. We don't want to give up
+ # the attempt to federate altogether in such cases.
+
+ logger.warning("Rejecting %s because %s", e.event_id, err.msg)
+
+ if e == event:
+ raise
+ events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
+
+ if auth_events or state:
+ await self.persist_events_and_notify(
+ room_id,
+ [
+ (e, events_to_context[e.event_id])
+ for e in itertools.chain(auth_events, state)
+ ],
+ )
+
+ new_event_context = await self._state_handler.compute_event_context(
+ event, old_state=state
+ )
+
+ return await self.persist_events_and_notify(
+ room_id, [(event, new_event_context)]
+ )
+
@log_function
async def backfill(
- self, dest: str, room_id: str, limit: int, extremities: List[str]
+ self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
) -> None:
"""Trigger a backfill request to `dest` for the given `room_id`
@@ -1116,14 +1230,12 @@ class FederationEventHandler:
await concurrently_execute(get_event, event_ids, 5)
logger.info("Fetched %i events of %i requested", len(events), len(event_ids))
- await self._auth_and_persist_fetched_events(destination, room_id, events)
+ await self._auth_and_persist_outliers(room_id, events)
- async def _auth_and_persist_fetched_events(
- self, origin: str, room_id: str, events: Iterable[EventBase]
+ async def _auth_and_persist_outliers(
+ self, room_id: str, events: Iterable[EventBase]
) -> None:
- """Persist the events fetched by _get_events_and_persist or _get_remote_auth_chain_for_event
-
- The events to be persisted must be outliers.
+ """Persist a batch of outlier events fetched from remote servers.
We first sort the events to make sure that we process each event's auth_events
before the event itself, and then auth and persist them.
@@ -1131,7 +1243,6 @@ class FederationEventHandler:
Notifies about the events where appropriate.
Params:
- origin: where the events came from
room_id: the room that the events are meant to be in (though this has
not yet been checked)
events: the events that have been fetched
@@ -1167,15 +1278,15 @@ class FederationEventHandler:
shortstr(e.event_id for e in roots),
)
- await self._auth_and_persist_fetched_events_inner(origin, room_id, roots)
+ await self._auth_and_persist_outliers_inner(room_id, roots)
for ev in roots:
del event_map[ev.event_id]
- async def _auth_and_persist_fetched_events_inner(
- self, origin: str, room_id: str, fetched_events: Collection[EventBase]
+ async def _auth_and_persist_outliers_inner(
+ self, room_id: str, fetched_events: Collection[EventBase]
) -> None:
- """Helper for _auth_and_persist_fetched_events
+ """Helper for _auth_and_persist_outliers
Persists a batch of events where we have (theoretically) already persisted all
of their auth events.
@@ -1203,7 +1314,7 @@ class FederationEventHandler:
def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]:
with nested_logging_context(suffix=event.event_id):
- auth = {}
+ auth = []
for auth_event_id in event.auth_event_ids():
ae = persisted_events.get(auth_event_id)
if not ae:
@@ -1216,7 +1327,7 @@ class FederationEventHandler:
# exist, which means it is premature to reject `event`. Instead we
# just ignore it for now.
return None
- auth[(ae.type, ae.state_key)] = ae
+ auth.append(ae)
context = EventContext.for_outlier()
try:
@@ -1256,6 +1367,10 @@ class FederationEventHandler:
Returns:
The updated context object.
+
+ Raises:
+ AuthError if we were unable to find copies of the event's auth events.
+ (Most other failures just cause us to set `context.rejected`.)
"""
# This method should only be used for non-outliers
assert not event.internal_metadata.outlier
@@ -1272,7 +1387,26 @@ class FederationEventHandler:
context.rejected = RejectedReason.AUTH_ERROR
return context
- # calculate what the auth events *should* be, to use as a basis for auth.
+ # next, check that we have all of the event's auth events.
+ #
+ # Note that this can raise AuthError, which we want to propagate to the
+ # caller rather than swallow with `context.rejected` (since we cannot be
+ # certain that there is a permanent problem with the event).
+ claimed_auth_events = await self._load_or_fetch_auth_events_for_event(
+ origin, event
+ )
+
+ # ... and check that the event passes auth at those auth events.
+ try:
+ check_auth_rules_for_event(room_version_obj, event, claimed_auth_events)
+ except AuthError as e:
+ logger.warning(
+ "While checking auth of %r against auth_events: %s", event, e
+ )
+ context.rejected = RejectedReason.AUTH_ERROR
+ return context
+
+ # now check auth against what we think the auth events *should* be.
prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = self._event_auth_handler.compute_auth_events(
event, prev_state_ids, for_verification=True
@@ -1305,7 +1439,9 @@ class FederationEventHandler:
auth_events_for_auth = calculated_auth_event_map
try:
- check_auth_rules_for_event(room_version_obj, event, auth_events_for_auth)
+ check_auth_rules_for_event(
+ room_version_obj, event, auth_events_for_auth.values()
+ )
except AuthError as e:
logger.warning("Failed auth resolution for %r because %s", event, e)
context.rejected = RejectedReason.AUTH_ERROR
@@ -1403,11 +1539,9 @@ class FederationEventHandler:
current_state_ids_list = [
e for k, e in current_state_ids.items() if k in auth_types
]
-
- auth_events_map = await self._store.get_events(current_state_ids_list)
- current_auth_events = {
- (e.type, e.state_key): e for e in auth_events_map.values()
- }
+ current_auth_events = await self._store.get_events_as_list(
+ current_state_ids_list
+ )
try:
check_auth_rules_for_event(room_version_obj, event, current_auth_events)
@@ -1472,6 +1606,9 @@ class FederationEventHandler:
# if we have missing events, we need to fetch those events from somewhere.
#
# we start by checking if they are in the store, and then try calling /event_auth/.
+ #
+ # TODO: this code is now redundant, since it should be impossible for us to
+ # get here without already having the auth events.
if missing_auth:
have_events = await self._store.have_seen_events(
event.room_id, missing_auth
@@ -1575,7 +1712,7 @@ class FederationEventHandler:
logger.info(
"After state res: updating auth_events with new state %s",
{
- (d.type, d.state_key): d.event_id
+ d
for d in new_state.values()
if auth_events.get((d.type, d.state_key)) != d
},
@@ -1589,6 +1726,75 @@ class FederationEventHandler:
return context, auth_events
+ async def _load_or_fetch_auth_events_for_event(
+ self, destination: str, event: EventBase
+ ) -> Collection[EventBase]:
+ """Fetch this event's auth_events, from database or remote
+
+ Loads any of the auth_events that we already have from the database/cache. If
+ there are any that are missing, calls /event_auth to get the complete auth
+ chain for the event (and then attempts to load the auth_events again).
+
+ If any of the auth_events cannot be found, raises an AuthError. This can happen
+ for a number of reasons; eg: the events don't exist, or we were unable to talk
+ to `destination`, or we couldn't validate the signature on the event (which
+ in turn has multiple potential causes).
+
+ Args:
+ destination: where to send the /event_auth request. Typically the server
+ that sent us `event` in the first place.
+ event: the event whose auth_events we want
+
+ Returns:
+ all of the events in `event.auth_events`, after deduplication
+
+ Raises:
+ AuthError if we were unable to fetch the auth_events for any reason.
+ """
+ event_auth_event_ids = set(event.auth_event_ids())
+ event_auth_events = await self._store.get_events(
+ event_auth_event_ids, allow_rejected=True
+ )
+ missing_auth_event_ids = event_auth_event_ids.difference(
+ event_auth_events.keys()
+ )
+ if not missing_auth_event_ids:
+ return event_auth_events.values()
+
+ logger.info(
+ "Event %s refers to unknown auth events %s: fetching auth chain",
+ event,
+ missing_auth_event_ids,
+ )
+ try:
+ await self._get_remote_auth_chain_for_event(
+ destination, event.room_id, event.event_id
+ )
+ except Exception as e:
+ logger.warning("Failed to get auth chain for %s: %s", event, e)
+ # in this case, it's very likely we still won't have all the auth
+ # events - but we pick that up below.
+
+ # try to fetch the auth events we missed list time.
+ extra_auth_events = await self._store.get_events(
+ missing_auth_event_ids, allow_rejected=True
+ )
+ missing_auth_event_ids.difference_update(extra_auth_events.keys())
+ event_auth_events.update(extra_auth_events)
+ if not missing_auth_event_ids:
+ return event_auth_events.values()
+
+ # we still don't have all the auth events.
+ logger.warning(
+ "Missing auth events for %s: %s",
+ event,
+ shortstr(missing_auth_event_ids),
+ )
+ # the fact we can't find the auth event doesn't mean it doesn't
+ # exist, which means it is premature to store `event` as rejected.
+ # instead we raise an AuthError, which will make the caller ignore it.
+ raise AuthError(code=HTTPStatus.FORBIDDEN, msg="Auth events could not be found")
+
async def _get_remote_auth_chain_for_event(
self, destination: str, room_id: str, event_id: str
) -> None:
@@ -1624,9 +1830,7 @@ class FederationEventHandler:
for s in seen_remotes:
remote_event_map.pop(s, None)
- await self._auth_and_persist_fetched_events(
- destination, room_id, remote_event_map.values()
- )
+ await self._auth_and_persist_outliers(room_id, remote_event_map.values())
async def _update_context_for_auth_events(
self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
@@ -1696,16 +1900,27 @@ class FederationEventHandler:
# persist_events_and_notify directly.)
assert not event.internal_metadata.outlier
- try:
- if (
- not backfilled
- and not context.rejected
- and (await self._store.get_min_depth(event.room_id)) <= event.depth
- ):
+ if not backfilled and not context.rejected:
+ min_depth = await self._store.get_min_depth(event.room_id)
+ if min_depth is None or min_depth > event.depth:
+ # XXX richvdh 2021/10/07: I don't really understand what this
+ # condition is doing. I think it's trying not to send pushes
+ # for events that predate our join - but that's not really what
+ # min_depth means, and anyway ancient events are a more general
+ # problem.
+ #
+ # for now I'm just going to log about it.
+ logger.info(
+ "Skipping push actions for old event with depth %s < %s",
+ event.depth,
+ min_depth,
+ )
+ else:
await self._action_generator.handle_push_actions_for_event(
event, context
)
+ try:
await self.persist_events_and_notify(
event.room_id, [(event, context)], backfilled=backfilled
)
@@ -1837,6 +2052,3 @@ class FederationEventHandler:
len(ev.auth_event_ids()),
)
raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
-
- async def get_min_depth_for_context(self, context: str) -> int:
- return await self._store.get_min_depth(context)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 4de9f4b828..2e024b551f 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -607,29 +607,6 @@ class EventCreationHandler:
builder.internal_metadata.historical = historical
- # Strip down the auth_event_ids to only what we need to auth the event.
- # For example, we don't need extra m.room.member that don't match event.sender
- if auth_event_ids is not None:
- # If auth events are provided, prev events must be also.
- assert prev_event_ids is not None
-
- temp_event = await builder.build(
- prev_event_ids=prev_event_ids,
- auth_event_ids=auth_event_ids,
- depth=depth,
- )
- auth_events = await self.store.get_events_as_list(auth_event_ids)
- # Create a StateMap[str]
- auth_event_state_map = {
- (e.type, e.state_key): e.event_id for e in auth_events
- }
- # Actually strip down and use the necessary auth events
- auth_event_ids = self._event_auth_handler.compute_auth_events(
- event=temp_event,
- current_state_ids=auth_event_state_map,
- for_verification=False,
- )
-
event, context = await self.create_new_client_event(
builder=builder,
requester=requester,
@@ -936,6 +913,33 @@ class EventCreationHandler:
Tuple of created event, context
"""
+ # Strip down the auth_event_ids to only what we need to auth the event.
+ # For example, we don't need extra m.room.member that don't match event.sender
+ full_state_ids_at_event = None
+ if auth_event_ids is not None:
+ # If auth events are provided, prev events must be also.
+ assert prev_event_ids is not None
+
+ # Copy the full auth state before it stripped down
+ full_state_ids_at_event = auth_event_ids.copy()
+
+ temp_event = await builder.build(
+ prev_event_ids=prev_event_ids,
+ auth_event_ids=auth_event_ids,
+ depth=depth,
+ )
+ auth_events = await self.store.get_events_as_list(auth_event_ids)
+ # Create a StateMap[str]
+ auth_event_state_map = {
+ (e.type, e.state_key): e.event_id for e in auth_events
+ }
+ # Actually strip down and use the necessary auth events
+ auth_event_ids = self._event_auth_handler.compute_auth_events(
+ event=temp_event,
+ current_state_ids=auth_event_state_map,
+ for_verification=False,
+ )
+
if prev_event_ids is not None:
assert (
len(prev_event_ids) <= 10
@@ -965,6 +969,13 @@ class EventCreationHandler:
if builder.internal_metadata.outlier:
event.internal_metadata.outlier = True
context = EventContext.for_outlier()
+ elif (
+ event.type == EventTypes.MSC2716_INSERTION
+ and full_state_ids_at_event
+ and builder.internal_metadata.is_historical()
+ ):
+ old_state = await self.store.get_events_as_list(full_state_ids_at_event)
+ context = await self.state.compute_event_context(event, old_state=old_state)
else:
context = await self.state.compute_event_context(event)
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 176e4dfdd4..60ff896386 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -86,19 +86,22 @@ class PaginationHandler:
self._event_serializer = hs.get_event_client_serializer()
self._retention_default_max_lifetime = (
- hs.config.server.retention_default_max_lifetime
+ hs.config.retention.retention_default_max_lifetime
)
self._retention_allowed_lifetime_min = (
- hs.config.server.retention_allowed_lifetime_min
+ hs.config.retention.retention_allowed_lifetime_min
)
self._retention_allowed_lifetime_max = (
- hs.config.server.retention_allowed_lifetime_max
+ hs.config.retention.retention_allowed_lifetime_max
)
- if hs.config.worker.run_background_tasks and hs.config.server.retention_enabled:
+ if (
+ hs.config.worker.run_background_tasks
+ and hs.config.retention.retention_enabled
+ ):
# Run the purge jobs described in the configuration file.
- for job in hs.config.server.retention_purge_jobs:
+ for job in hs.config.retention.retention_purge_jobs:
logger.info("Setting up purge job with config: %s", job)
self.clock.looping_call(
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 404afb9402..b5968e047b 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -1489,7 +1489,7 @@ def format_user_presence_state(
The "user_id" is optional so that this function can be used to format presence
updates for client /sync responses and for federation /send requests.
"""
- content = {"presence": state.state}
+ content: JsonDict = {"presence": state.state}
if include_user_id:
content["user_id"] = state.user_id
if state.last_active_ts:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 7072bca1fc..6f39e9446f 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -465,17 +465,35 @@ class RoomCreationHandler:
# the room has been created
# Calculate the minimum power level needed to clone the room
event_power_levels = power_levels.get("events", {})
+ if not isinstance(event_power_levels, dict):
+ event_power_levels = {}
state_default = power_levels.get("state_default", 50)
+ try:
+ state_default_int = int(state_default) # type: ignore[arg-type]
+ except (TypeError, ValueError):
+ state_default_int = 50
ban = power_levels.get("ban", 50)
- needed_power_level = max(state_default, ban, max(event_power_levels.values()))
+ try:
+ ban = int(ban) # type: ignore[arg-type]
+ except (TypeError, ValueError):
+ ban = 50
+ needed_power_level = max(
+ state_default_int, ban, max(event_power_levels.values())
+ )
# Get the user's current power level, this matches the logic in get_user_power_level,
# but without the entire state map.
user_power_levels = power_levels.setdefault("users", {})
+ if not isinstance(user_power_levels, dict):
+ user_power_levels = {}
users_default = power_levels.get("users_default", 0)
current_power_level = user_power_levels.get(user_id, users_default)
+ try:
+ current_power_level_int = int(current_power_level) # type: ignore[arg-type]
+ except (TypeError, ValueError):
+ current_power_level_int = 0
# Raise the requester's power level in the new room if necessary
- if current_power_level < needed_power_level:
+ if current_power_level_int < needed_power_level:
user_power_levels[user_id] = needed_power_level
await self._send_events_for_new_room(
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index 51dd4e7555..2f5a3e4d19 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -13,6 +13,10 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+def generate_fake_event_id() -> str:
+ return "$fake_" + random_string(43)
+
+
class RoomBatchHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
@@ -177,6 +181,11 @@ class RoomBatchHandler:
state_event_ids_at_start = []
auth_event_ids = initial_auth_event_ids.copy()
+
+ # Make the state events float off on their own so we don't have a
+ # bunch of `@mxid joined the room` noise between each batch
+ prev_event_id_for_state_chain = generate_fake_event_id()
+
for state_event in state_events_at_start:
assert_params_in_dict(
state_event, ["type", "origin_server_ts", "content", "sender"]
@@ -200,10 +209,6 @@ class RoomBatchHandler:
# Mark all events as historical
event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True
- # Make the state events float off on their own so we don't have a
- # bunch of `@mxid joined the room` noise between each batch
- fake_prev_event_id = "$" + random_string(43)
-
# TODO: This is pretty much the same as some other code to handle inserting state in this file
if event_dict["type"] == EventTypes.Member:
membership = event_dict["content"].get("membership", None)
@@ -216,7 +221,7 @@ class RoomBatchHandler:
action=membership,
content=event_dict["content"],
outlier=True,
- prev_event_ids=[fake_prev_event_id],
+ prev_event_ids=[prev_event_id_for_state_chain],
# Make sure to use a copy of this list because we modify it
# later in the loop here. Otherwise it will be the same
# reference and also update in the event when we append later.
@@ -235,7 +240,7 @@ class RoomBatchHandler:
),
event_dict,
outlier=True,
- prev_event_ids=[fake_prev_event_id],
+ prev_event_ids=[prev_event_id_for_state_chain],
# Make sure to use a copy of this list because we modify it
# later in the loop here. Otherwise it will be the same
# reference and also update in the event when we append later.
@@ -245,6 +250,8 @@ class RoomBatchHandler:
state_event_ids_at_start.append(event_id)
auth_event_ids.append(event_id)
+ # Connect all the state in a floating chain
+ prev_event_id_for_state_chain = event_id
return state_event_ids_at_start
@@ -289,6 +296,10 @@ class RoomBatchHandler:
for ev in events_to_create:
assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"])
+ assert self.hs.is_mine_id(ev["sender"]), "User must be our own: %s" % (
+ ev["sender"],
+ )
+
event_dict = {
"type": ev["type"],
"origin_server_ts": ev["origin_server_ts"],
@@ -311,6 +322,19 @@ class RoomBatchHandler:
historical=True,
depth=inherited_depth,
)
+
+ assert context._state_group
+
+ # Normally this is done when persisting the event but we have to
+ # pre-emptively do it here because we create all the events first,
+ # then persist them in another pass below. And we want to share
+ # state_groups across the whole batch so this lookup needs to work
+ # for the next event in the batch in this loop.
+ await self.store.store_state_group_id_for_event_id(
+ event_id=event.event_id,
+ state_group_id=context._state_group,
+ )
+
logger.debug(
"RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s",
event,
@@ -318,10 +342,6 @@ class RoomBatchHandler:
auth_event_ids,
)
- assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
- event.sender,
- )
-
events_to_persist.append((event, context))
event_id = event.event_id
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 8810f048ba..991fee7e58 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -196,63 +196,12 @@ class UserDirectoryHandler(StateDeltasHandler):
room_id, prev_event_id, event_id, typ
)
elif typ == EventTypes.Member:
- change = await self._get_key_change(
+ await self._handle_room_membership_event(
+ room_id,
prev_event_id,
event_id,
- key_name="membership",
- public_value=Membership.JOIN,
+ state_key,
)
-
- is_remote = not self.is_mine_id(state_key)
- if change is MatchChange.now_false:
- # Need to check if the server left the room entirely, if so
- # we might need to remove all the users in that room
- is_in_room = await self.store.is_host_joined(
- room_id, self.server_name
- )
- if not is_in_room:
- logger.debug("Server left room: %r", room_id)
- # Fetch all the users that we marked as being in user
- # directory due to being in the room and then check if
- # need to remove those users or not
- user_ids = await self.store.get_users_in_dir_due_to_room(
- room_id
- )
-
- for user_id in user_ids:
- await self._handle_remove_user(room_id, user_id)
- continue
- else:
- logger.debug("Server is still in room: %r", room_id)
-
- include_in_dir = (
- is_remote
- or await self.store.should_include_local_user_in_dir(state_key)
- )
- if include_in_dir:
- if change is MatchChange.no_change:
- # Handle any profile changes for remote users.
- # (For local users we are not forced to scan membership
- # events; instead the rest of the application calls
- # `handle_local_profile_change`.)
- if is_remote:
- await self._handle_profile_change(
- state_key, room_id, prev_event_id, event_id
- )
- continue
-
- if change is MatchChange.now_true: # The user joined
- # This may be the first time we've seen a remote user. If
- # so, ensure we have a directory entry for them. (We don't
- # need to do this for local users: their directory entry
- # is created at the point of registration.
- if is_remote:
- await self._upsert_directory_entry_for_remote_user(
- state_key, event_id
- )
- await self._track_user_joined_room(room_id, state_key)
- else: # The user left
- await self._handle_remove_user(room_id, state_key)
else:
logger.debug("Ignoring irrelevant type: %r", typ)
@@ -317,14 +266,83 @@ class UserDirectoryHandler(StateDeltasHandler):
for user_id in users_in_room:
await self.store.remove_user_who_share_room(user_id, room_id)
- # Then, re-add them to the tables.
+ # Then, re-add all remote users and some local users to the tables.
# NOTE: this is not the most efficient method, as _track_user_joined_room sets
# up local_user -> other_user and other_user_whos_local -> local_user,
# which when ran over an entire room, will result in the same values
# being added multiple times. The batching upserts shouldn't make this
# too bad, though.
for user_id in users_in_room:
- await self._track_user_joined_room(room_id, user_id)
+ if not self.is_mine_id(
+ user_id
+ ) or await self.store.should_include_local_user_in_dir(user_id):
+ await self._track_user_joined_room(room_id, user_id)
+
+ async def _handle_room_membership_event(
+ self,
+ room_id: str,
+ prev_event_id: str,
+ event_id: str,
+ state_key: str,
+ ) -> None:
+ """Process a single room membershp event.
+
+ We have to do two things:
+
+ 1. Update the room-sharing tables.
+ This applies to remote users and non-excluded local users.
+ 2. Update the user_directory and user_directory_search tables.
+ This applies to remote users only, because we only become aware of
+ the (and any profile changes) by listening to these events.
+ The rest of the application knows exactly when local users are
+ created or their profile changed---it will directly call methods
+ on this class.
+ """
+ joined = await self._get_key_change(
+ prev_event_id,
+ event_id,
+ key_name="membership",
+ public_value=Membership.JOIN,
+ )
+
+ # Both cases ignore excluded local users, so start by discarding them.
+ is_remote = not self.is_mine_id(state_key)
+ if not is_remote and not await self.store.should_include_local_user_in_dir(
+ state_key
+ ):
+ return
+
+ if joined is MatchChange.now_false:
+ # Need to check if the server left the room entirely, if so
+ # we might need to remove all the users in that room
+ is_in_room = await self.store.is_host_joined(room_id, self.server_name)
+ if not is_in_room:
+ logger.debug("Server left room: %r", room_id)
+ # Fetch all the users that we marked as being in user
+ # directory due to being in the room and then check if
+ # need to remove those users or not
+ user_ids = await self.store.get_users_in_dir_due_to_room(room_id)
+
+ for user_id in user_ids:
+ await self._handle_remove_user(room_id, user_id)
+ else:
+ logger.debug("Server is still in room: %r", room_id)
+ await self._handle_remove_user(room_id, state_key)
+ elif joined is MatchChange.no_change:
+ # Handle any profile changes for remote users.
+ # (For local users the rest of the application calls
+ # `handle_local_profile_change`.)
+ if is_remote:
+ await self._handle_possible_remote_profile_change(
+ state_key, room_id, prev_event_id, event_id
+ )
+ elif joined is MatchChange.now_true: # The user joined
+ # This may be the first time we've seen a remote user. If
+ # so, ensure we have a directory entry for them. (For local users,
+ # the rest of the application calls `handle_local_profile_change`.)
+ if is_remote:
+ await self._upsert_directory_entry_for_remote_user(state_key, event_id)
+ await self._track_user_joined_room(room_id, state_key)
async def _upsert_directory_entry_for_remote_user(
self, user_id: str, event_id: str
@@ -349,8 +367,8 @@ class UserDirectoryHandler(StateDeltasHandler):
"""Someone's just joined a room. Update `users_in_public_rooms` or
`users_who_share_private_rooms` as appropriate.
- The caller is responsible for ensuring that the given user is not excluded
- from the user directory.
+ The caller is responsible for ensuring that the given user should be
+ included in the user directory.
"""
is_public = await self.store.is_room_world_readable_or_publicly_joinable(
room_id
@@ -386,24 +404,32 @@ class UserDirectoryHandler(StateDeltasHandler):
await self.store.add_users_who_share_private_room(room_id, to_insert)
async def _handle_remove_user(self, room_id: str, user_id: str) -> None:
- """Called when we might need to remove user from directory
+ """Called when when someone leaves a room. The user may be local or remote.
+
+ (If the person who left was the last local user in this room, the server
+ is no longer in the room. We call this function to forget that the remaining
+ remote users are in the room, even though they haven't left. So the name is
+ a little misleading!)
Args:
room_id: The room ID that user left or stopped being public that
user_id
"""
- logger.debug("Removing user %r", user_id)
+ logger.debug("Removing user %r from room %r", user_id, room_id)
# Remove user from sharing tables
await self.store.remove_user_who_share_room(user_id, room_id)
- # Are they still in any rooms? If not, remove them entirely.
- rooms_user_is_in = await self.store.get_user_dir_rooms_user_is_in(user_id)
+ # Additionally, if they're a remote user and we're no longer joined
+ # to any rooms they're in, remove them from the user directory.
+ if not self.is_mine_id(user_id):
+ rooms_user_is_in = await self.store.get_user_dir_rooms_user_is_in(user_id)
- if len(rooms_user_is_in) == 0:
- await self.store.remove_from_user_dir(user_id)
+ if len(rooms_user_is_in) == 0:
+ logger.debug("Removing user %r from directory", user_id)
+ await self.store.remove_from_user_dir(user_id)
- async def _handle_profile_change(
+ async def _handle_possible_remote_profile_change(
self,
user_id: str,
room_id: str,
@@ -411,7 +437,8 @@ class UserDirectoryHandler(StateDeltasHandler):
event_id: Optional[str],
) -> None:
"""Check member event changes for any profile changes and update the
- database if there are.
+ database if there are. This is intended for remote users only. The caller
+ is responsible for checking that the given user is remote.
"""
if not prev_event_id or not event_id:
return
|