diff options
73 files changed, 1859 insertions, 1454 deletions
diff --git a/changelog.d/12087.bugfix b/changelog.d/12087.bugfix new file mode 100644 index 0000000000..6dacdddd0d --- /dev/null +++ b/changelog.d/12087.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug which caused the `/_matrix/federation/v1/state` and `.../state_ids` endpoints to return incorrect or invalid data when called for an event which we have stored as an "outlier". diff --git a/changelog.d/12198.misc b/changelog.d/12198.misc new file mode 100644 index 0000000000..6b184a9053 --- /dev/null +++ b/changelog.d/12198.misc @@ -0,0 +1 @@ +Add tests for database transaction callbacks. diff --git a/changelog.d/12199.misc b/changelog.d/12199.misc new file mode 100644 index 0000000000..16dec1d26d --- /dev/null +++ b/changelog.d/12199.misc @@ -0,0 +1 @@ +Handle cancellation in `DatabasePool.runInteraction()`. diff --git a/changelog.d/12216.misc b/changelog.d/12216.misc new file mode 100644 index 0000000000..dc398ac1e0 --- /dev/null +++ b/changelog.d/12216.misc @@ -0,0 +1 @@ +Add missing type hints for cache storage. diff --git a/changelog.d/12219.misc b/changelog.d/12219.misc new file mode 100644 index 0000000000..6079414092 --- /dev/null +++ b/changelog.d/12219.misc @@ -0,0 +1 @@ +Clean-up logic around rebasing URLs for URL image previews. diff --git a/changelog.d/12224.misc b/changelog.d/12224.misc new file mode 100644 index 0000000000..b67a701dbb --- /dev/null +++ b/changelog.d/12224.misc @@ -0,0 +1 @@ +Add type hints to tests files. diff --git a/changelog.d/12225.misc b/changelog.d/12225.misc new file mode 100644 index 0000000000..23105c727c --- /dev/null +++ b/changelog.d/12225.misc @@ -0,0 +1 @@ +Use the `ignored_users` table in additional places instead of re-parsing the account data. diff --git a/changelog.d/12227.misc b/changelog.d/12227.misc new file mode 100644 index 0000000000..41c9dcbd37 --- /dev/null +++ b/changelog.d/12227.misc @@ -0,0 +1 @@ +Refactor the relations endpoints to add a `RelationsHandler`. diff --git a/changelog.d/12228.bugfix b/changelog.d/12228.bugfix new file mode 100644 index 0000000000..4755777139 --- /dev/null +++ b/changelog.d/12228.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.53.0 where an unnecessary query could be performed when fetching bundled aggregations for threads. diff --git a/changelog.d/12231.doc b/changelog.d/12231.doc new file mode 100644 index 0000000000..16593d2b92 --- /dev/null +++ b/changelog.d/12231.doc @@ -0,0 +1 @@ +Fix the link to the module documentation in the legacy spam checker warning message. diff --git a/changelog.d/12232.misc b/changelog.d/12232.misc new file mode 100644 index 0000000000..4a4132edff --- /dev/null +++ b/changelog.d/12232.misc @@ -0,0 +1 @@ +Refactor relations tests to improve code re-use. diff --git a/changelog.d/12237.misc b/changelog.d/12237.misc new file mode 100644 index 0000000000..41c9dcbd37 --- /dev/null +++ b/changelog.d/12237.misc @@ -0,0 +1 @@ +Refactor the relations endpoints to add a `RelationsHandler`. diff --git a/changelog.d/12240.misc b/changelog.d/12240.misc new file mode 100644 index 0000000000..c5b6356799 --- /dev/null +++ b/changelog.d/12240.misc @@ -0,0 +1 @@ +Add type hints to tests files. \ No newline at end of file diff --git a/changelog.d/12242.misc b/changelog.d/12242.misc new file mode 100644 index 0000000000..38e7e0f7d1 --- /dev/null +++ b/changelog.d/12242.misc @@ -0,0 +1 @@ +Generate announcement links in the release script. diff --git a/changelog.d/12243.doc b/changelog.d/12243.doc new file mode 100644 index 0000000000..b2031f0a40 --- /dev/null +++ b/changelog.d/12243.doc @@ -0,0 +1 @@ +Remove incorrect prefixes in the worker documentation for some endpoints. diff --git a/changelog.d/12244.misc b/changelog.d/12244.misc new file mode 100644 index 0000000000..950d48e4c6 --- /dev/null +++ b/changelog.d/12244.misc @@ -0,0 +1 @@ +Improve error message when dependencies check finds a broken installation. \ No newline at end of file diff --git a/changelog.d/12246.doc b/changelog.d/12246.doc new file mode 100644 index 0000000000..e7fcc1b99c --- /dev/null +++ b/changelog.d/12246.doc @@ -0,0 +1 @@ +Correct `check_username_for_spam` annotations and docs. \ No newline at end of file diff --git a/changelog.d/12248.misc b/changelog.d/12248.misc new file mode 100644 index 0000000000..2b1290d1e1 --- /dev/null +++ b/changelog.d/12248.misc @@ -0,0 +1 @@ +Add missing type hints for storage. \ No newline at end of file diff --git a/changelog.d/12256.misc b/changelog.d/12256.misc new file mode 100644 index 0000000000..c5b6356799 --- /dev/null +++ b/changelog.d/12256.misc @@ -0,0 +1 @@ +Add type hints to tests files. \ No newline at end of file diff --git a/changelog.d/12258.misc b/changelog.d/12258.misc new file mode 100644 index 0000000000..80024c8e91 --- /dev/null +++ b/changelog.d/12258.misc @@ -0,0 +1 @@ +Compress metrics HTTP resource when enabled. Contributed by Nick @ Beeper. diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md index 2b672b78f9..472d957180 100644 --- a/docs/modules/spam_checker_callbacks.md +++ b/docs/modules/spam_checker_callbacks.md @@ -172,7 +172,7 @@ any of the subsequent implementations of this callback. _First introduced in Synapse v1.37.0_ ```python -async def check_username_for_spam(user_profile: Dict[str, str]) -> bool +async def check_username_for_spam(user_profile: synapse.module_api.UserProfile) -> bool ``` Called when computing search results in the user directory. The module must return a @@ -182,9 +182,11 @@ search results; otherwise return `False`. The profile is represented as a dictionary with the following keys: -* `user_id`: The Matrix ID for this user. -* `display_name`: The user's display name. -* `avatar_url`: The `mxc://` URL to the user's avatar. +* `user_id: str`. The Matrix ID for this user. +* `display_name: Optional[str]`. The user's display name, or `None` if this user + has not set a display name. +* `avatar_url: Optional[str]`. The `mxc://` URL to the user's avatar, or `None` + if this user has not set an avatar. The module is given a copy of the original dictionary, so modifying it from within the module cannot modify a user's profile when included in user directory search results. diff --git a/docs/workers.md b/docs/workers.md index 8751134e65..9eb4194e4d 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -185,8 +185,8 @@ worker: refer to the [stream writers](#stream-writers) section below for further information. # Sync requests - ^/_matrix/client/(v2_alpha|r0|v3)/sync$ - ^/_matrix/client/(api/v1|v2_alpha|r0|v3)/events$ + ^/_matrix/client/(r0|v3)/sync$ + ^/_matrix/client/(api/v1|r0|v3)/events$ ^/_matrix/client/(api/v1|r0|v3)/initialSync$ ^/_matrix/client/(api/v1|r0|v3)/rooms/[^/]+/initialSync$ @@ -200,13 +200,9 @@ information. ^/_matrix/federation/v1/query/ ^/_matrix/federation/v1/make_join/ ^/_matrix/federation/v1/make_leave/ - ^/_matrix/federation/v1/send_join/ - ^/_matrix/federation/v2/send_join/ - ^/_matrix/federation/v1/send_leave/ - ^/_matrix/federation/v2/send_leave/ - ^/_matrix/federation/v1/invite/ - ^/_matrix/federation/v2/invite/ - ^/_matrix/federation/v1/query_auth/ + ^/_matrix/federation/(v1|v2)/send_join/ + ^/_matrix/federation/(v1|v2)/send_leave/ + ^/_matrix/federation/(v1|v2)/invite/ ^/_matrix/federation/v1/event_auth/ ^/_matrix/federation/v1/exchange_third_party_invite/ ^/_matrix/federation/v1/user/devices/ @@ -274,6 +270,8 @@ information. Additionally, the following REST endpoints can be handled for GET requests: ^/_matrix/federation/v1/groups/ + ^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/ + ^/_matrix/client/(r0|v3|unstable)/groups/ Pagination requests can also be handled, but all requests for a given room must be routed to the same instance. Additionally, care must be taken to @@ -397,23 +395,23 @@ the stream writer for the `typing` stream: The following endpoints should be routed directly to the worker configured as the stream writer for the `to_device` stream: - ^/_matrix/client/(api/v1|r0|v3|unstable)/sendToDevice/ + ^/_matrix/client/(r0|v3|unstable)/sendToDevice/ ##### The `account_data` stream The following endpoints should be routed directly to the worker configured as the stream writer for the `account_data` stream: - ^/_matrix/client/(api/v1|r0|v3|unstable)/.*/tags - ^/_matrix/client/(api/v1|r0|v3|unstable)/.*/account_data + ^/_matrix/client/(r0|v3|unstable)/.*/tags + ^/_matrix/client/(r0|v3|unstable)/.*/account_data ##### The `receipts` stream The following endpoints should be routed directly to the worker configured as the stream writer for the `receipts` stream: - ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/receipt - ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/read_markers + ^/_matrix/client/(r0|v3|unstable)/rooms/.*/receipt + ^/_matrix/client/(r0|v3|unstable)/rooms/.*/read_markers ##### The `presence` stream @@ -528,7 +526,7 @@ Note that if a reverse proxy is used , then `/_matrix/media/` must be routed for Handles searches in the user directory. It can handle REST endpoints matching the following regular expressions: - ^/_matrix/client/(api/v1|r0|v3|unstable)/user_directory/search$ + ^/_matrix/client/(r0|v3|unstable)/user_directory/search$ When using this worker you must also set `update_user_directory: False` in the shared configuration file to stop the main synapse running background @@ -540,7 +538,7 @@ Proxies some frequently-requested client endpoints to add caching and remove load from the main synapse. It can handle REST endpoints matching the following regular expressions: - ^/_matrix/client/(api/v1|r0|v3|unstable)/keys/upload + ^/_matrix/client/(r0|v3|unstable)/keys/upload If `use_presence` is False in the homeserver config, it can also handle REST endpoints matching the following regular expressions: diff --git a/mypy.ini b/mypy.ini index f9c39fcaae..24d4ba15d4 100644 --- a/mypy.ini +++ b/mypy.ini @@ -42,9 +42,6 @@ exclude = (?x) |synapse/storage/databases/main/cache.py |synapse/storage/databases/main/devices.py |synapse/storage/databases/main/event_federation.py - |synapse/storage/databases/main/group_server.py - |synapse/storage/databases/main/metrics.py - |synapse/storage/databases/main/monthly_active_users.py |synapse/storage/databases/main/push_rule.py |synapse/storage/databases/main/receipts.py |synapse/storage/databases/main/roommember.py @@ -66,14 +63,6 @@ exclude = (?x) |tests/federation/test_federation_server.py |tests/federation/transport/test_knocking.py |tests/federation/transport/test_server.py - |tests/handlers/test_cas.py - |tests/handlers/test_directory.py - |tests/handlers/test_e2e_keys.py - |tests/handlers/test_federation.py - |tests/handlers/test_oidc.py - |tests/handlers/test_presence.py - |tests/handlers/test_profile.py - |tests/handlers/test_saml.py |tests/handlers/test_typing.py |tests/http/federation/test_matrix_federation_agent.py |tests/http/federation/test_srv_resolver.py @@ -85,7 +74,6 @@ exclude = (?x) |tests/logging/test_terse_json.py |tests/module_api/test_api.py |tests/push/test_email.py - |tests/push/test_http.py |tests/push/test_presentable_names.py |tests/push/test_push_rule_evaluator.py |tests/rest/client/test_transactions.py @@ -94,12 +82,7 @@ exclude = (?x) |tests/server.py |tests/server_notices/test_resource_limits_server_notices.py |tests/state/test_v2.py - |tests/storage/test_background_update.py |tests/storage/test_base.py - |tests/storage/test_client_ips.py - |tests/storage/test_database.py - |tests/storage/test_event_federation.py - |tests/storage/test_id_generators.py |tests/storage/test_roommember.py |tests/test_metrics.py |tests/test_phone_home.py diff --git a/scripts-dev/release.py b/scripts-dev/release.py index 046453e65f..685fa32b03 100755 --- a/scripts-dev/release.py +++ b/scripts-dev/release.py @@ -66,11 +66,15 @@ def cli(): ./scripts-dev/release.py tag - # ... wait for asssets to build ... + # ... wait for assets to build ... ./scripts-dev/release.py publish ./scripts-dev/release.py upload + # Optional: generate some nice links for the announcement + + ./scripts-dev/release.py upload + If the env var GH_TOKEN (or GITHUB_TOKEN) is set, or passed into the `tag`/`publish` command, then a new draft release will be created/published. """ @@ -415,6 +419,41 @@ def upload(): ) +@cli.command() +def announce(): + """Generate markdown to announce the release.""" + + current_version, _, _ = parse_version_from_module() + tag_name = f"v{current_version}" + + click.echo( + f""" +Hi everyone. Synapse {current_version} has just been released. + +[notes](https://github.com/matrix-org/synapse/releases/tag/{tag_name}) |\ +[docker](https://hub.docker.com/r/matrixdotorg/synapse/tags?name={tag_name}) | \ +[debs](https://packages.matrix.org/debian/) | \ +[pypi](https://pypi.org/project/matrix-synapse/{current_version}/)""" + ) + + if "rc" in tag_name: + click.echo( + """ +Announce the RC in +- #homeowners:matrix.org (Synapse Announcements) +- #synapse-dev:matrix.org""" + ) + else: + click.echo( + """ +Announce the release in +- #homeowners:matrix.org (Synapse Announcements), bumping the version in the topic +- #synapse:matrix.org (Synapse Admins), bumping the version in the topic +- #synapse-dev:matrix.org +- #synapse-package-maintainers:matrix.org""" + ) + + def parse_version_from_module() -> Tuple[ version.Version, redbaron.RedBaron, redbaron.Node ]: diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index e4dc04c0b4..ad2b7c9515 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -261,7 +261,10 @@ class SynapseHomeServer(HomeServer): resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) if name == "metrics" and self.config.metrics.enable_metrics: - resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) + metrics_resource: Resource = MetricsResource(RegistryProxy) + if compress: + metrics_resource = gz_wrap(metrics_resource) + resources[METRICS_PREFIX] = metrics_resource if name == "replication": resources[REPLICATION_PREFIX] = ReplicationRestResource(self) diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py index a233a9ce03..4c52103b1c 100644 --- a/synapse/config/spam_checker.py +++ b/synapse/config/spam_checker.py @@ -25,8 +25,8 @@ logger = logging.getLogger(__name__) LEGACY_SPAM_CHECKER_WARNING = """ This server is using a spam checker module that is implementing the deprecated spam checker interface. Please check with the module's maintainer to see if a new version -supporting Synapse's generic modules system is available. -For more information, please see https://matrix-org.github.io/synapse/latest/modules.html +supporting Synapse's generic modules system is available. For more information, please +see https://matrix-org.github.io/synapse/latest/modules/index.html ---------------------------------------------------------------------------------------""" diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 60904a55f5..cd80fcf9d1 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -21,7 +21,6 @@ from typing import ( Awaitable, Callable, Collection, - Dict, List, Optional, Tuple, @@ -31,7 +30,7 @@ from typing import ( from synapse.rest.media.v1._base import FileInfo from synapse.rest.media.v1.media_storage import ReadableFileWrapper from synapse.spam_checker_api import RegistrationBehaviour -from synapse.types import RoomAlias +from synapse.types import RoomAlias, UserProfile from synapse.util.async_helpers import maybe_awaitable if TYPE_CHECKING: @@ -50,7 +49,7 @@ USER_MAY_SEND_3PID_INVITE_CALLBACK = Callable[[str, str, str, str], Awaitable[bo USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]] USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[[str, RoomAlias], Awaitable[bool]] USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]] -CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[Dict[str, str]], Awaitable[bool]] +CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[UserProfile], Awaitable[bool]] LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[ [ Optional[dict], @@ -383,7 +382,7 @@ class SpamChecker: return True - async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool: + async def check_username_for_spam(self, user_profile: UserProfile) -> bool: """Checks if a user ID or display name are considered "spammy" by this server. If the server considers a username spammy, then it will not be included in diff --git a/synapse/events/utils.py b/synapse/events/utils.py index a0520068e0..7120062127 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -38,8 +38,8 @@ from synapse.util.frozenutils import unfreeze from . import EventBase if TYPE_CHECKING: + from synapse.handlers.relations import BundledAggregations from synapse.server import HomeServer - from synapse.storage.databases.main.relations import BundledAggregations # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 482bbdd867..af2d0f7d79 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -22,7 +22,6 @@ from typing import ( Callable, Collection, Dict, - Iterable, List, Optional, Tuple, @@ -577,10 +576,10 @@ class FederationServer(FederationBase): async def _on_context_state_request_compute( self, room_id: str, event_id: Optional[str] ) -> Dict[str, list]: + pdus: Collection[EventBase] if event_id: - pdus: Iterable[EventBase] = await self.handler.get_state_for_pdu( - room_id, event_id - ) + event_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) + pdus = await self.store.get_events_as_list(event_ids) else: pdus = (await self.state.get_current_state(room_id)).values() diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index db39aeabde..350ec9c03a 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -950,54 +950,35 @@ class FederationHandler: return event - async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]: - """Returns the state at the event. i.e. not including said event.""" - - event = await self.store.get_event(event_id, check_room_id=room_id) - - state_groups = await self.state_store.get_state_groups(room_id, [event_id]) - - if state_groups: - _, state = list(state_groups.items()).pop() - results = {(e.type, e.state_key): e for e in state} - - if event.is_state(): - # Get previous state - if "replaces_state" in event.unsigned: - prev_id = event.unsigned["replaces_state"] - if prev_id != event.event_id: - prev_event = await self.store.get_event(prev_id) - results[(event.type, event.state_key)] = prev_event - else: - del results[(event.type, event.state_key)] - - res = list(results.values()) - return res - else: - return [] - async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: """Returns the state at the event. i.e. not including said event.""" event = await self.store.get_event(event_id, check_room_id=room_id) + if event.internal_metadata.outlier: + raise NotFoundError("State not known at event %s" % (event_id,)) state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id]) - if state_groups: - _, state = list(state_groups.items()).pop() - results = state + # get_state_groups_ids should return exactly one result + assert len(state_groups) == 1 - if event.is_state(): - # Get previous state - if "replaces_state" in event.unsigned: - prev_id = event.unsigned["replaces_state"] - if prev_id != event.event_id: - results[(event.type, event.state_key)] = prev_id - else: - results.pop((event.type, event.state_key), None) + state_map = next(iter(state_groups.values())) - return list(results.values()) - else: - return [] + state_key = event.get_state_key() + if state_key is not None: + # the event was not rejected (get_event raises a NotFoundError for rejected + # events) so the state at the event should include the event itself. + assert ( + state_map.get((event.type, state_key)) == event.event_id + ), "State at event did not include event itself" + + # ... but we need the state *before* that event + if "replaces_state" in event.unsigned: + prev_id = event.unsigned["replaces_state"] + state_map[(event.type, state_key)] = prev_id + else: + del state_map[(event.type, state_key)] + + return list(state_map.values()) async def on_backfill_request( self, origin: str, room_id: str, pdu_list: List[str], limit: int diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 60059fec3e..876b879483 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set import attr @@ -134,6 +134,7 @@ class PaginationHandler: self.clock = hs.get_clock() self._server_name = hs.hostname self._room_shutdown_handler = hs.get_room_shutdown_handler() + self._relations_handler = hs.get_relations_handler() self.pagination_lock = ReadWriteLock() # IDs of rooms in which there currently an active purge *or delete* operation. @@ -422,7 +423,7 @@ class PaginationHandler: pagin_config: PaginationConfig, as_client_event: bool = True, event_filter: Optional[Filter] = None, - ) -> Dict[str, Any]: + ) -> JsonDict: """Get messages in a room. Args: @@ -431,6 +432,7 @@ class PaginationHandler: pagin_config: The pagination config rules to apply, if any. as_client_event: True to get events in client-server format. event_filter: Filter to apply to results or None + Returns: Pagination API results """ @@ -538,7 +540,9 @@ class PaginationHandler: state_dict = await self.store.get_events(list(state_ids.values())) state = state_dict.values() - aggregations = await self.store.get_bundled_aggregations(events, user_id) + aggregations = await self._relations_handler.get_bundled_aggregations( + events, user_id + ) time_now = self.clock.time_msec() diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py new file mode 100644 index 0000000000..57135d4519 --- /dev/null +++ b/synapse/handlers/relations.py @@ -0,0 +1,264 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast + +import attr +from frozendict import frozendict + +from synapse.api.constants import RelationTypes +from synapse.api.errors import SynapseError +from synapse.events import EventBase +from synapse.types import JsonDict, Requester, StreamToken + +if TYPE_CHECKING: + from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore + + +logger = logging.getLogger(__name__) + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _ThreadAggregation: + # The latest event in the thread. + latest_event: EventBase + # The latest edit to the latest event in the thread. + latest_edit: Optional[EventBase] + # The total number of events in the thread. + count: int + # True if the current user has sent an event to the thread. + current_user_participated: bool + + +@attr.s(slots=True, auto_attribs=True) +class BundledAggregations: + """ + The bundled aggregations for an event. + + Some values require additional processing during serialization. + """ + + annotations: Optional[JsonDict] = None + references: Optional[JsonDict] = None + replace: Optional[EventBase] = None + thread: Optional[_ThreadAggregation] = None + + def __bool__(self) -> bool: + return bool(self.annotations or self.references or self.replace or self.thread) + + +class RelationsHandler: + def __init__(self, hs: "HomeServer"): + self._main_store = hs.get_datastores().main + self._auth = hs.get_auth() + self._clock = hs.get_clock() + self._event_handler = hs.get_event_handler() + self._event_serializer = hs.get_event_client_serializer() + + async def get_relations( + self, + requester: Requester, + event_id: str, + room_id: str, + relation_type: Optional[str] = None, + event_type: Optional[str] = None, + aggregation_key: Optional[str] = None, + limit: int = 5, + direction: str = "b", + from_token: Optional[StreamToken] = None, + to_token: Optional[StreamToken] = None, + ) -> JsonDict: + """Get related events of a event, ordered by topological ordering. + + TODO Accept a PaginationConfig instead of individual pagination parameters. + + Args: + requester: The user requesting the relations. + event_id: Fetch events that relate to this event ID. + room_id: The room the event belongs to. + relation_type: Only fetch events with this relation type, if given. + event_type: Only fetch events with this event type, if given. + aggregation_key: Only fetch events with this aggregation key, if given. + limit: Only fetch the most recent `limit` events. + direction: Whether to fetch the most recent first (`"b"`) or the + oldest first (`"f"`). + from_token: Fetch rows from the given token, or from the start if None. + to_token: Fetch rows up to the given token, or up to the end if None. + + Returns: + The pagination chunk. + """ + + user_id = requester.user.to_string() + + await self._auth.check_user_in_room_or_world_readable( + room_id, user_id, allow_departed_users=True + ) + + # This gets the original event and checks that a) the event exists and + # b) the user is allowed to view it. + event = await self._event_handler.get_event(requester.user, room_id, event_id) + if event is None: + raise SynapseError(404, "Unknown parent event.") + + pagination_chunk = await self._main_store.get_relations_for_event( + event_id=event_id, + event=event, + room_id=room_id, + relation_type=relation_type, + event_type=event_type, + aggregation_key=aggregation_key, + limit=limit, + direction=direction, + from_token=from_token, + to_token=to_token, + ) + + events = await self._main_store.get_events_as_list( + [c["event_id"] for c in pagination_chunk.chunk] + ) + + now = self._clock.time_msec() + # Do not bundle aggregations when retrieving the original event because + # we want the content before relations are applied to it. + original_event = self._event_serializer.serialize_event( + event, now, bundle_aggregations=None + ) + # The relations returned for the requested event do include their + # bundled aggregations. + aggregations = await self.get_bundled_aggregations( + events, requester.user.to_string() + ) + serialized_events = self._event_serializer.serialize_events( + events, now, bundle_aggregations=aggregations + ) + + return_value = await pagination_chunk.to_dict(self._main_store) + return_value["chunk"] = serialized_events + return_value["original_event"] = original_event + + return return_value + + async def _get_bundled_aggregation_for_event( + self, event: EventBase, user_id: str + ) -> Optional[BundledAggregations]: + """Generate bundled aggregations for an event. + + Note that this does not use a cache, but depends on cached methods. + + Args: + event: The event to calculate bundled aggregations for. + user_id: The user requesting the bundled aggregations. + + Returns: + The bundled aggregations for an event, if bundled aggregations are + enabled and the event can have bundled aggregations. + """ + + # Do not bundle aggregations for an event which represents an edit or an + # annotation. It does not make sense for them to have related events. + relates_to = event.content.get("m.relates_to") + if isinstance(relates_to, (dict, frozendict)): + relation_type = relates_to.get("rel_type") + if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): + return None + + event_id = event.event_id + room_id = event.room_id + + # The bundled aggregations to include, a mapping of relation type to a + # type-specific value. Some types include the direct return type here + # while others need more processing during serialization. + aggregations = BundledAggregations() + + annotations = await self._main_store.get_aggregation_groups_for_event( + event_id, room_id + ) + if annotations.chunk: + aggregations.annotations = await annotations.to_dict( + cast("DataStore", self) + ) + + references = await self._main_store.get_relations_for_event( + event_id, event, room_id, RelationTypes.REFERENCE, direction="f" + ) + if references.chunk: + aggregations.references = await references.to_dict(cast("DataStore", self)) + + # Store the bundled aggregations in the event metadata for later use. + return aggregations + + async def get_bundled_aggregations( + self, events: Iterable[EventBase], user_id: str + ) -> Dict[str, BundledAggregations]: + """Generate bundled aggregations for events. + + Args: + events: The iterable of events to calculate bundled aggregations for. + user_id: The user requesting the bundled aggregations. + + Returns: + A map of event ID to the bundled aggregation for the event. Not all + events may have bundled aggregations in the results. + """ + # De-duplicate events by ID to handle the same event requested multiple times. + # + # State events do not get bundled aggregations. + events_by_id = { + event.event_id: event for event in events if not event.is_state() + } + + # event ID -> bundled aggregation in non-serialized form. + results: Dict[str, BundledAggregations] = {} + + # Fetch other relations per event. + for event in events_by_id.values(): + event_result = await self._get_bundled_aggregation_for_event(event, user_id) + if event_result: + results[event.event_id] = event_result + + # Fetch any edits (but not for redacted events). + edits = await self._main_store.get_applicable_edits( + [ + event_id + for event_id, event in events_by_id.items() + if not event.internal_metadata.is_redacted() + ] + ) + for event_id, edit in edits.items(): + results.setdefault(event_id, BundledAggregations()).replace = edit + + # Fetch thread summaries. + summaries = await self._main_store.get_thread_summaries(events_by_id.keys()) + # Only fetch participated for a limited selection based on what had + # summaries. + participated = await self._main_store.get_threads_participated( + [event_id for event_id, summary in summaries.items() if summary], user_id + ) + for event_id, summary in summaries.items(): + if summary: + thread_count, latest_thread_event, edit = summary + results.setdefault( + event_id, BundledAggregations() + ).thread = _ThreadAggregation( + latest_event=latest_thread_event, + latest_edit=edit, + count=thread_count, + # If there's a thread summary it must also exist in the + # participated dictionary. + current_user_participated=participated[event_id], + ) + + return results diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index b9735631fc..092e185c99 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -60,8 +60,8 @@ from synapse.events import EventBase from synapse.events.utils import copy_power_levels_contents from synapse.federation.federation_client import InvalidResponseError from synapse.handlers.federation import get_domains_from_state +from synapse.handlers.relations import BundledAggregations from synapse.rest.admin._base import assert_user_is_admin -from synapse.storage.databases.main.relations import BundledAggregations from synapse.storage.state import StateFilter from synapse.streams import EventSource from synapse.types import ( @@ -1118,6 +1118,7 @@ class RoomContextHandler: self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state + self._relations_handler = hs.get_relations_handler() async def get_event_context( self, @@ -1190,7 +1191,7 @@ class RoomContextHandler: event = filtered[0] # Fetch the aggregations. - aggregations = await self.store.get_bundled_aggregations( + aggregations = await self._relations_handler.get_bundled_aggregations( itertools.chain(events_before, (event,), events_after), user.to_string(), ) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index aa16e417eb..30eddda65f 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -54,6 +54,7 @@ class SearchHandler: self.clock = hs.get_clock() self.hs = hs self._event_serializer = hs.get_event_client_serializer() + self._relations_handler = hs.get_relations_handler() self.storage = hs.get_storage() self.state_store = self.storage.state self.auth = hs.get_auth() @@ -354,7 +355,7 @@ class SearchHandler: aggregations = None if self._msc3666_enabled: - aggregations = await self.store.get_bundled_aggregations( + aggregations = await self._relations_handler.get_bundled_aggregations( # Generate an iterable of EventBase for all the events that will be # returned, including contextual events. itertools.chain( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 0aa3052fd6..6c569cfb1c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -28,16 +28,16 @@ from typing import ( import attr from prometheus_client import Counter -from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes +from synapse.api.constants import EventTypes, Membership, ReceiptTypes from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase +from synapse.handlers.relations import BundledAggregations from synapse.logging.context import current_context from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.databases.main.event_push_actions import NotifCounts -from synapse.storage.databases.main.relations import BundledAggregations from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter from synapse.types import ( @@ -269,6 +269,7 @@ class SyncHandler: self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.presence_handler = hs.get_presence_handler() + self._relations_handler = hs.get_relations_handler() self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() self.state = hs.get_state_handler() @@ -638,8 +639,10 @@ class SyncHandler: # as clients will have all the necessary information. bundled_aggregations = None if limited or newly_joined_room: - bundled_aggregations = await self.store.get_bundled_aggregations( - recents, sync_config.user.to_string() + bundled_aggregations = ( + await self._relations_handler.get_bundled_aggregations( + recents, sync_config.user.to_string() + ) ) return TimelineBatch( @@ -1601,7 +1604,7 @@ class SyncHandler: return set(), set(), set(), set() # 3. Work out which rooms need reporting in the sync response. - ignored_users = await self._get_ignored_users(user_id) + ignored_users = await self.store.ignored_users(user_id) if since_token: room_changes = await self._get_rooms_changed( sync_result_builder, ignored_users @@ -1627,7 +1630,6 @@ class SyncHandler: logger.debug("Generating room entry for %s", room_entry.room_id) await self._generate_room_entry( sync_result_builder, - ignored_users, room_entry, ephemeral=ephemeral_by_room.get(room_entry.room_id, []), tags=tags_by_room.get(room_entry.room_id), @@ -1657,29 +1659,6 @@ class SyncHandler: newly_left_users, ) - async def _get_ignored_users(self, user_id: str) -> FrozenSet[str]: - """Retrieve the users ignored by the given user from their global account_data. - - Returns an empty set if - - there is no global account_data entry for ignored_users - - there is such an entry, but it's not a JSON object. - """ - # TODO: Can we `SELECT ignored_user_id FROM ignored_users WHERE ignorer_user_id=?;` instead? - ignored_account_data = ( - await self.store.get_global_account_data_by_type_for_user( - user_id=user_id, data_type=AccountDataTypes.IGNORED_USER_LIST - ) - ) - - # If there is ignored users account data and it matches the proper type, - # then use it. - ignored_users: FrozenSet[str] = frozenset() - if ignored_account_data: - ignored_users_data = ignored_account_data.get("ignored_users", {}) - if isinstance(ignored_users_data, dict): - ignored_users = frozenset(ignored_users_data.keys()) - return ignored_users - async def _have_rooms_changed( self, sync_result_builder: "SyncResultBuilder" ) -> bool: @@ -2022,7 +2001,6 @@ class SyncHandler: async def _generate_room_entry( self, sync_result_builder: "SyncResultBuilder", - ignored_users: FrozenSet[str], room_builder: "RoomSyncResultBuilder", ephemeral: List[JsonDict], tags: Optional[Dict[str, Dict[str, Any]]], @@ -2051,7 +2029,6 @@ class SyncHandler: Args: sync_result_builder - ignored_users: Set of users ignored by user. room_builder ephemeral: List of new ephemeral events for room tags: List of *all* tags for room, or None if there has been diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index d27ed2be6a..048fd4bb82 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -19,8 +19,8 @@ import synapse.metrics from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.databases.main.user_directory import SearchResult from synapse.storage.roommember import ProfileInfo -from synapse.types import JsonDict from synapse.util.metrics import Measure if TYPE_CHECKING: @@ -78,7 +78,7 @@ class UserDirectoryHandler(StateDeltasHandler): async def search_users( self, user_id: str, search_term: str, limit: int - ) -> JsonDict: + ) -> SearchResult: """Searches for users in directory Returns: diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index d735c1d461..aa8256b36f 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -111,6 +111,7 @@ from synapse.types import ( StateMap, UserID, UserInfo, + UserProfile, create_requester, ) from synapse.util import Clock @@ -150,6 +151,7 @@ __all__ = [ "EventBase", "StateMap", "ProfileInfo", + "UserProfile", ] logger = logging.getLogger(__name__) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 8140afcb6b..030898e4d0 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -213,7 +213,7 @@ class BulkPushRuleEvaluator: if not event.is_state(): ignorers = await self.store.ignored_by(event.sender) else: - ignorers = set() + ignorers = frozenset() for uid, rules in rules_by_user.items(): if event.sender == uid: diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index d9a6be43f7..c16078b187 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -51,9 +51,7 @@ class RelationPaginationServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastores().main - self.clock = hs.get_clock() - self._event_serializer = hs.get_event_client_serializer() - self.event_handler = hs.get_event_handler() + self._relations_handler = hs.get_relations_handler() async def on_GET( self, @@ -65,16 +63,6 @@ class RelationPaginationServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) - await self.auth.check_user_in_room_or_world_readable( - room_id, requester.user.to_string(), allow_departed_users=True - ) - - # This gets the original event and checks that a) the event exists and - # b) the user is allowed to view it. - event = await self.event_handler.get_event(requester.user, room_id, parent_id) - if event is None: - raise SynapseError(404, "Unknown parent event.") - limit = parse_integer(request, "limit", default=5) direction = parse_string( request, "org.matrix.msc3715.dir", default="b", allowed_values=["f", "b"] @@ -90,9 +78,9 @@ class RelationPaginationServlet(RestServlet): if to_token_str: to_token = await StreamToken.from_string(self.store, to_token_str) - pagination_chunk = await self.store.get_relations_for_event( + result = await self._relations_handler.get_relations( + requester=requester, event_id=parent_id, - event=event, room_id=room_id, relation_type=relation_type, event_type=event_type, @@ -102,30 +90,7 @@ class RelationPaginationServlet(RestServlet): to_token=to_token, ) - events = await self.store.get_events_as_list( - [c["event_id"] for c in pagination_chunk.chunk] - ) - - now = self.clock.time_msec() - # Do not bundle aggregations when retrieving the original event because - # we want the content before relations are applied to it. - original_event = self._event_serializer.serialize_event( - event, now, bundle_aggregations=None - ) - # The relations returned for the requested event do include their - # bundled aggregations. - aggregations = await self.store.get_bundled_aggregations( - events, requester.user.to_string() - ) - serialized_events = self._event_serializer.serialize_events( - events, now, bundle_aggregations=aggregations - ) - - return_value = await pagination_chunk.to_dict(self.store) - return_value["chunk"] = serialized_events - return_value["original_event"] = original_event - - return 200, return_value + return 200, result class RelationAggregationPaginationServlet(RestServlet): @@ -245,9 +210,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastores().main - self.clock = hs.get_clock() - self._event_serializer = hs.get_event_client_serializer() - self.event_handler = hs.get_event_handler() + self._relations_handler = hs.get_relations_handler() async def on_GET( self, @@ -260,18 +223,6 @@ class RelationAggregationGroupPaginationServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) - await self.auth.check_user_in_room_or_world_readable( - room_id, - requester.user.to_string(), - allow_departed_users=True, - ) - - # This checks that a) the event exists and b) the user is allowed to - # view it. - event = await self.event_handler.get_event(requester.user, room_id, parent_id) - if event is None: - raise SynapseError(404, "Unknown parent event.") - if relation_type != RelationTypes.ANNOTATION: raise SynapseError(400, "Relation type must be 'annotation'") @@ -286,9 +237,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet): if to_token_str: to_token = await StreamToken.from_string(self.store, to_token_str) - result = await self.store.get_relations_for_event( + result = await self._relations_handler.get_relations( + requester=requester, event_id=parent_id, - event=event, room_id=room_id, relation_type=relation_type, event_type=event_type, @@ -298,17 +249,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): to_token=to_token, ) - events = await self.store.get_events_as_list( - [c["event_id"] for c in result.chunk] - ) - - now = self.clock.time_msec() - serialized_events = self._event_serializer.serialize_events(events, now) - - return_value = await result.to_dict(self.store) - return_value["chunk"] = serialized_events - - return 200, return_value + return 200, result def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 8a06ab8c5f..47e152c8cc 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -645,6 +645,7 @@ class RoomEventServlet(RestServlet): self._store = hs.get_datastores().main self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() + self._relations_handler = hs.get_relations_handler() self.auth = hs.get_auth() async def on_GET( @@ -663,7 +664,7 @@ class RoomEventServlet(RestServlet): if event: # Ensure there are bundled aggregations available. - aggregations = await self._store.get_bundled_aggregations( + aggregations = await self._relations_handler.get_bundled_aggregations( [event], requester.user.to_string() ) diff --git a/synapse/rest/client/user_directory.py b/synapse/rest/client/user_directory.py index a47d9bd01d..116c982ce6 100644 --- a/synapse/rest/client/user_directory.py +++ b/synapse/rest/client/user_directory.py @@ -19,7 +19,7 @@ from synapse.api.errors import SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest -from synapse.types import JsonDict +from synapse.types import JsonMapping from ._base import client_patterns @@ -38,7 +38,7 @@ class UserDirectorySearchRestServlet(RestServlet): self.auth = hs.get_auth() self.user_directory_handler = hs.get_user_directory_handler() - async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonMapping]: """Searches for users in directory Returns: diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py index 872a9e72e8..4cc9c66fbe 100644 --- a/synapse/rest/media/v1/preview_html.py +++ b/synapse/rest/media/v1/preview_html.py @@ -16,7 +16,6 @@ import itertools import logging import re from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Union -from urllib import parse as urlparse if TYPE_CHECKING: from lxml import etree @@ -144,9 +143,7 @@ def decode_body( return etree.fromstring(body, parser) -def parse_html_to_open_graph( - tree: "etree.Element", media_uri: str -) -> Dict[str, Optional[str]]: +def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]: """ Parse the HTML document into an Open Graph response. @@ -155,7 +152,6 @@ def parse_html_to_open_graph( Args: tree: The parsed HTML document. - media_url: The URI used to download the body. Returns: The Open Graph response as a dictionary. @@ -209,7 +205,7 @@ def parse_html_to_open_graph( "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content" ) if meta_image: - og["og:image"] = rebase_url(meta_image[0], media_uri) + og["og:image"] = meta_image[0] else: # TODO: consider inlined CSS styles as well as width & height attribs images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") @@ -320,37 +316,6 @@ def _iterate_over_text( ) -def rebase_url(url: str, base: str) -> str: - """ - Resolves a potentially relative `url` against an absolute `base` URL. - - For example: - - >>> rebase_url("subpage", "https://example.com/foo/") - 'https://example.com/foo/subpage' - >>> rebase_url("sibling", "https://example.com/foo") - 'https://example.com/sibling' - >>> rebase_url("/bar", "https://example.com/foo/") - 'https://example.com/bar' - >>> rebase_url("https://alice.com/a/", "https://example.com/foo/") - 'https://alice.com/a' - """ - base_parts = urlparse.urlparse(base) - # Convert the parsed URL to a list for (potential) modification. - url_parts = list(urlparse.urlparse(url)) - # Add a scheme, if one does not exist. - if not url_parts[0]: - url_parts[0] = base_parts.scheme or "http" - # Fix up the hostname, if this is not a data URL. - if url_parts[0] != "data" and not url_parts[1]: - url_parts[1] = base_parts.netloc - # If the path does not start with a /, nest it under the base path's last - # directory. - if not url_parts[2].startswith("/"): - url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts.path) + url_parts[2] - return urlparse.urlunparse(url_parts) - - def summarize_paragraphs( text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500 ) -> Optional[str]: diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 14ea88b240..d47af8ead6 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -22,7 +22,7 @@ import shutil import sys import traceback from typing import TYPE_CHECKING, BinaryIO, Iterable, Optional, Tuple -from urllib import parse as urlparse +from urllib.parse import urljoin, urlparse, urlsplit from urllib.request import urlopen import attr @@ -44,11 +44,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.media.v1._base import get_filename_from_headers from synapse.rest.media.v1.media_storage import MediaStorage from synapse.rest.media.v1.oembed import OEmbedProvider -from synapse.rest.media.v1.preview_html import ( - decode_body, - parse_html_to_open_graph, - rebase_url, -) +from synapse.rest.media.v1.preview_html import decode_body, parse_html_to_open_graph from synapse.types import JsonDict, UserID from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred @@ -187,7 +183,7 @@ class PreviewUrlResource(DirectServeJsonResource): ts = self.clock.time_msec() # XXX: we could move this into _do_preview if we wanted. - url_tuple = urlparse.urlsplit(url) + url_tuple = urlsplit(url) for entry in self.url_preview_url_blacklist: match = True for attrib in entry: @@ -322,7 +318,7 @@ class PreviewUrlResource(DirectServeJsonResource): # Parse Open Graph information from the HTML in case the oEmbed # response failed or is incomplete. - og_from_html = parse_html_to_open_graph(tree, media_info.uri) + og_from_html = parse_html_to_open_graph(tree) # Compile the Open Graph response by using the scraped # information from the HTML and overlaying any information @@ -588,12 +584,17 @@ class PreviewUrlResource(DirectServeJsonResource): if "og:image" not in og or not og["og:image"]: return + # The image URL from the HTML might be relative to the previewed page, + # convert it to an URL which can be requested directly. + image_url = og["og:image"] + url_parts = urlparse(image_url) + if url_parts.scheme != "data": + image_url = urljoin(media_info.uri, image_url) + # FIXME: it might be cleaner to use the same flow as the main /preview_url # request itself and benefit from the same caching etc. But for now we # just rely on the caching on the master request to speed things up. - image_info = await self._handle_url( - rebase_url(og["og:image"], media_info.uri), user, allow_data_urls=True - ) + image_info = await self._handle_url(image_url, user, allow_data_urls=True) if _is_media(image_info.media_type): # TODO: make sure we don't choke on white-on-transparent images diff --git a/synapse/server.py b/synapse/server.py index 2fcf18a7a6..380369db92 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -94,6 +94,7 @@ from synapse.handlers.profile import ProfileHandler from synapse.handlers.read_marker import ReadMarkerHandler from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.register import RegistrationHandler +from synapse.handlers.relations import RelationsHandler from synapse.handlers.room import ( RoomContextHandler, RoomCreationHandler, @@ -720,6 +721,10 @@ class HomeServer(metaclass=abc.ABCMeta): return PaginationHandler(self) @cache_in_self + def get_relations_handler(self) -> RelationsHandler: + return RelationsHandler(self) + + @cache_in_self def get_room_context_handler(self) -> RoomContextHandler: return RoomContextHandler(self) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 99802228c9..9749f0c06e 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -41,6 +41,7 @@ from prometheus_client import Histogram from typing_extensions import Literal from twisted.enterprise import adbapi +from twisted.internet import defer from synapse.api.errors import StoreError from synapse.config.database import DatabaseConnectionConfig @@ -55,6 +56,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor +from synapse.util.async_helpers import delay_cancellation from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -732,34 +734,45 @@ class DatabasePool: Returns: The result of func """ - after_callbacks: List[_CallbackListEntry] = [] - exception_callbacks: List[_CallbackListEntry] = [] - if not current_context(): - logger.warning("Starting db txn '%s' from sentinel context", desc) + async def _runInteraction() -> R: + after_callbacks: List[_CallbackListEntry] = [] + exception_callbacks: List[_CallbackListEntry] = [] - try: - with opentracing.start_active_span(f"db.{desc}"): - result = await self.runWithConnection( - self.new_transaction, - desc, - after_callbacks, - exception_callbacks, - func, - *args, - db_autocommit=db_autocommit, - isolation_level=isolation_level, - **kwargs, - ) + if not current_context(): + logger.warning("Starting db txn '%s' from sentinel context", desc) - for after_callback, after_args, after_kwargs in after_callbacks: - after_callback(*after_args, **after_kwargs) - except Exception: - for after_callback, after_args, after_kwargs in exception_callbacks: - after_callback(*after_args, **after_kwargs) - raise + try: + with opentracing.start_active_span(f"db.{desc}"): + result = await self.runWithConnection( + self.new_transaction, + desc, + after_callbacks, + exception_callbacks, + func, + *args, + db_autocommit=db_autocommit, + isolation_level=isolation_level, + **kwargs, + ) - return cast(R, result) + for after_callback, after_args, after_kwargs in after_callbacks: + after_callback(*after_args, **after_kwargs) + + return cast(R, result) + except Exception: + for after_callback, after_args, after_kwargs in exception_callbacks: + after_callback(*after_args, **after_kwargs) + raise + + # To handle cancellation, we ensure that `after_callback`s and + # `exception_callback`s are always run, since the transaction will complete + # on another thread regardless of cancellation. + # + # We also wait until everything above is done before releasing the + # `CancelledError`, so that logging contexts won't get used after they have been + # finished. + return await delay_cancellation(defer.ensureDeferred(_runInteraction())) async def runWithConnection( self, diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 52146aacc8..9af9f4f18e 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -14,7 +14,17 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + FrozenSet, + Iterable, + List, + Optional, + Tuple, + cast, +) from synapse.api.constants import AccountDataTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker @@ -365,7 +375,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) ) @cached(max_entries=5000, iterable=True) - async def ignored_by(self, user_id: str) -> Set[str]: + async def ignored_by(self, user_id: str) -> FrozenSet[str]: """ Get users which ignore the given user. @@ -375,7 +385,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) Return: The user IDs which ignore the given user. """ - return set( + return frozenset( await self.db_pool.simple_select_onecol( table="ignored_users", keyvalues={"ignored_user_id": user_id}, @@ -384,6 +394,26 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) ) ) + @cached(max_entries=5000, iterable=True) + async def ignored_users(self, user_id: str) -> FrozenSet[str]: + """ + Get users which the given user ignores. + + Params: + user_id: The user ID which is making the request. + + Return: + The user IDs which are ignored by the given user. + """ + return frozenset( + await self.db_pool.simple_select_onecol( + table="ignored_users", + keyvalues={"ignorer_user_id": user_id}, + retcol="ignored_user_id", + desc="ignored_users", + ) + ) + def process_replication_rows( self, stream_name: str, @@ -529,6 +559,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) else: currently_ignored_users = set() + # If the data has not changed, nothing to do. + if previously_ignored_users == currently_ignored_users: + return + # Delete entries which are no longer ignored. self.db_pool.simple_delete_many_txn( txn, @@ -551,6 +585,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) # Invalidate the cache for any ignored users which were added or removed. for ignored_user_id in previously_ignored_users ^ currently_ignored_users: self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,)) + self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,)) async def purge_account_data_for_user(self, user_id: str) -> None: """ diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index d6a2df1afe..2d7511d613 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import ( EventsStream, EventsStreamCurrentStateRow, EventsStreamEventRow, + EventsStreamRow, ) from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -31,6 +32,7 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.engines import PostgresEngine +from synapse.util.caches.descriptors import _CachedFunction from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -82,7 +84,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): if last_id == current_id: return [], current_id, False - def get_all_updated_caches_txn(txn): + def get_all_updated_caches_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: # We purposefully don't bound by the current token, as we want to # send across cache invalidations as quickly as possible. Cache # invalidations are idempotent, so duplicates are fine. @@ -107,7 +111,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): "get_all_updated_caches", get_all_updated_caches_txn ) - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + ) -> None: if stream_name == EventsStream.NAME: for row in rows: self._process_event_stream_row(token, row) @@ -142,10 +148,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore): super().process_replication_rows(stream_name, instance_name, token, rows) - def _process_event_stream_row(self, token, row): + def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None: data = row.data if row.type == EventsStreamEventRow.TypeId: + assert isinstance(data, EventsStreamEventRow) self._invalidate_caches_for_event( token, data.event_id, @@ -157,9 +164,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): backfilled=False, ) elif row.type == EventsStreamCurrentStateRow.TypeId: - self._curr_state_delta_stream_cache.entity_has_changed( - row.data.room_id, token - ) + assert isinstance(data, EventsStreamCurrentStateRow) + self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token) if data.type == EventTypes.Member: self.get_rooms_for_user_with_stream_ordering.invalidate( @@ -170,15 +176,15 @@ class CacheInvalidationWorkerStore(SQLBaseStore): def _invalidate_caches_for_event( self, - stream_ordering, - event_id, - room_id, - etype, - state_key, - redacts, - relates_to, - backfilled, - ): + stream_ordering: int, + event_id: str, + room_id: str, + etype: str, + state_key: Optional[str], + redacts: Optional[str], + relates_to: Optional[str], + backfilled: bool, + ) -> None: self._invalidate_get_event_cache(event_id) self.have_seen_event.invalidate((room_id, event_id)) @@ -207,7 +213,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_thread_summary.invalidate((relates_to,)) self.get_thread_participated.invalidate((relates_to,)) - async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): + async def invalidate_cache_and_stream( + self, cache_name: str, keys: Tuple[Any, ...] + ) -> None: """Invalidates the cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -227,7 +235,12 @@ class CacheInvalidationWorkerStore(SQLBaseStore): keys, ) - def _invalidate_cache_and_stream(self, txn, cache_func, keys): + def _invalidate_cache_and_stream( + self, + txn: LoggingTransaction, + cache_func: _CachedFunction, + keys: Tuple[Any, ...], + ) -> None: """Invalidates the cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -238,7 +251,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): txn.call_after(cache_func.invalidate, keys) self._send_invalidation_to_replication(txn, cache_func.__name__, keys) - def _invalidate_all_cache_and_stream(self, txn, cache_func): + def _invalidate_all_cache_and_stream( + self, txn: LoggingTransaction, cache_func: _CachedFunction + ) -> None: """Invalidates the entire cache and adds it to the cache stream so slaves will know to invalidate their caches. """ @@ -279,8 +294,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): ) def _send_invalidation_to_replication( - self, txn, cache_name: str, keys: Optional[Iterable[Any]] - ): + self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]] + ) -> None: """Notifies replication that given cache has been invalidated. Note that this does *not* invalidate the cache locally. @@ -315,7 +330,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): "instance_name": self._instance_name, "cache_func": cache_name, "keys": keys, - "invalidation_ts": self.clock.time_msec(), + "invalidation_ts": self._clock.time_msec(), }, ) diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 3f6086050b..0aef121d83 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -13,13 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast from typing_extensions import TypedDict from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.types import JsonDict from synapse.util import json_encoder @@ -75,7 +79,7 @@ class GroupServerWorkerStore(SQLBaseStore): ) -> List[Dict[str, Any]]: # TODO: Pagination - keyvalues = {"group_id": group_id} + keyvalues: JsonDict = {"group_id": group_id} if not include_private: keyvalues["is_public"] = True @@ -117,7 +121,7 @@ class GroupServerWorkerStore(SQLBaseStore): # TODO: Pagination - def _get_rooms_in_group_txn(txn): + def _get_rooms_in_group_txn(txn: LoggingTransaction) -> List[_RoomInGroup]: sql = """ SELECT room_id, is_public FROM group_rooms WHERE group_id = ? @@ -176,8 +180,10 @@ class GroupServerWorkerStore(SQLBaseStore): * "order": int, the sort order of rooms in this category """ - def _get_rooms_for_summary_txn(txn): - keyvalues = {"group_id": group_id} + def _get_rooms_for_summary_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + keyvalues: JsonDict = {"group_id": group_id} if not include_private: keyvalues["is_public"] = True @@ -241,7 +247,7 @@ class GroupServerWorkerStore(SQLBaseStore): "get_rooms_for_summary", _get_rooms_for_summary_txn ) - async def get_group_categories(self, group_id): + async def get_group_categories(self, group_id: str) -> JsonDict: rows = await self.db_pool.simple_select_list( table="group_room_categories", keyvalues={"group_id": group_id}, @@ -257,7 +263,7 @@ class GroupServerWorkerStore(SQLBaseStore): for row in rows } - async def get_group_category(self, group_id, category_id): + async def get_group_category(self, group_id: str, category_id: str) -> JsonDict: category = await self.db_pool.simple_select_one( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, @@ -269,7 +275,7 @@ class GroupServerWorkerStore(SQLBaseStore): return category - async def get_group_roles(self, group_id): + async def get_group_roles(self, group_id: str) -> JsonDict: rows = await self.db_pool.simple_select_list( table="group_roles", keyvalues={"group_id": group_id}, @@ -285,7 +291,7 @@ class GroupServerWorkerStore(SQLBaseStore): for row in rows } - async def get_group_role(self, group_id, role_id): + async def get_group_role(self, group_id: str, role_id: str) -> JsonDict: role = await self.db_pool.simple_select_one( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, @@ -311,15 +317,19 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_local_groups_for_room", ) - async def get_users_for_summary_by_role(self, group_id, include_private=False): + async def get_users_for_summary_by_role( + self, group_id: str, include_private: bool = False + ) -> Tuple[List[JsonDict], JsonDict]: """Get the users and roles that should be included in a summary request Returns: ([users], [roles]) """ - def _get_users_for_summary_txn(txn): - keyvalues = {"group_id": group_id} + def _get_users_for_summary_txn( + txn: LoggingTransaction, + ) -> Tuple[List[JsonDict], JsonDict]: + keyvalues: JsonDict = {"group_id": group_id} if not include_private: keyvalues["is_public"] = True @@ -406,7 +416,9 @@ class GroupServerWorkerStore(SQLBaseStore): allow_none=True, ) - async def get_users_membership_info_in_group(self, group_id, user_id): + async def get_users_membership_info_in_group( + self, group_id: str, user_id: str + ) -> JsonDict: """Get a dict describing the membership of a user in a group. Example if joined: @@ -421,7 +433,7 @@ class GroupServerWorkerStore(SQLBaseStore): An empty dict if the user is not join/invite/etc """ - def _get_users_membership_in_group_txn(txn): + def _get_users_membership_in_group_txn(txn: LoggingTransaction) -> JsonDict: row = self.db_pool.simple_select_one_txn( txn, table="group_users", @@ -463,10 +475,14 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_publicised_groups_for_user", ) - async def get_attestations_need_renewals(self, valid_until_ms): + async def get_attestations_need_renewals( + self, valid_until_ms: int + ) -> List[Dict[str, Any]]: """Get all attestations that need to be renewed until givent time""" - def _get_attestations_need_renewals_txn(txn): + def _get_attestations_need_renewals_txn( + txn: LoggingTransaction, + ) -> List[Dict[str, Any]]: sql = """ SELECT group_id, user_id FROM group_attestations_renewals WHERE valid_until_ms <= ? @@ -478,7 +494,9 @@ class GroupServerWorkerStore(SQLBaseStore): "get_attestations_need_renewals", _get_attestations_need_renewals_txn ) - async def get_remote_attestation(self, group_id, user_id): + async def get_remote_attestation( + self, group_id: str, user_id: str + ) -> Optional[JsonDict]: """Get the attestation that proves the remote agrees that the user is in the group. """ @@ -504,8 +522,8 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_joined_groups", ) - async def get_all_groups_for_user(self, user_id, now_token): - def _get_all_groups_for_user_txn(txn): + async def get_all_groups_for_user(self, user_id, now_token) -> List[JsonDict]: + def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]: sql = """ SELECT group_id, type, membership, u.content FROM local_group_updates AS u @@ -528,15 +546,16 @@ class GroupServerWorkerStore(SQLBaseStore): "get_all_groups_for_user", _get_all_groups_for_user_txn ) - async def get_groups_changes_for_user(self, user_id, from_token, to_token): - from_token = int(from_token) - has_changed = self._group_updates_stream_cache.has_entity_changed( + async def get_groups_changes_for_user( + self, user_id: str, from_token: int, to_token: int + ) -> List[JsonDict]: + has_changed = self._group_updates_stream_cache.has_entity_changed( # type: ignore[attr-defined] user_id, from_token ) if not has_changed: return [] - def _get_groups_changes_for_user_txn(txn): + def _get_groups_changes_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]: sql = """ SELECT group_id, membership, type, u.content FROM local_group_updates AS u @@ -583,12 +602,14 @@ class GroupServerWorkerStore(SQLBaseStore): """ last_id = int(last_id) - has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) + has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) # type: ignore[attr-defined] if not has_changed: return [], current_id, False - def _get_all_groups_changes_txn(txn): + def _get_all_groups_changes_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: sql = """ SELECT stream_id, group_id, user_id, type, content FROM local_group_updates @@ -596,10 +617,13 @@ class GroupServerWorkerStore(SQLBaseStore): LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) - updates = [ - (stream_id, (group_id, user_id, gtype, db_to_json(content_json))) - for stream_id, group_id, user_id, gtype, content_json in txn - ] + updates = cast( + List[Tuple[int, tuple]], + [ + (stream_id, (group_id, user_id, gtype, db_to_json(content_json))) + for stream_id, group_id, user_id, gtype, content_json in txn + ], + ) limited = False upto_token = current_id @@ -633,8 +657,8 @@ class GroupServerStore(GroupServerWorkerStore): self, group_id: str, room_id: str, - category_id: str, - order: int, + category_id: Optional[str], + order: Optional[int], is_public: Optional[bool], ) -> None: """Add (or update) room's entry in summary. @@ -661,11 +685,11 @@ class GroupServerStore(GroupServerWorkerStore): def _add_room_to_summary_txn( self, - txn, + txn: LoggingTransaction, group_id: str, room_id: str, - category_id: str, - order: int, + category_id: Optional[str], + order: Optional[int], is_public: Optional[bool], ) -> None: """Add (or update) room's entry in summary. @@ -750,7 +774,7 @@ class GroupServerStore(GroupServerWorkerStore): WHERE group_id = ? AND category_id = ? """ txn.execute(sql, (group_id, category_id)) - (order,) = txn.fetchone() + (order,) = cast(Tuple[int], txn.fetchone()) if existing: to_update = {} @@ -766,7 +790,7 @@ class GroupServerStore(GroupServerWorkerStore): "category_id": category_id, "room_id": room_id, }, - values=to_update, + updatevalues=to_update, ) else: if is_public is None: @@ -785,7 +809,7 @@ class GroupServerStore(GroupServerWorkerStore): ) async def remove_room_from_summary( - self, group_id: str, room_id: str, category_id: str + self, group_id: str, room_id: str, category_id: Optional[str] ) -> int: if category_id is None: category_id = _DEFAULT_CATEGORY_ID @@ -808,8 +832,8 @@ class GroupServerStore(GroupServerWorkerStore): is_public: Optional[bool], ) -> None: """Add/update room category for group""" - insertion_values = {} - update_values = {"category_id": category_id} # This cannot be empty + insertion_values: JsonDict = {} + update_values: JsonDict = {"category_id": category_id} # This cannot be empty if profile is None: insertion_values["profile"] = "{}" @@ -844,8 +868,8 @@ class GroupServerStore(GroupServerWorkerStore): is_public: Optional[bool], ) -> None: """Add/remove user role""" - insertion_values = {} - update_values = {"role_id": role_id} # This cannot be empty + insertion_values: JsonDict = {} + update_values: JsonDict = {"role_id": role_id} # This cannot be empty if profile is None: insertion_values["profile"] = "{}" @@ -876,8 +900,8 @@ class GroupServerStore(GroupServerWorkerStore): self, group_id: str, user_id: str, - role_id: str, - order: int, + role_id: Optional[str], + order: Optional[int], is_public: Optional[bool], ) -> None: """Add (or update) user's entry in summary. @@ -904,13 +928,13 @@ class GroupServerStore(GroupServerWorkerStore): def _add_user_to_summary_txn( self, - txn, + txn: LoggingTransaction, group_id: str, user_id: str, - role_id: str, - order: int, + role_id: Optional[str], + order: Optional[int], is_public: Optional[bool], - ): + ) -> None: """Add (or update) user's entry in summary. Args: @@ -989,7 +1013,7 @@ class GroupServerStore(GroupServerWorkerStore): WHERE group_id = ? AND role_id = ? """ txn.execute(sql, (group_id, role_id)) - (order,) = txn.fetchone() + (order,) = cast(Tuple[int], txn.fetchone()) if existing: to_update = {} @@ -1005,7 +1029,7 @@ class GroupServerStore(GroupServerWorkerStore): "role_id": role_id, "user_id": user_id, }, - values=to_update, + updatevalues=to_update, ) else: if is_public is None: @@ -1024,7 +1048,7 @@ class GroupServerStore(GroupServerWorkerStore): ) async def remove_user_from_summary( - self, group_id: str, user_id: str, role_id: str + self, group_id: str, user_id: str, role_id: Optional[str] ) -> int: if role_id is None: role_id = _DEFAULT_ROLE_ID @@ -1065,7 +1089,7 @@ class GroupServerStore(GroupServerWorkerStore): Optional if the user and group are on the same server """ - def _add_user_to_group_txn(txn): + def _add_user_to_group_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_insert_txn( txn, table="group_users", @@ -1108,7 +1132,7 @@ class GroupServerStore(GroupServerWorkerStore): await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn) async def remove_user_from_group(self, group_id: str, user_id: str) -> None: - def _remove_user_from_group_txn(txn): + def _remove_user_from_group_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_delete_txn( txn, table="group_users", @@ -1159,7 +1183,7 @@ class GroupServerStore(GroupServerWorkerStore): ) async def remove_room_from_group(self, group_id: str, room_id: str) -> None: - def _remove_room_from_group_txn(txn): + def _remove_room_from_group_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_delete_txn( txn, table="group_rooms", @@ -1216,7 +1240,9 @@ class GroupServerStore(GroupServerWorkerStore): content = content or {} - def _register_user_group_membership_txn(txn, next_id): + def _register_user_group_membership_txn( + txn: LoggingTransaction, next_id: int + ) -> int: # TODO: Upsert? self.db_pool.simple_delete_txn( txn, @@ -1249,7 +1275,7 @@ class GroupServerStore(GroupServerWorkerStore): ), }, ) - self._group_updates_stream_cache.entity_has_changed(user_id, next_id) + self._group_updates_stream_cache.entity_has_changed(user_id, next_id) # type: ignore[attr-defined] # TODO: Insert profile to ensure it comes down stream if its a join. @@ -1289,7 +1315,7 @@ class GroupServerStore(GroupServerWorkerStore): return next_id - async with self._group_updates_id_gen.get_next() as next_id: + async with self._group_updates_id_gen.get_next() as next_id: # type: ignore[attr-defined] res = await self.db_pool.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, @@ -1298,7 +1324,13 @@ class GroupServerStore(GroupServerWorkerStore): return res async def create_group( - self, group_id, user_id, name, avatar_url, short_description, long_description + self, + group_id: str, + user_id: str, + name: str, + avatar_url: str, + short_description: str, + long_description: str, ) -> None: await self.db_pool.simple_insert( table="groups", @@ -1313,7 +1345,7 @@ class GroupServerStore(GroupServerWorkerStore): desc="create_group", ) - async def update_group_profile(self, group_id, profile): + async def update_group_profile(self, group_id: str, profile: JsonDict) -> None: await self.db_pool.simple_update_one( table="groups", keyvalues={"group_id": group_id}, @@ -1361,8 +1393,8 @@ class GroupServerStore(GroupServerWorkerStore): desc="remove_attestation_renewal", ) - def get_group_stream_token(self): - return self._group_updates_id_gen.get_current_token() + def get_group_stream_token(self) -> int: + return self._group_updates_id_gen.get_current_token() # type: ignore[attr-defined] async def delete_group(self, group_id: str) -> None: """Deletes a group fully from the database. @@ -1371,7 +1403,7 @@ class GroupServerStore(GroupServerWorkerStore): group_id: The group ID to delete. """ - def _delete_group_txn(txn): + def _delete_group_txn(txn: LoggingTransaction) -> None: tables = [ "groups", "group_users", diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index e9a0cdc6be..216622964a 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -12,15 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, + LoggingTransaction, make_in_list_sql_clause, ) +from synapse.storage.databases.main.registration import RegistrationWorkerStore from synapse.util.caches.descriptors import cached from synapse.util.threepids import canonicalise_email @@ -56,7 +58,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): Number of current monthly active users """ - def _count_users(txn): + def _count_users(txn: LoggingTransaction) -> int: # Exclude app service users sql = """ SELECT COUNT(*) @@ -66,7 +68,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): WHERE (users.appservice_id IS NULL OR users.appservice_id = ''); """ txn.execute(sql) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return count return await self.db_pool.runInteraction("count_users", _count_users) @@ -84,7 +86,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): """ - def _count_users_by_service(txn): + def _count_users_by_service(txn: LoggingTransaction) -> Dict[str, int]: sql = """ SELECT COALESCE(appservice_id, 'native'), COUNT(*) FROM monthly_active_users @@ -93,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): """ txn.execute(sql) - result = txn.fetchall() + result = cast(List[Tuple[str, int]], txn.fetchall()) return dict(result) return await self.db_pool.runInteraction( @@ -141,12 +143,12 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): ) @wrap_as_background_process("reap_monthly_active_users") - async def reap_monthly_active_users(self): + async def reap_monthly_active_users(self) -> None: """Cleans out monthly active user table to ensure that no stale entries exist. """ - def _reap_users(txn, reserved_users): + def _reap_users(txn: LoggingTransaction, reserved_users: List[str]) -> None: """ Args: reserved_users (tuple): reserved users to preserve @@ -210,10 +212,10 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): # is racy. # Have resolved to invalidate the whole cache for now and do # something about it if and when the perf becomes significant - self._invalidate_all_cache_and_stream( + self._invalidate_all_cache_and_stream( # type: ignore[attr-defined] txn, self.user_last_seen_monthly_active ) - self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) + self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) # type: ignore[attr-defined] reserved_users = await self.get_registered_reserved_users() await self.db_pool.runInteraction( @@ -221,7 +223,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): ) -class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): +class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerStore): def __init__( self, database: DatabasePool, @@ -242,13 +244,15 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value], ) - def _initialise_reserved_users(self, txn, threepids): + def _initialise_reserved_users( + self, txn: LoggingTransaction, threepids: List[dict] + ) -> None: """Ensures that reserved threepids are accounted for in the MAU table, should be called on start up. Args: - txn (cursor): - threepids (list[dict]): List of threepid dicts to reserve + txn: + threepids: List of threepid dicts to reserve """ # XXX what is this function trying to achieve? It upserts into @@ -299,7 +303,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id ) - def upsert_monthly_active_user_txn(self, txn, user_id): + def upsert_monthly_active_user_txn( + self, txn: LoggingTransaction, user_id: str + ) -> None: """Updates or inserts monthly active user member We consciously do not call is_support_txn from this method because it @@ -336,7 +342,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): txn, self.user_last_seen_monthly_active, (user_id,) ) - async def populate_monthly_active_users(self, user_id): + async def populate_monthly_active_users(self, user_id: str) -> None: """Checks on the state of monthly active user limits and optionally add the user to the monthly active tables @@ -345,7 +351,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): """ if self._limit_usage_by_mau or self._mau_stats_only: # Trial users and guests should not be included as part of MAU group - is_guest = await self.is_guest(user_id) + is_guest = await self.is_guest(user_id) # type: ignore[attr-defined] if is_guest: return is_trial = await self.is_trial_user(user_id) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index c4869d64e6..b2295fd51f 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -27,7 +27,6 @@ from typing import ( ) import attr -from frozendict import frozendict from synapse.api.constants import RelationTypes from synapse.events import EventBase @@ -41,45 +40,15 @@ from synapse.storage.database import ( from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.engines import PostgresEngine from synapse.storage.relations import AggregationPaginationToken, PaginationChunk -from synapse.types import JsonDict, RoomStreamToken, StreamToken +from synapse.types import RoomStreamToken, StreamToken from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: from synapse.server import HomeServer - from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) -@attr.s(slots=True, frozen=True, auto_attribs=True) -class _ThreadAggregation: - # The latest event in the thread. - latest_event: EventBase - # The latest edit to the latest event in the thread. - latest_edit: Optional[EventBase] - # The total number of events in the thread. - count: int - # True if the current user has sent an event to the thread. - current_user_participated: bool - - -@attr.s(slots=True, auto_attribs=True) -class BundledAggregations: - """ - The bundled aggregations for an event. - - Some values require additional processing during serialization. - """ - - annotations: Optional[JsonDict] = None - references: Optional[JsonDict] = None - replace: Optional[EventBase] = None - thread: Optional[_ThreadAggregation] = None - - def __bool__(self) -> bool: - return bool(self.annotations or self.references or self.replace or self.thread) - - class RelationsWorkerStore(SQLBaseStore): def __init__( self, @@ -384,7 +353,7 @@ class RelationsWorkerStore(SQLBaseStore): raise NotImplementedError() @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids") - async def _get_applicable_edits( + async def get_applicable_edits( self, event_ids: Collection[str] ) -> Dict[str, Optional[EventBase]]: """Get the most recent edit (if any) that has happened for the given @@ -473,7 +442,7 @@ class RelationsWorkerStore(SQLBaseStore): raise NotImplementedError() @cachedList(cached_method_name="get_thread_summary", list_name="event_ids") - async def _get_thread_summaries( + async def get_thread_summaries( self, event_ids: Collection[str] ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]: """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event. @@ -587,7 +556,7 @@ class RelationsWorkerStore(SQLBaseStore): latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined] # Check to see if any of those events are edited. - latest_edits = await self._get_applicable_edits(latest_event_ids.values()) + latest_edits = await self.get_applicable_edits(latest_event_ids.values()) # Map to the event IDs to the thread summary. # @@ -610,7 +579,7 @@ class RelationsWorkerStore(SQLBaseStore): raise NotImplementedError() @cachedList(cached_method_name="get_thread_participated", list_name="event_ids") - async def _get_threads_participated( + async def get_threads_participated( self, event_ids: Collection[str], user_id: str ) -> Dict[str, bool]: """Get whether the requesting user participated in the given threads. @@ -766,114 +735,6 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) - async def _get_bundled_aggregation_for_event( - self, event: EventBase, user_id: str - ) -> Optional[BundledAggregations]: - """Generate bundled aggregations for an event. - - Note that this does not use a cache, but depends on cached methods. - - Args: - event: The event to calculate bundled aggregations for. - user_id: The user requesting the bundled aggregations. - - Returns: - The bundled aggregations for an event, if bundled aggregations are - enabled and the event can have bundled aggregations. - """ - - # Do not bundle aggregations for an event which represents an edit or an - # annotation. It does not make sense for them to have related events. - relates_to = event.content.get("m.relates_to") - if isinstance(relates_to, (dict, frozendict)): - relation_type = relates_to.get("rel_type") - if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): - return None - - event_id = event.event_id - room_id = event.room_id - - # The bundled aggregations to include, a mapping of relation type to a - # type-specific value. Some types include the direct return type here - # while others need more processing during serialization. - aggregations = BundledAggregations() - - annotations = await self.get_aggregation_groups_for_event(event_id, room_id) - if annotations.chunk: - aggregations.annotations = await annotations.to_dict( - cast("DataStore", self) - ) - - references = await self.get_relations_for_event( - event_id, event, room_id, RelationTypes.REFERENCE, direction="f" - ) - if references.chunk: - aggregations.references = await references.to_dict(cast("DataStore", self)) - - # Store the bundled aggregations in the event metadata for later use. - return aggregations - - async def get_bundled_aggregations( - self, events: Iterable[EventBase], user_id: str - ) -> Dict[str, BundledAggregations]: - """Generate bundled aggregations for events. - - Args: - events: The iterable of events to calculate bundled aggregations for. - user_id: The user requesting the bundled aggregations. - - Returns: - A map of event ID to the bundled aggregation for the event. Not all - events may have bundled aggregations in the results. - """ - # De-duplicate events by ID to handle the same event requested multiple times. - # - # State events do not get bundled aggregations. - events_by_id = { - event.event_id: event for event in events if not event.is_state() - } - - # event ID -> bundled aggregation in non-serialized form. - results: Dict[str, BundledAggregations] = {} - - # Fetch other relations per event. - for event in events_by_id.values(): - event_result = await self._get_bundled_aggregation_for_event(event, user_id) - if event_result: - results[event.event_id] = event_result - - # Fetch any edits (but not for redacted events). - edits = await self._get_applicable_edits( - [ - event_id - for event_id, event in events_by_id.items() - if not event.internal_metadata.is_redacted() - ] - ) - for event_id, edit in edits.items(): - results.setdefault(event_id, BundledAggregations()).replace = edit - - # Fetch thread summaries. - summaries = await self._get_thread_summaries(events_by_id.keys()) - # Only fetch participated for a limited selection based on what had - # summaries. - participated = await self._get_threads_participated(summaries.keys(), user_id) - for event_id, summary in summaries.items(): - if summary: - thread_count, latest_thread_event, edit = summary - results.setdefault( - event_id, BundledAggregations() - ).thread = _ThreadAggregation( - latest_event=latest_thread_event, - latest_edit=edit, - count=thread_count, - # If there's a thread summary it must also exist in the - # participated dictionary. - current_user_participated=participated[event_id], - ) - - return results - class RelationsStore(RelationsWorkerStore): pass diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index e7fddd2426..55cc9178f0 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -26,6 +26,8 @@ from typing import ( cast, ) +from typing_extensions import TypedDict + from synapse.api.errors import StoreError if TYPE_CHECKING: @@ -40,7 +42,12 @@ from synapse.storage.database import ( from synapse.storage.databases.main.state import StateFilter from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id +from synapse.types import ( + JsonDict, + UserProfile, + get_domain_from_id, + get_localpart_from_id, +) from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -591,6 +598,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) +class SearchResult(TypedDict): + limited: bool + results: List[UserProfile] + + class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # How many records do we calculate before sending it to # add_users_who_share_private_rooms? @@ -777,7 +789,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): async def search_user_dir( self, user_id: str, search_term: str, limit: int - ) -> JsonDict: + ) -> SearchResult: """Searches for users in directory Returns: @@ -910,8 +922,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # This should be unreachable. raise Exception("Unrecognized database engine") - results = await self.db_pool.execute( - "search_user_dir", self.db_pool.cursor_to_dict, sql, *args + results = cast( + List[UserProfile], + await self.db_pool.execute( + "search_user_dir", self.db_pool.cursor_to_dict, sql, *args + ), ) limited = len(results) > limit diff --git a/synapse/types.py b/synapse/types.py index 53be3583a0..5ce2a5b0a5 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -34,6 +34,7 @@ from typing import ( import attr from frozendict import frozendict from signedjson.key import decode_verify_key_bytes +from typing_extensions import TypedDict from unpaddedbase64 import decode_base64 from zope.interface import Interface @@ -63,6 +64,10 @@ MutableStateMap = MutableMapping[StateKey, T] # JSON types. These could be made stronger, but will do for now. # A JSON-serialisable dict. JsonDict = Dict[str, Any] +# A JSON-serialisable mapping; roughly speaking an immutable JSONDict. +# Useful when you have a TypedDict which isn't going to be mutated and you don't want +# to cast to JsonDict everywhere. +JsonMapping = Mapping[str, Any] # A JSON-serialisable object. JsonSerializable = object @@ -791,3 +796,9 @@ class UserInfo: is_deactivated: bool is_guest: bool is_shadow_banned: bool + + +class UserProfile(TypedDict): + user_id: str + display_name: Optional[str] + avatar_url: Optional[str] diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py index 12cd804939..66f1da7502 100644 --- a/synapse/util/check_dependencies.py +++ b/synapse/util/check_dependencies.py @@ -128,6 +128,19 @@ def _incorrect_version( ) +def _no_reported_version(requirement: Requirement, extra: Optional[str] = None) -> str: + if extra: + return ( + f"Synapse {VERSION} needs {requirement} for {extra}, " + f"but can't determine {requirement.name}'s version" + ) + else: + return ( + f"Synapse {VERSION} needs {requirement}, " + f"but can't determine {requirement.name}'s version" + ) + + def check_requirements(extra: Optional[str] = None) -> None: """Check Synapse's dependencies are present and correctly versioned. @@ -163,8 +176,17 @@ def check_requirements(extra: Optional[str] = None) -> None: deps_unfulfilled.append(requirement.name) errors.append(_not_installed(requirement, extra)) else: + if dist.version is None: + # This shouldn't happen---it suggests a borked virtualenv. (See #12223) + # Try to give a vaguely helpful error message anyway. + # Type-ignore: the annotations don't reflect reality: see + # https://github.com/python/typeshed/issues/7513 + # https://bugs.python.org/issue47060 + deps_unfulfilled.append(requirement.name) # type: ignore[unreachable] + errors.append(_no_reported_version(requirement, extra)) + # We specify prereleases=True to allow prereleases such as RCs. - if not requirement.specifier.contains(dist.version, prereleases=True): + elif not requirement.specifier.contains(dist.version, prereleases=True): deps_unfulfilled.append(requirement.name) errors.append(_incorrect_version(requirement, dist.version, extra)) diff --git a/synapse/visibility.py b/synapse/visibility.py index 281cbe4d88..49519eb8f5 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -14,12 +14,7 @@ import logging from typing import Dict, FrozenSet, List, Optional -from synapse.api.constants import ( - AccountDataTypes, - EventTypes, - HistoryVisibility, - Membership, -) +from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.events import EventBase from synapse.events.utils import prune_event from synapse.storage import Storage @@ -87,15 +82,8 @@ async def filter_events_for_client( state_filter=StateFilter.from_types(types), ) - ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user( - user_id, AccountDataTypes.IGNORED_USER_LIST - ) - - ignore_list: FrozenSet[str] = frozenset() - if ignore_dict_content: - ignored_users_dict = ignore_dict_content.get("ignored_users", {}) - if isinstance(ignored_users_dict, dict): - ignore_list = frozenset(ignored_users_dict.keys()) + # Get the users who are ignored by the requesting user. + ignore_list = await storage.main.ignored_users(user_id) erased_senders = await storage.main.are_users_erased(e.sender for e in events) diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index a267228846..a54aa29cf1 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -11,9 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + from synapse.handlers.cas import CasResponse +from synapse.server import HomeServer +from synapse.util import Clock from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config @@ -24,7 +29,7 @@ SERVER_URL = "https://issuer/" class CasHandlerTestCase(HomeserverTestCase): - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["public_baseurl"] = BASE_URL cas_config = { @@ -40,7 +45,7 @@ class CasHandlerTestCase(HomeserverTestCase): return config - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver() self.handler = hs.get_cas_handler() @@ -51,7 +56,7 @@ class CasHandlerTestCase(HomeserverTestCase): return hs - def test_map_cas_user_to_user(self): + def test_map_cas_user_to_user(self) -> None: """Ensure that mapping the CAS user returned from a provider to an MXID works properly.""" # stub out the auth handler @@ -75,7 +80,7 @@ class CasHandlerTestCase(HomeserverTestCase): auth_provider_session_id=None, ) - def test_map_cas_user_to_existing_user(self): + def test_map_cas_user_to_existing_user(self) -> None: """Existing users can log in with CAS account.""" store = self.hs.get_datastores().main self.get_success( @@ -119,7 +124,7 @@ class CasHandlerTestCase(HomeserverTestCase): auth_provider_session_id=None, ) - def test_map_cas_user_to_invalid_localpart(self): + def test_map_cas_user_to_invalid_localpart(self) -> None: """CAS automaps invalid characters to base-64 encoding.""" # stub out the auth handler @@ -150,7 +155,7 @@ class CasHandlerTestCase(HomeserverTestCase): } } ) - def test_required_attributes(self): + def test_required_attributes(self) -> None: """The required attributes must be met from the CAS response.""" # stub out the auth handler @@ -166,7 +171,7 @@ class CasHandlerTestCase(HomeserverTestCase): auth_handler.complete_sso_login.assert_not_called() # The response doesn't have any department. - cas_response = CasResponse("test_user", {"userGroup": "staff"}) + cas_response = CasResponse("test_user", {"userGroup": ["staff"]}) request.reset_mock() self.get_success( self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 6e403a87c5..11ad44223d 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -12,14 +12,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from typing import Any, Awaitable, Callable, Dict from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + import synapse.api.errors import synapse.rest.admin from synapse.api.constants import EventTypes from synapse.rest.client import directory, login, room -from synapse.types import RoomAlias, create_requester +from synapse.server import HomeServer +from synapse.types import JsonDict, RoomAlias, create_requester +from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable @@ -28,13 +32,15 @@ from tests.test_utils import make_awaitable class DirectoryTestCase(unittest.HomeserverTestCase): """Tests the directory service.""" - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.mock_federation = Mock() self.mock_registry = Mock() - self.query_handlers = {} + self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {} - def register_query_handler(query_type, handler): + def register_query_handler( + query_type: str, handler: Callable[[dict], Awaitable[JsonDict]] + ) -> None: self.query_handlers[query_type] = handler self.mock_registry.register_query_handler = register_query_handler @@ -54,7 +60,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): return hs - def test_get_local_association(self): + def test_get_local_association(self) -> None: self.get_success( self.store.create_room_alias_association( self.my_room, "!8765qwer:test", ["test"] @@ -65,7 +71,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result) - def test_get_remote_association(self): + def test_get_remote_association(self) -> None: self.mock_federation.make_query.return_value = make_awaitable( {"room_id": "!8765qwer:test", "servers": ["test", "remote"]} ) @@ -83,7 +89,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ignore_backoff=True, ) - def test_incoming_fed_query(self): + def test_incoming_fed_query(self) -> None: self.get_success( self.store.create_room_alias_association( self.your_room, "!8765asdf:test", ["test"] @@ -105,7 +111,7 @@ class TestCreateAlias(unittest.HomeserverTestCase): directory.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = hs.get_directory_handler() # Create user @@ -125,7 +131,7 @@ class TestCreateAlias(unittest.HomeserverTestCase): self.test_user_tok = self.login("user", "pass") self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) - def test_create_alias_joined_room(self): + def test_create_alias_joined_room(self) -> None: """A user can create an alias for a room they're in.""" self.get_success( self.handler.create_association( @@ -135,7 +141,7 @@ class TestCreateAlias(unittest.HomeserverTestCase): ) ) - def test_create_alias_other_room(self): + def test_create_alias_other_room(self) -> None: """A user cannot create an alias for a room they're NOT in.""" other_room_id = self.helper.create_room_as( self.admin_user, tok=self.admin_user_tok @@ -150,7 +156,7 @@ class TestCreateAlias(unittest.HomeserverTestCase): synapse.api.errors.SynapseError, ) - def test_create_alias_admin(self): + def test_create_alias_admin(self) -> None: """An admin can create an alias for a room they're NOT in.""" other_room_id = self.helper.create_room_as( self.test_user, tok=self.test_user_tok @@ -173,7 +179,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): directory.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.handler = hs.get_directory_handler() self.state_handler = hs.get_state_handler() @@ -195,7 +201,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): self.test_user_tok = self.login("user", "pass") self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) - def _create_alias(self, user): + def _create_alias(self, user) -> None: # Create a new alias to this room. self.get_success( self.store.create_room_alias_association( @@ -203,7 +209,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): ) ) - def test_delete_alias_not_allowed(self): + def test_delete_alias_not_allowed(self) -> None: """A user that doesn't meet the expected guidelines cannot delete an alias.""" self._create_alias(self.admin_user) self.get_failure( @@ -213,7 +219,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): synapse.api.errors.AuthError, ) - def test_delete_alias_creator(self): + def test_delete_alias_creator(self) -> None: """An alias creator can delete their own alias.""" # Create an alias from a different user. self._create_alias(self.test_user) @@ -232,7 +238,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): synapse.api.errors.SynapseError, ) - def test_delete_alias_admin(self): + def test_delete_alias_admin(self) -> None: """A server admin can delete an alias created by another user.""" # Create an alias from a different user. self._create_alias(self.test_user) @@ -251,7 +257,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): synapse.api.errors.SynapseError, ) - def test_delete_alias_sufficient_power(self): + def test_delete_alias_sufficient_power(self) -> None: """A user with a sufficient power level should be able to delete an alias.""" self._create_alias(self.admin_user) @@ -288,7 +294,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): directory.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.handler = hs.get_directory_handler() self.state_handler = hs.get_state_handler() @@ -317,7 +323,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) return room_alias - def _set_canonical_alias(self, content): + def _set_canonical_alias(self, content) -> None: """Configure the canonical alias state on the room.""" self.helper.send_state( self.room_id, @@ -334,7 +340,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) ) - def test_remove_alias(self): + def test_remove_alias(self) -> None: """Removing an alias that is the canonical alias should remove it there too.""" # Set this new alias as the canonical alias for this room self._set_canonical_alias( @@ -356,7 +362,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): self.assertNotIn("alias", data["content"]) self.assertNotIn("alt_aliases", data["content"]) - def test_remove_other_alias(self): + def test_remove_other_alias(self) -> None: """Removing an alias listed as in alt_aliases should remove it there too.""" # Create a second alias. other_test_alias = "#test2:test" @@ -393,7 +399,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): servlets = [directory.register_servlets, room.register_servlets] - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() # Add custom alias creation rules to the config. @@ -403,7 +409,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): return config - def test_denied(self): + def test_denied(self) -> None: room_id = self.helper.create_room_as(self.user_id) channel = self.make_request( @@ -413,7 +419,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): ) self.assertEqual(403, channel.code, channel.result) - def test_allowed(self): + def test_allowed(self) -> None: room_id = self.helper.create_room_as(self.user_id) channel = self.make_request( @@ -423,7 +429,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): ) self.assertEqual(200, channel.code, channel.result) - def test_denied_during_creation(self): + def test_denied_during_creation(self) -> None: """A room alias that is not allowed should be rejected during creation.""" # Invalid room alias. self.helper.create_room_as( @@ -432,7 +438,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): extra_content={"room_alias_name": "foo"}, ) - def test_allowed_during_creation(self): + def test_allowed_during_creation(self) -> None: """A valid room alias should be allowed during creation.""" room_id = self.helper.create_room_as( self.user_id, @@ -459,7 +465,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): data = {"room_alias_name": "unofficial_test"} allowed_localpart = "allowed" - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() # Add custom room list publication rules to the config. @@ -474,7 +480,9 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): return config - def prepare(self, reactor, clock, hs): + def prepare( + self, reactor: MemoryReactor, clock: Clock, hs: HomeServer + ) -> HomeServer: self.allowed_user_id = self.register_user(self.allowed_localpart, "pass") self.allowed_access_token = self.login(self.allowed_localpart, "pass") @@ -483,7 +491,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): return hs - def test_denied_without_publication_permission(self): + def test_denied_without_publication_permission(self) -> None: """ Try to create a room, register an alias for it, and publish it, as a user without permission to publish rooms. @@ -497,7 +505,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): expect_code=403, ) - def test_allowed_when_creating_private_room(self): + def test_allowed_when_creating_private_room(self) -> None: """ Try to create a room, register an alias for it, and NOT publish it, as a user without permission to publish rooms. @@ -511,7 +519,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): expect_code=200, ) - def test_allowed_with_publication_permission(self): + def test_allowed_with_publication_permission(self) -> None: """ Try to create a room, register an alias for it, and publish it, as a user WITH permission to publish rooms. @@ -525,7 +533,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): expect_code=200, ) - def test_denied_publication_with_invalid_alias(self): + def test_denied_publication_with_invalid_alias(self) -> None: """ Try to create a room, register an alias for it, and publish it, as a user WITH permission to publish rooms. @@ -538,7 +546,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): expect_code=403, ) - def test_can_create_as_private_room_after_rejection(self): + def test_can_create_as_private_room_after_rejection(self) -> None: """ After failing to publish a room with an alias as a user without publish permission, retry as the same user, but without publishing the room. @@ -549,7 +557,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): self.test_denied_without_publication_permission() self.test_allowed_when_creating_private_room() - def test_can_create_with_permission_after_rejection(self): + def test_can_create_with_permission_after_rejection(self) -> None: """ After failing to publish a room with an alias as a user without publish permission, retry as someone with permission, using the same alias. @@ -566,7 +574,9 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): servlets = [directory.register_servlets, room.register_servlets] - def prepare(self, reactor, clock, hs): + def prepare( + self, reactor: MemoryReactor, clock: Clock, hs: HomeServer + ) -> HomeServer: room_id = self.helper.create_room_as(self.user_id) channel = self.make_request( @@ -579,7 +589,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): return hs - def test_disabling_room_list(self): + def test_disabling_room_list(self) -> None: self.room_list_handler.enable_room_list_search = True self.directory_handler.enable_room_list_search = True diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 9338ab92e9..ac21a28c43 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -20,33 +20,37 @@ from parameterized import parameterized from signedjson import key as key, sign as sign from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import RoomEncryptionAlgorithms from synapse.api.errors import Codes, SynapseError +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: return self.setup_test_homeserver(federation_client=mock.Mock()) - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = hs.get_e2e_keys_handler() self.store = self.hs.get_datastores().main - def test_query_local_devices_no_devices(self): + def test_query_local_devices_no_devices(self) -> None: """If the user has no devices, we expect an empty list.""" local_user = "@boris:" + self.hs.hostname res = self.get_success(self.handler.query_local_devices({local_user: None})) self.assertDictEqual(res, {local_user: {}}) - def test_reupload_one_time_keys(self): + def test_reupload_one_time_keys(self) -> None: """we should be able to re-upload the same keys""" local_user = "@boris:" + self.hs.hostname device_id = "xyz" - keys = { + keys: JsonDict = { "alg1:k1": "key1", "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}}, "alg2:k3": {"key": "key3"}, @@ -74,7 +78,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}} ) - def test_change_one_time_keys(self): + def test_change_one_time_keys(self) -> None: """attempts to change one-time-keys should be rejected""" local_user = "@boris:" + self.hs.hostname @@ -134,7 +138,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): SynapseError, ) - def test_claim_one_time_key(self): + def test_claim_one_time_key(self) -> None: local_user = "@boris:" + self.hs.hostname device_id = "xyz" keys = {"alg1:k1": "key1"} @@ -161,7 +165,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): }, ) - def test_fallback_key(self): + def test_fallback_key(self) -> None: local_user = "@boris:" + self.hs.hostname device_id = "xyz" fallback_key = {"alg1:k1": "fallback_key1"} @@ -294,7 +298,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}}, ) - def test_replace_master_key(self): + def test_replace_master_key(self) -> None: """uploading a new signing key should make the old signing key unavailable""" local_user = "@boris:" + self.hs.hostname keys1 = { @@ -328,7 +332,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): ) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) - def test_reupload_signatures(self): + def test_reupload_signatures(self) -> None: """re-uploading a signature should not fail""" local_user = "@boris:" + self.hs.hostname keys1 = { @@ -433,7 +437,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1) self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2) - def test_self_signing_key_doesnt_show_up_as_device(self): + def test_self_signing_key_doesnt_show_up_as_device(self) -> None: """signing keys should be hidden when fetching a user's devices""" local_user = "@boris:" + self.hs.hostname keys1 = { @@ -462,7 +466,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): res = self.get_success(self.handler.query_local_devices({local_user: None})) self.assertDictEqual(res, {local_user: {}}) - def test_upload_signatures(self): + def test_upload_signatures(self) -> None: """should check signatures that are uploaded""" # set up a user with cross-signing keys and a device. This user will # try uploading signatures @@ -686,7 +690,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey], ) - def test_query_devices_remote_no_sync(self): + def test_query_devices_remote_no_sync(self) -> None: """Tests that querying keys for a remote user that we don't share a room with returns the cross signing keys correctly. """ @@ -759,7 +763,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): }, ) - def test_query_devices_remote_sync(self): + def test_query_devices_remote_sync(self) -> None: """Tests that querying keys for a remote user that we share a room with, but haven't yet fetched the keys for, returns the cross signing keys correctly. @@ -845,7 +849,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): (["device_1", "device_2"],), ] ) - def test_query_all_devices_caches_result(self, device_ids: Iterable[str]): + def test_query_all_devices_caches_result(self, device_ids: Iterable[str]) -> None: """Test that requests for all of a remote user's devices are cached. We do this by asserting that only one call over federation was made, and that @@ -853,7 +857,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): """ local_user_id = "@test:test" remote_user_id = "@test:other" - request_body = {"device_keys": {remote_user_id: []}} + request_body: JsonDict = {"device_keys": {remote_user_id: []}} response_devices = [ { diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index e8b4e39d1a..89078fc637 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List +from typing import List, cast from unittest import TestCase +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError from synapse.api.room_versions import RoomVersions @@ -23,7 +25,9 @@ from synapse.federation.federation_base import event_from_pdu_json from synapse.logging.context import LoggingContext, run_in_background from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer from synapse.types import create_requester +from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest @@ -42,7 +46,7 @@ class FederationTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver(federation_http_client=None) self.handler = hs.get_federation_handler() self.store = hs.get_datastores().main @@ -50,7 +54,7 @@ class FederationTestCase(unittest.HomeserverTestCase): self._event_auth_handler = hs.get_event_auth_handler() return hs - def test_exchange_revoked_invite(self): + def test_exchange_revoked_invite(self) -> None: user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") @@ -96,7 +100,7 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure) self.assertEqual(failure.msg, "You are not invited to this room.") - def test_rejected_message_event_state(self): + def test_rejected_message_event_state(self) -> None: """ Check that we store the state group correctly for rejected non-state events. @@ -126,7 +130,7 @@ class FederationTestCase(unittest.HomeserverTestCase): "content": {}, "room_id": room_id, "sender": "@yetanotheruser:" + OTHER_SERVER, - "depth": join_event["depth"] + 1, + "depth": cast(int, join_event["depth"]) + 1, "prev_events": [join_event.event_id], "auth_events": [], "origin_server_ts": self.clock.time_msec(), @@ -149,7 +153,7 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertEqual(sg, sg2) - def test_rejected_state_event_state(self): + def test_rejected_state_event_state(self) -> None: """ Check that we store the state group correctly for rejected state events. @@ -180,7 +184,7 @@ class FederationTestCase(unittest.HomeserverTestCase): "content": {}, "room_id": room_id, "sender": "@yetanotheruser:" + OTHER_SERVER, - "depth": join_event["depth"] + 1, + "depth": cast(int, join_event["depth"]) + 1, "prev_events": [join_event.event_id], "auth_events": [], "origin_server_ts": self.clock.time_msec(), @@ -203,7 +207,7 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertEqual(sg, sg2) - def test_backfill_with_many_backward_extremities(self): + def test_backfill_with_many_backward_extremities(self) -> None: """ Check that we can backfill with many backward extremities. The goal is to make sure that when we only use a portion @@ -262,7 +266,7 @@ class FederationTestCase(unittest.HomeserverTestCase): ) self.get_success(d) - def test_backfill_floating_outlier_membership_auth(self): + def test_backfill_floating_outlier_membership_auth(self) -> None: """ As the local homeserver, check that we can properly process a federated event from the OTHER_SERVER with auth_events that include a floating @@ -377,7 +381,7 @@ class FederationTestCase(unittest.HomeserverTestCase): for ae in auth_events ] - self.handler.federation_client.get_event_auth = get_event_auth + self.handler.federation_client.get_event_auth = get_event_auth # type: ignore[assignment] with LoggingContext("receive_pdu"): # Fake the OTHER_SERVER federating the message event over to our local homeserver @@ -397,7 +401,7 @@ class FederationTestCase(unittest.HomeserverTestCase): @unittest.override_config( {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} ) - def test_invite_by_user_ratelimit(self): + def test_invite_by_user_ratelimit(self) -> None: """Tests that invites from federation to a particular user are actually rate-limited. """ @@ -446,7 +450,9 @@ class FederationTestCase(unittest.HomeserverTestCase): exc=LimitExceededError, ) - def _build_and_send_join_event(self, other_server, other_user, room_id): + def _build_and_send_join_event( + self, other_server: str, other_user: str, room_id: str + ) -> EventBase: join_event = self.get_success( self.handler.on_make_join_request(other_server, room_id, other_user) ) @@ -469,7 +475,7 @@ class FederationTestCase(unittest.HomeserverTestCase): class EventFromPduTestCase(TestCase): - def test_valid_json(self): + def test_valid_json(self) -> None: """Valid JSON should be turned into an event.""" ev = event_from_pdu_json( { @@ -487,7 +493,7 @@ class EventFromPduTestCase(TestCase): self.assertIsInstance(ev, EventBase) - def test_invalid_numbers(self): + def test_invalid_numbers(self) -> None: """Invalid values for an integer should be rejected, all floats should be rejected.""" for value in [ -(2 ** 53), @@ -512,7 +518,7 @@ class EventFromPduTestCase(TestCase): RoomVersions.V6, ) - def test_invalid_nested(self): + def test_invalid_nested(self) -> None: """List and dictionaries are recursively searched.""" with self.assertRaises(SynapseError): event_from_pdu_json( diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index e8418b6638..014815db6e 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -13,14 +13,18 @@ # limitations under the License. import json import os +from typing import Any, Dict from unittest.mock import ANY, Mock, patch from urllib.parse import parse_qs, urlparse import pymacaroons +from twisted.test.proto_helpers import MemoryReactor + from synapse.handlers.sso import MappingException from synapse.server import HomeServer -from synapse.types import UserID +from synapse.types import JsonDict, UserID +from synapse.util import Clock from synapse.util.macaroons import get_value_from_macaroon from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock @@ -98,7 +102,7 @@ class TestMappingProviderFailures(TestMappingProvider): } -async def get_json(url): +async def get_json(url: str) -> JsonDict: # Mock get_json calls to handle jwks & oidc discovery endpoints if url == WELL_KNOWN: # Minimal discovery document, as defined in OpenID.Discovery @@ -116,6 +120,8 @@ async def get_json(url): elif url == JWKS_URI: return {"keys": []} + return {} + def _key_file_path() -> str: """path to a file containing the private half of a test key""" @@ -147,12 +153,12 @@ class OidcHandlerTestCase(HomeserverTestCase): if not HAS_OIDC: skip = "requires OIDC" - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["public_baseurl"] = BASE_URL return config - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.http_client = Mock(spec=["get_json"]) self.http_client.get_json.side_effect = get_json self.http_client.user_agent = b"Synapse Test" @@ -164,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase): sso_handler = hs.get_sso_handler() # Mock the render error method. self.render_error = Mock(return_value=None) - sso_handler.render_error = self.render_error + sso_handler.render_error = self.render_error # type: ignore[assignment] # Reduce the number of attempts when generating MXIDs. sso_handler._MAP_USERNAME_RETRIES = 3 @@ -193,14 +199,14 @@ class OidcHandlerTestCase(HomeserverTestCase): return args @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_config(self): + def test_config(self) -> None: """Basic config correctly sets up the callback URL and client auth correctly.""" self.assertEqual(self.provider._callback_url, CALLBACK_URL) self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID) self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET) @override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}}) - def test_discovery(self): + def test_discovery(self) -> None: """The handler should discover the endpoints from OIDC discovery document.""" # This would throw if some metadata were invalid metadata = self.get_success(self.provider.load_metadata()) @@ -219,13 +225,13 @@ class OidcHandlerTestCase(HomeserverTestCase): self.http_client.get_json.assert_not_called() @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) - def test_no_discovery(self): + def test_no_discovery(self) -> None: """When discovery is disabled, it should not try to load from discovery document.""" self.get_success(self.provider.load_metadata()) self.http_client.get_json.assert_not_called() @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) - def test_load_jwks(self): + def test_load_jwks(self) -> None: """JWKS loading is done once (then cached) if used.""" jwks = self.get_success(self.provider.load_jwks()) self.http_client.get_json.assert_called_once_with(JWKS_URI) @@ -253,7 +259,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_failure(self.provider.load_jwks(force=True), RuntimeError) @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_validate_config(self): + def test_validate_config(self) -> None: """Provider metadatas are extensively validated.""" h = self.provider @@ -336,14 +342,14 @@ class OidcHandlerTestCase(HomeserverTestCase): force_load_metadata() @override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}}) - def test_skip_verification(self): + def test_skip_verification(self) -> None: """Provider metadata validation can be disabled by config.""" with self.metadata_edit({"issuer": "http://insecure"}): # This should not throw get_awaitable_result(self.provider.load_metadata()) @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_redirect_request(self): + def test_redirect_request(self) -> None: """The redirect request has the right arguments & generates a valid session cookie.""" req = Mock(spec=["cookies"]) req.cookies = [] @@ -387,7 +393,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(redirect, "http://client/redirect") @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_callback_error(self): + def test_callback_error(self) -> None: """Errors from the provider returned in the callback are displayed.""" request = Mock(args={}) request.args[b"error"] = [b"invalid_client"] @@ -399,7 +405,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertRenderedError("invalid_client", "some description") @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_callback(self): + def test_callback(self) -> None: """Code callback works and display errors if something went wrong. A lot of scenarios are tested here: @@ -428,9 +434,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": username, } expected_user_id = "@%s:%s" % (username, self.hs.hostname) - self.provider._exchange_code = simple_async_mock(return_value=token) - self.provider._parse_id_token = simple_async_mock(return_value=userinfo) - self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) + self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] + self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] + self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment] auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() @@ -468,7 +474,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertRenderedError("mapping_error") # Handle ID token errors - self.provider._parse_id_token = simple_async_mock(raises=Exception()) + self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment] self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_token") @@ -483,7 +489,7 @@ class OidcHandlerTestCase(HomeserverTestCase): "type": "bearer", "access_token": "access_token", } - self.provider._exchange_code = simple_async_mock(return_value=token) + self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( @@ -510,8 +516,8 @@ class OidcHandlerTestCase(HomeserverTestCase): id_token = { "sid": "abcdefgh", } - self.provider._parse_id_token = simple_async_mock(return_value=id_token) - self.provider._exchange_code = simple_async_mock(return_value=token) + self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment] + self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] auth_handler.complete_sso_login.reset_mock() self.provider._fetch_userinfo.reset_mock() self.get_success(self.handler.handle_oidc_callback(request)) @@ -531,21 +537,21 @@ class OidcHandlerTestCase(HomeserverTestCase): self.render_error.assert_not_called() # Handle userinfo fetching error - self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) + self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment] self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("fetch_error") # Handle code exchange failure from synapse.handlers.oidc import OidcError - self.provider._exchange_code = simple_async_mock( + self.provider._exchange_code = simple_async_mock( # type: ignore[assignment] raises=OidcError("invalid_request") ) self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_request") @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_callback_session(self): + def test_callback_session(self) -> None: """The callback verifies the session presence and validity""" request = Mock(spec=["args", "getCookie", "cookies"]) @@ -590,7 +596,7 @@ class OidcHandlerTestCase(HomeserverTestCase): @override_config( {"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}} ) - def test_exchange_code(self): + def test_exchange_code(self) -> None: """Code exchange behaves correctly and handles various error scenarios.""" token = {"type": "bearer"} token_json = json.dumps(token).encode("utf-8") @@ -686,7 +692,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } } ) - def test_exchange_code_jwt_key(self): + def test_exchange_code_jwt_key(self) -> None: """Test that code exchange works with a JWK client secret.""" from authlib.jose import jwt @@ -741,7 +747,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } } ) - def test_exchange_code_no_auth(self): + def test_exchange_code_no_auth(self) -> None: """Test that code exchange works with no client secret.""" token = {"type": "bearer"} self.http_client.request = simple_async_mock( @@ -776,7 +782,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } } ) - def test_extra_attributes(self): + def test_extra_attributes(self) -> None: """ Login while using a mapping provider that implements get_extra_attributes. """ @@ -790,8 +796,8 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "foo", "phone": "1234567", } - self.provider._exchange_code = simple_async_mock(return_value=token) - self.provider._parse_id_token = simple_async_mock(return_value=userinfo) + self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] + self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() @@ -817,12 +823,12 @@ class OidcHandlerTestCase(HomeserverTestCase): ) @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_map_userinfo_to_user(self): + def test_map_userinfo_to_user(self) -> None: """Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() - userinfo = { + userinfo: dict = { "sub": "test_user", "username": "test_user", } @@ -870,7 +876,7 @@ class OidcHandlerTestCase(HomeserverTestCase): ) @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}}) - def test_map_userinfo_to_existing_user(self): + def test_map_userinfo_to_existing_user(self) -> None: """Existing users can log in with OpenID Connect when allow_existing_users is True.""" store = self.hs.get_datastores().main user = UserID.from_string("@test_user:test") @@ -974,7 +980,7 @@ class OidcHandlerTestCase(HomeserverTestCase): ) @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_map_userinfo_to_invalid_localpart(self): + def test_map_userinfo_to_invalid_localpart(self) -> None: """If the mapping provider generates an invalid localpart it should be rejected.""" self.get_success( _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"}) @@ -991,7 +997,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } } ) - def test_map_userinfo_to_user_retries(self): + def test_map_userinfo_to_user_retries(self) -> None: """The mapping provider can retry generating an MXID if the MXID is already in use.""" auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() @@ -1039,7 +1045,7 @@ class OidcHandlerTestCase(HomeserverTestCase): ) @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_empty_localpart(self): + def test_empty_localpart(self) -> None: """Attempts to map onto an empty localpart should be rejected.""" userinfo = { "sub": "tester", @@ -1058,7 +1064,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } } ) - def test_null_localpart(self): + def test_null_localpart(self) -> None: """Mapping onto a null localpart via an empty OIDC attribute should be rejected""" userinfo = { "sub": "tester", @@ -1075,7 +1081,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } } ) - def test_attribute_requirements(self): + def test_attribute_requirements(self) -> None: """The required attributes must be met from the OIDC userinfo response.""" auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() @@ -1115,7 +1121,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } } ) - def test_attribute_requirements_contains(self): + def test_attribute_requirements_contains(self) -> None: """Test that auth succeeds if userinfo attribute CONTAINS required value""" auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() @@ -1146,7 +1152,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } } ) - def test_attribute_requirements_mismatch(self): + def test_attribute_requirements_mismatch(self) -> None: """ Test that auth fails if attributes exist but don't match, or are non-string values. @@ -1154,7 +1160,7 @@ class OidcHandlerTestCase(HomeserverTestCase): auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() # userinfo with "test": "not_foobar" attribute should fail - userinfo = { + userinfo: dict = { "sub": "tester", "username": "tester", "test": "not_foobar", @@ -1248,9 +1254,9 @@ async def _make_callback_with_userinfo( handler = hs.get_oidc_handler() provider = handler._providers["oidc"] - provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) - provider._parse_id_token = simple_async_mock(return_value=userinfo) - provider._fetch_userinfo = simple_async_mock(return_value=userinfo) + provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment] + provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] + provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment] state = "state" session = handler._token_generator.generate_oidc_session_token( diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 6ddec9ecf1..b2ed9cbe37 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -331,11 +331,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): # Extract presence update user ID and state information into lists of tuples db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]] - presence_states = [(ps.user_id, ps.state) for ps in presence_states] + presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states] # Compare what we put into the storage with what we got out. # They should be identical. - self.assertEqual(presence_states, db_presence_states) + self.assertEqual(presence_states_compare, db_presence_states) class PresenceTimeoutTestCase(unittest.TestCase): @@ -357,6 +357,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) + assert new_state is not None self.assertEqual(new_state.state, PresenceState.UNAVAILABLE) self.assertEqual(new_state.status_msg, status_msg) @@ -380,6 +381,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) + assert new_state is not None self.assertEqual(new_state.state, PresenceState.BUSY) self.assertEqual(new_state.status_msg, status_msg) @@ -399,6 +401,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) + assert new_state is not None self.assertEqual(new_state.state, PresenceState.OFFLINE) self.assertEqual(new_state.status_msg, status_msg) @@ -420,6 +423,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) self.assertIsNotNone(new_state) + assert new_state is not None self.assertEqual(new_state.state, PresenceState.ONLINE) self.assertEqual(new_state.status_msg, status_msg) @@ -477,6 +481,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) self.assertIsNotNone(new_state) + assert new_state is not None self.assertEqual(new_state.state, PresenceState.OFFLINE) self.assertEqual(new_state.status_msg, status_msg) @@ -653,13 +658,13 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase): self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None) def _set_presencestate_with_status_msg( - self, user_id: str, state: PresenceState, status_msg: Optional[str] + self, user_id: str, state: str, status_msg: Optional[str] ): """Set a PresenceState and status_msg and check the result. Args: user_id: User for that the status is to be set. - PresenceState: The new PresenceState. + state: The new PresenceState. status_msg: Status message that is to be set. """ self.get_success( diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 972cbac6e4..08733a9f2d 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -11,14 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any, Awaitable, Callable, Dict from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + import synapse.types from synapse.api.errors import AuthError, SynapseError from synapse.rest import admin from synapse.server import HomeServer -from synapse.types import UserID +from synapse.types import JsonDict, UserID +from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable @@ -29,13 +32,15 @@ class ProfileTestCase(unittest.HomeserverTestCase): servlets = [admin.register_servlets] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.mock_federation = Mock() self.mock_registry = Mock() - self.query_handlers = {} + self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {} - def register_query_handler(query_type, handler): + def register_query_handler( + query_type: str, handler: Callable[[dict], Awaitable[JsonDict]] + ) -> None: self.query_handlers[query_type] = handler self.mock_registry.register_query_handler = register_query_handler @@ -47,7 +52,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) return hs - def prepare(self, reactor, clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.frank = UserID.from_string("@1234abcd:test") @@ -58,7 +63,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.handler = hs.get_profile_handler() - def test_get_my_name(self): + def test_get_my_name(self) -> None: self.get_success( self.store.set_profile_displayname(self.frank.localpart, "Frank") ) @@ -67,7 +72,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertEqual("Frank", displayname) - def test_set_my_name(self): + def test_set_my_name(self) -> None: self.get_success( self.handler.set_displayname( self.frank, synapse.types.create_requester(self.frank), "Frank Jr." @@ -110,7 +115,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.get_success(self.store.get_profile_displayname(self.frank.localpart)) ) - def test_set_my_name_if_disabled(self): + def test_set_my_name_if_disabled(self) -> None: self.hs.config.registration.enable_set_displayname = False # Setting displayname for the first time is allowed @@ -135,7 +140,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): SynapseError, ) - def test_set_my_name_noauth(self): + def test_set_my_name_noauth(self) -> None: self.get_failure( self.handler.set_displayname( self.frank, synapse.types.create_requester(self.bob), "Frank Jr." @@ -143,7 +148,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): AuthError, ) - def test_get_other_name(self): + def test_get_other_name(self) -> None: self.mock_federation.make_query.return_value = make_awaitable( {"displayname": "Alice"} ) @@ -158,7 +163,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ignore_backoff=True, ) - def test_incoming_fed_query(self): + def test_incoming_fed_query(self) -> None: self.get_success(self.store.create_profile("caroline")) self.get_success(self.store.set_profile_displayname("caroline", "Caroline")) @@ -174,7 +179,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertEqual({"displayname": "Caroline"}, response) - def test_get_my_avatar(self): + def test_get_my_avatar(self) -> None: self.get_success( self.store.set_profile_avatar_url( self.frank.localpart, "http://my.server/me.png" @@ -184,7 +189,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertEqual("http://my.server/me.png", avatar_url) - def test_set_my_avatar(self): + def test_set_my_avatar(self) -> None: self.get_success( self.handler.set_avatar_url( self.frank, @@ -225,7 +230,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), ) - def test_set_my_avatar_if_disabled(self): + def test_set_my_avatar_if_disabled(self) -> None: self.hs.config.registration.enable_set_avatar_url = False # Setting displayname for the first time is allowed @@ -250,7 +255,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): SynapseError, ) - def test_avatar_constraints_no_config(self): + def test_avatar_constraints_no_config(self) -> None: """Tests that the method to check an avatar against configured constraints skips all of its check if no constraint is configured. """ @@ -263,7 +268,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertTrue(res) @unittest.override_config({"max_avatar_size": 50}) - def test_avatar_constraints_missing(self): + def test_avatar_constraints_missing(self) -> None: """Tests that an avatar isn't allowed if the file at the given MXC URI couldn't be found. """ @@ -273,7 +278,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertFalse(res) @unittest.override_config({"max_avatar_size": 50}) - def test_avatar_constraints_file_size(self): + def test_avatar_constraints_file_size(self) -> None: """Tests that a file that's above the allowed file size is forbidden but one that's below it is allowed. """ @@ -295,7 +300,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertFalse(res) @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]}) - def test_avatar_constraint_mime_type(self): + def test_avatar_constraint_mime_type(self) -> None: """Tests that a file with an unauthorised MIME type is forbidden but one with an authorised content type is allowed. """ diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 23941abed8..8d4404eda1 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Any, Dict, Optional from unittest.mock import Mock import attr +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.errors import RedirectException +from synapse.server import HomeServer +from synapse.util import Clock from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config @@ -81,10 +85,10 @@ class TestRedirectMappingProvider(TestMappingProvider): class SamlHandlerTestCase(HomeserverTestCase): - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["public_baseurl"] = BASE_URL - saml_config = { + saml_config: Dict[str, Any] = { "sp_config": {"metadata": {}}, # Disable grandfathering. "grandfathered_mxid_source_attribute": None, @@ -98,7 +102,7 @@ class SamlHandlerTestCase(HomeserverTestCase): return config - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver() self.handler = hs.get_saml_handler() @@ -114,7 +118,7 @@ class SamlHandlerTestCase(HomeserverTestCase): elif not has_xmlsec1: skip = "Requires xmlsec1" - def test_map_saml_response_to_user(self): + def test_map_saml_response_to_user(self) -> None: """Ensure that mapping the SAML response returned from a provider to an MXID works properly.""" # stub out the auth handler @@ -140,7 +144,7 @@ class SamlHandlerTestCase(HomeserverTestCase): ) @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) - def test_map_saml_response_to_existing_user(self): + def test_map_saml_response_to_existing_user(self) -> None: """Existing users can log in with SAML account.""" store = self.hs.get_datastores().main self.get_success( @@ -186,7 +190,7 @@ class SamlHandlerTestCase(HomeserverTestCase): auth_provider_session_id=None, ) - def test_map_saml_response_to_invalid_localpart(self): + def test_map_saml_response_to_invalid_localpart(self) -> None: """If the mapping provider generates an invalid localpart it should be rejected.""" # stub out the auth handler @@ -207,7 +211,7 @@ class SamlHandlerTestCase(HomeserverTestCase): ) auth_handler.complete_sso_login.assert_not_called() - def test_map_saml_response_to_user_retries(self): + def test_map_saml_response_to_user_retries(self) -> None: """The mapping provider can retry generating an MXID if the MXID is already in use.""" # stub out the auth handler and error renderer @@ -271,7 +275,7 @@ class SamlHandlerTestCase(HomeserverTestCase): } } ) - def test_map_saml_response_redirect(self): + def test_map_saml_response_redirect(self) -> None: """Test a mapping provider that raises a RedirectException""" saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"}) @@ -292,7 +296,7 @@ class SamlHandlerTestCase(HomeserverTestCase): }, } ) - def test_attribute_requirements(self): + def test_attribute_requirements(self) -> None: """The required attributes must be met from the SAML response.""" # stub out the auth handler diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index f91a80b9fa..ffd5c4cb93 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -18,11 +18,14 @@ from typing import Dict from unittest.mock import ANY, Mock, call from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource from synapse.api.errors import AuthError from synapse.federation.transport.server import TransportLayerServer -from synapse.types import UserID, create_requester +from synapse.server import HomeServer +from synapse.types import JsonDict, UserID, create_requester +from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable @@ -42,7 +45,9 @@ ROOM_ID = "a-room" OTHER_ROOM_ID = "another-room" -def _expect_edu_transaction(edu_type, content, origin="test"): +def _expect_edu_transaction( + edu_type: str, content: JsonDict, origin: str = "test" +) -> JsonDict: return { "origin": origin, "origin_server_ts": 1000000, @@ -51,12 +56,12 @@ def _expect_edu_transaction(edu_type, content, origin="test"): } -def _make_edu_transaction_json(edu_type, content): +def _make_edu_transaction_json(edu_type: str, content: JsonDict) -> bytes: return json.dumps(_expect_edu_transaction(edu_type, content)).encode("utf8") class TypingNotificationsTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # we mock out the keyring so as to skip the authentication check on the # federation API call. mock_keyring = Mock(spec=["verify_json_for_server"]) @@ -83,7 +88,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): d["/_matrix/federation"] = TransportLayerServer(self.hs) return d - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: mock_notifier = hs.get_notifier() self.on_new_event = mock_notifier.on_new_event @@ -111,24 +116,24 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.room_members = [] - async def check_user_in_room(room_id, user_id): + async def check_user_in_room(room_id: str, user_id: str) -> None: if user_id not in [u.to_string() for u in self.room_members]: raise AuthError(401, "User is not in the room") return None hs.get_auth().check_user_in_room = check_user_in_room - async def check_host_in_room(room_id, server_name): + async def check_host_in_room(room_id: str, server_name: str) -> bool: return room_id == ROOM_ID hs.get_event_auth_handler().check_host_in_room = check_host_in_room - def get_joined_hosts_for_room(room_id): + def get_joined_hosts_for_room(room_id: str): return {member.domain for member in self.room_members} self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room - async def get_users_in_room(room_id): + async def get_users_in_room(room_id: str): return {str(u) for u in self.room_members} self.datastore.get_users_in_room = get_users_in_room @@ -153,7 +158,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): lambda *args, **kwargs: make_awaitable(None) ) - def test_started_typing_local(self): + def test_started_typing_local(self) -> None: self.room_members = [U_APPLE, U_BANANA] self.assertEqual(self.event_source.get_current_key(), 0) @@ -187,7 +192,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) @override_config({"send_federation": True}) - def test_started_typing_remote_send(self): + def test_started_typing_remote_send(self) -> None: self.room_members = [U_APPLE, U_ONION] self.get_success( @@ -217,7 +222,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): try_trailing_slash_on_400=True, ) - def test_started_typing_remote_recv(self): + def test_started_typing_remote_recv(self) -> None: self.room_members = [U_APPLE, U_ONION] self.assertEqual(self.event_source.get_current_key(), 0) @@ -256,7 +261,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ], ) - def test_started_typing_remote_recv_not_in_room(self): + def test_started_typing_remote_recv_not_in_room(self) -> None: self.room_members = [U_APPLE, U_ONION] self.assertEqual(self.event_source.get_current_key(), 0) @@ -292,7 +297,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEqual(events[1], 0) @override_config({"send_federation": True}) - def test_stopped_typing(self): + def test_stopped_typing(self) -> None: self.room_members = [U_APPLE, U_BANANA, U_ONION] # Gut-wrenching @@ -343,7 +348,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], ) - def test_typing_timeout(self): + def test_typing_timeout(self) -> None: self.room_members = [U_APPLE, U_BANANA] self.assertEqual(self.event_source.get_current_key(), 0) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index 6691e07128..ba158f5d93 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -11,15 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple +from typing import Any, Dict, List, Optional, Tuple from unittest.mock import Mock from twisted.internet.defer import Deferred +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.logging.context import make_deferred_yieldable from synapse.push import PusherConfigException from synapse.rest.client import login, push_rule, receipts, room +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests.unittest import HomeserverTestCase, override_config @@ -35,13 +39,13 @@ class HTTPPusherTests(HomeserverTestCase): user_id = True hijack_auth = False - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["start_pushers"] = True return config - def make_homeserver(self, reactor, clock): - self.push_attempts: List[tuple[Deferred, str, dict]] = [] + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.push_attempts: List[Tuple[Deferred, str, dict]] = [] m = Mock() @@ -56,7 +60,7 @@ class HTTPPusherTests(HomeserverTestCase): return hs - def test_invalid_configuration(self): + def test_invalid_configuration(self) -> None: """Invalid push configurations should be rejected.""" # Register the user who gets notified user_id = self.register_user("user", "pass") @@ -68,7 +72,7 @@ class HTTPPusherTests(HomeserverTestCase): ) token_id = user_tuple.token_id - def test_data(data): + def test_data(data: Optional[JsonDict]) -> None: self.get_failure( self.hs.get_pusherpool().add_pusher( user_id=user_id, @@ -95,7 +99,7 @@ class HTTPPusherTests(HomeserverTestCase): # A url with an incorrect path isn't accepted. test_data({"url": "http://example.com/foo"}) - def test_sends_http(self): + def test_sends_http(self) -> None: """ The HTTP pusher will send pushes for each message to a HTTP endpoint when configured to do so. @@ -200,7 +204,7 @@ class HTTPPusherTests(HomeserverTestCase): self.assertEqual(len(pushers), 1) self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) - def test_sends_high_priority_for_encrypted(self): + def test_sends_high_priority_for_encrypted(self) -> None: """ The HTTP pusher will send pushes at high priority if they correspond to an encrypted message. @@ -321,7 +325,7 @@ class HTTPPusherTests(HomeserverTestCase): ) self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high") - def test_sends_high_priority_for_one_to_one_only(self): + def test_sends_high_priority_for_one_to_one_only(self) -> None: """ The HTTP pusher will send pushes at high priority if they correspond to a message in a one-to-one room. @@ -404,7 +408,7 @@ class HTTPPusherTests(HomeserverTestCase): # check that this is low-priority self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") - def test_sends_high_priority_for_mention(self): + def test_sends_high_priority_for_mention(self) -> None: """ The HTTP pusher will send pushes at high priority if they correspond to a message containing the user's display name. @@ -480,7 +484,7 @@ class HTTPPusherTests(HomeserverTestCase): # check that this is low-priority self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") - def test_sends_high_priority_for_atroom(self): + def test_sends_high_priority_for_atroom(self) -> None: """ The HTTP pusher will send pushes at high priority if they correspond to a message that contains @room. @@ -563,7 +567,7 @@ class HTTPPusherTests(HomeserverTestCase): # check that this is low-priority self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") - def test_push_unread_count_group_by_room(self): + def test_push_unread_count_group_by_room(self) -> None: """ The HTTP pusher will group unread count by number of unread rooms. """ @@ -576,7 +580,7 @@ class HTTPPusherTests(HomeserverTestCase): self._check_push_attempt(6, 1) @override_config({"push": {"group_unread_count_by_room": False}}) - def test_push_unread_count_message_count(self): + def test_push_unread_count_message_count(self) -> None: """ The HTTP pusher will send the total unread message count. """ @@ -589,7 +593,7 @@ class HTTPPusherTests(HomeserverTestCase): # last read receipt self._check_push_attempt(6, 3) - def _test_push_unread_count(self): + def _test_push_unread_count(self) -> None: """ Tests that the correct unread count appears in sent push notifications @@ -681,7 +685,7 @@ class HTTPPusherTests(HomeserverTestCase): self.helper.send(room_id, body="HELLO???", tok=other_access_token) - def _advance_time_and_make_push_succeed(self, expected_push_attempts): + def _advance_time_and_make_push_succeed(self, expected_push_attempts: int) -> None: self.pump() self.push_attempts[expected_push_attempts - 1][0].callback({}) @@ -708,7 +712,9 @@ class HTTPPusherTests(HomeserverTestCase): expected_unread_count_last_push, ) - def _send_read_request(self, access_token, message_event_id, room_id): + def _send_read_request( + self, access_token: str, message_event_id: str, room_id: str + ) -> None: # Now set the user's read receipt position to the first event # # This will actually trigger a new notification to be sent out so that @@ -748,7 +754,7 @@ class HTTPPusherTests(HomeserverTestCase): return user_id, access_token - def test_dont_notify_rule_overrides_message(self): + def test_dont_notify_rule_overrides_message(self) -> None: """ The override push rule will suppress notification """ diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 3849beb9d6..5dba187076 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Dict, Optional, Union import frozendict @@ -20,12 +20,13 @@ from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent from synapse.push import push_rule_evaluator from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent +from synapse.types import JsonDict from tests import unittest class PushRuleEvaluatorTestCase(unittest.TestCase): - def _get_evaluator(self, content): + def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluatorForEvent: event = FrozenEvent( { "event_id": "$event_id", @@ -39,12 +40,12 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): ) room_member_count = 0 sender_power_level = 0 - power_levels = {} + power_levels: Dict[str, Union[int, Dict[str, int]]] = {} return PushRuleEvaluatorForEvent( event, room_member_count, sender_power_level, power_levels ) - def test_display_name(self): + def test_display_name(self) -> None: """Check for a matching display name in the body of the event.""" evaluator = self._get_evaluator({"body": "foo bar baz"}) @@ -71,20 +72,20 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar")) def _assert_matches( - self, condition: Dict[str, Any], content: Dict[str, Any], msg=None + self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None ) -> None: evaluator = self._get_evaluator(content) self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"), msg) def _assert_not_matches( - self, condition: Dict[str, Any], content: Dict[str, Any], msg=None + self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None ) -> None: evaluator = self._get_evaluator(content) self.assertFalse( evaluator.matches(condition, "@user:test", "display_name"), msg ) - def test_event_match_body(self): + def test_event_match_body(self) -> None: """Check that event_match conditions on content.body work as expected""" # if the key is `content.body`, the pattern matches substrings. @@ -165,7 +166,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): r"? after \ should match any character", ) - def test_event_match_non_body(self): + def test_event_match_non_body(self) -> None: """Check that event_match conditions on other keys work as expected""" # if the key is anything other than 'content.body', the pattern must match the @@ -241,7 +242,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): "pattern should not match before a newline", ) - def test_no_body(self): + def test_no_body(self) -> None: """Not having a body shouldn't break the evaluator.""" evaluator = self._get_evaluator({}) @@ -250,7 +251,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): } self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) - def test_invalid_body(self): + def test_invalid_body(self) -> None: """A non-string body should not break the evaluator.""" condition = { "kind": "contains_display_name", @@ -260,7 +261,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): evaluator = self._get_evaluator({"body": body}) self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) - def test_tweaks_for_actions(self): + def test_tweaks_for_actions(self) -> None: """ This tests the behaviour of tweaks_for_actions. """ diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 171f4e97c8..329690f8f7 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -15,7 +15,7 @@ import itertools import urllib.parse -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple from unittest.mock import patch from twisted.test.proto_helpers import MemoryReactor @@ -79,6 +79,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase): content: Optional[dict] = None, access_token: Optional[str] = None, parent_id: Optional[str] = None, + expected_response_code: int = 200, ) -> FakeChannel: """Helper function to send a relation pointing at `self.parent_id` @@ -115,16 +116,60 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase): content, access_token=access_token, ) + self.assertEqual(expected_response_code, channel.code, channel.json_body) return channel + def _get_related_events(self) -> List[str]: + """ + Requests /relations on the parent ID and returns a list of event IDs. + """ + # Request the relations of the event. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + return [ev["event_id"] for ev in channel.json_body["chunk"]] + + def _get_bundled_aggregations(self) -> JsonDict: + """ + Requests /event on the parent ID and returns the m.relations field (from unsigned), if it exists. + """ + # Fetch the bundled aggregations of the event. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + return channel.json_body["unsigned"].get("m.relations", {}) + + def _get_aggregations(self) -> List[JsonDict]: + """Request /aggregations on the parent ID and includes the returned chunk.""" + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + return channel.json_body["chunk"] + + def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: + """ + Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. + """ + for event in events: + if event["event_id"] == self.parent_id: + return event + + raise AssertionError(f"Event {self.parent_id} not found in chunk") + class RelationsTestCase(BaseRelationsTestCase): def test_send_relation(self) -> None: """Tests that sending a relation works.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") - self.assertEqual(200, channel.code, channel.json_body) - event_id = channel.json_body["event_id"] channel = self.make_request( @@ -151,13 +196,13 @@ class RelationsTestCase(BaseRelationsTestCase): def test_deny_invalid_event(self) -> None: """Test that we deny relations on non-existant events""" - channel = self._send_relation( + self._send_relation( RelationTypes.ANNOTATION, EventTypes.Message, parent_id="foo", content={"body": "foo", "msgtype": "m.text"}, + expected_response_code=400, ) - self.assertEqual(400, channel.code, channel.json_body) # Unless that event is referenced from another event! self.get_success( @@ -171,13 +216,12 @@ class RelationsTestCase(BaseRelationsTestCase): desc="test_deny_invalid_event", ) ) - channel = self._send_relation( + self._send_relation( RelationTypes.THREAD, EventTypes.Message, parent_id="foo", content={"body": "foo", "msgtype": "m.text"}, ) - self.assertEqual(200, channel.code, channel.json_body) def test_deny_invalid_room(self) -> None: """Test that we deny relations on non-existant events""" @@ -187,18 +231,20 @@ class RelationsTestCase(BaseRelationsTestCase): parent_id = res["event_id"] # Attempt to send an annotation to that event. - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A" + self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + parent_id=parent_id, + key="A", + expected_response_code=400, ) - self.assertEqual(400, channel.code, channel.json_body) def test_deny_double_react(self) -> None: """Test that we deny relations on membership events""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(400, channel.code, channel.json_body) + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", expected_response_code=400 + ) def test_deny_forked_thread(self) -> None: """It is invalid to start a thread off a thread.""" @@ -208,316 +254,24 @@ class RelationsTestCase(BaseRelationsTestCase): content={"msgtype": "m.text", "body": "foo"}, parent_id=self.parent_id, ) - self.assertEqual(200, channel.code, channel.json_body) parent_id = channel.json_body["event_id"] - channel = self._send_relation( + self._send_relation( RelationTypes.THREAD, "m.room.message", content={"msgtype": "m.text", "body": "foo"}, parent_id=parent_id, + expected_response_code=400, ) - self.assertEqual(400, channel.code, channel.json_body) - - def test_basic_paginate_relations(self) -> None: - """Tests that calling pagination API correctly the latest relations.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(200, channel.code, channel.json_body) - first_annotation_id = channel.json_body["event_id"] - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEqual(200, channel.code, channel.json_body) - second_annotation_id = channel.json_body["event_id"] - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - # We expect to get back a single pagination result, which is the latest - # full relation event we sent above. - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - self.assert_dict( - { - "event_id": second_annotation_id, - "sender": self.user_id, - "type": "m.reaction", - }, - channel.json_body["chunk"][0], - ) - - # We also expect to get the original event (the id of which is self.parent_id) - self.assertEqual( - channel.json_body["original_event"]["event_id"], self.parent_id - ) - - # Make sure next_batch has something in it that looks like it could be a - # valid token. - self.assertIsInstance( - channel.json_body.get("next_batch"), str, channel.json_body - ) - - # Request the relations again, but with a different direction. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations" - f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - # We expect to get back a single pagination result, which is the earliest - # full relation event we sent above. - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - self.assert_dict( - { - "event_id": first_annotation_id, - "sender": self.user_id, - "type": "m.reaction", - }, - channel.json_body["chunk"][0], - ) - - def test_repeated_paginate_relations(self) -> None: - """Test that if we paginate using a limit and tokens then we get the - expected events. - """ - - expected_event_ids = [] - for idx in range(10): - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx) - ) - self.assertEqual(200, channel.code, channel.json_body) - expected_event_ids.append(channel.json_body["event_id"]) - - prev_token = "" - found_event_ids: List[str] = [] - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEqual(found_event_ids, expected_event_ids) - - def test_pagination_from_sync_and_messages(self) -> None: - """Pagination tokens from /sync and /messages can be used to paginate /relations.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A") - self.assertEqual(200, channel.code, channel.json_body) - annotation_id = channel.json_body["event_id"] - # Send an event after the relation events. - self.helper.send(self.room, body="Latest event", tok=self.user_token) - - # Request /sync, limiting it such that only the latest event is returned - # (and not the relation). - filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}') - channel = self.make_request( - "GET", f"/sync?filter={filter}", access_token=self.user_token - ) - self.assertEqual(200, channel.code, channel.json_body) - room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] - sync_prev_batch = room_timeline["prev_batch"] - self.assertIsNotNone(sync_prev_batch) - # Ensure the relation event is not in the batch returned from /sync. - self.assertNotIn( - annotation_id, [ev["event_id"] for ev in room_timeline["events"]] - ) - - # Request /messages, limiting it such that only the latest event is - # returned (and not the relation). - channel = self.make_request( - "GET", - f"/rooms/{self.room}/messages?dir=b&limit=1", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - messages_end = channel.json_body["end"] - self.assertIsNotNone(messages_end) - # Ensure the relation event is not in the chunk returned from /messages. - self.assertNotIn( - annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] - ) - - # Request /relations with the pagination tokens received from both the - # /sync and /messages responses above, in turn. - # - # This is a tiny bit silly since the client wouldn't know the parent ID - # from the requests above; consider the parent ID to be known from a - # previous /sync. - for from_token in (sync_prev_batch, messages_end): - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - # The relation should be in the returned chunk. - self.assertIn( - annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] - ) - - def test_aggregation_pagination_groups(self) -> None: - """Test that we can paginate annotation groups correctly.""" - - # We need to create ten separate users to send each reaction. - access_tokens = [self.user_token, self.user2_token] - idx = 0 - while len(access_tokens) < 10: - user_id, token = self._create_user("test" + str(idx)) - idx += 1 - - self.helper.join(self.room, user=user_id, tok=token) - access_tokens.append(token) - - idx = 0 - sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1} - for key in itertools.chain.from_iterable( - itertools.repeat(key, num) for key, num in sent_groups.items() - ): - channel = self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key=key, - access_token=access_tokens[idx], - ) - self.assertEqual(200, channel.code, channel.json_body) - - idx += 1 - idx %= len(access_tokens) - - prev_token: Optional[str] = None - found_groups: Dict[str, int] = {} - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - for groups in channel.json_body["chunk"]: - # We only expect reactions - self.assertEqual(groups["type"], "m.reaction", channel.json_body) - - # We should only see each key once - self.assertNotIn(groups["key"], found_groups, channel.json_body) - - found_groups[groups["key"]] = groups["count"] - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - self.assertEqual(sent_groups, found_groups) - - def test_aggregation_pagination_within_group(self) -> None: - """Test that we can paginate within an annotation group.""" - - # We need to create ten separate users to send each reaction. - access_tokens = [self.user_token, self.user2_token] - idx = 0 - while len(access_tokens) < 10: - user_id, token = self._create_user("test" + str(idx)) - idx += 1 - - self.helper.join(self.room, user=user_id, tok=token) - access_tokens.append(token) - - idx = 0 - expected_event_ids = [] - for _ in range(10): - channel = self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key="👍", - access_token=access_tokens[idx], - ) - self.assertEqual(200, channel.code, channel.json_body) - expected_event_ids.append(channel.json_body["event_id"]) - - idx += 1 - - # Also send a different type of reaction so that we test we don't see it - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - self.assertEqual(200, channel.code, channel.json_body) - - prev_token = "" - found_event_ids: List[str] = [] - encoded_key = urllib.parse.quote_plus("👍".encode()) - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}" - f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}" - f"/m.reaction/{encoded_key}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEqual(found_event_ids, expected_event_ids) def test_aggregation(self) -> None: """Test that annotations get correctly aggregated.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation( + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEqual(200, channel.code, channel.json_body) + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") channel = self.make_request( "GET", @@ -547,215 +301,6 @@ class RelationsTestCase(BaseRelationsTestCase): ) self.assertEqual(400, channel.code, channel.json_body) - @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) - def test_bundled_aggregations(self) -> None: - """ - Test that annotations, references, and threads get correctly bundled. - - Note that this doesn't test against /relations since only thread relations - get bundled via that API. See test_aggregation_get_event_for_thread. - - See test_edit for a similar test for edits. - """ - # Setup by sending a variety of relations. - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) - reply_1 = channel.json_body["event_id"] - - channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) - reply_2 = channel.json_body["event_id"] - - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) - thread_2 = channel.json_body["event_id"] - - def assert_bundle(event_json: JsonDict) -> None: - """Assert the expected values of the bundled aggregations.""" - relations_dict = event_json["unsigned"].get("m.relations") - - # Ensure the fields are as expected. - self.assertCountEqual( - relations_dict.keys(), - ( - RelationTypes.ANNOTATION, - RelationTypes.REFERENCE, - RelationTypes.THREAD, - ), - ) - - # Check the values of each field. - self.assertEqual( - { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - ] - }, - relations_dict[RelationTypes.ANNOTATION], - ) - - self.assertEqual( - {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, - relations_dict[RelationTypes.REFERENCE], - ) - - self.assertEqual( - 2, - relations_dict[RelationTypes.THREAD].get("count"), - ) - self.assertTrue( - relations_dict[RelationTypes.THREAD].get("current_user_participated") - ) - # The latest thread event has some fields that don't matter. - self.assert_dict( - { - "content": { - "m.relates_to": { - "event_id": self.parent_id, - "rel_type": RelationTypes.THREAD, - } - }, - "event_id": thread_2, - "sender": self.user_id, - "type": "m.room.test", - }, - relations_dict[RelationTypes.THREAD].get("latest_event"), - ) - - # Request the event directly. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(channel.json_body) - - # Request the room messages. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/messages?dir=b", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) - - # Request the room context. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/context/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(channel.json_body["event"]) - - # Request sync. - channel = self.make_request("GET", "/sync", access_token=self.user_token) - self.assertEqual(200, channel.code, channel.json_body) - room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] - self.assertTrue(room_timeline["limited"]) - assert_bundle(self._find_event_in_chunk(room_timeline["events"])) - - # Request search. - channel = self.make_request( - "POST", - "/search", - # Search term matches the parent message. - content={"search_categories": {"room_events": {"search_term": "Hi"}}}, - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - chunk = [ - result["result"] - for result in channel.json_body["search_categories"]["room_events"][ - "results" - ] - ] - assert_bundle(self._find_event_in_chunk(chunk)) - - def test_aggregation_get_event_for_annotation(self) -> None: - """Test that annotations do not get bundled aggregations included - when directly requested. - """ - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(200, channel.code, channel.json_body) - annotation_id = channel.json_body["event_id"] - - # Annotate the annotation. - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id - ) - self.assertEqual(200, channel.code, channel.json_body) - - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{annotation_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) - - def test_aggregation_get_event_for_thread(self) -> None: - """Test that threads get bundled aggregations included when directly requested.""" - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) - thread_id = channel.json_body["event_id"] - - # Annotate the annotation. - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id - ) - self.assertEqual(200, channel.code, channel.json_body) - - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{thread_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual( - channel.json_body["unsigned"].get("m.relations"), - { - RelationTypes.ANNOTATION: { - "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] - }, - }, - ) - - # It should also be included when the entire thread is requested. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual(len(channel.json_body["chunk"]), 1) - - thread_message = channel.json_body["chunk"][0] - self.assertEqual( - thread_message["unsigned"].get("m.relations"), - { - RelationTypes.ANNOTATION: { - "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] - }, - }, - ) - def test_ignore_invalid_room(self) -> None: """Test that we ignore invalid relations over federation.""" # Create another room and send a message in it. @@ -877,8 +422,6 @@ class RelationsTestCase(BaseRelationsTestCase): "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, ) - self.assertEqual(200, channel.code, channel.json_body) - edit_event_id = channel.json_body["event_id"] def assert_bundle(event_json: JsonDict) -> None: @@ -954,7 +497,7 @@ class RelationsTestCase(BaseRelationsTestCase): shouldn't be allowed, are correctly handled. """ - channel = self._send_relation( + self._send_relation( RelationTypes.REPLACE, "m.room.message", content={ @@ -963,7 +506,6 @@ class RelationsTestCase(BaseRelationsTestCase): "m.new_content": {"msgtype": "m.text", "body": "First edit"}, }, ) - self.assertEqual(200, channel.code, channel.json_body) new_body = {"msgtype": "m.text", "body": "I've been edited!"} channel = self._send_relation( @@ -971,11 +513,9 @@ class RelationsTestCase(BaseRelationsTestCase): "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, ) - self.assertEqual(200, channel.code, channel.json_body) - edit_event_id = channel.json_body["event_id"] - channel = self._send_relation( + self._send_relation( RelationTypes.REPLACE, "m.room.message.WRONG_TYPE", content={ @@ -984,7 +524,6 @@ class RelationsTestCase(BaseRelationsTestCase): "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"}, }, ) - self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", @@ -1015,7 +554,6 @@ class RelationsTestCase(BaseRelationsTestCase): "m.room.message", content={"msgtype": "m.text", "body": "A reply!"}, ) - self.assertEqual(200, channel.code, channel.json_body) reply = channel.json_body["event_id"] new_body = {"msgtype": "m.text", "body": "I've been edited!"} @@ -1025,8 +563,6 @@ class RelationsTestCase(BaseRelationsTestCase): content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, parent_id=reply, ) - self.assertEqual(200, channel.code, channel.json_body) - edit_event_id = channel.json_body["event_id"] channel = self.make_request( @@ -1071,17 +607,15 @@ class RelationsTestCase(BaseRelationsTestCase): "m.room.message", content={"msgtype": "m.text", "body": "A threaded reply!"}, ) - self.assertEqual(200, channel.code, channel.json_body) threaded_event_id = channel.json_body["event_id"] new_body = {"msgtype": "m.text", "body": "I've been edited!"} - channel = self._send_relation( + self._send_relation( RelationTypes.REPLACE, "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, parent_id=threaded_event_id, ) - self.assertEqual(200, channel.code, channel.json_body) # Fetch the thread root, to get the bundled aggregation for the thread. channel = self.make_request( @@ -1113,11 +647,10 @@ class RelationsTestCase(BaseRelationsTestCase): "m.new_content": new_body, }, ) - self.assertEqual(200, channel.code, channel.json_body) edit_event_id = channel.json_body["event_id"] # Edit the edit event. - channel = self._send_relation( + self._send_relation( RelationTypes.REPLACE, "m.room.message", content={ @@ -1127,7 +660,6 @@ class RelationsTestCase(BaseRelationsTestCase): }, parent_id=edit_event_id, ) - self.assertEqual(200, channel.code, channel.json_body) # Request the original event. channel = self.make_request( @@ -1154,7 +686,6 @@ class RelationsTestCase(BaseRelationsTestCase): def test_unknown_relations(self) -> None: """Unknown relations should be accepted.""" channel = self._send_relation("m.relation.test", "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) event_id = channel.json_body["event_id"] channel = self.make_request( @@ -1195,28 +726,15 @@ class RelationsTestCase(BaseRelationsTestCase): self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(channel.json_body["chunk"], []) - def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: - """ - Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. - """ - for event in events: - if event["event_id"] == self.parent_id: - return event - - raise AssertionError(f"Event {self.parent_id} not found in chunk") - def test_background_update(self) -> None: """Test the event_arbitrary_relations background update.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") - self.assertEqual(200, channel.code, channel.json_body) annotation_event_id_good = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A") - self.assertEqual(200, channel.code, channel.json_body) annotation_event_id_bad = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) thread_event_id = channel.json_body["event_id"] # Clean-up the table as if the inserts did not happen during event creation. @@ -1267,6 +785,516 @@ class RelationsTestCase(BaseRelationsTestCase): [annotation_event_id_good, thread_event_id], ) + +class RelationPaginationTestCase(BaseRelationsTestCase): + def test_basic_paginate_relations(self) -> None: + """Tests that calling pagination API correctly the latest relations.""" + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + first_annotation_id = channel.json_body["event_id"] + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") + second_annotation_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + # We expect to get back a single pagination result, which is the latest + # full relation event we sent above. + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assert_dict( + { + "event_id": second_annotation_id, + "sender": self.user_id, + "type": "m.reaction", + }, + channel.json_body["chunk"][0], + ) + + # We also expect to get the original event (the id of which is self.parent_id) + self.assertEqual( + channel.json_body["original_event"]["event_id"], self.parent_id + ) + + # Make sure next_batch has something in it that looks like it could be a + # valid token. + self.assertIsInstance( + channel.json_body.get("next_batch"), str, channel.json_body + ) + + # Request the relations again, but with a different direction. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations" + f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + # We expect to get back a single pagination result, which is the earliest + # full relation event we sent above. + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assert_dict( + { + "event_id": first_annotation_id, + "sender": self.user_id, + "type": "m.reaction", + }, + channel.json_body["chunk"][0], + ) + + def test_repeated_paginate_relations(self) -> None: + """Test that if we paginate using a limit and tokens then we get the + expected events. + """ + + expected_event_ids = [] + for idx in range(10): + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx) + ) + expected_event_ids.append(channel.json_body["event_id"]) + + prev_token = "" + found_event_ids: List[str] = [] + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + next_batch = channel.json_body.get("next_batch") + + self.assertNotEqual(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEqual(found_event_ids, expected_event_ids) + + def test_pagination_from_sync_and_messages(self) -> None: + """Pagination tokens from /sync and /messages can be used to paginate /relations.""" + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A") + annotation_id = channel.json_body["event_id"] + # Send an event after the relation events. + self.helper.send(self.room, body="Latest event", tok=self.user_token) + + # Request /sync, limiting it such that only the latest event is returned + # (and not the relation). + filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}') + channel = self.make_request( + "GET", f"/sync?filter={filter}", access_token=self.user_token + ) + self.assertEqual(200, channel.code, channel.json_body) + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + sync_prev_batch = room_timeline["prev_batch"] + self.assertIsNotNone(sync_prev_batch) + # Ensure the relation event is not in the batch returned from /sync. + self.assertNotIn( + annotation_id, [ev["event_id"] for ev in room_timeline["events"]] + ) + + # Request /messages, limiting it such that only the latest event is + # returned (and not the relation). + channel = self.make_request( + "GET", + f"/rooms/{self.room}/messages?dir=b&limit=1", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + messages_end = channel.json_body["end"] + self.assertIsNotNone(messages_end) + # Ensure the relation event is not in the chunk returned from /messages. + self.assertNotIn( + annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] + ) + + # Request /relations with the pagination tokens received from both the + # /sync and /messages responses above, in turn. + # + # This is a tiny bit silly since the client wouldn't know the parent ID + # from the requests above; consider the parent ID to be known from a + # previous /sync. + for from_token in (sync_prev_batch, messages_end): + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + # The relation should be in the returned chunk. + self.assertIn( + annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] + ) + + def test_aggregation_pagination_groups(self) -> None: + """Test that we can paginate annotation groups correctly.""" + + # We need to create ten separate users to send each reaction. + access_tokens = [self.user_token, self.user2_token] + idx = 0 + while len(access_tokens) < 10: + user_id, token = self._create_user("test" + str(idx)) + idx += 1 + + self.helper.join(self.room, user=user_id, tok=token) + access_tokens.append(token) + + idx = 0 + sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1} + for key in itertools.chain.from_iterable( + itertools.repeat(key, num) for key, num in sent_groups.items() + ): + self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + key=key, + access_token=access_tokens[idx], + ) + + idx += 1 + idx %= len(access_tokens) + + prev_token: Optional[str] = None + found_groups: Dict[str, int] = {} + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + + for groups in channel.json_body["chunk"]: + # We only expect reactions + self.assertEqual(groups["type"], "m.reaction", channel.json_body) + + # We should only see each key once + self.assertNotIn(groups["key"], found_groups, channel.json_body) + + found_groups[groups["key"]] = groups["count"] + + next_batch = channel.json_body.get("next_batch") + + self.assertNotEqual(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + self.assertEqual(sent_groups, found_groups) + + def test_aggregation_pagination_within_group(self) -> None: + """Test that we can paginate within an annotation group.""" + + # We need to create ten separate users to send each reaction. + access_tokens = [self.user_token, self.user2_token] + idx = 0 + while len(access_tokens) < 10: + user_id, token = self._create_user("test" + str(idx)) + idx += 1 + + self.helper.join(self.room, user=user_id, tok=token) + access_tokens.append(token) + + idx = 0 + expected_event_ids = [] + for _ in range(10): + channel = self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + key="👍", + access_token=access_tokens[idx], + ) + expected_event_ids.append(channel.json_body["event_id"]) + + idx += 1 + + # Also send a different type of reaction so that we test we don't see it + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") + + prev_token = "" + found_event_ids: List[str] = [] + encoded_key = urllib.parse.quote_plus("👍".encode()) + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}" + f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}" + f"/m.reaction/{encoded_key}?limit=1{from_token}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + + next_batch = channel.json_body.get("next_batch") + + self.assertNotEqual(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEqual(found_event_ids, expected_event_ids) + + +class BundledAggregationsTestCase(BaseRelationsTestCase): + """ + See RelationsTestCase.test_edit for a similar test for edits. + + Note that this doesn't test against /relations since only thread relations + get bundled via that API. See test_aggregation_get_event_for_thread. + """ + + def _test_bundled_aggregations( + self, + relation_type: str, + assertion_callable: Callable[[JsonDict], None], + expected_db_txn_for_event: int, + ) -> None: + """ + Makes requests to various endpoints which should include bundled aggregations + and then calls an assertion function on the bundled aggregations. + + Args: + relation_type: The field to search for in the `m.relations` field in unsigned. + assertion_callable: Called with the contents of unsigned["m.relations"][relation_type] + for relation-specific assertions. + expected_db_txn_for_event: The number of database transactions which + are expected for a call to /event/. + """ + + def assert_bundle(event_json: JsonDict) -> None: + """Assert the expected values of the bundled aggregations.""" + relations_dict = event_json["unsigned"].get("m.relations") + + # Ensure the fields are as expected. + self.assertCountEqual(relations_dict.keys(), (relation_type,)) + assertion_callable(relations_dict[relation_type]) + + # Request the event directly. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + assert_bundle(channel.json_body) + assert channel.resource_usage is not None + self.assertEqual(channel.resource_usage.db_txn_count, expected_db_txn_for_event) + + # Request the room messages. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/messages?dir=b", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) + + # Request the room context. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/context/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + assert_bundle(channel.json_body["event"]) + + # Request sync. + filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}') + channel = self.make_request( + "GET", f"/sync?filter={filter}", access_token=self.user_token + ) + self.assertEqual(200, channel.code, channel.json_body) + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + self.assertTrue(room_timeline["limited"]) + assert_bundle(self._find_event_in_chunk(room_timeline["events"])) + + # Request search. + channel = self.make_request( + "POST", + "/search", + # Search term matches the parent message. + content={"search_categories": {"room_events": {"search_term": "Hi"}}}, + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + chunk = [ + result["result"] + for result in channel.json_body["search_categories"]["room_events"][ + "results" + ] + ] + assert_bundle(self._find_event_in_chunk(chunk)) + + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) + def test_annotation(self) -> None: + """ + Test that annotations get correctly bundled. + """ + # Setup by sending a variety of relations. + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") + + def assert_annotations(bundled_aggregations: JsonDict) -> None: + self.assertEqual( + { + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + ] + }, + bundled_aggregations, + ) + + self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7) + + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) + def test_reference(self) -> None: + """ + Test that references get correctly bundled. + """ + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + reply_1 = channel.json_body["event_id"] + + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + reply_2 = channel.json_body["event_id"] + + def assert_annotations(bundled_aggregations: JsonDict) -> None: + self.assertEqual( + {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, + bundled_aggregations, + ) + + self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7) + + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) + def test_thread(self) -> None: + """ + Test that threads get correctly bundled. + """ + self._send_relation(RelationTypes.THREAD, "m.room.test") + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + thread_2 = channel.json_body["event_id"] + + def assert_annotations(bundled_aggregations: JsonDict) -> None: + self.assertEqual(2, bundled_aggregations.get("count")) + self.assertTrue(bundled_aggregations.get("current_user_participated")) + # The latest thread event has some fields that don't matter. + self.assert_dict( + { + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "rel_type": RelationTypes.THREAD, + } + }, + "event_id": thread_2, + "sender": self.user_id, + "type": "m.room.test", + }, + bundled_aggregations.get("latest_event"), + ) + + self._test_bundled_aggregations(RelationTypes.THREAD, assert_annotations, 9) + + def test_aggregation_get_event_for_annotation(self) -> None: + """Test that annotations do not get bundled aggregations included + when directly requested. + """ + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + annotation_id = channel.json_body["event_id"] + + # Annotate the annotation. + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id + ) + + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{annotation_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) + + def test_aggregation_get_event_for_thread(self) -> None: + """Test that threads get bundled aggregations included when directly requested.""" + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + thread_id = channel.json_body["event_id"] + + # Annotate the annotation. + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id + ) + + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{thread_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual( + channel.json_body["unsigned"].get("m.relations"), + { + RelationTypes.ANNOTATION: { + "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] + }, + }, + ) + + # It should also be included when the entire thread is requested. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual(len(channel.json_body["chunk"]), 1) + + thread_message = channel.json_body["chunk"][0] + self.assertEqual( + thread_message["unsigned"].get("m.relations"), + { + RelationTypes.ANNOTATION: { + "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] + }, + }, + ) + def test_bundled_aggregations_with_filter(self) -> None: """ If "unsigned" is an omitted field (due to filtering), adding the bundled @@ -1322,46 +1350,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): ) self.assertEqual(200, channel.code, channel.json_body) - def _make_relation_requests(self) -> Tuple[List[str], JsonDict]: - """ - Makes requests and ensures they result in a 200 response, returns a - tuple of results: - - 1. `/relations` -> Returns a list of event IDs. - 2. `/event` -> Returns the response's m.relations field (from unsigned), - if it exists. - """ - - # Request the relations of the event. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]] - - # Fetch the bundled aggregations of the event. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/event/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - bundled_relations = channel.json_body["unsigned"].get("m.relations", {}) - - return event_ids, bundled_relations - - def _get_aggregations(self) -> List[JsonDict]: - """Request /aggregations on the parent ID and includes the returned chunk.""" - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - return channel.json_body["chunk"] - def test_redact_relation_annotation(self) -> None: """ Test that annotations of an event are properly handled after the @@ -1371,17 +1359,16 @@ class RelationRedactionTestCase(BaseRelationsTestCase): the response to relations. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(200, channel.code, channel.json_body) to_redact_event_id = channel.json_body["event_id"] channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) - self.assertEqual(200, channel.code, channel.json_body) unredacted_event_id = channel.json_body["event_id"] # Both relations should exist. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id]) self.assertEquals( relations["m.annotation"], @@ -1396,7 +1383,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase): self._redact(to_redact_event_id) # The unredacted relation should still exist. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEquals(event_ids, [unredacted_event_id]) self.assertEquals( relations["m.annotation"], @@ -1419,7 +1407,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): EventTypes.Message, content={"body": "reply 1", "msgtype": "m.text"}, ) - self.assertEqual(200, channel.code, channel.json_body) unredacted_event_id = channel.json_body["event_id"] # Note that the *last* event in the thread is redacted, as that gets @@ -1429,11 +1416,11 @@ class RelationRedactionTestCase(BaseRelationsTestCase): EventTypes.Message, content={"body": "reply 2", "msgtype": "m.text"}, ) - self.assertEqual(200, channel.code, channel.json_body) to_redact_event_id = channel.json_body["event_id"] # Both relations exist. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEquals(event_ids, [to_redact_event_id, unredacted_event_id]) self.assertDictContainsSubset( { @@ -1452,7 +1439,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase): self._redact(to_redact_event_id) # The unredacted relation should still exist. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEquals(event_ids, [unredacted_event_id]) self.assertDictContainsSubset( { @@ -1472,7 +1460,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase): is redacted. """ # Add a relation - channel = self._send_relation( + self._send_relation( RelationTypes.REPLACE, "m.room.message", parent_id=self.parent_id, @@ -1482,10 +1470,10 @@ class RelationRedactionTestCase(BaseRelationsTestCase): "m.new_content": {"msgtype": "m.text", "body": "First edit"}, }, ) - self.assertEqual(200, channel.code, channel.json_body) # Check the relation is returned - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEqual(len(event_ids), 1) self.assertIn(RelationTypes.REPLACE, relations) @@ -1493,7 +1481,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase): self._redact(self.parent_id) # The relations are not returned. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEqual(len(event_ids), 0) self.assertEqual(relations, {}) @@ -1503,11 +1492,11 @@ class RelationRedactionTestCase(BaseRelationsTestCase): """ # Add a relation channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") - self.assertEqual(200, channel.code, channel.json_body) related_event_id = channel.json_body["event_id"] # The relations should exist. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEqual(len(event_ids), 1) self.assertIn(RelationTypes.ANNOTATION, relations) @@ -1519,7 +1508,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase): self._redact(self.parent_id) # The relations are returned. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEquals(event_ids, [related_event_id]) self.assertEquals( relations["m.annotation"], @@ -1540,14 +1530,14 @@ class RelationRedactionTestCase(BaseRelationsTestCase): EventTypes.Message, content={"body": "reply 1", "msgtype": "m.text"}, ) - self.assertEqual(200, channel.code, channel.json_body) related_event_id = channel.json_body["event_id"] # Redact one of the reactions. self._redact(self.parent_id) # The unredacted relation should still exist. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEquals(len(event_ids), 1) self.assertDictContainsSubset( { diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py index 3fb37a2a59..62e308814d 100644 --- a/tests/rest/media/v1/test_html_preview.py +++ b/tests/rest/media/v1/test_html_preview.py @@ -16,7 +16,6 @@ from synapse.rest.media.v1.preview_html import ( _get_html_media_encodings, decode_body, parse_html_to_open_graph, - rebase_url, summarize_paragraphs, ) @@ -161,7 +160,7 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -177,7 +176,7 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -196,7 +195,7 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual( og, @@ -218,7 +217,7 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -232,7 +231,7 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) @@ -247,7 +246,7 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) @@ -262,7 +261,7 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) @@ -290,7 +289,7 @@ class CalcOgTestCase(unittest.TestCase): <head><title>Foo</title></head><body>Some text.</body></html> """.strip() tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) def test_invalid_encoding(self) -> None: @@ -304,7 +303,7 @@ class CalcOgTestCase(unittest.TestCase): </html> """ tree = decode_body(html, "http://example.com/test.html", "invalid-encoding") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) def test_invalid_encoding2(self) -> None: @@ -319,7 +318,7 @@ class CalcOgTestCase(unittest.TestCase): </html> """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."}) def test_windows_1252(self) -> None: @@ -333,7 +332,7 @@ class CalcOgTestCase(unittest.TestCase): </html> """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."}) @@ -448,34 +447,3 @@ class MediaEncodingTestCase(unittest.TestCase): 'text/html; charset="invalid"', ) self.assertEqual(list(encodings), ["utf-8", "cp1252"]) - - -class RebaseUrlTestCase(unittest.TestCase): - def test_relative(self) -> None: - """Relative URLs should be resolved based on the context of the base URL.""" - self.assertEqual( - rebase_url("subpage", "https://example.com/foo/"), - "https://example.com/foo/subpage", - ) - self.assertEqual( - rebase_url("sibling", "https://example.com/foo"), - "https://example.com/sibling", - ) - self.assertEqual( - rebase_url("/bar", "https://example.com/foo/"), - "https://example.com/bar", - ) - - def test_absolute(self) -> None: - """Absolute URLs should not be modified.""" - self.assertEqual( - rebase_url("https://alice.com/a/", "https://example.com/foo/"), - "https://alice.com/a/", - ) - - def test_data(self) -> None: - """Data URLs should not be modified.""" - self.assertEqual( - rebase_url("data:,Hello%2C%20World%21", "https://example.com/foo/"), - "data:,Hello%2C%20World%21", - ) diff --git a/tests/server.py b/tests/server.py index 82990c2eb9..6ce2a17bf4 100644 --- a/tests/server.py +++ b/tests/server.py @@ -54,13 +54,18 @@ from twisted.internet.interfaces import ( ITransport, ) from twisted.python.failure import Failure -from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock +from twisted.test.proto_helpers import ( + AccumulatingProtocol, + MemoryReactor, + MemoryReactorClock, +) from twisted.web.http_headers import Headers from twisted.web.resource import IResource from twisted.web.server import Request, Site from synapse.config.database import DatabaseConnectionConfig from synapse.http.site import SynapseRequest +from synapse.logging.context import ContextResourceUsage from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.engines import PostgresEngine, create_engine @@ -88,18 +93,19 @@ class TimedOutException(Exception): """ -@attr.s +@attr.s(auto_attribs=True) class FakeChannel: """ A fake Twisted Web Channel (the part that interfaces with the wire). """ - site = attr.ib(type=Union[Site, "FakeSite"]) - _reactor = attr.ib() - result = attr.ib(type=dict, default=attr.Factory(dict)) - _ip = attr.ib(type=str, default="127.0.0.1") + site: Union[Site, "FakeSite"] + _reactor: MemoryReactor + result: dict = attr.Factory(dict) + _ip: str = "127.0.0.1" _producer: Optional[Union[IPullProducer, IPushProducer]] = None + resource_usage: Optional[ContextResourceUsage] = None @property def json_body(self): @@ -168,6 +174,8 @@ class FakeChannel: def requestDone(self, _self): self.result["done"] = True + if isinstance(_self, SynapseRequest): + self.resource_usage = _self.logcontext.get_resource_usage() def getPeer(self): # We give an address so that getClientIP returns a non null entry, diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py index 272cd35402..72bf5b3d31 100644 --- a/tests/storage/test_account_data.py +++ b/tests/storage/test_account_data.py @@ -47,9 +47,18 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): expected_ignorer_user_ids, ) + def assert_ignored( + self, ignorer_user_id: str, expected_ignored_user_ids: Set[str] + ) -> None: + self.assertEqual( + self.get_success(self.store.ignored_users(ignorer_user_id)), + expected_ignored_user_ids, + ) + def test_ignoring_users(self): """Basic adding/removing of users from the ignore list.""" self._update_ignore_list("@other:test", "@another:remote") + self.assert_ignored(self.user, {"@other:test", "@another:remote"}) # Check a user which no one ignores. self.assert_ignorers("@user:test", set()) @@ -62,6 +71,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): # Add one user, remove one user, and leave one user. self._update_ignore_list("@foo:test", "@another:remote") + self.assert_ignored(self.user, {"@foo:test", "@another:remote"}) # Check the removed user. self.assert_ignorers("@other:test", set()) @@ -76,20 +86,24 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): """Ensure that caching works properly between different users.""" # The first user ignores a user. self._update_ignore_list("@other:test") + self.assert_ignored(self.user, {"@other:test"}) self.assert_ignorers("@other:test", {self.user}) # The second user ignores them. self._update_ignore_list("@other:test", ignorer_user_id="@second:test") + self.assert_ignored("@second:test", {"@other:test"}) self.assert_ignorers("@other:test", {self.user, "@second:test"}) # The first user un-ignores them. self._update_ignore_list() + self.assert_ignored(self.user, set()) self.assert_ignorers("@other:test", {"@second:test"}) def test_invalid_data(self): """Invalid data ends up clearing out the ignored users list.""" # Add some data and ensure it is there. self._update_ignore_list("@other:test") + self.assert_ignored(self.user, {"@other:test"}) self.assert_ignorers("@other:test", {self.user}) # No ignored_users key. @@ -102,10 +116,12 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): ) # No one ignores the user now. + self.assert_ignored(self.user, set()) self.assert_ignorers("@other:test", set()) # Add some data and ensure it is there. self._update_ignore_list("@other:test") + self.assert_ignored(self.user, {"@other:test"}) self.assert_ignorers("@other:test", {self.user}) # Invalid data. @@ -118,4 +134,5 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): ) # No one ignores the user now. + self.assert_ignored(self.user, set()) self.assert_ignorers("@other:test", set()) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 5cf18b690e..fd619b64d4 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -17,8 +17,12 @@ from unittest.mock import Mock import yaml from twisted.internet.defer import Deferred, ensureDeferred +from twisted.test.proto_helpers import MemoryReactor +from synapse.server import HomeServer from synapse.storage.background_updates import BackgroundUpdater +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable, simple_async_mock @@ -26,7 +30,7 @@ from tests.unittest import override_config class BackgroundUpdateTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates # the base test class should have run the real bg updates for us self.assertTrue( @@ -39,7 +43,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): ) self.store = self.hs.get_datastores().main - async def update(self, progress, count): + async def update(self, progress: JsonDict, count: int) -> int: duration_ms = 10 await self.clock.sleep((count * duration_ms) / 1000) progress = {"my_key": progress["my_key"] + 1} @@ -51,7 +55,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): ) return count - def test_do_background_update(self): + def test_do_background_update(self) -> None: # the time we claim it takes to update one item when running the update duration_ms = 10 @@ -80,7 +84,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): # second step: complete the update # we should now get run with a much bigger number of items to update - async def update(progress, count): + async def update(progress: JsonDict, count: int) -> int: self.assertEqual(progress, {"my_key": 2}) self.assertAlmostEqual( count, @@ -110,7 +114,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): """ ) ) - def test_background_update_default_batch_set_by_config(self): + def test_background_update_default_batch_set_by_config(self) -> None: """ Test that the background update is run with the default_batch_size set by the config """ @@ -133,7 +137,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): # on the first call, we should get run with the default background update size specified in the config self.update_handler.assert_called_once_with({"my_key": 1}, 20) - def test_background_update_default_sleep_behavior(self): + def test_background_update_default_sleep_behavior(self) -> None: """ Test default background update behavior, which is to sleep """ @@ -147,7 +151,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): self.update_handler.side_effect = self.update self.update_handler.reset_mock() - self.updates.start_doing_background_updates(), + self.updates.start_doing_background_updates() # 2: advance the reactor less than the default sleep duration (1000ms) self.reactor.pump([0.5]) @@ -167,7 +171,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): """ ) ) - def test_background_update_sleep_set_in_config(self): + def test_background_update_sleep_set_in_config(self) -> None: """ Test that changing the sleep time in the config changes how long it sleeps """ @@ -181,7 +185,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): self.update_handler.side_effect = self.update self.update_handler.reset_mock() - self.updates.start_doing_background_updates(), + self.updates.start_doing_background_updates() # 2: advance the reactor less than the configured sleep duration (500ms) self.reactor.pump([0.45]) @@ -201,7 +205,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): """ ) ) - def test_disabling_background_update_sleep(self): + def test_disabling_background_update_sleep(self) -> None: """ Test that disabling sleep in the config results in bg update not sleeping """ @@ -215,7 +219,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): self.update_handler.side_effect = self.update self.update_handler.reset_mock() - self.updates.start_doing_background_updates(), + self.updates.start_doing_background_updates() # 2: advance the reactor very little self.reactor.pump([0.025]) @@ -230,7 +234,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): """ ) ) - def test_background_update_duration_set_in_config(self): + def test_background_update_duration_set_in_config(self) -> None: """ Test that the desired duration set in the config is used in determining batch size """ @@ -254,7 +258,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): # the first update was run with the default batch size, this should be run with 500ms as the # desired duration - async def update(progress, count): + async def update(progress: JsonDict, count: int) -> int: self.assertEqual(progress, {"my_key": 2}) self.assertAlmostEqual( count, @@ -275,7 +279,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): """ ) ) - def test_background_update_min_batch_set_in_config(self): + def test_background_update_min_batch_set_in_config(self) -> None: """ Test that the minimum batch size set in the config is used """ @@ -290,7 +294,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): ) # Run the update with the long-running update item - async def update(progress, count): + async def update_long(progress: JsonDict, count: int) -> int: await self.clock.sleep((count * duration_ms) / 1000) progress = {"my_key": progress["my_key"] + 1} await self.store.db_pool.runInteraction( @@ -301,7 +305,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): ) return count - self.update_handler.side_effect = update + self.update_handler.side_effect = update_long self.update_handler.reset_mock() res = self.get_success( self.updates.do_next_background_update(False), @@ -311,25 +315,25 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): # the first update was run with the default batch size, this should be run with minimum batch size # as the first items took a very long time - async def update(progress, count): + async def update_short(progress: JsonDict, count: int) -> int: self.assertEqual(progress, {"my_key": 2}) self.assertEqual(count, 5) await self.updates._end_background_update("test_update") return count - self.update_handler.side_effect = update + self.update_handler.side_effect = update_short self.get_success(self.updates.do_next_background_update(False)) class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates # the base test class should have run the real bg updates for us self.assertTrue( self.get_success(self.updates.has_completed_background_updates()) ) - self.update_deferred = Deferred() + self.update_deferred: Deferred[int] = Deferred() self.update_handler = Mock(return_value=self.update_deferred) self.updates.register_background_update_handler( "test_update", self.update_handler @@ -358,7 +362,7 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): ), ) - def test_controller(self): + def test_controller(self) -> None: store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( @@ -368,7 +372,7 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): ) # Set the return value for the context manager. - enter_defer = Deferred() + enter_defer: Deferred[int] = Deferred() self._update_ctx_manager.__aenter__ = Mock(return_value=enter_defer) # Start the background update. diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py index 8597867563..a40fc20ef9 100644 --- a/tests/storage/test_database.py +++ b/tests/storage/test_database.py @@ -12,7 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.database import make_tuple_comparison_clause +from typing import Callable, Tuple +from unittest.mock import Mock, call + +from twisted.internet import defer +from twisted.internet.defer import CancelledError, Deferred +from twisted.test.proto_helpers import MemoryReactor + +from synapse.server import HomeServer +from synapse.storage.database import ( + DatabasePool, + LoggingTransaction, + make_tuple_comparison_clause, +) +from synapse.util import Clock from tests import unittest @@ -22,3 +35,150 @@ class TupleComparisonClauseTestCase(unittest.TestCase): clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)]) self.assertEqual(clause, "(a,b) > (?,?)") self.assertEqual(args, [1, 2]) + + +class CallbacksTestCase(unittest.HomeserverTestCase): + """Tests for transaction callbacks.""" + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.db_pool: DatabasePool = self.store.db_pool + + def _run_interaction( + self, func: Callable[[LoggingTransaction], object] + ) -> Tuple[Mock, Mock]: + """Run the given function in a database transaction, with callbacks registered. + + Args: + func: The function to be run in a transaction. The transaction will be + retried if `func` raises an `OperationalError`. + + Returns: + Two mocks, which were registered as an `after_callback` and an + `exception_callback` respectively, on every transaction attempt. + """ + after_callback = Mock() + exception_callback = Mock() + + def _test_txn(txn: LoggingTransaction) -> None: + txn.call_after(after_callback, 123, 456, extra=789) + txn.call_on_exception(exception_callback, 987, 654, extra=321) + func(txn) + + try: + self.get_success_or_raise( + self.db_pool.runInteraction("test_transaction", _test_txn) + ) + except Exception: + pass + + return after_callback, exception_callback + + def test_after_callback(self) -> None: + """Test that the after callback is called when a transaction succeeds.""" + after_callback, exception_callback = self._run_interaction(lambda txn: None) + + after_callback.assert_called_once_with(123, 456, extra=789) + exception_callback.assert_not_called() + + def test_exception_callback(self) -> None: + """Test that the exception callback is called when a transaction fails.""" + _test_txn = Mock(side_effect=ZeroDivisionError) + after_callback, exception_callback = self._run_interaction(_test_txn) + + after_callback.assert_not_called() + exception_callback.assert_called_once_with(987, 654, extra=321) + + def test_failed_retry(self) -> None: + """Test that the exception callback is called for every failed attempt.""" + # Always raise an `OperationalError`. + _test_txn = Mock(side_effect=self.db_pool.engine.module.OperationalError) + after_callback, exception_callback = self._run_interaction(_test_txn) + + after_callback.assert_not_called() + exception_callback.assert_has_calls( + [ + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + ] + ) + self.assertEqual(exception_callback.call_count, 6) # no additional calls + + def test_successful_retry(self) -> None: + """Test callbacks for a failed transaction followed by a successful attempt.""" + # Raise an `OperationalError` on the first attempt only. + _test_txn = Mock( + side_effect=[self.db_pool.engine.module.OperationalError, None] + ) + after_callback, exception_callback = self._run_interaction(_test_txn) + + # Calling both `after_callback`s when the first attempt failed is rather + # surprising (#12184). Let's document the behaviour in a test. + after_callback.assert_has_calls( + [ + call(123, 456, extra=789), + call(123, 456, extra=789), + ] + ) + self.assertEqual(after_callback.call_count, 2) # no additional calls + exception_callback.assert_not_called() + + +class CancellationTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.db_pool: DatabasePool = self.store.db_pool + + def test_after_callback(self) -> None: + """Test that the after callback is called when a transaction succeeds.""" + d: "Deferred[None]" + after_callback = Mock() + exception_callback = Mock() + + def _test_txn(txn: LoggingTransaction) -> None: + txn.call_after(after_callback, 123, 456, extra=789) + txn.call_on_exception(exception_callback, 987, 654, extra=321) + d.cancel() + + d = defer.ensureDeferred( + self.db_pool.runInteraction("test_transaction", _test_txn) + ) + self.get_failure(d, CancelledError) + + after_callback.assert_called_once_with(123, 456, extra=789) + exception_callback.assert_not_called() + + def test_exception_callback(self) -> None: + """Test that the exception callback is called when a transaction fails.""" + d: "Deferred[None]" + after_callback = Mock() + exception_callback = Mock() + + def _test_txn(txn: LoggingTransaction) -> None: + txn.call_after(after_callback, 123, 456, extra=789) + txn.call_on_exception(exception_callback, 987, 654, extra=321) + d.cancel() + # Simulate a retryable failure on every attempt. + raise self.db_pool.engine.module.OperationalError() + + d = defer.ensureDeferred( + self.db_pool.runInteraction("test_transaction", _test_txn) + ) + self.get_failure(d, CancelledError) + + after_callback.assert_not_called() + exception_callback.assert_has_calls( + [ + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + ] + ) + self.assertEqual(exception_callback.call_count, 6) # no additional calls diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 6ac4b93f98..395396340b 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -13,9 +13,13 @@ # limitations under the License. from typing import List, Optional -from synapse.storage.database import DatabasePool +from twisted.test.proto_helpers import MemoryReactor + +from synapse.server import HomeServer +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.util import Clock from tests.unittest import HomeserverTestCase from tests.utils import USE_POSTGRES_FOR_TESTS @@ -25,13 +29,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): if not USE_POSTGRES_FOR_TESTS: skip = "Requires Postgres" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) - def _setup_db(self, txn): + def _setup_db(self, txn: LoggingTransaction) -> None: txn.execute("CREATE SEQUENCE foobar_seq") txn.execute( """ @@ -59,12 +63,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) - def _insert_rows(self, instance_name: str, number: int): + def _insert_rows(self, instance_name: str, number: int) -> None: """Insert N rows as the given instance, inserting with stream IDs pulled from the postgres sequence. """ - def _insert(txn): + def _insert(txn: LoggingTransaction) -> None: for _ in range(number): txn.execute( "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)", @@ -80,12 +84,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) - def _insert_row_with_id(self, instance_name: str, stream_id: int): + def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None: """Insert one row as the given instance with given stream_id, updating the postgres sequence position to match. """ - def _insert(txn): + def _insert(txn: LoggingTransaction) -> None: txn.execute( "INSERT INTO foobar VALUES (?, ?)", ( @@ -104,7 +108,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert)) - def test_empty(self): + def test_empty(self) -> None: """Test an ID generator against an empty database gives sensible current positions. """ @@ -114,7 +118,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # The table is empty so we expect an empty map for positions self.assertEqual(id_gen.get_positions(), {}) - def test_single_instance(self): + def test_single_instance(self) -> None: """Test that reads and writes from a single process are handled correctly. """ @@ -130,7 +134,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. - async def _get_next_async(): + async def _get_next_async() -> None: async with id_gen.get_next() as stream_id: self.assertEqual(stream_id, 8) @@ -142,7 +146,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"master": 8}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) - def test_out_of_order_finish(self): + def test_out_of_order_finish(self) -> None: """Test that IDs persisted out of order are correctly handled""" # Prefill table with 7 rows written by 'master' @@ -191,7 +195,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"master": 11}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 11) - def test_multi_instance(self): + def test_multi_instance(self) -> None: """Test that reads and writes from multiple processes are handled correctly. """ @@ -215,7 +219,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. - async def _get_next_async(): + async def _get_next_async() -> None: async with first_id_gen.get_next() as stream_id: self.assertEqual(stream_id, 8) @@ -233,7 +237,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # ... but calling `get_next` on the second instance should give a unique # stream ID - async def _get_next_async(): + async def _get_next_async2() -> None: async with second_id_gen.get_next() as stream_id: self.assertEqual(stream_id, 9) @@ -241,7 +245,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): second_id_gen.get_positions(), {"first": 3, "second": 7} ) - self.get_success(_get_next_async()) + self.get_success(_get_next_async2()) self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9}) @@ -249,7 +253,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): second_id_gen.advance("first", 8) self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9}) - def test_get_next_txn(self): + def test_get_next_txn(self) -> None: """Test that the `get_next_txn` function works correctly.""" # Prefill table with 7 rows written by 'master' @@ -263,7 +267,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. - def _get_next_txn(txn): + def _get_next_txn(txn: LoggingTransaction) -> None: stream_id = id_gen.get_next_txn(txn) self.assertEqual(stream_id, 8) @@ -275,7 +279,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"master": 8}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) - def test_get_persisted_upto_position(self): + def test_get_persisted_upto_position(self) -> None: """Test that `get_persisted_upto_position` correctly tracks updates to positions. """ @@ -317,7 +321,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): id_gen.advance("second", 15) self.assertEqual(id_gen.get_persisted_upto_position(), 11) - def test_get_persisted_upto_position_get_next(self): + def test_get_persisted_upto_position_get_next(self) -> None: """Test that `get_persisted_upto_position` correctly tracks updates to positions when `get_next` is called. """ @@ -331,7 +335,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_persisted_upto_position(), 5) - async def _get_next_async(): + async def _get_next_async() -> None: async with id_gen.get_next() as stream_id: self.assertEqual(stream_id, 6) self.assertEqual(id_gen.get_persisted_upto_position(), 5) @@ -344,7 +348,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # `persisted_upto_position` in this case, then it will be correct in the # other cases that are tested above (since they'll hit the same code). - def test_restart_during_out_of_order_persistence(self): + def test_restart_during_out_of_order_persistence(self) -> None: """Test that restarting a process while another process is writing out of order updates are handled correctly. """ @@ -388,7 +392,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): id_gen_worker.advance("master", 9) self.assertEqual(id_gen_worker.get_positions(), {"master": 9}) - def test_writer_config_change(self): + def test_writer_config_change(self) -> None: """Test that changing the writer config correctly works.""" self._insert_row_with_id("first", 3) @@ -421,7 +425,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # Check that we get a sane next stream ID with this new config. - async def _get_next_async(): + async def _get_next_async() -> None: async with id_gen_3.get_next() as stream_id: self.assertEqual(stream_id, 6) @@ -435,7 +439,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen_5.get_current_token_for_writer("first"), 6) self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6) - def test_sequence_consistency(self): + def test_sequence_consistency(self) -> None: """Test that we error out if the table and sequence diverges.""" # Prefill with some rows @@ -458,13 +462,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): if not USE_POSTGRES_FOR_TESTS: skip = "Requires Postgres" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) - def _setup_db(self, txn): + def _setup_db(self, txn: LoggingTransaction) -> None: txn.execute("CREATE SEQUENCE foobar_seq") txn.execute( """ @@ -493,10 +497,10 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): return self.get_success(self.db_pool.runWithConnection(_create)) - def _insert_row(self, instance_name: str, stream_id: int): + def _insert_row(self, instance_name: str, stream_id: int) -> None: """Insert one row as the given instance with given stream_id.""" - def _insert(txn): + def _insert(txn: LoggingTransaction) -> None: txn.execute( "INSERT INTO foobar VALUES (?, ?)", ( @@ -514,13 +518,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): self.get_success(self.db_pool.runInteraction("_insert_row", _insert)) - def test_single_instance(self): + def test_single_instance(self) -> None: """Test that reads and writes from a single process are handled correctly. """ id_gen = self._create_id_generator() - async def _get_next_async(): + async def _get_next_async() -> None: async with id_gen.get_next() as stream_id: self._insert_row("master", stream_id) @@ -530,7 +534,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_current_token_for_writer("master"), -1) self.assertEqual(id_gen.get_persisted_upto_position(), -1) - async def _get_next_async2(): + async def _get_next_async2() -> None: async with id_gen.get_next_mult(3) as stream_ids: for stream_id in stream_ids: self._insert_row("master", stream_id) @@ -548,14 +552,14 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4) self.assertEqual(second_id_gen.get_persisted_upto_position(), -4) - def test_multiple_instance(self): + def test_multiple_instance(self) -> None: """Tests that having multiple instances that get advanced over federation works corretly. """ id_gen_1 = self._create_id_generator("first", writers=["first", "second"]) id_gen_2 = self._create_id_generator("second", writers=["first", "second"]) - async def _get_next_async(): + async def _get_next_async() -> None: async with id_gen_1.get_next() as stream_id: self._insert_row("first", stream_id) id_gen_2.advance("first", stream_id) @@ -567,7 +571,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen_1.get_persisted_upto_position(), -1) self.assertEqual(id_gen_2.get_persisted_upto_position(), -1) - async def _get_next_async2(): + async def _get_next_async2() -> None: async with id_gen_2.get_next() as stream_id: self._insert_row("second", stream_id) id_gen_1.advance("second", stream_id) @@ -584,13 +588,13 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): if not USE_POSTGRES_FOR_TESTS: skip = "Requires Postgres" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) - def _setup_db(self, txn): + def _setup_db(self, txn: LoggingTransaction) -> None: txn.execute("CREATE SEQUENCE foobar_seq") txn.execute( """ @@ -642,7 +646,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): from the postgres sequence. """ - def _insert(txn): + def _insert(txn: LoggingTransaction) -> None: for _ in range(number): txn.execute( "INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,), @@ -659,7 +663,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) - def test_load_existing_stream(self): + def test_load_existing_stream(self) -> None: """Test creating ID gens with multiple tables that have rows from after the position in `stream_positions` table. """ diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py index 38e9f58ac6..5d1aa025d1 100644 --- a/tests/util/test_check_dependencies.py +++ b/tests/util/test_check_dependencies.py @@ -12,7 +12,7 @@ from tests.unittest import TestCase class DummyDistribution(metadata.Distribution): - def __init__(self, version: str): + def __init__(self, version: object): self._version = version @property @@ -30,6 +30,7 @@ old = DummyDistribution("0.1.2") old_release_candidate = DummyDistribution("0.1.2rc3") new = DummyDistribution("1.2.3") new_release_candidate = DummyDistribution("1.2.3rc4") +distribution_with_no_version = DummyDistribution(None) # could probably use stdlib TestCase --- no need for twisted here @@ -67,6 +68,18 @@ class TestDependencyChecker(TestCase): # should not raise check_requirements() + def test_version_reported_as_none(self) -> None: + """Complain if importlib.metadata.version() returns None. + + This shouldn't normally happen, but it was seen in the wild (#12223). + """ + with patch( + "synapse.util.check_dependencies.metadata.requires", + return_value=["dummypkg >= 1"], + ): + with self.mock_installed_package(distribution_with_no_version): + self.assertRaises(DependencyException, check_requirements) + def test_checks_ignore_dev_dependencies(self) -> None: """Bot generic and per-extra checks should ignore dev dependencies.""" with patch( |