diff options
116 files changed, 1634 insertions, 974 deletions
diff --git a/changelog.d/10877.feature b/changelog.d/10877.feature new file mode 100644 index 0000000000..06a246c108 --- /dev/null +++ b/changelog.d/10877.feature @@ -0,0 +1 @@ +Ensure `(room_id, next_batch_id)` is unique across [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) insertion events in rooms to avoid cross-talk/conflicts between batches. diff --git a/changelog.d/10895.misc b/changelog.d/10895.misc new file mode 100644 index 0000000000..d1c8224980 --- /dev/null +++ b/changelog.d/10895.misc @@ -0,0 +1 @@ +Fix type hints to be compatible with an upcoming change to Twisted. \ No newline at end of file diff --git a/changelog.d/10902.misc b/changelog.d/10902.misc new file mode 100644 index 0000000000..2cd79887f6 --- /dev/null +++ b/changelog.d/10902.misc @@ -0,0 +1 @@ +Update utility code to handle C implementations of frozendict. \ No newline at end of file diff --git a/changelog.d/10903.misc b/changelog.d/10903.misc new file mode 100644 index 0000000000..2716ccb08c --- /dev/null +++ b/changelog.d/10903.misc @@ -0,0 +1 @@ +Drop old functionality which maintained database compatibility with Synapse versions before 1.31. diff --git a/changelog.d/10915.misc b/changelog.d/10915.misc new file mode 100644 index 0000000000..1ce2910ffa --- /dev/null +++ b/changelog.d/10915.misc @@ -0,0 +1 @@ +Clean-up configuration helper classes for the `ServerConfig` class. diff --git a/changelog.d/10916.misc b/changelog.d/10916.misc new file mode 100644 index 0000000000..586a0b3a96 --- /dev/null +++ b/changelog.d/10916.misc @@ -0,0 +1 @@ +Use direct references to config flags. diff --git a/changelog.d/10922.bugfix b/changelog.d/10922.bugfix new file mode 100644 index 0000000000..b7315514e0 --- /dev/null +++ b/changelog.d/10922.bugfix @@ -0,0 +1 @@ +Fix a minor bug in the response to `/_matrix/client/r0/voip/turnServer`. Contributed by @lukaslihotzki. diff --git a/changelog.d/10924.bugfix b/changelog.d/10924.bugfix new file mode 100644 index 0000000000..c73a51e32f --- /dev/null +++ b/changelog.d/10924.bugfix @@ -0,0 +1 @@ +Fix a bug where empty `yyyy-mm-dd/` directories would be left behind in the media store's `url_cache_thumbnails/` directory. diff --git a/changelog.d/10926.misc b/changelog.d/10926.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/10926.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/changelog.d/10927.bugfix b/changelog.d/10927.bugfix new file mode 100644 index 0000000000..fd24288c54 --- /dev/null +++ b/changelog.d/10927.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse v1.40.0 where the signature checks for room version 8/9 could be applied to earlier room versions in some situations. diff --git a/changelog.d/10934.misc b/changelog.d/10934.misc new file mode 100644 index 0000000000..56c640ec9e --- /dev/null +++ b/changelog.d/10934.misc @@ -0,0 +1 @@ +Refactor various parts of the codebase to use `RoomVersion` objects instead of room version identifier strings. diff --git a/changelog.d/10935.misc b/changelog.d/10935.misc new file mode 100644 index 0000000000..80529c04ca --- /dev/null +++ b/changelog.d/10935.misc @@ -0,0 +1 @@ +Refactor user directory tests in preparation for upcoming changes. diff --git a/changelog.d/10936.misc b/changelog.d/10936.misc new file mode 100644 index 0000000000..9d1d6e5b02 --- /dev/null +++ b/changelog.d/10936.misc @@ -0,0 +1 @@ +Include the event id in the logcontext when handling PDUs received over federation. diff --git a/changelog.d/10939.misc b/changelog.d/10939.misc new file mode 100644 index 0000000000..a7cecf8a5b --- /dev/null +++ b/changelog.d/10939.misc @@ -0,0 +1 @@ +Fix logged errors in unit tests. diff --git a/changelog.d/10940.misc b/changelog.d/10940.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/10940.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/changelog.d/10945.misc b/changelog.d/10945.misc new file mode 100644 index 0000000000..7cf1f02ad6 --- /dev/null +++ b/changelog.d/10945.misc @@ -0,0 +1 @@ +Fix a broken test to ensure that consent configuration works during registration. diff --git a/changelog.d/10958.misc b/changelog.d/10958.misc new file mode 100644 index 0000000000..409ecc35cb --- /dev/null +++ b/changelog.d/10958.misc @@ -0,0 +1 @@ +Add type hints to filtering classes. diff --git a/changelog.d/10959.misc b/changelog.d/10959.misc new file mode 100644 index 0000000000..586a0b3a96 --- /dev/null +++ b/changelog.d/10959.misc @@ -0,0 +1 @@ +Use direct references to config flags. diff --git a/changelog.d/10960.bugfix b/changelog.d/10960.bugfix new file mode 100644 index 0000000000..b4f1c228ea --- /dev/null +++ b/changelog.d/10960.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where rebuilding the user directory wouldn't exclude support and disabled users. \ No newline at end of file diff --git a/changelog.d/10961.misc b/changelog.d/10961.misc new file mode 100644 index 0000000000..0e35813488 --- /dev/null +++ b/changelog.d/10961.misc @@ -0,0 +1 @@ +Add type-hint to `HomeserverTestcase.setup_test_homeserver`. \ No newline at end of file diff --git a/changelog.d/9655.feature b/changelog.d/9655.feature new file mode 100644 index 0000000000..70cac230d8 --- /dev/null +++ b/changelog.d/9655.feature @@ -0,0 +1 @@ +Add [MSC3069](https://github.com/matrix-org/matrix-doc/pull/3069) support to `/account/whoami`. \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index 437d0a46a5..568166db33 100644 --- a/mypy.ini +++ b/mypy.ini @@ -162,6 +162,12 @@ disallow_untyped_defs = True [mypy-synapse.util.wheel_timer] disallow_untyped_defs = True +[mypy-tests.handlers.test_user_directory] +disallow_untyped_defs = True + +[mypy-tests.storage.test_user_directory] +disallow_untyped_defs = True + [mypy-pymacaroons.*] ignore_missing_imports = True diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index ad1ff6a9df..20e91a115d 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -15,7 +15,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -from typing import List +from typing import ( + TYPE_CHECKING, + Awaitable, + Container, + Iterable, + List, + Optional, + Set, + TypeVar, + Union, +) import jsonschema from jsonschema import FormatChecker @@ -23,7 +33,11 @@ from jsonschema import FormatChecker from synapse.api.constants import EventContentFields from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState -from synapse.types import RoomID, UserID +from synapse.events import EventBase +from synapse.types import JsonDict, RoomID, UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer FILTER_SCHEMA = { "additionalProperties": False, @@ -120,25 +134,29 @@ USER_FILTER_SCHEMA = { @FormatChecker.cls_checks("matrix_room_id") -def matrix_room_id_validator(room_id_str): +def matrix_room_id_validator(room_id_str: str) -> RoomID: return RoomID.from_string(room_id_str) @FormatChecker.cls_checks("matrix_user_id") -def matrix_user_id_validator(user_id_str): +def matrix_user_id_validator(user_id_str: str) -> UserID: return UserID.from_string(user_id_str) class Filtering: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.store = hs.get_datastore() - async def get_user_filter(self, user_localpart, filter_id): + async def get_user_filter( + self, user_localpart: str, filter_id: Union[int, str] + ) -> "FilterCollection": result = await self.store.get_user_filter(user_localpart, filter_id) return FilterCollection(result) - def add_user_filter(self, user_localpart, user_filter): + def add_user_filter( + self, user_localpart: str, user_filter: JsonDict + ) -> Awaitable[int]: self.check_valid_filter(user_filter) return self.store.add_user_filter(user_localpart, user_filter) @@ -146,13 +164,13 @@ class Filtering: # replace_user_filter at some point? There's no REST API specified for # them however - def check_valid_filter(self, user_filter_json): + def check_valid_filter(self, user_filter_json: JsonDict) -> None: """Check if the provided filter is valid. This inspects all definitions contained within the filter. Args: - user_filter_json(dict): The filter + user_filter_json: The filter Raises: SynapseError: If the filter is not valid. """ @@ -167,8 +185,12 @@ class Filtering: raise SynapseError(400, str(e)) +# Filters work across events, presence EDUs, and account data. +FilterEvent = TypeVar("FilterEvent", EventBase, UserPresenceState, JsonDict) + + class FilterCollection: - def __init__(self, filter_json): + def __init__(self, filter_json: JsonDict): self._filter_json = filter_json room_filter_json = self._filter_json.get("room", {}) @@ -188,25 +210,25 @@ class FilterCollection: self.event_fields = filter_json.get("event_fields", []) self.event_format = filter_json.get("event_format", "client") - def __repr__(self): + def __repr__(self) -> str: return "<FilterCollection %s>" % (json.dumps(self._filter_json),) - def get_filter_json(self): + def get_filter_json(self) -> JsonDict: return self._filter_json - def timeline_limit(self): + def timeline_limit(self) -> int: return self._room_timeline_filter.limit() - def presence_limit(self): + def presence_limit(self) -> int: return self._presence_filter.limit() - def ephemeral_limit(self): + def ephemeral_limit(self) -> int: return self._room_ephemeral_filter.limit() - def lazy_load_members(self): + def lazy_load_members(self) -> bool: return self._room_state_filter.lazy_load_members() - def include_redundant_members(self): + def include_redundant_members(self) -> bool: return self._room_state_filter.include_redundant_members() def filter_presence(self, events): @@ -218,29 +240,31 @@ class FilterCollection: def filter_room_state(self, events): return self._room_state_filter.filter(self._room_filter.filter(events)) - def filter_room_timeline(self, events): + def filter_room_timeline(self, events: Iterable[FilterEvent]) -> List[FilterEvent]: return self._room_timeline_filter.filter(self._room_filter.filter(events)) - def filter_room_ephemeral(self, events): + def filter_room_ephemeral(self, events: Iterable[FilterEvent]) -> List[FilterEvent]: return self._room_ephemeral_filter.filter(self._room_filter.filter(events)) - def filter_room_account_data(self, events): + def filter_room_account_data( + self, events: Iterable[FilterEvent] + ) -> List[FilterEvent]: return self._room_account_data.filter(self._room_filter.filter(events)) - def blocks_all_presence(self): + def blocks_all_presence(self) -> bool: return ( self._presence_filter.filters_all_types() or self._presence_filter.filters_all_senders() ) - def blocks_all_room_ephemeral(self): + def blocks_all_room_ephemeral(self) -> bool: return ( self._room_ephemeral_filter.filters_all_types() or self._room_ephemeral_filter.filters_all_senders() or self._room_ephemeral_filter.filters_all_rooms() ) - def blocks_all_room_timeline(self): + def blocks_all_room_timeline(self) -> bool: return ( self._room_timeline_filter.filters_all_types() or self._room_timeline_filter.filters_all_senders() @@ -249,7 +273,7 @@ class FilterCollection: class Filter: - def __init__(self, filter_json): + def __init__(self, filter_json: JsonDict): self.filter_json = filter_json self.types = self.filter_json.get("types", None) @@ -266,20 +290,20 @@ class Filter: self.labels = self.filter_json.get("org.matrix.labels", None) self.not_labels = self.filter_json.get("org.matrix.not_labels", []) - def filters_all_types(self): + def filters_all_types(self) -> bool: return "*" in self.not_types - def filters_all_senders(self): + def filters_all_senders(self) -> bool: return "*" in self.not_senders - def filters_all_rooms(self): + def filters_all_rooms(self) -> bool: return "*" in self.not_rooms - def check(self, event): + def check(self, event: FilterEvent) -> bool: """Checks whether the filter matches the given event. Returns: - bool: True if the event matches + True if the event matches """ # We usually get the full "events" as dictionaries coming through, # except for presence which actually gets passed around as its own @@ -305,18 +329,25 @@ class Filter: room_id = event.get("room_id", None) ev_type = event.get("type", None) - content = event.get("content", {}) + content = event.get("content") or {} # check if there is a string url field in the content for filtering purposes contains_url = isinstance(content.get("url"), str) labels = content.get(EventContentFields.LABELS, []) return self.check_fields(room_id, sender, ev_type, labels, contains_url) - def check_fields(self, room_id, sender, event_type, labels, contains_url): + def check_fields( + self, + room_id: Optional[str], + sender: Optional[str], + event_type: Optional[str], + labels: Container[str], + contains_url: bool, + ) -> bool: """Checks whether the filter matches the given event fields. Returns: - bool: True if the event fields match + True if the event fields match """ literal_keys = { "rooms": lambda v: room_id == v, @@ -343,14 +374,14 @@ class Filter: return True - def filter_rooms(self, room_ids): + def filter_rooms(self, room_ids: Iterable[str]) -> Set[str]: """Apply the 'rooms' filter to a given list of rooms. Args: - room_ids (list): A list of room_ids. + room_ids: A list of room_ids. Returns: - list: A list of room_ids that match the filter + A list of room_ids that match the filter """ room_ids = set(room_ids) @@ -363,23 +394,23 @@ class Filter: return room_ids - def filter(self, events): + def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]: return list(filter(self.check, events)) - def limit(self): + def limit(self) -> int: return self.filter_json.get("limit", 10) - def lazy_load_members(self): + def lazy_load_members(self) -> bool: return self.filter_json.get("lazy_load_members", False) - def include_redundant_members(self): + def include_redundant_members(self) -> bool: return self.filter_json.get("include_redundant_members", False) - def with_room_ids(self, room_ids): + def with_room_ids(self, room_ids: Iterable[str]) -> "Filter": """Returns a new filter with the given room IDs appended. Args: - room_ids (iterable[unicode]): The room_ids to add + room_ids: The room_ids to add Returns: filter: A new filter including the given rooms and the old @@ -390,8 +421,8 @@ class Filter: return newFilter -def _matches_wildcard(actual_value, filter_value): - if filter_value.endswith("*"): +def _matches_wildcard(actual_value: Optional[str], filter_value: str) -> bool: + if filter_value.endswith("*") and isinstance(actual_value, str): type_prefix = filter_value[:-1] return actual_value.startswith(type_prefix) else: diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 548f6dcde9..749bc1deb9 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -86,11 +86,11 @@ def start_worker_reactor(appname, config, run_command=reactor.run): start_reactor( appname, - soft_file_limit=config.soft_file_limit, - gc_thresholds=config.gc_thresholds, + soft_file_limit=config.server.soft_file_limit, + gc_thresholds=config.server.gc_thresholds, pid_file=config.worker.worker_pid_file, daemonize=config.worker.worker_daemonize, - print_pidfile=config.print_pidfile, + print_pidfile=config.server.print_pidfile, logger=logger, run_command=run_command, ) @@ -298,7 +298,7 @@ def refresh_certificate(hs): Refresh the TLS certificates that Synapse is using by re-reading them from disk and updating the TLS context factories to use them. """ - if not hs.config.has_tls_listener(): + if not hs.config.server.has_tls_listener(): return hs.config.read_certificate_from_disk() diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index f2c5b75247..556bcc124e 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -195,14 +195,14 @@ def start(config_options): config.logging.no_redirect_stdio = True # Explicitly disable background processes - config.update_user_directory = False + config.server.update_user_directory = False config.worker.run_background_tasks = False config.start_pushers = False config.pusher_shard_config.instances = [] config.send_federation = False config.federation_shard_config.instances = [] - synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts + synapse.events.USE_FROZEN_DICTS = config.server.use_frozen_dicts ss = AdminCmdServer( config.server.server_name, diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 3036e1b4a0..7489f31d9a 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -462,7 +462,7 @@ def start(config_options): # For other worker types we force this to off. config.server.update_user_directory = False - synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts + synapse.events.USE_FROZEN_DICTS = config.server.use_frozen_dicts synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage if config.server.gc_seconds: diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 205831dcda..2b2d4bbf83 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -248,7 +248,7 @@ class SynapseHomeServer(HomeServer): resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) if name == "webclient": - webclient_loc = self.config.web_client_location + webclient_loc = self.config.server.web_client_location if webclient_loc is None: logger.warning( @@ -343,7 +343,7 @@ def setup(config_options): # generating config files and shouldn't try to continue. sys.exit(0) - events.USE_FROZEN_DICTS = config.use_frozen_dicts + events.USE_FROZEN_DICTS = config.server.use_frozen_dicts synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage if config.server.gc_seconds: @@ -439,11 +439,11 @@ def run(hs): _base.start_reactor( "synapse-homeserver", - soft_file_limit=hs.config.soft_file_limit, - gc_thresholds=hs.config.gc_thresholds, - pid_file=hs.config.pid_file, - daemonize=hs.config.daemonize, - print_pidfile=hs.config.print_pidfile, + soft_file_limit=hs.config.server.soft_file_limit, + gc_thresholds=hs.config.server.gc_thresholds, + pid_file=hs.config.server.pid_file, + daemonize=hs.config.server.daemonize, + print_pidfile=hs.config.server.print_pidfile, logger=logger, ) diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index 49e7a45e5c..fcd01e833c 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -74,7 +74,7 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process): store = hs.get_datastore() stats["homeserver"] = hs.config.server.server_name - stats["server_context"] = hs.config.server_context + stats["server_context"] = hs.config.server.server_context stats["timestamp"] = now stats["uptime_seconds"] = uptime version = sys.version_info @@ -171,7 +171,7 @@ def start_phone_stats_home(hs): current_mau_count_by_service = {} reserved_users = () store = hs.get_datastore() - if hs.config.limit_usage_by_mau or hs.config.mau_stats_only: + if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only: current_mau_count = await store.get_monthly_active_count() current_mau_count_by_service = ( await store.get_monthly_active_count_by_service() @@ -183,9 +183,9 @@ def start_phone_stats_home(hs): current_mau_by_service_gauge.labels(app_service).set(float(count)) registered_reserved_users_mau_gauge.set(float(len(reserved_users))) - max_mau_gauge.set(float(hs.config.max_mau_value)) + max_mau_gauge.set(float(hs.config.server.max_mau_value)) - if hs.config.limit_usage_by_mau or hs.config.mau_stats_only: + if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only: generate_monthly_active_users() clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000) # End of monthly active user settings diff --git a/synapse/config/_base.py b/synapse/config/_base.py index d974a1a2a8..26152b0924 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -327,7 +327,7 @@ class RootConfig: """ Redirect lookups on this object either to config objects, or values on config objects, so that `config.tls.blah` works, as well as legacy uses - of things like `config.server_name`. It will first look up the config + of things like `config.server.server_name`. It will first look up the config section name, and then values on those config classes. """ if item in self._configs.keys(): diff --git a/synapse/config/server.py b/synapse/config/server.py index ad8715da29..818b806357 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -1,6 +1,4 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2017-2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,7 +17,7 @@ import logging import os.path import re from textwrap import indent -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import attr import yaml @@ -184,49 +182,74 @@ KNOWN_RESOURCES = { @attr.s(frozen=True) class HttpResourceConfig: - names = attr.ib( - type=List[str], + names: List[str] = attr.ib( factory=list, validator=attr.validators.deep_iterable(attr.validators.in_(KNOWN_RESOURCES)), # type: ignore ) - compress = attr.ib( - type=bool, + compress: bool = attr.ib( default=False, validator=attr.validators.optional(attr.validators.instance_of(bool)), # type: ignore[arg-type] ) -@attr.s(frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class HttpListenerConfig: """Object describing the http-specific parts of the config of a listener""" - x_forwarded = attr.ib(type=bool, default=False) - resources = attr.ib(type=List[HttpResourceConfig], factory=list) - additional_resources = attr.ib(type=Dict[str, dict], factory=dict) - tag = attr.ib(type=str, default=None) + x_forwarded: bool = False + resources: List[HttpResourceConfig] = attr.ib(factory=list) + additional_resources: Dict[str, dict] = attr.ib(factory=dict) + tag: Optional[str] = None -@attr.s(frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class ListenerConfig: """Object describing the configuration of a single listener.""" - port = attr.ib(type=int, validator=attr.validators.instance_of(int)) - bind_addresses = attr.ib(type=List[str]) - type = attr.ib(type=str, validator=attr.validators.in_(KNOWN_LISTENER_TYPES)) - tls = attr.ib(type=bool, default=False) + port: int = attr.ib(validator=attr.validators.instance_of(int)) + bind_addresses: List[str] + type: str = attr.ib(validator=attr.validators.in_(KNOWN_LISTENER_TYPES)) + tls: bool = False # http_options is only populated if type=http - http_options = attr.ib(type=Optional[HttpListenerConfig], default=None) + http_options: Optional[HttpListenerConfig] = None -@attr.s(frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class ManholeConfig: """Object describing the configuration of the manhole""" - username = attr.ib(type=str, validator=attr.validators.instance_of(str)) - password = attr.ib(type=str, validator=attr.validators.instance_of(str)) - priv_key = attr.ib(type=Optional[Key]) - pub_key = attr.ib(type=Optional[Key]) + username: str = attr.ib(validator=attr.validators.instance_of(str)) + password: str = attr.ib(validator=attr.validators.instance_of(str)) + priv_key: Optional[Key] + pub_key: Optional[Key] + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RetentionConfig: + """Object describing the configuration of the manhole""" + + interval: int + shortest_max_lifetime: Optional[int] + longest_max_lifetime: Optional[int] + + +@attr.s(frozen=True) +class LimitRemoteRoomsConfig: + enabled: bool = attr.ib(validator=attr.validators.instance_of(bool), default=False) + complexity: Union[float, int] = attr.ib( + validator=attr.validators.instance_of( + (float, int) # type: ignore[arg-type] # noqa + ), + default=1.0, + ) + complexity_error: str = attr.ib( + validator=attr.validators.instance_of(str), + default=ROOM_COMPLEXITY_TOO_GREAT, + ) + admins_can_join: bool = attr.ib( + validator=attr.validators.instance_of(bool), default=False + ) class ServerConfig(Config): @@ -519,7 +542,7 @@ class ServerConfig(Config): " greater than 'allowed_lifetime_max'" ) - self.retention_purge_jobs: List[Dict[str, Optional[int]]] = [] + self.retention_purge_jobs: List[RetentionConfig] = [] for purge_job_config in retention_config.get("purge_jobs", []): interval_config = purge_job_config.get("interval") @@ -553,20 +576,12 @@ class ServerConfig(Config): ) self.retention_purge_jobs.append( - { - "interval": interval, - "shortest_max_lifetime": shortest_max_lifetime, - "longest_max_lifetime": longest_max_lifetime, - } + RetentionConfig(interval, shortest_max_lifetime, longest_max_lifetime) ) if not self.retention_purge_jobs: self.retention_purge_jobs = [ - { - "interval": self.parse_duration("1d"), - "shortest_max_lifetime": None, - "longest_max_lifetime": None, - } + RetentionConfig(self.parse_duration("1d"), None, None) ] self.listeners = [parse_listener_def(x) for x in config.get("listeners", [])] @@ -591,25 +606,6 @@ class ServerConfig(Config): self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None)) self.gc_seconds = self.read_gc_intervals(config.get("gc_min_interval", None)) - @attr.s - class LimitRemoteRoomsConfig: - enabled = attr.ib( - validator=attr.validators.instance_of(bool), default=False - ) - complexity = attr.ib( - validator=attr.validators.instance_of( - (float, int) # type: ignore[arg-type] # noqa - ), - default=1.0, - ) - complexity_error = attr.ib( - validator=attr.validators.instance_of(str), - default=ROOM_COMPLEXITY_TOO_GREAT, - ) - admins_can_join = attr.ib( - validator=attr.validators.instance_of(bool), default=False - ) - self.limit_remote_rooms = LimitRemoteRoomsConfig( **(config.get("limit_remote_rooms") or {}) ) diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 650402836c..7a1adc2750 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -41,42 +41,112 @@ from synapse.types import StateMap, UserID, get_domain_from_id logger = logging.getLogger(__name__) -def check( - room_version_obj: RoomVersion, - event: EventBase, - auth_events: StateMap[EventBase], - do_sig_check: bool = True, - do_size_check: bool = True, +def validate_event_for_room_version( + room_version_obj: RoomVersion, event: EventBase ) -> None: - """Checks if this event is correctly authed. + """Ensure that the event complies with the limits, and has the right signatures + + NB: does not *validate* the signatures - it assumes that any signatures present + have already been checked. + + NB: it does not check that the event satisfies the auth rules (that is done in + check_auth_rules_for_event) - these tests are independent of the rest of the state + in the room. + + NB: This is used to check events that have been received over federation. As such, + it can only enforce the checks specified in the relevant room version, to avoid + a split-brain situation where some servers accept such events, and others reject + them. + + TODO: consider moving this into EventValidator Args: - room_version_obj: the version of the room - event: the event being checked. - auth_events: the existing room state. - do_sig_check: True if it should be verified that the sending server - signed the event. - do_size_check: True if the size of the event fields should be verified. + room_version_obj: the version of the room which contains this event + event: the event to be checked Raises: - AuthError if the checks fail - - Returns: - if the auth checks pass. + SynapseError if there is a problem with the event """ - assert isinstance(auth_events, dict) - - if do_size_check: - _check_size_limits(event) + _check_size_limits(event) if not hasattr(event, "room_id"): raise AuthError(500, "Event has no room_id: %s" % event) - room_id = event.room_id + # check that the event has the correct signatures + sender_domain = get_domain_from_id(event.sender) + + is_invite_via_3pid = ( + event.type == EventTypes.Member + and event.membership == Membership.INVITE + and "third_party_invite" in event.content + ) + + # Check the sender's domain has signed the event + if not event.signatures.get(sender_domain): + # We allow invites via 3pid to have a sender from a different + # HS, as the sender must match the sender of the original + # 3pid invite. This is checked further down with the + # other dedicated membership checks. + if not is_invite_via_3pid: + raise AuthError(403, "Event not signed by sender's server") + + if event.format_version in (EventFormatVersions.V1,): + # Only older room versions have event IDs to check. + event_id_domain = get_domain_from_id(event.event_id) + + # Check the origin domain has signed the event + if not event.signatures.get(event_id_domain): + raise AuthError(403, "Event not signed by sending server") + + is_invite_via_allow_rule = ( + room_version_obj.msc3083_join_rules + and event.type == EventTypes.Member + and event.membership == Membership.JOIN + and EventContentFields.AUTHORISING_USER in event.content + ) + if is_invite_via_allow_rule: + authoriser_domain = get_domain_from_id( + event.content[EventContentFields.AUTHORISING_USER] + ) + if not event.signatures.get(authoriser_domain): + raise AuthError(403, "Event not signed by authorising server") + + +def check_auth_rules_for_event( + room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase] +) -> None: + """Check that an event complies with the auth rules + + Checks whether an event passes the auth rules with a given set of state events + + Assumes that we have already checked that the event is the right shape (it has + enough signatures, has a room ID, etc). In other words: + + - it's fine for use in state resolution, when we have already decided whether to + accept the event or not, and are now trying to decide whether it should make it + into the room state + + - when we're doing the initial event auth, it is only suitable in combination with + a bunch of other tests. + + Args: + room_version_obj: the version of the room + event: the event being checked. + auth_events: the room state to check the events against. + + Raises: + AuthError if the checks fail + """ + assert isinstance(auth_events, dict) # We need to ensure that the auth events are actually for the same room, to # stop people from using powers they've been granted in other rooms for # example. + # + # Arguably we don't need to do this when we're just doing state res, as presumably + # the state res algorithm isn't silly enough to give us events from different rooms. + # Still, it's easier to do it anyway. + room_id = event.room_id for auth_event in auth_events.values(): if auth_event.room_id != room_id: raise AuthError( @@ -86,44 +156,6 @@ def check( % (event.event_id, room_id, auth_event.event_id, auth_event.room_id), ) - if do_sig_check: - sender_domain = get_domain_from_id(event.sender) - - is_invite_via_3pid = ( - event.type == EventTypes.Member - and event.membership == Membership.INVITE - and "third_party_invite" in event.content - ) - - # Check the sender's domain has signed the event - if not event.signatures.get(sender_domain): - # We allow invites via 3pid to have a sender from a different - # HS, as the sender must match the sender of the original - # 3pid invite. This is checked further down with the - # other dedicated membership checks. - if not is_invite_via_3pid: - raise AuthError(403, "Event not signed by sender's server") - - if event.format_version in (EventFormatVersions.V1,): - # Only older room versions have event IDs to check. - event_id_domain = get_domain_from_id(event.event_id) - - # Check the origin domain has signed the event - if not event.signatures.get(event_id_domain): - raise AuthError(403, "Event not signed by sending server") - - is_invite_via_allow_rule = ( - event.type == EventTypes.Member - and event.membership == Membership.JOIN - and EventContentFields.AUTHORISING_USER in event.content - ) - if is_invite_via_allow_rule: - authoriser_domain = get_domain_from_id( - event.content[EventContentFields.AUTHORISING_USER] - ) - if not event.signatures.get(authoriser_domain): - raise AuthError(403, "Event not signed by authorising server") - # Implementation of https://matrix.org/docs/spec/rooms/v1#authorization-rules # # 1. If type is m.room.create: diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 87e2bb123b..50f2a4c1f4 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -18,10 +18,8 @@ import attr from nacl.signing import SigningKey from synapse.api.constants import MAX_DEPTH -from synapse.api.errors import UnsupportedRoomVersionError from synapse.api.room_versions import ( KNOWN_EVENT_FORMAT_VERSIONS, - KNOWN_ROOM_VERSIONS, EventFormatVersions, RoomVersion, ) @@ -197,24 +195,6 @@ class EventBuilderFactory: self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() - def new(self, room_version: str, key_values: dict) -> EventBuilder: - """Generate an event builder appropriate for the given room version - - Deprecated: use for_room_version with a RoomVersion object instead - - Args: - room_version: Version of the room that we're creating an event builder for - key_values: Fields used as the basis of the new event - - Returns: - EventBuilder - """ - v = KNOWN_ROOM_VERSIONS.get(room_version) - if not v: - # this can happen if support is withdrawn for a room version - raise UnsupportedRoomVersionError() - return self.for_room_version(v, key_values) - def for_room_version( self, room_version: RoomVersion, key_values: dict ) -> EventBuilder: diff --git a/synapse/events/presence_router.py b/synapse/events/presence_router.py index eb4556cdc1..68b8b19024 100644 --- a/synapse/events/presence_router.py +++ b/synapse/events/presence_router.py @@ -45,11 +45,11 @@ def load_legacy_presence_router(hs: "HomeServer"): configuration, and registers the hooks they implement. """ - if hs.config.presence_router_module_class is None: + if hs.config.server.presence_router_module_class is None: return - module = hs.config.presence_router_module_class - config = hs.config.presence_router_config + module = hs.config.server.presence_router_module_class + config = hs.config.server.presence_router_config api = hs.get_module_api() presence_router = module(config=config, module_api=api) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 38fccd1efc..520edbbf61 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -372,7 +372,7 @@ class EventClientSerializer: def __init__(self, hs): self.store = hs.get_datastore() self.experimental_msc1849_support_enabled = ( - hs.config.experimental_msc1849_support_enabled + hs.config.server.experimental_msc1849_support_enabled ) async def serialize_event( diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 5f4383eebc..d8c0b86f23 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1008,7 +1008,10 @@ class FederationServer(FederationBase): async with lock: logger.info("handling received PDU: %s", event) try: - await self._federation_event_handler.on_receive_pdu(origin, event) + with nested_logging_context(event.event_id): + await self._federation_event_handler.on_receive_pdu( + origin, event + ) except FederationError as e: # XXX: Ideally we'd inform the remote we failed to process # the event, but we can't return an error in the transaction diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index 95176ba6f9..c32539bf5a 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -117,7 +117,7 @@ class PublicRoomList(BaseFederationServlet): ): super().__init__(hs, authenticator, ratelimiter, server_name) self.handler = hs.get_room_list_handler() - self.allow_access = hs.config.allow_public_rooms_over_federation + self.allow_access = hs.config.server.allow_public_rooms_over_federation async def on_GET( self, origin: str, content: Literal[None], query: Dict[bytes, List[bytes]] diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index a8c717efd5..2d0f3d566c 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -198,7 +198,7 @@ class AuthHandler(BaseHandler): if inst.is_enabled(): self.checkers[inst.AUTH_TYPE] = inst # type: ignore - self.bcrypt_rounds = hs.config.bcrypt_rounds + self.bcrypt_rounds = hs.config.registration.bcrypt_rounds # we can't use hs.get_module_api() here, because to do so will create an # import loop. diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 5cfba3c817..9078781d5a 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -49,7 +49,7 @@ class DirectoryHandler(BaseHandler): self.store = hs.get_datastore() self.config = hs.config self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search - self.require_membership = hs.config.require_membership_for_aliases + self.require_membership = hs.config.server.require_membership_for_aliases self.third_party_event_rules = hs.get_third_party_event_rules() self.federation = hs.get_federation_client() diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index cb81fa0986..d089c56286 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -22,7 +22,8 @@ from synapse.api.constants import ( RestrictedJoinRuleTypes, ) from synapse.api.errors import AuthError, Codes, SynapseError -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion +from synapse.api.room_versions import RoomVersion +from synapse.event_auth import check_auth_rules_for_event from synapse.events import EventBase from synapse.events.builder import EventBuilder from synapse.events.snapshot import EventContext @@ -45,21 +46,17 @@ class EventAuthHandler: self._store = hs.get_datastore() self._server_name = hs.hostname - async def check_from_context( + async def check_auth_rules_from_context( self, - room_version: str, + room_version_obj: RoomVersion, event: EventBase, context: EventContext, - do_sig_check: bool = True, ) -> None: + """Check an event passes the auth rules at its own auth events""" auth_event_ids = event.auth_event_ids() auth_events_by_id = await self._store.get_events(auth_event_ids) auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()} - - room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - event_auth.check( - room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check - ) + check_auth_rules_for_event(room_version_obj, event, auth_events) def compute_auth_events( self, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index adbd150e46..043ca4a224 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -45,6 +45,10 @@ from synapse.api.errors import ( ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions from synapse.crypto.event_signing import compute_event_signature +from synapse.event_auth import ( + check_auth_rules_for_event, + validate_event_for_room_version, +) from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator @@ -723,8 +727,8 @@ class FederationHandler(BaseHandler): state_ids, ) - builder = self.event_builder_factory.new( - room_version.identifier, + builder = self.event_builder_factory.for_room_version( + room_version, { "type": EventTypes.Member, "content": event_content, @@ -747,10 +751,9 @@ class FederationHandler(BaseHandler): # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_join_request` - await self._event_auth_handler.check_from_context( - room_version.identifier, event, context, do_sig_check=False + await self._event_auth_handler.check_auth_rules_from_context( + room_version, event, context ) - return event async def on_invite_request( @@ -767,7 +770,7 @@ class FederationHandler(BaseHandler): if is_blocked: raise SynapseError(403, "This room has been blocked on this server") - if self.hs.config.block_non_admin_invites: + if self.hs.config.server.block_non_admin_invites: raise SynapseError(403, "This server does not accept room invites") if not await self.spam_checker.user_may_invite( @@ -902,9 +905,9 @@ class FederationHandler(BaseHandler): ) raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) - room_version = await self.store.get_room_version_id(room_id) - builder = self.event_builder_factory.new( - room_version, + room_version_obj = await self.store.get_room_version(room_id) + builder = self.event_builder_factory.for_room_version( + room_version_obj, { "type": EventTypes.Member, "content": {"membership": Membership.LEAVE}, @@ -921,8 +924,8 @@ class FederationHandler(BaseHandler): try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_leave_request` - await self._event_auth_handler.check_from_context( - room_version, event, context, do_sig_check=False + await self._event_auth_handler.check_auth_rules_from_context( + room_version_obj, event, context ) except AuthError as e: logger.warning("Failed to create new leave %r because %s", event, e) @@ -954,10 +957,10 @@ class FederationHandler(BaseHandler): ) raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) - room_version = await self.store.get_room_version_id(room_id) + room_version_obj = await self.store.get_room_version(room_id) - builder = self.event_builder_factory.new( - room_version, + builder = self.event_builder_factory.for_room_version( + room_version_obj, { "type": EventTypes.Member, "content": {"membership": Membership.KNOCK}, @@ -983,8 +986,8 @@ class FederationHandler(BaseHandler): try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_knock_request` - await self._event_auth_handler.check_from_context( - room_version, event, context, do_sig_check=False + await self._event_auth_handler.check_auth_rules_from_context( + room_version_obj, event, context ) except AuthError as e: logger.warning("Failed to create new knock %r because %s", event, e) @@ -1173,7 +1176,8 @@ class FederationHandler(BaseHandler): auth_for_e[(EventTypes.Create, "")] = create_event try: - event_auth.check(room_version, e, auth_events=auth_for_e) + validate_event_for_room_version(room_version, e) + check_auth_rules_for_event(room_version, e, auth_for_e) except SynapseError as err: # we may get SynapseErrors here as well as AuthErrors. For # instance, there are a couple of (ancient) events in some @@ -1250,8 +1254,10 @@ class FederationHandler(BaseHandler): } if await self._event_auth_handler.check_host_in_room(room_id, self.hs.hostname): - room_version = await self.store.get_room_version_id(room_id) - builder = self.event_builder_factory.new(room_version, event_dict) + room_version_obj = await self.store.get_room_version(room_id) + builder = self.event_builder_factory.for_room_version( + room_version_obj, event_dict + ) EventValidator().validate_builder(builder) event, context = await self.event_creation_handler.create_new_client_event( @@ -1259,7 +1265,7 @@ class FederationHandler(BaseHandler): ) event, context = await self.add_display_name_to_third_party_invite( - room_version, event_dict, event, context + room_version_obj, event_dict, event, context ) EventValidator().validate_new(event, self.config) @@ -1269,8 +1275,9 @@ class FederationHandler(BaseHandler): event.internal_metadata.send_on_behalf_of = self.hs.hostname try: - await self._event_auth_handler.check_from_context( - room_version, event, context + validate_event_for_room_version(room_version_obj, event) + await self._event_auth_handler.check_auth_rules_from_context( + room_version_obj, event, context ) except AuthError as e: logger.warning("Denying new third party invite %r because %s", event, e) @@ -1304,22 +1311,25 @@ class FederationHandler(BaseHandler): """ assert_params_in_dict(event_dict, ["room_id"]) - room_version = await self.store.get_room_version_id(event_dict["room_id"]) + room_version_obj = await self.store.get_room_version(event_dict["room_id"]) # NB: event_dict has a particular specced format we might need to fudge # if we change event formats too much. - builder = self.event_builder_factory.new(room_version, event_dict) + builder = self.event_builder_factory.for_room_version( + room_version_obj, event_dict + ) event, context = await self.event_creation_handler.create_new_client_event( builder=builder ) event, context = await self.add_display_name_to_third_party_invite( - room_version, event_dict, event, context + room_version_obj, event_dict, event, context ) try: - await self._event_auth_handler.check_from_context( - room_version, event, context + validate_event_for_room_version(room_version_obj, event) + await self._event_auth_handler.check_auth_rules_from_context( + room_version_obj, event, context ) except AuthError as e: logger.warning("Denying third party invite %r because %s", event, e) @@ -1336,7 +1346,7 @@ class FederationHandler(BaseHandler): async def add_display_name_to_third_party_invite( self, - room_version: str, + room_version_obj: RoomVersion, event_dict: JsonDict, event: EventBase, context: EventContext, @@ -1368,7 +1378,9 @@ class FederationHandler(BaseHandler): # auth checks. If we need the invite and don't have it then the # auth check code will explode appropriately. - builder = self.event_builder_factory.new(room_version, event_dict) + builder = self.event_builder_factory.for_room_version( + room_version_obj, event_dict + ) EventValidator().validate_builder(builder) event, context = await self.event_creation_handler.create_new_client_event( builder=builder diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 01fd841122..e587b5b3b3 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -29,7 +29,6 @@ from typing import ( from prometheus_client import Counter -from synapse import event_auth from synapse.api.constants import ( EventContentFields, EventTypes, @@ -47,7 +46,11 @@ from synapse.api.errors import ( SynapseError, ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.event_auth import auth_types_for_event +from synapse.event_auth import ( + auth_types_for_event, + check_auth_rules_for_event, + validate_event_for_room_version, +) from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.federation.federation_client import InvalidResponseError @@ -68,11 +71,7 @@ from synapse.types import ( UserID, get_domain_from_id, ) -from synapse.util.async_helpers import ( - Linearizer, - concurrently_execute, - yieldable_gather_results, -) +from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.iterutils import batch_iter from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import shortstr @@ -1189,7 +1188,10 @@ class FederationEventHandler: allow_rejected=True, ) - async def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]: + room_version = await self._store.get_room_version_id(room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + + def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]: with nested_logging_context(suffix=event.event_id): auth = {} for auth_event_id in event.auth_event_ids(): @@ -1207,17 +1209,16 @@ class FederationEventHandler: auth[(ae.type, ae.state_key)] = ae context = EventContext.for_outlier() - context = await self._check_event_auth( - origin, - event, - context, - claimed_auth_event_map=auth, - ) + try: + validate_event_for_room_version(room_version_obj, event) + check_auth_rules_for_event(room_version_obj, event, auth) + except AuthError as e: + logger.warning("Rejecting %r because %s", event, e) + context.rejected = RejectedReason.AUTH_ERROR + return event, context - events_to_persist = ( - x for x in await yieldable_gather_results(prep, fetched_events) if x - ) + events_to_persist = (x for x in (prep(event) for event in fetched_events) if x) await self.persist_events_and_notify(room_id, tuple(events_to_persist)) async def _check_event_auth( @@ -1226,7 +1227,6 @@ class FederationEventHandler: event: EventBase, context: EventContext, state: Optional[Iterable[EventBase]] = None, - claimed_auth_event_map: Optional[StateMap[EventBase]] = None, backfilled: bool = False, ) -> EventContext: """ @@ -1242,42 +1242,36 @@ class FederationEventHandler: The state events used to check the event for soft-fail. If this is not provided the current state events will be used. - claimed_auth_event_map: - A map of (type, state_key) => event for the event's claimed auth_events. - Possibly including events that were rejected, or are in the wrong room. - - Only populated when populating outliers. - backfilled: True if the event was backfilled. Returns: The updated context object. """ - # claimed_auth_event_map should be given iff the event is an outlier - assert bool(claimed_auth_event_map) == event.internal_metadata.outlier + # This method should only be used for non-outliers + assert not event.internal_metadata.outlier room_version = await self._store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - if claimed_auth_event_map: - # if we have a copy of the auth events from the event, use that as the - # basis for auth. - auth_events = claimed_auth_event_map - else: - # otherwise, we calculate what the auth events *should* be, and use that - prev_state_ids = await context.get_prev_state_ids() - auth_events_ids = self._event_auth_handler.compute_auth_events( - event, prev_state_ids, for_verification=True - ) - auth_events_x = await self._store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()} + # calculate what the auth events *should* be, to use as a basis for auth. + prev_state_ids = await context.get_prev_state_ids() + auth_events_ids = self._event_auth_handler.compute_auth_events( + event, prev_state_ids, for_verification=True + ) + auth_events_x = await self._store.get_events(auth_events_ids) + calculated_auth_event_map = { + (e.type, e.state_key): e for e in auth_events_x.values() + } try: ( context, auth_events_for_auth, ) = await self._update_auth_events_and_context_for_auth( - origin, event, context, auth_events + origin, + event, + context, + calculated_auth_event_map=calculated_auth_event_map, ) except Exception: # We don't really mind if the above fails, so lets not fail @@ -1289,10 +1283,11 @@ class FederationEventHandler: "Ignoring failure and continuing processing of event.", event.event_id, ) - auth_events_for_auth = auth_events + auth_events_for_auth = calculated_auth_event_map try: - event_auth.check(room_version_obj, event, auth_events=auth_events_for_auth) + validate_event_for_room_version(room_version_obj, event) + check_auth_rules_for_event(room_version_obj, event, auth_events_for_auth) except AuthError as e: logger.warning("Failed auth resolution for %r because %s", event, e) context.rejected = RejectedReason.AUTH_ERROR @@ -1404,7 +1399,10 @@ class FederationEventHandler: } try: - event_auth.check(room_version_obj, event, auth_events=current_auth_events) + # TODO: skip the call to validate_event_for_room_version? we should already + # have validated the event. + validate_event_for_room_version(room_version_obj, event) + check_auth_rules_for_event(room_version_obj, event, current_auth_events) except AuthError as e: logger.warning( "Soft-failing %r (from %s) because %s", @@ -1425,7 +1423,7 @@ class FederationEventHandler: origin: str, event: EventBase, context: EventContext, - input_auth_events: StateMap[EventBase], + calculated_auth_event_map: StateMap[EventBase], ) -> Tuple[EventContext, StateMap[EventBase]]: """Helper for _check_event_auth. See there for docs. @@ -1443,19 +1441,17 @@ class FederationEventHandler: event: context: - input_auth_events: - Map from (event_type, state_key) to event - - Normally, our calculated auth_events based on the state of the room - at the event's position in the DAG, though occasionally (eg if the - event is an outlier), may be the auth events claimed by the remote - server. + calculated_auth_event_map: + Our calculated auth_events based on the state of the room + at the event's position in the DAG. Returns: updated context, updated auth event map """ - # take a copy of input_auth_events before we modify it. - auth_events: MutableStateMap[EventBase] = dict(input_auth_events) + assert not event.internal_metadata.outlier + + # take a copy of calculated_auth_event_map before we modify it. + auth_events: MutableStateMap[EventBase] = dict(calculated_auth_event_map) event_auth_events = set(event.auth_event_ids()) @@ -1496,15 +1492,6 @@ class FederationEventHandler: } ) - if event.internal_metadata.is_outlier(): - # XXX: given that, for an outlier, we'll be working with the - # event's *claimed* auth events rather than those we calculated: - # (a) is there any point in this test, since different_auth below will - # obviously be empty - # (b) alternatively, why don't we do it earlier? - logger.info("Skipping auth_event fetch for outlier") - return context, auth_events - different_auth = event_auth_events.difference( e.event_id for e in auth_events.values() ) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index fe8a995892..c881475c25 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -57,7 +57,7 @@ class IdentityHandler(BaseHandler): self.http_client = SimpleHttpClient(hs) # An HTTP client for contacting identity servers specified by clients. self.blacklisting_http_client = SimpleHttpClient( - hs, ip_blacklist=hs.config.federation_ip_range_blacklist + hs, ip_blacklist=hs.config.server.federation_ip_range_blacklist ) self.federation_http_client = hs.get_federation_http_client() self.hs = hs @@ -573,9 +573,15 @@ class IdentityHandler(BaseHandler): # Try to validate as email if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + # Remote emails will only be used if a valid identity server is provided. + assert ( + self.hs.config.registration.account_threepid_delegate_email is not None + ) + # Ask our delegated email identity server validation_session = await self.threepid_from_creds( - self.hs.config.account_threepid_delegate_email, threepid_creds + self.hs.config.registration.account_threepid_delegate_email, + threepid_creds, ) elif self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: # Get a validated session matching these details @@ -587,10 +593,11 @@ class IdentityHandler(BaseHandler): return validation_session # Try to validate as msisdn - if self.hs.config.account_threepid_delegate_msisdn: + if self.hs.config.registration.account_threepid_delegate_msisdn: # Ask our delegated msisdn identity server validation_session = await self.threepid_from_creds( - self.hs.config.account_threepid_delegate_msisdn, threepid_creds + self.hs.config.registration.account_threepid_delegate_msisdn, + threepid_creds, ) return validation_session diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index fd861e94f8..ccd7827207 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -16,6 +16,7 @@ # limitations under the License. import logging import random +from http import HTTPStatus from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple from canonicaljson import encode_canonical_json @@ -39,9 +40,11 @@ from synapse.api.errors import ( NotFoundError, ShadowBanError, SynapseError, + UnsupportedRoomVersionError, ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.api.urls import ConsentURIBuilder +from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase from synapse.events.builder import EventBuilder from synapse.events.snapshot import EventContext @@ -79,7 +82,7 @@ class MessageHandler: self.storage = hs.get_storage() self.state_store = self.storage.state self._event_serializer = hs.get_event_client_serializer() - self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages # The scheduled call to self._expire_event. None if no call is currently # scheduled. @@ -413,7 +416,9 @@ class EventCreationHandler: self.server_name = hs.hostname self.notifier = hs.get_notifier() self.config = hs.config - self.require_membership_for_aliases = hs.config.require_membership_for_aliases + self.require_membership_for_aliases = ( + hs.config.server.require_membership_for_aliases + ) self._events_shard_config = self.config.worker.events_shard_config self._instance_name = hs.get_instance_name() @@ -423,7 +428,7 @@ class EventCreationHandler: Membership.JOIN, Membership.KNOCK, } - if self.hs.config.include_profile_data_on_invite: + if self.hs.config.server.include_profile_data_on_invite: self.membership_types_to_include_profile_data_in.add(Membership.INVITE) self.send_event = ReplicationSendEventRestServlet.make_client(hs) @@ -459,11 +464,11 @@ class EventCreationHandler: # self._rooms_to_exclude_from_dummy_event_insertion: Dict[str, int] = {} # The number of forward extremeities before a dummy event is sent. - self._dummy_events_threshold = hs.config.dummy_events_threshold + self._dummy_events_threshold = hs.config.server.dummy_events_threshold if ( self.config.worker.run_background_tasks - and self.config.cleanup_extremities_with_dummy_events + and self.config.server.cleanup_extremities_with_dummy_events ): self.clock.looping_call( lambda: run_as_background_process( @@ -475,7 +480,7 @@ class EventCreationHandler: self._message_handler = hs.get_message_handler() - self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages self._external_cache = hs.get_external_cache() @@ -549,16 +554,22 @@ class EventCreationHandler: await self.auth.check_auth_blocking(requester=requester) if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "": - room_version = event_dict["content"]["room_version"] + room_version_id = event_dict["content"]["room_version"] + room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id) + if not room_version_obj: + # this can happen if support is withdrawn for a room version + raise UnsupportedRoomVersionError(room_version_id) else: try: - room_version = await self.store.get_room_version_id( + room_version_obj = await self.store.get_room_version( event_dict["room_id"] ) except NotFoundError: raise AuthError(403, "Unknown room") - builder = self.event_builder_factory.new(room_version, event_dict) + builder = self.event_builder_factory.for_room_version( + room_version_obj, event_dict + ) self.validator.validate_builder(builder) @@ -1064,9 +1075,17 @@ class EventCreationHandler: EventTypes.Create, "", ): - room_version = event.content.get("room_version", RoomVersions.V1.identifier) + room_version_id = event.content.get( + "room_version", RoomVersions.V1.identifier + ) + room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id) + if not room_version_obj: + raise UnsupportedRoomVersionError( + "Attempt to create a room with unsupported room version %s" + % (room_version_id,) + ) else: - room_version = await self.store.get_room_version_id(event.room_id) + room_version_obj = await self.store.get_room_version(event.room_id) if event.internal_metadata.is_out_of_band_membership(): # the only sort of out-of-band-membership events we expect to see here are @@ -1075,8 +1094,9 @@ class EventCreationHandler: assert event.content["membership"] == Membership.LEAVE else: try: - await self._event_auth_handler.check_from_context( - room_version, event, context + validate_event_for_room_version(room_version_obj, event) + await self._event_auth_handler.check_auth_rules_from_context( + room_version_obj, event, context ) except AuthError as err: logger.warning("Denying new event %r because %s", event, err) @@ -1456,6 +1476,39 @@ class EventCreationHandler: if prev_state_ids: raise AuthError(403, "Changing the room create event is forbidden") + if event.type == EventTypes.MSC2716_INSERTION: + room_version = await self.store.get_room_version_id(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + + create_event = await self.store.get_create_event_for_room(event.room_id) + room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR) + + # Only check an insertion event if the room version + # supports it or the event is from the room creator. + if room_version_obj.msc2716_historical or ( + self.config.experimental.msc2716_enabled + and event.sender == room_creator + ): + next_batch_id = event.content.get( + EventContentFields.MSC2716_NEXT_BATCH_ID + ) + conflicting_insertion_event_id = ( + await self.store.get_insertion_event_by_batch_id( + event.room_id, next_batch_id + ) + ) + if conflicting_insertion_event_id is not None: + # The current insertion event that we're processing is invalid + # because an insertion event already exists in the room with the + # same next_batch_id. We can't allow multiple because the batch + # pointing will get weird, e.g. we can't determine which insertion + # event the batch event is pointing to. + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Another insertion event already exists with the same next_batch_id", + errcode=Codes.INVALID_PARAM, + ) + # Mark any `m.historical` messages as backfilled so they don't appear # in `/sync` and have the proper decrementing `stream_ordering` as we import backfilled = False diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 08b93b3ec1..176e4dfdd4 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -85,23 +85,29 @@ class PaginationHandler: self._purges_by_id: Dict[str, PurgeStatus] = {} self._event_serializer = hs.get_event_client_serializer() - self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime + self._retention_default_max_lifetime = ( + hs.config.server.retention_default_max_lifetime + ) - self._retention_allowed_lifetime_min = hs.config.retention_allowed_lifetime_min - self._retention_allowed_lifetime_max = hs.config.retention_allowed_lifetime_max + self._retention_allowed_lifetime_min = ( + hs.config.server.retention_allowed_lifetime_min + ) + self._retention_allowed_lifetime_max = ( + hs.config.server.retention_allowed_lifetime_max + ) - if hs.config.worker.run_background_tasks and hs.config.retention_enabled: + if hs.config.worker.run_background_tasks and hs.config.server.retention_enabled: # Run the purge jobs described in the configuration file. - for job in hs.config.retention_purge_jobs: + for job in hs.config.server.retention_purge_jobs: logger.info("Setting up purge job with config: %s", job) self.clock.looping_call( run_as_background_process, - job["interval"], + job.interval, "purge_history_for_rooms_in_range", self.purge_history_for_rooms_in_range, - job["shortest_max_lifetime"], - job["longest_max_lifetime"], + job.shortest_max_lifetime, + job.longest_max_lifetime, ) async def purge_history_for_rooms_in_range( diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index b23a1541bc..2e19706c69 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -178,7 +178,7 @@ class ProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") - if not by_admin and not self.hs.config.enable_set_displayname: + if not by_admin and not self.hs.config.registration.enable_set_displayname: profile = await self.store.get_profileinfo(target_user.localpart) if profile.display_name: raise SynapseError( @@ -268,7 +268,7 @@ class ProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") - if not by_admin and not self.hs.config.enable_set_avatar_url: + if not by_admin and not self.hs.config.registration.enable_set_avatar_url: profile = await self.store.get_profileinfo(target_user.localpart) if profile.avatar_url: raise SynapseError( @@ -397,7 +397,7 @@ class ProfileHandler(BaseHandler): # when building a membership event. In this case, we must allow the # lookup. if ( - not self.hs.config.limit_profile_requests_to_users_who_share_rooms + not self.hs.config.server.limit_profile_requests_to_users_who_share_rooms or not requester ): return diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 4f99f137a2..441af7a848 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -116,8 +116,8 @@ class RegistrationHandler(BaseHandler): self._register_device_client = self.register_device_inner self.pusher_pool = hs.get_pusherpool() - self.session_lifetime = hs.config.session_lifetime - self.access_token_lifetime = hs.config.access_token_lifetime + self.session_lifetime = hs.config.registration.session_lifetime + self.access_token_lifetime = hs.config.registration.access_token_lifetime init_counters_for_auth_provider("") @@ -340,8 +340,13 @@ class RegistrationHandler(BaseHandler): auth_provider=(auth_provider_id or ""), ).inc() + # If the user does not need to consent at registration, auto-join any + # configured rooms. if not self.hs.config.consent.user_consent_at_registration: - if not self.hs.config.auto_join_rooms_for_guests and make_guest: + if ( + not self.hs.config.registration.auto_join_rooms_for_guests + and make_guest + ): logger.info( "Skipping auto-join for %s because auto-join for guests is disabled", user_id, @@ -387,7 +392,7 @@ class RegistrationHandler(BaseHandler): "preset": self.hs.config.registration.autocreate_auto_join_room_preset, } - # If the configuration providers a user ID to create rooms with, use + # If the configuration provides a user ID to create rooms with, use # that instead of the first user registered. requires_join = False if self.hs.config.registration.auto_join_user_id: @@ -854,7 +859,7 @@ class RegistrationHandler(BaseHandler): # Necessary due to auth checks prior to the threepid being # written to the db if is_threepid_reserved( - self.hs.config.mau_limits_reserved_threepids, threepid + self.hs.config.server.mau_limits_reserved_threepids, threepid ): await self.store.upsert_monthly_active_user(user_id) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 8fede5e935..873e08258e 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -52,6 +52,7 @@ from synapse.api.errors import ( ) from synapse.api.filtering import Filter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion +from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase from synapse.events.utils import copy_power_levels_contents from synapse.rest.admin._base import assert_user_is_admin @@ -237,8 +238,9 @@ class RoomCreationHandler(BaseHandler): }, }, ) - old_room_version = await self.store.get_room_version_id(old_room_id) - await self._event_auth_handler.check_from_context( + old_room_version = await self.store.get_room_version(old_room_id) + validate_event_for_room_version(old_room_version, tombstone_event) + await self._event_auth_handler.check_auth_rules_from_context( old_room_version, tombstone_event, tombstone_context ) @@ -666,7 +668,7 @@ class RoomCreationHandler(BaseHandler): await self.ratelimit(requester) room_version_id = config.get( - "room_version", self.config.default_room_version.identifier + "room_version", self.config.server.default_room_version.identifier ) if not isinstance(room_version_id, str): diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index afa7e4727d..c8fb24a20c 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -89,8 +89,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.spam_checker = hs.get_spam_checker() self.third_party_event_rules = hs.get_third_party_event_rules() self._server_notices_mxid = self.config.servernotices.server_notices_mxid - self._enable_lookup = hs.config.enable_3pid_lookup - self.allow_per_room_profiles = self.config.allow_per_room_profiles + self._enable_lookup = hs.config.registration.enable_3pid_lookup + self.allow_per_room_profiles = self.config.server.allow_per_room_profiles self._join_rate_limiter_local = Ratelimiter( store=self.store, @@ -625,7 +625,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: - if self.config.block_non_admin_invites: + if self.config.server.block_non_admin_invites: logger.info( "Blocking invite: user is not admin and non-admin " "invites disabled" @@ -1230,7 +1230,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): Raises: ShadowBanError if the requester has been shadow-banned. """ - if self.config.block_non_admin_invites: + if self.config.server.block_non_admin_invites: is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: raise SynapseError( @@ -1428,7 +1428,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): Returns: bool of whether the complexity is too great, or None if unable to be fetched """ - max_complexity = self.hs.config.limit_remote_rooms.complexity + max_complexity = self.hs.config.server.limit_remote_rooms.complexity complexity = await self.federation_handler.get_room_complexity( remote_room_hosts, room_id ) @@ -1444,7 +1444,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): Args: room_id: The room ID to check for complexity. """ - max_complexity = self.hs.config.limit_remote_rooms.complexity + max_complexity = self.hs.config.server.limit_remote_rooms.complexity complexity = await self.store.get_room_complexity(room_id) return complexity["v1"] > max_complexity @@ -1480,7 +1480,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): if too_complex is True: raise SynapseError( code=400, - msg=self.hs.config.limit_remote_rooms.complexity_error, + msg=self.hs.config.server.limit_remote_rooms.complexity_error, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, ) @@ -1515,7 +1515,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): ) raise SynapseError( code=400, - msg=self.hs.config.limit_remote_rooms.complexity_error, + msg=self.hs.config.server.limit_remote_rooms.complexity_error, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, ) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 8226d6f5a1..6d3333ee00 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -105,7 +105,7 @@ class SearchHandler(BaseHandler): dict to be returned to the client with results of search """ - if not self.hs.config.enable_search: + if not self.hs.config.server.enable_search: raise SynapseError(400, "Search is disabled on this homeserver") batch_group = None diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py index 25e6b012b7..1a062a784c 100644 --- a/synapse/handlers/send_email.py +++ b/synapse/handlers/send_email.py @@ -105,8 +105,13 @@ async def _sendmail( # set to enable TLS. factory = build_sender_factory(hostname=smtphost if enable_tls else None) - # the IReactorTCP interface claims host has to be a bytes, which seems to be wrong - reactor.connectTCP(smtphost, smtpport, factory, timeout=30, bindAddress=None) # type: ignore[arg-type] + reactor.connectTCP( + smtphost, # type: ignore[arg-type] + smtpport, + factory, + timeout=30, + bindAddress=None, + ) await make_deferred_yieldable(d) diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 8f5d465fa1..184730ebe8 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -153,21 +153,23 @@ class _BaseThreepidAuthChecker: # msisdns are currently always ThreepidBehaviour.REMOTE if medium == "msisdn": - if not self.hs.config.account_threepid_delegate_msisdn: + if not self.hs.config.registration.account_threepid_delegate_msisdn: raise SynapseError( 400, "Phone number verification is not enabled on this homeserver" ) threepid = await identity_handler.threepid_from_creds( - self.hs.config.account_threepid_delegate_msisdn, threepid_creds + self.hs.config.registration.account_threepid_delegate_msisdn, + threepid_creds, ) elif medium == "email": if ( self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE ): - assert self.hs.config.account_threepid_delegate_email + assert self.hs.config.registration.account_threepid_delegate_email threepid = await identity_handler.threepid_from_creds( - self.hs.config.account_threepid_delegate_email, threepid_creds + self.hs.config.registration.account_threepid_delegate_email, + threepid_creds, ) elif ( self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL @@ -240,7 +242,7 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): _BaseThreepidAuthChecker.__init__(self, hs) def is_enabled(self) -> bool: - return bool(self.hs.config.account_threepid_delegate_msisdn) + return bool(self.hs.config.registration.account_threepid_delegate_msisdn) async def check_auth(self, authdict: dict, clientip: str) -> Any: return await self._check_threepid("msisdn", authdict) @@ -252,7 +254,7 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs - self._enabled = bool(hs.config.registration_requires_token) + self._enabled = bool(hs.config.registration.registration_requires_token) self.store = hs.get_datastore() def is_enabled(self) -> bool: diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index b91e7cb501..18d8c8744e 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -60,7 +60,7 @@ class UserDirectoryHandler(StateDeltasHandler): self.clock = hs.get_clock() self.notifier = hs.get_notifier() self.is_mine_id = hs.is_mine_id - self.update_user_directory = hs.config.update_user_directory + self.update_user_directory = hs.config.server.update_user_directory self.search_all_users = hs.config.userdirectory.user_directory_search_all_users self.spam_checker = hs.get_spam_checker() # The current position in the current_state_delta stream @@ -132,12 +132,7 @@ class UserDirectoryHandler(StateDeltasHandler): # FIXME(#3714): We should probably do this in the same worker as all # the other changes. - # Support users are for diagnostics and should not appear in the user directory. - is_support = await self.store.is_support_user(user_id) - # When change profile information of deactivated user it should not appear in the user directory. - is_deactivated = await self.store.get_user_deactivated_status(user_id) - - if not (is_support or is_deactivated): + if await self.store.should_include_local_user_in_dir(user_id): await self.store.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url ) @@ -229,8 +224,10 @@ class UserDirectoryHandler(StateDeltasHandler): else: logger.debug("Server is still in room: %r", room_id) - is_support = await self.store.is_support_user(state_key) - if not is_support: + include_in_dir = not self.is_mine_id( + state_key + ) or await self.store.should_include_local_user_in_dir(state_key) + if include_in_dir: if change is MatchChange.no_change: # Handle any profile changes await self._handle_profile_change( @@ -356,13 +353,7 @@ class UserDirectoryHandler(StateDeltasHandler): # First, if they're our user then we need to update for every user if self.is_mine_id(user_id): - - is_appservice = self.store.get_if_app_services_interested_in_user( - user_id - ) - - # We don't care about appservice users. - if not is_appservice: + if await self.store.should_include_local_user_in_dir(user_id): for other_user_id in other_users_in_room: if user_id == other_user_id: continue @@ -374,10 +365,10 @@ class UserDirectoryHandler(StateDeltasHandler): if user_id == other_user_id: continue - is_appservice = self.store.get_if_app_services_interested_in_user( + include_other_user = self.is_mine_id( other_user_id - ) - if self.is_mine_id(other_user_id) and not is_appservice: + ) and await self.store.should_include_local_user_in_dir(other_user_id) + if include_other_user: to_insert.add((other_user_id, user_id)) if to_insert: diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index cdc36b8d25..4f59224686 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -327,23 +327,23 @@ class MatrixFederationHttpClient: self.reactor = hs.get_reactor() user_agent = hs.version_string - if hs.config.user_agent_suffix: - user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix) + if hs.config.server.user_agent_suffix: + user_agent = "%s %s" % (user_agent, hs.config.server.user_agent_suffix) user_agent = user_agent.encode("ascii") federation_agent = MatrixFederationAgent( self.reactor, tls_client_options_factory, user_agent, - hs.config.federation_ip_range_whitelist, - hs.config.federation_ip_range_blacklist, + hs.config.server.federation_ip_range_whitelist, + hs.config.server.federation_ip_range_blacklist, ) # Use a BlacklistingAgentWrapper to prevent circumventing the IP # blacklist via IP literals in server names self.agent = BlacklistingAgentWrapper( federation_agent, - ip_blacklist=hs.config.federation_ip_range_blacklist, + ip_blacklist=hs.config.server.federation_ip_range_blacklist, ) self.clock = hs.get_clock() diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 1438a82b60..d64d1dbacd 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -315,7 +315,7 @@ class ReplicationCommandHandler: hs, outbound_redis_connection ) hs.get_reactor().connectTCP( - hs.config.redis.redis_host.encode(), + hs.config.redis.redis_host, # type: ignore[arg-type] hs.config.redis.redis_port, self._factory, ) @@ -324,7 +324,11 @@ class ReplicationCommandHandler: self._factory = DirectTcpReplicationClientFactory(hs, client_name, self) host = hs.config.worker.worker_replication_host port = hs.config.worker.worker_replication_port - hs.get_reactor().connectTCP(host.encode(), port, self._factory) + hs.get_reactor().connectTCP( + host, # type: ignore[arg-type] + port, + self._factory, + ) def get_streams(self) -> Dict[str, Stream]: """Get a map from stream name to all streams.""" diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 8c0df627c8..062fe2f33e 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -364,6 +364,12 @@ def lazyConnection( factory.continueTrying = reconnect reactor = hs.get_reactor() - reactor.connectTCP(host.encode(), port, factory, timeout=30, bindAddress=None) + reactor.connectTCP( + host, # type: ignore[arg-type] + port, + factory, + timeout=30, + bindAddress=None, + ) return factory.handler diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 030852cb5b..80f9b23bfd 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -71,7 +71,7 @@ class ReplicationStreamer: self.notifier = hs.get_notifier() self._instance_name = hs.get_instance_name() - self._replication_torture_level = hs.config.replication_torture_level + self._replication_torture_level = hs.config.server.replication_torture_level self.notifier.add_replication_callback(self.on_notifier_poke) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 46bfec4623..f20aa65301 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -442,7 +442,7 @@ class UserRegisterServlet(RestServlet): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: self._clear_old_nonces() - if not self.hs.config.registration_shared_secret: + if not self.hs.config.registration.registration_shared_secret: raise SynapseError(400, "Shared secret registration is not enabled") body = parse_json_object_from_request(request) @@ -498,7 +498,7 @@ class UserRegisterServlet(RestServlet): got_mac = body["mac"] want_mac_builder = hmac.new( - key=self.hs.config.registration_shared_secret.encode(), + key=self.hs.config.registration.registration_shared_secret.encode(), digestmod=hashlib.sha1, ) want_mac_builder.update(nonce.encode("utf8")) diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index 6a7608d60b..6b272658fc 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -119,7 +119,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): ) if existing_user_id is None: - if self.config.request_token_inhibit_3pid_errors: + if self.config.server.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it @@ -130,11 +130,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.account_threepid_delegate_email + assert self.hs.config.registration.account_threepid_delegate_email # Have the configured identity server handle the request ret = await self.identity_handler.requestEmailToken( - self.hs.config.account_threepid_delegate_email, + self.hs.config.registration.account_threepid_delegate_email, email, client_secret, send_attempt, @@ -403,7 +403,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): existing_user_id = await self.store.get_user_id_by_threepid("email", email) if existing_user_id is not None: - if self.config.request_token_inhibit_3pid_errors: + if self.config.server.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it @@ -414,11 +414,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.account_threepid_delegate_email + assert self.hs.config.registration.account_threepid_delegate_email # Have the configured identity server handle the request ret = await self.identity_handler.requestEmailToken( - self.hs.config.account_threepid_delegate_email, + self.hs.config.registration.account_threepid_delegate_email, email, client_secret, send_attempt, @@ -486,7 +486,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn) if existing_user_id is not None: - if self.hs.config.request_token_inhibit_3pid_errors: + if self.hs.config.server.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it @@ -496,7 +496,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) - if not self.hs.config.account_threepid_delegate_msisdn: + if not self.hs.config.registration.account_threepid_delegate_msisdn: logger.warning( "No upstream msisdn account_threepid_delegate configured on the server to " "handle this request" @@ -507,7 +507,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): ) ret = await self.identity_handler.requestMsisdnToken( - self.hs.config.account_threepid_delegate_msisdn, + self.hs.config.registration.account_threepid_delegate_msisdn, country, phone_number, client_secret, @@ -604,7 +604,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): self.identity_handler = hs.get_identity_handler() async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: - if not self.config.account_threepid_delegate_msisdn: + if not self.config.registration.account_threepid_delegate_msisdn: raise SynapseError( 400, "This homeserver is not validating phone numbers. Use an identity server " @@ -617,7 +617,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): # Proxy submit_token request to msisdn threepid delegate response = await self.identity_handler.proxy_msisdn_submit_token( - self.config.account_threepid_delegate_msisdn, + self.config.registration.account_threepid_delegate_msisdn, body["client_secret"], body["sid"], body["token"], @@ -644,7 +644,7 @@ class ThreepidRestServlet(RestServlet): return 200, {"threepids": threepids} async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if not self.hs.config.enable_3pid_changes: + if not self.hs.config.registration.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN ) @@ -693,7 +693,7 @@ class ThreepidAddRestServlet(RestServlet): @interactive_auth_handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if not self.hs.config.enable_3pid_changes: + if not self.hs.config.registration.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN ) @@ -801,7 +801,7 @@ class ThreepidDeleteRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if not self.hs.config.enable_3pid_changes: + if not self.hs.config.registration.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN ) @@ -857,8 +857,8 @@ def assert_valid_next_link(hs: "HomeServer", next_link: str) -> None: # If the domain whitelist is set, the domain must be in it if ( valid - and hs.config.next_link_domain_whitelist is not None - and next_link_parsed.hostname not in hs.config.next_link_domain_whitelist + and hs.config.server.next_link_domain_whitelist is not None + and next_link_parsed.hostname not in hs.config.server.next_link_domain_whitelist ): valid = False @@ -878,9 +878,13 @@ class WhoamiRestServlet(RestServlet): self.auth = hs.get_auth() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request, allow_guest=True) - response = {"user_id": requester.user.to_string()} + response = { + "user_id": requester.user.to_string(), + # MSC: https://github.com/matrix-org/matrix-doc/pull/3069 + "org.matrix.msc3069.is_guest": bool(requester.is_guest), + } # Appservices and similar accounts do not have device IDs # that we can report on, so exclude them for compliance. diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index 282861fae2..c9ad35a3ad 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -49,8 +49,10 @@ class AuthRestServlet(RestServlet): self.registration_handler = hs.get_registration_handler() self.recaptcha_template = hs.config.captcha.recaptcha_template self.terms_template = hs.config.terms_template - self.registration_token_template = hs.config.registration_token_template - self.success_template = hs.config.fallback_success_template + self.registration_token_template = ( + hs.config.registration.registration_token_template + ) + self.success_template = hs.config.registration.fallback_success_template async def on_GET(self, request: SynapseRequest, stagetype: str) -> None: session = parse_string(request, "session") diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index 65b3b5ce2c..2a3e24ae7e 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -44,10 +44,10 @@ class CapabilitiesRestServlet(RestServlet): await self.auth.get_user_by_req(request, allow_guest=True) change_password = self.auth_handler.can_change_password() - response = { + response: JsonDict = { "capabilities": { "m.room_versions": { - "default": self.config.default_room_version.identifier, + "default": self.config.server.default_room_version.identifier, "available": { v.identifier: v.disposition for v in KNOWN_ROOM_VERSIONS.values() @@ -64,13 +64,13 @@ class CapabilitiesRestServlet(RestServlet): if self.config.experimental.msc3283_enabled: response["capabilities"]["org.matrix.msc3283.set_displayname"] = { - "enabled": self.config.enable_set_displayname + "enabled": self.config.registration.enable_set_displayname } response["capabilities"]["org.matrix.msc3283.set_avatar_url"] = { - "enabled": self.config.enable_set_avatar_url + "enabled": self.config.registration.enable_set_avatar_url } response["capabilities"]["org.matrix.msc3283.3pid_changes"] = { - "enabled": self.config.enable_3pid_changes + "enabled": self.config.registration.enable_3pid_changes } return 200, response diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py index 6ed60c7418..cc1c2f9731 100644 --- a/synapse/rest/client/filter.py +++ b/synapse/rest/client/filter.py @@ -90,7 +90,7 @@ class CreateFilterRestServlet(RestServlet): raise AuthError(403, "Can only create filters for local users") content = parse_json_object_from_request(request) - set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit) + set_timeline_upper_limit(content, self.hs.config.server.filter_timeline_limit) filter_id = await self.filtering.add_user_filter( user_localpart=target_user.localpart, user_filter=content diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index fa5c173f4b..d49a647b03 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -79,7 +79,7 @@ class LoginRestServlet(RestServlet): self.saml2_enabled = hs.config.saml2.saml2_enabled self.cas_enabled = hs.config.cas.cas_enabled self.oidc_enabled = hs.config.oidc.oidc_enabled - self._msc2918_enabled = hs.config.access_token_lifetime is not None + self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None self.auth = hs.get_auth() @@ -447,7 +447,7 @@ class RefreshTokenServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth_handler = hs.get_auth_handler() self._clock = hs.get_clock() - self.access_token_lifetime = hs.config.access_token_lifetime + self.access_token_lifetime = hs.config.registration.access_token_lifetime async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: refresh_submission = parse_json_object_from_request(request) @@ -556,7 +556,7 @@ class CasTicketServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: LoginRestServlet(hs).register(http_server) - if hs.config.access_token_lifetime is not None: + if hs.config.registration.access_token_lifetime is not None: RefreshTokenServlet(hs).register(http_server) SsoRedirectServlet(hs).register(http_server) if hs.config.cas.cas_enabled: diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py index d0f20de569..c684636c0a 100644 --- a/synapse/rest/client/profile.py +++ b/synapse/rest/client/profile.py @@ -41,7 +41,7 @@ class ProfileDisplaynameRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester_user = None - if self.hs.config.require_auth_for_profile_requests: + if self.hs.config.server.require_auth_for_profile_requests: requester = await self.auth.get_user_by_req(request) requester_user = requester.user @@ -94,7 +94,7 @@ class ProfileAvatarURLRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester_user = None - if self.hs.config.require_auth_for_profile_requests: + if self.hs.config.server.require_auth_for_profile_requests: requester = await self.auth.get_user_by_req(request) requester_user = requester.user @@ -146,7 +146,7 @@ class ProfileRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester_user = None - if self.hs.config.require_auth_for_profile_requests: + if self.hs.config.server.require_auth_for_profile_requests: requester = await self.auth.get_user_by_req(request) requester_user = requester.user diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 48b0062cf4..bf3cb34146 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -129,7 +129,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): ) if existing_user_id is not None: - if self.hs.config.request_token_inhibit_3pid_errors: + if self.hs.config.server.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it @@ -140,11 +140,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.account_threepid_delegate_email + assert self.hs.config.registration.account_threepid_delegate_email # Have the configured identity server handle the request ret = await self.identity_handler.requestEmailToken( - self.hs.config.account_threepid_delegate_email, + self.hs.config.registration.account_threepid_delegate_email, email, client_secret, send_attempt, @@ -209,7 +209,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): ) if existing_user_id is not None: - if self.hs.config.request_token_inhibit_3pid_errors: + if self.hs.config.server.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it @@ -221,7 +221,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): 400, "Phone number is already in use", Codes.THREEPID_IN_USE ) - if not self.hs.config.account_threepid_delegate_msisdn: + if not self.hs.config.registration.account_threepid_delegate_msisdn: logger.warning( "No upstream msisdn account_threepid_delegate configured on the server to " "handle this request" @@ -231,7 +231,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): ) ret = await self.identity_handler.requestMsisdnToken( - self.hs.config.account_threepid_delegate_msisdn, + self.hs.config.registration.account_threepid_delegate_msisdn, country, phone_number, client_secret, @@ -341,7 +341,7 @@ class UsernameAvailabilityRestServlet(RestServlet): ) async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: - if not self.hs.config.enable_registration: + if not self.hs.config.registration.enable_registration: raise SynapseError( 403, "Registration has been disabled", errcode=Codes.FORBIDDEN ) @@ -391,7 +391,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: await self.ratelimiter.ratelimit(None, (request.getClientIP(),)) - if not self.hs.config.enable_registration: + if not self.hs.config.registration.enable_registration: raise SynapseError( 403, "Registration has been disabled", errcode=Codes.FORBIDDEN ) @@ -419,8 +419,8 @@ class RegisterRestServlet(RestServlet): self.ratelimiter = hs.get_registration_ratelimiter() self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() - self._registration_enabled = self.hs.config.enable_registration - self._msc2918_enabled = hs.config.access_token_lifetime is not None + self._registration_enabled = self.hs.config.registration.enable_registration + self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None self._registration_flows = _calculate_registration_flows( hs.config, self.auth_handler @@ -682,7 +682,7 @@ class RegisterRestServlet(RestServlet): # written to the db if threepid: if is_threepid_reserved( - self.hs.config.mau_limits_reserved_threepids, threepid + self.hs.config.server.mau_limits_reserved_threepids, threepid ): await self.store.upsert_monthly_active_user(registered_user_id) @@ -800,7 +800,7 @@ class RegisterRestServlet(RestServlet): async def _do_guest_registration( self, params: JsonDict, address: Optional[str] = None ) -> Tuple[int, JsonDict]: - if not self.hs.config.allow_guest_access: + if not self.hs.config.registration.allow_guest_access: raise SynapseError(403, "Guest access is disabled") user_id = await self.registration_handler.register_user( make_guest=True, address=address @@ -849,13 +849,13 @@ def _calculate_registration_flows( """ # FIXME: need a better error than "no auth flow found" for scenarios # where we required 3PID for registration but the user didn't give one - require_email = "email" in config.registrations_require_3pid - require_msisdn = "msisdn" in config.registrations_require_3pid + require_email = "email" in config.registration.registrations_require_3pid + require_msisdn = "msisdn" in config.registration.registrations_require_3pid show_msisdn = True show_email = True - if config.disable_msisdn_registration: + if config.registration.disable_msisdn_registration: show_msisdn = False require_msisdn = False @@ -909,7 +909,7 @@ def _calculate_registration_flows( flow.insert(0, LoginType.RECAPTCHA) # Prepend registration token to all flows if we're requiring a token - if config.registration_requires_token: + if config.registration.registration_requires_token: for flow in flows: flow.insert(0, LoginType.REGISTRATION_TOKEN) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index bf46dc60f2..ed95189b6d 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -369,7 +369,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): # Option to allow servers to require auth when accessing # /publicRooms via CS API. This is especially helpful in private # federations. - if not self.hs.config.allow_public_rooms_without_auth: + if not self.hs.config.server.allow_public_rooms_without_auth: raise # We allow people to not be authed if they're just looking at our diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index bf14ec384e..1dffcc3147 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -306,11 +306,13 @@ class RoomBatchSendEventRestServlet(RestServlet): # Verify the batch_id_from_query corresponds to an actual insertion event # and have the batch connected. corresponding_insertion_event_id = ( - await self.store.get_insertion_event_by_batch_id(batch_id_from_query) + await self.store.get_insertion_event_by_batch_id( + room_id, batch_id_from_query + ) ) if corresponding_insertion_event_id is None: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "No insertion event corresponds to the given ?batch_id", errcode=Codes.INVALID_PARAM, ) diff --git a/synapse/rest/client/shared_rooms.py b/synapse/rest/client/shared_rooms.py index 1d90493eb0..09a46737de 100644 --- a/synapse/rest/client/shared_rooms.py +++ b/synapse/rest/client/shared_rooms.py @@ -42,7 +42,7 @@ class UserSharedRoomsServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() - self.user_directory_active = hs.config.update_user_directory + self.user_directory_active = hs.config.server.update_user_directory async def on_GET( self, request: SynapseRequest, user_id: str diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 1259058b9b..913216a7c4 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -155,7 +155,7 @@ class SyncRestServlet(RestServlet): try: filter_object = json_decoder.decode(filter_id) set_timeline_upper_limit( - filter_object, self.hs.config.filter_timeline_limit + filter_object, self.hs.config.server.filter_timeline_limit ) except Exception: raise SynapseError(400, "Invalid filter JSON") diff --git a/synapse/rest/client/voip.py b/synapse/rest/client/voip.py index ea2b8aa45f..ea7e025156 100644 --- a/synapse/rest/client/voip.py +++ b/synapse/rest/client/voip.py @@ -70,7 +70,7 @@ class VoipRestServlet(RestServlet): { "username": username, "password": password, - "ttl": userLifetime / 1000, + "ttl": userLifetime // 1000, "uris": turnUris, }, ) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 79a42b2455..044f44a397 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -73,6 +73,7 @@ OG_TAG_VALUE_MAXLEN = 1000 ONE_HOUR = 60 * 60 * 1000 ONE_DAY = 24 * ONE_HOUR +IMAGE_CACHE_EXPIRY_MS = 2 * ONE_DAY @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -496,6 +497,27 @@ class PreviewUrlResource(DirectServeJsonResource): logger.info("Still running DB updates; skipping expiry") return + def try_remove_parent_dirs(dirs: Iterable[str]) -> None: + """Attempt to remove the given chain of parent directories + + Args: + dirs: The list of directory paths to delete, with children appearing + before their parents. + """ + for dir in dirs: + try: + os.rmdir(dir) + except FileNotFoundError: + # Already deleted, continue with deleting the rest + pass + except OSError as e: + # Failed, skip deleting the rest of the parent dirs + if e.errno != errno.ENOTEMPTY: + logger.warning( + "Failed to remove media directory: %r: %s", dir, e + ) + break + # First we delete expired url cache entries media_ids = await self.store.get_expired_url_cache(now) @@ -504,20 +526,16 @@ class PreviewUrlResource(DirectServeJsonResource): fname = self.filepaths.url_cache_filepath(media_id) try: os.remove(fname) + except FileNotFoundError: + pass # If the path doesn't exist, meh except OSError as e: - # If the path doesn't exist, meh - if e.errno != errno.ENOENT: - logger.warning("Failed to remove media: %r: %s", media_id, e) - continue + logger.warning("Failed to remove media: %r: %s", media_id, e) + continue removed_media.append(media_id) - try: - dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id) - for dir in dirs: - os.rmdir(dir) - except Exception: - pass + dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id) + try_remove_parent_dirs(dirs) await self.store.delete_url_cache(removed_media) @@ -530,7 +548,7 @@ class PreviewUrlResource(DirectServeJsonResource): # These may be cached for a bit on the client (i.e., they # may have a room open with a preview url thing open). # So we wait a couple of days before deleting, just in case. - expire_before = now - 2 * ONE_DAY + expire_before = now - IMAGE_CACHE_EXPIRY_MS media_ids = await self.store.get_url_cache_media_before(expire_before) removed_media = [] @@ -538,36 +556,30 @@ class PreviewUrlResource(DirectServeJsonResource): fname = self.filepaths.url_cache_filepath(media_id) try: os.remove(fname) + except FileNotFoundError: + pass # If the path doesn't exist, meh except OSError as e: - # If the path doesn't exist, meh - if e.errno != errno.ENOENT: - logger.warning("Failed to remove media: %r: %s", media_id, e) - continue + logger.warning("Failed to remove media: %r: %s", media_id, e) + continue - try: - dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id) - for dir in dirs: - os.rmdir(dir) - except Exception: - pass + dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id) + try_remove_parent_dirs(dirs) thumbnail_dir = self.filepaths.url_cache_thumbnail_directory(media_id) try: shutil.rmtree(thumbnail_dir) + except FileNotFoundError: + pass # If the path doesn't exist, meh except OSError as e: - # If the path doesn't exist, meh - if e.errno != errno.ENOENT: - logger.warning("Failed to remove media: %r: %s", media_id, e) - continue + logger.warning("Failed to remove media: %r: %s", media_id, e) + continue removed_media.append(media_id) - try: - dirs = self.filepaths.url_cache_thumbnail_dirs_to_delete(media_id) - for dir in dirs: - os.rmdir(dir) - except Exception: - pass + dirs = self.filepaths.url_cache_thumbnail_dirs_to_delete(media_id) + # Note that one of the directories to be deleted has already been + # removed by the `rmtree` above. + try_remove_parent_dirs(dirs) await self.store.delete_url_cache_media(removed_media) diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index c80a3a99aa..7ac01faab4 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -39,9 +39,9 @@ class WellKnownBuilder: result = {"m.homeserver": {"base_url": self._config.server.public_baseurl}} - if self._config.default_identity_server: + if self._config.registration.default_identity_server: result["m.identity_server"] = { - "base_url": self._config.default_identity_server + "base_url": self._config.registration.default_identity_server } return result diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 073b0d754f..8522930b50 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -47,9 +47,9 @@ class ResourceLimitsServerNotices: self._notifier = hs.get_notifier() self._enabled = ( - hs.config.limit_usage_by_mau + hs.config.server.limit_usage_by_mau and self._server_notices_manager.is_enabled() - and not hs.config.hs_disabled + and not hs.config.server.hs_disabled ) async def maybe_send_server_notice_to_user(self, user_id: str) -> None: @@ -98,7 +98,7 @@ class ResourceLimitsServerNotices: try: if ( limit_type == LimitBlockingTypes.MONTHLY_ACTIVE_USER - and not self._config.mau_limit_alerting + and not self._config.server.mau_limit_alerting ): # We have hit the MAU limit, but MAU alerting is disabled: # reset room if necessary and return @@ -149,7 +149,7 @@ class ResourceLimitsServerNotices: "body": event_body, "msgtype": ServerNoticeMsgType, "server_notice_type": ServerNoticeLimitReached, - "admin_contact": self._config.admin_contact, + "admin_contact": self._config.server.admin_contact, "limit_type": event_limit_type, } event = await self._server_notices_manager.send_notice( diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 92336d7cc8..017e6fd92d 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -329,12 +329,10 @@ def _resolve_auth_events( auth_events[(prev_event.type, prev_event.state_key)] = prev_event try: # The signatures have already been checked at this point - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V1, event, auth_events, - do_sig_check=False, - do_size_check=False, ) prev_event = event except AuthError: @@ -349,12 +347,10 @@ def _resolve_normal_events( for event in _ordered_events(events): try: # The signatures have already been checked at this point - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V1, event, auth_events, - do_sig_check=False, - do_size_check=False, ) return event except AuthError: diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 7b1e8361de..586b0e12fe 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -546,12 +546,10 @@ async def _iterative_auth_checks( auth_events[key] = event_map[ev_id] try: - event_auth.check( + event_auth.check_auth_rules_for_event( room_version, event, auth_events, - do_sig_check=False, - do_size_check=False, ) resolved_state[(event.type, event.state_key)] = event_id diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index 6305414e3d..eee07227ef 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -36,7 +36,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase if ( hs.config.worker.run_background_tasks - and self.hs.config.redaction_retention_period is not None + and self.hs.config.server.redaction_retention_period is not None ): hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000) @@ -48,7 +48,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase By censor we mean update the event_json table with the redacted event. """ - if self.hs.config.redaction_retention_period is None: + if self.hs.config.server.redaction_retention_period is None: return if not ( @@ -60,7 +60,9 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase # created. return - before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period + before_ts = ( + self._clock.time_msec() - self.hs.config.server.redaction_retention_period + ) # We fetch all redactions that: # 1. point to an event we have, diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index cc192f5c87..c77acc7c84 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -353,7 +353,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self.user_ips_max_age = hs.config.user_ips_max_age + self.user_ips_max_age = hs.config.server.user_ips_max_age if hs.config.worker.run_background_tasks and self.user_ips_max_age: self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 584f818ff3..bc7d213fe2 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -104,7 +104,7 @@ class PersistEventsStore: self._clock = hs.get_clock() self._instance_name = hs.get_instance_name() - self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages + self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id # Ideally we'd move these ID gens here, unfortunately some other ID @@ -1276,13 +1276,6 @@ class PersistEventsStore: logger.exception("") raise - # update the stored internal_metadata to update the "outlier" flag. - # TODO: This is unused as of Synapse 1.31. Remove it once we are happy - # to drop backwards-compatibility with 1.30. - metadata_json = json_encoder.encode(event.internal_metadata.get_dict()) - sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?" - txn.execute(sql, (metadata_json, event.event_id)) - # Add an entry to the ex_outlier_stream table to replicate the # change in outlier status to our workers. stream_order = event.internal_metadata.stream_ordering @@ -1327,19 +1320,6 @@ class PersistEventsStore: d.pop("redacted_because", None) return d - def get_internal_metadata(event): - im = event.internal_metadata.get_dict() - - # temporary hack for database compatibility with Synapse 1.30 and earlier: - # store the `outlier` flag inside the internal_metadata json as well as in - # the `events` table, so that if anyone rolls back to an older Synapse, - # things keep working. This can be removed once we are happy to drop support - # for that - if event.internal_metadata.is_outlier(): - im["outlier"] = True - - return im - self.db_pool.simple_insert_many_txn( txn, table="event_json", @@ -1348,7 +1328,7 @@ class PersistEventsStore: "event_id": event.event_id, "room_id": event.room_id, "internal_metadata": json_encoder.encode( - get_internal_metadata(event) + event.internal_metadata.get_dict() ), "json": json_encoder.encode(event_dict(event)), "format_version": event.format_version, diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py index bb244a03c0..434986fa64 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union + from canonicaljson import encode_canonical_json from synapse.api.errors import Codes, SynapseError @@ -22,7 +24,9 @@ from synapse.util.caches.descriptors import cached class FilteringStore(SQLBaseStore): @cached(num_args=2) - async def get_user_filter(self, user_localpart, filter_id): + async def get_user_filter( + self, user_localpart: str, filter_id: Union[int, str] + ) -> JsonDict: # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail # with a coherent error message rather than 500 M_UNKNOWN. try: @@ -40,7 +44,7 @@ class FilteringStore(SQLBaseStore): return db_to_json(def_json) - async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str: + async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int: def_json = encode_canonical_json(user_filter) # Need an atomic transaction to SELECT the maximal ID so far then diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index b76ee51a9b..a14ac03d4b 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -32,8 +32,8 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): self._clock = hs.get_clock() self.hs = hs - self._limit_usage_by_mau = hs.config.limit_usage_by_mau - self._max_mau_value = hs.config.max_mau_value + self._limit_usage_by_mau = hs.config.server.limit_usage_by_mau + self._max_mau_value = hs.config.server.max_mau_value @cached(num_args=0) async def get_monthly_active_count(self) -> int: @@ -96,8 +96,8 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): """ users = [] - for tp in self.hs.config.mau_limits_reserved_threepids[ - : self.hs.config.max_mau_value + for tp in self.hs.config.server.mau_limits_reserved_threepids[ + : self.hs.config.server.max_mau_value ]: user_id = await self.hs.get_datastore().get_user_id_by_threepid( tp["medium"], tp["address"] @@ -212,7 +212,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self._mau_stats_only = hs.config.mau_stats_only + self._mau_stats_only = hs.config.server.mau_stats_only # Do not add more reserved users than the total allowable number self.db_pool.new_transaction( @@ -221,7 +221,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): [], [], self._initialise_reserved_users, - hs.config.mau_limits_reserved_threepids[: self._max_mau_value], + hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value], ) def _initialise_reserved_users(self, txn, threepids): diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index c83089ee63..de262fbf5a 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -207,7 +207,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return False now = self._clock.time_msec() - trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000 + trial_duration_ms = self.config.server.mau_trial_days * 24 * 60 * 60 * 1000 is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms return is_trial @@ -1710,7 +1710,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): We do this by grandfathering in existing user threepids assuming that they used one of the server configured trusted identity servers. """ - id_servers = set(self.config.trusted_third_party_id_servers) + id_servers = set(self.config.registration.trusted_third_party_id_servers) def _bg_user_threepids_grandfather_txn(txn): sql = """ diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 118b390e93..d69eaf80ce 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -679,8 +679,8 @@ class RoomWorkerStore(SQLBaseStore): # policy. if not ret: return { - "min_lifetime": self.config.retention_default_min_lifetime, - "max_lifetime": self.config.retention_default_max_lifetime, + "min_lifetime": self.config.server.retention_default_min_lifetime, + "max_lifetime": self.config.server.retention_default_max_lifetime, } row = ret[0] @@ -690,10 +690,10 @@ class RoomWorkerStore(SQLBaseStore): # The default values will be None if no default policy has been defined, or if one # of the attributes is missing from the default policy. if row["min_lifetime"] is None: - row["min_lifetime"] = self.config.retention_default_min_lifetime + row["min_lifetime"] = self.config.server.retention_default_min_lifetime if row["max_lifetime"] is None: - row["max_lifetime"] = self.config.retention_default_max_lifetime + row["max_lifetime"] = self.config.server.retention_default_max_lifetime return row diff --git a/synapse/storage/databases/main/room_batch.py b/synapse/storage/databases/main/room_batch.py index a383388757..300a563c9e 100644 --- a/synapse/storage/databases/main/room_batch.py +++ b/synapse/storage/databases/main/room_batch.py @@ -18,7 +18,9 @@ from synapse.storage._base import SQLBaseStore class RoomBatchStore(SQLBaseStore): - async def get_insertion_event_by_batch_id(self, batch_id: str) -> Optional[str]: + async def get_insertion_event_by_batch_id( + self, room_id: str, batch_id: str + ) -> Optional[str]: """Retrieve a insertion event ID. Args: @@ -30,7 +32,7 @@ class RoomBatchStore(SQLBaseStore): """ return await self.db_pool.simple_select_one_onecol( table="insertion_events", - keyvalues={"next_batch_id": batch_id}, + keyvalues={"room_id": room_id, "next_batch_id": batch_id}, retcol="event_id", allow_none=True, ) diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 2a1e99e17a..c85383c975 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -51,7 +51,7 @@ class SearchWorkerStore(SQLBaseStore): txn: entries: entries to be added to the table """ - if not self.hs.config.enable_search: + if not self.hs.config.server.enable_search: return if isinstance(self.database_engine, PostgresEngine): sql = ( @@ -105,7 +105,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - if not hs.config.enable_search: + if not hs.config.server.enable_search: return self.db_pool.updates.register_background_update_handler( diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 90d65edc42..5f538947ec 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -40,12 +40,10 @@ from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) - TEMP_TABLE = "_temp_populate_user_directory" class UserDirectoryBackgroundUpdateStore(StateDeltasStore): - # How many records do we calculate before sending it to # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 @@ -235,6 +233,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) users_with_profile = await self.get_users_in_room_with_profiles(room_id) + # Throw away users excluded from the directory. + users_with_profile = { + user_id: profile + for user_id, profile in users_with_profile.items() + if not self.hs.is_mine_id(user_id) + or await self.should_include_local_user_in_dir(user_id) + } # Update each user in the user directory. for user_id, profile in users_with_profile.items(): @@ -246,9 +251,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): if is_public: for user_id in users_with_profile: - if self.get_if_app_services_interested_in_user(user_id): - continue - to_insert.add(user_id) if to_insert: @@ -256,12 +258,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): to_insert.clear() else: for user_id in users_with_profile: + # We want the set of pairs (L, M) where L and M are + # in `users_with_profile` and L is local. + # Do so by looking for the local user L first. if not self.hs.is_mine_id(user_id): continue - if self.get_if_app_services_interested_in_user(user_id): - continue - for other_user_id in users_with_profile: if user_id == other_user_id: continue @@ -349,10 +351,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) for user_id in users_to_work_on: - profile = await self.get_profileinfo(get_localpart_from_id(user_id)) - await self.update_profile_in_user_dir( - user_id, profile.display_name, profile.avatar_url - ) + if await self.should_include_local_user_in_dir(user_id): + profile = await self.get_profileinfo(get_localpart_from_id(user_id)) + await self.update_profile_in_user_dir( + user_id, profile.display_name, profile.avatar_url + ) # We've finished processing a user. Delete it from the table. await self.db_pool.simple_delete_one( @@ -369,6 +372,24 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): return len(users_to_work_on) + async def should_include_local_user_in_dir(self, user: str) -> bool: + """Certain classes of local user are omitted from the user directory. + Is this user one of them? + """ + # App service users aren't usually contactable, so exclude them. + if self.get_if_app_services_interested_in_user(user): + # TODO we might want to make this configurable for each app service + return False + + # Support users are for diagnostics and should not appear in the user directory. + if await self.is_support_user(user): + return False + + # Deactivated users aren't contactable, so should not appear in the user directory. + if await self.get_user_deactivated_status(user): + return False + return True + async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool: """Check if the room is either world_readable or publically joinable""" @@ -527,7 +548,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): desc="get_user_in_directory", ) - async def update_user_directory_stream_pos(self, stream_id: int) -> None: + async def update_user_directory_stream_pos(self, stream_id: Optional[int]) -> None: await self.db_pool.simple_update_one( table="user_directory_stream_pos", keyvalues={}, @@ -537,7 +558,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): - # How many records do we calculate before sending it to # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index f31880b8ec..a63eaddfdc 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -366,7 +366,7 @@ def _upgrade_existing_database( + "new for the server to understand" ) - # some of the deltas assume that config.server_name is set correctly, so now + # some of the deltas assume that server_name is set correctly, so now # is a good time to run the sanity check. if not is_empty and "main" in databases: from synapse.storage.databases.main import check_database_before_upgrade diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 573e05a482..1aee741a8b 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# When updating these values, please leave a short summary of the changes below. - -SCHEMA_VERSION = 64 +SCHEMA_VERSION = 64 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -46,7 +44,7 @@ Changes in SCHEMA_VERSION = 64: """ -SCHEMA_COMPAT_VERSION = 59 +SCHEMA_COMPAT_VERSION = 60 # 60: "outlier" not in internal_metadata. """Limit on how far the synapse codebase can be rolled back without breaking db compat This value is stored in the database, and checked on startup. If the value in the diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index bd234549bd..64daff59df 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -50,7 +50,13 @@ def _handle_frozendict(obj: Any) -> Dict[Any, Any]: if type(obj) is frozendict: # fishing the protected dict out of the object is a bit nasty, # but we don't really want the overhead of copying the dict. - return obj._dict + try: + return obj._dict + except AttributeError: + # When the C implementation of frozendict is used, + # there isn't a `_dict` attribute with a dict + # so we resort to making a copy of the frozendict + return dict(obj) raise TypeError( "Object of type %s is not JSON serializable" % obj.__class__.__name__ ) diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py index baa9190a9a..389adf00f6 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py @@ -44,8 +44,8 @@ def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool: bool: whether the 3PID medium/address is allowed to be added to this HS """ - if hs.config.allowed_local_3pids: - for constraint in hs.config.allowed_local_3pids: + if hs.config.registration.allowed_local_3pids: + for constraint in hs.config.registration.allowed_local_3pids: logger.debug( "Checking 3PID %s (%s) against %s (%s)", address, diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index cccff7af26..3aa9ba3c43 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -217,7 +217,7 @@ class AuthTestCase(unittest.HomeserverTestCase): user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( - location=self.hs.config.server_name, + location=self.hs.config.server.server_name, identifier="key", key=self.hs.config.key.macaroon_secret_key, ) @@ -239,7 +239,7 @@ class AuthTestCase(unittest.HomeserverTestCase): user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( - location=self.hs.config.server_name, + location=self.hs.config.server.server_name, identifier="key", key=self.hs.config.key.macaroon_secret_key, ) @@ -268,7 +268,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.store.get_monthly_active_count = simple_async_mock(lots_of_users) e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact) + self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.value.code, 403) @@ -303,7 +303,7 @@ class AuthTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( "abcd", - self.hs.config.server_name, + self.hs.config.server.server_name, id="1234", namespaces={ "users": [{"regex": "@_appservice.*:sender", "exclusive": True}] @@ -332,7 +332,7 @@ class AuthTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( "abcd", - self.hs.config.server_name, + self.hs.config.server.server_name, id="1234", namespaces={ "users": [{"regex": "@_appservice.*:sender", "exclusive": True}] @@ -372,7 +372,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled_message = "Reason for being disabled" e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact) + self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.value.code, 403) @@ -387,7 +387,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled_message = "Reason for being disabled" e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact) + self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.value.code, 403) diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index a2b5ed2030..55f0899bae 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -24,7 +24,7 @@ from synapse.appservice.scheduler import ( from synapse.logging.context import make_deferred_yieldable from tests import unittest -from tests.test_utils import make_awaitable +from tests.test_utils import simple_async_mock from ..utils import MockClock @@ -49,11 +49,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): txn = Mock(id=txn_id, service=service, events=events) # mock methods - self.store.get_appservice_state = Mock( - return_value=defer.succeed(ApplicationServiceState.UP) - ) - txn.send = Mock(return_value=make_awaitable(True)) - self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) + self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP) + txn.send = simple_async_mock(True) + txn.complete = simple_async_mock(True) + self.store.create_appservice_txn = simple_async_mock(txn) # actual call self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) @@ -71,10 +70,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): events = [Mock(), Mock()] txn = Mock(id="idhere", service=service, events=events) - self.store.get_appservice_state = Mock( - return_value=defer.succeed(ApplicationServiceState.DOWN) + self.store.get_appservice_state = simple_async_mock( + ApplicationServiceState.DOWN ) - self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) + self.store.create_appservice_txn = simple_async_mock(txn) # actual call self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) @@ -94,12 +93,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): txn = Mock(id=txn_id, service=service, events=events) # mock methods - self.store.get_appservice_state = Mock( - return_value=defer.succeed(ApplicationServiceState.UP) - ) - self.store.set_appservice_state = Mock(return_value=defer.succeed(True)) - txn.send = Mock(return_value=make_awaitable(False)) # fails to send - self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) + self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP) + self.store.set_appservice_state = simple_async_mock(True) + txn.send = simple_async_mock(False) # fails to send + self.store.create_appservice_txn = simple_async_mock(txn) # actual call self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) @@ -122,7 +119,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.as_api = Mock() self.store = Mock() self.service = Mock() - self.callback = Mock() + self.callback = simple_async_mock() self.recoverer = _Recoverer( clock=self.clock, as_api=self.as_api, @@ -144,8 +141,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.recoverer.recover() # shouldn't have called anything prior to waiting for exp backoff self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count) - txn.send = Mock(return_value=make_awaitable(True)) - txn.complete.return_value = make_awaitable(None) + txn.send = simple_async_mock(True) + txn.complete = simple_async_mock(None) # wait for exp backoff self.clock.advance_time(2) self.assertEquals(1, txn.send.call_count) @@ -170,8 +167,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.recoverer.recover() self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count) - txn.send = Mock(return_value=make_awaitable(False)) - txn.complete.return_value = make_awaitable(None) + txn.send = simple_async_mock(False) + txn.complete = simple_async_mock(None) self.clock.advance_time(2) self.assertEquals(1, txn.send.call_count) self.assertEquals(0, txn.complete.call_count) @@ -184,7 +181,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.assertEquals(3, txn.send.call_count) self.assertEquals(0, txn.complete.call_count) self.assertEquals(0, self.callback.call_count) - txn.send = Mock(return_value=make_awaitable(True)) # successfully send the txn + txn.send = simple_async_mock(True) # successfully send the txn pop_txn = True # returns the txn the first time, then no more. self.clock.advance_time(16) self.assertEquals(1, txn.send.call_count) # new mock reset call count @@ -195,6 +192,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): def setUp(self): self.txn_ctrl = Mock() + self.txn_ctrl.send = simple_async_mock() self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock()) def test_send_single_event_no_queue(self): diff --git a/tests/config/test_load.py b/tests/config/test_load.py index ef6c2beec7..8e49ca26d9 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -84,16 +84,16 @@ class ConfigLoadingTestCase(unittest.TestCase): ) # Check that disable_registration clobbers enable_registration. config = HomeServerConfig.load_config("", ["-c", self.file]) - self.assertFalse(config.enable_registration) + self.assertFalse(config.registration.enable_registration) config = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) - self.assertFalse(config.enable_registration) + self.assertFalse(config.registration.enable_registration) # Check that either config value is clobbered by the command line. config = HomeServerConfig.load_or_generate_config( "", ["-c", self.file, "--enable-registration"] ) - self.assertTrue(config.enable_registration) + self.assertTrue(config.registration.enable_registration) def test_stats_enabled(self): self.generate_config_and_remove_lines_containing("enable_metrics") diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index 3b3866bff8..3deb14c308 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -26,6 +26,7 @@ from synapse.rest.client import login, presence, room from synapse.types import JsonDict, StreamToken, create_requester from tests.handlers.test_sync import generate_sync_config +from tests.test_utils import simple_async_mock from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config @@ -133,8 +134,12 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): ] def make_homeserver(self, reactor, clock): + # Mock out the calls over federation. + fed_transport_client = Mock(spec=["send_transaction"]) + fed_transport_client.send_transaction = simple_async_mock({}) + hs = self.setup_test_homeserver( - federation_transport_client=Mock(spec=["send_transaction"]), + federation_transport_client=fed_transport_client, ) # Load the modules into the homeserver module_api = hs.get_module_api() diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 65b18fbd7a..b457dad6d2 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -336,7 +336,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): recovery """ mock_send_txn = self.hs.get_federation_transport_client().send_transaction - mock_send_txn.side_effect = lambda t, cb: defer.fail("fail") + mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) # create devices u1 = self.register_user("user", "pass") @@ -376,7 +376,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): This case tests the behaviour when the server has never been reachable. """ mock_send_txn = self.hs.get_federation_transport_client().send_transaction - mock_send_txn.side_effect = lambda t, cb: defer.fail("fail") + mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) # create devices u1 = self.register_user("user", "pass") @@ -429,7 +429,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # now the server goes offline mock_send_txn = self.hs.get_federation_transport_client().send_transaction - mock_send_txn.side_effect = lambda t, cb: defer.fail("fail") + mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) self.login("user", "pass", device_id="D2") self.login("user", "pass", device_id="D3") diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 0b60cc4261..03e1e11f49 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -120,7 +120,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): self.assertEqual( channel.json_body["room_version"], - self.hs.config.default_room_version.identifier, + self.hs.config.server.default_room_version.identifier, ) members = set( diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 57cc3e2646..c153018fd8 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -110,7 +110,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) def test_set_my_name_if_disabled(self): - self.hs.config.enable_set_displayname = False + self.hs.config.registration.enable_set_displayname = False # Setting displayname for the first time is allowed self.get_success( @@ -225,7 +225,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) def test_set_my_avatar_if_disabled(self): - self.hs.config.enable_set_avatar_url = False + self.hs.config.registration.enable_set_avatar_url = False # Setting displayname for the first time is allowed self.get_success( diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index d3efb67e3e..db691c4c1c 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -16,7 +16,12 @@ from unittest.mock import Mock from synapse.api.auth import Auth from synapse.api.constants import UserTypes -from synapse.api.errors import Codes, ResourceLimitError, SynapseError +from synapse.api.errors import ( + CodeMessageException, + Codes, + ResourceLimitError, + SynapseError, +) from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import RoomAlias, RoomID, UserID, create_requester @@ -120,14 +125,24 @@ class RegistrationTestCase(unittest.HomeserverTestCase): hs_config = self.default_config() # some of the tests rely on us having a user consent version - hs_config["user_consent"] = { - "version": "test_consent_version", - "template_dir": ".", - } + hs_config.setdefault("user_consent", {}).update( + { + "version": "test_consent_version", + "template_dir": ".", + } + ) hs_config["max_mau_value"] = 50 hs_config["limit_usage_by_mau"] = True - hs = self.setup_test_homeserver(config=hs_config) + # Don't attempt to reach out over federation. + self.mock_federation_client = Mock() + self.mock_federation_client.make_query.side_effect = CodeMessageException( + 500, "" + ) + + hs = self.setup_test_homeserver( + config=hs_config, federation_client=self.mock_federation_client + ) load_legacy_spam_checkers(hs) @@ -138,9 +153,6 @@ class RegistrationTestCase(unittest.HomeserverTestCase): return hs def prepare(self, reactor, clock, hs): - self.mock_distributor = Mock() - self.mock_distributor.declare("registered_user") - self.mock_captcha_client = Mock() self.handler = self.hs.get_registration_handler() self.store = self.hs.get_datastore() self.lots_of_users = 100 @@ -174,21 +186,21 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertEquals(result_user_id, user_id) self.assertTrue(result_token is not None) + @override_config({"limit_usage_by_mau": False}) def test_mau_limits_when_disabled(self): - self.hs.config.limit_usage_by_mau = False # Ensure does not throw exception self.get_success(self.get_or_create_user(self.requester, "a", "display_name")) + @override_config({"limit_usage_by_mau": True}) def test_get_or_create_user_mau_not_blocked(self): - self.hs.config.limit_usage_by_mau = True self.store.count_monthly_users = Mock( - return_value=make_awaitable(self.hs.config.max_mau_value - 1) + return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) ) # Ensure does not throw exception self.get_success(self.get_or_create_user(self.requester, "c", "User")) + @override_config({"limit_usage_by_mau": True}) def test_get_or_create_user_mau_blocked(self): - self.hs.config.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( return_value=make_awaitable(self.lots_of_users) ) @@ -198,15 +210,15 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.server.max_mau_value) ) self.get_failure( self.get_or_create_user(self.requester, "b", "display_name"), ResourceLimitError, ) + @override_config({"limit_usage_by_mau": True}) def test_register_mau_blocked(self): - self.hs.config.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( return_value=make_awaitable(self.lots_of_users) ) @@ -215,16 +227,16 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.server.max_mau_value) ) self.get_failure( self.handler.register_user(localpart="local_part"), ResourceLimitError ) + @override_config( + {"auto_join_rooms": ["#room:test"], "auto_join_rooms_for_guests": False} + ) def test_auto_join_rooms_for_guests(self): - room_alias_str = "#room:test" - self.hs.config.auto_join_rooms = [room_alias_str] - self.hs.config.auto_join_rooms_for_guests = False user_id = self.get_success( self.handler.register_user(localpart="jeff", make_guest=True), ) @@ -243,34 +255,33 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertTrue(room_id["room_id"] in rooms) self.assertEqual(len(rooms), 1) + @override_config({"auto_join_rooms": []}) def test_auto_create_auto_join_rooms_with_no_rooms(self): - self.hs.config.auto_join_rooms = [] frank = UserID.from_string("@frank:test") user_id = self.get_success(self.handler.register_user(frank.localpart)) self.assertEqual(user_id, frank.to_string()) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) + @override_config({"auto_join_rooms": ["#room:another"]}) def test_auto_create_auto_join_where_room_is_another_domain(self): - self.hs.config.auto_join_rooms = ["#room:another"] frank = UserID.from_string("@frank:test") user_id = self.get_success(self.handler.register_user(frank.localpart)) self.assertEqual(user_id, frank.to_string()) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) + @override_config( + {"auto_join_rooms": ["#room:test"], "autocreate_auto_join_rooms": False} + ) def test_auto_create_auto_join_where_auto_create_is_false(self): - self.hs.config.autocreate_auto_join_rooms = False - room_alias_str = "#room:test" - self.hs.config.auto_join_rooms = [room_alias_str] user_id = self.get_success(self.handler.register_user(localpart="jeff")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) + @override_config({"auto_join_rooms": ["#room:test"]}) def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self): room_alias_str = "#room:test" - self.hs.config.auto_join_rooms = [room_alias_str] - self.store.is_real_user = Mock(return_value=make_awaitable(False)) user_id = self.get_success(self.handler.register_user(localpart="support")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) @@ -294,10 +305,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertTrue(room_id["room_id"] in rooms) self.assertEqual(len(rooms), 1) + @override_config({"auto_join_rooms": ["#room:test"]}) def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(self): - room_alias_str = "#room:test" - self.hs.config.auto_join_rooms = [room_alias_str] - self.store.count_real_users = Mock(return_value=make_awaitable(2)) self.store.is_real_user = Mock(return_value=make_awaitable(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) @@ -510,6 +519,17 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertEqual(rooms, set()) self.assertEqual(invited_rooms, []) + @override_config( + { + "user_consent": { + "block_events_error": "Error", + "require_at_registration": True, + }, + "form_secret": "53cr3t", + "public_baseurl": "http://test", + "auto_join_rooms": ["#room:test"], + }, + ) def test_auto_create_auto_join_where_no_consent(self): """Test to ensure that the first user is not auto-joined to a room if they have not given general consent. @@ -521,25 +541,20 @@ class RegistrationTestCase(unittest.HomeserverTestCase): # * The server is configured to auto-join to a room # (and autocreate if necessary) - event_creation_handler = self.hs.get_event_creation_handler() - # (Messing with the internals of event_creation_handler is fragile - # but can't see a better way to do this. One option could be to subclass - # the test with custom config.) - event_creation_handler._block_events_without_consent_error = "Error" - event_creation_handler._consent_uri_builder = Mock() - room_alias_str = "#room:test" - self.hs.config.auto_join_rooms = [room_alias_str] - # When:- - # * the user is registered and post consent actions are called + # * the user is registered user_id = self.get_success(self.handler.register_user(localpart="jeff")) - self.get_success(self.handler.post_consent_actions(user_id)) # Then:- # * Ensure that they have not been joined to the room rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) + # The user provides consent; ensure they are now in the rooms. + self.get_success(self.handler.post_consent_actions(user_id)) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertEqual(len(rooms), 1) + def test_register_support_user(self): user_id = self.get_success( self.handler.register_user(localpart="user", user_type=UserTypes.SUPPORT) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 266333c553..b3c3af113b 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -11,47 +11,208 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple -from unittest.mock import Mock +from typing import Tuple +from unittest.mock import Mock, patch from urllib.parse import quote from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.api.room_versions import RoomVersion, RoomVersions -from synapse.rest.client import login, room, user_directory +from synapse.appservice import ApplicationService +from synapse.rest.client import login, register, room, user_directory +from synapse.server import HomeServer from synapse.storage.roommember import ProfileInfo from synapse.types import create_requester +from synapse.util import Clock from tests import unittest +from tests.storage.test_user_directory import GetUserDirectoryTables +from tests.test_utils.event_injection import inject_member_event from tests.unittest import override_config class UserDirectoryTestCase(unittest.HomeserverTestCase): - """ - Tests the UserDirectoryHandler. + """Tests the UserDirectoryHandler. + + We're broadly testing two kinds of things here. + + 1. Check that we correctly update the user directory in response + to events (e.g. join a room, leave a room, change name, make public) + 2. Check that the search logic behaves as expected. + + The background process that rebuilds the user directory is tested in + tests/storage/test_user_directory.py. """ servlets = [ login.register_servlets, synapse.rest.admin.register_servlets, + register.register_servlets, room.register_servlets, ] - def make_homeserver(self, reactor, clock): - + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["update_user_directory"] = True - return self.setup_test_homeserver(config=config) - def prepare(self, reactor, clock, hs): + self.appservice = ApplicationService( + token="i_am_an_app_service", + hostname="test", + id="1234", + namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + sender="@as:test", + ) + + mock_load_appservices = Mock(return_value=[self.appservice]) + with patch( + "synapse.storage.databases.main.appservice.load_appservices", + mock_load_appservices, + ): + hs = self.setup_test_homeserver(config=config) + return hs + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastore() self.handler = hs.get_user_directory_handler() self.event_builder_factory = self.hs.get_event_builder_factory() self.event_creation_handler = self.hs.get_event_creation_handler() + self.user_dir_helper = GetUserDirectoryTables(self.store) + + def test_normal_user_pair(self) -> None: + """Sanity check that the room-sharing tables are updated correctly.""" + alice = self.register_user("alice", "pass") + alice_token = self.login(alice, "pass") + bob = self.register_user("bob", "pass") + bob_token = self.login(bob, "pass") + + public = self.helper.create_room_as( + alice, + is_public=True, + extra_content={"visibility": "public"}, + tok=alice_token, + ) + private = self.helper.create_room_as(alice, is_public=False, tok=alice_token) + self.helper.invite(private, alice, bob, tok=alice_token) + self.helper.join(public, bob, tok=bob_token) + self.helper.join(private, bob, tok=bob_token) + + # Alice also makes a second public room but no-one else joins + public2 = self.helper.create_room_as( + alice, + is_public=True, + extra_content={"visibility": "public"}, + tok=alice_token, + ) + + users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) + in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) + in_private = self.get_success( + self.user_dir_helper.get_users_who_share_private_rooms() + ) + + self.assertEqual(users, {alice, bob}) + self.assertEqual( + set(in_public), {(alice, public), (bob, public), (alice, public2)} + ) + self.assertEqual( + self.user_dir_helper._compress_shared(in_private), + {(alice, bob, private), (bob, alice, private)}, + ) + + # The next three tests (test_population_excludes_*) all setup + # - A normal user included in the user dir + # - A public and private room created by that user + # - A user excluded from the room dir, belonging to both rooms + + # They match similar logic in storage/test_user_directory. But that tests + # rebuilding the directory; this tests updating it incrementally. + + def test_excludes_support_user(self) -> None: + alice = self.register_user("alice", "pass") + alice_token = self.login(alice, "pass") + support = "@support1:test" + self.get_success( + self.store.register_user( + user_id=support, password_hash=None, user_type=UserTypes.SUPPORT + ) + ) + + public, private = self._create_rooms_and_inject_memberships( + alice, alice_token, support + ) + self._check_only_one_user_in_directory(alice, public) + + def test_excludes_deactivated_user(self) -> None: + admin = self.register_user("admin", "pass", admin=True) + admin_token = self.login(admin, "pass") + user = self.register_user("naughty", "pass") + + # Deactivate the user. + channel = self.make_request( + "PUT", + f"/_synapse/admin/v2/users/{user}", + access_token=admin_token, + content={"deactivated": True}, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["deactivated"], True) - def test_handle_local_profile_change_with_support_user(self): + # Join the deactivated user to rooms owned by the admin. + # Is this something that could actually happen outside of a test? + public, private = self._create_rooms_and_inject_memberships( + admin, admin_token, user + ) + self._check_only_one_user_in_directory(admin, public) + + def test_excludes_appservices_user(self) -> None: + # Register an AS user. + user = self.register_user("user", "pass") + token = self.login(user, "pass") + as_user = self.register_appservice_user("as_user_potato", self.appservice.token) + + # Join the AS user to rooms owned by the normal user. + public, private = self._create_rooms_and_inject_memberships( + user, token, as_user + ) + self._check_only_one_user_in_directory(user, public) + + def _create_rooms_and_inject_memberships( + self, creator: str, token: str, joiner: str + ) -> Tuple[str, str]: + """Create a public and private room as a normal user. + Then get the `joiner` into those rooms. + """ + # TODO: Duplicates the same-named method in UserDirectoryInitialPopulationTest. + public_room = self.helper.create_room_as( + creator, + is_public=True, + # See https://github.com/matrix-org/synapse/issues/10951 + extra_content={"visibility": "public"}, + tok=token, + ) + private_room = self.helper.create_room_as(creator, is_public=False, tok=token) + + # HACK: get the user into these rooms + self.get_success(inject_member_event(self.hs, public_room, joiner, "join")) + self.get_success(inject_member_event(self.hs, private_room, joiner, "join")) + + return public_room, private_room + + def _check_only_one_user_in_directory(self, user: str, public: str) -> None: + users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) + in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) + in_private = self.get_success( + self.user_dir_helper.get_users_who_share_private_rooms() + ) + + self.assertEqual(users, {user}) + self.assertEqual(set(in_public), {(user, public)}) + self.assertEqual(in_private, []) + + def test_handle_local_profile_change_with_support_user(self) -> None: support_user_id = "@support:test" self.get_success( self.store.register_user( @@ -64,7 +225,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) self.get_success( - self.handler.handle_local_profile_change(support_user_id, None) + self.handler.handle_local_profile_change( + support_user_id, ProfileInfo("I love support me", None) + ) ) profile = self.get_success(self.store.get_user_in_directory(support_user_id)) self.assertTrue(profile is None) @@ -77,7 +240,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): profile = self.get_success(self.store.get_user_in_directory(regular_user_id)) self.assertTrue(profile["display_name"] == display_name) - def test_handle_local_profile_change_with_deactivated_user(self): + def test_handle_local_profile_change_with_deactivated_user(self) -> None: # create user r_user_id = "@regular:test" self.get_success( @@ -112,7 +275,27 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): profile = self.get_success(self.store.get_user_in_directory(r_user_id)) self.assertTrue(profile is None) - def test_handle_user_deactivated_support_user(self): + def test_handle_local_profile_change_with_appservice_user(self) -> None: + # create user + as_user_id = self.register_appservice_user( + "as_user_alice", self.appservice.token + ) + + # profile is not in directory + profile = self.get_success(self.store.get_user_in_directory(as_user_id)) + self.assertTrue(profile is None) + + # update profile + profile_info = ProfileInfo(avatar_url="avatar_url", display_name="4L1c3") + self.get_success( + self.handler.handle_local_profile_change(as_user_id, profile_info) + ) + + # profile is still not in directory + profile = self.get_success(self.store.get_user_in_directory(as_user_id)) + self.assertTrue(profile is None) + + def test_handle_user_deactivated_support_user(self) -> None: s_user_id = "@support:test" self.get_success( self.store.register_user( @@ -120,20 +303,29 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) - self.store.remove_from_user_dir = Mock(return_value=defer.succeed(None)) - self.get_success(self.handler.handle_local_user_deactivated(s_user_id)) - self.store.remove_from_user_dir.not_called() + mock_remove_from_user_dir = Mock(return_value=defer.succeed(None)) + with patch.object( + self.store, "remove_from_user_dir", mock_remove_from_user_dir + ): + self.get_success(self.handler.handle_local_user_deactivated(s_user_id)) + # BUG: the correct spelling is assert_not_called, but that makes the test fail + # and it's not clear that this is actually the behaviour we want. + mock_remove_from_user_dir.not_called() - def test_handle_user_deactivated_regular_user(self): + def test_handle_user_deactivated_regular_user(self) -> None: r_user_id = "@regular:test" self.get_success( self.store.register_user(user_id=r_user_id, password_hash=None) ) - self.store.remove_from_user_dir = Mock(return_value=defer.succeed(None)) - self.get_success(self.handler.handle_local_user_deactivated(r_user_id)) - self.store.remove_from_user_dir.called_once_with(r_user_id) - def test_reactivation_makes_regular_user_searchable(self): + mock_remove_from_user_dir = Mock(return_value=defer.succeed(None)) + with patch.object( + self.store, "remove_from_user_dir", mock_remove_from_user_dir + ): + self.get_success(self.handler.handle_local_user_deactivated(r_user_id)) + mock_remove_from_user_dir.assert_called_once_with(r_user_id) + + def test_reactivation_makes_regular_user_searchable(self) -> None: user = self.register_user("regular", "pass") user_token = self.login(user, "pass") admin_user = self.register_user("admin", "pass", admin=True) @@ -171,7 +363,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.assertEqual(len(s["results"]), 1) self.assertEqual(s["results"][0]["user_id"], user) - def test_private_room(self): + def test_private_room(self) -> None: """ A user can be searched for only by people that are either in a public room, or that share a private chat. @@ -191,11 +383,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.helper.join(room, user=u2, tok=u2_token) # Check we have populated the database correctly. - shares_private = self.get_users_who_share_private_rooms() - public_users = self.get_users_in_public_rooms() + shares_private = self.get_success( + self.user_dir_helper.get_users_who_share_private_rooms() + ) + public_users = self.get_success( + self.user_dir_helper.get_users_in_public_rooms() + ) self.assertEqual( - self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)} + self.user_dir_helper._compress_shared(shares_private), + {(u1, u2, room), (u2, u1, room)}, ) self.assertEqual(public_users, []) @@ -215,10 +412,14 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.helper.leave(room, user=u2, tok=u2_token) # Check we have removed the values. - shares_private = self.get_users_who_share_private_rooms() - public_users = self.get_users_in_public_rooms() + shares_private = self.get_success( + self.user_dir_helper.get_users_who_share_private_rooms() + ) + public_users = self.get_success( + self.user_dir_helper.get_users_in_public_rooms() + ) - self.assertEqual(self._compress_shared(shares_private), set()) + self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set()) self.assertEqual(public_users, []) # User1 now gets no search results for any of the other users. @@ -228,7 +429,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): s = self.get_success(self.handler.search_users(u1, "user3", 10)) self.assertEqual(len(s["results"]), 0) - def test_spam_checker(self): + def test_spam_checker(self) -> None: """ A user which fails the spam checks will not appear in search results. """ @@ -246,11 +447,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.helper.join(room, user=u2, tok=u2_token) # Check we have populated the database correctly. - shares_private = self.get_users_who_share_private_rooms() - public_users = self.get_users_in_public_rooms() + shares_private = self.get_success( + self.user_dir_helper.get_users_who_share_private_rooms() + ) + public_users = self.get_success( + self.user_dir_helper.get_users_in_public_rooms() + ) self.assertEqual( - self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)} + self.user_dir_helper._compress_shared(shares_private), + {(u1, u2, room), (u2, u1, room)}, ) self.assertEqual(public_users, []) @@ -258,7 +464,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): s = self.get_success(self.handler.search_users(u1, "user2", 10)) self.assertEqual(len(s["results"]), 1) - async def allow_all(user_profile): + async def allow_all(user_profile: ProfileInfo) -> bool: # Allow all users. return False @@ -272,7 +478,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.assertEqual(len(s["results"]), 1) # Configure a spam checker that filters all users. - async def block_all(user_profile): + async def block_all(user_profile: ProfileInfo) -> bool: # All users are spammy. return True @@ -282,7 +488,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): s = self.get_success(self.handler.search_users(u1, "user2", 10)) self.assertEqual(len(s["results"]), 0) - def test_legacy_spam_checker(self): + def test_legacy_spam_checker(self) -> None: """ A spam checker without the expected method should be ignored. """ @@ -300,11 +506,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.helper.join(room, user=u2, tok=u2_token) # Check we have populated the database correctly. - shares_private = self.get_users_who_share_private_rooms() - public_users = self.get_users_in_public_rooms() + shares_private = self.get_success( + self.user_dir_helper.get_users_who_share_private_rooms() + ) + public_users = self.get_success( + self.user_dir_helper.get_users_in_public_rooms() + ) self.assertEqual( - self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)} + self.user_dir_helper._compress_shared(shares_private), + {(u1, u2, room), (u2, u1, room)}, ) self.assertEqual(public_users, []) @@ -317,134 +528,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): s = self.get_success(self.handler.search_users(u1, "user2", 10)) self.assertEqual(len(s["results"]), 1) - def _compress_shared(self, shared): - """ - Compress a list of users who share rooms dicts to a list of tuples. - """ - r = set() - for i in shared: - r.add((i["user_id"], i["other_user_id"], i["room_id"])) - return r - - def get_users_in_public_rooms(self) -> List[Tuple[str, str]]: - r = self.get_success( - self.store.db_pool.simple_select_list( - "users_in_public_rooms", None, ("user_id", "room_id") - ) - ) - retval = [] - for i in r: - retval.append((i["user_id"], i["room_id"])) - return retval - - def get_users_who_share_private_rooms(self) -> List[Tuple[str, str, str]]: - return self.get_success( - self.store.db_pool.simple_select_list( - "users_who_share_private_rooms", - None, - ["user_id", "other_user_id", "room_id"], - ) - ) - - def _add_background_updates(self): - """ - Add the background updates we need to run. - """ - # Ugh, have to reset this flag - self.store.db_pool.updates._all_done = False - - self.get_success( - self.store.db_pool.simple_insert( - "background_updates", - { - "update_name": "populate_user_directory_createtables", - "progress_json": "{}", - }, - ) - ) - self.get_success( - self.store.db_pool.simple_insert( - "background_updates", - { - "update_name": "populate_user_directory_process_rooms", - "progress_json": "{}", - "depends_on": "populate_user_directory_createtables", - }, - ) - ) - self.get_success( - self.store.db_pool.simple_insert( - "background_updates", - { - "update_name": "populate_user_directory_process_users", - "progress_json": "{}", - "depends_on": "populate_user_directory_process_rooms", - }, - ) - ) - self.get_success( - self.store.db_pool.simple_insert( - "background_updates", - { - "update_name": "populate_user_directory_cleanup", - "progress_json": "{}", - "depends_on": "populate_user_directory_process_users", - }, - ) - ) - - def test_initial(self): - """ - The user directory's initial handler correctly updates the search tables. - """ - u1 = self.register_user("user1", "pass") - u1_token = self.login(u1, "pass") - u2 = self.register_user("user2", "pass") - u2_token = self.login(u2, "pass") - u3 = self.register_user("user3", "pass") - u3_token = self.login(u3, "pass") - - room = self.helper.create_room_as(u1, is_public=True, tok=u1_token) - self.helper.invite(room, src=u1, targ=u2, tok=u1_token) - self.helper.join(room, user=u2, tok=u2_token) - - private_room = self.helper.create_room_as(u1, is_public=False, tok=u1_token) - self.helper.invite(private_room, src=u1, targ=u3, tok=u1_token) - self.helper.join(private_room, user=u3, tok=u3_token) - - self.get_success(self.store.update_user_directory_stream_pos(None)) - self.get_success(self.store.delete_all_from_user_dir()) - - shares_private = self.get_users_who_share_private_rooms() - public_users = self.get_users_in_public_rooms() - - # Nothing updated yet - self.assertEqual(shares_private, []) - self.assertEqual(public_users, []) - - # Do the initial population of the user directory via the background update - self._add_background_updates() - - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) - - shares_private = self.get_users_who_share_private_rooms() - public_users = self.get_users_in_public_rooms() - - # User 1 and User 2 are in the same public room - self.assertEqual(set(public_users), {(u1, room), (u2, room)}) - - # User 1 and User 3 share private rooms - self.assertEqual( - self._compress_shared(shares_private), - {(u1, u3, private_room), (u3, u1, private_room)}, - ) - - def test_initial_share_all_users(self): + def test_initial_share_all_users(self) -> None: """ Search all users = True means that a user does not have to share a private room with the searching user or be in a public room to be search @@ -457,26 +541,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.register_user("user2", "pass") u3 = self.register_user("user3", "pass") - # Wipe the user dir - self.get_success(self.store.update_user_directory_stream_pos(None)) - self.get_success(self.store.delete_all_from_user_dir()) - - # Do the initial population of the user directory via the background update - self._add_background_updates() - - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) - - shares_private = self.get_users_who_share_private_rooms() - public_users = self.get_users_in_public_rooms() + shares_private = self.get_success( + self.user_dir_helper.get_users_who_share_private_rooms() + ) + public_users = self.get_success( + self.user_dir_helper.get_users_in_public_rooms() + ) # No users share rooms self.assertEqual(public_users, []) - self.assertEqual(self._compress_shared(shares_private), set()) + self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set()) # Despite not sharing a room, search_all_users means we get a search # result. @@ -501,7 +575,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): } } ) - def test_prefer_local_users(self): + def test_prefer_local_users(self) -> None: """Tests that local users are shown higher in search results when user_directory.prefer_local_users is True. """ @@ -535,15 +609,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): local_users = [local_user_1, local_user_2, local_user_3] remote_users = [remote_user_1, remote_user_2, remote_user_3] - # Populate the user directory via background update - self._add_background_updates() - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) - # The local searching user searches for the term "user", which other users have # in their user id results = self.get_success( @@ -565,7 +630,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): room_id: str, room_version: RoomVersion, user_id: str, - ): + ) -> None: # Add a user to the room. builder = self.event_builder_factory.for_room_version( room_version, @@ -588,8 +653,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): class TestUserDirSearchDisabled(unittest.HomeserverTestCase): - user_id = "@test:test" - servlets = [ user_directory.register_servlets, room.register_servlets, @@ -597,7 +660,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): synapse.rest.admin.register_servlets_for_client_rest_resource, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["update_user_directory"] = True hs = self.setup_test_homeserver(config=config) @@ -606,19 +669,24 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): return hs - def test_disabling_room_list(self): + def test_disabling_room_list(self) -> None: self.config.userdirectory.user_directory_search_enabled = True - # First we create a room with another user so that user dir is non-empty - # for our user - self.helper.create_room_as(self.user_id) + # Create two users and put them in the same room. + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") u2 = self.register_user("user2", "pass") - room = self.helper.create_room_as(self.user_id) - self.helper.join(room, user=u2) + u2_token = self.login(u2, "pass") + + room = self.helper.create_room_as(u1, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) - # Assert user directory is not empty + # Each should see the other when searching the user directory. channel = self.make_request( - "POST", b"user_directory/search", b'{"search_term":"user2"}' + "POST", + b"user_directory/search", + b'{"search_term":"user2"}', + access_token=u1_token, ) self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["results"]) > 0) @@ -626,7 +694,10 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): # Disable user directory and check search returns nothing self.config.userdirectory.user_directory_search_enabled = False channel = self.make_request( - "POST", b"user_directory/search", b'{"search_term":"user2"}' + "POST", + b"user_directory/search", + b'{"search_term":"user2"}', + access_token=u1_token, ) self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["results"]) == 0) diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index d9a8b077d3..638babae69 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -226,7 +226,7 @@ class FederationClientTests(HomeserverTestCase): """Ensure that Synapse does not try to connect to blacklisted IPs""" # Set up the ip_range blacklist - self.hs.config.federation_ip_range_blacklist = IPSet( + self.hs.config.server.federation_ip_range_blacklist = IPSet( ["127.0.0.0/8", "fe80::/64"] ) self.reactor.lookups["internal"] = "127.0.0.1" diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 9d38974fba..e915dd5c7c 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -25,6 +25,7 @@ from synapse.types import create_requester from tests.events.test_presence_router import send_presence_update, sync_presence from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.test_utils import simple_async_mock from tests.test_utils.event_injection import inject_member_event from tests.unittest import HomeserverTestCase, override_config from tests.utils import USE_POSTGRES_FOR_TESTS @@ -46,8 +47,12 @@ class ModuleApiTestCase(HomeserverTestCase): self.auth_handler = homeserver.get_auth_handler() def make_homeserver(self, reactor, clock): + # Mock out the calls over federation. + fed_transport_client = Mock(spec=["send_transaction"]) + fed_transport_client.send_transaction = simple_async_mock({}) + return self.setup_test_homeserver( - federation_transport_client=Mock(spec=["send_transaction"]), + federation_transport_client=fed_transport_client, ) def test_can_register_user(self): diff --git a/tests/replication/_base.py b/tests/replication/_base.py index c7555c26db..eac4664b41 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -70,8 +70,16 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # databases objects are the same. self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool + # Normally we'd pass in the handler to `setup_test_homeserver`, which would + # eventually hit "Install @cache_in_self attributes" in tests/utils.py. + # Unfortunately our handler wants a reference to the homeserver. That leaves + # us with a chicken-and-egg problem. + # We can workaround this: create the homeserver first, create the handler + # and bodge it in after the fact. The bodging requires us to know the + # dirty details of how `cache_in_self` works. We politely ask mypy to + # ignore our dirty dealings. self.test_handler = self._build_replication_data_handler() - self.worker_hs._replication_data_handler = self.test_handler + self.worker_hs._replication_data_handler = self.test_handler # type: ignore[attr-defined] repl_handler = ReplicationCommandHandler(self.worker_hs) self.client = ClientReplicationStreamProtocol( @@ -240,7 +248,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): if self.hs.config.redis.redis_enabled: # Handle attempts to connect to fake redis server. self.reactor.add_tcp_client_callback( - b"localhost", + "localhost", 6379, self.connect_any_redis_attempts, ) @@ -315,12 +323,15 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): ) ) + # Copy the port into a new, non-Optional variable so mypy knows we're + # not going to reset `instance_loc` to `None` under its feet. See + # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions + port = instance_loc.port + self.reactor.add_tcp_client_callback( self.reactor.lookups[instance_loc.host], instance_loc.port, - lambda: self._handle_http_replication_attempt( - worker_hs, instance_loc.port - ), + lambda: self._handle_http_replication_attempt(worker_hs, port), ) store = worker_hs.get_datastore() @@ -424,7 +435,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): clients = self.reactor.tcpClients while clients: (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) - self.assertEqual(host, b"localhost") + self.assertEqual(host, "localhost") self.assertEqual(port, 6379) client_protocol = client_factory.buildProtocol(None) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index ee3ae9cce4..6ed9e42173 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -59,7 +59,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.hs = self.setup_test_homeserver() - self.hs.config.registration_shared_secret = "shared" + self.hs.config.registration.registration_shared_secret = "shared" self.hs.get_media_repository = Mock() self.hs.get_deactivate_account_handler = Mock() @@ -71,7 +71,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): If there is no shared secret, registration through this method will be prevented. """ - self.hs.config.registration_shared_secret = None + self.hs.config.registration.registration_shared_secret = None channel = self.make_request("POST", self.url, b"{}") @@ -422,7 +422,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.server.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -1485,7 +1485,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.server.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -1522,7 +1522,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.server.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 9e9e953cf4..89d85b0a17 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -470,13 +470,45 @@ class WhoamiTestCase(unittest.HomeserverTestCase): register.register_servlets, ] + def default_config(self): + config = super().default_config() + config["allow_guest_access"] = True + return config + def test_GET_whoami(self): device_id = "wouldgohere" user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test", device_id=device_id) - whoami = self.whoami(tok) - self.assertEqual(whoami, {"user_id": user_id, "device_id": device_id}) + whoami = self._whoami(tok) + self.assertEqual( + whoami, + { + "user_id": user_id, + "device_id": device_id, + # Unstable until MSC3069 enters spec + "org.matrix.msc3069.is_guest": False, + }, + ) + + def test_GET_whoami_guests(self): + channel = self.make_request( + b"POST", b"/_matrix/client/r0/register?kind=guest", b"{}" + ) + tok = channel.json_body["access_token"] + user_id = channel.json_body["user_id"] + device_id = channel.json_body["device_id"] + + whoami = self._whoami(tok) + self.assertEqual( + whoami, + { + "user_id": user_id, + "device_id": device_id, + # Unstable until MSC3069 enters spec + "org.matrix.msc3069.is_guest": True, + }, + ) def test_GET_whoami_appservices(self): user_id = "@as:test" @@ -484,18 +516,25 @@ class WhoamiTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( as_token, - self.hs.config.server_name, + self.hs.config.server.server_name, id="1234", namespaces={"users": [{"regex": user_id, "exclusive": True}]}, sender=user_id, ) self.hs.get_datastore().services_cache.append(appservice) - whoami = self.whoami(as_token) - self.assertEqual(whoami, {"user_id": user_id}) + whoami = self._whoami(as_token) + self.assertEqual( + whoami, + { + "user_id": user_id, + # Unstable until MSC3069 enters spec + "org.matrix.msc3069.is_guest": False, + }, + ) self.assertFalse(hasattr(whoami, "device_id")) - def whoami(self, tok): + def _whoami(self, tok): channel = self.make_request("GET", "account/whoami", {}, access_token=tok) self.assertEqual(channel.code, 200) return channel.json_body @@ -625,7 +664,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_add_email_if_disabled(self): """Test adding email to profile when doing so is disallowed""" - self.hs.config.enable_3pid_changes = False + self.hs.config.registration.enable_3pid_changes = False client_secret = "foobar" session_id = self._request_token(self.email, client_secret) @@ -695,7 +734,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_delete_email_if_disabled(self): """Test deleting an email from profile when disallowed""" - self.hs.config.enable_3pid_changes = False + self.hs.config.registration.enable_3pid_changes = False # Add a threepid self.get_success( diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py index 422361b62a..b9e3602552 100644 --- a/tests/rest/client/test_capabilities.py +++ b/tests/rest/client/test_capabilities.py @@ -55,7 +55,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version) self.assertEqual( - self.config.default_room_version.identifier, + self.config.server.default_room_version.identifier, capabilities["m.room_versions"]["default"], ) diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index ca2e8ff8ef..becb4e8dcc 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -37,7 +37,7 @@ class IdentityTestCase(unittest.HomeserverTestCase): return self.hs def test_3pid_lookup_disabled(self): - self.hs.config.enable_3pid_lookup = False + self.hs.config.registration.enable_3pid_lookup = False self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 371615a015..a63f04bd41 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -94,9 +94,9 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver() - self.hs.config.enable_registration = True - self.hs.config.registrations_require_3pid = [] - self.hs.config.auto_join_rooms = [] + self.hs.config.registration.enable_registration = True + self.hs.config.registration.registrations_require_3pid = [] + self.hs.config.registration.auto_join_rooms = [] self.hs.config.captcha.enable_registration_captcha = False return self.hs @@ -1064,13 +1064,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): register.register_servlets, ] - def register_as_user(self, username): - self.make_request( - b"POST", - "/_matrix/client/r0/register?access_token=%s" % (self.service.token,), - {"username": username}, - ) - def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver() @@ -1107,7 +1100,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): def test_login_appservice_user(self): """Test that an appservice user can use /login""" - self.register_as_user(AS_USER) + self.register_appservice_user(AS_USER, self.service.token) params = { "type": login.LoginRestServlet.APPSERVICE_TYPE, @@ -1121,7 +1114,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): def test_login_appservice_user_bot(self): """Test that the appservice bot can use /login""" - self.register_as_user(AS_USER) + self.register_appservice_user(AS_USER, self.service.token) params = { "type": login.LoginRestServlet.APPSERVICE_TYPE, @@ -1135,7 +1128,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): def test_login_appservice_wrong_user(self): """Test that non-as users cannot login with the as token""" - self.register_as_user(AS_USER) + self.register_appservice_user(AS_USER, self.service.token) params = { "type": login.LoginRestServlet.APPSERVICE_TYPE, @@ -1149,7 +1142,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): def test_login_appservice_wrong_as(self): """Test that as users cannot login with wrong as token""" - self.register_as_user(AS_USER) + self.register_appservice_user(AS_USER, self.service.token) params = { "type": login.LoginRestServlet.APPSERVICE_TYPE, @@ -1165,7 +1158,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): """Test that users must provide a token when using the appservice login method """ - self.register_as_user(AS_USER) + self.register_appservice_user(AS_USER, self.service.token) params = { "type": login.LoginRestServlet.APPSERVICE_TYPE, diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py index 1d152352d1..56fe1a3d01 100644 --- a/tests/rest/client/test_presence.py +++ b/tests/rest/client/test_presence.py @@ -50,7 +50,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): PUT to the status endpoint with use_presence enabled will call set_state on the presence handler. """ - self.hs.config.use_presence = True + self.hs.config.server.use_presence = True body = {"presence": "here", "status_msg": "beep boop"} channel = self.make_request( diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 72a5a11b46..66dcfc9f88 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -50,7 +50,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( as_token, - self.hs.config.server_name, + self.hs.config.server.server_name, id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, sender="@as:test", @@ -74,7 +74,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( as_token, - self.hs.config.server_name, + self.hs.config.server.server_name, id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, sender="@as:test", @@ -147,7 +147,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_guest_registration(self): self.hs.config.key.macaroon_secret_key = "test" - self.hs.config.allow_guest_access = True + self.hs.config.registration.allow_guest_access = True channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") @@ -156,7 +156,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertDictContainsSubset(det_data, channel.json_body) def test_POST_disabled_guest_registration(self): - self.hs.config.allow_guest_access = False + self.hs.config.registration.allow_guest_access = False channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 4d09b5d07e..ce43de780b 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -21,11 +21,13 @@ from twisted.internet.error import DNSLookupError from twisted.test.proto_helpers import AccumulatingProtocol from synapse.config.oembed import OEmbedEndpointConfig +from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS from synapse.util.stringutils import parse_and_validate_mxc_uri from tests import unittest from tests.server import FakeTransport from tests.test_utils import SMALL_PNG +from tests.utils import MockClock try: import lxml @@ -851,3 +853,32 @@ class URLPreviewTests(unittest.HomeserverTestCase): 404, "URL cache thumbnail was unexpectedly retrieved from a storage provider", ) + + def test_cache_expiry(self): + """Test that URL cache files and thumbnails are cleaned up properly on expiry.""" + self.preview_url.clock = MockClock() + + _host, media_id = self._download_image() + + file_path = self.preview_url.filepaths.url_cache_filepath(media_id) + file_dirs = self.preview_url.filepaths.url_cache_filepath_dirs_to_delete( + media_id + ) + thumbnail_dir = self.preview_url.filepaths.url_cache_thumbnail_directory( + media_id + ) + thumbnail_dirs = self.preview_url.filepaths.url_cache_thumbnail_dirs_to_delete( + media_id + ) + + self.assertTrue(os.path.isfile(file_path)) + self.assertTrue(os.path.isdir(thumbnail_dir)) + + self.preview_url.clock.advance_time_msec(IMAGE_CACHE_EXPIRY_MS + 1) + self.get_success(self.preview_url._expire_url_cache_data()) + + for path in [file_path] + file_dirs + [thumbnail_dir] + thumbnail_dirs: + self.assertFalse( + os.path.exists(path), + f"{os.path.relpath(path, self.media_store_path)} was not deleted", + ) diff --git a/tests/server.py b/tests/server.py index 88dfa8058e..64645651ce 100644 --- a/tests/server.py +++ b/tests/server.py @@ -317,7 +317,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): def __init__(self): self.threadpool = ThreadPool(self) - self._tcp_callbacks = {} + self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {} self._udp = [] self.lookups: Dict[str, str] = {} self._thread_callbacks: Deque[Callable[[], None]] = deque() @@ -355,7 +355,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): def getThreadPool(self): return self.threadpool - def add_tcp_client_callback(self, host, port, callback): + def add_tcp_client_callback(self, host: str, port: int, callback: Callable): """Add a callback that will be invoked when we receive a connection attempt to the given IP/port using `connectTCP`. @@ -364,7 +364,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): """ self._tcp_callbacks[(host, port)] = callback - def connectTCP(self, host, port, factory, timeout=30, bindAddress=None): + def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None): """Fake L{IReactorTCP.connectTCP}.""" conn = super().connectTCP( @@ -475,7 +475,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): return server -def get_clock(): +def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]: clock = ThreadedMemoryReactorClock() hs_clock = Clock(clock) return clock, hs_clock diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 7f25200a5d..36c495954f 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -346,7 +346,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): invites = [] # Register as many users as the MAU limit allows. - for i in range(self.hs.config.max_mau_value): + for i in range(self.hs.config.server.max_mau_value): localpart = "user%d" % i user_id = self.register_user(localpart, "password") tok = self.login(localpart, "password") diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 944dbc34a2..d6b4cdd788 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -51,7 +51,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): @override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)}) def test_initialise_reserved_users(self): - threepids = self.hs.config.mau_limits_reserved_threepids + threepids = self.hs.config.server.mau_limits_reserved_threepids # register three users, of which two have reserved 3pids, and a third # which is a support user. @@ -101,9 +101,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): # XXX some of this is redundant. poking things into the config shouldn't # work, and in any case it's not obvious what we expect to happen when # we advance the reactor. - self.hs.config.max_mau_value = 0 + self.hs.config.server.max_mau_value = 0 self.reactor.advance(FORTY_DAYS) - self.hs.config.max_mau_value = 5 + self.hs.config.server.max_mau_value = 5 self.get_success(self.store.reap_monthly_active_users()) @@ -183,7 +183,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.get_success(d) count = self.get_success(self.store.get_monthly_active_count()) - self.assertEqual(count, self.hs.config.max_mau_value) + self.assertEqual(count, self.hs.config.server.max_mau_value) self.reactor.advance(FORTY_DAYS) @@ -199,7 +199,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): def test_reap_monthly_active_users_reserved_users(self): """Tests that reaping correctly handles reaping where reserved users are present""" - threepids = self.hs.config.mau_limits_reserved_threepids + threepids = self.hs.config.server.mau_limits_reserved_threepids initial_users = len(threepids) reserved_user_number = initial_users - 1 for i in range(initial_users): @@ -234,7 +234,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.get_success(d) count = self.get_success(self.store.get_monthly_active_count()) - self.assertEqual(count, self.hs.config.max_mau_value) + self.assertEqual(count, self.hs.config.server.max_mau_value) def test_populate_monthly_users_is_guest(self): # Test that guest users are not added to mau list @@ -294,7 +294,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): {"medium": "email", "address": user2_email}, ] - self.hs.config.mau_limits_reserved_threepids = threepids + self.hs.config.server.mau_limits_reserved_threepids = threepids d = self.store.db_pool.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 222e5d129d..6884ca9b7a 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -11,7 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, List, Set, Tuple +from unittest.mock import Mock, patch +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import UserTypes +from synapse.appservice import ApplicationService +from synapse.rest import admin +from synapse.rest.client import login, register, room +from synapse.server import HomeServer +from synapse.storage import DataStore +from synapse.util import Clock + +from tests.test_utils.event_injection import inject_member_event from tests.unittest import HomeserverTestCase, override_config ALICE = "@alice:a" @@ -21,8 +34,321 @@ BOBBY = "@bobby:a" BELA = "@somenickname:a" +class GetUserDirectoryTables: + """Helper functions that we want to reuse in tests/handlers/test_user_directory.py""" + + def __init__(self, store: DataStore): + self.store = store + + def _compress_shared( + self, shared: List[Dict[str, str]] + ) -> Set[Tuple[str, str, str]]: + """ + Compress a list of users who share rooms dicts to a list of tuples. + """ + r = set() + for i in shared: + r.add((i["user_id"], i["other_user_id"], i["room_id"])) + return r + + async def get_users_in_public_rooms(self) -> List[Tuple[str, str]]: + r = await self.store.db_pool.simple_select_list( + "users_in_public_rooms", None, ("user_id", "room_id") + ) + + retval = [] + for i in r: + retval.append((i["user_id"], i["room_id"])) + return retval + + async def get_users_who_share_private_rooms(self) -> List[Dict[str, str]]: + return await self.store.db_pool.simple_select_list( + "users_who_share_private_rooms", + None, + ["user_id", "other_user_id", "room_id"], + ) + + async def get_users_in_user_directory(self) -> Set[str]: + result = await self.store.db_pool.simple_select_list( + "user_directory", + None, + ["user_id"], + ) + return {row["user_id"] for row in result} + + +class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): + """Ensure that rebuilding the directory writes the correct data to the DB. + + See also tests/handlers/test_user_directory.py for similar checks. They + test the incremental updates, rather than the big rebuild. + """ + + servlets = [ + login.register_servlets, + admin.register_servlets, + room.register_servlets, + register.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.appservice = ApplicationService( + token="i_am_an_app_service", + hostname="test", + id="1234", + namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + sender="@as:test", + ) + + mock_load_appservices = Mock(return_value=[self.appservice]) + with patch( + "synapse.storage.databases.main.appservice.load_appservices", + mock_load_appservices, + ): + hs = super().make_homeserver(reactor, clock) + return hs + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastore() + self.user_dir_helper = GetUserDirectoryTables(self.store) + + def _purge_and_rebuild_user_dir(self) -> None: + """Nuke the user directory tables, start the background process to + repopulate them, and wait for the process to complete. This allows us + to inspect the outcome of the background process alone, without any of + the other incremental updates. + """ + self.get_success(self.store.update_user_directory_stream_pos(None)) + self.get_success(self.store.delete_all_from_user_dir()) + + shares_private = self.get_success( + self.user_dir_helper.get_users_who_share_private_rooms() + ) + public_users = self.get_success( + self.user_dir_helper.get_users_in_public_rooms() + ) + + # Nothing updated yet + self.assertEqual(shares_private, []) + self.assertEqual(public_users, []) + + # Ugh, have to reset this flag + self.store.db_pool.updates._all_done = False + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + { + "update_name": "populate_user_directory_createtables", + "progress_json": "{}", + }, + ) + ) + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + { + "update_name": "populate_user_directory_process_rooms", + "progress_json": "{}", + "depends_on": "populate_user_directory_createtables", + }, + ) + ) + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + { + "update_name": "populate_user_directory_process_users", + "progress_json": "{}", + "depends_on": "populate_user_directory_process_rooms", + }, + ) + ) + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + { + "update_name": "populate_user_directory_cleanup", + "progress_json": "{}", + "depends_on": "populate_user_directory_process_users", + }, + ) + ) + + while not self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db_pool.updates.do_next_background_update(100), by=0.1 + ) + + def test_initial(self) -> None: + """ + The user directory's initial handler correctly updates the search tables. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + u3 = self.register_user("user3", "pass") + u3_token = self.login(u3, "pass") + + room = self.helper.create_room_as(u1, is_public=True, tok=u1_token) + self.helper.invite(room, src=u1, targ=u2, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) + + private_room = self.helper.create_room_as(u1, is_public=False, tok=u1_token) + self.helper.invite(private_room, src=u1, targ=u3, tok=u1_token) + self.helper.join(private_room, user=u3, tok=u3_token) + + self.get_success(self.store.update_user_directory_stream_pos(None)) + self.get_success(self.store.delete_all_from_user_dir()) + + shares_private = self.get_success( + self.user_dir_helper.get_users_who_share_private_rooms() + ) + public_users = self.get_success( + self.user_dir_helper.get_users_in_public_rooms() + ) + + # Nothing updated yet + self.assertEqual(shares_private, []) + self.assertEqual(public_users, []) + + # Do the initial population of the user directory via the background update + self._purge_and_rebuild_user_dir() + + shares_private = self.get_success( + self.user_dir_helper.get_users_who_share_private_rooms() + ) + public_users = self.get_success( + self.user_dir_helper.get_users_in_public_rooms() + ) + + # User 1 and User 2 are in the same public room + self.assertEqual(set(public_users), {(u1, room), (u2, room)}) + + # User 1 and User 3 share private rooms + self.assertEqual( + self.user_dir_helper._compress_shared(shares_private), + {(u1, u3, private_room), (u3, u1, private_room)}, + ) + + # All three should have entries in the directory + users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) + self.assertEqual(users, {u1, u2, u3}) + + # The next three tests (test_population_excludes_*) all set up + # - A normal user included in the user dir + # - A public and private room created by that user + # - A user excluded from the room dir, belonging to both rooms + + # They match similar logic in handlers/test_user_directory.py But that tests + # updating the directory; this tests rebuilding it from scratch. + + def _create_rooms_and_inject_memberships( + self, creator: str, token: str, joiner: str + ) -> Tuple[str, str]: + """Create a public and private room as a normal user. + Then get the `joiner` into those rooms. + """ + public_room = self.helper.create_room_as( + creator, + is_public=True, + # See https://github.com/matrix-org/synapse/issues/10951 + extra_content={"visibility": "public"}, + tok=token, + ) + private_room = self.helper.create_room_as(creator, is_public=False, tok=token) + + # HACK: get the user into these rooms + self.get_success(inject_member_event(self.hs, public_room, joiner, "join")) + self.get_success(inject_member_event(self.hs, private_room, joiner, "join")) + + return public_room, private_room + + def _check_room_sharing_tables( + self, normal_user: str, public_room: str, private_room: str + ) -> None: + # After rebuilding the directory, we should only see the normal user. + users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) + self.assertEqual(users, {normal_user}) + in_public_rooms = self.get_success( + self.user_dir_helper.get_users_in_public_rooms() + ) + self.assertEqual(set(in_public_rooms), {(normal_user, public_room)}) + in_private_rooms = self.get_success( + self.user_dir_helper.get_users_who_share_private_rooms() + ) + self.assertEqual(in_private_rooms, []) + + def test_population_excludes_support_user(self) -> None: + # Create a normal and support user. + user = self.register_user("user", "pass") + token = self.login(user, "pass") + support = "@support1:test" + self.get_success( + self.store.register_user( + user_id=support, password_hash=None, user_type=UserTypes.SUPPORT + ) + ) + + # Join the support user to rooms owned by the normal user. + public, private = self._create_rooms_and_inject_memberships( + user, token, support + ) + + # Rebuild the directory. + self._purge_and_rebuild_user_dir() + + # Check the support user is not in the directory. + self._check_room_sharing_tables(user, public, private) + + def test_population_excludes_deactivated_user(self) -> None: + user = self.register_user("naughty", "pass") + admin = self.register_user("admin", "pass", admin=True) + admin_token = self.login(admin, "pass") + + # Deactivate the user. + channel = self.make_request( + "PUT", + f"/_synapse/admin/v2/users/{user}", + access_token=admin_token, + content={"deactivated": True}, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["deactivated"], True) + + # Join the deactivated user to rooms owned by the admin. + # Is this something that could actually happen outside of a test? + public, private = self._create_rooms_and_inject_memberships( + admin, admin_token, user + ) + + # Rebuild the user dir. The deactivated user should be missing. + self._purge_and_rebuild_user_dir() + self._check_room_sharing_tables(admin, public, private) + + def test_population_excludes_appservice_user(self) -> None: + # Register an AS user. + user = self.register_user("user", "pass") + token = self.login(user, "pass") + as_user = self.register_appservice_user("as_user_potato", self.appservice.token) + + # Join the AS user to rooms owned by the normal user. + public, private = self._create_rooms_and_inject_memberships( + user, token, as_user + ) + + # Rebuild the directory. + self._purge_and_rebuild_user_dir() + + # Check the AS user is not in the directory. + self._check_room_sharing_tables(user, public, private) + + class UserDirectoryStoreTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastore() # alice and bob are both in !room_id. bobby is not but shares @@ -33,7 +359,7 @@ class UserDirectoryStoreTestCase(HomeserverTestCase): self.get_success(self.store.update_profile_in_user_dir(BELA, "Bela", None)) self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))) - def test_search_user_dir(self): + def test_search_user_dir(self) -> None: # normally when alice searches the directory she should just find # bob because bobby doesn't share a room with her. r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10)) @@ -44,7 +370,7 @@ class UserDirectoryStoreTestCase(HomeserverTestCase): ) @override_config({"user_directory": {"search_all_users": True}}) - def test_search_user_dir_all_users(self): + def test_search_user_dir_all_users(self) -> None: r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10)) self.assertFalse(r["limited"]) self.assertEqual(2, len(r["results"])) @@ -58,7 +384,7 @@ class UserDirectoryStoreTestCase(HomeserverTestCase): ) @override_config({"user_directory": {"search_all_users": True}}) - def test_search_user_dir_stop_words(self): + def test_search_user_dir_stop_words(self) -> None: """Tests that a user can look up another user by searching for the start if its display name even if that name happens to be a common English word that would usually be ignored in full text searches. diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 1a4d078780..cf407c51cf 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -38,21 +38,19 @@ class EventAuthTestCase(unittest.TestCase): } # creator should be able to send state - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V1, _random_state_event(creator), auth_events, - do_sig_check=False, ) # joiner should not be able to send state self.assertRaises( AuthError, - event_auth.check, + event_auth.check_auth_rules_for_event, RoomVersions.V1, _random_state_event(joiner), auth_events, - do_sig_check=False, ) def test_state_default_level(self): @@ -77,19 +75,17 @@ class EventAuthTestCase(unittest.TestCase): # pleb should not be able to send state self.assertRaises( AuthError, - event_auth.check, + event_auth.check_auth_rules_for_event, RoomVersions.V1, _random_state_event(pleb), auth_events, - do_sig_check=False, ), # king should be able to send state - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V1, _random_state_event(king), auth_events, - do_sig_check=False, ) def test_alias_event(self): @@ -102,37 +98,33 @@ class EventAuthTestCase(unittest.TestCase): } # creator should be able to send aliases - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V1, _alias_event(creator), auth_events, - do_sig_check=False, ) # Reject an event with no state key. with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V1, _alias_event(creator, state_key=""), auth_events, - do_sig_check=False, ) # If the domain of the sender does not match the state key, reject. with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V1, _alias_event(creator, state_key="test.com"), auth_events, - do_sig_check=False, ) # Note that the member does *not* need to be in the room. - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V1, _alias_event(other), auth_events, - do_sig_check=False, ) def test_msc2432_alias_event(self): @@ -145,34 +137,30 @@ class EventAuthTestCase(unittest.TestCase): } # creator should be able to send aliases - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _alias_event(creator), auth_events, - do_sig_check=False, ) # No particular checks are done on the state key. - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _alias_event(creator, state_key=""), auth_events, - do_sig_check=False, ) - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _alias_event(creator, state_key="test.com"), auth_events, - do_sig_check=False, ) # Per standard auth rules, the member must be in the room. with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _alias_event(other), auth_events, - do_sig_check=False, ) def test_msc2209(self): @@ -192,20 +180,18 @@ class EventAuthTestCase(unittest.TestCase): } # pleb should be able to modify the notifications power level. - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V1, _power_levels_event(pleb, {"notifications": {"room": 100}}), auth_events, - do_sig_check=False, ) # But an MSC2209 room rejects this change. with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _power_levels_event(pleb, {"notifications": {"room": 100}}), auth_events, - do_sig_check=False, ) def test_join_rules_public(self): @@ -222,59 +208,53 @@ class EventAuthTestCase(unittest.TestCase): } # Check join. - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) # A user cannot be force-joined to a room. with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _member_event(pleb, "join", sender=creator), auth_events, - do_sig_check=False, ) # Banned should be rejected. auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban") with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) # A user who left can re-join. auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) # A user can send a join if they're in the room. auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) # A user can accept an invite. auth_events[("m.room.member", pleb)] = _member_event( pleb, "invite", sender=creator ) - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) def test_join_rules_invite(self): @@ -292,60 +272,54 @@ class EventAuthTestCase(unittest.TestCase): # A join without an invite is rejected. with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) # A user cannot be force-joined to a room. with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _member_event(pleb, "join", sender=creator), auth_events, - do_sig_check=False, ) # Banned should be rejected. auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban") with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) # A user who left cannot re-join. auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) # A user can send a join if they're in the room. auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) # A user can accept an invite. auth_events[("m.room.member", pleb)] = _member_event( pleb, "invite", sender=creator ) - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) def test_join_rules_msc3083_restricted(self): @@ -370,11 +344,10 @@ class EventAuthTestCase(unittest.TestCase): # Older room versions don't understand this join rule with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) # A properly formatted join event should work. @@ -384,11 +357,10 @@ class EventAuthTestCase(unittest.TestCase): EventContentFields.AUTHORISING_USER: "@creator:example.com" }, ) - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V8, authorised_join_event, auth_events, - do_sig_check=False, ) # A join issued by a specific user works (i.e. the power level checks @@ -400,7 +372,7 @@ class EventAuthTestCase(unittest.TestCase): pl_auth_events[("m.room.member", "@inviter:foo.test")] = _join_event( "@inviter:foo.test" ) - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V8, _join_event( pleb, @@ -409,16 +381,14 @@ class EventAuthTestCase(unittest.TestCase): }, ), pl_auth_events, - do_sig_check=False, ) # A join which is missing an authorised server is rejected. with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V8, _join_event(pleb), auth_events, - do_sig_check=False, ) # An join authorised by a user who is not in the room is rejected. @@ -427,7 +397,7 @@ class EventAuthTestCase(unittest.TestCase): creator, {"invite": 100, "users": {"@other:example.com": 150}} ) with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V8, _join_event( pleb, @@ -436,13 +406,12 @@ class EventAuthTestCase(unittest.TestCase): }, ), auth_events, - do_sig_check=False, ) # A user cannot be force-joined to a room. (This uses an event which # *would* be valid, but is sent be a different user.) with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V8, _member_event( pleb, @@ -453,36 +422,32 @@ class EventAuthTestCase(unittest.TestCase): }, ), auth_events, - do_sig_check=False, ) # Banned should be rejected. auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban") with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V8, authorised_join_event, auth_events, - do_sig_check=False, ) # A user who left can re-join. auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V8, authorised_join_event, auth_events, - do_sig_check=False, ) # A user can send a join if they're in the room. (This doesn't need to # be authorised since the user is already joined.) auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V8, _join_event(pleb), auth_events, - do_sig_check=False, ) # A user can accept an invite. (This doesn't need to be authorised since @@ -490,11 +455,10 @@ class EventAuthTestCase(unittest.TestCase): auth_events[("m.room.member", pleb)] = _member_event( pleb, "invite", sender=creator ) - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V8, _join_event(pleb), auth_events, - do_sig_check=False, ) diff --git a/tests/test_federation.py b/tests/test_federation.py index c51e018da1..24fc77d7a7 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -82,7 +82,6 @@ class MessageAcceptTests(unittest.HomeserverTestCase): event, context, state=None, - claimed_auth_event_map=None, backfilled=False, ): return context diff --git a/tests/test_mau.py b/tests/test_mau.py index 66111eb367..80ab40e255 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -165,7 +165,7 @@ class TestMauLimit(unittest.HomeserverTestCase): @override_config({"mau_trial_days": 1}) def test_trial_users_cant_come_back(self): - self.hs.config.mau_trial_days = 1 + self.hs.config.server.mau_trial_days = 1 # We should be able to register more than the limit initially token1 = self.create_user("kermit1") diff --git a/tests/unittest.py b/tests/unittest.py index 7a6f5954d0..ae393ee53e 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -20,7 +20,7 @@ import inspect import logging import secrets import time -from typing import Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union from unittest.mock import Mock, patch from canonicaljson import json @@ -28,6 +28,7 @@ from canonicaljson import json from twisted.internet.defer import Deferred, ensureDeferred, succeed from twisted.python.failure import Failure from twisted.python.threadpool import ThreadPool +from twisted.test.proto_helpers import MemoryReactor from twisted.trial import unittest from twisted.web.resource import Resource @@ -46,6 +47,7 @@ from synapse.logging.context import ( ) from synapse.server import HomeServer from synapse.types import UserID, create_requester +from synapse.util import Clock from synapse.util.httpresourcetree import create_resource_tree from synapse.util.ratelimitutils import FederationRateLimiter @@ -232,7 +234,7 @@ class HomeserverTestCase(TestCase): # Honour the `use_frozen_dicts` config option. We have to do this # manually because this is taken care of in the app `start` code, which # we don't run. Plus we want to reset it on tearDown. - events.USE_FROZEN_DICTS = self.hs.config.use_frozen_dicts + events.USE_FROZEN_DICTS = self.hs.config.server.use_frozen_dicts if self.hs is None: raise Exception("No homeserver returned from make_homeserver.") @@ -371,7 +373,7 @@ class HomeserverTestCase(TestCase): return config - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): """ Prepare for the test. This involves things like mocking out parts of the homeserver, or building test data common across the whole test @@ -447,7 +449,7 @@ class HomeserverTestCase(TestCase): client_ip, ) - def setup_test_homeserver(self, *args, **kwargs): + def setup_test_homeserver(self, *args: Any, **kwargs: Any) -> HomeServer: """ Set up the test homeserver, meant to be called by the overridable make_homeserver. It automatically passes through the test class's @@ -558,7 +560,7 @@ class HomeserverTestCase(TestCase): Returns: The MXID of the new user. """ - self.hs.config.registration_shared_secret = "shared" + self.hs.config.registration.registration_shared_secret = "shared" # Create the user channel = self.make_request("GET", "/_synapse/admin/v1/register") @@ -594,6 +596,35 @@ class HomeserverTestCase(TestCase): user_id = channel.json_body["user_id"] return user_id + def register_appservice_user( + self, + username: str, + appservice_token: str, + ) -> str: + """Register an appservice user as an application service. + Requires the client-facing registration API be registered. + + Args: + username: the user to be registered by an application service. + Should be a full username, i.e. ""@localpart:hostname" as opposed to just "localpart" + appservice_token: the acccess token for that application service. + + Raises: if the request to '/register' does not return 200 OK. + + Returns: the MXID of the new user. + """ + channel = self.make_request( + "POST", + "/_matrix/client/r0/register", + { + "username": username, + "type": "m.login.application_service", + }, + access_token=appservice_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + return channel.json_body["user_id"] + def login( self, username, |