diff --git a/changelog.d/12087.bugfix b/changelog.d/12087.bugfix
new file mode 100644
index 0000000000..6dacdddd0d
--- /dev/null
+++ b/changelog.d/12087.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug which caused the `/_matrix/federation/v1/state` and `.../state_ids` endpoints to return incorrect or invalid data when called for an event which we have stored as an "outlier".
diff --git a/changelog.d/12198.misc b/changelog.d/12198.misc
new file mode 100644
index 0000000000..6b184a9053
--- /dev/null
+++ b/changelog.d/12198.misc
@@ -0,0 +1 @@
+Add tests for database transaction callbacks.
diff --git a/changelog.d/12199.misc b/changelog.d/12199.misc
new file mode 100644
index 0000000000..16dec1d26d
--- /dev/null
+++ b/changelog.d/12199.misc
@@ -0,0 +1 @@
+Handle cancellation in `DatabasePool.runInteraction()`.
diff --git a/changelog.d/12216.misc b/changelog.d/12216.misc
new file mode 100644
index 0000000000..dc398ac1e0
--- /dev/null
+++ b/changelog.d/12216.misc
@@ -0,0 +1 @@
+Add missing type hints for cache storage.
diff --git a/changelog.d/12219.misc b/changelog.d/12219.misc
new file mode 100644
index 0000000000..6079414092
--- /dev/null
+++ b/changelog.d/12219.misc
@@ -0,0 +1 @@
+Clean-up logic around rebasing URLs for URL image previews.
diff --git a/changelog.d/12224.misc b/changelog.d/12224.misc
new file mode 100644
index 0000000000..b67a701dbb
--- /dev/null
+++ b/changelog.d/12224.misc
@@ -0,0 +1 @@
+Add type hints to tests files.
diff --git a/changelog.d/12225.misc b/changelog.d/12225.misc
new file mode 100644
index 0000000000..23105c727c
--- /dev/null
+++ b/changelog.d/12225.misc
@@ -0,0 +1 @@
+Use the `ignored_users` table in additional places instead of re-parsing the account data.
diff --git a/changelog.d/12227.misc b/changelog.d/12227.misc
new file mode 100644
index 0000000000..41c9dcbd37
--- /dev/null
+++ b/changelog.d/12227.misc
@@ -0,0 +1 @@
+Refactor the relations endpoints to add a `RelationsHandler`.
diff --git a/changelog.d/12228.bugfix b/changelog.d/12228.bugfix
new file mode 100644
index 0000000000..4755777139
--- /dev/null
+++ b/changelog.d/12228.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in v1.53.0 where an unnecessary query could be performed when fetching bundled aggregations for threads.
diff --git a/changelog.d/12231.doc b/changelog.d/12231.doc
new file mode 100644
index 0000000000..16593d2b92
--- /dev/null
+++ b/changelog.d/12231.doc
@@ -0,0 +1 @@
+Fix the link to the module documentation in the legacy spam checker warning message.
diff --git a/changelog.d/12232.misc b/changelog.d/12232.misc
new file mode 100644
index 0000000000..4a4132edff
--- /dev/null
+++ b/changelog.d/12232.misc
@@ -0,0 +1 @@
+Refactor relations tests to improve code re-use.
diff --git a/changelog.d/12237.misc b/changelog.d/12237.misc
new file mode 100644
index 0000000000..41c9dcbd37
--- /dev/null
+++ b/changelog.d/12237.misc
@@ -0,0 +1 @@
+Refactor the relations endpoints to add a `RelationsHandler`.
diff --git a/changelog.d/12240.misc b/changelog.d/12240.misc
new file mode 100644
index 0000000000..c5b6356799
--- /dev/null
+++ b/changelog.d/12240.misc
@@ -0,0 +1 @@
+Add type hints to tests files.
\ No newline at end of file
diff --git a/changelog.d/12242.misc b/changelog.d/12242.misc
new file mode 100644
index 0000000000..38e7e0f7d1
--- /dev/null
+++ b/changelog.d/12242.misc
@@ -0,0 +1 @@
+Generate announcement links in the release script.
diff --git a/changelog.d/12243.doc b/changelog.d/12243.doc
new file mode 100644
index 0000000000..b2031f0a40
--- /dev/null
+++ b/changelog.d/12243.doc
@@ -0,0 +1 @@
+Remove incorrect prefixes in the worker documentation for some endpoints.
diff --git a/changelog.d/12244.misc b/changelog.d/12244.misc
new file mode 100644
index 0000000000..950d48e4c6
--- /dev/null
+++ b/changelog.d/12244.misc
@@ -0,0 +1 @@
+Improve error message when dependencies check finds a broken installation.
\ No newline at end of file
diff --git a/changelog.d/12246.doc b/changelog.d/12246.doc
new file mode 100644
index 0000000000..e7fcc1b99c
--- /dev/null
+++ b/changelog.d/12246.doc
@@ -0,0 +1 @@
+Correct `check_username_for_spam` annotations and docs.
\ No newline at end of file
diff --git a/changelog.d/12248.misc b/changelog.d/12248.misc
new file mode 100644
index 0000000000..2b1290d1e1
--- /dev/null
+++ b/changelog.d/12248.misc
@@ -0,0 +1 @@
+Add missing type hints for storage.
\ No newline at end of file
diff --git a/changelog.d/12256.misc b/changelog.d/12256.misc
new file mode 100644
index 0000000000..c5b6356799
--- /dev/null
+++ b/changelog.d/12256.misc
@@ -0,0 +1 @@
+Add type hints to tests files.
\ No newline at end of file
diff --git a/changelog.d/12258.misc b/changelog.d/12258.misc
new file mode 100644
index 0000000000..80024c8e91
--- /dev/null
+++ b/changelog.d/12258.misc
@@ -0,0 +1 @@
+Compress metrics HTTP resource when enabled. Contributed by Nick @ Beeper.
diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md
index 2b672b78f9..472d957180 100644
--- a/docs/modules/spam_checker_callbacks.md
+++ b/docs/modules/spam_checker_callbacks.md
@@ -172,7 +172,7 @@ any of the subsequent implementations of this callback.
_First introduced in Synapse v1.37.0_
```python
-async def check_username_for_spam(user_profile: Dict[str, str]) -> bool
+async def check_username_for_spam(user_profile: synapse.module_api.UserProfile) -> bool
```
Called when computing search results in the user directory. The module must return a
@@ -182,9 +182,11 @@ search results; otherwise return `False`.
The profile is represented as a dictionary with the following keys:
-* `user_id`: The Matrix ID for this user.
-* `display_name`: The user's display name.
-* `avatar_url`: The `mxc://` URL to the user's avatar.
+* `user_id: str`. The Matrix ID for this user.
+* `display_name: Optional[str]`. The user's display name, or `None` if this user
+ has not set a display name.
+* `avatar_url: Optional[str]`. The `mxc://` URL to the user's avatar, or `None`
+ if this user has not set an avatar.
The module is given a copy of the original dictionary, so modifying it from within the
module cannot modify a user's profile when included in user directory search results.
diff --git a/docs/workers.md b/docs/workers.md
index 8751134e65..9eb4194e4d 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -185,8 +185,8 @@ worker: refer to the [stream writers](#stream-writers) section below for further
information.
# Sync requests
- ^/_matrix/client/(v2_alpha|r0|v3)/sync$
- ^/_matrix/client/(api/v1|v2_alpha|r0|v3)/events$
+ ^/_matrix/client/(r0|v3)/sync$
+ ^/_matrix/client/(api/v1|r0|v3)/events$
^/_matrix/client/(api/v1|r0|v3)/initialSync$
^/_matrix/client/(api/v1|r0|v3)/rooms/[^/]+/initialSync$
@@ -200,13 +200,9 @@ information.
^/_matrix/federation/v1/query/
^/_matrix/federation/v1/make_join/
^/_matrix/federation/v1/make_leave/
- ^/_matrix/federation/v1/send_join/
- ^/_matrix/federation/v2/send_join/
- ^/_matrix/federation/v1/send_leave/
- ^/_matrix/federation/v2/send_leave/
- ^/_matrix/federation/v1/invite/
- ^/_matrix/federation/v2/invite/
- ^/_matrix/federation/v1/query_auth/
+ ^/_matrix/federation/(v1|v2)/send_join/
+ ^/_matrix/federation/(v1|v2)/send_leave/
+ ^/_matrix/federation/(v1|v2)/invite/
^/_matrix/federation/v1/event_auth/
^/_matrix/federation/v1/exchange_third_party_invite/
^/_matrix/federation/v1/user/devices/
@@ -274,6 +270,8 @@ information.
Additionally, the following REST endpoints can be handled for GET requests:
^/_matrix/federation/v1/groups/
+ ^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/
+ ^/_matrix/client/(r0|v3|unstable)/groups/
Pagination requests can also be handled, but all requests for a given
room must be routed to the same instance. Additionally, care must be taken to
@@ -397,23 +395,23 @@ the stream writer for the `typing` stream:
The following endpoints should be routed directly to the worker configured as
the stream writer for the `to_device` stream:
- ^/_matrix/client/(api/v1|r0|v3|unstable)/sendToDevice/
+ ^/_matrix/client/(r0|v3|unstable)/sendToDevice/
##### The `account_data` stream
The following endpoints should be routed directly to the worker configured as
the stream writer for the `account_data` stream:
- ^/_matrix/client/(api/v1|r0|v3|unstable)/.*/tags
- ^/_matrix/client/(api/v1|r0|v3|unstable)/.*/account_data
+ ^/_matrix/client/(r0|v3|unstable)/.*/tags
+ ^/_matrix/client/(r0|v3|unstable)/.*/account_data
##### The `receipts` stream
The following endpoints should be routed directly to the worker configured as
the stream writer for the `receipts` stream:
- ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/receipt
- ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/read_markers
+ ^/_matrix/client/(r0|v3|unstable)/rooms/.*/receipt
+ ^/_matrix/client/(r0|v3|unstable)/rooms/.*/read_markers
##### The `presence` stream
@@ -528,7 +526,7 @@ Note that if a reverse proxy is used , then `/_matrix/media/` must be routed for
Handles searches in the user directory. It can handle REST endpoints matching
the following regular expressions:
- ^/_matrix/client/(api/v1|r0|v3|unstable)/user_directory/search$
+ ^/_matrix/client/(r0|v3|unstable)/user_directory/search$
When using this worker you must also set `update_user_directory: False` in the
shared configuration file to stop the main synapse running background
@@ -540,7 +538,7 @@ Proxies some frequently-requested client endpoints to add caching and remove
load from the main synapse. It can handle REST endpoints matching the following
regular expressions:
- ^/_matrix/client/(api/v1|r0|v3|unstable)/keys/upload
+ ^/_matrix/client/(r0|v3|unstable)/keys/upload
If `use_presence` is False in the homeserver config, it can also handle REST
endpoints matching the following regular expressions:
diff --git a/mypy.ini b/mypy.ini
index f9c39fcaae..24d4ba15d4 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -42,9 +42,6 @@ exclude = (?x)
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py
|synapse/storage/databases/main/event_federation.py
- |synapse/storage/databases/main/group_server.py
- |synapse/storage/databases/main/metrics.py
- |synapse/storage/databases/main/monthly_active_users.py
|synapse/storage/databases/main/push_rule.py
|synapse/storage/databases/main/receipts.py
|synapse/storage/databases/main/roommember.py
@@ -66,14 +63,6 @@ exclude = (?x)
|tests/federation/test_federation_server.py
|tests/federation/transport/test_knocking.py
|tests/federation/transport/test_server.py
- |tests/handlers/test_cas.py
- |tests/handlers/test_directory.py
- |tests/handlers/test_e2e_keys.py
- |tests/handlers/test_federation.py
- |tests/handlers/test_oidc.py
- |tests/handlers/test_presence.py
- |tests/handlers/test_profile.py
- |tests/handlers/test_saml.py
|tests/handlers/test_typing.py
|tests/http/federation/test_matrix_federation_agent.py
|tests/http/federation/test_srv_resolver.py
@@ -85,7 +74,6 @@ exclude = (?x)
|tests/logging/test_terse_json.py
|tests/module_api/test_api.py
|tests/push/test_email.py
- |tests/push/test_http.py
|tests/push/test_presentable_names.py
|tests/push/test_push_rule_evaluator.py
|tests/rest/client/test_transactions.py
@@ -94,12 +82,7 @@ exclude = (?x)
|tests/server.py
|tests/server_notices/test_resource_limits_server_notices.py
|tests/state/test_v2.py
- |tests/storage/test_background_update.py
|tests/storage/test_base.py
- |tests/storage/test_client_ips.py
- |tests/storage/test_database.py
- |tests/storage/test_event_federation.py
- |tests/storage/test_id_generators.py
|tests/storage/test_roommember.py
|tests/test_metrics.py
|tests/test_phone_home.py
diff --git a/scripts-dev/release.py b/scripts-dev/release.py
index 046453e65f..685fa32b03 100755
--- a/scripts-dev/release.py
+++ b/scripts-dev/release.py
@@ -66,11 +66,15 @@ def cli():
./scripts-dev/release.py tag
- # ... wait for asssets to build ...
+ # ... wait for assets to build ...
./scripts-dev/release.py publish
./scripts-dev/release.py upload
+ # Optional: generate some nice links for the announcement
+
+ ./scripts-dev/release.py upload
+
If the env var GH_TOKEN (or GITHUB_TOKEN) is set, or passed into the
`tag`/`publish` command, then a new draft release will be created/published.
"""
@@ -415,6 +419,41 @@ def upload():
)
+@cli.command()
+def announce():
+ """Generate markdown to announce the release."""
+
+ current_version, _, _ = parse_version_from_module()
+ tag_name = f"v{current_version}"
+
+ click.echo(
+ f"""
+Hi everyone. Synapse {current_version} has just been released.
+
+[notes](https://github.com/matrix-org/synapse/releases/tag/{tag_name}) |\
+[docker](https://hub.docker.com/r/matrixdotorg/synapse/tags?name={tag_name}) | \
+[debs](https://packages.matrix.org/debian/) | \
+[pypi](https://pypi.org/project/matrix-synapse/{current_version}/)"""
+ )
+
+ if "rc" in tag_name:
+ click.echo(
+ """
+Announce the RC in
+- #homeowners:matrix.org (Synapse Announcements)
+- #synapse-dev:matrix.org"""
+ )
+ else:
+ click.echo(
+ """
+Announce the release in
+- #homeowners:matrix.org (Synapse Announcements), bumping the version in the topic
+- #synapse:matrix.org (Synapse Admins), bumping the version in the topic
+- #synapse-dev:matrix.org
+- #synapse-package-maintainers:matrix.org"""
+ )
+
+
def parse_version_from_module() -> Tuple[
version.Version, redbaron.RedBaron, redbaron.Node
]:
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index e4dc04c0b4..ad2b7c9515 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -261,7 +261,10 @@ class SynapseHomeServer(HomeServer):
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
if name == "metrics" and self.config.metrics.enable_metrics:
- resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
+ metrics_resource: Resource = MetricsResource(RegistryProxy)
+ if compress:
+ metrics_resource = gz_wrap(metrics_resource)
+ resources[METRICS_PREFIX] = metrics_resource
if name == "replication":
resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py
index a233a9ce03..4c52103b1c 100644
--- a/synapse/config/spam_checker.py
+++ b/synapse/config/spam_checker.py
@@ -25,8 +25,8 @@ logger = logging.getLogger(__name__)
LEGACY_SPAM_CHECKER_WARNING = """
This server is using a spam checker module that is implementing the deprecated spam
checker interface. Please check with the module's maintainer to see if a new version
-supporting Synapse's generic modules system is available.
-For more information, please see https://matrix-org.github.io/synapse/latest/modules.html
+supporting Synapse's generic modules system is available. For more information, please
+see https://matrix-org.github.io/synapse/latest/modules/index.html
---------------------------------------------------------------------------------------"""
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 60904a55f5..cd80fcf9d1 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -21,7 +21,6 @@ from typing import (
Awaitable,
Callable,
Collection,
- Dict,
List,
Optional,
Tuple,
@@ -31,7 +30,7 @@ from typing import (
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.media_storage import ReadableFileWrapper
from synapse.spam_checker_api import RegistrationBehaviour
-from synapse.types import RoomAlias
+from synapse.types import RoomAlias, UserProfile
from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
@@ -50,7 +49,7 @@ USER_MAY_SEND_3PID_INVITE_CALLBACK = Callable[[str, str, str, str], Awaitable[bo
USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]]
USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[[str, RoomAlias], Awaitable[bool]]
USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]]
-CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[Dict[str, str]], Awaitable[bool]]
+CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[UserProfile], Awaitable[bool]]
LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[
[
Optional[dict],
@@ -383,7 +382,7 @@ class SpamChecker:
return True
- async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
+ async def check_username_for_spam(self, user_profile: UserProfile) -> bool:
"""Checks if a user ID or display name are considered "spammy" by this server.
If the server considers a username spammy, then it will not be included in
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index a0520068e0..7120062127 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -38,8 +38,8 @@ from synapse.util.frozenutils import unfreeze
from . import EventBase
if TYPE_CHECKING:
+ from synapse.handlers.relations import BundledAggregations
from synapse.server import HomeServer
- from synapse.storage.databases.main.relations import BundledAggregations
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 482bbdd867..af2d0f7d79 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -22,7 +22,6 @@ from typing import (
Callable,
Collection,
Dict,
- Iterable,
List,
Optional,
Tuple,
@@ -577,10 +576,10 @@ class FederationServer(FederationBase):
async def _on_context_state_request_compute(
self, room_id: str, event_id: Optional[str]
) -> Dict[str, list]:
+ pdus: Collection[EventBase]
if event_id:
- pdus: Iterable[EventBase] = await self.handler.get_state_for_pdu(
- room_id, event_id
- )
+ event_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
+ pdus = await self.store.get_events_as_list(event_ids)
else:
pdus = (await self.state.get_current_state(room_id)).values()
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index db39aeabde..350ec9c03a 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -950,54 +950,35 @@ class FederationHandler:
return event
- async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
- """Returns the state at the event. i.e. not including said event."""
-
- event = await self.store.get_event(event_id, check_room_id=room_id)
-
- state_groups = await self.state_store.get_state_groups(room_id, [event_id])
-
- if state_groups:
- _, state = list(state_groups.items()).pop()
- results = {(e.type, e.state_key): e for e in state}
-
- if event.is_state():
- # Get previous state
- if "replaces_state" in event.unsigned:
- prev_id = event.unsigned["replaces_state"]
- if prev_id != event.event_id:
- prev_event = await self.store.get_event(prev_id)
- results[(event.type, event.state_key)] = prev_event
- else:
- del results[(event.type, event.state_key)]
-
- res = list(results.values())
- return res
- else:
- return []
-
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
"""Returns the state at the event. i.e. not including said event."""
event = await self.store.get_event(event_id, check_room_id=room_id)
+ if event.internal_metadata.outlier:
+ raise NotFoundError("State not known at event %s" % (event_id,))
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
- if state_groups:
- _, state = list(state_groups.items()).pop()
- results = state
+ # get_state_groups_ids should return exactly one result
+ assert len(state_groups) == 1
- if event.is_state():
- # Get previous state
- if "replaces_state" in event.unsigned:
- prev_id = event.unsigned["replaces_state"]
- if prev_id != event.event_id:
- results[(event.type, event.state_key)] = prev_id
- else:
- results.pop((event.type, event.state_key), None)
+ state_map = next(iter(state_groups.values()))
- return list(results.values())
- else:
- return []
+ state_key = event.get_state_key()
+ if state_key is not None:
+ # the event was not rejected (get_event raises a NotFoundError for rejected
+ # events) so the state at the event should include the event itself.
+ assert (
+ state_map.get((event.type, state_key)) == event.event_id
+ ), "State at event did not include event itself"
+
+ # ... but we need the state *before* that event
+ if "replaces_state" in event.unsigned:
+ prev_id = event.unsigned["replaces_state"]
+ state_map[(event.type, state_key)] = prev_id
+ else:
+ del state_map[(event.type, state_key)]
+
+ return list(state_map.values())
async def on_backfill_request(
self, origin: str, room_id: str, pdu_list: List[str], limit: int
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 60059fec3e..876b879483 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set
+from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set
import attr
@@ -134,6 +134,7 @@ class PaginationHandler:
self.clock = hs.get_clock()
self._server_name = hs.hostname
self._room_shutdown_handler = hs.get_room_shutdown_handler()
+ self._relations_handler = hs.get_relations_handler()
self.pagination_lock = ReadWriteLock()
# IDs of rooms in which there currently an active purge *or delete* operation.
@@ -422,7 +423,7 @@ class PaginationHandler:
pagin_config: PaginationConfig,
as_client_event: bool = True,
event_filter: Optional[Filter] = None,
- ) -> Dict[str, Any]:
+ ) -> JsonDict:
"""Get messages in a room.
Args:
@@ -431,6 +432,7 @@ class PaginationHandler:
pagin_config: The pagination config rules to apply, if any.
as_client_event: True to get events in client-server format.
event_filter: Filter to apply to results or None
+
Returns:
Pagination API results
"""
@@ -538,7 +540,9 @@ class PaginationHandler:
state_dict = await self.store.get_events(list(state_ids.values()))
state = state_dict.values()
- aggregations = await self.store.get_bundled_aggregations(events, user_id)
+ aggregations = await self._relations_handler.get_bundled_aggregations(
+ events, user_id
+ )
time_now = self.clock.time_msec()
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
new file mode 100644
index 0000000000..57135d4519
--- /dev/null
+++ b/synapse/handlers/relations.py
@@ -0,0 +1,264 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast
+
+import attr
+from frozendict import frozendict
+
+from synapse.api.constants import RelationTypes
+from synapse.api.errors import SynapseError
+from synapse.events import EventBase
+from synapse.types import JsonDict, Requester, StreamToken
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+ from synapse.storage.databases.main import DataStore
+
+
+logger = logging.getLogger(__name__)
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _ThreadAggregation:
+ # The latest event in the thread.
+ latest_event: EventBase
+ # The latest edit to the latest event in the thread.
+ latest_edit: Optional[EventBase]
+ # The total number of events in the thread.
+ count: int
+ # True if the current user has sent an event to the thread.
+ current_user_participated: bool
+
+
+@attr.s(slots=True, auto_attribs=True)
+class BundledAggregations:
+ """
+ The bundled aggregations for an event.
+
+ Some values require additional processing during serialization.
+ """
+
+ annotations: Optional[JsonDict] = None
+ references: Optional[JsonDict] = None
+ replace: Optional[EventBase] = None
+ thread: Optional[_ThreadAggregation] = None
+
+ def __bool__(self) -> bool:
+ return bool(self.annotations or self.references or self.replace or self.thread)
+
+
+class RelationsHandler:
+ def __init__(self, hs: "HomeServer"):
+ self._main_store = hs.get_datastores().main
+ self._auth = hs.get_auth()
+ self._clock = hs.get_clock()
+ self._event_handler = hs.get_event_handler()
+ self._event_serializer = hs.get_event_client_serializer()
+
+ async def get_relations(
+ self,
+ requester: Requester,
+ event_id: str,
+ room_id: str,
+ relation_type: Optional[str] = None,
+ event_type: Optional[str] = None,
+ aggregation_key: Optional[str] = None,
+ limit: int = 5,
+ direction: str = "b",
+ from_token: Optional[StreamToken] = None,
+ to_token: Optional[StreamToken] = None,
+ ) -> JsonDict:
+ """Get related events of a event, ordered by topological ordering.
+
+ TODO Accept a PaginationConfig instead of individual pagination parameters.
+
+ Args:
+ requester: The user requesting the relations.
+ event_id: Fetch events that relate to this event ID.
+ room_id: The room the event belongs to.
+ relation_type: Only fetch events with this relation type, if given.
+ event_type: Only fetch events with this event type, if given.
+ aggregation_key: Only fetch events with this aggregation key, if given.
+ limit: Only fetch the most recent `limit` events.
+ direction: Whether to fetch the most recent first (`"b"`) or the
+ oldest first (`"f"`).
+ from_token: Fetch rows from the given token, or from the start if None.
+ to_token: Fetch rows up to the given token, or up to the end if None.
+
+ Returns:
+ The pagination chunk.
+ """
+
+ user_id = requester.user.to_string()
+
+ await self._auth.check_user_in_room_or_world_readable(
+ room_id, user_id, allow_departed_users=True
+ )
+
+ # This gets the original event and checks that a) the event exists and
+ # b) the user is allowed to view it.
+ event = await self._event_handler.get_event(requester.user, room_id, event_id)
+ if event is None:
+ raise SynapseError(404, "Unknown parent event.")
+
+ pagination_chunk = await self._main_store.get_relations_for_event(
+ event_id=event_id,
+ event=event,
+ room_id=room_id,
+ relation_type=relation_type,
+ event_type=event_type,
+ aggregation_key=aggregation_key,
+ limit=limit,
+ direction=direction,
+ from_token=from_token,
+ to_token=to_token,
+ )
+
+ events = await self._main_store.get_events_as_list(
+ [c["event_id"] for c in pagination_chunk.chunk]
+ )
+
+ now = self._clock.time_msec()
+ # Do not bundle aggregations when retrieving the original event because
+ # we want the content before relations are applied to it.
+ original_event = self._event_serializer.serialize_event(
+ event, now, bundle_aggregations=None
+ )
+ # The relations returned for the requested event do include their
+ # bundled aggregations.
+ aggregations = await self.get_bundled_aggregations(
+ events, requester.user.to_string()
+ )
+ serialized_events = self._event_serializer.serialize_events(
+ events, now, bundle_aggregations=aggregations
+ )
+
+ return_value = await pagination_chunk.to_dict(self._main_store)
+ return_value["chunk"] = serialized_events
+ return_value["original_event"] = original_event
+
+ return return_value
+
+ async def _get_bundled_aggregation_for_event(
+ self, event: EventBase, user_id: str
+ ) -> Optional[BundledAggregations]:
+ """Generate bundled aggregations for an event.
+
+ Note that this does not use a cache, but depends on cached methods.
+
+ Args:
+ event: The event to calculate bundled aggregations for.
+ user_id: The user requesting the bundled aggregations.
+
+ Returns:
+ The bundled aggregations for an event, if bundled aggregations are
+ enabled and the event can have bundled aggregations.
+ """
+
+ # Do not bundle aggregations for an event which represents an edit or an
+ # annotation. It does not make sense for them to have related events.
+ relates_to = event.content.get("m.relates_to")
+ if isinstance(relates_to, (dict, frozendict)):
+ relation_type = relates_to.get("rel_type")
+ if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
+ return None
+
+ event_id = event.event_id
+ room_id = event.room_id
+
+ # The bundled aggregations to include, a mapping of relation type to a
+ # type-specific value. Some types include the direct return type here
+ # while others need more processing during serialization.
+ aggregations = BundledAggregations()
+
+ annotations = await self._main_store.get_aggregation_groups_for_event(
+ event_id, room_id
+ )
+ if annotations.chunk:
+ aggregations.annotations = await annotations.to_dict(
+ cast("DataStore", self)
+ )
+
+ references = await self._main_store.get_relations_for_event(
+ event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
+ )
+ if references.chunk:
+ aggregations.references = await references.to_dict(cast("DataStore", self))
+
+ # Store the bundled aggregations in the event metadata for later use.
+ return aggregations
+
+ async def get_bundled_aggregations(
+ self, events: Iterable[EventBase], user_id: str
+ ) -> Dict[str, BundledAggregations]:
+ """Generate bundled aggregations for events.
+
+ Args:
+ events: The iterable of events to calculate bundled aggregations for.
+ user_id: The user requesting the bundled aggregations.
+
+ Returns:
+ A map of event ID to the bundled aggregation for the event. Not all
+ events may have bundled aggregations in the results.
+ """
+ # De-duplicate events by ID to handle the same event requested multiple times.
+ #
+ # State events do not get bundled aggregations.
+ events_by_id = {
+ event.event_id: event for event in events if not event.is_state()
+ }
+
+ # event ID -> bundled aggregation in non-serialized form.
+ results: Dict[str, BundledAggregations] = {}
+
+ # Fetch other relations per event.
+ for event in events_by_id.values():
+ event_result = await self._get_bundled_aggregation_for_event(event, user_id)
+ if event_result:
+ results[event.event_id] = event_result
+
+ # Fetch any edits (but not for redacted events).
+ edits = await self._main_store.get_applicable_edits(
+ [
+ event_id
+ for event_id, event in events_by_id.items()
+ if not event.internal_metadata.is_redacted()
+ ]
+ )
+ for event_id, edit in edits.items():
+ results.setdefault(event_id, BundledAggregations()).replace = edit
+
+ # Fetch thread summaries.
+ summaries = await self._main_store.get_thread_summaries(events_by_id.keys())
+ # Only fetch participated for a limited selection based on what had
+ # summaries.
+ participated = await self._main_store.get_threads_participated(
+ [event_id for event_id, summary in summaries.items() if summary], user_id
+ )
+ for event_id, summary in summaries.items():
+ if summary:
+ thread_count, latest_thread_event, edit = summary
+ results.setdefault(
+ event_id, BundledAggregations()
+ ).thread = _ThreadAggregation(
+ latest_event=latest_thread_event,
+ latest_edit=edit,
+ count=thread_count,
+ # If there's a thread summary it must also exist in the
+ # participated dictionary.
+ current_user_participated=participated[event_id],
+ )
+
+ return results
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index b9735631fc..092e185c99 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -60,8 +60,8 @@ from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents
from synapse.federation.federation_client import InvalidResponseError
from synapse.handlers.federation import get_domains_from_state
+from synapse.handlers.relations import BundledAggregations
from synapse.rest.admin._base import assert_user_is_admin
-from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.state import StateFilter
from synapse.streams import EventSource
from synapse.types import (
@@ -1118,6 +1118,7 @@ class RoomContextHandler:
self.store = hs.get_datastores().main
self.storage = hs.get_storage()
self.state_store = self.storage.state
+ self._relations_handler = hs.get_relations_handler()
async def get_event_context(
self,
@@ -1190,7 +1191,7 @@ class RoomContextHandler:
event = filtered[0]
# Fetch the aggregations.
- aggregations = await self.store.get_bundled_aggregations(
+ aggregations = await self._relations_handler.get_bundled_aggregations(
itertools.chain(events_before, (event,), events_after),
user.to_string(),
)
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index aa16e417eb..30eddda65f 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -54,6 +54,7 @@ class SearchHandler:
self.clock = hs.get_clock()
self.hs = hs
self._event_serializer = hs.get_event_client_serializer()
+ self._relations_handler = hs.get_relations_handler()
self.storage = hs.get_storage()
self.state_store = self.storage.state
self.auth = hs.get_auth()
@@ -354,7 +355,7 @@ class SearchHandler:
aggregations = None
if self._msc3666_enabled:
- aggregations = await self.store.get_bundled_aggregations(
+ aggregations = await self._relations_handler.get_bundled_aggregations(
# Generate an iterable of EventBase for all the events that will be
# returned, including contextual events.
itertools.chain(
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 0aa3052fd6..6c569cfb1c 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -28,16 +28,16 @@ from typing import (
import attr
from prometheus_client import Counter
-from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes
+from synapse.api.constants import EventTypes, Membership, ReceiptTypes
from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
+from synapse.handlers.relations import BundledAggregations
from synapse.logging.context import current_context
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.databases.main.event_push_actions import NotifCounts
-from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
from synapse.types import (
@@ -269,6 +269,7 @@ class SyncHandler:
self.store = hs.get_datastores().main
self.notifier = hs.get_notifier()
self.presence_handler = hs.get_presence_handler()
+ self._relations_handler = hs.get_relations_handler()
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
self.state = hs.get_state_handler()
@@ -638,8 +639,10 @@ class SyncHandler:
# as clients will have all the necessary information.
bundled_aggregations = None
if limited or newly_joined_room:
- bundled_aggregations = await self.store.get_bundled_aggregations(
- recents, sync_config.user.to_string()
+ bundled_aggregations = (
+ await self._relations_handler.get_bundled_aggregations(
+ recents, sync_config.user.to_string()
+ )
)
return TimelineBatch(
@@ -1601,7 +1604,7 @@ class SyncHandler:
return set(), set(), set(), set()
# 3. Work out which rooms need reporting in the sync response.
- ignored_users = await self._get_ignored_users(user_id)
+ ignored_users = await self.store.ignored_users(user_id)
if since_token:
room_changes = await self._get_rooms_changed(
sync_result_builder, ignored_users
@@ -1627,7 +1630,6 @@ class SyncHandler:
logger.debug("Generating room entry for %s", room_entry.room_id)
await self._generate_room_entry(
sync_result_builder,
- ignored_users,
room_entry,
ephemeral=ephemeral_by_room.get(room_entry.room_id, []),
tags=tags_by_room.get(room_entry.room_id),
@@ -1657,29 +1659,6 @@ class SyncHandler:
newly_left_users,
)
- async def _get_ignored_users(self, user_id: str) -> FrozenSet[str]:
- """Retrieve the users ignored by the given user from their global account_data.
-
- Returns an empty set if
- - there is no global account_data entry for ignored_users
- - there is such an entry, but it's not a JSON object.
- """
- # TODO: Can we `SELECT ignored_user_id FROM ignored_users WHERE ignorer_user_id=?;` instead?
- ignored_account_data = (
- await self.store.get_global_account_data_by_type_for_user(
- user_id=user_id, data_type=AccountDataTypes.IGNORED_USER_LIST
- )
- )
-
- # If there is ignored users account data and it matches the proper type,
- # then use it.
- ignored_users: FrozenSet[str] = frozenset()
- if ignored_account_data:
- ignored_users_data = ignored_account_data.get("ignored_users", {})
- if isinstance(ignored_users_data, dict):
- ignored_users = frozenset(ignored_users_data.keys())
- return ignored_users
-
async def _have_rooms_changed(
self, sync_result_builder: "SyncResultBuilder"
) -> bool:
@@ -2022,7 +2001,6 @@ class SyncHandler:
async def _generate_room_entry(
self,
sync_result_builder: "SyncResultBuilder",
- ignored_users: FrozenSet[str],
room_builder: "RoomSyncResultBuilder",
ephemeral: List[JsonDict],
tags: Optional[Dict[str, Dict[str, Any]]],
@@ -2051,7 +2029,6 @@ class SyncHandler:
Args:
sync_result_builder
- ignored_users: Set of users ignored by user.
room_builder
ephemeral: List of new ephemeral events for room
tags: List of *all* tags for room, or None if there has been
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index d27ed2be6a..048fd4bb82 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -19,8 +19,8 @@ import synapse.metrics
from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership
from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases.main.user_directory import SearchResult
from synapse.storage.roommember import ProfileInfo
-from synapse.types import JsonDict
from synapse.util.metrics import Measure
if TYPE_CHECKING:
@@ -78,7 +78,7 @@ class UserDirectoryHandler(StateDeltasHandler):
async def search_users(
self, user_id: str, search_term: str, limit: int
- ) -> JsonDict:
+ ) -> SearchResult:
"""Searches for users in directory
Returns:
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index d735c1d461..aa8256b36f 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -111,6 +111,7 @@ from synapse.types import (
StateMap,
UserID,
UserInfo,
+ UserProfile,
create_requester,
)
from synapse.util import Clock
@@ -150,6 +151,7 @@ __all__ = [
"EventBase",
"StateMap",
"ProfileInfo",
+ "UserProfile",
]
logger = logging.getLogger(__name__)
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 8140afcb6b..030898e4d0 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -213,7 +213,7 @@ class BulkPushRuleEvaluator:
if not event.is_state():
ignorers = await self.store.ignored_by(event.sender)
else:
- ignorers = set()
+ ignorers = frozenset()
for uid, rules in rules_by_user.items():
if event.sender == uid:
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index d9a6be43f7..c16078b187 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -51,9 +51,7 @@ class RelationPaginationServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
- self.clock = hs.get_clock()
- self._event_serializer = hs.get_event_client_serializer()
- self.event_handler = hs.get_event_handler()
+ self._relations_handler = hs.get_relations_handler()
async def on_GET(
self,
@@ -65,16 +63,6 @@ class RelationPaginationServlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
- await self.auth.check_user_in_room_or_world_readable(
- room_id, requester.user.to_string(), allow_departed_users=True
- )
-
- # This gets the original event and checks that a) the event exists and
- # b) the user is allowed to view it.
- event = await self.event_handler.get_event(requester.user, room_id, parent_id)
- if event is None:
- raise SynapseError(404, "Unknown parent event.")
-
limit = parse_integer(request, "limit", default=5)
direction = parse_string(
request, "org.matrix.msc3715.dir", default="b", allowed_values=["f", "b"]
@@ -90,9 +78,9 @@ class RelationPaginationServlet(RestServlet):
if to_token_str:
to_token = await StreamToken.from_string(self.store, to_token_str)
- pagination_chunk = await self.store.get_relations_for_event(
+ result = await self._relations_handler.get_relations(
+ requester=requester,
event_id=parent_id,
- event=event,
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
@@ -102,30 +90,7 @@ class RelationPaginationServlet(RestServlet):
to_token=to_token,
)
- events = await self.store.get_events_as_list(
- [c["event_id"] for c in pagination_chunk.chunk]
- )
-
- now = self.clock.time_msec()
- # Do not bundle aggregations when retrieving the original event because
- # we want the content before relations are applied to it.
- original_event = self._event_serializer.serialize_event(
- event, now, bundle_aggregations=None
- )
- # The relations returned for the requested event do include their
- # bundled aggregations.
- aggregations = await self.store.get_bundled_aggregations(
- events, requester.user.to_string()
- )
- serialized_events = self._event_serializer.serialize_events(
- events, now, bundle_aggregations=aggregations
- )
-
- return_value = await pagination_chunk.to_dict(self.store)
- return_value["chunk"] = serialized_events
- return_value["original_event"] = original_event
-
- return 200, return_value
+ return 200, result
class RelationAggregationPaginationServlet(RestServlet):
@@ -245,9 +210,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
- self.clock = hs.get_clock()
- self._event_serializer = hs.get_event_client_serializer()
- self.event_handler = hs.get_event_handler()
+ self._relations_handler = hs.get_relations_handler()
async def on_GET(
self,
@@ -260,18 +223,6 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
- await self.auth.check_user_in_room_or_world_readable(
- room_id,
- requester.user.to_string(),
- allow_departed_users=True,
- )
-
- # This checks that a) the event exists and b) the user is allowed to
- # view it.
- event = await self.event_handler.get_event(requester.user, room_id, parent_id)
- if event is None:
- raise SynapseError(404, "Unknown parent event.")
-
if relation_type != RelationTypes.ANNOTATION:
raise SynapseError(400, "Relation type must be 'annotation'")
@@ -286,9 +237,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
if to_token_str:
to_token = await StreamToken.from_string(self.store, to_token_str)
- result = await self.store.get_relations_for_event(
+ result = await self._relations_handler.get_relations(
+ requester=requester,
event_id=parent_id,
- event=event,
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
@@ -298,17 +249,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
to_token=to_token,
)
- events = await self.store.get_events_as_list(
- [c["event_id"] for c in result.chunk]
- )
-
- now = self.clock.time_msec()
- serialized_events = self._event_serializer.serialize_events(events, now)
-
- return_value = await result.to_dict(self.store)
- return_value["chunk"] = serialized_events
-
- return 200, return_value
+ return 200, result
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 8a06ab8c5f..47e152c8cc 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -645,6 +645,7 @@ class RoomEventServlet(RestServlet):
self._store = hs.get_datastores().main
self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
+ self._relations_handler = hs.get_relations_handler()
self.auth = hs.get_auth()
async def on_GET(
@@ -663,7 +664,7 @@ class RoomEventServlet(RestServlet):
if event:
# Ensure there are bundled aggregations available.
- aggregations = await self._store.get_bundled_aggregations(
+ aggregations = await self._relations_handler.get_bundled_aggregations(
[event], requester.user.to_string()
)
diff --git a/synapse/rest/client/user_directory.py b/synapse/rest/client/user_directory.py
index a47d9bd01d..116c982ce6 100644
--- a/synapse/rest/client/user_directory.py
+++ b/synapse/rest/client/user_directory.py
@@ -19,7 +19,7 @@ from synapse.api.errors import SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
-from synapse.types import JsonDict
+from synapse.types import JsonMapping
from ._base import client_patterns
@@ -38,7 +38,7 @@ class UserDirectorySearchRestServlet(RestServlet):
self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler()
- async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonMapping]:
"""Searches for users in directory
Returns:
diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py
index 872a9e72e8..4cc9c66fbe 100644
--- a/synapse/rest/media/v1/preview_html.py
+++ b/synapse/rest/media/v1/preview_html.py
@@ -16,7 +16,6 @@ import itertools
import logging
import re
from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Union
-from urllib import parse as urlparse
if TYPE_CHECKING:
from lxml import etree
@@ -144,9 +143,7 @@ def decode_body(
return etree.fromstring(body, parser)
-def parse_html_to_open_graph(
- tree: "etree.Element", media_uri: str
-) -> Dict[str, Optional[str]]:
+def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
"""
Parse the HTML document into an Open Graph response.
@@ -155,7 +152,6 @@ def parse_html_to_open_graph(
Args:
tree: The parsed HTML document.
- media_url: The URI used to download the body.
Returns:
The Open Graph response as a dictionary.
@@ -209,7 +205,7 @@ def parse_html_to_open_graph(
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
)
if meta_image:
- og["og:image"] = rebase_url(meta_image[0], media_uri)
+ og["og:image"] = meta_image[0]
else:
# TODO: consider inlined CSS styles as well as width & height attribs
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
@@ -320,37 +316,6 @@ def _iterate_over_text(
)
-def rebase_url(url: str, base: str) -> str:
- """
- Resolves a potentially relative `url` against an absolute `base` URL.
-
- For example:
-
- >>> rebase_url("subpage", "https://example.com/foo/")
- 'https://example.com/foo/subpage'
- >>> rebase_url("sibling", "https://example.com/foo")
- 'https://example.com/sibling'
- >>> rebase_url("/bar", "https://example.com/foo/")
- 'https://example.com/bar'
- >>> rebase_url("https://alice.com/a/", "https://example.com/foo/")
- 'https://alice.com/a'
- """
- base_parts = urlparse.urlparse(base)
- # Convert the parsed URL to a list for (potential) modification.
- url_parts = list(urlparse.urlparse(url))
- # Add a scheme, if one does not exist.
- if not url_parts[0]:
- url_parts[0] = base_parts.scheme or "http"
- # Fix up the hostname, if this is not a data URL.
- if url_parts[0] != "data" and not url_parts[1]:
- url_parts[1] = base_parts.netloc
- # If the path does not start with a /, nest it under the base path's last
- # directory.
- if not url_parts[2].startswith("/"):
- url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts.path) + url_parts[2]
- return urlparse.urlunparse(url_parts)
-
-
def summarize_paragraphs(
text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
) -> Optional[str]:
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 14ea88b240..d47af8ead6 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -22,7 +22,7 @@ import shutil
import sys
import traceback
from typing import TYPE_CHECKING, BinaryIO, Iterable, Optional, Tuple
-from urllib import parse as urlparse
+from urllib.parse import urljoin, urlparse, urlsplit
from urllib.request import urlopen
import attr
@@ -44,11 +44,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.oembed import OEmbedProvider
-from synapse.rest.media.v1.preview_html import (
- decode_body,
- parse_html_to_open_graph,
- rebase_url,
-)
+from synapse.rest.media.v1.preview_html import decode_body, parse_html_to_open_graph
from synapse.types import JsonDict, UserID
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
@@ -187,7 +183,7 @@ class PreviewUrlResource(DirectServeJsonResource):
ts = self.clock.time_msec()
# XXX: we could move this into _do_preview if we wanted.
- url_tuple = urlparse.urlsplit(url)
+ url_tuple = urlsplit(url)
for entry in self.url_preview_url_blacklist:
match = True
for attrib in entry:
@@ -322,7 +318,7 @@ class PreviewUrlResource(DirectServeJsonResource):
# Parse Open Graph information from the HTML in case the oEmbed
# response failed or is incomplete.
- og_from_html = parse_html_to_open_graph(tree, media_info.uri)
+ og_from_html = parse_html_to_open_graph(tree)
# Compile the Open Graph response by using the scraped
# information from the HTML and overlaying any information
@@ -588,12 +584,17 @@ class PreviewUrlResource(DirectServeJsonResource):
if "og:image" not in og or not og["og:image"]:
return
+ # The image URL from the HTML might be relative to the previewed page,
+ # convert it to an URL which can be requested directly.
+ image_url = og["og:image"]
+ url_parts = urlparse(image_url)
+ if url_parts.scheme != "data":
+ image_url = urljoin(media_info.uri, image_url)
+
# FIXME: it might be cleaner to use the same flow as the main /preview_url
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
- image_info = await self._handle_url(
- rebase_url(og["og:image"], media_info.uri), user, allow_data_urls=True
- )
+ image_info = await self._handle_url(image_url, user, allow_data_urls=True)
if _is_media(image_info.media_type):
# TODO: make sure we don't choke on white-on-transparent images
diff --git a/synapse/server.py b/synapse/server.py
index 2fcf18a7a6..380369db92 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -94,6 +94,7 @@ from synapse.handlers.profile import ProfileHandler
from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.receipts import ReceiptsHandler
from synapse.handlers.register import RegistrationHandler
+from synapse.handlers.relations import RelationsHandler
from synapse.handlers.room import (
RoomContextHandler,
RoomCreationHandler,
@@ -720,6 +721,10 @@ class HomeServer(metaclass=abc.ABCMeta):
return PaginationHandler(self)
@cache_in_self
+ def get_relations_handler(self) -> RelationsHandler:
+ return RelationsHandler(self)
+
+ @cache_in_self
def get_room_context_handler(self) -> RoomContextHandler:
return RoomContextHandler(self)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 99802228c9..9749f0c06e 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -41,6 +41,7 @@ from prometheus_client import Histogram
from typing_extensions import Literal
from twisted.enterprise import adbapi
+from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
@@ -55,6 +56,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
+from synapse.util.async_helpers import delay_cancellation
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@@ -732,34 +734,45 @@ class DatabasePool:
Returns:
The result of func
"""
- after_callbacks: List[_CallbackListEntry] = []
- exception_callbacks: List[_CallbackListEntry] = []
- if not current_context():
- logger.warning("Starting db txn '%s' from sentinel context", desc)
+ async def _runInteraction() -> R:
+ after_callbacks: List[_CallbackListEntry] = []
+ exception_callbacks: List[_CallbackListEntry] = []
- try:
- with opentracing.start_active_span(f"db.{desc}"):
- result = await self.runWithConnection(
- self.new_transaction,
- desc,
- after_callbacks,
- exception_callbacks,
- func,
- *args,
- db_autocommit=db_autocommit,
- isolation_level=isolation_level,
- **kwargs,
- )
+ if not current_context():
+ logger.warning("Starting db txn '%s' from sentinel context", desc)
- for after_callback, after_args, after_kwargs in after_callbacks:
- after_callback(*after_args, **after_kwargs)
- except Exception:
- for after_callback, after_args, after_kwargs in exception_callbacks:
- after_callback(*after_args, **after_kwargs)
- raise
+ try:
+ with opentracing.start_active_span(f"db.{desc}"):
+ result = await self.runWithConnection(
+ self.new_transaction,
+ desc,
+ after_callbacks,
+ exception_callbacks,
+ func,
+ *args,
+ db_autocommit=db_autocommit,
+ isolation_level=isolation_level,
+ **kwargs,
+ )
- return cast(R, result)
+ for after_callback, after_args, after_kwargs in after_callbacks:
+ after_callback(*after_args, **after_kwargs)
+
+ return cast(R, result)
+ except Exception:
+ for after_callback, after_args, after_kwargs in exception_callbacks:
+ after_callback(*after_args, **after_kwargs)
+ raise
+
+ # To handle cancellation, we ensure that `after_callback`s and
+ # `exception_callback`s are always run, since the transaction will complete
+ # on another thread regardless of cancellation.
+ #
+ # We also wait until everything above is done before releasing the
+ # `CancelledError`, so that logging contexts won't get used after they have been
+ # finished.
+ return await delay_cancellation(defer.ensureDeferred(_runInteraction()))
async def runWithConnection(
self,
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 52146aacc8..9af9f4f18e 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -14,7 +14,17 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ FrozenSet,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ cast,
+)
from synapse.api.constants import AccountDataTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@@ -365,7 +375,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
)
@cached(max_entries=5000, iterable=True)
- async def ignored_by(self, user_id: str) -> Set[str]:
+ async def ignored_by(self, user_id: str) -> FrozenSet[str]:
"""
Get users which ignore the given user.
@@ -375,7 +385,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
Return:
The user IDs which ignore the given user.
"""
- return set(
+ return frozenset(
await self.db_pool.simple_select_onecol(
table="ignored_users",
keyvalues={"ignored_user_id": user_id},
@@ -384,6 +394,26 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
)
)
+ @cached(max_entries=5000, iterable=True)
+ async def ignored_users(self, user_id: str) -> FrozenSet[str]:
+ """
+ Get users which the given user ignores.
+
+ Params:
+ user_id: The user ID which is making the request.
+
+ Return:
+ The user IDs which are ignored by the given user.
+ """
+ return frozenset(
+ await self.db_pool.simple_select_onecol(
+ table="ignored_users",
+ keyvalues={"ignorer_user_id": user_id},
+ retcol="ignored_user_id",
+ desc="ignored_users",
+ )
+ )
+
def process_replication_rows(
self,
stream_name: str,
@@ -529,6 +559,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
else:
currently_ignored_users = set()
+ # If the data has not changed, nothing to do.
+ if previously_ignored_users == currently_ignored_users:
+ return
+
# Delete entries which are no longer ignored.
self.db_pool.simple_delete_many_txn(
txn,
@@ -551,6 +585,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
# Invalidate the cache for any ignored users which were added or removed.
for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
+ self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
async def purge_account_data_for_user(self, user_id: str) -> None:
"""
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index d6a2df1afe..2d7511d613 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamCurrentStateRow,
EventsStreamEventRow,
+ EventsStreamRow,
)
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
@@ -31,6 +32,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
+from synapse.util.caches.descriptors import _CachedFunction
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@@ -82,7 +84,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
- def get_all_updated_caches_txn(txn):
+ def get_all_updated_caches_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
@@ -107,7 +111,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
"get_all_updated_caches", get_all_updated_caches_txn
)
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(
+ self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
+ ) -> None:
if stream_name == EventsStream.NAME:
for row in rows:
self._process_event_stream_row(token, row)
@@ -142,10 +148,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
super().process_replication_rows(stream_name, instance_name, token, rows)
- def _process_event_stream_row(self, token, row):
+ def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
data = row.data
if row.type == EventsStreamEventRow.TypeId:
+ assert isinstance(data, EventsStreamEventRow)
self._invalidate_caches_for_event(
token,
data.event_id,
@@ -157,9 +164,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
backfilled=False,
)
elif row.type == EventsStreamCurrentStateRow.TypeId:
- self._curr_state_delta_stream_cache.entity_has_changed(
- row.data.room_id, token
- )
+ assert isinstance(data, EventsStreamCurrentStateRow)
+ self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token)
if data.type == EventTypes.Member:
self.get_rooms_for_user_with_stream_ordering.invalidate(
@@ -170,15 +176,15 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
def _invalidate_caches_for_event(
self,
- stream_ordering,
- event_id,
- room_id,
- etype,
- state_key,
- redacts,
- relates_to,
- backfilled,
- ):
+ stream_ordering: int,
+ event_id: str,
+ room_id: str,
+ etype: str,
+ state_key: Optional[str],
+ redacts: Optional[str],
+ relates_to: Optional[str],
+ backfilled: bool,
+ ) -> None:
self._invalidate_get_event_cache(event_id)
self.have_seen_event.invalidate((room_id, event_id))
@@ -207,7 +213,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_thread_summary.invalidate((relates_to,))
self.get_thread_participated.invalidate((relates_to,))
- async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
+ async def invalidate_cache_and_stream(
+ self, cache_name: str, keys: Tuple[Any, ...]
+ ) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
@@ -227,7 +235,12 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
keys,
)
- def _invalidate_cache_and_stream(self, txn, cache_func, keys):
+ def _invalidate_cache_and_stream(
+ self,
+ txn: LoggingTransaction,
+ cache_func: _CachedFunction,
+ keys: Tuple[Any, ...],
+ ) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
@@ -238,7 +251,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
- def _invalidate_all_cache_and_stream(self, txn, cache_func):
+ def _invalidate_all_cache_and_stream(
+ self, txn: LoggingTransaction, cache_func: _CachedFunction
+ ) -> None:
"""Invalidates the entire cache and adds it to the cache stream so slaves
will know to invalidate their caches.
"""
@@ -279,8 +294,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
)
def _send_invalidation_to_replication(
- self, txn, cache_name: str, keys: Optional[Iterable[Any]]
- ):
+ self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
+ ) -> None:
"""Notifies replication that given cache has been invalidated.
Note that this does *not* invalidate the cache locally.
@@ -315,7 +330,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
"instance_name": self._instance_name,
"cache_func": cache_name,
"keys": keys,
- "invalidation_ts": self.clock.time_msec(),
+ "invalidation_ts": self._clock.time_msec(),
},
)
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 3f6086050b..0aef121d83 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -13,13 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
from typing_extensions import TypedDict
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.types import JsonDict
from synapse.util import json_encoder
@@ -75,7 +79,7 @@ class GroupServerWorkerStore(SQLBaseStore):
) -> List[Dict[str, Any]]:
# TODO: Pagination
- keyvalues = {"group_id": group_id}
+ keyvalues: JsonDict = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True
@@ -117,7 +121,7 @@ class GroupServerWorkerStore(SQLBaseStore):
# TODO: Pagination
- def _get_rooms_in_group_txn(txn):
+ def _get_rooms_in_group_txn(txn: LoggingTransaction) -> List[_RoomInGroup]:
sql = """
SELECT room_id, is_public FROM group_rooms
WHERE group_id = ?
@@ -176,8 +180,10 @@ class GroupServerWorkerStore(SQLBaseStore):
* "order": int, the sort order of rooms in this category
"""
- def _get_rooms_for_summary_txn(txn):
- keyvalues = {"group_id": group_id}
+ def _get_rooms_for_summary_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
+ keyvalues: JsonDict = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True
@@ -241,7 +247,7 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_rooms_for_summary", _get_rooms_for_summary_txn
)
- async def get_group_categories(self, group_id):
+ async def get_group_categories(self, group_id: str) -> JsonDict:
rows = await self.db_pool.simple_select_list(
table="group_room_categories",
keyvalues={"group_id": group_id},
@@ -257,7 +263,7 @@ class GroupServerWorkerStore(SQLBaseStore):
for row in rows
}
- async def get_group_category(self, group_id, category_id):
+ async def get_group_category(self, group_id: str, category_id: str) -> JsonDict:
category = await self.db_pool.simple_select_one(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
@@ -269,7 +275,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return category
- async def get_group_roles(self, group_id):
+ async def get_group_roles(self, group_id: str) -> JsonDict:
rows = await self.db_pool.simple_select_list(
table="group_roles",
keyvalues={"group_id": group_id},
@@ -285,7 +291,7 @@ class GroupServerWorkerStore(SQLBaseStore):
for row in rows
}
- async def get_group_role(self, group_id, role_id):
+ async def get_group_role(self, group_id: str, role_id: str) -> JsonDict:
role = await self.db_pool.simple_select_one(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
@@ -311,15 +317,19 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_local_groups_for_room",
)
- async def get_users_for_summary_by_role(self, group_id, include_private=False):
+ async def get_users_for_summary_by_role(
+ self, group_id: str, include_private: bool = False
+ ) -> Tuple[List[JsonDict], JsonDict]:
"""Get the users and roles that should be included in a summary request
Returns:
([users], [roles])
"""
- def _get_users_for_summary_txn(txn):
- keyvalues = {"group_id": group_id}
+ def _get_users_for_summary_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[JsonDict], JsonDict]:
+ keyvalues: JsonDict = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True
@@ -406,7 +416,9 @@ class GroupServerWorkerStore(SQLBaseStore):
allow_none=True,
)
- async def get_users_membership_info_in_group(self, group_id, user_id):
+ async def get_users_membership_info_in_group(
+ self, group_id: str, user_id: str
+ ) -> JsonDict:
"""Get a dict describing the membership of a user in a group.
Example if joined:
@@ -421,7 +433,7 @@ class GroupServerWorkerStore(SQLBaseStore):
An empty dict if the user is not join/invite/etc
"""
- def _get_users_membership_in_group_txn(txn):
+ def _get_users_membership_in_group_txn(txn: LoggingTransaction) -> JsonDict:
row = self.db_pool.simple_select_one_txn(
txn,
table="group_users",
@@ -463,10 +475,14 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_publicised_groups_for_user",
)
- async def get_attestations_need_renewals(self, valid_until_ms):
+ async def get_attestations_need_renewals(
+ self, valid_until_ms: int
+ ) -> List[Dict[str, Any]]:
"""Get all attestations that need to be renewed until givent time"""
- def _get_attestations_need_renewals_txn(txn):
+ def _get_attestations_need_renewals_txn(
+ txn: LoggingTransaction,
+ ) -> List[Dict[str, Any]]:
sql = """
SELECT group_id, user_id FROM group_attestations_renewals
WHERE valid_until_ms <= ?
@@ -478,7 +494,9 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_attestations_need_renewals", _get_attestations_need_renewals_txn
)
- async def get_remote_attestation(self, group_id, user_id):
+ async def get_remote_attestation(
+ self, group_id: str, user_id: str
+ ) -> Optional[JsonDict]:
"""Get the attestation that proves the remote agrees that the user is
in the group.
"""
@@ -504,8 +522,8 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_joined_groups",
)
- async def get_all_groups_for_user(self, user_id, now_token):
- def _get_all_groups_for_user_txn(txn):
+ async def get_all_groups_for_user(self, user_id, now_token) -> List[JsonDict]:
+ def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
sql = """
SELECT group_id, type, membership, u.content
FROM local_group_updates AS u
@@ -528,15 +546,16 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_all_groups_for_user", _get_all_groups_for_user_txn
)
- async def get_groups_changes_for_user(self, user_id, from_token, to_token):
- from_token = int(from_token)
- has_changed = self._group_updates_stream_cache.has_entity_changed(
+ async def get_groups_changes_for_user(
+ self, user_id: str, from_token: int, to_token: int
+ ) -> List[JsonDict]:
+ has_changed = self._group_updates_stream_cache.has_entity_changed( # type: ignore[attr-defined]
user_id, from_token
)
if not has_changed:
return []
- def _get_groups_changes_for_user_txn(txn):
+ def _get_groups_changes_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
sql = """
SELECT group_id, membership, type, u.content
FROM local_group_updates AS u
@@ -583,12 +602,14 @@ class GroupServerWorkerStore(SQLBaseStore):
"""
last_id = int(last_id)
- has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id)
+ has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) # type: ignore[attr-defined]
if not has_changed:
return [], current_id, False
- def _get_all_groups_changes_txn(txn):
+ def _get_all_groups_changes_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
sql = """
SELECT stream_id, group_id, user_id, type, content
FROM local_group_updates
@@ -596,10 +617,13 @@ class GroupServerWorkerStore(SQLBaseStore):
LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
- updates = [
- (stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
- for stream_id, group_id, user_id, gtype, content_json in txn
- ]
+ updates = cast(
+ List[Tuple[int, tuple]],
+ [
+ (stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
+ for stream_id, group_id, user_id, gtype, content_json in txn
+ ],
+ )
limited = False
upto_token = current_id
@@ -633,8 +657,8 @@ class GroupServerStore(GroupServerWorkerStore):
self,
group_id: str,
room_id: str,
- category_id: str,
- order: int,
+ category_id: Optional[str],
+ order: Optional[int],
is_public: Optional[bool],
) -> None:
"""Add (or update) room's entry in summary.
@@ -661,11 +685,11 @@ class GroupServerStore(GroupServerWorkerStore):
def _add_room_to_summary_txn(
self,
- txn,
+ txn: LoggingTransaction,
group_id: str,
room_id: str,
- category_id: str,
- order: int,
+ category_id: Optional[str],
+ order: Optional[int],
is_public: Optional[bool],
) -> None:
"""Add (or update) room's entry in summary.
@@ -750,7 +774,7 @@ class GroupServerStore(GroupServerWorkerStore):
WHERE group_id = ? AND category_id = ?
"""
txn.execute(sql, (group_id, category_id))
- (order,) = txn.fetchone()
+ (order,) = cast(Tuple[int], txn.fetchone())
if existing:
to_update = {}
@@ -766,7 +790,7 @@ class GroupServerStore(GroupServerWorkerStore):
"category_id": category_id,
"room_id": room_id,
},
- values=to_update,
+ updatevalues=to_update,
)
else:
if is_public is None:
@@ -785,7 +809,7 @@ class GroupServerStore(GroupServerWorkerStore):
)
async def remove_room_from_summary(
- self, group_id: str, room_id: str, category_id: str
+ self, group_id: str, room_id: str, category_id: Optional[str]
) -> int:
if category_id is None:
category_id = _DEFAULT_CATEGORY_ID
@@ -808,8 +832,8 @@ class GroupServerStore(GroupServerWorkerStore):
is_public: Optional[bool],
) -> None:
"""Add/update room category for group"""
- insertion_values = {}
- update_values = {"category_id": category_id} # This cannot be empty
+ insertion_values: JsonDict = {}
+ update_values: JsonDict = {"category_id": category_id} # This cannot be empty
if profile is None:
insertion_values["profile"] = "{}"
@@ -844,8 +868,8 @@ class GroupServerStore(GroupServerWorkerStore):
is_public: Optional[bool],
) -> None:
"""Add/remove user role"""
- insertion_values = {}
- update_values = {"role_id": role_id} # This cannot be empty
+ insertion_values: JsonDict = {}
+ update_values: JsonDict = {"role_id": role_id} # This cannot be empty
if profile is None:
insertion_values["profile"] = "{}"
@@ -876,8 +900,8 @@ class GroupServerStore(GroupServerWorkerStore):
self,
group_id: str,
user_id: str,
- role_id: str,
- order: int,
+ role_id: Optional[str],
+ order: Optional[int],
is_public: Optional[bool],
) -> None:
"""Add (or update) user's entry in summary.
@@ -904,13 +928,13 @@ class GroupServerStore(GroupServerWorkerStore):
def _add_user_to_summary_txn(
self,
- txn,
+ txn: LoggingTransaction,
group_id: str,
user_id: str,
- role_id: str,
- order: int,
+ role_id: Optional[str],
+ order: Optional[int],
is_public: Optional[bool],
- ):
+ ) -> None:
"""Add (or update) user's entry in summary.
Args:
@@ -989,7 +1013,7 @@ class GroupServerStore(GroupServerWorkerStore):
WHERE group_id = ? AND role_id = ?
"""
txn.execute(sql, (group_id, role_id))
- (order,) = txn.fetchone()
+ (order,) = cast(Tuple[int], txn.fetchone())
if existing:
to_update = {}
@@ -1005,7 +1029,7 @@ class GroupServerStore(GroupServerWorkerStore):
"role_id": role_id,
"user_id": user_id,
},
- values=to_update,
+ updatevalues=to_update,
)
else:
if is_public is None:
@@ -1024,7 +1048,7 @@ class GroupServerStore(GroupServerWorkerStore):
)
async def remove_user_from_summary(
- self, group_id: str, user_id: str, role_id: str
+ self, group_id: str, user_id: str, role_id: Optional[str]
) -> int:
if role_id is None:
role_id = _DEFAULT_ROLE_ID
@@ -1065,7 +1089,7 @@ class GroupServerStore(GroupServerWorkerStore):
Optional if the user and group are on the same server
"""
- def _add_user_to_group_txn(txn):
+ def _add_user_to_group_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_insert_txn(
txn,
table="group_users",
@@ -1108,7 +1132,7 @@ class GroupServerStore(GroupServerWorkerStore):
await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
async def remove_user_from_group(self, group_id: str, user_id: str) -> None:
- def _remove_user_from_group_txn(txn):
+ def _remove_user_from_group_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn,
table="group_users",
@@ -1159,7 +1183,7 @@ class GroupServerStore(GroupServerWorkerStore):
)
async def remove_room_from_group(self, group_id: str, room_id: str) -> None:
- def _remove_room_from_group_txn(txn):
+ def _remove_room_from_group_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn,
table="group_rooms",
@@ -1216,7 +1240,9 @@ class GroupServerStore(GroupServerWorkerStore):
content = content or {}
- def _register_user_group_membership_txn(txn, next_id):
+ def _register_user_group_membership_txn(
+ txn: LoggingTransaction, next_id: int
+ ) -> int:
# TODO: Upsert?
self.db_pool.simple_delete_txn(
txn,
@@ -1249,7 +1275,7 @@ class GroupServerStore(GroupServerWorkerStore):
),
},
)
- self._group_updates_stream_cache.entity_has_changed(user_id, next_id)
+ self._group_updates_stream_cache.entity_has_changed(user_id, next_id) # type: ignore[attr-defined]
# TODO: Insert profile to ensure it comes down stream if its a join.
@@ -1289,7 +1315,7 @@ class GroupServerStore(GroupServerWorkerStore):
return next_id
- async with self._group_updates_id_gen.get_next() as next_id:
+ async with self._group_updates_id_gen.get_next() as next_id: # type: ignore[attr-defined]
res = await self.db_pool.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,
@@ -1298,7 +1324,13 @@ class GroupServerStore(GroupServerWorkerStore):
return res
async def create_group(
- self, group_id, user_id, name, avatar_url, short_description, long_description
+ self,
+ group_id: str,
+ user_id: str,
+ name: str,
+ avatar_url: str,
+ short_description: str,
+ long_description: str,
) -> None:
await self.db_pool.simple_insert(
table="groups",
@@ -1313,7 +1345,7 @@ class GroupServerStore(GroupServerWorkerStore):
desc="create_group",
)
- async def update_group_profile(self, group_id, profile):
+ async def update_group_profile(self, group_id: str, profile: JsonDict) -> None:
await self.db_pool.simple_update_one(
table="groups",
keyvalues={"group_id": group_id},
@@ -1361,8 +1393,8 @@ class GroupServerStore(GroupServerWorkerStore):
desc="remove_attestation_renewal",
)
- def get_group_stream_token(self):
- return self._group_updates_id_gen.get_current_token()
+ def get_group_stream_token(self) -> int:
+ return self._group_updates_id_gen.get_current_token() # type: ignore[attr-defined]
async def delete_group(self, group_id: str) -> None:
"""Deletes a group fully from the database.
@@ -1371,7 +1403,7 @@ class GroupServerStore(GroupServerWorkerStore):
group_id: The group ID to delete.
"""
- def _delete_group_txn(txn):
+ def _delete_group_txn(txn: LoggingTransaction) -> None:
tables = [
"groups",
"group_users",
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index e9a0cdc6be..216622964a 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -12,15 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, List, Optional
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
+ LoggingTransaction,
make_in_list_sql_clause,
)
+from synapse.storage.databases.main.registration import RegistrationWorkerStore
from synapse.util.caches.descriptors import cached
from synapse.util.threepids import canonicalise_email
@@ -56,7 +58,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
Number of current monthly active users
"""
- def _count_users(txn):
+ def _count_users(txn: LoggingTransaction) -> int:
# Exclude app service users
sql = """
SELECT COUNT(*)
@@ -66,7 +68,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
WHERE (users.appservice_id IS NULL OR users.appservice_id = '');
"""
txn.execute(sql)
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction("count_users", _count_users)
@@ -84,7 +86,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
"""
- def _count_users_by_service(txn):
+ def _count_users_by_service(txn: LoggingTransaction) -> Dict[str, int]:
sql = """
SELECT COALESCE(appservice_id, 'native'), COUNT(*)
FROM monthly_active_users
@@ -93,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
"""
txn.execute(sql)
- result = txn.fetchall()
+ result = cast(List[Tuple[str, int]], txn.fetchall())
return dict(result)
return await self.db_pool.runInteraction(
@@ -141,12 +143,12 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
)
@wrap_as_background_process("reap_monthly_active_users")
- async def reap_monthly_active_users(self):
+ async def reap_monthly_active_users(self) -> None:
"""Cleans out monthly active user table to ensure that no stale
entries exist.
"""
- def _reap_users(txn, reserved_users):
+ def _reap_users(txn: LoggingTransaction, reserved_users: List[str]) -> None:
"""
Args:
reserved_users (tuple): reserved users to preserve
@@ -210,10 +212,10 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
# is racy.
# Have resolved to invalidate the whole cache for now and do
# something about it if and when the perf becomes significant
- self._invalidate_all_cache_and_stream(
+ self._invalidate_all_cache_and_stream( # type: ignore[attr-defined]
txn, self.user_last_seen_monthly_active
)
- self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
+ self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) # type: ignore[attr-defined]
reserved_users = await self.get_registered_reserved_users()
await self.db_pool.runInteraction(
@@ -221,7 +223,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
)
-class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
+class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerStore):
def __init__(
self,
database: DatabasePool,
@@ -242,13 +244,15 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
)
- def _initialise_reserved_users(self, txn, threepids):
+ def _initialise_reserved_users(
+ self, txn: LoggingTransaction, threepids: List[dict]
+ ) -> None:
"""Ensures that reserved threepids are accounted for in the MAU table, should
be called on start up.
Args:
- txn (cursor):
- threepids (list[dict]): List of threepid dicts to reserve
+ txn:
+ threepids: List of threepid dicts to reserve
"""
# XXX what is this function trying to achieve? It upserts into
@@ -299,7 +303,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
)
- def upsert_monthly_active_user_txn(self, txn, user_id):
+ def upsert_monthly_active_user_txn(
+ self, txn: LoggingTransaction, user_id: str
+ ) -> None:
"""Updates or inserts monthly active user member
We consciously do not call is_support_txn from this method because it
@@ -336,7 +342,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
txn, self.user_last_seen_monthly_active, (user_id,)
)
- async def populate_monthly_active_users(self, user_id):
+ async def populate_monthly_active_users(self, user_id: str) -> None:
"""Checks on the state of monthly active user limits and optionally
add the user to the monthly active tables
@@ -345,7 +351,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"""
if self._limit_usage_by_mau or self._mau_stats_only:
# Trial users and guests should not be included as part of MAU group
- is_guest = await self.is_guest(user_id)
+ is_guest = await self.is_guest(user_id) # type: ignore[attr-defined]
if is_guest:
return
is_trial = await self.is_trial_user(user_id)
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index c4869d64e6..b2295fd51f 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -27,7 +27,6 @@ from typing import (
)
import attr
-from frozendict import frozendict
from synapse.api.constants import RelationTypes
from synapse.events import EventBase
@@ -41,45 +40,15 @@ from synapse.storage.database import (
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine
from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
-from synapse.types import JsonDict, RoomStreamToken, StreamToken
+from synapse.types import RoomStreamToken, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
from synapse.server import HomeServer
- from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class _ThreadAggregation:
- # The latest event in the thread.
- latest_event: EventBase
- # The latest edit to the latest event in the thread.
- latest_edit: Optional[EventBase]
- # The total number of events in the thread.
- count: int
- # True if the current user has sent an event to the thread.
- current_user_participated: bool
-
-
-@attr.s(slots=True, auto_attribs=True)
-class BundledAggregations:
- """
- The bundled aggregations for an event.
-
- Some values require additional processing during serialization.
- """
-
- annotations: Optional[JsonDict] = None
- references: Optional[JsonDict] = None
- replace: Optional[EventBase] = None
- thread: Optional[_ThreadAggregation] = None
-
- def __bool__(self) -> bool:
- return bool(self.annotations or self.references or self.replace or self.thread)
-
-
class RelationsWorkerStore(SQLBaseStore):
def __init__(
self,
@@ -384,7 +353,7 @@ class RelationsWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
- async def _get_applicable_edits(
+ async def get_applicable_edits(
self, event_ids: Collection[str]
) -> Dict[str, Optional[EventBase]]:
"""Get the most recent edit (if any) that has happened for the given
@@ -473,7 +442,7 @@ class RelationsWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
- async def _get_thread_summaries(
+ async def get_thread_summaries(
self, event_ids: Collection[str]
) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
"""Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
@@ -587,7 +556,7 @@ class RelationsWorkerStore(SQLBaseStore):
latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]
# Check to see if any of those events are edited.
- latest_edits = await self._get_applicable_edits(latest_event_ids.values())
+ latest_edits = await self.get_applicable_edits(latest_event_ids.values())
# Map to the event IDs to the thread summary.
#
@@ -610,7 +579,7 @@ class RelationsWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
- async def _get_threads_participated(
+ async def get_threads_participated(
self, event_ids: Collection[str], user_id: str
) -> Dict[str, bool]:
"""Get whether the requesting user participated in the given threads.
@@ -766,114 +735,6 @@ class RelationsWorkerStore(SQLBaseStore):
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
- async def _get_bundled_aggregation_for_event(
- self, event: EventBase, user_id: str
- ) -> Optional[BundledAggregations]:
- """Generate bundled aggregations for an event.
-
- Note that this does not use a cache, but depends on cached methods.
-
- Args:
- event: The event to calculate bundled aggregations for.
- user_id: The user requesting the bundled aggregations.
-
- Returns:
- The bundled aggregations for an event, if bundled aggregations are
- enabled and the event can have bundled aggregations.
- """
-
- # Do not bundle aggregations for an event which represents an edit or an
- # annotation. It does not make sense for them to have related events.
- relates_to = event.content.get("m.relates_to")
- if isinstance(relates_to, (dict, frozendict)):
- relation_type = relates_to.get("rel_type")
- if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
- return None
-
- event_id = event.event_id
- room_id = event.room_id
-
- # The bundled aggregations to include, a mapping of relation type to a
- # type-specific value. Some types include the direct return type here
- # while others need more processing during serialization.
- aggregations = BundledAggregations()
-
- annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
- if annotations.chunk:
- aggregations.annotations = await annotations.to_dict(
- cast("DataStore", self)
- )
-
- references = await self.get_relations_for_event(
- event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
- )
- if references.chunk:
- aggregations.references = await references.to_dict(cast("DataStore", self))
-
- # Store the bundled aggregations in the event metadata for later use.
- return aggregations
-
- async def get_bundled_aggregations(
- self, events: Iterable[EventBase], user_id: str
- ) -> Dict[str, BundledAggregations]:
- """Generate bundled aggregations for events.
-
- Args:
- events: The iterable of events to calculate bundled aggregations for.
- user_id: The user requesting the bundled aggregations.
-
- Returns:
- A map of event ID to the bundled aggregation for the event. Not all
- events may have bundled aggregations in the results.
- """
- # De-duplicate events by ID to handle the same event requested multiple times.
- #
- # State events do not get bundled aggregations.
- events_by_id = {
- event.event_id: event for event in events if not event.is_state()
- }
-
- # event ID -> bundled aggregation in non-serialized form.
- results: Dict[str, BundledAggregations] = {}
-
- # Fetch other relations per event.
- for event in events_by_id.values():
- event_result = await self._get_bundled_aggregation_for_event(event, user_id)
- if event_result:
- results[event.event_id] = event_result
-
- # Fetch any edits (but not for redacted events).
- edits = await self._get_applicable_edits(
- [
- event_id
- for event_id, event in events_by_id.items()
- if not event.internal_metadata.is_redacted()
- ]
- )
- for event_id, edit in edits.items():
- results.setdefault(event_id, BundledAggregations()).replace = edit
-
- # Fetch thread summaries.
- summaries = await self._get_thread_summaries(events_by_id.keys())
- # Only fetch participated for a limited selection based on what had
- # summaries.
- participated = await self._get_threads_participated(summaries.keys(), user_id)
- for event_id, summary in summaries.items():
- if summary:
- thread_count, latest_thread_event, edit = summary
- results.setdefault(
- event_id, BundledAggregations()
- ).thread = _ThreadAggregation(
- latest_event=latest_thread_event,
- latest_edit=edit,
- count=thread_count,
- # If there's a thread summary it must also exist in the
- # participated dictionary.
- current_user_participated=participated[event_id],
- )
-
- return results
-
class RelationsStore(RelationsWorkerStore):
pass
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index e7fddd2426..55cc9178f0 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -26,6 +26,8 @@ from typing import (
cast,
)
+from typing_extensions import TypedDict
+
from synapse.api.errors import StoreError
if TYPE_CHECKING:
@@ -40,7 +42,12 @@ from synapse.storage.database import (
from synapse.storage.databases.main.state import StateFilter
from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
+from synapse.types import (
+ JsonDict,
+ UserProfile,
+ get_domain_from_id,
+ get_localpart_from_id,
+)
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -591,6 +598,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
+class SearchResult(TypedDict):
+ limited: bool
+ results: List[UserProfile]
+
+
class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# How many records do we calculate before sending it to
# add_users_who_share_private_rooms?
@@ -777,7 +789,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
async def search_user_dir(
self, user_id: str, search_term: str, limit: int
- ) -> JsonDict:
+ ) -> SearchResult:
"""Searches for users in directory
Returns:
@@ -910,8 +922,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# This should be unreachable.
raise Exception("Unrecognized database engine")
- results = await self.db_pool.execute(
- "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
+ results = cast(
+ List[UserProfile],
+ await self.db_pool.execute(
+ "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
+ ),
)
limited = len(results) > limit
diff --git a/synapse/types.py b/synapse/types.py
index 53be3583a0..5ce2a5b0a5 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -34,6 +34,7 @@ from typing import (
import attr
from frozendict import frozendict
from signedjson.key import decode_verify_key_bytes
+from typing_extensions import TypedDict
from unpaddedbase64 import decode_base64
from zope.interface import Interface
@@ -63,6 +64,10 @@ MutableStateMap = MutableMapping[StateKey, T]
# JSON types. These could be made stronger, but will do for now.
# A JSON-serialisable dict.
JsonDict = Dict[str, Any]
+# A JSON-serialisable mapping; roughly speaking an immutable JSONDict.
+# Useful when you have a TypedDict which isn't going to be mutated and you don't want
+# to cast to JsonDict everywhere.
+JsonMapping = Mapping[str, Any]
# A JSON-serialisable object.
JsonSerializable = object
@@ -791,3 +796,9 @@ class UserInfo:
is_deactivated: bool
is_guest: bool
is_shadow_banned: bool
+
+
+class UserProfile(TypedDict):
+ user_id: str
+ display_name: Optional[str]
+ avatar_url: Optional[str]
diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py
index 12cd804939..66f1da7502 100644
--- a/synapse/util/check_dependencies.py
+++ b/synapse/util/check_dependencies.py
@@ -128,6 +128,19 @@ def _incorrect_version(
)
+def _no_reported_version(requirement: Requirement, extra: Optional[str] = None) -> str:
+ if extra:
+ return (
+ f"Synapse {VERSION} needs {requirement} for {extra}, "
+ f"but can't determine {requirement.name}'s version"
+ )
+ else:
+ return (
+ f"Synapse {VERSION} needs {requirement}, "
+ f"but can't determine {requirement.name}'s version"
+ )
+
+
def check_requirements(extra: Optional[str] = None) -> None:
"""Check Synapse's dependencies are present and correctly versioned.
@@ -163,8 +176,17 @@ def check_requirements(extra: Optional[str] = None) -> None:
deps_unfulfilled.append(requirement.name)
errors.append(_not_installed(requirement, extra))
else:
+ if dist.version is None:
+ # This shouldn't happen---it suggests a borked virtualenv. (See #12223)
+ # Try to give a vaguely helpful error message anyway.
+ # Type-ignore: the annotations don't reflect reality: see
+ # https://github.com/python/typeshed/issues/7513
+ # https://bugs.python.org/issue47060
+ deps_unfulfilled.append(requirement.name) # type: ignore[unreachable]
+ errors.append(_no_reported_version(requirement, extra))
+
# We specify prereleases=True to allow prereleases such as RCs.
- if not requirement.specifier.contains(dist.version, prereleases=True):
+ elif not requirement.specifier.contains(dist.version, prereleases=True):
deps_unfulfilled.append(requirement.name)
errors.append(_incorrect_version(requirement, dist.version, extra))
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 281cbe4d88..49519eb8f5 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -14,12 +14,7 @@
import logging
from typing import Dict, FrozenSet, List, Optional
-from synapse.api.constants import (
- AccountDataTypes,
- EventTypes,
- HistoryVisibility,
- Membership,
-)
+from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.events import EventBase
from synapse.events.utils import prune_event
from synapse.storage import Storage
@@ -87,15 +82,8 @@ async def filter_events_for_client(
state_filter=StateFilter.from_types(types),
)
- ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user(
- user_id, AccountDataTypes.IGNORED_USER_LIST
- )
-
- ignore_list: FrozenSet[str] = frozenset()
- if ignore_dict_content:
- ignored_users_dict = ignore_dict_content.get("ignored_users", {})
- if isinstance(ignored_users_dict, dict):
- ignore_list = frozenset(ignored_users_dict.keys())
+ # Get the users who are ignored by the requesting user.
+ ignore_list = await storage.main.ignored_users(user_id)
erased_senders = await storage.main.are_users_erased(e.sender for e in events)
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index a267228846..a54aa29cf1 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -11,9 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict
from unittest.mock import Mock
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.handlers.cas import CasResponse
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
@@ -24,7 +29,7 @@ SERVER_URL = "https://issuer/"
class CasHandlerTestCase(HomeserverTestCase):
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["public_baseurl"] = BASE_URL
cas_config = {
@@ -40,7 +45,7 @@ class CasHandlerTestCase(HomeserverTestCase):
return config
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver()
self.handler = hs.get_cas_handler()
@@ -51,7 +56,7 @@ class CasHandlerTestCase(HomeserverTestCase):
return hs
- def test_map_cas_user_to_user(self):
+ def test_map_cas_user_to_user(self) -> None:
"""Ensure that mapping the CAS user returned from a provider to an MXID works properly."""
# stub out the auth handler
@@ -75,7 +80,7 @@ class CasHandlerTestCase(HomeserverTestCase):
auth_provider_session_id=None,
)
- def test_map_cas_user_to_existing_user(self):
+ def test_map_cas_user_to_existing_user(self) -> None:
"""Existing users can log in with CAS account."""
store = self.hs.get_datastores().main
self.get_success(
@@ -119,7 +124,7 @@ class CasHandlerTestCase(HomeserverTestCase):
auth_provider_session_id=None,
)
- def test_map_cas_user_to_invalid_localpart(self):
+ def test_map_cas_user_to_invalid_localpart(self) -> None:
"""CAS automaps invalid characters to base-64 encoding."""
# stub out the auth handler
@@ -150,7 +155,7 @@ class CasHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_required_attributes(self):
+ def test_required_attributes(self) -> None:
"""The required attributes must be met from the CAS response."""
# stub out the auth handler
@@ -166,7 +171,7 @@ class CasHandlerTestCase(HomeserverTestCase):
auth_handler.complete_sso_login.assert_not_called()
# The response doesn't have any department.
- cas_response = CasResponse("test_user", {"userGroup": "staff"})
+ cas_response = CasResponse("test_user", {"userGroup": ["staff"]})
request.reset_mock()
self.get_success(
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 6e403a87c5..11ad44223d 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -12,14 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+from typing import Any, Awaitable, Callable, Dict
from unittest.mock import Mock
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.api.errors
import synapse.rest.admin
from synapse.api.constants import EventTypes
from synapse.rest.client import directory, login, room
-from synapse.types import RoomAlias, create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, RoomAlias, create_requester
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
@@ -28,13 +32,15 @@ from tests.test_utils import make_awaitable
class DirectoryTestCase(unittest.HomeserverTestCase):
"""Tests the directory service."""
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.mock_federation = Mock()
self.mock_registry = Mock()
- self.query_handlers = {}
+ self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
- def register_query_handler(query_type, handler):
+ def register_query_handler(
+ query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
+ ) -> None:
self.query_handlers[query_type] = handler
self.mock_registry.register_query_handler = register_query_handler
@@ -54,7 +60,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
return hs
- def test_get_local_association(self):
+ def test_get_local_association(self) -> None:
self.get_success(
self.store.create_room_alias_association(
self.my_room, "!8765qwer:test", ["test"]
@@ -65,7 +71,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
- def test_get_remote_association(self):
+ def test_get_remote_association(self) -> None:
self.mock_federation.make_query.return_value = make_awaitable(
{"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
)
@@ -83,7 +89,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
ignore_backoff=True,
)
- def test_incoming_fed_query(self):
+ def test_incoming_fed_query(self) -> None:
self.get_success(
self.store.create_room_alias_association(
self.your_room, "!8765asdf:test", ["test"]
@@ -105,7 +111,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
directory.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_directory_handler()
# Create user
@@ -125,7 +131,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
self.test_user_tok = self.login("user", "pass")
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
- def test_create_alias_joined_room(self):
+ def test_create_alias_joined_room(self) -> None:
"""A user can create an alias for a room they're in."""
self.get_success(
self.handler.create_association(
@@ -135,7 +141,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
)
)
- def test_create_alias_other_room(self):
+ def test_create_alias_other_room(self) -> None:
"""A user cannot create an alias for a room they're NOT in."""
other_room_id = self.helper.create_room_as(
self.admin_user, tok=self.admin_user_tok
@@ -150,7 +156,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
synapse.api.errors.SynapseError,
)
- def test_create_alias_admin(self):
+ def test_create_alias_admin(self) -> None:
"""An admin can create an alias for a room they're NOT in."""
other_room_id = self.helper.create_room_as(
self.test_user, tok=self.test_user_tok
@@ -173,7 +179,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
directory.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.handler = hs.get_directory_handler()
self.state_handler = hs.get_state_handler()
@@ -195,7 +201,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
self.test_user_tok = self.login("user", "pass")
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
- def _create_alias(self, user):
+ def _create_alias(self, user) -> None:
# Create a new alias to this room.
self.get_success(
self.store.create_room_alias_association(
@@ -203,7 +209,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
)
)
- def test_delete_alias_not_allowed(self):
+ def test_delete_alias_not_allowed(self) -> None:
"""A user that doesn't meet the expected guidelines cannot delete an alias."""
self._create_alias(self.admin_user)
self.get_failure(
@@ -213,7 +219,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
synapse.api.errors.AuthError,
)
- def test_delete_alias_creator(self):
+ def test_delete_alias_creator(self) -> None:
"""An alias creator can delete their own alias."""
# Create an alias from a different user.
self._create_alias(self.test_user)
@@ -232,7 +238,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
synapse.api.errors.SynapseError,
)
- def test_delete_alias_admin(self):
+ def test_delete_alias_admin(self) -> None:
"""A server admin can delete an alias created by another user."""
# Create an alias from a different user.
self._create_alias(self.test_user)
@@ -251,7 +257,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
synapse.api.errors.SynapseError,
)
- def test_delete_alias_sufficient_power(self):
+ def test_delete_alias_sufficient_power(self) -> None:
"""A user with a sufficient power level should be able to delete an alias."""
self._create_alias(self.admin_user)
@@ -288,7 +294,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
directory.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.handler = hs.get_directory_handler()
self.state_handler = hs.get_state_handler()
@@ -317,7 +323,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)
return room_alias
- def _set_canonical_alias(self, content):
+ def _set_canonical_alias(self, content) -> None:
"""Configure the canonical alias state on the room."""
self.helper.send_state(
self.room_id,
@@ -334,7 +340,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)
)
- def test_remove_alias(self):
+ def test_remove_alias(self) -> None:
"""Removing an alias that is the canonical alias should remove it there too."""
# Set this new alias as the canonical alias for this room
self._set_canonical_alias(
@@ -356,7 +362,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
self.assertNotIn("alias", data["content"])
self.assertNotIn("alt_aliases", data["content"])
- def test_remove_other_alias(self):
+ def test_remove_other_alias(self) -> None:
"""Removing an alias listed as in alt_aliases should remove it there too."""
# Create a second alias.
other_test_alias = "#test2:test"
@@ -393,7 +399,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
servlets = [directory.register_servlets, room.register_servlets]
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
# Add custom alias creation rules to the config.
@@ -403,7 +409,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
return config
- def test_denied(self):
+ def test_denied(self) -> None:
room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request(
@@ -413,7 +419,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
)
self.assertEqual(403, channel.code, channel.result)
- def test_allowed(self):
+ def test_allowed(self) -> None:
room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request(
@@ -423,7 +429,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
)
self.assertEqual(200, channel.code, channel.result)
- def test_denied_during_creation(self):
+ def test_denied_during_creation(self) -> None:
"""A room alias that is not allowed should be rejected during creation."""
# Invalid room alias.
self.helper.create_room_as(
@@ -432,7 +438,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
extra_content={"room_alias_name": "foo"},
)
- def test_allowed_during_creation(self):
+ def test_allowed_during_creation(self) -> None:
"""A valid room alias should be allowed during creation."""
room_id = self.helper.create_room_as(
self.user_id,
@@ -459,7 +465,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
data = {"room_alias_name": "unofficial_test"}
allowed_localpart = "allowed"
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
# Add custom room list publication rules to the config.
@@ -474,7 +480,9 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
return config
- def prepare(self, reactor, clock, hs):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
+ ) -> HomeServer:
self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
self.allowed_access_token = self.login(self.allowed_localpart, "pass")
@@ -483,7 +491,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
return hs
- def test_denied_without_publication_permission(self):
+ def test_denied_without_publication_permission(self) -> None:
"""
Try to create a room, register an alias for it, and publish it,
as a user without permission to publish rooms.
@@ -497,7 +505,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=403,
)
- def test_allowed_when_creating_private_room(self):
+ def test_allowed_when_creating_private_room(self) -> None:
"""
Try to create a room, register an alias for it, and NOT publish it,
as a user without permission to publish rooms.
@@ -511,7 +519,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=200,
)
- def test_allowed_with_publication_permission(self):
+ def test_allowed_with_publication_permission(self) -> None:
"""
Try to create a room, register an alias for it, and publish it,
as a user WITH permission to publish rooms.
@@ -525,7 +533,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=200,
)
- def test_denied_publication_with_invalid_alias(self):
+ def test_denied_publication_with_invalid_alias(self) -> None:
"""
Try to create a room, register an alias for it, and publish it,
as a user WITH permission to publish rooms.
@@ -538,7 +546,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=403,
)
- def test_can_create_as_private_room_after_rejection(self):
+ def test_can_create_as_private_room_after_rejection(self) -> None:
"""
After failing to publish a room with an alias as a user without publish permission,
retry as the same user, but without publishing the room.
@@ -549,7 +557,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
self.test_denied_without_publication_permission()
self.test_allowed_when_creating_private_room()
- def test_can_create_with_permission_after_rejection(self):
+ def test_can_create_with_permission_after_rejection(self) -> None:
"""
After failing to publish a room with an alias as a user without publish permission,
retry as someone with permission, using the same alias.
@@ -566,7 +574,9 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
servlets = [directory.register_servlets, room.register_servlets]
- def prepare(self, reactor, clock, hs):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
+ ) -> HomeServer:
room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request(
@@ -579,7 +589,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
return hs
- def test_disabling_room_list(self):
+ def test_disabling_room_list(self) -> None:
self.room_list_handler.enable_room_list_search = True
self.directory_handler.enable_room_list_search = True
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 9338ab92e9..ac21a28c43 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -20,33 +20,37 @@ from parameterized import parameterized
from signedjson import key as key, sign as sign
from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import Codes, SynapseError
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(federation_client=mock.Mock())
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_e2e_keys_handler()
self.store = self.hs.get_datastores().main
- def test_query_local_devices_no_devices(self):
+ def test_query_local_devices_no_devices(self) -> None:
"""If the user has no devices, we expect an empty list."""
local_user = "@boris:" + self.hs.hostname
res = self.get_success(self.handler.query_local_devices({local_user: None}))
self.assertDictEqual(res, {local_user: {}})
- def test_reupload_one_time_keys(self):
+ def test_reupload_one_time_keys(self) -> None:
"""we should be able to re-upload the same keys"""
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
- keys = {
+ keys: JsonDict = {
"alg1:k1": "key1",
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
"alg2:k3": {"key": "key3"},
@@ -74,7 +78,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}
)
- def test_change_one_time_keys(self):
+ def test_change_one_time_keys(self) -> None:
"""attempts to change one-time-keys should be rejected"""
local_user = "@boris:" + self.hs.hostname
@@ -134,7 +138,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
SynapseError,
)
- def test_claim_one_time_key(self):
+ def test_claim_one_time_key(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
keys = {"alg1:k1": "key1"}
@@ -161,7 +165,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
- def test_fallback_key(self):
+ def test_fallback_key(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
fallback_key = {"alg1:k1": "fallback_key1"}
@@ -294,7 +298,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
)
- def test_replace_master_key(self):
+ def test_replace_master_key(self) -> None:
"""uploading a new signing key should make the old signing key unavailable"""
local_user = "@boris:" + self.hs.hostname
keys1 = {
@@ -328,7 +332,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
- def test_reupload_signatures(self):
+ def test_reupload_signatures(self) -> None:
"""re-uploading a signature should not fail"""
local_user = "@boris:" + self.hs.hostname
keys1 = {
@@ -433,7 +437,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
- def test_self_signing_key_doesnt_show_up_as_device(self):
+ def test_self_signing_key_doesnt_show_up_as_device(self) -> None:
"""signing keys should be hidden when fetching a user's devices"""
local_user = "@boris:" + self.hs.hostname
keys1 = {
@@ -462,7 +466,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
res = self.get_success(self.handler.query_local_devices({local_user: None}))
self.assertDictEqual(res, {local_user: {}})
- def test_upload_signatures(self):
+ def test_upload_signatures(self) -> None:
"""should check signatures that are uploaded"""
# set up a user with cross-signing keys and a device. This user will
# try uploading signatures
@@ -686,7 +690,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
)
- def test_query_devices_remote_no_sync(self):
+ def test_query_devices_remote_no_sync(self) -> None:
"""Tests that querying keys for a remote user that we don't share a room
with returns the cross signing keys correctly.
"""
@@ -759,7 +763,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
- def test_query_devices_remote_sync(self):
+ def test_query_devices_remote_sync(self) -> None:
"""Tests that querying keys for a remote user that we share a room with,
but haven't yet fetched the keys for, returns the cross signing keys
correctly.
@@ -845,7 +849,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
(["device_1", "device_2"],),
]
)
- def test_query_all_devices_caches_result(self, device_ids: Iterable[str]):
+ def test_query_all_devices_caches_result(self, device_ids: Iterable[str]) -> None:
"""Test that requests for all of a remote user's devices are cached.
We do this by asserting that only one call over federation was made, and that
@@ -853,7 +857,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
"""
local_user_id = "@test:test"
remote_user_id = "@test:other"
- request_body = {"device_keys": {remote_user_id: []}}
+ request_body: JsonDict = {"device_keys": {remote_user_id: []}}
response_devices = [
{
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index e8b4e39d1a..89078fc637 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import List
+from typing import List, cast
from unittest import TestCase
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
from synapse.api.room_versions import RoomVersions
@@ -23,7 +25,9 @@ from synapse.federation.federation_base import event_from_pdu_json
from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
from synapse.types import create_requester
+from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
@@ -42,7 +46,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver(federation_http_client=None)
self.handler = hs.get_federation_handler()
self.store = hs.get_datastores().main
@@ -50,7 +54,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self._event_auth_handler = hs.get_event_auth_handler()
return hs
- def test_exchange_revoked_invite(self):
+ def test_exchange_revoked_invite(self) -> None:
user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")
@@ -96,7 +100,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure)
self.assertEqual(failure.msg, "You are not invited to this room.")
- def test_rejected_message_event_state(self):
+ def test_rejected_message_event_state(self) -> None:
"""
Check that we store the state group correctly for rejected non-state events.
@@ -126,7 +130,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
"content": {},
"room_id": room_id,
"sender": "@yetanotheruser:" + OTHER_SERVER,
- "depth": join_event["depth"] + 1,
+ "depth": cast(int, join_event["depth"]) + 1,
"prev_events": [join_event.event_id],
"auth_events": [],
"origin_server_ts": self.clock.time_msec(),
@@ -149,7 +153,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(sg, sg2)
- def test_rejected_state_event_state(self):
+ def test_rejected_state_event_state(self) -> None:
"""
Check that we store the state group correctly for rejected state events.
@@ -180,7 +184,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
"content": {},
"room_id": room_id,
"sender": "@yetanotheruser:" + OTHER_SERVER,
- "depth": join_event["depth"] + 1,
+ "depth": cast(int, join_event["depth"]) + 1,
"prev_events": [join_event.event_id],
"auth_events": [],
"origin_server_ts": self.clock.time_msec(),
@@ -203,7 +207,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(sg, sg2)
- def test_backfill_with_many_backward_extremities(self):
+ def test_backfill_with_many_backward_extremities(self) -> None:
"""
Check that we can backfill with many backward extremities.
The goal is to make sure that when we only use a portion
@@ -262,7 +266,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
)
self.get_success(d)
- def test_backfill_floating_outlier_membership_auth(self):
+ def test_backfill_floating_outlier_membership_auth(self) -> None:
"""
As the local homeserver, check that we can properly process a federated
event from the OTHER_SERVER with auth_events that include a floating
@@ -377,7 +381,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
for ae in auth_events
]
- self.handler.federation_client.get_event_auth = get_event_auth
+ self.handler.federation_client.get_event_auth = get_event_auth # type: ignore[assignment]
with LoggingContext("receive_pdu"):
# Fake the OTHER_SERVER federating the message event over to our local homeserver
@@ -397,7 +401,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
@unittest.override_config(
{"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
)
- def test_invite_by_user_ratelimit(self):
+ def test_invite_by_user_ratelimit(self) -> None:
"""Tests that invites from federation to a particular user are
actually rate-limited.
"""
@@ -446,7 +450,9 @@ class FederationTestCase(unittest.HomeserverTestCase):
exc=LimitExceededError,
)
- def _build_and_send_join_event(self, other_server, other_user, room_id):
+ def _build_and_send_join_event(
+ self, other_server: str, other_user: str, room_id: str
+ ) -> EventBase:
join_event = self.get_success(
self.handler.on_make_join_request(other_server, room_id, other_user)
)
@@ -469,7 +475,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
class EventFromPduTestCase(TestCase):
- def test_valid_json(self):
+ def test_valid_json(self) -> None:
"""Valid JSON should be turned into an event."""
ev = event_from_pdu_json(
{
@@ -487,7 +493,7 @@ class EventFromPduTestCase(TestCase):
self.assertIsInstance(ev, EventBase)
- def test_invalid_numbers(self):
+ def test_invalid_numbers(self) -> None:
"""Invalid values for an integer should be rejected, all floats should be rejected."""
for value in [
-(2 ** 53),
@@ -512,7 +518,7 @@ class EventFromPduTestCase(TestCase):
RoomVersions.V6,
)
- def test_invalid_nested(self):
+ def test_invalid_nested(self) -> None:
"""List and dictionaries are recursively searched."""
with self.assertRaises(SynapseError):
event_from_pdu_json(
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index e8418b6638..014815db6e 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,14 +13,18 @@
# limitations under the License.
import json
import os
+from typing import Any, Dict
from unittest.mock import ANY, Mock, patch
from urllib.parse import parse_qs, urlparse
import pymacaroons
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.handlers.sso import MappingException
from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+from synapse.util import Clock
from synapse.util.macaroons import get_value_from_macaroon
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
@@ -98,7 +102,7 @@ class TestMappingProviderFailures(TestMappingProvider):
}
-async def get_json(url):
+async def get_json(url: str) -> JsonDict:
# Mock get_json calls to handle jwks & oidc discovery endpoints
if url == WELL_KNOWN:
# Minimal discovery document, as defined in OpenID.Discovery
@@ -116,6 +120,8 @@ async def get_json(url):
elif url == JWKS_URI:
return {"keys": []}
+ return {}
+
def _key_file_path() -> str:
"""path to a file containing the private half of a test key"""
@@ -147,12 +153,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
if not HAS_OIDC:
skip = "requires OIDC"
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["public_baseurl"] = BASE_URL
return config
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.http_client = Mock(spec=["get_json"])
self.http_client.get_json.side_effect = get_json
self.http_client.user_agent = b"Synapse Test"
@@ -164,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
sso_handler = hs.get_sso_handler()
# Mock the render error method.
self.render_error = Mock(return_value=None)
- sso_handler.render_error = self.render_error
+ sso_handler.render_error = self.render_error # type: ignore[assignment]
# Reduce the number of attempts when generating MXIDs.
sso_handler._MAP_USERNAME_RETRIES = 3
@@ -193,14 +199,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
return args
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_config(self):
+ def test_config(self) -> None:
"""Basic config correctly sets up the callback URL and client auth correctly."""
self.assertEqual(self.provider._callback_url, CALLBACK_URL)
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
@override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
- def test_discovery(self):
+ def test_discovery(self) -> None:
"""The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid
metadata = self.get_success(self.provider.load_metadata())
@@ -219,13 +225,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
- def test_no_discovery(self):
+ def test_no_discovery(self) -> None:
"""When discovery is disabled, it should not try to load from discovery document."""
self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
- def test_load_jwks(self):
+ def test_load_jwks(self) -> None:
"""JWKS loading is done once (then cached) if used."""
jwks = self.get_success(self.provider.load_jwks())
self.http_client.get_json.assert_called_once_with(JWKS_URI)
@@ -253,7 +259,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_validate_config(self):
+ def test_validate_config(self) -> None:
"""Provider metadatas are extensively validated."""
h = self.provider
@@ -336,14 +342,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
force_load_metadata()
@override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
- def test_skip_verification(self):
+ def test_skip_verification(self) -> None:
"""Provider metadata validation can be disabled by config."""
with self.metadata_edit({"issuer": "http://insecure"}):
# This should not throw
get_awaitable_result(self.provider.load_metadata())
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_redirect_request(self):
+ def test_redirect_request(self) -> None:
"""The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["cookies"])
req.cookies = []
@@ -387,7 +393,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(redirect, "http://client/redirect")
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_callback_error(self):
+ def test_callback_error(self) -> None:
"""Errors from the provider returned in the callback are displayed."""
request = Mock(args={})
request.args[b"error"] = [b"invalid_client"]
@@ -399,7 +405,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("invalid_client", "some description")
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_callback(self):
+ def test_callback(self) -> None:
"""Code callback works and display errors if something went wrong.
A lot of scenarios are tested here:
@@ -428,9 +434,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": username,
}
expected_user_id = "@%s:%s" % (username, self.hs.hostname)
- self.provider._exchange_code = simple_async_mock(return_value=token)
- self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
- self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+ self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
+ self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
+ self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -468,7 +474,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("mapping_error")
# Handle ID token errors
- self.provider._parse_id_token = simple_async_mock(raises=Exception())
+ self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment]
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token")
@@ -483,7 +489,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"type": "bearer",
"access_token": "access_token",
}
- self.provider._exchange_code = simple_async_mock(return_value=token)
+ self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with(
@@ -510,8 +516,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
id_token = {
"sid": "abcdefgh",
}
- self.provider._parse_id_token = simple_async_mock(return_value=id_token)
- self.provider._exchange_code = simple_async_mock(return_value=token)
+ self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment]
+ self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
auth_handler.complete_sso_login.reset_mock()
self.provider._fetch_userinfo.reset_mock()
self.get_success(self.handler.handle_oidc_callback(request))
@@ -531,21 +537,21 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.render_error.assert_not_called()
# Handle userinfo fetching error
- self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
+ self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment]
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("fetch_error")
# Handle code exchange failure
from synapse.handlers.oidc import OidcError
- self.provider._exchange_code = simple_async_mock(
+ self.provider._exchange_code = simple_async_mock( # type: ignore[assignment]
raises=OidcError("invalid_request")
)
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request")
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_callback_session(self):
+ def test_callback_session(self) -> None:
"""The callback verifies the session presence and validity"""
request = Mock(spec=["args", "getCookie", "cookies"])
@@ -590,7 +596,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config(
{"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
)
- def test_exchange_code(self):
+ def test_exchange_code(self) -> None:
"""Code exchange behaves correctly and handles various error scenarios."""
token = {"type": "bearer"}
token_json = json.dumps(token).encode("utf-8")
@@ -686,7 +692,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_exchange_code_jwt_key(self):
+ def test_exchange_code_jwt_key(self) -> None:
"""Test that code exchange works with a JWK client secret."""
from authlib.jose import jwt
@@ -741,7 +747,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_exchange_code_no_auth(self):
+ def test_exchange_code_no_auth(self) -> None:
"""Test that code exchange works with no client secret."""
token = {"type": "bearer"}
self.http_client.request = simple_async_mock(
@@ -776,7 +782,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_extra_attributes(self):
+ def test_extra_attributes(self) -> None:
"""
Login while using a mapping provider that implements get_extra_attributes.
"""
@@ -790,8 +796,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "foo",
"phone": "1234567",
}
- self.provider._exchange_code = simple_async_mock(return_value=token)
- self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
+ self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
+ self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -817,12 +823,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_map_userinfo_to_user(self):
+ def test_map_userinfo_to_user(self) -> None:
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
- userinfo = {
+ userinfo: dict = {
"sub": "test_user",
"username": "test_user",
}
@@ -870,7 +876,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
@override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
- def test_map_userinfo_to_existing_user(self):
+ def test_map_userinfo_to_existing_user(self) -> None:
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
store = self.hs.get_datastores().main
user = UserID.from_string("@test_user:test")
@@ -974,7 +980,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_map_userinfo_to_invalid_localpart(self):
+ def test_map_userinfo_to_invalid_localpart(self) -> None:
"""If the mapping provider generates an invalid localpart it should be rejected."""
self.get_success(
_make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
@@ -991,7 +997,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_map_userinfo_to_user_retries(self):
+ def test_map_userinfo_to_user_retries(self) -> None:
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -1039,7 +1045,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_empty_localpart(self):
+ def test_empty_localpart(self) -> None:
"""Attempts to map onto an empty localpart should be rejected."""
userinfo = {
"sub": "tester",
@@ -1058,7 +1064,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_null_localpart(self):
+ def test_null_localpart(self) -> None:
"""Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
userinfo = {
"sub": "tester",
@@ -1075,7 +1081,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_attribute_requirements(self):
+ def test_attribute_requirements(self) -> None:
"""The required attributes must be met from the OIDC userinfo response."""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -1115,7 +1121,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_attribute_requirements_contains(self):
+ def test_attribute_requirements_contains(self) -> None:
"""Test that auth succeeds if userinfo attribute CONTAINS required value"""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -1146,7 +1152,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_attribute_requirements_mismatch(self):
+ def test_attribute_requirements_mismatch(self) -> None:
"""
Test that auth fails if attributes exist but don't match,
or are non-string values.
@@ -1154,7 +1160,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
# userinfo with "test": "not_foobar" attribute should fail
- userinfo = {
+ userinfo: dict = {
"sub": "tester",
"username": "tester",
"test": "not_foobar",
@@ -1248,9 +1254,9 @@ async def _make_callback_with_userinfo(
handler = hs.get_oidc_handler()
provider = handler._providers["oidc"]
- provider._exchange_code = simple_async_mock(return_value={"id_token": ""})
- provider._parse_id_token = simple_async_mock(return_value=userinfo)
- provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+ provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment]
+ provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
+ provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
state = "state"
session = handler._token_generator.generate_oidc_session_token(
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 6ddec9ecf1..b2ed9cbe37 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -331,11 +331,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
# Extract presence update user ID and state information into lists of tuples
db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]]
- presence_states = [(ps.user_id, ps.state) for ps in presence_states]
+ presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states]
# Compare what we put into the storage with what we got out.
# They should be identical.
- self.assertEqual(presence_states, db_presence_states)
+ self.assertEqual(presence_states_compare, db_presence_states)
class PresenceTimeoutTestCase(unittest.TestCase):
@@ -357,6 +357,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
self.assertIsNotNone(new_state)
+ assert new_state is not None
self.assertEqual(new_state.state, PresenceState.UNAVAILABLE)
self.assertEqual(new_state.status_msg, status_msg)
@@ -380,6 +381,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
self.assertIsNotNone(new_state)
+ assert new_state is not None
self.assertEqual(new_state.state, PresenceState.BUSY)
self.assertEqual(new_state.status_msg, status_msg)
@@ -399,6 +401,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
self.assertIsNotNone(new_state)
+ assert new_state is not None
self.assertEqual(new_state.state, PresenceState.OFFLINE)
self.assertEqual(new_state.status_msg, status_msg)
@@ -420,6 +423,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
)
self.assertIsNotNone(new_state)
+ assert new_state is not None
self.assertEqual(new_state.state, PresenceState.ONLINE)
self.assertEqual(new_state.status_msg, status_msg)
@@ -477,6 +481,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
)
self.assertIsNotNone(new_state)
+ assert new_state is not None
self.assertEqual(new_state.state, PresenceState.OFFLINE)
self.assertEqual(new_state.status_msg, status_msg)
@@ -653,13 +658,13 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase):
self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None)
def _set_presencestate_with_status_msg(
- self, user_id: str, state: PresenceState, status_msg: Optional[str]
+ self, user_id: str, state: str, status_msg: Optional[str]
):
"""Set a PresenceState and status_msg and check the result.
Args:
user_id: User for that the status is to be set.
- PresenceState: The new PresenceState.
+ state: The new PresenceState.
status_msg: Status message that is to be set.
"""
self.get_success(
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 972cbac6e4..08733a9f2d 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -11,14 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict
+from typing import Any, Awaitable, Callable, Dict
from unittest.mock import Mock
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.types
from synapse.api.errors import AuthError, SynapseError
from synapse.rest import admin
from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
@@ -29,13 +32,15 @@ class ProfileTestCase(unittest.HomeserverTestCase):
servlets = [admin.register_servlets]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.mock_federation = Mock()
self.mock_registry = Mock()
- self.query_handlers = {}
+ self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
- def register_query_handler(query_type, handler):
+ def register_query_handler(
+ query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
+ ) -> None:
self.query_handlers[query_type] = handler
self.mock_registry.register_query_handler = register_query_handler
@@ -47,7 +52,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
return hs
- def prepare(self, reactor, clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.frank = UserID.from_string("@1234abcd:test")
@@ -58,7 +63,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.handler = hs.get_profile_handler()
- def test_get_my_name(self):
+ def test_get_my_name(self) -> None:
self.get_success(
self.store.set_profile_displayname(self.frank.localpart, "Frank")
)
@@ -67,7 +72,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual("Frank", displayname)
- def test_set_my_name(self):
+ def test_set_my_name(self) -> None:
self.get_success(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
@@ -110,7 +115,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.get_profile_displayname(self.frank.localpart))
)
- def test_set_my_name_if_disabled(self):
+ def test_set_my_name_if_disabled(self) -> None:
self.hs.config.registration.enable_set_displayname = False
# Setting displayname for the first time is allowed
@@ -135,7 +140,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
SynapseError,
)
- def test_set_my_name_noauth(self):
+ def test_set_my_name_noauth(self) -> None:
self.get_failure(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
@@ -143,7 +148,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
AuthError,
)
- def test_get_other_name(self):
+ def test_get_other_name(self) -> None:
self.mock_federation.make_query.return_value = make_awaitable(
{"displayname": "Alice"}
)
@@ -158,7 +163,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
ignore_backoff=True,
)
- def test_incoming_fed_query(self):
+ def test_incoming_fed_query(self) -> None:
self.get_success(self.store.create_profile("caroline"))
self.get_success(self.store.set_profile_displayname("caroline", "Caroline"))
@@ -174,7 +179,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual({"displayname": "Caroline"}, response)
- def test_get_my_avatar(self):
+ def test_get_my_avatar(self) -> None:
self.get_success(
self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png"
@@ -184,7 +189,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual("http://my.server/me.png", avatar_url)
- def test_set_my_avatar(self):
+ def test_set_my_avatar(self) -> None:
self.get_success(
self.handler.set_avatar_url(
self.frank,
@@ -225,7 +230,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
)
- def test_set_my_avatar_if_disabled(self):
+ def test_set_my_avatar_if_disabled(self) -> None:
self.hs.config.registration.enable_set_avatar_url = False
# Setting displayname for the first time is allowed
@@ -250,7 +255,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
SynapseError,
)
- def test_avatar_constraints_no_config(self):
+ def test_avatar_constraints_no_config(self) -> None:
"""Tests that the method to check an avatar against configured constraints skips
all of its check if no constraint is configured.
"""
@@ -263,7 +268,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertTrue(res)
@unittest.override_config({"max_avatar_size": 50})
- def test_avatar_constraints_missing(self):
+ def test_avatar_constraints_missing(self) -> None:
"""Tests that an avatar isn't allowed if the file at the given MXC URI couldn't
be found.
"""
@@ -273,7 +278,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertFalse(res)
@unittest.override_config({"max_avatar_size": 50})
- def test_avatar_constraints_file_size(self):
+ def test_avatar_constraints_file_size(self) -> None:
"""Tests that a file that's above the allowed file size is forbidden but one
that's below it is allowed.
"""
@@ -295,7 +300,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertFalse(res)
@unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
- def test_avatar_constraint_mime_type(self):
+ def test_avatar_constraint_mime_type(self) -> None:
"""Tests that a file with an unauthorised MIME type is forbidden but one with
an authorised content type is allowed.
"""
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 23941abed8..8d4404eda1 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -12,12 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional
+from typing import Any, Dict, Optional
from unittest.mock import Mock
import attr
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.errors import RedirectException
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
@@ -81,10 +85,10 @@ class TestRedirectMappingProvider(TestMappingProvider):
class SamlHandlerTestCase(HomeserverTestCase):
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["public_baseurl"] = BASE_URL
- saml_config = {
+ saml_config: Dict[str, Any] = {
"sp_config": {"metadata": {}},
# Disable grandfathering.
"grandfathered_mxid_source_attribute": None,
@@ -98,7 +102,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
return config
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver()
self.handler = hs.get_saml_handler()
@@ -114,7 +118,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
elif not has_xmlsec1:
skip = "Requires xmlsec1"
- def test_map_saml_response_to_user(self):
+ def test_map_saml_response_to_user(self) -> None:
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
# stub out the auth handler
@@ -140,7 +144,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
)
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
- def test_map_saml_response_to_existing_user(self):
+ def test_map_saml_response_to_existing_user(self) -> None:
"""Existing users can log in with SAML account."""
store = self.hs.get_datastores().main
self.get_success(
@@ -186,7 +190,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
auth_provider_session_id=None,
)
- def test_map_saml_response_to_invalid_localpart(self):
+ def test_map_saml_response_to_invalid_localpart(self) -> None:
"""If the mapping provider generates an invalid localpart it should be rejected."""
# stub out the auth handler
@@ -207,7 +211,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
)
auth_handler.complete_sso_login.assert_not_called()
- def test_map_saml_response_to_user_retries(self):
+ def test_map_saml_response_to_user_retries(self) -> None:
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
# stub out the auth handler and error renderer
@@ -271,7 +275,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_map_saml_response_redirect(self):
+ def test_map_saml_response_redirect(self) -> None:
"""Test a mapping provider that raises a RedirectException"""
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
@@ -292,7 +296,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
},
}
)
- def test_attribute_requirements(self):
+ def test_attribute_requirements(self) -> None:
"""The required attributes must be met from the SAML response."""
# stub out the auth handler
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index f91a80b9fa..ffd5c4cb93 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -18,11 +18,14 @@ from typing import Dict
from unittest.mock import ANY, Mock, call
from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
from synapse.api.errors import AuthError
from synapse.federation.transport.server import TransportLayerServer
-from synapse.types import UserID, create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, UserID, create_requester
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
@@ -42,7 +45,9 @@ ROOM_ID = "a-room"
OTHER_ROOM_ID = "another-room"
-def _expect_edu_transaction(edu_type, content, origin="test"):
+def _expect_edu_transaction(
+ edu_type: str, content: JsonDict, origin: str = "test"
+) -> JsonDict:
return {
"origin": origin,
"origin_server_ts": 1000000,
@@ -51,12 +56,12 @@ def _expect_edu_transaction(edu_type, content, origin="test"):
}
-def _make_edu_transaction_json(edu_type, content):
+def _make_edu_transaction_json(edu_type: str, content: JsonDict) -> bytes:
return json.dumps(_expect_edu_transaction(edu_type, content)).encode("utf8")
class TypingNotificationsTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
mock_keyring = Mock(spec=["verify_json_for_server"])
@@ -83,7 +88,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
d["/_matrix/federation"] = TransportLayerServer(self.hs)
return d
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
mock_notifier = hs.get_notifier()
self.on_new_event = mock_notifier.on_new_event
@@ -111,24 +116,24 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.room_members = []
- async def check_user_in_room(room_id, user_id):
+ async def check_user_in_room(room_id: str, user_id: str) -> None:
if user_id not in [u.to_string() for u in self.room_members]:
raise AuthError(401, "User is not in the room")
return None
hs.get_auth().check_user_in_room = check_user_in_room
- async def check_host_in_room(room_id, server_name):
+ async def check_host_in_room(room_id: str, server_name: str) -> bool:
return room_id == ROOM_ID
hs.get_event_auth_handler().check_host_in_room = check_host_in_room
- def get_joined_hosts_for_room(room_id):
+ def get_joined_hosts_for_room(room_id: str):
return {member.domain for member in self.room_members}
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
- async def get_users_in_room(room_id):
+ async def get_users_in_room(room_id: str):
return {str(u) for u in self.room_members}
self.datastore.get_users_in_room = get_users_in_room
@@ -153,7 +158,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
lambda *args, **kwargs: make_awaitable(None)
)
- def test_started_typing_local(self):
+ def test_started_typing_local(self) -> None:
self.room_members = [U_APPLE, U_BANANA]
self.assertEqual(self.event_source.get_current_key(), 0)
@@ -187,7 +192,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
@override_config({"send_federation": True})
- def test_started_typing_remote_send(self):
+ def test_started_typing_remote_send(self) -> None:
self.room_members = [U_APPLE, U_ONION]
self.get_success(
@@ -217,7 +222,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
try_trailing_slash_on_400=True,
)
- def test_started_typing_remote_recv(self):
+ def test_started_typing_remote_recv(self) -> None:
self.room_members = [U_APPLE, U_ONION]
self.assertEqual(self.event_source.get_current_key(), 0)
@@ -256,7 +261,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
],
)
- def test_started_typing_remote_recv_not_in_room(self):
+ def test_started_typing_remote_recv_not_in_room(self) -> None:
self.room_members = [U_APPLE, U_ONION]
self.assertEqual(self.event_source.get_current_key(), 0)
@@ -292,7 +297,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEqual(events[1], 0)
@override_config({"send_federation": True})
- def test_stopped_typing(self):
+ def test_stopped_typing(self) -> None:
self.room_members = [U_APPLE, U_BANANA, U_ONION]
# Gut-wrenching
@@ -343,7 +348,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
[{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
)
- def test_typing_timeout(self):
+ def test_typing_timeout(self) -> None:
self.room_members = [U_APPLE, U_BANANA]
self.assertEqual(self.event_source.get_current_key(), 0)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 6691e07128..ba158f5d93 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -11,15 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Tuple
+from typing import Any, Dict, List, Optional, Tuple
from unittest.mock import Mock
from twisted.internet.defer import Deferred
+from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.logging.context import make_deferred_yieldable
from synapse.push import PusherConfigException
from synapse.rest.client import login, push_rule, receipts, room
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase, override_config
@@ -35,13 +39,13 @@ class HTTPPusherTests(HomeserverTestCase):
user_id = True
hijack_auth = False
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["start_pushers"] = True
return config
- def make_homeserver(self, reactor, clock):
- self.push_attempts: List[tuple[Deferred, str, dict]] = []
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.push_attempts: List[Tuple[Deferred, str, dict]] = []
m = Mock()
@@ -56,7 +60,7 @@ class HTTPPusherTests(HomeserverTestCase):
return hs
- def test_invalid_configuration(self):
+ def test_invalid_configuration(self) -> None:
"""Invalid push configurations should be rejected."""
# Register the user who gets notified
user_id = self.register_user("user", "pass")
@@ -68,7 +72,7 @@ class HTTPPusherTests(HomeserverTestCase):
)
token_id = user_tuple.token_id
- def test_data(data):
+ def test_data(data: Optional[JsonDict]) -> None:
self.get_failure(
self.hs.get_pusherpool().add_pusher(
user_id=user_id,
@@ -95,7 +99,7 @@ class HTTPPusherTests(HomeserverTestCase):
# A url with an incorrect path isn't accepted.
test_data({"url": "http://example.com/foo"})
- def test_sends_http(self):
+ def test_sends_http(self) -> None:
"""
The HTTP pusher will send pushes for each message to a HTTP endpoint
when configured to do so.
@@ -200,7 +204,7 @@ class HTTPPusherTests(HomeserverTestCase):
self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
- def test_sends_high_priority_for_encrypted(self):
+ def test_sends_high_priority_for_encrypted(self) -> None:
"""
The HTTP pusher will send pushes at high priority if they correspond
to an encrypted message.
@@ -321,7 +325,7 @@ class HTTPPusherTests(HomeserverTestCase):
)
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
- def test_sends_high_priority_for_one_to_one_only(self):
+ def test_sends_high_priority_for_one_to_one_only(self) -> None:
"""
The HTTP pusher will send pushes at high priority if they correspond
to a message in a one-to-one room.
@@ -404,7 +408,7 @@ class HTTPPusherTests(HomeserverTestCase):
# check that this is low-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
- def test_sends_high_priority_for_mention(self):
+ def test_sends_high_priority_for_mention(self) -> None:
"""
The HTTP pusher will send pushes at high priority if they correspond
to a message containing the user's display name.
@@ -480,7 +484,7 @@ class HTTPPusherTests(HomeserverTestCase):
# check that this is low-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
- def test_sends_high_priority_for_atroom(self):
+ def test_sends_high_priority_for_atroom(self) -> None:
"""
The HTTP pusher will send pushes at high priority if they correspond
to a message that contains @room.
@@ -563,7 +567,7 @@ class HTTPPusherTests(HomeserverTestCase):
# check that this is low-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
- def test_push_unread_count_group_by_room(self):
+ def test_push_unread_count_group_by_room(self) -> None:
"""
The HTTP pusher will group unread count by number of unread rooms.
"""
@@ -576,7 +580,7 @@ class HTTPPusherTests(HomeserverTestCase):
self._check_push_attempt(6, 1)
@override_config({"push": {"group_unread_count_by_room": False}})
- def test_push_unread_count_message_count(self):
+ def test_push_unread_count_message_count(self) -> None:
"""
The HTTP pusher will send the total unread message count.
"""
@@ -589,7 +593,7 @@ class HTTPPusherTests(HomeserverTestCase):
# last read receipt
self._check_push_attempt(6, 3)
- def _test_push_unread_count(self):
+ def _test_push_unread_count(self) -> None:
"""
Tests that the correct unread count appears in sent push notifications
@@ -681,7 +685,7 @@ class HTTPPusherTests(HomeserverTestCase):
self.helper.send(room_id, body="HELLO???", tok=other_access_token)
- def _advance_time_and_make_push_succeed(self, expected_push_attempts):
+ def _advance_time_and_make_push_succeed(self, expected_push_attempts: int) -> None:
self.pump()
self.push_attempts[expected_push_attempts - 1][0].callback({})
@@ -708,7 +712,9 @@ class HTTPPusherTests(HomeserverTestCase):
expected_unread_count_last_push,
)
- def _send_read_request(self, access_token, message_event_id, room_id):
+ def _send_read_request(
+ self, access_token: str, message_event_id: str, room_id: str
+ ) -> None:
# Now set the user's read receipt position to the first event
#
# This will actually trigger a new notification to be sent out so that
@@ -748,7 +754,7 @@ class HTTPPusherTests(HomeserverTestCase):
return user_id, access_token
- def test_dont_notify_rule_overrides_message(self):
+ def test_dont_notify_rule_overrides_message(self) -> None:
"""
The override push rule will suppress notification
"""
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 3849beb9d6..5dba187076 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict
+from typing import Dict, Optional, Union
import frozendict
@@ -20,12 +20,13 @@ from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent
from synapse.push import push_rule_evaluator
from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent
+from synapse.types import JsonDict
from tests import unittest
class PushRuleEvaluatorTestCase(unittest.TestCase):
- def _get_evaluator(self, content):
+ def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluatorForEvent:
event = FrozenEvent(
{
"event_id": "$event_id",
@@ -39,12 +40,12 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
)
room_member_count = 0
sender_power_level = 0
- power_levels = {}
+ power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
return PushRuleEvaluatorForEvent(
event, room_member_count, sender_power_level, power_levels
)
- def test_display_name(self):
+ def test_display_name(self) -> None:
"""Check for a matching display name in the body of the event."""
evaluator = self._get_evaluator({"body": "foo bar baz"})
@@ -71,20 +72,20 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
def _assert_matches(
- self, condition: Dict[str, Any], content: Dict[str, Any], msg=None
+ self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
) -> None:
evaluator = self._get_evaluator(content)
self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"), msg)
def _assert_not_matches(
- self, condition: Dict[str, Any], content: Dict[str, Any], msg=None
+ self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
) -> None:
evaluator = self._get_evaluator(content)
self.assertFalse(
evaluator.matches(condition, "@user:test", "display_name"), msg
)
- def test_event_match_body(self):
+ def test_event_match_body(self) -> None:
"""Check that event_match conditions on content.body work as expected"""
# if the key is `content.body`, the pattern matches substrings.
@@ -165,7 +166,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
r"? after \ should match any character",
)
- def test_event_match_non_body(self):
+ def test_event_match_non_body(self) -> None:
"""Check that event_match conditions on other keys work as expected"""
# if the key is anything other than 'content.body', the pattern must match the
@@ -241,7 +242,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
"pattern should not match before a newline",
)
- def test_no_body(self):
+ def test_no_body(self) -> None:
"""Not having a body shouldn't break the evaluator."""
evaluator = self._get_evaluator({})
@@ -250,7 +251,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
}
self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
- def test_invalid_body(self):
+ def test_invalid_body(self) -> None:
"""A non-string body should not break the evaluator."""
condition = {
"kind": "contains_display_name",
@@ -260,7 +261,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
evaluator = self._get_evaluator({"body": body})
self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
- def test_tweaks_for_actions(self):
+ def test_tweaks_for_actions(self) -> None:
"""
This tests the behaviour of tweaks_for_actions.
"""
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 171f4e97c8..329690f8f7 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -15,7 +15,7 @@
import itertools
import urllib.parse
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
from unittest.mock import patch
from twisted.test.proto_helpers import MemoryReactor
@@ -79,6 +79,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
content: Optional[dict] = None,
access_token: Optional[str] = None,
parent_id: Optional[str] = None,
+ expected_response_code: int = 200,
) -> FakeChannel:
"""Helper function to send a relation pointing at `self.parent_id`
@@ -115,16 +116,60 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
content,
access_token=access_token,
)
+ self.assertEqual(expected_response_code, channel.code, channel.json_body)
return channel
+ def _get_related_events(self) -> List[str]:
+ """
+ Requests /relations on the parent ID and returns a list of event IDs.
+ """
+ # Request the relations of the event.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ return [ev["event_id"] for ev in channel.json_body["chunk"]]
+
+ def _get_bundled_aggregations(self) -> JsonDict:
+ """
+ Requests /event on the parent ID and returns the m.relations field (from unsigned), if it exists.
+ """
+ # Fetch the bundled aggregations of the event.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/event/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ return channel.json_body["unsigned"].get("m.relations", {})
+
+ def _get_aggregations(self) -> List[JsonDict]:
+ """Request /aggregations on the parent ID and includes the returned chunk."""
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ return channel.json_body["chunk"]
+
+ def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
+ """
+ Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
+ """
+ for event in events:
+ if event["event_id"] == self.parent_id:
+ return event
+
+ raise AssertionError(f"Event {self.parent_id} not found in chunk")
+
class RelationsTestCase(BaseRelationsTestCase):
def test_send_relation(self) -> None:
"""Tests that sending a relation works."""
-
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
- self.assertEqual(200, channel.code, channel.json_body)
-
event_id = channel.json_body["event_id"]
channel = self.make_request(
@@ -151,13 +196,13 @@ class RelationsTestCase(BaseRelationsTestCase):
def test_deny_invalid_event(self) -> None:
"""Test that we deny relations on non-existant events"""
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.ANNOTATION,
EventTypes.Message,
parent_id="foo",
content={"body": "foo", "msgtype": "m.text"},
+ expected_response_code=400,
)
- self.assertEqual(400, channel.code, channel.json_body)
# Unless that event is referenced from another event!
self.get_success(
@@ -171,13 +216,12 @@ class RelationsTestCase(BaseRelationsTestCase):
desc="test_deny_invalid_event",
)
)
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.THREAD,
EventTypes.Message,
parent_id="foo",
content={"body": "foo", "msgtype": "m.text"},
)
- self.assertEqual(200, channel.code, channel.json_body)
def test_deny_invalid_room(self) -> None:
"""Test that we deny relations on non-existant events"""
@@ -187,18 +231,20 @@ class RelationsTestCase(BaseRelationsTestCase):
parent_id = res["event_id"]
# Attempt to send an annotation to that event.
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A"
+ self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ parent_id=parent_id,
+ key="A",
+ expected_response_code=400,
)
- self.assertEqual(400, channel.code, channel.json_body)
def test_deny_double_react(self) -> None:
"""Test that we deny relations on membership events"""
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(400, channel.code, channel.json_body)
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
+ self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", expected_response_code=400
+ )
def test_deny_forked_thread(self) -> None:
"""It is invalid to start a thread off a thread."""
@@ -208,316 +254,24 @@ class RelationsTestCase(BaseRelationsTestCase):
content={"msgtype": "m.text", "body": "foo"},
parent_id=self.parent_id,
)
- self.assertEqual(200, channel.code, channel.json_body)
parent_id = channel.json_body["event_id"]
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.THREAD,
"m.room.message",
content={"msgtype": "m.text", "body": "foo"},
parent_id=parent_id,
+ expected_response_code=400,
)
- self.assertEqual(400, channel.code, channel.json_body)
-
- def test_basic_paginate_relations(self) -> None:
- """Tests that calling pagination API correctly the latest relations."""
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(200, channel.code, channel.json_body)
- first_annotation_id = channel.json_body["event_id"]
-
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
- self.assertEqual(200, channel.code, channel.json_body)
- second_annotation_id = channel.json_body["event_id"]
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- # We expect to get back a single pagination result, which is the latest
- # full relation event we sent above.
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
- self.assert_dict(
- {
- "event_id": second_annotation_id,
- "sender": self.user_id,
- "type": "m.reaction",
- },
- channel.json_body["chunk"][0],
- )
-
- # We also expect to get the original event (the id of which is self.parent_id)
- self.assertEqual(
- channel.json_body["original_event"]["event_id"], self.parent_id
- )
-
- # Make sure next_batch has something in it that looks like it could be a
- # valid token.
- self.assertIsInstance(
- channel.json_body.get("next_batch"), str, channel.json_body
- )
-
- # Request the relations again, but with a different direction.
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations"
- f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- # We expect to get back a single pagination result, which is the earliest
- # full relation event we sent above.
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
- self.assert_dict(
- {
- "event_id": first_annotation_id,
- "sender": self.user_id,
- "type": "m.reaction",
- },
- channel.json_body["chunk"][0],
- )
-
- def test_repeated_paginate_relations(self) -> None:
- """Test that if we paginate using a limit and tokens then we get the
- expected events.
- """
-
- expected_event_ids = []
- for idx in range(10):
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx)
- )
- self.assertEqual(200, channel.code, channel.json_body)
- expected_event_ids.append(channel.json_body["event_id"])
-
- prev_token = ""
- found_event_ids: List[str] = []
- for _ in range(20):
- from_token = ""
- if prev_token:
- from_token = "&from=" + prev_token
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
- next_batch = channel.json_body.get("next_batch")
-
- self.assertNotEqual(prev_token, next_batch)
- prev_token = next_batch
-
- if not prev_token:
- break
-
- # We paginated backwards, so reverse
- found_event_ids.reverse()
- self.assertEqual(found_event_ids, expected_event_ids)
-
- def test_pagination_from_sync_and_messages(self) -> None:
- """Pagination tokens from /sync and /messages can be used to paginate /relations."""
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
- self.assertEqual(200, channel.code, channel.json_body)
- annotation_id = channel.json_body["event_id"]
- # Send an event after the relation events.
- self.helper.send(self.room, body="Latest event", tok=self.user_token)
-
- # Request /sync, limiting it such that only the latest event is returned
- # (and not the relation).
- filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}')
- channel = self.make_request(
- "GET", f"/sync?filter={filter}", access_token=self.user_token
- )
- self.assertEqual(200, channel.code, channel.json_body)
- room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
- sync_prev_batch = room_timeline["prev_batch"]
- self.assertIsNotNone(sync_prev_batch)
- # Ensure the relation event is not in the batch returned from /sync.
- self.assertNotIn(
- annotation_id, [ev["event_id"] for ev in room_timeline["events"]]
- )
-
- # Request /messages, limiting it such that only the latest event is
- # returned (and not the relation).
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/messages?dir=b&limit=1",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- messages_end = channel.json_body["end"]
- self.assertIsNotNone(messages_end)
- # Ensure the relation event is not in the chunk returned from /messages.
- self.assertNotIn(
- annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
- )
-
- # Request /relations with the pagination tokens received from both the
- # /sync and /messages responses above, in turn.
- #
- # This is a tiny bit silly since the client wouldn't know the parent ID
- # from the requests above; consider the parent ID to be known from a
- # previous /sync.
- for from_token in (sync_prev_batch, messages_end):
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- # The relation should be in the returned chunk.
- self.assertIn(
- annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
- )
-
- def test_aggregation_pagination_groups(self) -> None:
- """Test that we can paginate annotation groups correctly."""
-
- # We need to create ten separate users to send each reaction.
- access_tokens = [self.user_token, self.user2_token]
- idx = 0
- while len(access_tokens) < 10:
- user_id, token = self._create_user("test" + str(idx))
- idx += 1
-
- self.helper.join(self.room, user=user_id, tok=token)
- access_tokens.append(token)
-
- idx = 0
- sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1}
- for key in itertools.chain.from_iterable(
- itertools.repeat(key, num) for key, num in sent_groups.items()
- ):
- channel = self._send_relation(
- RelationTypes.ANNOTATION,
- "m.reaction",
- key=key,
- access_token=access_tokens[idx],
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- idx += 1
- idx %= len(access_tokens)
-
- prev_token: Optional[str] = None
- found_groups: Dict[str, int] = {}
- for _ in range(20):
- from_token = ""
- if prev_token:
- from_token = "&from=" + prev_token
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
- for groups in channel.json_body["chunk"]:
- # We only expect reactions
- self.assertEqual(groups["type"], "m.reaction", channel.json_body)
-
- # We should only see each key once
- self.assertNotIn(groups["key"], found_groups, channel.json_body)
-
- found_groups[groups["key"]] = groups["count"]
-
- next_batch = channel.json_body.get("next_batch")
-
- self.assertNotEqual(prev_token, next_batch)
- prev_token = next_batch
-
- if not prev_token:
- break
-
- self.assertEqual(sent_groups, found_groups)
-
- def test_aggregation_pagination_within_group(self) -> None:
- """Test that we can paginate within an annotation group."""
-
- # We need to create ten separate users to send each reaction.
- access_tokens = [self.user_token, self.user2_token]
- idx = 0
- while len(access_tokens) < 10:
- user_id, token = self._create_user("test" + str(idx))
- idx += 1
-
- self.helper.join(self.room, user=user_id, tok=token)
- access_tokens.append(token)
-
- idx = 0
- expected_event_ids = []
- for _ in range(10):
- channel = self._send_relation(
- RelationTypes.ANNOTATION,
- "m.reaction",
- key="👍",
- access_token=access_tokens[idx],
- )
- self.assertEqual(200, channel.code, channel.json_body)
- expected_event_ids.append(channel.json_body["event_id"])
-
- idx += 1
-
- # Also send a different type of reaction so that we test we don't see it
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
- self.assertEqual(200, channel.code, channel.json_body)
-
- prev_token = ""
- found_event_ids: List[str] = []
- encoded_key = urllib.parse.quote_plus("👍".encode())
- for _ in range(20):
- from_token = ""
- if prev_token:
- from_token = "&from=" + prev_token
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}"
- f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
- f"/m.reaction/{encoded_key}?limit=1{from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
- found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
-
- next_batch = channel.json_body.get("next_batch")
-
- self.assertNotEqual(prev_token, next_batch)
- prev_token = next_batch
-
- if not prev_token:
- break
-
- # We paginated backwards, so reverse
- found_event_ids.reverse()
- self.assertEqual(found_event_ids, expected_event_ids)
def test_aggregation(self) -> None:
"""Test that annotations get correctly aggregated."""
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
)
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
- self.assertEqual(200, channel.code, channel.json_body)
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
channel = self.make_request(
"GET",
@@ -547,215 +301,6 @@ class RelationsTestCase(BaseRelationsTestCase):
)
self.assertEqual(400, channel.code, channel.json_body)
- @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
- def test_bundled_aggregations(self) -> None:
- """
- Test that annotations, references, and threads get correctly bundled.
-
- Note that this doesn't test against /relations since only thread relations
- get bundled via that API. See test_aggregation_get_event_for_thread.
-
- See test_edit for a similar test for edits.
- """
- # Setup by sending a variety of relations.
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
- reply_1 = channel.json_body["event_id"]
-
- channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
- reply_2 = channel.json_body["event_id"]
-
- channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
- thread_2 = channel.json_body["event_id"]
-
- def assert_bundle(event_json: JsonDict) -> None:
- """Assert the expected values of the bundled aggregations."""
- relations_dict = event_json["unsigned"].get("m.relations")
-
- # Ensure the fields are as expected.
- self.assertCountEqual(
- relations_dict.keys(),
- (
- RelationTypes.ANNOTATION,
- RelationTypes.REFERENCE,
- RelationTypes.THREAD,
- ),
- )
-
- # Check the values of each field.
- self.assertEqual(
- {
- "chunk": [
- {"type": "m.reaction", "key": "a", "count": 2},
- {"type": "m.reaction", "key": "b", "count": 1},
- ]
- },
- relations_dict[RelationTypes.ANNOTATION],
- )
-
- self.assertEqual(
- {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
- relations_dict[RelationTypes.REFERENCE],
- )
-
- self.assertEqual(
- 2,
- relations_dict[RelationTypes.THREAD].get("count"),
- )
- self.assertTrue(
- relations_dict[RelationTypes.THREAD].get("current_user_participated")
- )
- # The latest thread event has some fields that don't matter.
- self.assert_dict(
- {
- "content": {
- "m.relates_to": {
- "event_id": self.parent_id,
- "rel_type": RelationTypes.THREAD,
- }
- },
- "event_id": thread_2,
- "sender": self.user_id,
- "type": "m.room.test",
- },
- relations_dict[RelationTypes.THREAD].get("latest_event"),
- )
-
- # Request the event directly.
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/event/{self.parent_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- assert_bundle(channel.json_body)
-
- # Request the room messages.
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/messages?dir=b",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
-
- # Request the room context.
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/context/{self.parent_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- assert_bundle(channel.json_body["event"])
-
- # Request sync.
- channel = self.make_request("GET", "/sync", access_token=self.user_token)
- self.assertEqual(200, channel.code, channel.json_body)
- room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
- self.assertTrue(room_timeline["limited"])
- assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
-
- # Request search.
- channel = self.make_request(
- "POST",
- "/search",
- # Search term matches the parent message.
- content={"search_categories": {"room_events": {"search_term": "Hi"}}},
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- chunk = [
- result["result"]
- for result in channel.json_body["search_categories"]["room_events"][
- "results"
- ]
- ]
- assert_bundle(self._find_event_in_chunk(chunk))
-
- def test_aggregation_get_event_for_annotation(self) -> None:
- """Test that annotations do not get bundled aggregations included
- when directly requested.
- """
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(200, channel.code, channel.json_body)
- annotation_id = channel.json_body["event_id"]
-
- # Annotate the annotation.
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/event/{annotation_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertIsNone(channel.json_body["unsigned"].get("m.relations"))
-
- def test_aggregation_get_event_for_thread(self) -> None:
- """Test that threads get bundled aggregations included when directly requested."""
- channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
- thread_id = channel.json_body["event_id"]
-
- # Annotate the annotation.
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/event/{thread_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(
- channel.json_body["unsigned"].get("m.relations"),
- {
- RelationTypes.ANNOTATION: {
- "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
- },
- },
- )
-
- # It should also be included when the entire thread is requested.
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(len(channel.json_body["chunk"]), 1)
-
- thread_message = channel.json_body["chunk"][0]
- self.assertEqual(
- thread_message["unsigned"].get("m.relations"),
- {
- RelationTypes.ANNOTATION: {
- "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
- },
- },
- )
-
def test_ignore_invalid_room(self) -> None:
"""Test that we ignore invalid relations over federation."""
# Create another room and send a message in it.
@@ -877,8 +422,6 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.room.message",
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
)
- self.assertEqual(200, channel.code, channel.json_body)
-
edit_event_id = channel.json_body["event_id"]
def assert_bundle(event_json: JsonDict) -> None:
@@ -954,7 +497,7 @@ class RelationsTestCase(BaseRelationsTestCase):
shouldn't be allowed, are correctly handled.
"""
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
content={
@@ -963,7 +506,6 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.new_content": {"msgtype": "m.text", "body": "First edit"},
},
)
- self.assertEqual(200, channel.code, channel.json_body)
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
channel = self._send_relation(
@@ -971,11 +513,9 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.room.message",
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
)
- self.assertEqual(200, channel.code, channel.json_body)
-
edit_event_id = channel.json_body["event_id"]
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.REPLACE,
"m.room.message.WRONG_TYPE",
content={
@@ -984,7 +524,6 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"},
},
)
- self.assertEqual(200, channel.code, channel.json_body)
channel = self.make_request(
"GET",
@@ -1015,7 +554,6 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.room.message",
content={"msgtype": "m.text", "body": "A reply!"},
)
- self.assertEqual(200, channel.code, channel.json_body)
reply = channel.json_body["event_id"]
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
@@ -1025,8 +563,6 @@ class RelationsTestCase(BaseRelationsTestCase):
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
parent_id=reply,
)
- self.assertEqual(200, channel.code, channel.json_body)
-
edit_event_id = channel.json_body["event_id"]
channel = self.make_request(
@@ -1071,17 +607,15 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.room.message",
content={"msgtype": "m.text", "body": "A threaded reply!"},
)
- self.assertEqual(200, channel.code, channel.json_body)
threaded_event_id = channel.json_body["event_id"]
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
parent_id=threaded_event_id,
)
- self.assertEqual(200, channel.code, channel.json_body)
# Fetch the thread root, to get the bundled aggregation for the thread.
channel = self.make_request(
@@ -1113,11 +647,10 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.new_content": new_body,
},
)
- self.assertEqual(200, channel.code, channel.json_body)
edit_event_id = channel.json_body["event_id"]
# Edit the edit event.
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
content={
@@ -1127,7 +660,6 @@ class RelationsTestCase(BaseRelationsTestCase):
},
parent_id=edit_event_id,
)
- self.assertEqual(200, channel.code, channel.json_body)
# Request the original event.
channel = self.make_request(
@@ -1154,7 +686,6 @@ class RelationsTestCase(BaseRelationsTestCase):
def test_unknown_relations(self) -> None:
"""Unknown relations should be accepted."""
channel = self._send_relation("m.relation.test", "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
event_id = channel.json_body["event_id"]
channel = self.make_request(
@@ -1195,28 +726,15 @@ class RelationsTestCase(BaseRelationsTestCase):
self.assertEqual(200, channel.code, channel.json_body)
self.assertEqual(channel.json_body["chunk"], [])
- def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
- """
- Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
- """
- for event in events:
- if event["event_id"] == self.parent_id:
- return event
-
- raise AssertionError(f"Event {self.parent_id} not found in chunk")
-
def test_background_update(self) -> None:
"""Test the event_arbitrary_relations background update."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
- self.assertEqual(200, channel.code, channel.json_body)
annotation_event_id_good = channel.json_body["event_id"]
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A")
- self.assertEqual(200, channel.code, channel.json_body)
annotation_event_id_bad = channel.json_body["event_id"]
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
thread_event_id = channel.json_body["event_id"]
# Clean-up the table as if the inserts did not happen during event creation.
@@ -1267,6 +785,516 @@ class RelationsTestCase(BaseRelationsTestCase):
[annotation_event_id_good, thread_event_id],
)
+
+class RelationPaginationTestCase(BaseRelationsTestCase):
+ def test_basic_paginate_relations(self) -> None:
+ """Tests that calling pagination API correctly the latest relations."""
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ first_annotation_id = channel.json_body["event_id"]
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
+ second_annotation_id = channel.json_body["event_id"]
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ # We expect to get back a single pagination result, which is the latest
+ # full relation event we sent above.
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+ self.assert_dict(
+ {
+ "event_id": second_annotation_id,
+ "sender": self.user_id,
+ "type": "m.reaction",
+ },
+ channel.json_body["chunk"][0],
+ )
+
+ # We also expect to get the original event (the id of which is self.parent_id)
+ self.assertEqual(
+ channel.json_body["original_event"]["event_id"], self.parent_id
+ )
+
+ # Make sure next_batch has something in it that looks like it could be a
+ # valid token.
+ self.assertIsInstance(
+ channel.json_body.get("next_batch"), str, channel.json_body
+ )
+
+ # Request the relations again, but with a different direction.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations"
+ f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ # We expect to get back a single pagination result, which is the earliest
+ # full relation event we sent above.
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+ self.assert_dict(
+ {
+ "event_id": first_annotation_id,
+ "sender": self.user_id,
+ "type": "m.reaction",
+ },
+ channel.json_body["chunk"][0],
+ )
+
+ def test_repeated_paginate_relations(self) -> None:
+ """Test that if we paginate using a limit and tokens then we get the
+ expected events.
+ """
+
+ expected_event_ids = []
+ for idx in range(10):
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx)
+ )
+ expected_event_ids.append(channel.json_body["event_id"])
+
+ prev_token = ""
+ found_event_ids: List[str] = []
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + prev_token
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEqual(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ # We paginated backwards, so reverse
+ found_event_ids.reverse()
+ self.assertEqual(found_event_ids, expected_event_ids)
+
+ def test_pagination_from_sync_and_messages(self) -> None:
+ """Pagination tokens from /sync and /messages can be used to paginate /relations."""
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
+ annotation_id = channel.json_body["event_id"]
+ # Send an event after the relation events.
+ self.helper.send(self.room, body="Latest event", tok=self.user_token)
+
+ # Request /sync, limiting it such that only the latest event is returned
+ # (and not the relation).
+ filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}')
+ channel = self.make_request(
+ "GET", f"/sync?filter={filter}", access_token=self.user_token
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+ sync_prev_batch = room_timeline["prev_batch"]
+ self.assertIsNotNone(sync_prev_batch)
+ # Ensure the relation event is not in the batch returned from /sync.
+ self.assertNotIn(
+ annotation_id, [ev["event_id"] for ev in room_timeline["events"]]
+ )
+
+ # Request /messages, limiting it such that only the latest event is
+ # returned (and not the relation).
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/messages?dir=b&limit=1",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ messages_end = channel.json_body["end"]
+ self.assertIsNotNone(messages_end)
+ # Ensure the relation event is not in the chunk returned from /messages.
+ self.assertNotIn(
+ annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+ )
+
+ # Request /relations with the pagination tokens received from both the
+ # /sync and /messages responses above, in turn.
+ #
+ # This is a tiny bit silly since the client wouldn't know the parent ID
+ # from the requests above; consider the parent ID to be known from a
+ # previous /sync.
+ for from_token in (sync_prev_batch, messages_end):
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ # The relation should be in the returned chunk.
+ self.assertIn(
+ annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+ )
+
+ def test_aggregation_pagination_groups(self) -> None:
+ """Test that we can paginate annotation groups correctly."""
+
+ # We need to create ten separate users to send each reaction.
+ access_tokens = [self.user_token, self.user2_token]
+ idx = 0
+ while len(access_tokens) < 10:
+ user_id, token = self._create_user("test" + str(idx))
+ idx += 1
+
+ self.helper.join(self.room, user=user_id, tok=token)
+ access_tokens.append(token)
+
+ idx = 0
+ sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1}
+ for key in itertools.chain.from_iterable(
+ itertools.repeat(key, num) for key, num in sent_groups.items()
+ ):
+ self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ key=key,
+ access_token=access_tokens[idx],
+ )
+
+ idx += 1
+ idx %= len(access_tokens)
+
+ prev_token: Optional[str] = None
+ found_groups: Dict[str, int] = {}
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + prev_token
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+
+ for groups in channel.json_body["chunk"]:
+ # We only expect reactions
+ self.assertEqual(groups["type"], "m.reaction", channel.json_body)
+
+ # We should only see each key once
+ self.assertNotIn(groups["key"], found_groups, channel.json_body)
+
+ found_groups[groups["key"]] = groups["count"]
+
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEqual(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ self.assertEqual(sent_groups, found_groups)
+
+ def test_aggregation_pagination_within_group(self) -> None:
+ """Test that we can paginate within an annotation group."""
+
+ # We need to create ten separate users to send each reaction.
+ access_tokens = [self.user_token, self.user2_token]
+ idx = 0
+ while len(access_tokens) < 10:
+ user_id, token = self._create_user("test" + str(idx))
+ idx += 1
+
+ self.helper.join(self.room, user=user_id, tok=token)
+ access_tokens.append(token)
+
+ idx = 0
+ expected_event_ids = []
+ for _ in range(10):
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ key="👍",
+ access_token=access_tokens[idx],
+ )
+ expected_event_ids.append(channel.json_body["event_id"])
+
+ idx += 1
+
+ # Also send a different type of reaction so that we test we don't see it
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
+
+ prev_token = ""
+ found_event_ids: List[str] = []
+ encoded_key = urllib.parse.quote_plus("👍".encode())
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + prev_token
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}"
+ f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
+ f"/m.reaction/{encoded_key}?limit=1{from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+
+ found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEqual(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ # We paginated backwards, so reverse
+ found_event_ids.reverse()
+ self.assertEqual(found_event_ids, expected_event_ids)
+
+
+class BundledAggregationsTestCase(BaseRelationsTestCase):
+ """
+ See RelationsTestCase.test_edit for a similar test for edits.
+
+ Note that this doesn't test against /relations since only thread relations
+ get bundled via that API. See test_aggregation_get_event_for_thread.
+ """
+
+ def _test_bundled_aggregations(
+ self,
+ relation_type: str,
+ assertion_callable: Callable[[JsonDict], None],
+ expected_db_txn_for_event: int,
+ ) -> None:
+ """
+ Makes requests to various endpoints which should include bundled aggregations
+ and then calls an assertion function on the bundled aggregations.
+
+ Args:
+ relation_type: The field to search for in the `m.relations` field in unsigned.
+ assertion_callable: Called with the contents of unsigned["m.relations"][relation_type]
+ for relation-specific assertions.
+ expected_db_txn_for_event: The number of database transactions which
+ are expected for a call to /event/.
+ """
+
+ def assert_bundle(event_json: JsonDict) -> None:
+ """Assert the expected values of the bundled aggregations."""
+ relations_dict = event_json["unsigned"].get("m.relations")
+
+ # Ensure the fields are as expected.
+ self.assertCountEqual(relations_dict.keys(), (relation_type,))
+ assertion_callable(relations_dict[relation_type])
+
+ # Request the event directly.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/event/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ assert_bundle(channel.json_body)
+ assert channel.resource_usage is not None
+ self.assertEqual(channel.resource_usage.db_txn_count, expected_db_txn_for_event)
+
+ # Request the room messages.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/messages?dir=b",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
+
+ # Request the room context.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/context/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ assert_bundle(channel.json_body["event"])
+
+ # Request sync.
+ filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}')
+ channel = self.make_request(
+ "GET", f"/sync?filter={filter}", access_token=self.user_token
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+ self.assertTrue(room_timeline["limited"])
+ assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
+
+ # Request search.
+ channel = self.make_request(
+ "POST",
+ "/search",
+ # Search term matches the parent message.
+ content={"search_categories": {"room_events": {"search_term": "Hi"}}},
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ chunk = [
+ result["result"]
+ for result in channel.json_body["search_categories"]["room_events"][
+ "results"
+ ]
+ ]
+ assert_bundle(self._find_event_in_chunk(chunk))
+
+ @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
+ def test_annotation(self) -> None:
+ """
+ Test that annotations get correctly bundled.
+ """
+ # Setup by sending a variety of relations.
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
+ )
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
+
+ def assert_annotations(bundled_aggregations: JsonDict) -> None:
+ self.assertEqual(
+ {
+ "chunk": [
+ {"type": "m.reaction", "key": "a", "count": 2},
+ {"type": "m.reaction", "key": "b", "count": 1},
+ ]
+ },
+ bundled_aggregations,
+ )
+
+ self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7)
+
+ @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
+ def test_reference(self) -> None:
+ """
+ Test that references get correctly bundled.
+ """
+ channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+ reply_1 = channel.json_body["event_id"]
+
+ channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+ reply_2 = channel.json_body["event_id"]
+
+ def assert_annotations(bundled_aggregations: JsonDict) -> None:
+ self.assertEqual(
+ {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
+ bundled_aggregations,
+ )
+
+ self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7)
+
+ @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
+ def test_thread(self) -> None:
+ """
+ Test that threads get correctly bundled.
+ """
+ self._send_relation(RelationTypes.THREAD, "m.room.test")
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ thread_2 = channel.json_body["event_id"]
+
+ def assert_annotations(bundled_aggregations: JsonDict) -> None:
+ self.assertEqual(2, bundled_aggregations.get("count"))
+ self.assertTrue(bundled_aggregations.get("current_user_participated"))
+ # The latest thread event has some fields that don't matter.
+ self.assert_dict(
+ {
+ "content": {
+ "m.relates_to": {
+ "event_id": self.parent_id,
+ "rel_type": RelationTypes.THREAD,
+ }
+ },
+ "event_id": thread_2,
+ "sender": self.user_id,
+ "type": "m.room.test",
+ },
+ bundled_aggregations.get("latest_event"),
+ )
+
+ self._test_bundled_aggregations(RelationTypes.THREAD, assert_annotations, 9)
+
+ def test_aggregation_get_event_for_annotation(self) -> None:
+ """Test that annotations do not get bundled aggregations included
+ when directly requested.
+ """
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ annotation_id = channel.json_body["event_id"]
+
+ # Annotate the annotation.
+ self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id
+ )
+
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/event/{annotation_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ self.assertIsNone(channel.json_body["unsigned"].get("m.relations"))
+
+ def test_aggregation_get_event_for_thread(self) -> None:
+ """Test that threads get bundled aggregations included when directly requested."""
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ thread_id = channel.json_body["event_id"]
+
+ # Annotate the annotation.
+ self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
+ )
+
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/event/{thread_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ self.assertEqual(
+ channel.json_body["unsigned"].get("m.relations"),
+ {
+ RelationTypes.ANNOTATION: {
+ "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
+ },
+ },
+ )
+
+ # It should also be included when the entire thread is requested.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ self.assertEqual(len(channel.json_body["chunk"]), 1)
+
+ thread_message = channel.json_body["chunk"][0]
+ self.assertEqual(
+ thread_message["unsigned"].get("m.relations"),
+ {
+ RelationTypes.ANNOTATION: {
+ "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
+ },
+ },
+ )
+
def test_bundled_aggregations_with_filter(self) -> None:
"""
If "unsigned" is an omitted field (due to filtering), adding the bundled
@@ -1322,46 +1350,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
)
self.assertEqual(200, channel.code, channel.json_body)
- def _make_relation_requests(self) -> Tuple[List[str], JsonDict]:
- """
- Makes requests and ensures they result in a 200 response, returns a
- tuple of results:
-
- 1. `/relations` -> Returns a list of event IDs.
- 2. `/event` -> Returns the response's m.relations field (from unsigned),
- if it exists.
- """
-
- # Request the relations of the event.
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}",
- access_token=self.user_token,
- )
- self.assertEquals(200, channel.code, channel.json_body)
- event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]]
-
- # Fetch the bundled aggregations of the event.
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/event/{self.parent_id}",
- access_token=self.user_token,
- )
- self.assertEquals(200, channel.code, channel.json_body)
- bundled_relations = channel.json_body["unsigned"].get("m.relations", {})
-
- return event_ids, bundled_relations
-
- def _get_aggregations(self) -> List[JsonDict]:
- """Request /aggregations on the parent ID and includes the returned chunk."""
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- return channel.json_body["chunk"]
-
def test_redact_relation_annotation(self) -> None:
"""
Test that annotations of an event are properly handled after the
@@ -1371,17 +1359,16 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
the response to relations.
"""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(200, channel.code, channel.json_body)
to_redact_event_id = channel.json_body["event_id"]
channel = self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
)
- self.assertEqual(200, channel.code, channel.json_body)
unredacted_event_id = channel.json_body["event_id"]
# Both relations should exist.
- event_ids, relations = self._make_relation_requests()
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id])
self.assertEquals(
relations["m.annotation"],
@@ -1396,7 +1383,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
self._redact(to_redact_event_id)
# The unredacted relation should still exist.
- event_ids, relations = self._make_relation_requests()
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
self.assertEquals(event_ids, [unredacted_event_id])
self.assertEquals(
relations["m.annotation"],
@@ -1419,7 +1407,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
EventTypes.Message,
content={"body": "reply 1", "msgtype": "m.text"},
)
- self.assertEqual(200, channel.code, channel.json_body)
unredacted_event_id = channel.json_body["event_id"]
# Note that the *last* event in the thread is redacted, as that gets
@@ -1429,11 +1416,11 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
EventTypes.Message,
content={"body": "reply 2", "msgtype": "m.text"},
)
- self.assertEqual(200, channel.code, channel.json_body)
to_redact_event_id = channel.json_body["event_id"]
# Both relations exist.
- event_ids, relations = self._make_relation_requests()
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
self.assertEquals(event_ids, [to_redact_event_id, unredacted_event_id])
self.assertDictContainsSubset(
{
@@ -1452,7 +1439,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
self._redact(to_redact_event_id)
# The unredacted relation should still exist.
- event_ids, relations = self._make_relation_requests()
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
self.assertEquals(event_ids, [unredacted_event_id])
self.assertDictContainsSubset(
{
@@ -1472,7 +1460,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
is redacted.
"""
# Add a relation
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
parent_id=self.parent_id,
@@ -1482,10 +1470,10 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
"m.new_content": {"msgtype": "m.text", "body": "First edit"},
},
)
- self.assertEqual(200, channel.code, channel.json_body)
# Check the relation is returned
- event_ids, relations = self._make_relation_requests()
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
self.assertEqual(len(event_ids), 1)
self.assertIn(RelationTypes.REPLACE, relations)
@@ -1493,7 +1481,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
self._redact(self.parent_id)
# The relations are not returned.
- event_ids, relations = self._make_relation_requests()
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
self.assertEqual(len(event_ids), 0)
self.assertEqual(relations, {})
@@ -1503,11 +1492,11 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
"""
# Add a relation
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
- self.assertEqual(200, channel.code, channel.json_body)
related_event_id = channel.json_body["event_id"]
# The relations should exist.
- event_ids, relations = self._make_relation_requests()
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
self.assertEqual(len(event_ids), 1)
self.assertIn(RelationTypes.ANNOTATION, relations)
@@ -1519,7 +1508,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
self._redact(self.parent_id)
# The relations are returned.
- event_ids, relations = self._make_relation_requests()
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
self.assertEquals(event_ids, [related_event_id])
self.assertEquals(
relations["m.annotation"],
@@ -1540,14 +1530,14 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
EventTypes.Message,
content={"body": "reply 1", "msgtype": "m.text"},
)
- self.assertEqual(200, channel.code, channel.json_body)
related_event_id = channel.json_body["event_id"]
# Redact one of the reactions.
self._redact(self.parent_id)
# The unredacted relation should still exist.
- event_ids, relations = self._make_relation_requests()
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
self.assertEquals(len(event_ids), 1)
self.assertDictContainsSubset(
{
diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py
index 3fb37a2a59..62e308814d 100644
--- a/tests/rest/media/v1/test_html_preview.py
+++ b/tests/rest/media/v1/test_html_preview.py
@@ -16,7 +16,6 @@ from synapse.rest.media.v1.preview_html import (
_get_html_media_encodings,
decode_body,
parse_html_to_open_graph,
- rebase_url,
summarize_paragraphs,
)
@@ -161,7 +160,7 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -177,7 +176,7 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -196,7 +195,7 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(
og,
@@ -218,7 +217,7 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -232,7 +231,7 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
@@ -247,7 +246,7 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
@@ -262,7 +261,7 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
@@ -290,7 +289,7 @@ class CalcOgTestCase(unittest.TestCase):
<head><title>Foo</title></head><body>Some text.</body></html>
""".strip()
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_invalid_encoding(self) -> None:
@@ -304,7 +303,7 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
tree = decode_body(html, "http://example.com/test.html", "invalid-encoding")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_invalid_encoding2(self) -> None:
@@ -319,7 +318,7 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
def test_windows_1252(self) -> None:
@@ -333,7 +332,7 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."})
@@ -448,34 +447,3 @@ class MediaEncodingTestCase(unittest.TestCase):
'text/html; charset="invalid"',
)
self.assertEqual(list(encodings), ["utf-8", "cp1252"])
-
-
-class RebaseUrlTestCase(unittest.TestCase):
- def test_relative(self) -> None:
- """Relative URLs should be resolved based on the context of the base URL."""
- self.assertEqual(
- rebase_url("subpage", "https://example.com/foo/"),
- "https://example.com/foo/subpage",
- )
- self.assertEqual(
- rebase_url("sibling", "https://example.com/foo"),
- "https://example.com/sibling",
- )
- self.assertEqual(
- rebase_url("/bar", "https://example.com/foo/"),
- "https://example.com/bar",
- )
-
- def test_absolute(self) -> None:
- """Absolute URLs should not be modified."""
- self.assertEqual(
- rebase_url("https://alice.com/a/", "https://example.com/foo/"),
- "https://alice.com/a/",
- )
-
- def test_data(self) -> None:
- """Data URLs should not be modified."""
- self.assertEqual(
- rebase_url("data:,Hello%2C%20World%21", "https://example.com/foo/"),
- "data:,Hello%2C%20World%21",
- )
diff --git a/tests/server.py b/tests/server.py
index 82990c2eb9..6ce2a17bf4 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -54,13 +54,18 @@ from twisted.internet.interfaces import (
ITransport,
)
from twisted.python.failure import Failure
-from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
+from twisted.test.proto_helpers import (
+ AccumulatingProtocol,
+ MemoryReactor,
+ MemoryReactorClock,
+)
from twisted.web.http_headers import Headers
from twisted.web.resource import IResource
from twisted.web.server import Request, Site
from synapse.config.database import DatabaseConnectionConfig
from synapse.http.site import SynapseRequest
+from synapse.logging.context import ContextResourceUsage
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.engines import PostgresEngine, create_engine
@@ -88,18 +93,19 @@ class TimedOutException(Exception):
"""
-@attr.s
+@attr.s(auto_attribs=True)
class FakeChannel:
"""
A fake Twisted Web Channel (the part that interfaces with the
wire).
"""
- site = attr.ib(type=Union[Site, "FakeSite"])
- _reactor = attr.ib()
- result = attr.ib(type=dict, default=attr.Factory(dict))
- _ip = attr.ib(type=str, default="127.0.0.1")
+ site: Union[Site, "FakeSite"]
+ _reactor: MemoryReactor
+ result: dict = attr.Factory(dict)
+ _ip: str = "127.0.0.1"
_producer: Optional[Union[IPullProducer, IPushProducer]] = None
+ resource_usage: Optional[ContextResourceUsage] = None
@property
def json_body(self):
@@ -168,6 +174,8 @@ class FakeChannel:
def requestDone(self, _self):
self.result["done"] = True
+ if isinstance(_self, SynapseRequest):
+ self.resource_usage = _self.logcontext.get_resource_usage()
def getPeer(self):
# We give an address so that getClientIP returns a non null entry,
diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py
index 272cd35402..72bf5b3d31 100644
--- a/tests/storage/test_account_data.py
+++ b/tests/storage/test_account_data.py
@@ -47,9 +47,18 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
expected_ignorer_user_ids,
)
+ def assert_ignored(
+ self, ignorer_user_id: str, expected_ignored_user_ids: Set[str]
+ ) -> None:
+ self.assertEqual(
+ self.get_success(self.store.ignored_users(ignorer_user_id)),
+ expected_ignored_user_ids,
+ )
+
def test_ignoring_users(self):
"""Basic adding/removing of users from the ignore list."""
self._update_ignore_list("@other:test", "@another:remote")
+ self.assert_ignored(self.user, {"@other:test", "@another:remote"})
# Check a user which no one ignores.
self.assert_ignorers("@user:test", set())
@@ -62,6 +71,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
# Add one user, remove one user, and leave one user.
self._update_ignore_list("@foo:test", "@another:remote")
+ self.assert_ignored(self.user, {"@foo:test", "@another:remote"})
# Check the removed user.
self.assert_ignorers("@other:test", set())
@@ -76,20 +86,24 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
"""Ensure that caching works properly between different users."""
# The first user ignores a user.
self._update_ignore_list("@other:test")
+ self.assert_ignored(self.user, {"@other:test"})
self.assert_ignorers("@other:test", {self.user})
# The second user ignores them.
self._update_ignore_list("@other:test", ignorer_user_id="@second:test")
+ self.assert_ignored("@second:test", {"@other:test"})
self.assert_ignorers("@other:test", {self.user, "@second:test"})
# The first user un-ignores them.
self._update_ignore_list()
+ self.assert_ignored(self.user, set())
self.assert_ignorers("@other:test", {"@second:test"})
def test_invalid_data(self):
"""Invalid data ends up clearing out the ignored users list."""
# Add some data and ensure it is there.
self._update_ignore_list("@other:test")
+ self.assert_ignored(self.user, {"@other:test"})
self.assert_ignorers("@other:test", {self.user})
# No ignored_users key.
@@ -102,10 +116,12 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
)
# No one ignores the user now.
+ self.assert_ignored(self.user, set())
self.assert_ignorers("@other:test", set())
# Add some data and ensure it is there.
self._update_ignore_list("@other:test")
+ self.assert_ignored(self.user, {"@other:test"})
self.assert_ignorers("@other:test", {self.user})
# Invalid data.
@@ -118,4 +134,5 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
)
# No one ignores the user now.
+ self.assert_ignored(self.user, set())
self.assert_ignorers("@other:test", set())
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 5cf18b690e..fd619b64d4 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -17,8 +17,12 @@ from unittest.mock import Mock
import yaml
from twisted.internet.defer import Deferred, ensureDeferred
+from twisted.test.proto_helpers import MemoryReactor
+from synapse.server import HomeServer
from synapse.storage.background_updates import BackgroundUpdater
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable, simple_async_mock
@@ -26,7 +30,7 @@ from tests.unittest import override_config
class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, homeserver):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates
# the base test class should have run the real bg updates for us
self.assertTrue(
@@ -39,7 +43,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
)
self.store = self.hs.get_datastores().main
- async def update(self, progress, count):
+ async def update(self, progress: JsonDict, count: int) -> int:
duration_ms = 10
await self.clock.sleep((count * duration_ms) / 1000)
progress = {"my_key": progress["my_key"] + 1}
@@ -51,7 +55,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
)
return count
- def test_do_background_update(self):
+ def test_do_background_update(self) -> None:
# the time we claim it takes to update one item when running the update
duration_ms = 10
@@ -80,7 +84,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
# second step: complete the update
# we should now get run with a much bigger number of items to update
- async def update(progress, count):
+ async def update(progress: JsonDict, count: int) -> int:
self.assertEqual(progress, {"my_key": 2})
self.assertAlmostEqual(
count,
@@ -110,7 +114,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
"""
)
)
- def test_background_update_default_batch_set_by_config(self):
+ def test_background_update_default_batch_set_by_config(self) -> None:
"""
Test that the background update is run with the default_batch_size set by the config
"""
@@ -133,7 +137,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
# on the first call, we should get run with the default background update size specified in the config
self.update_handler.assert_called_once_with({"my_key": 1}, 20)
- def test_background_update_default_sleep_behavior(self):
+ def test_background_update_default_sleep_behavior(self) -> None:
"""
Test default background update behavior, which is to sleep
"""
@@ -147,7 +151,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
self.update_handler.side_effect = self.update
self.update_handler.reset_mock()
- self.updates.start_doing_background_updates(),
+ self.updates.start_doing_background_updates()
# 2: advance the reactor less than the default sleep duration (1000ms)
self.reactor.pump([0.5])
@@ -167,7 +171,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
"""
)
)
- def test_background_update_sleep_set_in_config(self):
+ def test_background_update_sleep_set_in_config(self) -> None:
"""
Test that changing the sleep time in the config changes how long it sleeps
"""
@@ -181,7 +185,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
self.update_handler.side_effect = self.update
self.update_handler.reset_mock()
- self.updates.start_doing_background_updates(),
+ self.updates.start_doing_background_updates()
# 2: advance the reactor less than the configured sleep duration (500ms)
self.reactor.pump([0.45])
@@ -201,7 +205,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
"""
)
)
- def test_disabling_background_update_sleep(self):
+ def test_disabling_background_update_sleep(self) -> None:
"""
Test that disabling sleep in the config results in bg update not sleeping
"""
@@ -215,7 +219,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
self.update_handler.side_effect = self.update
self.update_handler.reset_mock()
- self.updates.start_doing_background_updates(),
+ self.updates.start_doing_background_updates()
# 2: advance the reactor very little
self.reactor.pump([0.025])
@@ -230,7 +234,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
"""
)
)
- def test_background_update_duration_set_in_config(self):
+ def test_background_update_duration_set_in_config(self) -> None:
"""
Test that the desired duration set in the config is used in determining batch size
"""
@@ -254,7 +258,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
# the first update was run with the default batch size, this should be run with 500ms as the
# desired duration
- async def update(progress, count):
+ async def update(progress: JsonDict, count: int) -> int:
self.assertEqual(progress, {"my_key": 2})
self.assertAlmostEqual(
count,
@@ -275,7 +279,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
"""
)
)
- def test_background_update_min_batch_set_in_config(self):
+ def test_background_update_min_batch_set_in_config(self) -> None:
"""
Test that the minimum batch size set in the config is used
"""
@@ -290,7 +294,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
)
# Run the update with the long-running update item
- async def update(progress, count):
+ async def update_long(progress: JsonDict, count: int) -> int:
await self.clock.sleep((count * duration_ms) / 1000)
progress = {"my_key": progress["my_key"] + 1}
await self.store.db_pool.runInteraction(
@@ -301,7 +305,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
)
return count
- self.update_handler.side_effect = update
+ self.update_handler.side_effect = update_long
self.update_handler.reset_mock()
res = self.get_success(
self.updates.do_next_background_update(False),
@@ -311,25 +315,25 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
# the first update was run with the default batch size, this should be run with minimum batch size
# as the first items took a very long time
- async def update(progress, count):
+ async def update_short(progress: JsonDict, count: int) -> int:
self.assertEqual(progress, {"my_key": 2})
self.assertEqual(count, 5)
await self.updates._end_background_update("test_update")
return count
- self.update_handler.side_effect = update
+ self.update_handler.side_effect = update_short
self.get_success(self.updates.do_next_background_update(False))
class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, homeserver):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates
# the base test class should have run the real bg updates for us
self.assertTrue(
self.get_success(self.updates.has_completed_background_updates())
)
- self.update_deferred = Deferred()
+ self.update_deferred: Deferred[int] = Deferred()
self.update_handler = Mock(return_value=self.update_deferred)
self.updates.register_background_update_handler(
"test_update", self.update_handler
@@ -358,7 +362,7 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
),
)
- def test_controller(self):
+ def test_controller(self) -> None:
store = self.hs.get_datastores().main
self.get_success(
store.db_pool.simple_insert(
@@ -368,7 +372,7 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
)
# Set the return value for the context manager.
- enter_defer = Deferred()
+ enter_defer: Deferred[int] = Deferred()
self._update_ctx_manager.__aenter__ = Mock(return_value=enter_defer)
# Start the background update.
diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py
index 8597867563..a40fc20ef9 100644
--- a/tests/storage/test_database.py
+++ b/tests/storage/test_database.py
@@ -12,7 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.database import make_tuple_comparison_clause
+from typing import Callable, Tuple
+from unittest.mock import Mock, call
+
+from twisted.internet import defer
+from twisted.internet.defer import CancelledError, Deferred
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingTransaction,
+ make_tuple_comparison_clause,
+)
+from synapse.util import Clock
from tests import unittest
@@ -22,3 +35,150 @@ class TupleComparisonClauseTestCase(unittest.TestCase):
clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)])
self.assertEqual(clause, "(a,b) > (?,?)")
self.assertEqual(args, [1, 2])
+
+
+class CallbacksTestCase(unittest.HomeserverTestCase):
+ """Tests for transaction callbacks."""
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.db_pool: DatabasePool = self.store.db_pool
+
+ def _run_interaction(
+ self, func: Callable[[LoggingTransaction], object]
+ ) -> Tuple[Mock, Mock]:
+ """Run the given function in a database transaction, with callbacks registered.
+
+ Args:
+ func: The function to be run in a transaction. The transaction will be
+ retried if `func` raises an `OperationalError`.
+
+ Returns:
+ Two mocks, which were registered as an `after_callback` and an
+ `exception_callback` respectively, on every transaction attempt.
+ """
+ after_callback = Mock()
+ exception_callback = Mock()
+
+ def _test_txn(txn: LoggingTransaction) -> None:
+ txn.call_after(after_callback, 123, 456, extra=789)
+ txn.call_on_exception(exception_callback, 987, 654, extra=321)
+ func(txn)
+
+ try:
+ self.get_success_or_raise(
+ self.db_pool.runInteraction("test_transaction", _test_txn)
+ )
+ except Exception:
+ pass
+
+ return after_callback, exception_callback
+
+ def test_after_callback(self) -> None:
+ """Test that the after callback is called when a transaction succeeds."""
+ after_callback, exception_callback = self._run_interaction(lambda txn: None)
+
+ after_callback.assert_called_once_with(123, 456, extra=789)
+ exception_callback.assert_not_called()
+
+ def test_exception_callback(self) -> None:
+ """Test that the exception callback is called when a transaction fails."""
+ _test_txn = Mock(side_effect=ZeroDivisionError)
+ after_callback, exception_callback = self._run_interaction(_test_txn)
+
+ after_callback.assert_not_called()
+ exception_callback.assert_called_once_with(987, 654, extra=321)
+
+ def test_failed_retry(self) -> None:
+ """Test that the exception callback is called for every failed attempt."""
+ # Always raise an `OperationalError`.
+ _test_txn = Mock(side_effect=self.db_pool.engine.module.OperationalError)
+ after_callback, exception_callback = self._run_interaction(_test_txn)
+
+ after_callback.assert_not_called()
+ exception_callback.assert_has_calls(
+ [
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ ]
+ )
+ self.assertEqual(exception_callback.call_count, 6) # no additional calls
+
+ def test_successful_retry(self) -> None:
+ """Test callbacks for a failed transaction followed by a successful attempt."""
+ # Raise an `OperationalError` on the first attempt only.
+ _test_txn = Mock(
+ side_effect=[self.db_pool.engine.module.OperationalError, None]
+ )
+ after_callback, exception_callback = self._run_interaction(_test_txn)
+
+ # Calling both `after_callback`s when the first attempt failed is rather
+ # surprising (#12184). Let's document the behaviour in a test.
+ after_callback.assert_has_calls(
+ [
+ call(123, 456, extra=789),
+ call(123, 456, extra=789),
+ ]
+ )
+ self.assertEqual(after_callback.call_count, 2) # no additional calls
+ exception_callback.assert_not_called()
+
+
+class CancellationTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.db_pool: DatabasePool = self.store.db_pool
+
+ def test_after_callback(self) -> None:
+ """Test that the after callback is called when a transaction succeeds."""
+ d: "Deferred[None]"
+ after_callback = Mock()
+ exception_callback = Mock()
+
+ def _test_txn(txn: LoggingTransaction) -> None:
+ txn.call_after(after_callback, 123, 456, extra=789)
+ txn.call_on_exception(exception_callback, 987, 654, extra=321)
+ d.cancel()
+
+ d = defer.ensureDeferred(
+ self.db_pool.runInteraction("test_transaction", _test_txn)
+ )
+ self.get_failure(d, CancelledError)
+
+ after_callback.assert_called_once_with(123, 456, extra=789)
+ exception_callback.assert_not_called()
+
+ def test_exception_callback(self) -> None:
+ """Test that the exception callback is called when a transaction fails."""
+ d: "Deferred[None]"
+ after_callback = Mock()
+ exception_callback = Mock()
+
+ def _test_txn(txn: LoggingTransaction) -> None:
+ txn.call_after(after_callback, 123, 456, extra=789)
+ txn.call_on_exception(exception_callback, 987, 654, extra=321)
+ d.cancel()
+ # Simulate a retryable failure on every attempt.
+ raise self.db_pool.engine.module.OperationalError()
+
+ d = defer.ensureDeferred(
+ self.db_pool.runInteraction("test_transaction", _test_txn)
+ )
+ self.get_failure(d, CancelledError)
+
+ after_callback.assert_not_called()
+ exception_callback.assert_has_calls(
+ [
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ ]
+ )
+ self.assertEqual(exception_callback.call_count, 6) # no additional calls
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 6ac4b93f98..395396340b 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -13,9 +13,13 @@
# limitations under the License.
from typing import List, Optional
-from synapse.storage.database import DatabasePool
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
from tests.utils import USE_POSTGRES_FOR_TESTS
@@ -25,13 +29,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
- def _setup_db(self, txn):
+ def _setup_db(self, txn: LoggingTransaction) -> None:
txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute(
"""
@@ -59,12 +63,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
- def _insert_rows(self, instance_name: str, number: int):
+ def _insert_rows(self, instance_name: str, number: int) -> None:
"""Insert N rows as the given instance, inserting with stream IDs pulled
from the postgres sequence.
"""
- def _insert(txn):
+ def _insert(txn: LoggingTransaction) -> None:
for _ in range(number):
txn.execute(
"INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
@@ -80,12 +84,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
- def _insert_row_with_id(self, instance_name: str, stream_id: int):
+ def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None:
"""Insert one row as the given instance with given stream_id, updating
the postgres sequence position to match.
"""
- def _insert(txn):
+ def _insert(txn: LoggingTransaction) -> None:
txn.execute(
"INSERT INTO foobar VALUES (?, ?)",
(
@@ -104,7 +108,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
- def test_empty(self):
+ def test_empty(self) -> None:
"""Test an ID generator against an empty database gives sensible
current positions.
"""
@@ -114,7 +118,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# The table is empty so we expect an empty map for positions
self.assertEqual(id_gen.get_positions(), {})
- def test_single_instance(self):
+ def test_single_instance(self) -> None:
"""Test that reads and writes from a single process are handled
correctly.
"""
@@ -130,7 +134,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
- async def _get_next_async():
+ async def _get_next_async() -> None:
async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8)
@@ -142,7 +146,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
- def test_out_of_order_finish(self):
+ def test_out_of_order_finish(self) -> None:
"""Test that IDs persisted out of order are correctly handled"""
# Prefill table with 7 rows written by 'master'
@@ -191,7 +195,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 11})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
- def test_multi_instance(self):
+ def test_multi_instance(self) -> None:
"""Test that reads and writes from multiple processes are handled
correctly.
"""
@@ -215,7 +219,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
- async def _get_next_async():
+ async def _get_next_async() -> None:
async with first_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8)
@@ -233,7 +237,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# ... but calling `get_next` on the second instance should give a unique
# stream ID
- async def _get_next_async():
+ async def _get_next_async2() -> None:
async with second_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 9)
@@ -241,7 +245,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen.get_positions(), {"first": 3, "second": 7}
)
- self.get_success(_get_next_async())
+ self.get_success(_get_next_async2())
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9})
@@ -249,7 +253,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen.advance("first", 8)
self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
- def test_get_next_txn(self):
+ def test_get_next_txn(self) -> None:
"""Test that the `get_next_txn` function works correctly."""
# Prefill table with 7 rows written by 'master'
@@ -263,7 +267,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
- def _get_next_txn(txn):
+ def _get_next_txn(txn: LoggingTransaction) -> None:
stream_id = id_gen.get_next_txn(txn)
self.assertEqual(stream_id, 8)
@@ -275,7 +279,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
- def test_get_persisted_upto_position(self):
+ def test_get_persisted_upto_position(self) -> None:
"""Test that `get_persisted_upto_position` correctly tracks updates to
positions.
"""
@@ -317,7 +321,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen.advance("second", 15)
self.assertEqual(id_gen.get_persisted_upto_position(), 11)
- def test_get_persisted_upto_position_get_next(self):
+ def test_get_persisted_upto_position_get_next(self) -> None:
"""Test that `get_persisted_upto_position` correctly tracks updates to
positions when `get_next` is called.
"""
@@ -331,7 +335,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
- async def _get_next_async():
+ async def _get_next_async() -> None:
async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 6)
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
@@ -344,7 +348,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# `persisted_upto_position` in this case, then it will be correct in the
# other cases that are tested above (since they'll hit the same code).
- def test_restart_during_out_of_order_persistence(self):
+ def test_restart_during_out_of_order_persistence(self) -> None:
"""Test that restarting a process while another process is writing out
of order updates are handled correctly.
"""
@@ -388,7 +392,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen_worker.advance("master", 9)
self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
- def test_writer_config_change(self):
+ def test_writer_config_change(self) -> None:
"""Test that changing the writer config correctly works."""
self._insert_row_with_id("first", 3)
@@ -421,7 +425,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# Check that we get a sane next stream ID with this new config.
- async def _get_next_async():
+ async def _get_next_async() -> None:
async with id_gen_3.get_next() as stream_id:
self.assertEqual(stream_id, 6)
@@ -435,7 +439,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen_5.get_current_token_for_writer("first"), 6)
self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6)
- def test_sequence_consistency(self):
+ def test_sequence_consistency(self) -> None:
"""Test that we error out if the table and sequence diverges."""
# Prefill with some rows
@@ -458,13 +462,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
- def _setup_db(self, txn):
+ def _setup_db(self, txn: LoggingTransaction) -> None:
txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute(
"""
@@ -493,10 +497,10 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
return self.get_success(self.db_pool.runWithConnection(_create))
- def _insert_row(self, instance_name: str, stream_id: int):
+ def _insert_row(self, instance_name: str, stream_id: int) -> None:
"""Insert one row as the given instance with given stream_id."""
- def _insert(txn):
+ def _insert(txn: LoggingTransaction) -> None:
txn.execute(
"INSERT INTO foobar VALUES (?, ?)",
(
@@ -514,13 +518,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
- def test_single_instance(self):
+ def test_single_instance(self) -> None:
"""Test that reads and writes from a single process are handled
correctly.
"""
id_gen = self._create_id_generator()
- async def _get_next_async():
+ async def _get_next_async() -> None:
async with id_gen.get_next() as stream_id:
self._insert_row("master", stream_id)
@@ -530,7 +534,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
self.assertEqual(id_gen.get_persisted_upto_position(), -1)
- async def _get_next_async2():
+ async def _get_next_async2() -> None:
async with id_gen.get_next_mult(3) as stream_ids:
for stream_id in stream_ids:
self._insert_row("master", stream_id)
@@ -548,14 +552,14 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4)
self.assertEqual(second_id_gen.get_persisted_upto_position(), -4)
- def test_multiple_instance(self):
+ def test_multiple_instance(self) -> None:
"""Tests that having multiple instances that get advanced over
federation works corretly.
"""
id_gen_1 = self._create_id_generator("first", writers=["first", "second"])
id_gen_2 = self._create_id_generator("second", writers=["first", "second"])
- async def _get_next_async():
+ async def _get_next_async() -> None:
async with id_gen_1.get_next() as stream_id:
self._insert_row("first", stream_id)
id_gen_2.advance("first", stream_id)
@@ -567,7 +571,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
- async def _get_next_async2():
+ async def _get_next_async2() -> None:
async with id_gen_2.get_next() as stream_id:
self._insert_row("second", stream_id)
id_gen_1.advance("second", stream_id)
@@ -584,13 +588,13 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
- def _setup_db(self, txn):
+ def _setup_db(self, txn: LoggingTransaction) -> None:
txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute(
"""
@@ -642,7 +646,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
from the postgres sequence.
"""
- def _insert(txn):
+ def _insert(txn: LoggingTransaction) -> None:
for _ in range(number):
txn.execute(
"INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
@@ -659,7 +663,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
- def test_load_existing_stream(self):
+ def test_load_existing_stream(self) -> None:
"""Test creating ID gens with multiple tables that have rows from after
the position in `stream_positions` table.
"""
diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py
index 38e9f58ac6..5d1aa025d1 100644
--- a/tests/util/test_check_dependencies.py
+++ b/tests/util/test_check_dependencies.py
@@ -12,7 +12,7 @@ from tests.unittest import TestCase
class DummyDistribution(metadata.Distribution):
- def __init__(self, version: str):
+ def __init__(self, version: object):
self._version = version
@property
@@ -30,6 +30,7 @@ old = DummyDistribution("0.1.2")
old_release_candidate = DummyDistribution("0.1.2rc3")
new = DummyDistribution("1.2.3")
new_release_candidate = DummyDistribution("1.2.3rc4")
+distribution_with_no_version = DummyDistribution(None)
# could probably use stdlib TestCase --- no need for twisted here
@@ -67,6 +68,18 @@ class TestDependencyChecker(TestCase):
# should not raise
check_requirements()
+ def test_version_reported_as_none(self) -> None:
+ """Complain if importlib.metadata.version() returns None.
+
+ This shouldn't normally happen, but it was seen in the wild (#12223).
+ """
+ with patch(
+ "synapse.util.check_dependencies.metadata.requires",
+ return_value=["dummypkg >= 1"],
+ ):
+ with self.mock_installed_package(distribution_with_no_version):
+ self.assertRaises(DependencyException, check_requirements)
+
def test_checks_ignore_dev_dependencies(self) -> None:
"""Bot generic and per-extra checks should ignore dev dependencies."""
with patch(
|