From ab7a24cc6bbffa5ba67b42731c45b1d4d33f3ae3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 8 Dec 2020 14:04:35 +0000 Subject: Better formatting for config errors from modules (#8874) The idea is that the parse_config method of extension modules can raise either a ConfigError or a JsonValidationError, and it will be magically turned into a legible error message. There's a few components to it: * Separating the "path" and the "message" parts of a ConfigError, so that we can fiddle with the path bit to turn it into an absolute path. * Generally improving the way ConfigErrors get printed. * Passing in the config path to load_module so that it can wrap any exceptions that get caught appropriately. --- synapse/util/module_loader.py | 64 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 58 insertions(+), 6 deletions(-) (limited to 'synapse/util') diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py index 94b59afb38..1ee61851e4 100644 --- a/synapse/util/module_loader.py +++ b/synapse/util/module_loader.py @@ -15,28 +15,56 @@ import importlib import importlib.util +import itertools +from typing import Any, Iterable, Tuple, Type + +import jsonschema from synapse.config._base import ConfigError +from synapse.config._util import json_error_to_config_error -def load_module(provider): +def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]: """ Loads a synapse module with its config - Take a dict with keys 'module' (the module name) and 'config' - (the config dict). + + Args: + provider: a dict with keys 'module' (the module name) and 'config' + (the config dict). + config_path: the path within the config file. This will be used as a basis + for any error message. Returns Tuple of (provider class, parsed config object) """ + + modulename = provider.get("module") + if not isinstance(modulename, str): + raise ConfigError( + "expected a string", path=itertools.chain(config_path, ("module",)) + ) + # We need to import the module, and then pick the class out of # that, so we split based on the last dot. - module, clz = provider["module"].rsplit(".", 1) + module, clz = modulename.rsplit(".", 1) module = importlib.import_module(module) provider_class = getattr(module, clz) + module_config = provider.get("config") try: - provider_config = provider_class.parse_config(provider.get("config")) + provider_config = provider_class.parse_config(module_config) + except jsonschema.ValidationError as e: + raise json_error_to_config_error(e, itertools.chain(config_path, ("config",))) + except ConfigError as e: + raise _wrap_config_error( + "Failed to parse config for module %r" % (modulename,), + prefix=itertools.chain(config_path, ("config",)), + e=e, + ) except Exception as e: - raise ConfigError("Failed to parse config for %r: %s" % (provider["module"], e)) + raise ConfigError( + "Failed to parse config for module %r" % (modulename,), + path=itertools.chain(config_path, ("config",)), + ) from e return provider_class, provider_config @@ -56,3 +84,27 @@ def load_python_module(location: str): mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) # type: ignore return mod + + +def _wrap_config_error( + msg: str, prefix: Iterable[str], e: ConfigError +) -> "ConfigError": + """Wrap a relative ConfigError with a new path + + This is useful when we have a ConfigError with a relative path due to a problem + parsing part of the config, and we now need to set it in context. + """ + path = prefix + if e.path: + path = itertools.chain(prefix, e.path) + + e1 = ConfigError(msg, path) + + # ideally we would set the 'cause' of the new exception to the original exception; + # however now that we have merged the path into our own, the stringification of + # e will be incorrect, so instead we create a new exception with just the "msg" + # part. + + e1.__cause__ = Exception(e.msg) + e1.__cause__.__cause__ = e.__cause__ + return e1 -- cgit 1.5.1 From f14428b25c37e44675edac4a80d7bd1e47112586 Mon Sep 17 00:00:00 2001 From: David Teller Date: Fri, 11 Dec 2020 20:05:15 +0100 Subject: Allow spam-checker modules to be provide async methods. (#8890) Spam checker modules can now provide async methods. This is implemented in a backwards-compatible manner. --- changelog.d/8890.feature | 1 + docs/spam_checker.md | 19 ++++++--- synapse/events/spamcheck.py | 55 +++++++++++++++++++-------- synapse/federation/federation_base.py | 7 +++- synapse/handlers/auth.py | 8 ++-- synapse/handlers/directory.py | 6 ++- synapse/handlers/federation.py | 2 +- synapse/handlers/message.py | 2 +- synapse/handlers/receipts.py | 7 +--- synapse/handlers/register.py | 2 +- synapse/handlers/room.py | 4 +- synapse/handlers/room_member.py | 2 +- synapse/handlers/user_directory.py | 10 ++--- synapse/metrics/background_process_metrics.py | 9 +---- synapse/rest/media/v1/storage_provider.py | 16 +++----- synapse/server.py | 2 +- synapse/util/async_helpers.py | 8 ++-- synapse/util/distributor.py | 7 +--- tests/handlers/test_user_directory.py | 4 +- 19 files changed, 98 insertions(+), 73 deletions(-) create mode 100644 changelog.d/8890.feature (limited to 'synapse/util') diff --git a/changelog.d/8890.feature b/changelog.d/8890.feature new file mode 100644 index 0000000000..97aa72a76e --- /dev/null +++ b/changelog.d/8890.feature @@ -0,0 +1 @@ +Spam-checkers may now define their methods as `async`. diff --git a/docs/spam_checker.md b/docs/spam_checker.md index 7fc08f1b70..5b4f6428e6 100644 --- a/docs/spam_checker.md +++ b/docs/spam_checker.md @@ -22,6 +22,8 @@ well as some specific methods: * `user_may_create_room` * `user_may_create_room_alias` * `user_may_publish_room` +* `check_username_for_spam` +* `check_registration_for_spam` The details of the each of these methods (as well as their inputs and outputs) are documented in the `synapse.events.spamcheck.SpamChecker` class. @@ -32,28 +34,33 @@ call back into the homeserver internals. ### Example ```python +from synapse.spam_checker_api import RegistrationBehaviour + class ExampleSpamChecker: def __init__(self, config, api): self.config = config self.api = api - def check_event_for_spam(self, foo): + async def check_event_for_spam(self, foo): return False # allow all events - def user_may_invite(self, inviter_userid, invitee_userid, room_id): + async def user_may_invite(self, inviter_userid, invitee_userid, room_id): return True # allow all invites - def user_may_create_room(self, userid): + async def user_may_create_room(self, userid): return True # allow all room creations - def user_may_create_room_alias(self, userid, room_alias): + async def user_may_create_room_alias(self, userid, room_alias): return True # allow all room aliases - def user_may_publish_room(self, userid, room_id): + async def user_may_publish_room(self, userid, room_id): return True # allow publishing of all rooms - def check_username_for_spam(self, user_profile): + async def check_username_for_spam(self, user_profile): return False # allow all usernames + + async def check_registration_for_spam(self, email_threepid, username, request_info): + return RegistrationBehaviour.ALLOW # allow all registrations ``` ## Configuration diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 936896656a..e7e3a7b9a4 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -15,10 +15,11 @@ # limitations under the License. import inspect -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import Collection +from synapse.util.async_helpers import maybe_awaitable if TYPE_CHECKING: import synapse.events @@ -39,7 +40,9 @@ class SpamChecker: else: self.spam_checkers.append(module(config=config)) - def check_event_for_spam(self, event: "synapse.events.EventBase") -> bool: + async def check_event_for_spam( + self, event: "synapse.events.EventBase" + ) -> Union[bool, str]: """Checks if a given event is considered "spammy" by this server. If the server considers an event spammy, then it will be rejected if @@ -50,15 +53,16 @@ class SpamChecker: event: the event to be checked Returns: - True if the event is spammy. + True or a string if the event is spammy. If a string is returned it + will be used as the error message returned to the user. """ for spam_checker in self.spam_checkers: - if spam_checker.check_event_for_spam(event): + if await maybe_awaitable(spam_checker.check_event_for_spam(event)): return True return False - def user_may_invite( + async def user_may_invite( self, inviter_userid: str, invitee_userid: str, room_id: str ) -> bool: """Checks if a given user may send an invite @@ -75,14 +79,18 @@ class SpamChecker: """ for spam_checker in self.spam_checkers: if ( - spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id) + await maybe_awaitable( + spam_checker.user_may_invite( + inviter_userid, invitee_userid, room_id + ) + ) is False ): return False return True - def user_may_create_room(self, userid: str) -> bool: + async def user_may_create_room(self, userid: str) -> bool: """Checks if a given user may create a room If this method returns false, the creation request will be rejected. @@ -94,12 +102,15 @@ class SpamChecker: True if the user may create a room, otherwise False """ for spam_checker in self.spam_checkers: - if spam_checker.user_may_create_room(userid) is False: + if ( + await maybe_awaitable(spam_checker.user_may_create_room(userid)) + is False + ): return False return True - def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool: + async def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool: """Checks if a given user may create a room alias If this method returns false, the association request will be rejected. @@ -112,12 +123,17 @@ class SpamChecker: True if the user may create a room alias, otherwise False """ for spam_checker in self.spam_checkers: - if spam_checker.user_may_create_room_alias(userid, room_alias) is False: + if ( + await maybe_awaitable( + spam_checker.user_may_create_room_alias(userid, room_alias) + ) + is False + ): return False return True - def user_may_publish_room(self, userid: str, room_id: str) -> bool: + async def user_may_publish_room(self, userid: str, room_id: str) -> bool: """Checks if a given user may publish a room to the directory If this method returns false, the publish request will be rejected. @@ -130,12 +146,17 @@ class SpamChecker: True if the user may publish the room, otherwise False """ for spam_checker in self.spam_checkers: - if spam_checker.user_may_publish_room(userid, room_id) is False: + if ( + await maybe_awaitable( + spam_checker.user_may_publish_room(userid, room_id) + ) + is False + ): return False return True - def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool: + async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool: """Checks if a user ID or display name are considered "spammy" by this server. If the server considers a username spammy, then it will not be included in @@ -157,12 +178,12 @@ class SpamChecker: if checker: # Make a copy of the user profile object to ensure the spam checker # cannot modify it. - if checker(user_profile.copy()): + if await maybe_awaitable(checker(user_profile.copy())): return True return False - def check_registration_for_spam( + async def check_registration_for_spam( self, email_threepid: Optional[dict], username: Optional[str], @@ -185,7 +206,9 @@ class SpamChecker: # spam checker checker = getattr(spam_checker, "check_registration_for_spam", None) if checker: - behaviour = checker(email_threepid, username, request_info) + behaviour = await maybe_awaitable( + checker(email_threepid, username, request_info) + ) assert isinstance(behaviour, RegistrationBehaviour) if behaviour != RegistrationBehaviour.ALLOW: return behaviour diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 38aa47963f..383737520a 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -78,6 +78,7 @@ class FederationBase: ctx = current_context() + @defer.inlineCallbacks def callback(_, pdu: EventBase): with PreserveLoggingContext(ctx): if not check_event_content_hash(pdu): @@ -105,7 +106,11 @@ class FederationBase: ) return redacted_event - if self.spam_checker.check_event_for_spam(pdu): + result = yield defer.ensureDeferred( + self.spam_checker.check_event_for_spam(pdu) + ) + + if result: logger.warning( "Event contains spam, redacting %s: %s", pdu.event_id, diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 62f98dabc0..8deec4cd0c 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -14,7 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect import logging import time import unicodedata @@ -59,6 +58,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import ModuleApi from synapse.types import JsonDict, Requester, UserID from synapse.util import stringutils as stringutils +from synapse.util.async_helpers import maybe_awaitable from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.threepids import canonicalise_email @@ -1639,6 +1639,6 @@ class PasswordProvider: # This might return an awaitable, if it does block the log out # until it completes. - result = g(user_id=user_id, device_id=device_id, access_token=access_token,) - if inspect.isawaitable(result): - await result + await maybe_awaitable( + g(user_id=user_id, device_id=device_id, access_token=access_token,) + ) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index ad5683d251..abcf86352d 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -133,7 +133,9 @@ class DirectoryHandler(BaseHandler): 403, "You must be in the room to create an alias for it" ) - if not self.spam_checker.user_may_create_room_alias(user_id, room_alias): + if not await self.spam_checker.user_may_create_room_alias( + user_id, room_alias + ): raise AuthError(403, "This user is not permitted to create this alias") if not self.config.is_alias_creation_allowed( @@ -409,7 +411,7 @@ class DirectoryHandler(BaseHandler): """ user_id = requester.user.to_string() - if not self.spam_checker.user_may_publish_room(user_id, room_id): + if not await self.spam_checker.user_may_publish_room(user_id, room_id): raise AuthError( 403, "This user is not permitted to publish rooms to the room list" ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index df82e60b33..fd8de8696d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1593,7 +1593,7 @@ class FederationHandler(BaseHandler): if self.hs.config.block_non_admin_invites: raise SynapseError(403, "This server does not accept room invites") - if not self.spam_checker.user_may_invite( + if not await self.spam_checker.user_may_invite( event.sender, event.state_key, event.room_id ): raise SynapseError( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 96843338ae..2b8aa9443d 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -744,7 +744,7 @@ class EventCreationHandler: event.sender, ) - spam_error = self.spam_checker.check_event_for_spam(event) + spam_error = await self.spam_checker.check_event_for_spam(event) if spam_error: if not isinstance(spam_error, str): spam_error = "Spam is not permitted here" diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 153cbae7b9..e850e45e46 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -18,7 +18,6 @@ from typing import List, Tuple from synapse.appservice import ApplicationService from synapse.handlers._base import BaseHandler from synapse.types import JsonDict, ReadReceipt, get_domain_from_id -from synapse.util.async_helpers import maybe_awaitable logger = logging.getLogger(__name__) @@ -98,10 +97,8 @@ class ReceiptsHandler(BaseHandler): self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids) # Note that the min here shouldn't be relied upon to be accurate. - await maybe_awaitable( - self.hs.get_pusherpool().on_new_receipts( - min_batch_id, max_batch_id, affected_room_ids - ) + await self.hs.get_pusherpool().on_new_receipts( + min_batch_id, max_batch_id, affected_room_ids ) return True diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 0d85fd0868..94b5610acd 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -187,7 +187,7 @@ class RegistrationHandler(BaseHandler): """ self.check_registration_ratelimit(address) - result = self.spam_checker.check_registration_for_spam( + result = await self.spam_checker.check_registration_for_spam( threepid, localpart, user_agent_ips or [], ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 82fb72b381..7583418946 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -358,7 +358,7 @@ class RoomCreationHandler(BaseHandler): """ user_id = requester.user.to_string() - if not self.spam_checker.user_may_create_room(user_id): + if not await self.spam_checker.user_may_create_room(user_id): raise SynapseError(403, "You are not permitted to create rooms") creation_content = { @@ -609,7 +609,7 @@ class RoomCreationHandler(BaseHandler): 403, "You are not permitted to create rooms", Codes.FORBIDDEN ) - if not is_requester_admin and not self.spam_checker.user_may_create_room( + if not is_requester_admin and not await self.spam_checker.user_may_create_room( user_id ): raise SynapseError(403, "You are not permitted to create rooms") diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index d85110a35e..cb5a29bc7e 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -408,7 +408,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ) block_invite = True - if not self.spam_checker.user_may_invite( + if not await self.spam_checker.user_may_invite( requester.user.to_string(), target.to_string(), room_id ): logger.info("Blocking invite due to spam checker") diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index afbebfc200..f263a638f8 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -81,11 +81,11 @@ class UserDirectoryHandler(StateDeltasHandler): results = await self.store.search_user_dir(user_id, search_term, limit) # Remove any spammy users from the results. - results["results"] = [ - user - for user in results["results"] - if not self.spam_checker.check_username_for_spam(user) - ] + non_spammy_users = [] + for user in results["results"]: + if not await self.spam_checker.check_username_for_spam(user): + non_spammy_users.append(user) + results["results"] = non_spammy_users return results diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 658f6ecd72..76b7decf26 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import logging import threading from functools import wraps @@ -25,6 +24,7 @@ from twisted.internet import defer from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.logging.opentracing import noop_context_manager, start_active_span +from synapse.util.async_helpers import maybe_awaitable if TYPE_CHECKING: import resource @@ -206,12 +206,7 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar if bg_start_span: ctx = start_active_span(desc, tags={"request_id": context.request}) with ctx: - result = func(*args, **kwargs) - - if inspect.isawaitable(result): - result = await result - - return result + return await maybe_awaitable(func(*args, **kwargs)) except Exception: logger.exception( "Background process '%s' threw an exception", desc, diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 18c9ed48d6..67f67efde7 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import logging import os import shutil @@ -21,6 +20,7 @@ from typing import Optional from synapse.config._base import Config from synapse.logging.context import defer_to_thread, run_in_background +from synapse.util.async_helpers import maybe_awaitable from ._base import FileInfo, Responder from .media_storage import FileResponder @@ -91,16 +91,14 @@ class StorageProviderWrapper(StorageProvider): if self.store_synchronous: # store_file is supposed to return an Awaitable, but guard # against improper implementations. - result = self.backend.store_file(path, file_info) - if inspect.isawaitable(result): - return await result + return await maybe_awaitable(self.backend.store_file(path, file_info)) else: # TODO: Handle errors. async def store(): try: - result = self.backend.store_file(path, file_info) - if inspect.isawaitable(result): - return await result + return await maybe_awaitable( + self.backend.store_file(path, file_info) + ) except Exception: logger.exception("Error storing file") @@ -110,9 +108,7 @@ class StorageProviderWrapper(StorageProvider): async def fetch(self, path, file_info): # store_file is supposed to return an Awaitable, but guard # against improper implementations. - result = self.backend.fetch(path, file_info) - if inspect.isawaitable(result): - return await result + return await maybe_awaitable(self.backend.fetch(path, file_info)) class FileStorageProviderBackend(StorageProvider): diff --git a/synapse/server.py b/synapse/server.py index 043810ad31..a198b0eb46 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -618,7 +618,7 @@ class HomeServer(metaclass=abc.ABCMeta): return StatsHandler(self) @cache_in_self - def get_spam_checker(self): + def get_spam_checker(self) -> SpamChecker: return SpamChecker(self) @cache_in_self diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 382f0cf3f0..9a873c8e8e 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -15,10 +15,12 @@ # limitations under the License. import collections +import inspect import logging from contextlib import contextmanager from typing import ( Any, + Awaitable, Callable, Dict, Hashable, @@ -542,11 +544,11 @@ class DoneAwaitable: raise StopIteration(self.value) -def maybe_awaitable(value): +def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]: """Convert a value to an awaitable if not already an awaitable. """ - - if hasattr(value, "__await__"): + if inspect.isawaitable(value): + assert isinstance(value, Awaitable) return value return DoneAwaitable(value) diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index f73e95393c..a6ee9edaec 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -12,13 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect import logging from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util.async_helpers import maybe_awaitable logger = logging.getLogger(__name__) @@ -105,10 +105,7 @@ class Signal: async def do(observer): try: - result = observer(*args, **kwargs) - if inspect.isawaitable(result): - result = await result - return result + return await maybe_awaitable(observer(*args, **kwargs)) except Exception as e: logger.warning( "%s signal observer %s failed: %r", self.name, observer, e, diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 98e5af2072..647a17cb90 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -270,7 +270,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): spam_checker = self.hs.get_spam_checker() class AllowAll: - def check_username_for_spam(self, user_profile): + async def check_username_for_spam(self, user_profile): # Allow all users. return False @@ -283,7 +283,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Configure a spam checker that filters all users. class BlockAll: - def check_username_for_spam(self, user_profile): + async def check_username_for_spam(self, user_profile): # All users are spammy. return True -- cgit 1.5.1 From 06fefe0bb19d5ef0a5873ea5697e2018ce9e6026 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 5 Jan 2021 08:06:55 -0500 Subject: Add type hints to the logging context code. (#8939) --- changelog.d/8939.misc | 1 + synapse/logging/context.py | 50 +++++++++++++++++++++++++++++---------------- synapse/storage/database.py | 8 +++++--- synapse/util/metrics.py | 10 ++++++++- 4 files changed, 47 insertions(+), 22 deletions(-) create mode 100644 changelog.d/8939.misc (limited to 'synapse/util') diff --git a/changelog.d/8939.misc b/changelog.d/8939.misc new file mode 100644 index 0000000000..bf94135fd5 --- /dev/null +++ b/changelog.d/8939.misc @@ -0,0 +1 @@ +Various clean-ups to the structured logging and logging context code. diff --git a/synapse/logging/context.py b/synapse/logging/context.py index a507a83e93..c2db8b45f3 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -252,7 +252,12 @@ class LoggingContext: "scope", ] - def __init__(self, name=None, parent_context=None, request=None) -> None: + def __init__( + self, + name: Optional[str] = None, + parent_context: "Optional[LoggingContext]" = None, + request: Optional[str] = None, + ) -> None: self.previous_context = current_context() self.name = name @@ -536,20 +541,20 @@ class LoggingContextFilter(logging.Filter): def __init__(self, request: str = ""): self._default_request = request - def filter(self, record) -> Literal[True]: + def filter(self, record: logging.LogRecord) -> Literal[True]: """Add each fields from the logging contexts to the record. Returns: True to include the record in the log output. """ context = current_context() - record.request = self._default_request + record.request = self._default_request # type: ignore # context should never be None, but if it somehow ends up being, then # we end up in a death spiral of infinite loops, so let's check, for # robustness' sake. if context is not None: # Logging is interested in the request. - record.request = context.request + record.request = context.request # type: ignore return True @@ -616,9 +621,7 @@ def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSe return current -def nested_logging_context( - suffix: str, parent_context: Optional[LoggingContext] = None -) -> LoggingContext: +def nested_logging_context(suffix: str) -> LoggingContext: """Creates a new logging context as a child of another. The nested logging context will have a 'request' made up of the parent context's @@ -632,20 +635,23 @@ def nested_logging_context( # ... do stuff Args: - suffix (str): suffix to add to the parent context's 'request'. - parent_context (LoggingContext|None): parent context. Will use the current context - if None. + suffix: suffix to add to the parent context's 'request'. Returns: LoggingContext: new logging context. """ - if parent_context is not None: - context = parent_context # type: LoggingContextOrSentinel + curr_context = current_context() + if not curr_context: + logger.warning( + "Starting nested logging context from sentinel context: metrics will be lost" + ) + parent_context = None + prefix = "" else: - context = current_context() - return LoggingContext( - parent_context=context, request=str(context.request) + "-" + suffix - ) + assert isinstance(curr_context, LoggingContext) + parent_context = curr_context + prefix = str(parent_context.request) + return LoggingContext(parent_context=parent_context, request=prefix + "-" + suffix) def preserve_fn(f): @@ -822,10 +828,18 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): Deferred: A Deferred which fires a callback with the result of `f`, or an errback if `f` throws an exception. """ - logcontext = current_context() + curr_context = current_context() + if not curr_context: + logger.warning( + "Calling defer_to_threadpool from sentinel context: metrics will be lost" + ) + parent_context = None + else: + assert isinstance(curr_context, LoggingContext) + parent_context = curr_context def g(): - with LoggingContext(parent_context=logcontext): + with LoggingContext(parent_context=parent_context): return f(*args, **kwargs) return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g)) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index d1b5760c2c..b70ca3087b 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -42,7 +42,6 @@ from synapse.api.errors import StoreError from synapse.config.database import DatabaseConnectionConfig from synapse.logging.context import ( LoggingContext, - LoggingContextOrSentinel, current_context, make_deferred_yieldable, ) @@ -671,12 +670,15 @@ class DatabasePool: Returns: The result of func """ - parent_context = current_context() # type: Optional[LoggingContextOrSentinel] - if not parent_context: + curr_context = current_context() + if not curr_context: logger.warning( "Starting db connection from sentinel context: metrics will be lost" ) parent_context = None + else: + assert isinstance(curr_context, LoggingContext) + parent_context = curr_context start_time = monotonic_time() diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index ffdea0de8d..24123d5cc4 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -108,7 +108,15 @@ class Measure: def __init__(self, clock, name): self.clock = clock self.name = name - parent_context = current_context() + curr_context = current_context() + if not curr_context: + logger.warning( + "Starting metrics collection from sentinel context: metrics will be lost" + ) + parent_context = None + else: + assert isinstance(curr_context, LoggingContext) + parent_context = curr_context self._logging_context = LoggingContext( "Measure[%s]" % (self.name,), parent_context ) -- cgit 1.5.1 From 1b4d5d6acf8cfbe65601b881360ea730f9693d80 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 6 Jan 2021 12:33:20 -0500 Subject: Empty iterables should count towards cache usage. (#9028) --- changelog.d/9028.bugfix | 1 + synapse/util/caches/deferred_cache.py | 2 +- tests/util/caches/test_deferred_cache.py | 73 ++++++++++++++++++++++---------- 3 files changed, 52 insertions(+), 24 deletions(-) create mode 100644 changelog.d/9028.bugfix (limited to 'synapse/util') diff --git a/changelog.d/9028.bugfix b/changelog.d/9028.bugfix new file mode 100644 index 0000000000..66666886a4 --- /dev/null +++ b/changelog.d/9028.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where some caches could grow larger than configured. diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 601305487c..1adc92eb90 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -105,7 +105,7 @@ class DeferredCache(Generic[KT, VT]): keylen=keylen, cache_name=name, cache_type=cache_type, - size_callback=(lambda d: len(d)) if iterable else None, + size_callback=(lambda d: len(d) or 1) if iterable else None, metrics_collection_callback=metrics_cb, apply_cache_factor_from_config=apply_cache_factor_from_config, ) # type: LruCache[KT, VT] diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py index dadfabd46d..ecd9efc4df 100644 --- a/tests/util/caches/test_deferred_cache.py +++ b/tests/util/caches/test_deferred_cache.py @@ -25,13 +25,8 @@ from tests.unittest import TestCase class DeferredCacheTestCase(TestCase): def test_empty(self): cache = DeferredCache("test") - failed = False - try: + with self.assertRaises(KeyError): cache.get("foo") - except KeyError: - failed = True - - self.assertTrue(failed) def test_hit(self): cache = DeferredCache("test") @@ -155,13 +150,8 @@ class DeferredCacheTestCase(TestCase): cache.prefill(("foo",), 123) cache.invalidate(("foo",)) - failed = False - try: + with self.assertRaises(KeyError): cache.get(("foo",)) - except KeyError: - failed = True - - self.assertTrue(failed) def test_invalidate_all(self): cache = DeferredCache("testcache") @@ -215,13 +205,8 @@ class DeferredCacheTestCase(TestCase): cache.prefill(2, "two") cache.prefill(3, "three") # 1 will be evicted - failed = False - try: + with self.assertRaises(KeyError): cache.get(1) - except KeyError: - failed = True - - self.assertTrue(failed) cache.get(2) cache.get(3) @@ -239,13 +224,55 @@ class DeferredCacheTestCase(TestCase): cache.prefill(3, "three") - failed = False - try: + with self.assertRaises(KeyError): cache.get(2) - except KeyError: - failed = True - self.assertTrue(failed) + cache.get(1) + cache.get(3) + + def test_eviction_iterable(self): + cache = DeferredCache( + "test", max_entries=3, apply_cache_factor_from_config=False, iterable=True, + ) + + cache.prefill(1, ["one", "two"]) + cache.prefill(2, ["three"]) + # Now access 1 again, thus causing 2 to be least-recently used + cache.get(1) + + # Now add an item to the cache, which evicts 2. + cache.prefill(3, ["four"]) + with self.assertRaises(KeyError): + cache.get(2) + + # Ensure 1 & 3 are in the cache. cache.get(1) cache.get(3) + + # Now access 1 again, thus causing 3 to be least-recently used + cache.get(1) + + # Now add an item with multiple elements to the cache + cache.prefill(4, ["five", "six"]) + + # Both 1 and 3 are evicted since there's too many elements. + with self.assertRaises(KeyError): + cache.get(1) + with self.assertRaises(KeyError): + cache.get(3) + + # Now add another item to fill the cache again. + cache.prefill(5, ["seven"]) + + # Now access 4, thus causing 5 to be least-recently used + cache.get(4) + + # Add an empty item. + cache.prefill(6, []) + + # 5 gets evicted and replaced since an empty element counts as an item. + with self.assertRaises(KeyError): + cache.get(5) + cache.get(4) + cache.get(6) -- cgit 1.5.1 From a03d71dc9d60251b8b753cc223b704a4095231da Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 8 Jan 2021 14:33:53 +0000 Subject: Fix "Starting metrics collection from sentinel context" errors (#9053) --- changelog.d/9053.bugfix | 1 + synapse/notifier.py | 39 +++++++++++++++++++-------------------- synapse/util/metrics.py | 3 ++- 3 files changed, 22 insertions(+), 21 deletions(-) create mode 100644 changelog.d/9053.bugfix (limited to 'synapse/util') diff --git a/changelog.d/9053.bugfix b/changelog.d/9053.bugfix new file mode 100644 index 0000000000..3d8bbf11a1 --- /dev/null +++ b/changelog.d/9053.bugfix @@ -0,0 +1 @@ +Fix bug where we didn't correctly record CPU time spent in 'on_new_event' block. diff --git a/synapse/notifier.py b/synapse/notifier.py index c4c8bb271d..0745899b48 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -396,31 +396,30 @@ class Notifier: Will wake up all listeners for the given users and rooms. """ - with PreserveLoggingContext(): - with Measure(self.clock, "on_new_event"): - user_streams = set() + with Measure(self.clock, "on_new_event"): + user_streams = set() - for user in users: - user_stream = self.user_to_user_stream.get(str(user)) - if user_stream is not None: - user_streams.add(user_stream) + for user in users: + user_stream = self.user_to_user_stream.get(str(user)) + if user_stream is not None: + user_streams.add(user_stream) - for room in rooms: - user_streams |= self.room_to_user_streams.get(room, set()) + for room in rooms: + user_streams |= self.room_to_user_streams.get(room, set()) - time_now_ms = self.clock.time_msec() - for user_stream in user_streams: - try: - user_stream.notify(stream_key, new_token, time_now_ms) - except Exception: - logger.exception("Failed to notify listener") + time_now_ms = self.clock.time_msec() + for user_stream in user_streams: + try: + user_stream.notify(stream_key, new_token, time_now_ms) + except Exception: + logger.exception("Failed to notify listener") - self.notify_replication() + self.notify_replication() - # Notify appservices - self._notify_app_services_ephemeral( - stream_key, new_token, users, - ) + # Notify appservices + self._notify_app_services_ephemeral( + stream_key, new_token, users, + ) def on_new_replication_data(self) -> None: """Used to inform replication listeners that something has happened diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 24123d5cc4..f4de6b9f54 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -111,7 +111,8 @@ class Measure: curr_context = current_context() if not curr_context: logger.warning( - "Starting metrics collection from sentinel context: metrics will be lost" + "Starting metrics collection %r from sentinel context: metrics will be lost", + name, ) parent_context = None else: -- cgit 1.5.1 From 1315a2e8be702a513d49c1142e9e52b642286635 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 11 Jan 2021 16:09:22 +0000 Subject: Use a chain cover index to efficiently calculate auth chain difference (#8868) --- changelog.d/8868.misc | 1 + docs/auth_chain_diff.dot | 32 ++ docs/auth_chain_diff.dot.png | Bin 0 -> 42427 bytes docs/auth_chain_difference_algorithm.md | 108 +++++ synapse/storage/database.py | 22 +- synapse/storage/databases/main/event_federation.py | 185 +++++++ synapse/storage/databases/main/events.py | 535 ++++++++++++++++++++- synapse/storage/databases/main/room.py | 51 +- .../main/schema/delta/59/04_event_auth_chains.sql | 52 ++ .../delta/59/04_event_auth_chains.sql.postgres | 16 + synapse/util/iterutils.py | 53 +- tests/storage/test_event_chain.py | 472 ++++++++++++++++++ tests/storage/test_event_federation.py | 249 +++++++++- tests/util/test_itertools.py | 41 +- 14 files changed, 1769 insertions(+), 48 deletions(-) create mode 100644 changelog.d/8868.misc create mode 100644 docs/auth_chain_diff.dot create mode 100644 docs/auth_chain_diff.dot.png create mode 100644 docs/auth_chain_difference_algorithm.md create mode 100644 synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql create mode 100644 synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres create mode 100644 tests/storage/test_event_chain.py (limited to 'synapse/util') diff --git a/changelog.d/8868.misc b/changelog.d/8868.misc new file mode 100644 index 0000000000..1a11e30944 --- /dev/null +++ b/changelog.d/8868.misc @@ -0,0 +1 @@ +Improve efficiency of large state resolutions for new rooms. diff --git a/docs/auth_chain_diff.dot b/docs/auth_chain_diff.dot new file mode 100644 index 0000000000..978d579ada --- /dev/null +++ b/docs/auth_chain_diff.dot @@ -0,0 +1,32 @@ +digraph auth { + nodesep=0.5; + rankdir="RL"; + + C [label="Create (1,1)"]; + + BJ [label="Bob's Join (2,1)", color=red]; + BJ2 [label="Bob's Join (2,2)", color=red]; + BJ2 -> BJ [color=red, dir=none]; + + subgraph cluster_foo { + A1 [label="Alice's invite (4,1)", color=blue]; + A2 [label="Alice's Join (4,2)", color=blue]; + A3 [label="Alice's Join (4,3)", color=blue]; + A3 -> A2 -> A1 [color=blue, dir=none]; + color=none; + } + + PL1 [label="Power Level (3,1)", color=darkgreen]; + PL2 [label="Power Level (3,2)", color=darkgreen]; + PL2 -> PL1 [color=darkgreen, dir=none]; + + {rank = same; C; BJ; PL1; A1;} + + A1 -> C [color=grey]; + A1 -> BJ [color=grey]; + PL1 -> C [color=grey]; + BJ2 -> PL1 [penwidth=2]; + + A3 -> PL2 [penwidth=2]; + A1 -> PL1 -> BJ -> C [penwidth=2]; +} diff --git a/docs/auth_chain_diff.dot.png b/docs/auth_chain_diff.dot.png new file mode 100644 index 0000000000..771c07308f Binary files /dev/null and b/docs/auth_chain_diff.dot.png differ diff --git a/docs/auth_chain_difference_algorithm.md b/docs/auth_chain_difference_algorithm.md new file mode 100644 index 0000000000..30f72a70da --- /dev/null +++ b/docs/auth_chain_difference_algorithm.md @@ -0,0 +1,108 @@ +# Auth Chain Difference Algorithm + +The auth chain difference algorithm is used by V2 state resolution, where a +naive implementation can be a significant source of CPU and DB usage. + +### Definitions + +A *state set* is a set of state events; e.g. the input of a state resolution +algorithm is a collection of state sets. + +The *auth chain* of a set of events are all the events' auth events and *their* +auth events, recursively (i.e. the events reachable by walking the graph induced +by an event's auth events links). + +The *auth chain difference* of a collection of state sets is the union minus the +intersection of the sets of auth chains corresponding to the state sets, i.e an +event is in the auth chain difference if it is reachable by walking the auth +event graph from at least one of the state sets but not from *all* of the state +sets. + +## Breadth First Walk Algorithm + +A way of calculating the auth chain difference without calculating the full auth +chains for each state set is to do a parallel breadth first walk (ordered by +depth) of each state set's auth chain. By tracking which events are reachable +from each state set we can finish early if every pending event is reachable from +every state set. + +This can work well for state sets that have a small auth chain difference, but +can be very inefficient for larger differences. However, this algorithm is still +used if we don't have a chain cover index for the room (e.g. because we're in +the process of indexing it). + +## Chain Cover Index + +Synapse computes auth chain differences by pre-computing a "chain cover" index +for the auth chain in a room, allowing efficient reachability queries like "is +event A in the auth chain of event B". This is done by assigning every event a +*chain ID* and *sequence number* (e.g. `(5,3)`), and having a map of *links* +between chains (e.g. `(5,3) -> (2,4)`) such that A is reachable by B (i.e. `A` +is in the auth chain of `B`) if and only if either: + +1. A and B have the same chain ID and `A`'s sequence number is less than `B`'s + sequence number; or +2. there is a link `L` between `B`'s chain ID and `A`'s chain ID such that + `L.start_seq_no` <= `B.seq_no` and `A.seq_no` <= `L.end_seq_no`. + +There are actually two potential implementations, one where we store links from +each chain to every other reachable chain (the transitive closure of the links +graph), and one where we remove redundant links (the transitive reduction of the +links graph) e.g. if we have chains `C3 -> C2 -> C1` then the link `C3 -> C1` +would not be stored. Synapse uses the former implementations so that it doesn't +need to recurse to test reachability between chains. + +### Example + +An example auth graph would look like the following, where chains have been +formed based on type/state_key and are denoted by colour and are labelled with +`(chain ID, sequence number)`. Links are denoted by the arrows (links in grey +are those that would be remove in the second implementation described above). + +![Example](auth_chain_diff.dot.png) + +Note that we don't include all links between events and their auth events, as +most of those links would be redundant. For example, all events point to the +create event, but each chain only needs the one link from it's base to the +create event. + +## Using the Index + +This index can be used to calculate the auth chain difference of the state sets +by looking at the chain ID and sequence numbers reachable from each state set: + +1. For every state set lookup the chain ID/sequence numbers of each state event +2. Use the index to find all chains and the maximum sequence number reachable + from each state set. +3. The auth chain difference is then all events in each chain that have sequence + numbers between the maximum sequence number reachable from *any* state set and + the minimum reachable by *all* state sets (if any). + +Note that steps 2 is effectively calculating the auth chain for each state set +(in terms of chain IDs and sequence numbers), and step 3 is calculating the +difference between the union and intersection of the auth chains. + +### Worked Example + +For example, given the above graph, we can calculate the difference between +state sets consisting of: + +1. `S1`: Alice's invite `(4,1)` and Bob's second join `(2,2)`; and +2. `S2`: Alice's second join `(4,3)` and Bob's first join `(2,1)`. + +Using the index we see that the following auth chains are reachable from each +state set: + +1. `S1`: `(1,1)`, `(2,2)`, `(3,1)` & `(4,1)` +2. `S2`: `(1,1)`, `(2,1)`, `(3,2)` & `(4,3)` + +And so, for each the ranges that are in the auth chain difference: +1. Chain 1: None, (since everything can reach the create event). +2. Chain 2: The range `(1, 2]` (i.e. just `2`), as `1` is reachable by all state + sets and the maximum reachable is `2` (corresponding to Bob's second join). +3. Chain 3: Similarly the range `(1, 2]` (corresponding to the second power + level). +4. Chain 4: The range `(1, 3]` (corresponding to both of Alice's joins). + +So the final result is: Bob's second join `(2,2)`, the second power level +`(3,2)` and both of Alice's joins `(4,2)` & `(4,3)`. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index b70ca3087b..6cfadc2b4e 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -179,6 +179,9 @@ class LoggingDatabaseConnection: _CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]] +R = TypeVar("R") + + class LoggingTransaction: """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() @@ -266,6 +269,20 @@ class LoggingTransaction: for val in args: self.execute(sql, val) + def execute_values(self, sql: str, *args: Any) -> List[Tuple]: + """Corresponds to psycopg2.extras.execute_values. Only available when + using postgres. + + Always sets fetch=True when caling `execute_values`, so will return the + results. + """ + assert isinstance(self.database_engine, PostgresEngine) + from psycopg2.extras import execute_values # type: ignore + + return self._do_execute( + lambda *x: execute_values(self.txn, *x, fetch=True), sql, *args + ) + def execute(self, sql: str, *args: Any) -> None: self._do_execute(self.txn.execute, sql, *args) @@ -276,7 +293,7 @@ class LoggingTransaction: "Strip newlines out of SQL so that the loggers in the DB are on one line" return " ".join(line.strip() for line in sql.splitlines() if line.strip()) - def _do_execute(self, func, sql: str, *args: Any) -> None: + def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R: sql = self._make_sql_one_line(sql) # TODO(paul): Maybe use 'info' and 'debug' for values? @@ -347,9 +364,6 @@ class PerformanceCounters: return top_n_counters -R = TypeVar("R") - - class DatabasePool: """Wraps a single physical database and connection pool. diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index ebffd89251..8326640d20 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -24,6 +24,8 @@ from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore +from synapse.storage.engines import PostgresEngine +from synapse.storage.types import Cursor from synapse.types import Collection from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache @@ -32,6 +34,11 @@ from synapse.util.iterutils import batch_iter logger = logging.getLogger(__name__) +class _NoChainCoverIndex(Exception): + def __init__(self, room_id: str): + super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,)) + + class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) @@ -151,15 +158,193 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas The set of the difference in auth chains. """ + # Check if we have indexed the room so we can use the chain cover + # algorithm. + room = await self.get_room(room_id) + if room["has_auth_chain_index"]: + try: + return await self.db_pool.runInteraction( + "get_auth_chain_difference_chains", + self._get_auth_chain_difference_using_cover_index_txn, + room_id, + state_sets, + ) + except _NoChainCoverIndex: + # For whatever reason we don't actually have a chain cover index + # for the events in question, so we fall back to the old method. + pass + return await self.db_pool.runInteraction( "get_auth_chain_difference", self._get_auth_chain_difference_txn, state_sets, ) + def _get_auth_chain_difference_using_cover_index_txn( + self, txn: Cursor, room_id: str, state_sets: List[Set[str]] + ) -> Set[str]: + """Calculates the auth chain difference using the chain index. + + See docs/auth_chain_difference_algorithm.md for details + """ + + # First we look up the chain ID/sequence numbers for all the events, and + # work out the chain/sequence numbers reachable from each state set. + + initial_events = set(state_sets[0]).union(*state_sets[1:]) + + # Map from event_id -> (chain ID, seq no) + chain_info = {} # type: Dict[str, Tuple[int, int]] + + # Map from chain ID -> seq no -> event Id + chain_to_event = {} # type: Dict[int, Dict[int, str]] + + # All the chains that we've found that are reachable from the state + # sets. + seen_chains = set() # type: Set[int] + + sql = """ + SELECT event_id, chain_id, sequence_number + FROM event_auth_chains + WHERE %s + """ + for batch in batch_iter(initial_events, 1000): + clause, args = make_in_list_sql_clause( + txn.database_engine, "event_id", batch + ) + txn.execute(sql % (clause,), args) + + for event_id, chain_id, sequence_number in txn: + chain_info[event_id] = (chain_id, sequence_number) + seen_chains.add(chain_id) + chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id + + # Check that we actually have a chain ID for all the events. + events_missing_chain_info = initial_events.difference(chain_info) + if events_missing_chain_info: + # This can happen due to e.g. downgrade/upgrade of the server. We + # raise an exception and fall back to the previous algorithm. + logger.info( + "Unexpectedly found that events don't have chain IDs in room %s: %s", + room_id, + events_missing_chain_info, + ) + raise _NoChainCoverIndex(room_id) + + # Corresponds to `state_sets`, except as a map from chain ID to max + # sequence number reachable from the state set. + set_to_chain = [] # type: List[Dict[int, int]] + for state_set in state_sets: + chains = {} # type: Dict[int, int] + set_to_chain.append(chains) + + for event_id in state_set: + chain_id, seq_no = chain_info[event_id] + + chains[chain_id] = max(seq_no, chains.get(chain_id, 0)) + + # Now we look up all links for the chains we have, adding chains to + # set_to_chain that are reachable from each set. + sql = """ + SELECT + origin_chain_id, origin_sequence_number, + target_chain_id, target_sequence_number + FROM event_auth_chain_links + WHERE %s + """ + + # (We need to take a copy of `seen_chains` as we want to mutate it in + # the loop) + for batch in batch_iter(set(seen_chains), 1000): + clause, args = make_in_list_sql_clause( + txn.database_engine, "origin_chain_id", batch + ) + txn.execute(sql % (clause,), args) + + for ( + origin_chain_id, + origin_sequence_number, + target_chain_id, + target_sequence_number, + ) in txn: + for chains in set_to_chain: + # chains are only reachable if the origin sequence number of + # the link is less than the max sequence number in the + # origin chain. + if origin_sequence_number <= chains.get(origin_chain_id, 0): + chains[target_chain_id] = max( + target_sequence_number, chains.get(target_chain_id, 0), + ) + + seen_chains.add(target_chain_id) + + # Now for each chain we figure out the maximum sequence number reachable + # from *any* state set and the minimum sequence number reachable from + # *all* state sets. Events in that range are in the auth chain + # difference. + result = set() + + # Mapping from chain ID to the range of sequence numbers that should be + # pulled from the database. + chain_to_gap = {} # type: Dict[int, Tuple[int, int]] + + for chain_id in seen_chains: + min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain) + max_seq_no = max(chains.get(chain_id, 0) for chains in set_to_chain) + + if min_seq_no < max_seq_no: + # We have a non empty gap, try and fill it from the events that + # we have, otherwise add them to the list of gaps to pull out + # from the DB. + for seq_no in range(min_seq_no + 1, max_seq_no + 1): + event_id = chain_to_event.get(chain_id, {}).get(seq_no) + if event_id: + result.add(event_id) + else: + chain_to_gap[chain_id] = (min_seq_no, max_seq_no) + break + + if not chain_to_gap: + # If there are no gaps to fetch, we're done! + return result + + if isinstance(self.database_engine, PostgresEngine): + # We can use `execute_values` to efficiently fetch the gaps when + # using postgres. + sql = """ + SELECT event_id + FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq) + WHERE + c.chain_id = l.chain_id + AND min_seq < sequence_number AND sequence_number <= max_seq + """ + + args = [ + (chain_id, min_no, max_no) + for chain_id, (min_no, max_no) in chain_to_gap.items() + ] + + rows = txn.execute_values(sql, args) + result.update(r for r, in rows) + else: + # For SQLite we just fall back to doing a noddy for loop. + sql = """ + SELECT event_id FROM event_auth_chains + WHERE chain_id = ? AND ? < sequence_number AND sequence_number <= ? + """ + for chain_id, (min_no, max_no) in chain_to_gap.items(): + txn.execute(sql, (chain_id, min_no, max_no)) + result.update(r for r, in txn) + + return result + def _get_auth_chain_difference_txn( self, txn, state_sets: List[Set[str]] ) -> Set[str]: + """Calculates the auth chain difference using a breadth first search. + + This is used when we don't have a cover index for the room. + """ # Algorithm Description # ~~~~~~~~~~~~~~~~~~~~~ diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 5e7753e09b..186f064036 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -17,7 +17,17 @@ import itertools import logging from collections import OrderedDict, namedtuple -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generator, + Iterable, + List, + Optional, + Set, + Tuple, +) import attr from prometheus_client import Counter @@ -33,9 +43,10 @@ from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchEntry from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.storage.util.sequence import build_sequence_generator from synapse.types import StateMap, get_domain_from_id from synapse.util import json_encoder -from synapse.util.iterutils import batch_iter +from synapse.util.iterutils import batch_iter, sorted_topologically if TYPE_CHECKING: from synapse.server import HomeServer @@ -89,6 +100,14 @@ class PersistEventsStore: self._clock = hs.get_clock() self._instance_name = hs.get_instance_name() + def get_chain_id_txn(txn): + txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains") + return txn.fetchone()[0] + + self._event_chain_id_gen = build_sequence_generator( + db.engine, get_chain_id_txn, "event_auth_chain_id" + ) + self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id @@ -366,6 +385,36 @@ class PersistEventsStore: # Insert into event_to_state_groups. self._store_event_state_mappings_txn(txn, events_and_contexts) + self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts]) + + # _store_rejected_events_txn filters out any events which were + # rejected, and returns the filtered list. + events_and_contexts = self._store_rejected_events_txn( + txn, events_and_contexts=events_and_contexts + ) + + # From this point onwards the events are only ones that weren't + # rejected. + + self._update_metadata_tables_txn( + txn, + events_and_contexts=events_and_contexts, + all_events_and_contexts=all_events_and_contexts, + backfilled=backfilled, + ) + + # We call this last as it assumes we've inserted the events into + # room_memberships, where applicable. + self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) + + def _persist_event_auth_chain_txn( + self, txn: LoggingTransaction, events: List[EventBase], + ) -> None: + + # We only care about state events, so this if there are no state events. + if not any(e.is_state() for e in events): + return + # We want to store event_auth mappings for rejected events, as they're # used in state res v2. # This is only necessary if the rejected event appears in an accepted @@ -381,31 +430,357 @@ class PersistEventsStore: "room_id": event.room_id, "auth_id": auth_id, } - for event, _ in events_and_contexts + for event in events for auth_id in event.auth_event_ids() if event.is_state() ], ) - # _store_rejected_events_txn filters out any events which were - # rejected, and returns the filtered list. - events_and_contexts = self._store_rejected_events_txn( - txn, events_and_contexts=events_and_contexts + # We now calculate chain ID/sequence numbers for any state events we're + # persisting. We ignore out of band memberships as we're not in the room + # and won't have their auth chain (we'll fix it up later if we join the + # room). + # + # See: docs/auth_chain_difference_algorithm.md + + # We ignore legacy rooms that we aren't filling the chain cover index + # for. + rows = self.db_pool.simple_select_many_txn( + txn, + table="rooms", + column="room_id", + iterable={event.room_id for event in events if event.is_state()}, + keyvalues={}, + retcols=("room_id", "has_auth_chain_index"), ) + rooms_using_chain_index = { + row["room_id"] for row in rows if row["has_auth_chain_index"] + } - # From this point onwards the events are only ones that weren't - # rejected. + state_events = { + event.event_id: event + for event in events + if event.is_state() and event.room_id in rooms_using_chain_index + } - self._update_metadata_tables_txn( + if not state_events: + return + + # Map from event ID to chain ID/sequence number. + chain_map = {} # type: Dict[str, Tuple[int, int]] + + # We need to know the type/state_key and auth events of the events we're + # calculating chain IDs for. We don't rely on having the full Event + # instances as we'll potentially be pulling more events from the DB and + # we don't need the overhead of fetching/parsing the full event JSON. + event_to_types = { + e.event_id: (e.type, e.state_key) for e in state_events.values() + } + event_to_auth_chain = { + e.event_id: e.auth_event_ids() for e in state_events.values() + } + + # Set of event IDs to calculate chain ID/seq numbers for. + events_to_calc_chain_id_for = set(state_events) + + # We check if there are any events that need to be handled in the rooms + # we're looking at. These should just be out of band memberships, where + # we didn't have the auth chain when we first persisted. + rows = self.db_pool.simple_select_many_txn( txn, - events_and_contexts=events_and_contexts, - all_events_and_contexts=all_events_and_contexts, - backfilled=backfilled, + table="event_auth_chain_to_calculate", + keyvalues={}, + column="room_id", + iterable={e.room_id for e in state_events.values()}, + retcols=("event_id", "type", "state_key"), ) + for row in rows: + event_id = row["event_id"] + event_type = row["type"] + state_key = row["state_key"] + + # (We could pull out the auth events for all rows at once using + # simple_select_many, but this case happens rarely and almost always + # with a single row.) + auth_events = self.db_pool.simple_select_onecol_txn( + txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id", + ) - # We call this last as it assumes we've inserted the events into - # room_memberships, where applicable. - self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) + events_to_calc_chain_id_for.add(event_id) + event_to_types[event_id] = (event_type, state_key) + event_to_auth_chain[event_id] = auth_events + + # First we get the chain ID and sequence numbers for the events' + # auth events (that aren't also currently being persisted). + # + # Note that there there is an edge case here where we might not have + # calculated chains and sequence numbers for events that were "out + # of band". We handle this case by fetching the necessary info and + # adding it to the set of events to calculate chain IDs for. + + missing_auth_chains = { + a_id + for auth_events in event_to_auth_chain.values() + for a_id in auth_events + if a_id not in events_to_calc_chain_id_for + } + + # We loop here in case we find an out of band membership and need to + # fetch their auth event info. + while missing_auth_chains: + sql = """ + SELECT event_id, events.type, state_key, chain_id, sequence_number + FROM events + INNER JOIN state_events USING (event_id) + LEFT JOIN event_auth_chains USING (event_id) + WHERE + """ + clause, args = make_in_list_sql_clause( + txn.database_engine, "event_id", missing_auth_chains, + ) + txn.execute(sql + clause, args) + + missing_auth_chains.clear() + + for auth_id, event_type, state_key, chain_id, sequence_number in txn: + event_to_types[auth_id] = (event_type, state_key) + + if chain_id is None: + # No chain ID, so the event was persisted out of band. + # We add to list of events to calculate auth chains for. + + events_to_calc_chain_id_for.add(auth_id) + + event_to_auth_chain[ + auth_id + ] = self.db_pool.simple_select_onecol_txn( + txn, + "event_auth", + keyvalues={"event_id": auth_id}, + retcol="auth_id", + ) + + missing_auth_chains.update( + e + for e in event_to_auth_chain[auth_id] + if e not in event_to_types + ) + else: + chain_map[auth_id] = (chain_id, sequence_number) + + # Now we check if we have any events where we don't have auth chain, + # this should only be out of band memberships. + for event_id in sorted_topologically(event_to_auth_chain, event_to_auth_chain): + for auth_id in event_to_auth_chain[event_id]: + if ( + auth_id not in chain_map + and auth_id not in events_to_calc_chain_id_for + ): + events_to_calc_chain_id_for.discard(event_id) + + # If this is an event we're trying to persist we add it to + # the list of events to calculate chain IDs for next time + # around. (Otherwise we will have already added it to the + # table). + event = state_events.get(event_id) + if event: + self.db_pool.simple_insert_txn( + txn, + table="event_auth_chain_to_calculate", + values={ + "event_id": event.event_id, + "room_id": event.room_id, + "type": event.type, + "state_key": event.state_key, + }, + ) + + # We stop checking the event's auth events since we've + # discarded it. + break + + if not events_to_calc_chain_id_for: + return + + # We now calculate the chain IDs/sequence numbers for the events. We + # do this by looking at the chain ID and sequence number of any auth + # event with the same type/state_key and incrementing the sequence + # number by one. If there was no match or the chain ID/sequence + # number is already taken we generate a new chain. + # + # We need to do this in a topologically sorted order as we want to + # generate chain IDs/sequence numbers of an event's auth events + # before the event itself. + chains_tuples_allocated = set() # type: Set[Tuple[int, int]] + new_chain_tuples = {} # type: Dict[str, Tuple[int, int]] + for event_id in sorted_topologically( + events_to_calc_chain_id_for, event_to_auth_chain + ): + existing_chain_id = None + for auth_id in event_to_auth_chain[event_id]: + if event_to_types.get(event_id) == event_to_types.get(auth_id): + existing_chain_id = chain_map[auth_id] + break + + new_chain_tuple = None + if existing_chain_id: + # We found a chain ID/sequence number candidate, check its + # not already taken. + proposed_new_id = existing_chain_id[0] + proposed_new_seq = existing_chain_id[1] + 1 + if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated: + already_allocated = self.db_pool.simple_select_one_onecol_txn( + txn, + table="event_auth_chains", + keyvalues={ + "chain_id": proposed_new_id, + "sequence_number": proposed_new_seq, + }, + retcol="event_id", + allow_none=True, + ) + if already_allocated: + # Mark it as already allocated so we don't need to hit + # the DB again. + chains_tuples_allocated.add((proposed_new_id, proposed_new_seq)) + else: + new_chain_tuple = ( + proposed_new_id, + proposed_new_seq, + ) + + if not new_chain_tuple: + new_chain_tuple = (self._event_chain_id_gen.get_next_id_txn(txn), 1) + + chains_tuples_allocated.add(new_chain_tuple) + + chain_map[event_id] = new_chain_tuple + new_chain_tuples[event_id] = new_chain_tuple + + self.db_pool.simple_insert_many_txn( + txn, + table="event_auth_chains", + values=[ + {"event_id": event_id, "chain_id": c_id, "sequence_number": seq} + for event_id, (c_id, seq) in new_chain_tuples.items() + ], + ) + + self.db_pool.simple_delete_many_txn( + txn, + table="event_auth_chain_to_calculate", + keyvalues={}, + column="event_id", + iterable=new_chain_tuples, + ) + + # Now we need to calculate any new links between chains caused by + # the new events. + # + # Links are pairs of chain ID/sequence numbers such that for any + # event A (CA, SA) and any event B (CB, SB), B is in A's auth chain + # if and only if there is at least one link (CA, S1) -> (CB, S2) + # where SA >= S1 and S2 >= SB. + # + # We try and avoid adding redundant links to the table, e.g. if we + # have two links between two chains which both start/end at the + # sequence number event (or cross) then one can be safely dropped. + # + # To calculate new links we look at every new event and: + # 1. Fetch the chain ID/sequence numbers of its auth events, + # discarding any that are reachable by other auth events, or + # that have the same chain ID as the event. + # 2. For each retained auth event we: + # a. Add a link from the event's to the auth event's chain + # ID/sequence number; and + # b. Add a link from the event to every chain reachable by the + # auth event. + + # Step 1, fetch all existing links from all the chains we've seen + # referenced. + chain_links = _LinkMap() + rows = self.db_pool.simple_select_many_txn( + txn, + table="event_auth_chain_links", + column="origin_chain_id", + iterable={chain_id for chain_id, _ in chain_map.values()}, + keyvalues={}, + retcols=( + "origin_chain_id", + "origin_sequence_number", + "target_chain_id", + "target_sequence_number", + ), + ) + for row in rows: + chain_links.add_link( + (row["origin_chain_id"], row["origin_sequence_number"]), + (row["target_chain_id"], row["target_sequence_number"]), + new=False, + ) + + # We do this in toplogical order to avoid adding redundant links. + for event_id in sorted_topologically( + events_to_calc_chain_id_for, event_to_auth_chain + ): + chain_id, sequence_number = chain_map[event_id] + + # Filter out auth events that are reachable by other auth + # events. We do this by looking at every permutation of pairs of + # auth events (A, B) to check if B is reachable from A. + reduction = { + a_id + for a_id in event_to_auth_chain[event_id] + if chain_map[a_id][0] != chain_id + } + for start_auth_id, end_auth_id in itertools.permutations( + event_to_auth_chain[event_id], r=2, + ): + if chain_links.exists_path_from( + chain_map[start_auth_id], chain_map[end_auth_id] + ): + reduction.discard(end_auth_id) + + # Step 2, figure out what the new links are from the reduced + # list of auth events. + for auth_id in reduction: + auth_chain_id, auth_sequence_number = chain_map[auth_id] + + # Step 2a, add link between the event and auth event + chain_links.add_link( + (chain_id, sequence_number), (auth_chain_id, auth_sequence_number) + ) + + # Step 2b, add a link to chains reachable from the auth + # event. + for target_id, target_seq in chain_links.get_links_from( + (auth_chain_id, auth_sequence_number) + ): + if target_id == chain_id: + continue + + chain_links.add_link( + (chain_id, sequence_number), (target_id, target_seq) + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="event_auth_chain_links", + values=[ + { + "origin_chain_id": source_id, + "origin_sequence_number": source_seq, + "target_chain_id": target_id, + "target_sequence_number": target_seq, + } + for ( + source_id, + source_seq, + target_id, + target_seq, + ) in chain_links.get_additions() + ], + ) def _persist_transaction_ids_txn( self, @@ -1521,3 +1896,131 @@ class PersistEventsStore: if not ev.internal_metadata.is_outlier() ], ) + + +@attr.s(slots=True) +class _LinkMap: + """A helper type for tracking links between chains. + """ + + # Stores the set of links as nested maps: source chain ID -> target chain ID + # -> source sequence number -> target sequence number. + maps = attr.ib(type=Dict[int, Dict[int, Dict[int, int]]], factory=dict) + + # Stores the links that have been added (with new set to true), as tuples of + # `(source chain ID, source sequence no, target chain ID, target sequence no.)` + additions = attr.ib(type=Set[Tuple[int, int, int, int]], factory=set) + + def add_link( + self, + src_tuple: Tuple[int, int], + target_tuple: Tuple[int, int], + new: bool = True, + ) -> bool: + """Add a new link between two chains, ensuring no redundant links are added. + + New links should be added in topological order. + + Args: + src_tuple: The chain ID/sequence number of the source of the link. + target_tuple: The chain ID/sequence number of the target of the link. + new: Whether this is a "new" link, i.e. should it be returned + by `get_additions`. + + Returns: + True if a link was added, false if the given link was dropped as redundant + """ + src_chain, src_seq = src_tuple + target_chain, target_seq = target_tuple + + current_links = self.maps.setdefault(src_chain, {}).setdefault(target_chain, {}) + + assert src_chain != target_chain + + if new: + # Check if the new link is redundant + for current_seq_src, current_seq_target in current_links.items(): + # If a link "crosses" another link then its redundant. For example + # in the following link 1 (L1) is redundant, as any event reachable + # via L1 is *also* reachable via L2. + # + # Chain A Chain B + # | | + # L1 |------ | + # | | | + # L2 |---- | -->| + # | | | + # | |--->| + # | | + # | | + # + # So we only need to keep links which *do not* cross, i.e. links + # that both start and end above or below an existing link. + # + # Note, since we add links in topological ordering we should never + # see `src_seq` less than `current_seq_src`. + + if current_seq_src <= src_seq and target_seq <= current_seq_target: + # This new link is redundant, nothing to do. + return False + + self.additions.add((src_chain, src_seq, target_chain, target_seq)) + + current_links[src_seq] = target_seq + return True + + def get_links_from( + self, src_tuple: Tuple[int, int] + ) -> Generator[Tuple[int, int], None, None]: + """Gets the chains reachable from the given chain/sequence number. + + Yields: + The chain ID and sequence number the link points to. + """ + src_chain, src_seq = src_tuple + for target_id, sequence_numbers in self.maps.get(src_chain, {}).items(): + for link_src_seq, target_seq in sequence_numbers.items(): + if link_src_seq <= src_seq: + yield target_id, target_seq + + def get_links_between( + self, source_chain: int, target_chain: int + ) -> Generator[Tuple[int, int], None, None]: + """Gets the links between two chains. + + Yields: + The source and target sequence numbers. + """ + + yield from self.maps.get(source_chain, {}).get(target_chain, {}).items() + + def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]: + """Gets any newly added links. + + Yields: + The source chain ID/sequence number and target chain ID/sequence number + """ + + for src_chain, src_seq, target_chain, _ in self.additions: + target_seq = self.maps.get(src_chain, {}).get(target_chain, {}).get(src_seq) + if target_seq is not None: + yield (src_chain, src_seq, target_chain, target_seq) + + def exists_path_from( + self, src_tuple: Tuple[int, int], target_tuple: Tuple[int, int], + ) -> bool: + """Checks if there is a path between the source chain ID/sequence and + target chain ID/sequence. + """ + src_chain, src_seq = src_tuple + target_chain, target_seq = target_tuple + + if src_chain == target_chain: + return target_seq <= src_seq + + links = self.get_links_between(src_chain, target_chain) + for link_start_seq, link_end_seq in links: + if link_start_seq <= src_seq and target_seq <= link_end_seq: + return True + + return False diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 4650d0689b..284f2ce77c 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -84,7 +84,7 @@ class RoomWorkerStore(SQLBaseStore): return await self.db_pool.simple_select_one( table="rooms", keyvalues={"room_id": room_id}, - retcols=("room_id", "is_public", "creator"), + retcols=("room_id", "is_public", "creator", "has_auth_chain_index"), desc="get_room", allow_none=True, ) @@ -1166,6 +1166,37 @@ class RoomBackgroundUpdateStore(SQLBaseStore): # It's overridden by RoomStore for the synapse master. raise NotImplementedError() + async def has_auth_chain_index(self, room_id: str) -> bool: + """Check if the room has (or can have) a chain cover index. + + Defaults to True if we don't have an entry in `rooms` table nor any + events for the room. + """ + + has_auth_chain_index = await self.db_pool.simple_select_one_onecol( + table="rooms", + keyvalues={"room_id": room_id}, + retcol="has_auth_chain_index", + desc="has_auth_chain_index", + allow_none=True, + ) + + if has_auth_chain_index: + return True + + # It's possible that we already have events for the room in our DB + # without a corresponding room entry. If we do then we don't want to + # mark the room as having an auth chain cover index. + max_ordering = await self.db_pool.simple_select_one_onecol( + table="events", + keyvalues={"room_id": room_id}, + retcol="MAX(stream_ordering)", + allow_none=True, + desc="upsert_room_on_join", + ) + + return max_ordering is None + class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): def __init__(self, database: DatabasePool, db_conn, hs): @@ -1179,12 +1210,21 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): Called when we join a room over federation, and overwrites any room version currently in the table. """ + # It's possible that we already have events for the room in our DB + # without a corresponding room entry. If we do then we don't want to + # mark the room as having an auth chain cover index. + has_auth_chain_index = await self.has_auth_chain_index(room_id) + await self.db_pool.simple_upsert( desc="upsert_room_on_join", table="rooms", keyvalues={"room_id": room_id}, values={"room_version": room_version.identifier}, - insertion_values={"is_public": False, "creator": ""}, + insertion_values={ + "is_public": False, + "creator": "", + "has_auth_chain_index": has_auth_chain_index, + }, # rooms has a unique constraint on room_id, so no need to lock when doing an # emulated upsert. lock=False, @@ -1219,6 +1259,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): "creator": room_creator_user_id, "is_public": is_public, "room_version": room_version.identifier, + "has_auth_chain_index": True, }, ) if is_public: @@ -1247,6 +1288,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): When we receive an invite or any other event over federation that may relate to a room we are not in, store the version of the room if we don't already know the room version. """ + # It's possible that we already have events for the room in our DB + # without a corresponding room entry. If we do then we don't want to + # mark the room as having an auth chain cover index. + has_auth_chain_index = await self.has_auth_chain_index(room_id) + await self.db_pool.simple_upsert( desc="maybe_store_room_on_outlier_membership", table="rooms", @@ -1256,6 +1302,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): "room_version": room_version.identifier, "is_public": False, "creator": "", + "has_auth_chain_index": has_auth_chain_index, }, # rooms has a unique constraint on room_id, so no need to lock when doing an # emulated upsert. diff --git a/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql new file mode 100644 index 0000000000..729196cfd5 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql @@ -0,0 +1,52 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- See docs/auth_chain_difference_algorithm.md + +CREATE TABLE event_auth_chains ( + event_id TEXT PRIMARY KEY, + chain_id BIGINT NOT NULL, + sequence_number BIGINT NOT NULL +); + +CREATE UNIQUE INDEX event_auth_chains_c_seq_index ON event_auth_chains (chain_id, sequence_number); + + +CREATE TABLE event_auth_chain_links ( + origin_chain_id BIGINT NOT NULL, + origin_sequence_number BIGINT NOT NULL, + + target_chain_id BIGINT NOT NULL, + target_sequence_number BIGINT NOT NULL +); + + +CREATE INDEX event_auth_chain_links_idx ON event_auth_chain_links (origin_chain_id, target_chain_id); + + +-- Events that we have persisted but not calculated auth chains for, +-- e.g. out of band memberships (where we don't have the auth chain) +CREATE TABLE event_auth_chain_to_calculate ( + event_id TEXT PRIMARY KEY, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL +); + +CREATE INDEX event_auth_chain_to_calculate_rm_id ON event_auth_chain_to_calculate(room_id); + + +-- Whether we've calculated the above index for a room. +ALTER TABLE rooms ADD COLUMN has_auth_chain_index BOOLEAN; diff --git a/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres new file mode 100644 index 0000000000..e8a035bbeb --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres @@ -0,0 +1,16 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE IF NOT EXISTS event_auth_chain_id; diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py index 06faeebe7f..f7b4857a84 100644 --- a/synapse/util/iterutils.py +++ b/synapse/util/iterutils.py @@ -13,8 +13,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import heapq from itertools import islice -from typing import Iterable, Iterator, Sequence, Tuple, TypeVar +from typing import ( + Dict, + Generator, + Iterable, + Iterator, + Mapping, + Sequence, + Set, + Tuple, + TypeVar, +) + +from synapse.types import Collection T = TypeVar("T") @@ -46,3 +59,41 @@ def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]: If the input is empty, no chunks are returned. """ return (iseq[i : i + maxlen] for i in range(0, len(iseq), maxlen)) + + +def sorted_topologically( + nodes: Iterable[T], graph: Mapping[T, Collection[T]], +) -> Generator[T, None, None]: + """Given a set of nodes and a graph, yield the nodes in toplogical order. + + For example `sorted_topologically([1, 2], {1: [2]})` will yield `2, 1`. + """ + + # This is implemented by Kahn's algorithm. + + degree_map = {node: 0 for node in nodes} + reverse_graph = {} # type: Dict[T, Set[T]] + + for node, edges in graph.items(): + if node not in degree_map: + continue + + for edge in edges: + if edge in degree_map: + degree_map[node] += 1 + + reverse_graph.setdefault(edge, set()).add(node) + reverse_graph.setdefault(node, set()) + + zero_degree = [node for node, degree in degree_map.items() if degree == 0] + heapq.heapify(zero_degree) + + while zero_degree: + node = heapq.heappop(zero_degree) + yield node + + for edge in reverse_graph[node]: + if edge in degree_map: + degree_map[edge] -= 1 + if degree_map[edge] == 0: + heapq.heappush(zero_degree, edge) diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py new file mode 100644 index 0000000000..83c377824b --- /dev/null +++ b/tests/storage/test_event_chain.py @@ -0,0 +1,472 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Tuple + +from twisted.trial import unittest + +from synapse.api.constants import EventTypes +from synapse.api.room_versions import RoomVersions +from synapse.events import EventBase +from synapse.storage.databases.main.events import _LinkMap + +from tests.unittest import HomeserverTestCase + + +class EventChainStoreTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self._next_stream_ordering = 1 + + def test_simple(self): + """Test that the example in `docs/auth_chain_difference_algorithm.md` + works. + """ + + event_factory = self.hs.get_event_builder_factory() + bob = "@creator:test" + alice = "@alice:test" + room_id = "!room:test" + + # Ensure that we have a rooms entry so that we generate the chain index. + self.get_success( + self.store.store_room( + room_id=room_id, + room_creator_user_id="", + is_public=True, + room_version=RoomVersions.V6, + ) + ) + + create = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Create, + "state_key": "", + "sender": bob, + "room_id": room_id, + "content": {"tag": "create"}, + }, + ).build(prev_event_ids=[], auth_event_ids=[]) + ) + + bob_join = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": bob, + "sender": bob, + "room_id": room_id, + "content": {"tag": "bob_join"}, + }, + ).build(prev_event_ids=[], auth_event_ids=[create.event_id]) + ) + + power = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.PowerLevels, + "state_key": "", + "sender": bob, + "room_id": room_id, + "content": {"tag": "power"}, + }, + ).build( + prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id], + ) + ) + + alice_invite = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": alice, + "sender": bob, + "room_id": room_id, + "content": {"tag": "alice_invite"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, bob_join.event_id, power.event_id], + ) + ) + + alice_join = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": alice, + "sender": alice, + "room_id": room_id, + "content": {"tag": "alice_join"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id], + ) + ) + + power_2 = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.PowerLevels, + "state_key": "", + "sender": bob, + "room_id": room_id, + "content": {"tag": "power_2"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, bob_join.event_id, power.event_id], + ) + ) + + bob_join_2 = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": bob, + "sender": bob, + "room_id": room_id, + "content": {"tag": "bob_join_2"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, bob_join.event_id, power.event_id], + ) + ) + + alice_join2 = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": alice, + "sender": alice, + "room_id": room_id, + "content": {"tag": "alice_join2"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[ + create.event_id, + alice_join.event_id, + power_2.event_id, + ], + ) + ) + + events = [ + create, + bob_join, + power, + alice_invite, + alice_join, + bob_join_2, + power_2, + alice_join2, + ] + + expected_links = [ + (bob_join, create), + (power, create), + (power, bob_join), + (alice_invite, create), + (alice_invite, power), + (alice_invite, bob_join), + (bob_join_2, power), + (alice_join2, power_2), + ] + + self.persist(events) + chain_map, link_map = self.fetch_chains(events) + + # Check that the expected links and only the expected links have been + # added. + self.assertEqual(len(expected_links), len(list(link_map.get_additions()))) + + for start, end in expected_links: + start_id, start_seq = chain_map[start.event_id] + end_id, end_seq = chain_map[end.event_id] + + self.assertIn( + (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id)) + ) + + # Test that everything can reach the create event, but the create event + # can't reach anything. + for event in events[1:]: + self.assertTrue( + link_map.exists_path_from( + chain_map[event.event_id], chain_map[create.event_id] + ), + ) + + self.assertFalse( + link_map.exists_path_from( + chain_map[create.event_id], chain_map[event.event_id], + ), + ) + + def test_out_of_order_events(self): + """Test that we handle persisting events that we don't have the full + auth chain for yet (which should only happen for out of band memberships). + """ + event_factory = self.hs.get_event_builder_factory() + bob = "@creator:test" + alice = "@alice:test" + room_id = "!room:test" + + # Ensure that we have a rooms entry so that we generate the chain index. + self.get_success( + self.store.store_room( + room_id=room_id, + room_creator_user_id="", + is_public=True, + room_version=RoomVersions.V6, + ) + ) + + # First persist the base room. + create = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Create, + "state_key": "", + "sender": bob, + "room_id": room_id, + "content": {"tag": "create"}, + }, + ).build(prev_event_ids=[], auth_event_ids=[]) + ) + + bob_join = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": bob, + "sender": bob, + "room_id": room_id, + "content": {"tag": "bob_join"}, + }, + ).build(prev_event_ids=[], auth_event_ids=[create.event_id]) + ) + + power = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.PowerLevels, + "state_key": "", + "sender": bob, + "room_id": room_id, + "content": {"tag": "power"}, + }, + ).build( + prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id], + ) + ) + + self.persist([create, bob_join, power]) + + # Now persist an invite and a couple of memberships out of order. + alice_invite = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": alice, + "sender": bob, + "room_id": room_id, + "content": {"tag": "alice_invite"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, bob_join.event_id, power.event_id], + ) + ) + + alice_join = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": alice, + "sender": alice, + "room_id": room_id, + "content": {"tag": "alice_join"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id], + ) + ) + + alice_join2 = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": alice, + "sender": alice, + "room_id": room_id, + "content": {"tag": "alice_join2"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, alice_join.event_id, power.event_id], + ) + ) + + self.persist([alice_join]) + self.persist([alice_join2]) + self.persist([alice_invite]) + + # The end result should be sane. + events = [create, bob_join, power, alice_invite, alice_join] + + chain_map, link_map = self.fetch_chains(events) + + expected_links = [ + (bob_join, create), + (power, create), + (power, bob_join), + (alice_invite, create), + (alice_invite, power), + (alice_invite, bob_join), + ] + + # Check that the expected links and only the expected links have been + # added. + self.assertEqual(len(expected_links), len(list(link_map.get_additions()))) + + for start, end in expected_links: + start_id, start_seq = chain_map[start.event_id] + end_id, end_seq = chain_map[end.event_id] + + self.assertIn( + (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id)) + ) + + def persist( + self, events: List[EventBase], + ): + """Persist the given events and check that the links generated match + those given. + """ + + persist_events_store = self.hs.get_datastores().persist_events + + for e in events: + e.internal_metadata.stream_ordering = self._next_stream_ordering + self._next_stream_ordering += 1 + + def _persist(txn): + # We need to persist the events to the events and state_events + # tables. + persist_events_store._store_event_txn(txn, [(e, {}) for e in events]) + + # Actually call the function that calculates the auth chain stuff. + persist_events_store._persist_event_auth_chain_txn(txn, events) + + self.get_success( + persist_events_store.db_pool.runInteraction("_persist", _persist,) + ) + + def fetch_chains( + self, events: List[EventBase] + ) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]: + + # Fetch the map from event ID -> (chain ID, sequence number) + rows = self.get_success( + self.store.db_pool.simple_select_many_batch( + table="event_auth_chains", + column="event_id", + iterable=[e.event_id for e in events], + retcols=("event_id", "chain_id", "sequence_number"), + keyvalues={}, + ) + ) + + chain_map = { + row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows + } + + # Fetch all the links and pass them to the _LinkMap. + rows = self.get_success( + self.store.db_pool.simple_select_many_batch( + table="event_auth_chain_links", + column="origin_chain_id", + iterable=[chain_id for chain_id, _ in chain_map.values()], + retcols=( + "origin_chain_id", + "origin_sequence_number", + "target_chain_id", + "target_sequence_number", + ), + keyvalues={}, + ) + ) + + link_map = _LinkMap() + for row in rows: + added = link_map.add_link( + (row["origin_chain_id"], row["origin_sequence_number"]), + (row["target_chain_id"], row["target_sequence_number"]), + ) + + # We shouldn't have persisted any redundant links + self.assertTrue(added) + + return chain_map, link_map + + +class LinkMapTestCase(unittest.TestCase): + def test_simple(self): + """Basic tests for the LinkMap. + """ + link_map = _LinkMap() + + link_map.add_link((1, 1), (2, 1), new=False) + self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)]) + self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)]) + self.assertCountEqual(link_map.get_additions(), []) + self.assertTrue(link_map.exists_path_from((1, 5), (2, 1))) + self.assertFalse(link_map.exists_path_from((1, 5), (2, 2))) + self.assertTrue(link_map.exists_path_from((1, 5), (1, 1))) + self.assertFalse(link_map.exists_path_from((1, 1), (1, 5))) + + # Attempting to add a redundant link is ignored. + self.assertFalse(link_map.add_link((1, 4), (2, 1))) + self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)]) + + # Adding new non-redundant links works + self.assertTrue(link_map.add_link((1, 3), (2, 3))) + self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)]) + + self.assertTrue(link_map.add_link((2, 5), (1, 3))) + self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)]) + self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)]) + + self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)]) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 482506d731..9d04a066d8 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -13,6 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import attr +from parameterized import parameterized + +from synapse.events import _EventInternalMetadata + import tests.unittest import tests.utils @@ -113,7 +118,8 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1])) self.assertTrue(r == [room2] or r == [room3]) - def test_auth_difference(self): + @parameterized.expand([(True,), (False,)]) + def test_auth_difference(self, use_chain_cover_index: bool): room_id = "@ROOM:local" # The silly auth graph we use to test the auth difference algorithm, @@ -159,46 +165,223 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): "j": 1, } + # Mark the room as not having a cover index + + def store_room(txn): + self.store.db_pool.simple_insert_txn( + txn, + "rooms", + { + "room_id": room_id, + "creator": "room_creator_user_id", + "is_public": True, + "room_version": "6", + "has_auth_chain_index": use_chain_cover_index, + }, + ) + + self.get_success(self.store.db_pool.runInteraction("store_room", store_room)) + # We rudely fiddle with the appropriate tables directly, as that's much # easier than constructing events properly. - def insert_event(txn, event_id, stream_ordering): + def insert_event(txn): + stream_ordering = 0 + + for event_id in auth_graph: + stream_ordering += 1 + depth = depth_map[event_id] + + self.store.db_pool.simple_insert_txn( + txn, + table="events", + values={ + "event_id": event_id, + "room_id": room_id, + "depth": depth, + "topological_ordering": depth, + "type": "m.test", + "processed": True, + "outlier": False, + "stream_ordering": stream_ordering, + }, + ) + + self.hs.datastores.persist_events._persist_event_auth_chain_txn( + txn, + [ + FakeEvent(event_id, room_id, auth_graph[event_id]) + for event_id in auth_graph + ], + ) + + self.get_success(self.store.db_pool.runInteraction("insert", insert_event,)) + + # Now actually test that various combinations give the right result: + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c", "e", "f"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}]) + ) + self.assertSetEqual(difference, {"a", "b", "d", "e"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}]) + ) + self.assertSetEqual(difference, set()) + + def test_auth_difference_partial_cover(self): + """Test that we correctly handle rooms where not all events have a chain + cover calculated. This can happen in some obscure edge cases, including + during the background update that calculates the chain cover for old + rooms. + """ + + room_id = "@ROOM:local" + + # The silly auth graph we use to test the auth difference algorithm, + # where the top are the most recent events. + # + # A B + # \ / + # D E + # \ | + # ` F C + # | /| + # G ´ | + # | \ | + # H I + # | | + # K J + + auth_graph = { + "a": ["e"], + "b": ["e"], + "c": ["g", "i"], + "d": ["f"], + "e": ["f"], + "f": ["g"], + "g": ["h", "i"], + "h": ["k"], + "i": ["j"], + "k": [], + "j": [], + } + + depth_map = { + "a": 7, + "b": 7, + "c": 4, + "d": 6, + "e": 6, + "f": 5, + "g": 3, + "h": 2, + "i": 2, + "k": 1, + "j": 1, + } - depth = depth_map[event_id] + # We rudely fiddle with the appropriate tables directly, as that's much + # easier than constructing events properly. + def insert_event(txn): + # First insert the room and mark it as having a chain cover. self.store.db_pool.simple_insert_txn( txn, - table="events", - values={ - "event_id": event_id, + "rooms", + { "room_id": room_id, - "depth": depth, - "topological_ordering": depth, - "type": "m.test", - "processed": True, - "outlier": False, - "stream_ordering": stream_ordering, + "creator": "room_creator_user_id", + "is_public": True, + "room_version": "6", + "has_auth_chain_index": True, }, ) - self.store.db_pool.simple_insert_many_txn( + stream_ordering = 0 + + for event_id in auth_graph: + stream_ordering += 1 + depth = depth_map[event_id] + + self.store.db_pool.simple_insert_txn( + txn, + table="events", + values={ + "event_id": event_id, + "room_id": room_id, + "depth": depth, + "topological_ordering": depth, + "type": "m.test", + "processed": True, + "outlier": False, + "stream_ordering": stream_ordering, + }, + ) + + # Insert all events apart from 'B' + self.hs.datastores.persist_events._persist_event_auth_chain_txn( txn, - table="event_auth", - values=[ - {"event_id": event_id, "room_id": room_id, "auth_id": a} - for a in auth_graph[event_id] + [ + FakeEvent(event_id, room_id, auth_graph[event_id]) + for event_id in auth_graph + if event_id != "b" ], ) - next_stream_ordering = 0 - for event_id in auth_graph: - next_stream_ordering += 1 - self.get_success( - self.store.db_pool.runInteraction( - "insert", insert_event, event_id, next_stream_ordering - ) + # Now we insert the event 'B' without a chain cover, by temporarily + # pretending the room doesn't have a chain cover. + + self.store.db_pool.simple_update_txn( + txn, + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"has_auth_chain_index": False}, + ) + + self.hs.datastores.persist_events._persist_event_auth_chain_txn( + txn, [FakeEvent("b", room_id, auth_graph["b"])], + ) + + self.store.db_pool.simple_update_txn( + txn, + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"has_auth_chain_index": True}, ) + self.get_success(self.store.db_pool.runInteraction("insert", insert_event,)) + # Now actually test that various combinations give the right result: difference = self.get_success( @@ -240,3 +423,21 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): self.store.get_auth_chain_difference(room_id, [{"a"}]) ) self.assertSetEqual(difference, set()) + + +@attr.s +class FakeEvent: + event_id = attr.ib() + room_id = attr.ib() + auth_events = attr.ib() + + type = "foo" + state_key = "foo" + + internal_metadata = _EventInternalMetadata({}) + + def auth_event_ids(self): + return self.auth_events + + def is_state(self): + return True diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py index 0ab0a91483..1184cea5a3 100644 --- a/tests/util/test_itertools.py +++ b/tests/util/test_itertools.py @@ -12,7 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.iterutils import chunk_seq +from typing import Dict, List + +from synapse.util.iterutils import chunk_seq, sorted_topologically from tests.unittest import TestCase @@ -45,3 +47,40 @@ class ChunkSeqTests(TestCase): self.assertEqual( list(parts), [], ) + + +class SortTopologically(TestCase): + def test_empty(self): + "Test that an empty graph works correctly" + + graph = {} # type: Dict[int, List[int]] + self.assertEqual(list(sorted_topologically([], graph)), []) + + def test_disconnected(self): + "Test that a graph with no edges work" + + graph = {1: [], 2: []} # type: Dict[int, List[int]] + + # For disconnected nodes the output is simply sorted. + self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2]) + + def test_linear(self): + "Test that a simple `4 -> 3 -> 2 -> 1` graph works" + + graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]] + + self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) + + def test_subset(self): + "Test that only sorting a subset of the graph works" + graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]] + + self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4]) + + def test_fork(self): + "Test that a forked graph works" + graph = {1: [], 2: [1], 3: [1], 4: [2, 3]} # type: Dict[int, List[int]] + + # Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should + # always get the same one. + self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) -- cgit 1.5.1 From 1a08e0cdab0b3475fd4189aa1e3b6f9aaa823ccf Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 14 Jan 2021 18:57:32 +0000 Subject: Fix event chain bg update. (#9118) We passed in a graph to `sorted_topologically` which didn't have an entry for each node (as we dropped nodes with no edges). --- changelog.d/9118.misc | 1 + synapse/util/iterutils.py | 2 +- tests/util/test_itertools.py | 8 ++++++++ 3 files changed, 10 insertions(+), 1 deletion(-) create mode 100644 changelog.d/9118.misc (limited to 'synapse/util') diff --git a/changelog.d/9118.misc b/changelog.d/9118.misc new file mode 100644 index 0000000000..346741d982 --- /dev/null +++ b/changelog.d/9118.misc @@ -0,0 +1 @@ +Improve efficiency of large state resolutions. diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py index f7b4857a84..6ef2b008a4 100644 --- a/synapse/util/iterutils.py +++ b/synapse/util/iterutils.py @@ -92,7 +92,7 @@ def sorted_topologically( node = heapq.heappop(zero_degree) yield node - for edge in reverse_graph[node]: + for edge in reverse_graph.get(node, []): if edge in degree_map: degree_map[edge] -= 1 if degree_map[edge] == 0: diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py index 1184cea5a3..522c8061f9 100644 --- a/tests/util/test_itertools.py +++ b/tests/util/test_itertools.py @@ -56,6 +56,14 @@ class SortTopologically(TestCase): graph = {} # type: Dict[int, List[int]] self.assertEqual(list(sorted_topologically([], graph)), []) + def test_handle_empty_graph(self): + "Test that a graph where a node doesn't have an entry is treated as empty" + + graph = {} # type: Dict[int, List[int]] + + # For disconnected nodes the output is simply sorted. + self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2]) + def test_disconnected(self): "Test that a graph with no edges work" -- cgit 1.5.1 From 9ffac2bef1cbf74694280e4976605f3563f97074 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 15 Jan 2021 15:59:20 +0000 Subject: Remote dependency on distutils (#9125) `distutils` is pretty much deprecated these days, and replaced with `setuptools`. It's also annoying because it's you can't `pip install` it, and it's hard to figure out which debian package we should depend on to make sure it's there. Since we only use it for a tiny function anyway, let's just vendor said function into our codebase. --- changelog.d/9125.misc | 1 + debian/changelog | 6 ++++++ debian/control | 1 - synapse/config/registration.py | 11 +++++------ synapse/events/__init__.py | 3 ++- synapse/util/stringutils.py | 19 +++++++++++++++++++ 6 files changed, 33 insertions(+), 8 deletions(-) create mode 100644 changelog.d/9125.misc (limited to 'synapse/util') diff --git a/changelog.d/9125.misc b/changelog.d/9125.misc new file mode 100644 index 0000000000..08459caf5a --- /dev/null +++ b/changelog.d/9125.misc @@ -0,0 +1 @@ +Remove dependency on `distutils`. diff --git a/debian/changelog b/debian/changelog index 609436bf75..1c6308e3a2 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.25.0ubuntu1) UNRELEASED; urgency=medium + + * Remove dependency on `python3-distutils`. + + -- Richard van der Hoff Fri, 15 Jan 2021 12:44:19 +0000 + matrix-synapse-py3 (1.25.0) stable; urgency=medium [ Dan Callahan ] diff --git a/debian/control b/debian/control index b10401be43..8167a901a4 100644 --- a/debian/control +++ b/debian/control @@ -31,7 +31,6 @@ Pre-Depends: dpkg (>= 1.16.1) Depends: adduser, debconf, - python3-distutils|libpython3-stdlib (<< 3.6), ${misc:Depends}, ${shlibs:Depends}, ${synapse:pydepends}, diff --git a/synapse/config/registration.py b/synapse/config/registration.py index cc5f75123c..740c3fc1b1 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -14,14 +14,13 @@ # limitations under the License. import os -from distutils.util import strtobool import pkg_resources from synapse.api.constants import RoomCreationPreset from synapse.config._base import Config, ConfigError from synapse.types import RoomAlias, UserID -from synapse.util.stringutils import random_string_with_symbols +from synapse.util.stringutils import random_string_with_symbols, strtobool class AccountValidityConfig(Config): @@ -86,12 +85,12 @@ class RegistrationConfig(Config): section = "registration" def read_config(self, config, **kwargs): - self.enable_registration = bool( - strtobool(str(config.get("enable_registration", False))) + self.enable_registration = strtobool( + str(config.get("enable_registration", False)) ) if "disable_registration" in config: - self.enable_registration = not bool( - strtobool(str(config["disable_registration"])) + self.enable_registration = not strtobool( + str(config["disable_registration"]) ) self.account_validity = AccountValidityConfig( diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 8028663fa8..3ec4120f85 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -17,7 +17,6 @@ import abc import os -from distutils.util import strtobool from typing import Dict, Optional, Tuple, Type from unpaddedbase64 import encode_base64 @@ -26,6 +25,7 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVers from synapse.types import JsonDict, RoomStreamToken from synapse.util.caches import intern_dict from synapse.util.frozenutils import freeze +from synapse.util.stringutils import strtobool # Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents # bugs where we accidentally share e.g. signature dicts. However, converting a @@ -34,6 +34,7 @@ from synapse.util.frozenutils import freeze # NOTE: This is overridden by the configuration by the Synapse worker apps, but # for the sake of tests, it is set here while it cannot be configured on the # homeserver object itself. + USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0")) diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 61d96a6c28..b103c8694c 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -75,3 +75,22 @@ def shortstr(iterable: Iterable, maxitems: int = 5) -> str: if len(items) <= maxitems: return str(items) return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]" + + +def strtobool(val: str) -> bool: + """Convert a string representation of truth to True or False + + True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values + are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if + 'val' is anything else. + + This is lifted from distutils.util.strtobool, with the exception that it actually + returns a bool, rather than an int. + """ + val = val.lower() + if val in ("y", "yes", "t", "true", "on", "1"): + return True + elif val in ("n", "no", "f", "false", "off", "0"): + return False + else: + raise ValueError("invalid truth value %r" % (val,)) -- cgit 1.5.1 From 0cd2938bc854d947ae8102ded688a626c9fac5b5 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 20 Jan 2021 13:15:14 +0000 Subject: Support icons for Identity Providers (#9154) --- changelog.d/9154.feature | 1 + docs/sample_config.yaml | 4 ++ mypy.ini | 1 + synapse/config/oidc_config.py | 20 ++++++ synapse/config/server.py | 2 +- synapse/federation/federation_server.py | 2 +- synapse/federation/transport/server.py | 2 +- synapse/handlers/cas_handler.py | 4 ++ synapse/handlers/oidc_handler.py | 3 + synapse/handlers/room.py | 2 +- synapse/handlers/saml_handler.py | 4 ++ synapse/handlers/sso.py | 5 ++ synapse/http/endpoint.py | 79 --------------------- synapse/res/templates/sso_login_idp_picker.html | 3 + synapse/rest/client/v1/room.py | 3 +- synapse/storage/databases/main/room.py | 6 +- synapse/types.py | 2 +- synapse/util/stringutils.py | 92 +++++++++++++++++++++++++ tests/http/test_endpoint.py | 2 +- 19 files changed, 146 insertions(+), 91 deletions(-) create mode 100644 changelog.d/9154.feature delete mode 100644 synapse/http/endpoint.py (limited to 'synapse/util') diff --git a/changelog.d/9154.feature b/changelog.d/9154.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9154.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 7fdd798d70..b49a5da8cc 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1726,6 +1726,10 @@ saml2_config: # idp_name: A user-facing name for this identity provider, which is used to # offer the user a choice of login mechanisms. # +# idp_icon: An optional icon for this identity provider, which is presented +# by identity picker pages. If given, must be an MXC URI of the format +# mxc:/// +# # discover: set to 'false' to disable the use of the OIDC discovery mechanism # to discover endpoints. Defaults to true. # diff --git a/mypy.ini b/mypy.ini index b996867121..bd99069c81 100644 --- a/mypy.ini +++ b/mypy.ini @@ -100,6 +100,7 @@ files = synapse/util/async_helpers.py, synapse/util/caches, synapse/util/metrics.py, + synapse/util/stringutils.py, tests/replication, tests/test_utils, tests/handlers/test_password_providers.py, diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index df55367434..f257fcd412 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -23,6 +23,7 @@ from synapse.config._util import validate_config from synapse.python_dependencies import DependencyException, check_requirements from synapse.types import Collection, JsonDict from synapse.util.module_loader import load_module +from synapse.util.stringutils import parse_and_validate_mxc_uri from ._base import Config, ConfigError @@ -66,6 +67,10 @@ class OIDCConfig(Config): # idp_name: A user-facing name for this identity provider, which is used to # offer the user a choice of login mechanisms. # + # idp_icon: An optional icon for this identity provider, which is presented + # by identity picker pages. If given, must be an MXC URI of the format + # mxc:/// + # # discover: set to 'false' to disable the use of the OIDC discovery mechanism # to discover endpoints. Defaults to true. # @@ -207,6 +212,7 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { "properties": { "idp_id": {"type": "string", "minLength": 1, "maxLength": 128}, "idp_name": {"type": "string"}, + "idp_icon": {"type": "string"}, "discover": {"type": "boolean"}, "issuer": {"type": "string"}, "client_id": {"type": "string"}, @@ -336,9 +342,20 @@ def _parse_oidc_config_dict( config_path + ("idp_id",), ) + # MSC2858 also specifies that the idp_icon must be a valid MXC uri + idp_icon = oidc_config.get("idp_icon") + if idp_icon is not None: + try: + parse_and_validate_mxc_uri(idp_icon) + except ValueError as e: + raise ConfigError( + "idp_icon must be a valid MXC URI", config_path + ("idp_icon",) + ) from e + return OidcProviderConfig( idp_id=idp_id, idp_name=oidc_config.get("idp_name", "OIDC"), + idp_icon=idp_icon, discover=oidc_config.get("discover", True), issuer=oidc_config["issuer"], client_id=oidc_config["client_id"], @@ -366,6 +383,9 @@ class OidcProviderConfig: # user-facing name for this identity provider. idp_name = attr.ib(type=str) + # Optional MXC URI for icon for this IdP. + idp_icon = attr.ib(type=Optional[str]) + # whether the OIDC discovery mechanism is used to discover endpoints discover = attr.ib(type=bool) diff --git a/synapse/config/server.py b/synapse/config/server.py index 75ba161f35..47a0370173 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -26,7 +26,7 @@ import yaml from netaddr import IPSet from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.http.endpoint import parse_and_validate_server_name +from synapse.util.stringutils import parse_and_validate_server_name from ._base import Config, ConfigError diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index e5339aca23..171d25c945 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -49,7 +49,6 @@ from synapse.events import EventBase from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction -from synapse.http.endpoint import parse_server_name from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( make_deferred_yieldable, @@ -66,6 +65,7 @@ from synapse.types import JsonDict, get_domain_from_id from synapse.util import glob_to_regex, json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache +from synapse.util.stringutils import parse_server_name if TYPE_CHECKING: from synapse.server import HomeServer diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index cfd094e58f..95c64510a9 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -28,7 +28,6 @@ from synapse.api.urls import ( FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX, ) -from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.server import JsonResource from synapse.http.servlet import ( parse_boolean_from_args, @@ -45,6 +44,7 @@ from synapse.logging.opentracing import ( ) from synapse.server import HomeServer from synapse.types import ThirdPartyInstanceID, get_domain_from_id +from synapse.util.stringutils import parse_and_validate_server_name from synapse.util.versionstring import get_version_string logger = logging.getLogger(__name__) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index f3430c6713..0f342c607b 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -80,6 +80,10 @@ class CasHandler: # user-facing name of this auth provider self.idp_name = "CAS" + # we do not currently support icons for CAS auth, but this is required by + # the SsoIdentityProvider protocol type. + self.idp_icon = None + self._sso_handler = hs.get_sso_handler() self._sso_handler.register_identity_provider(self) diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index ba686d74b2..1607e12935 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -271,6 +271,9 @@ class OidcProvider: # user-facing name of this auth provider self.idp_name = provider.idp_name + # MXC URI for icon for this auth provider + self.idp_icon = provider.idp_icon + self._sso_handler = hs.get_sso_handler() self._sso_handler.register_identity_provider(self) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 3bece6d668..ee27d99135 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -38,7 +38,6 @@ from synapse.api.filtering import Filter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase from synapse.events.utils import copy_power_levels_contents -from synapse.http.endpoint import parse_and_validate_server_name from synapse.storage.state import StateFilter from synapse.types import ( JsonDict, @@ -55,6 +54,7 @@ from synapse.types import ( from synapse.util import stringutils from synapse.util.async_helpers import Linearizer from synapse.util.caches.response_cache import ResponseCache +from synapse.util.stringutils import parse_and_validate_server_name from synapse.visibility import filter_events_for_client from ._base import BaseHandler diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index a8376543c9..38461cf79d 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -78,6 +78,10 @@ class SamlHandler(BaseHandler): # user-facing name of this auth provider self.idp_name = "SAML" + # we do not currently support icons for SAML auth, but this is required by + # the SsoIdentityProvider protocol type. + self.idp_icon = None + # a map from saml session id to Saml2SessionData object self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index dcc85e9871..d493327a10 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -75,6 +75,11 @@ class SsoIdentityProvider(Protocol): def idp_name(self) -> str: """User-facing name for this provider""" + @property + def idp_icon(self) -> Optional[str]: + """Optional MXC URI for user-facing icon""" + return None + @abc.abstractmethod async def handle_redirect_request( self, diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py deleted file mode 100644 index 92a5b606c8..0000000000 --- a/synapse/http/endpoint.py +++ /dev/null @@ -1,79 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import re - -logger = logging.getLogger(__name__) - - -def parse_server_name(server_name): - """Split a server name into host/port parts. - - Args: - server_name (str): server name to parse - - Returns: - Tuple[str, int|None]: host/port parts. - - Raises: - ValueError if the server name could not be parsed. - """ - try: - if server_name[-1] == "]": - # ipv6 literal, hopefully - return server_name, None - - domain_port = server_name.rsplit(":", 1) - domain = domain_port[0] - port = int(domain_port[1]) if domain_port[1:] else None - return domain, port - except Exception: - raise ValueError("Invalid server name '%s'" % server_name) - - -VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z") - - -def parse_and_validate_server_name(server_name): - """Split a server name into host/port parts and do some basic validation. - - Args: - server_name (str): server name to parse - - Returns: - Tuple[str, int|None]: host/port parts. - - Raises: - ValueError if the server name could not be parsed. - """ - host, port = parse_server_name(server_name) - - # these tests don't need to be bulletproof as we'll find out soon enough - # if somebody is giving us invalid data. What we *do* need is to be sure - # that nobody is sneaking IP literals in that look like hostnames, etc. - - # look for ipv6 literals - if host[0] == "[": - if host[-1] != "]": - raise ValueError("Mismatched [...] in server name '%s'" % (server_name,)) - return host, port - - # otherwise it should only be alphanumerics. - if not VALID_HOST_REGEX.match(host): - raise ValueError( - "Server name '%s' contains invalid characters" % (server_name,) - ) - - return host, port diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html index f53c9cd679..5b38481012 100644 --- a/synapse/res/templates/sso_login_idp_picker.html +++ b/synapse/res/templates/sso_login_idp_picker.html @@ -17,6 +17,9 @@
  • +{% if p.idp_icon %} + +{% endif %}
  • {% endfor %} diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index e6725b03b0..f95627ee61 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -32,7 +32,6 @@ from synapse.api.errors import ( ) from synapse.api.filtering import Filter from synapse.events.utils import format_event_for_client_v2 -from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -47,7 +46,7 @@ from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID from synapse.util import json_decoder -from synapse.util.stringutils import random_string +from synapse.util.stringutils import parse_and_validate_server_name, random_string if TYPE_CHECKING: import synapse.server diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 284f2ce77c..a9fcb5f59c 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -16,7 +16,6 @@ import collections import logging -import re from abc import abstractmethod from enum import Enum from typing import Any, Dict, List, Optional, Tuple @@ -30,6 +29,7 @@ from synapse.storage.databases.main.search import SearchStore from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util import json_encoder from synapse.util.caches.descriptors import cached +from synapse.util.stringutils import MXC_REGEX logger = logging.getLogger(__name__) @@ -660,8 +660,6 @@ class RoomWorkerStore(SQLBaseStore): The local and remote media as a lists of tuples where the key is the hostname and the value is the media ID. """ - mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)") - sql = """ SELECT stream_ordering, json FROM events JOIN event_json USING (room_id, event_id) @@ -688,7 +686,7 @@ class RoomWorkerStore(SQLBaseStore): for url in (content_url, thumbnail_url): if not url: continue - matches = mxc_re.match(url) + matches = MXC_REGEX.match(url) if matches: hostname = matches.group(1) media_id = matches.group(2) diff --git a/synapse/types.py b/synapse/types.py index 20a43d05bf..eafe729dfe 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -37,7 +37,7 @@ from signedjson.key import decode_verify_key_bytes from unpaddedbase64 import decode_base64 from synapse.api.errors import Codes, SynapseError -from synapse.http.endpoint import parse_and_validate_server_name +from synapse.util.stringutils import parse_and_validate_server_name if TYPE_CHECKING: from synapse.appservice.api import ApplicationService diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index b103c8694c..f8038bf861 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -18,6 +18,7 @@ import random import re import string from collections.abc import Iterable +from typing import Optional, Tuple from synapse.api.errors import Codes, SynapseError @@ -26,6 +27,15 @@ _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$") +# https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris, +# together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically +# says "there is no grammar for media ids" +# +# The server_name part of this is purposely lax: use parse_and_validate_mxc for +# additional validation. +# +MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$") + # random_string and random_string_with_symbols are used for a range of things, # some cryptographically important, some less so. We use SystemRandom to make sure # we get cryptographically-secure randoms. @@ -59,6 +69,88 @@ def assert_valid_client_secret(client_secret): ) +def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]: + """Split a server name into host/port parts. + + Args: + server_name: server name to parse + + Returns: + host/port parts. + + Raises: + ValueError if the server name could not be parsed. + """ + try: + if server_name[-1] == "]": + # ipv6 literal, hopefully + return server_name, None + + domain_port = server_name.rsplit(":", 1) + domain = domain_port[0] + port = int(domain_port[1]) if domain_port[1:] else None + return domain, port + except Exception: + raise ValueError("Invalid server name '%s'" % server_name) + + +VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z") + + +def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]: + """Split a server name into host/port parts and do some basic validation. + + Args: + server_name: server name to parse + + Returns: + host/port parts. + + Raises: + ValueError if the server name could not be parsed. + """ + host, port = parse_server_name(server_name) + + # these tests don't need to be bulletproof as we'll find out soon enough + # if somebody is giving us invalid data. What we *do* need is to be sure + # that nobody is sneaking IP literals in that look like hostnames, etc. + + # look for ipv6 literals + if host[0] == "[": + if host[-1] != "]": + raise ValueError("Mismatched [...] in server name '%s'" % (server_name,)) + return host, port + + # otherwise it should only be alphanumerics. + if not VALID_HOST_REGEX.match(host): + raise ValueError( + "Server name '%s' contains invalid characters" % (server_name,) + ) + + return host, port + + +def parse_and_validate_mxc_uri(mxc: str) -> Tuple[str, Optional[int], str]: + """Parse the given string as an MXC URI + + Checks that the "server name" part is a valid server name + + Args: + mxc: the (alleged) MXC URI to be checked + Returns: + hostname, port, media id + Raises: + ValueError if the URI cannot be parsed + """ + m = MXC_REGEX.match(mxc) + if not m: + raise ValueError("mxc URI %r did not match expected format" % (mxc,)) + server_name = m.group(1) + media_id = m.group(2) + host, port = parse_and_validate_server_name(server_name) + return host, port, media_id + + def shortstr(iterable: Iterable, maxitems: int = 5) -> str: """If iterable has maxitems or fewer, return the stringification of a list containing those items. diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py index b2e9533b07..d06ea518ce 100644 --- a/tests/http/test_endpoint.py +++ b/tests/http/test_endpoint.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.http.endpoint import parse_and_validate_server_name, parse_server_name +from synapse.util.stringutils import parse_and_validate_server_name, parse_server_name from tests import unittest -- cgit 1.5.1 From 056327457ff471495741a539e99c840ed54afccd Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 22 Jan 2021 19:44:08 +0000 Subject: Fix chain cover update to handle events with duplicate auth events (#9210) --- changelog.d/9210.bugfix | 1 + synapse/util/iterutils.py | 2 +- tests/util/test_itertools.py | 12 ++++++++++++ 3 files changed, 14 insertions(+), 1 deletion(-) create mode 100644 changelog.d/9210.bugfix (limited to 'synapse/util') diff --git a/changelog.d/9210.bugfix b/changelog.d/9210.bugfix new file mode 100644 index 0000000000..f9e0765570 --- /dev/null +++ b/changelog.d/9210.bugfix @@ -0,0 +1 @@ +Fix chain cover update to handle events with duplicate auth events. Introduced in v1.26.0rc1. diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py index 6ef2b008a4..8d2411513f 100644 --- a/synapse/util/iterutils.py +++ b/synapse/util/iterutils.py @@ -78,7 +78,7 @@ def sorted_topologically( if node not in degree_map: continue - for edge in edges: + for edge in set(edges): if edge in degree_map: degree_map[node] += 1 diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py index 522c8061f9..1ef0af8e8f 100644 --- a/tests/util/test_itertools.py +++ b/tests/util/test_itertools.py @@ -92,3 +92,15 @@ class SortTopologically(TestCase): # Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should # always get the same one. self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) + + def test_duplicates(self): + "Test that a graph with duplicate edges work" + graph = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]} # type: Dict[int, List[int]] + + self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) + + def test_multiple_paths(self): + "Test that a graph with multiple paths between two nodes work" + graph = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]} # type: Dict[int, List[int]] + + self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) -- cgit 1.5.1 From a64c29926efd8460dfc9561d761898197638973d Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 27 Jan 2021 11:49:31 +0000 Subject: Pass a dict, instead of None, to modules if a None config is specified in the homeserver config (#9229) If a Synapse module's config block were empty in YAML, thus being translated to a `Nonetype` in Python, then some modules could fail as that None ends up getting passed to their `parse_config` method. Modules are expected to accept a `dict` instead. This PR ensures that if the user does end up specifying an empty config block (such as what [the default oidc config in the sample config](https://github.com/matrix-org/synapse/blob/5310808d3bebd17275355ecd474bc013e8c7462d/docs/sample_config.yaml#L1816-L1845) states) then `None` is not passed to the module. An empty dict is passed instead. This code assumes that no existing modules are relying on receiving a `None` config block, but I'd really hope that they aren't. --- changelog.d/9229.bugfix | 1 + synapse/util/module_loader.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 changelog.d/9229.bugfix (limited to 'synapse/util') diff --git a/changelog.d/9229.bugfix b/changelog.d/9229.bugfix new file mode 100644 index 0000000000..3ed32291de --- /dev/null +++ b/changelog.d/9229.bugfix @@ -0,0 +1 @@ +Fix a bug where `None` was passed to Synapse modules instead of an empty dictionary if an empty module `config` block was provided in the homeserver config. \ No newline at end of file diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py index 1ee61851e4..09b094ded7 100644 --- a/synapse/util/module_loader.py +++ b/synapse/util/module_loader.py @@ -49,7 +49,8 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]: module = importlib.import_module(module) provider_class = getattr(module, clz) - module_config = provider.get("config") + # Load the module config. If None, pass an empty dictionary instead + module_config = provider.get("config") or {} try: provider_config = provider_class.parse_config(module_config) except jsonschema.ValidationError as e: -- cgit 1.5.1 From 4167494c90bc0477bdf4855a79e81dc81bba1377 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 1 Feb 2021 15:52:50 +0000 Subject: Replace username picker with a template (#9275) There's some prelimiary work here to pull out the construction of a jinja environment to a separate function. I wanted to load the template at display time rather than load time, so that it's easy to update on the fly. Honestly, I think we should do this with all our templates: the risk of ending up with malformed templates is far outweighed by the improved turnaround time for an admin trying to update them. --- changelog.d/9275.feature | 1 + docs/sample_config.yaml | 32 +++++- synapse/config/_base.py | 39 +------ synapse/config/oidc_config.py | 3 +- synapse/config/sso.py | 33 +++++- synapse/handlers/sso.py | 2 +- .../res/templates/sso_auth_account_details.html | 115 +++++++++++++++++++++ synapse/res/templates/sso_auth_account_details.js | 76 ++++++++++++++ synapse/res/username_picker/index.html | 19 ---- synapse/res/username_picker/script.js | 95 ----------------- synapse/res/username_picker/style.css | 27 ----- synapse/rest/consent/consent_resource.py | 1 + synapse/rest/synapse/client/pick_username.py | 79 ++++++++++---- synapse/util/templates.py | 106 +++++++++++++++++++ tests/rest/client/v1/test_login.py | 5 +- 15 files changed, 429 insertions(+), 204 deletions(-) create mode 100644 changelog.d/9275.feature create mode 100644 synapse/res/templates/sso_auth_account_details.html create mode 100644 synapse/res/templates/sso_auth_account_details.js delete mode 100644 synapse/res/username_picker/index.html delete mode 100644 synapse/res/username_picker/script.js delete mode 100644 synapse/res/username_picker/style.css create mode 100644 synapse/util/templates.py (limited to 'synapse/util') diff --git a/changelog.d/9275.feature b/changelog.d/9275.feature new file mode 100644 index 0000000000..c21b197ca1 --- /dev/null +++ b/changelog.d/9275.feature @@ -0,0 +1 @@ +Improve the user experience of setting up an account via single-sign on. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 05506a7787..a6fbcc6080 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1801,7 +1801,8 @@ saml2_config: # # localpart_template: Jinja2 template for the localpart of the MXID. # If this is not set, the user will be prompted to choose their -# own username. +# own username (see 'sso_auth_account_details.html' in the 'sso' +# section of this file). # # display_name_template: Jinja2 template for the display name to set # on first login. If unset, no displayname will be set. @@ -1968,6 +1969,35 @@ sso: # # * idp: the 'idp_id' of the chosen IDP. # + # * HTML page to prompt new users to enter a userid and confirm other + # details: 'sso_auth_account_details.html'. This is only shown if the + # SSO implementation (with any user_mapping_provider) does not return + # a localpart. + # + # When rendering, this template is given the following variables: + # + # * server_name: the homeserver's name. + # + # * idp: details of the SSO Identity Provider that the user logged in + # with: an object with the following attributes: + # + # * idp_id: unique identifier for the IdP + # * idp_name: user-facing name for the IdP + # * idp_icon: if specified in the IdP config, an MXC URI for an icon + # for the IdP + # * idp_brand: if specified in the IdP config, a textual identifier + # for the brand of the IdP + # + # * user_attributes: an object containing details about the user that + # we received from the IdP. May have the following attributes: + # + # * display_name: the user's display_name + # * emails: a list of email addresses + # + # The template should render a form which submits the following fields: + # + # * username: the localpart of the user's chosen user id + # # * HTML page for a confirmation step before redirecting back to the client # with the login token: 'sso_redirect_confirm.html'. # diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 94144efc87..35e5594b73 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -18,18 +18,18 @@ import argparse import errno import os -import time -import urllib.parse from collections import OrderedDict from hashlib import sha256 from textwrap import dedent -from typing import Any, Callable, Iterable, List, MutableMapping, Optional +from typing import Any, Iterable, List, MutableMapping, Optional import attr import jinja2 import pkg_resources import yaml +from synapse.util.templates import _create_mxc_to_http_filter, _format_ts_filter + class ConfigError(Exception): """Represents a problem parsing the configuration @@ -248,6 +248,7 @@ class Config: # Search the custom template directory as well search_directories.insert(0, custom_template_directory) + # TODO: switch to synapse.util.templates.build_jinja_env loader = jinja2.FileSystemLoader(search_directories) env = jinja2.Environment(loader=loader, autoescape=autoescape) @@ -267,38 +268,6 @@ class Config: return templates -def _format_ts_filter(value: int, format: str): - return time.strftime(format, time.localtime(value / 1000)) - - -def _create_mxc_to_http_filter(public_baseurl: str) -> Callable: - """Create and return a jinja2 filter that converts MXC urls to HTTP - - Args: - public_baseurl: The public, accessible base URL of the homeserver - """ - - def mxc_to_http_filter(value, width, height, resize_method="crop"): - if value[0:6] != "mxc://": - return "" - - server_and_media_id = value[6:] - fragment = None - if "#" in server_and_media_id: - server_and_media_id, fragment = server_and_media_id.split("#", 1) - fragment = "#" + fragment - - params = {"width": width, "height": height, "method": resize_method} - return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( - public_baseurl, - server_and_media_id, - urllib.parse.urlencode(params), - fragment or "", - ) - - return mxc_to_http_filter - - class RootConfig: """ Holder of an application's configuration. diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index f31511e039..784b416f95 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -152,7 +152,8 @@ class OIDCConfig(Config): # # localpart_template: Jinja2 template for the localpart of the MXID. # If this is not set, the user will be prompted to choose their - # own username. + # own username (see 'sso_auth_account_details.html' in the 'sso' + # section of this file). # # display_name_template: Jinja2 template for the display name to set # on first login. If unset, no displayname will be set. diff --git a/synapse/config/sso.py b/synapse/config/sso.py index a470112ed4..e308fc9333 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -27,7 +27,7 @@ class SSOConfig(Config): sso_config = config.get("sso") or {} # type: Dict[str, Any] # The sso-specific template_dir - template_dir = sso_config.get("template_dir") + self.sso_template_dir = sso_config.get("template_dir") # Read templates from disk ( @@ -48,7 +48,7 @@ class SSOConfig(Config): "sso_auth_success.html", "sso_auth_bad_user.html", ], - template_dir, + self.sso_template_dir, ) # These templates have no placeholders, so render them here @@ -124,6 +124,35 @@ class SSOConfig(Config): # # * idp: the 'idp_id' of the chosen IDP. # + # * HTML page to prompt new users to enter a userid and confirm other + # details: 'sso_auth_account_details.html'. This is only shown if the + # SSO implementation (with any user_mapping_provider) does not return + # a localpart. + # + # When rendering, this template is given the following variables: + # + # * server_name: the homeserver's name. + # + # * idp: details of the SSO Identity Provider that the user logged in + # with: an object with the following attributes: + # + # * idp_id: unique identifier for the IdP + # * idp_name: user-facing name for the IdP + # * idp_icon: if specified in the IdP config, an MXC URI for an icon + # for the IdP + # * idp_brand: if specified in the IdP config, a textual identifier + # for the brand of the IdP + # + # * user_attributes: an object containing details about the user that + # we received from the IdP. May have the following attributes: + # + # * display_name: the user's display_name + # * emails: a list of email addresses + # + # The template should render a form which submits the following fields: + # + # * username: the localpart of the user's chosen user id + # # * HTML page for a confirmation step before redirecting back to the client # with the login token: 'sso_redirect_confirm.html'. # diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index ceaeb5a376..ff4750999a 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -530,7 +530,7 @@ class SsoHandler: logger.info("Recorded registration session id %s", session_id) # Set the cookie and redirect to the username picker - e = RedirectException(b"/_synapse/client/pick_username") + e = RedirectException(b"/_synapse/client/pick_username/account_details") e.cookies.append( b"%s=%s; path=/" % (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii")) diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html new file mode 100644 index 0000000000..f22b09aec1 --- /dev/null +++ b/synapse/res/templates/sso_auth_account_details.html @@ -0,0 +1,115 @@ + + + + Synapse Login + + + + + +
    +

    Your account is nearly ready

    +

    Check your details before creating an account on {{ server_name }}

    +
    +
    +
    +
    + +
    @
    + +
    :{{ server_name }}
    +
    + + {% if user_attributes %} +
    +

    Information from {{ idp.idp_name }}

    + {% if user_attributes.avatar_url %} +
    + +
    + {% endif %} + {% if user_attributes.display_name %} +
    +

    {{ user_attributes.display_name }}

    +
    + {% endif %} + {% for email in user_attributes.emails %} +
    +

    {{ email }}

    +
    + {% endfor %} +
    + {% endif %} +
    +
    + + + diff --git a/synapse/res/templates/sso_auth_account_details.js b/synapse/res/templates/sso_auth_account_details.js new file mode 100644 index 0000000000..deef419bb6 --- /dev/null +++ b/synapse/res/templates/sso_auth_account_details.js @@ -0,0 +1,76 @@ +const usernameField = document.getElementById("field-username"); + +function throttle(fn, wait) { + let timeout; + return function() { + const args = Array.from(arguments); + if (timeout) { + clearTimeout(timeout); + } + timeout = setTimeout(fn.bind.apply(fn, [null].concat(args)), wait); + } +} + +function checkUsernameAvailable(username) { + let check_uri = 'check?username=' + encodeURIComponent(username); + return fetch(check_uri, { + // include the cookie + "credentials": "same-origin", + }).then((response) => { + if(!response.ok) { + // for non-200 responses, raise the body of the response as an exception + return response.text().then((text) => { throw new Error(text); }); + } else { + return response.json(); + } + }).then((json) => { + if(json.error) { + return {message: json.error}; + } else if(json.available) { + return {available: true}; + } else { + return {message: username + " is not available, please choose another."}; + } + }); +} + +function validateUsername(username) { + usernameField.setCustomValidity(""); + if (usernameField.validity.valueMissing) { + usernameField.setCustomValidity("Please provide a username"); + return; + } + if (usernameField.validity.patternMismatch) { + usernameField.setCustomValidity("Invalid username, please only use " + allowedCharactersString); + return; + } + usernameField.setCustomValidity("Checking if username is available …"); + throttledCheckUsernameAvailable(username); +} + +const throttledCheckUsernameAvailable = throttle(function(username) { + const handleError = function(err) { + // don't prevent form submission on error + usernameField.setCustomValidity(""); + console.log(err.message); + }; + try { + checkUsernameAvailable(username).then(function(result) { + if (!result.available) { + usernameField.setCustomValidity(result.message); + usernameField.reportValidity(); + } else { + usernameField.setCustomValidity(""); + } + }, handleError); + } catch (err) { + handleError(err); + } +}, 500); + +usernameField.addEventListener("input", function(evt) { + validateUsername(usernameField.value); +}); +usernameField.addEventListener("change", function(evt) { + validateUsername(usernameField.value); +}); diff --git a/synapse/res/username_picker/index.html b/synapse/res/username_picker/index.html deleted file mode 100644 index 37ea8bb6d8..0000000000 --- a/synapse/res/username_picker/index.html +++ /dev/null @@ -1,19 +0,0 @@ - - - - Synapse Login - - - -
    -
    - - - -
    - - - -
    - - diff --git a/synapse/res/username_picker/script.js b/synapse/res/username_picker/script.js deleted file mode 100644 index 416a7c6f41..0000000000 --- a/synapse/res/username_picker/script.js +++ /dev/null @@ -1,95 +0,0 @@ -let inputField = document.getElementById("field-username"); -let inputForm = document.getElementById("form"); -let submitButton = document.getElementById("button-submit"); -let message = document.getElementById("message"); - -// Submit username and receive response -function showMessage(messageText) { - // Unhide the message text - message.classList.remove("hidden"); - - message.textContent = messageText; -}; - -function doSubmit() { - showMessage("Success. Please wait a moment for your browser to redirect."); - - // remove the event handler before re-submitting the form. - delete inputForm.onsubmit; - inputForm.submit(); -} - -function onResponse(response) { - // Display message - showMessage(response); - - // Enable submit button and input field - submitButton.classList.remove('button--disabled'); - submitButton.value = "Submit"; -}; - -let allowedUsernameCharacters = RegExp("[^a-z0-9\\.\\_\\=\\-\\/]"); -function usernameIsValid(username) { - return !allowedUsernameCharacters.test(username); -} -let allowedCharactersString = "lowercase letters, digits, ., _, -, /, ="; - -function buildQueryString(params) { - return Object.keys(params) - .map(k => encodeURIComponent(k) + '=' + encodeURIComponent(params[k])) - .join('&'); -} - -function submitUsername(username) { - if(username.length == 0) { - onResponse("Please enter a username."); - return; - } - if(!usernameIsValid(username)) { - onResponse("Invalid username. Only the following characters are allowed: " + allowedCharactersString); - return; - } - - // if this browser doesn't support fetch, skip the availability check. - if(!window.fetch) { - doSubmit(); - return; - } - - let check_uri = 'check?' + buildQueryString({"username": username}); - fetch(check_uri, { - // include the cookie - "credentials": "same-origin", - }).then((response) => { - if(!response.ok) { - // for non-200 responses, raise the body of the response as an exception - return response.text().then((text) => { throw text; }); - } else { - return response.json(); - } - }).then((json) => { - if(json.error) { - throw json.error; - } else if(json.available) { - doSubmit(); - } else { - onResponse("This username is not available, please choose another."); - } - }).catch((err) => { - onResponse("Error checking username availability: " + err); - }); -} - -function clickSubmit() { - event.preventDefault(); - if(submitButton.classList.contains('button--disabled')) { return; } - - // Disable submit button and input field - submitButton.classList.add('button--disabled'); - - // Submit username - submitButton.value = "Checking..."; - submitUsername(inputField.value); -}; - -inputForm.onsubmit = clickSubmit; diff --git a/synapse/res/username_picker/style.css b/synapse/res/username_picker/style.css deleted file mode 100644 index 745bd4c684..0000000000 --- a/synapse/res/username_picker/style.css +++ /dev/null @@ -1,27 +0,0 @@ -input[type="text"] { - font-size: 100%; - background-color: #ededf0; - border: 1px solid #fff; - border-radius: .2em; - padding: .5em .9em; - display: block; - width: 26em; -} - -.button--disabled { - border-color: #fff; - background-color: transparent; - color: #000; - text-transform: none; -} - -.hidden { - display: none; -} - -.tooltip { - background-color: #f9f9fa; - padding: 1em; - margin: 1em 0; -} - diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index b3e4d5612e..8b9ef26cf2 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -100,6 +100,7 @@ class ConsentResource(DirectServeHtmlResource): consent_template_directory = hs.config.user_consent_template_dir + # TODO: switch to synapse.util.templates.build_jinja_env loader = jinja2.FileSystemLoader(consent_template_directory) self._jinja_env = jinja2.Environment( loader=loader, autoescape=jinja2.select_autoescape(["html", "htm", "xml"]) diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py index 1bc737bad0..27540d3bbe 100644 --- a/synapse/rest/synapse/client/pick_username.py +++ b/synapse/rest/synapse/client/pick_username.py @@ -13,41 +13,41 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from typing import TYPE_CHECKING -import pkg_resources - from twisted.web.http import Request from twisted.web.resource import Resource -from twisted.web.static import File +from synapse.api.errors import SynapseError from synapse.handlers.sso import get_username_mapping_session_cookie_from_request -from synapse.http.server import DirectServeHtmlResource, DirectServeJsonResource +from synapse.http.server import ( + DirectServeHtmlResource, + DirectServeJsonResource, + respond_with_html, +) from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest +from synapse.util.templates import build_jinja_env if TYPE_CHECKING: from synapse.server import HomeServer +logger = logging.getLogger(__name__) + def pick_username_resource(hs: "HomeServer") -> Resource: """Factory method to generate the username picker resource. - This resource gets mounted under /_synapse/client/pick_username. The top-level - resource is just a File resource which serves up the static files in the resources - "res" directory, but it has a couple of children: + This resource gets mounted under /_synapse/client/pick_username and has two + children: - * "submit", which does the mechanics of registering the new user, and redirects the - browser back to the client URL - - * "check": checks if a userid is free. + * "account_details": renders the form and handles the POSTed response + * "check": a JSON endpoint which checks if a userid is free. """ - # XXX should we make this path customisable so that admins can restyle it? - base_path = pkg_resources.resource_filename("synapse", "res/username_picker") - - res = File(base_path) - res.putChild(b"submit", SubmitResource(hs)) + res = Resource() + res.putChild(b"account_details", AccountDetailsResource(hs)) res.putChild(b"check", AvailabilityCheckResource(hs)) return res @@ -69,15 +69,54 @@ class AvailabilityCheckResource(DirectServeJsonResource): return 200, {"available": is_available} -class SubmitResource(DirectServeHtmlResource): +class AccountDetailsResource(DirectServeHtmlResource): def __init__(self, hs: "HomeServer"): super().__init__() self._sso_handler = hs.get_sso_handler() - async def _async_render_POST(self, request: SynapseRequest): - localpart = parse_string(request, "username", required=True) + def template_search_dirs(): + if hs.config.sso.sso_template_dir: + yield hs.config.sso.sso_template_dir + yield hs.config.sso.default_template_dir + + self._jinja_env = build_jinja_env(template_search_dirs(), hs.config) + + async def _async_render_GET(self, request: Request) -> None: + try: + session_id = get_username_mapping_session_cookie_from_request(request) + session = self._sso_handler.get_mapping_session(session_id) + except SynapseError as e: + logger.warning("Error fetching session: %s", e) + self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code) + return + + idp_id = session.auth_provider_id + template_params = { + "idp": self._sso_handler.get_identity_providers()[idp_id], + "user_attributes": { + "display_name": session.display_name, + "emails": session.emails, + }, + } + + template = self._jinja_env.get_template("sso_auth_account_details.html") + html = template.render(template_params) + respond_with_html(request, 200, html) - session_id = get_username_mapping_session_cookie_from_request(request) + async def _async_render_POST(self, request: SynapseRequest): + try: + session_id = get_username_mapping_session_cookie_from_request(request) + except SynapseError as e: + logger.warning("Error fetching session cookie: %s", e) + self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code) + return + + try: + localpart = parse_string(request, "username", required=True) + except SynapseError as e: + logger.warning("[session %s] bad param: %s", session_id, e) + self._sso_handler.render_error(request, "bad_param", e.msg, code=e.code) + return await self._sso_handler.handle_submit_username_request( request, localpart, session_id diff --git a/synapse/util/templates.py b/synapse/util/templates.py new file mode 100644 index 0000000000..7e5109d206 --- /dev/null +++ b/synapse/util/templates.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for dealing with jinja2 templates""" + +import time +import urllib.parse +from typing import TYPE_CHECKING, Callable, Iterable, Union + +import jinja2 + +if TYPE_CHECKING: + from synapse.config.homeserver import HomeServerConfig + + +def build_jinja_env( + template_search_directories: Iterable[str], + config: "HomeServerConfig", + autoescape: Union[bool, Callable[[str], bool], None] = None, +) -> jinja2.Environment: + """Set up a Jinja2 environment to load templates from the given search path + + The returned environment defines the following filters: + - format_ts: formats timestamps as strings in the server's local timezone + (XXX: why is that useful??) + - mxc_to_http: converts mxc: uris to http URIs. Args are: + (uri, width, height, resize_method="crop") + + and the following global variables: + - server_name: matrix server name + + Args: + template_search_directories: directories to search for templates + + config: homeserver config, for things like `server_name` and `public_baseurl` + + autoescape: whether template variables should be autoescaped. bool, or + a function mapping from template name to bool. Defaults to escaping templates + whose names end in .html, .xml or .htm. + + Returns: + jinja environment + """ + + if autoescape is None: + autoescape = jinja2.select_autoescape() + + loader = jinja2.FileSystemLoader(template_search_directories) + env = jinja2.Environment(loader=loader, autoescape=autoescape) + + # Update the environment with our custom filters + env.filters.update( + { + "format_ts": _format_ts_filter, + "mxc_to_http": _create_mxc_to_http_filter(config.public_baseurl), + } + ) + + # common variables for all templates + env.globals.update({"server_name": config.server_name}) + + return env + + +def _create_mxc_to_http_filter(public_baseurl: str) -> Callable: + """Create and return a jinja2 filter that converts MXC urls to HTTP + + Args: + public_baseurl: The public, accessible base URL of the homeserver + """ + + def mxc_to_http_filter(value, width, height, resize_method="crop"): + if value[0:6] != "mxc://": + return "" + + server_and_media_id = value[6:] + fragment = None + if "#" in server_and_media_id: + server_and_media_id, fragment = server_and_media_id.split("#", 1) + fragment = "#" + fragment + + params = {"width": width, "height": height, "method": resize_method} + return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( + public_baseurl, + server_and_media_id, + urllib.parse.urlencode(params), + fragment or "", + ) + + return mxc_to_http_filter + + +def _format_ts_filter(value: int, format: str): + return time.strftime(format, time.localtime(value / 1000)) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index ded22a9767..66dfdaffbc 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -1222,7 +1222,7 @@ class UsernamePickerTestCase(HomeserverTestCase): # that should redirect to the username picker self.assertEqual(channel.code, 302, channel.result) picker_url = channel.headers.getRawHeaders("Location")[0] - self.assertEqual(picker_url, "/_synapse/client/pick_username") + self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details") # ... with a username_mapping_session cookie cookies = {} # type: Dict[str,str] @@ -1247,11 +1247,10 @@ class UsernamePickerTestCase(HomeserverTestCase): # Now, submit a username to the username picker, which should serve a redirect # to the completion page - submit_path = picker_url + "/submit" content = urlencode({b"username": b"bobby"}).encode("utf8") chan = self.make_request( "POST", - path=submit_path, + path=picker_url, content=content, content_is_form=True, custom_headers=[ -- cgit 1.5.1 From e40d88cff3cca3d5186d5f623ad1107bc403d69b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 11 Feb 2021 11:16:54 -0500 Subject: Backout changes for automatically calculating the public baseurl. (#9313) This breaks some people's configurations (if their Client-Server API is not accessed via port 443). --- changelog.d/9313.bugfix | 1 + docs/sample_config.yaml | 20 +++++++++----------- synapse/api/urls.py | 2 ++ synapse/config/cas.py | 16 +++++++++------- synapse/config/emailconfig.py | 8 ++++++++ synapse/config/oidc_config.py | 5 ++++- synapse/config/registration.py | 21 +++++++++++++++++---- synapse/config/saml2_config.py | 2 ++ synapse/config/server.py | 13 ++++--------- synapse/config/sso.py | 13 ++++++++----- synapse/handlers/identity.py | 4 ++++ synapse/rest/well_known.py | 4 ++++ synapse/util/templates.py | 15 ++++++++++++--- tests/rest/client/v1/test_login.py | 4 +++- tests/rest/test_well_known.py | 9 +++++++++ tests/utils.py | 1 + 16 files changed, 97 insertions(+), 41 deletions(-) create mode 100644 changelog.d/9313.bugfix (limited to 'synapse/util') diff --git a/changelog.d/9313.bugfix b/changelog.d/9313.bugfix new file mode 100644 index 0000000000..f578fd13dd --- /dev/null +++ b/changelog.d/9313.bugfix @@ -0,0 +1 @@ +Do not automatically calculate `public_baseurl` since it can be wrong in some situations. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 236abd9a3f..d395da11b4 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -74,10 +74,6 @@ pid_file: DATADIR/homeserver.pid # Otherwise, it should be the URL to reach Synapse's client HTTP listener (see # 'listeners' below). # -# If this is left unset, it defaults to 'https:///'. (Note that -# that will not work unless you configure Synapse or a reverse-proxy to listen -# on port 443.) -# #public_baseurl: https://example.com/ # Set the soft limit on the number of file descriptors synapse can use @@ -1169,9 +1165,8 @@ account_validity: # send an email to the account's email address with a renewal link. By # default, no such emails are sent. # - # If you enable this setting, you will also need to fill out the 'email' - # configuration section. You should also check that 'public_baseurl' is set - # correctly. + # If you enable this setting, you will also need to fill out the 'email' and + # 'public_baseurl' configuration sections. # #renew_at: 1w @@ -1262,7 +1257,8 @@ account_validity: # The identity server which we suggest that clients should use when users log # in on this server. # -# (By default, no suggestion is made, so it is left up to the client.) +# (By default, no suggestion is made, so it is left up to the client. +# This setting is ignored unless public_baseurl is also set.) # #default_identity_server: https://matrix.org @@ -1287,6 +1283,8 @@ account_validity: # by the Matrix Identity Service API specification: # https://matrix.org/docs/spec/identity_service/latest # +# If a delegate is specified, the config option public_baseurl must also be filled out. +# account_threepid_delegates: #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process @@ -1938,9 +1936,9 @@ sso: # phishing attacks from evil.site. To avoid this, include a slash after the # hostname: "https://my.client/". # - # The login fallback page (used by clients that don't natively support the - # required login flows) is automatically whitelisted in addition to any URLs - # in this list. + # If public_baseurl is set, then the login fallback page (used by clients + # that don't natively support the required login flows) is whitelisted in + # addition to any URLs in this list. # # By default, this list is empty. # diff --git a/synapse/api/urls.py b/synapse/api/urls.py index e36aeef31f..6379c86dde 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -42,6 +42,8 @@ class ConsentURIBuilder: """ if hs_config.form_secret is None: raise ConfigError("form_secret not set in config") + if hs_config.public_baseurl is None: + raise ConfigError("public_baseurl not set in config") self._hmac_secret = hs_config.form_secret.encode("utf-8") self._public_baseurl = hs_config.public_baseurl diff --git a/synapse/config/cas.py b/synapse/config/cas.py index b226890c2a..aaa7eba110 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config +from ._base import Config, ConfigError class CasConfig(Config): @@ -30,13 +30,15 @@ class CasConfig(Config): if self.cas_enabled: self.cas_server_url = cas_config["server_url"] - public_base_url = cas_config.get("service_url") or self.public_baseurl - if public_base_url[-1] != "/": - public_base_url += "/" + + # The public baseurl is required because it is used by the redirect + # template. + public_baseurl = self.public_baseurl + if not public_baseurl: + raise ConfigError("cas_config requires a public_baseurl to be set") + # TODO Update this to a _synapse URL. - self.cas_service_url = ( - public_base_url + "_matrix/client/r0/login/cas/ticket" - ) + self.cas_service_url = public_baseurl + "_matrix/client/r0/login/cas/ticket" self.cas_displayname_attribute = cas_config.get("displayname_attribute") self.cas_required_attributes = cas_config.get("required_attributes") or {} else: diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 6a487afd34..d4328c46b9 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -166,6 +166,11 @@ class EmailConfig(Config): if not self.email_notif_from: missing.append("email.notif_from") + # public_baseurl is required to build password reset and validation links that + # will be emailed to users + if config.get("public_baseurl") is None: + missing.append("public_baseurl") + if missing: raise ConfigError( MISSING_PASSWORD_RESET_CONFIG_ERROR % (", ".join(missing),) @@ -264,6 +269,9 @@ class EmailConfig(Config): if not self.email_notif_from: missing.append("email.notif_from") + if config.get("public_baseurl") is None: + missing.append("public_baseurl") + if missing: raise ConfigError( "email.enable_notifs is True but required keys are missing: %s" diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 4c24c50629..4d0f24a9d5 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -53,7 +53,10 @@ class OIDCConfig(Config): "Multiple OIDC providers have the idp_id %r." % idp_id ) - self.oidc_callback_url = self.public_baseurl + "_synapse/client/oidc/callback" + public_baseurl = self.public_baseurl + if public_baseurl is None: + raise ConfigError("oidc_config requires a public_baseurl to be set") + self.oidc_callback_url = public_baseurl + "_synapse/client/oidc/callback" @property def oidc_enabled(self) -> bool: diff --git a/synapse/config/registration.py b/synapse/config/registration.py index ac48913a0b..eb650af7fb 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -49,6 +49,10 @@ class AccountValidityConfig(Config): self.startup_job_max_delta = self.period * 10.0 / 100.0 + if self.renew_by_email_enabled: + if "public_baseurl" not in synapse_config: + raise ConfigError("Can't send renewal emails without 'public_baseurl'") + template_dir = config.get("template_dir") if not template_dir: @@ -105,6 +109,13 @@ class RegistrationConfig(Config): account_threepid_delegates = config.get("account_threepid_delegates") or {} self.account_threepid_delegate_email = account_threepid_delegates.get("email") self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn") + if self.account_threepid_delegate_msisdn and not self.public_baseurl: + raise ConfigError( + "The configuration option `public_baseurl` is required if " + "`account_threepid_delegate.msisdn` is set, such that " + "clients know where to submit validation tokens to. Please " + "configure `public_baseurl`." + ) self.default_identity_server = config.get("default_identity_server") self.allow_guest_access = config.get("allow_guest_access", False) @@ -227,9 +238,8 @@ class RegistrationConfig(Config): # send an email to the account's email address with a renewal link. By # default, no such emails are sent. # - # If you enable this setting, you will also need to fill out the 'email' - # configuration section. You should also check that 'public_baseurl' is set - # correctly. + # If you enable this setting, you will also need to fill out the 'email' and + # 'public_baseurl' configuration sections. # #renew_at: 1w @@ -320,7 +330,8 @@ class RegistrationConfig(Config): # The identity server which we suggest that clients should use when users log # in on this server. # - # (By default, no suggestion is made, so it is left up to the client.) + # (By default, no suggestion is made, so it is left up to the client. + # This setting is ignored unless public_baseurl is also set.) # #default_identity_server: https://matrix.org @@ -345,6 +356,8 @@ class RegistrationConfig(Config): # by the Matrix Identity Service API specification: # https://matrix.org/docs/spec/identity_service/latest # + # If a delegate is specified, the config option public_baseurl must also be filled out. + # account_threepid_delegates: #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index ad865a667f..7226abd829 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -189,6 +189,8 @@ class SAML2Config(Config): import saml2 public_baseurl = self.public_baseurl + if public_baseurl is None: + raise ConfigError("saml2_config requires a public_baseurl to be set") if self.saml2_grandfathered_mxid_source_attribute: optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute) diff --git a/synapse/config/server.py b/synapse/config/server.py index 47a0370173..5d72cf2d82 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -161,11 +161,7 @@ class ServerConfig(Config): self.print_pidfile = config.get("print_pidfile") self.user_agent_suffix = config.get("user_agent_suffix") self.use_frozen_dicts = config.get("use_frozen_dicts", False) - self.public_baseurl = config.get("public_baseurl") or "https://%s/" % ( - self.server_name, - ) - if self.public_baseurl[-1] != "/": - self.public_baseurl += "/" + self.public_baseurl = config.get("public_baseurl") # Whether to enable user presence. self.use_presence = config.get("use_presence", True) @@ -321,6 +317,9 @@ class ServerConfig(Config): # Always blacklist 0.0.0.0, :: self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) + if self.public_baseurl is not None: + if self.public_baseurl[-1] != "/": + self.public_baseurl += "/" self.start_pushers = config.get("start_pushers", True) # (undocumented) option for torturing the worker-mode replication a bit, @@ -748,10 +747,6 @@ class ServerConfig(Config): # Otherwise, it should be the URL to reach Synapse's client HTTP listener (see # 'listeners' below). # - # If this is left unset, it defaults to 'https:///'. (Note that - # that will not work unless you configure Synapse or a reverse-proxy to listen - # on port 443.) - # #public_baseurl: https://example.com/ # Set the soft limit on the number of file descriptors synapse can use diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 6c60c6fea4..19bdfd462b 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -64,8 +64,11 @@ class SSOConfig(Config): # gracefully to the client). This would make it pointless to ask the user for # confirmation, since the URL the confirmation page would be showing wouldn't be # the client's. - login_fallback_url = self.public_baseurl + "_matrix/static/client/login" - self.sso_client_whitelist.append(login_fallback_url) + # public_baseurl is an optional setting, so we only add the fallback's URL to the + # list if it's provided (because we can't figure out what that URL is otherwise). + if self.public_baseurl: + login_fallback_url = self.public_baseurl + "_matrix/static/client/login" + self.sso_client_whitelist.append(login_fallback_url) def generate_config_section(self, **kwargs): return """\ @@ -83,9 +86,9 @@ class SSOConfig(Config): # phishing attacks from evil.site. To avoid this, include a slash after the # hostname: "https://my.client/". # - # The login fallback page (used by clients that don't natively support the - # required login flows) is automatically whitelisted in addition to any URLs - # in this list. + # If public_baseurl is set, then the login fallback page (used by clients + # that don't natively support the required login flows) is whitelisted in + # addition to any URLs in this list. # # By default, this list is empty. # diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 4f7137539b..8fc1e8b91c 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -504,6 +504,10 @@ class IdentityHandler(BaseHandler): except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") + # It is already checked that public_baseurl is configured since this code + # should only be used if account_threepid_delegate_msisdn is true. + assert self.hs.config.public_baseurl + # we need to tell the client to send the token back to us, since it doesn't # otherwise know where to send it, so add submit_url response parameter # (see also MSC2078) diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index 241fe746d9..f591cc6c5c 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -34,6 +34,10 @@ class WellKnownBuilder: self._config = hs.config def get_well_known(self): + # if we don't have a public_baseurl, we can't help much here. + if self._config.public_baseurl is None: + return None + result = {"m.homeserver": {"base_url": self._config.public_baseurl}} if self._config.default_identity_server: diff --git a/synapse/util/templates.py b/synapse/util/templates.py index 7e5109d206..392dae4a40 100644 --- a/synapse/util/templates.py +++ b/synapse/util/templates.py @@ -17,7 +17,7 @@ import time import urllib.parse -from typing import TYPE_CHECKING, Callable, Iterable, Union +from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union import jinja2 @@ -74,14 +74,23 @@ def build_jinja_env( return env -def _create_mxc_to_http_filter(public_baseurl: str) -> Callable: +def _create_mxc_to_http_filter( + public_baseurl: Optional[str], +) -> Callable[[str, int, int, str], str]: """Create and return a jinja2 filter that converts MXC urls to HTTP Args: public_baseurl: The public, accessible base URL of the homeserver """ - def mxc_to_http_filter(value, width, height, resize_method="crop"): + def mxc_to_http_filter( + value: str, width: int, height: int, resize_method: str = "crop" + ) -> str: + if not public_baseurl: + raise RuntimeError( + "public_baseurl must be set in the homeserver config to convert MXC URLs to HTTP URLs." + ) + if value[0:6] != "mxc://": return "" diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 66dfdaffbc..bfcb786af8 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -672,10 +672,12 @@ class CASTestCase(unittest.HomeserverTestCase): self.redirect_path = "_synapse/client/login/sso/redirect/confirm" config = self.default_config() + config["public_baseurl"] = ( + config.get("public_baseurl") or "https://matrix.goodserver.com:8448" + ) config["cas_config"] = { "enabled": True, "server_url": CAS_SERVER, - "service_url": "https://matrix.goodserver.com:8448", } cas_user_id = "username" diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index c5e44af9f7..14de0921be 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -40,3 +40,12 @@ class WellKnownTests(unittest.HomeserverTestCase): "m.identity_server": {"base_url": "https://testis"}, }, ) + + def test_well_known_no_public_baseurl(self): + self.hs.config.public_baseurl = None + + channel = self.make_request( + "GET", "/.well-known/matrix/client", shorthand=False + ) + + self.assertEqual(channel.code, 404) diff --git a/tests/utils.py b/tests/utils.py index 68033d7535..840b657f82 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -159,6 +159,7 @@ def default_config(name, parse=False): }, "rc_3pid_validation": {"per_second": 10000, "burst_count": 10000}, "saml2_enabled": False, + "public_baseurl": None, "default_identity_server": None, "key_refresh_interval": 24 * 60 * 60 * 1000, "old_signing_keys": {}, -- cgit 1.5.1