diff options
author | Olivier Wilkinson (reivilibre) <oliverw@matrix.org> | 2021-12-13 15:54:21 +0000 |
---|---|---|
committer | Olivier Wilkinson (reivilibre) <oliverw@matrix.org> | 2021-12-13 15:54:21 +0000 |
commit | 075a2b7a5ded42c0058119687abd036e90a17296 (patch) | |
tree | 1618037e65a1f381e15cbcc4f1bdfd6d0ab526a0 | |
parent | Fix comment (diff) | |
parent | Make `get_device` return None if the device doesn't exist rather than raising... (diff) | |
download | synapse-075a2b7a5ded42c0058119687abd036e90a17296.tar.xz |
Merge branch 'develop' into rei/as_device_masquerading_msc3202
73 files changed, 1176 insertions, 440 deletions
diff --git a/changelog.d/11243.misc b/changelog.d/11243.misc new file mode 100644 index 0000000000..5ef7fe16d4 --- /dev/null +++ b/changelog.d/11243.misc @@ -0,0 +1 @@ +Allow specific, experimental events to be created without `prev_events`. Used by [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716). diff --git a/changelog.d/11378.feature b/changelog.d/11378.feature new file mode 100644 index 0000000000..524bf84f32 --- /dev/null +++ b/changelog.d/11378.feature @@ -0,0 +1 @@ +Allow guests to send state events per [MSC3419](https://github.com/matrix-org/matrix-doc/pull/3419). \ No newline at end of file diff --git a/changelog.d/11427.doc b/changelog.d/11427.doc new file mode 100644 index 0000000000..01cdfcf2b7 --- /dev/null +++ b/changelog.d/11427.doc @@ -0,0 +1 @@ +Document the usage of refresh tokens. \ No newline at end of file diff --git a/changelog.d/11480.misc b/changelog.d/11480.misc new file mode 100644 index 0000000000..aadc938b2b --- /dev/null +++ b/changelog.d/11480.misc @@ -0,0 +1 @@ +Add missing type hints to `synapse.config` module. diff --git a/changelog.d/11487.misc b/changelog.d/11487.misc new file mode 100644 index 0000000000..376b9078be --- /dev/null +++ b/changelog.d/11487.misc @@ -0,0 +1 @@ +Add test to ensure we share the same `state_group` across the whole historical batch when using the [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send` endpoint. diff --git a/changelog.d/11516.bugfix b/changelog.d/11516.bugfix new file mode 100644 index 0000000000..22bba93671 --- /dev/null +++ b/changelog.d/11516.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where relations from other rooms could be included in the bundled aggregations of an event. diff --git a/changelog.d/11520.misc b/changelog.d/11520.misc new file mode 100644 index 0000000000..2d84120e19 --- /dev/null +++ b/changelog.d/11520.misc @@ -0,0 +1 @@ +Use HTTPStatus constants in place of literals in `tests.rest.client.test_auth`. \ No newline at end of file diff --git a/changelog.d/11531.misc b/changelog.d/11531.misc new file mode 100644 index 0000000000..ed6ef3bb3e --- /dev/null +++ b/changelog.d/11531.misc @@ -0,0 +1 @@ +Add a receipt types constant for `m.read`. diff --git a/changelog.d/11535.misc b/changelog.d/11535.misc new file mode 100644 index 0000000000..580ac354ab --- /dev/null +++ b/changelog.d/11535.misc @@ -0,0 +1 @@ +Clean up `synapse.rest.admin`. \ No newline at end of file diff --git a/changelog.d/11541.misc b/changelog.d/11541.misc new file mode 100644 index 0000000000..31c72c2a20 --- /dev/null +++ b/changelog.d/11541.misc @@ -0,0 +1 @@ +Support unprefixed versions of fallback key property names. diff --git a/changelog.d/11542.misc b/changelog.d/11542.misc new file mode 100644 index 0000000000..f614165037 --- /dev/null +++ b/changelog.d/11542.misc @@ -0,0 +1 @@ +Add missing `errcode` to `parse_string` and `parse_boolean`. \ No newline at end of file diff --git a/changelog.d/11543.misc b/changelog.d/11543.misc new file mode 100644 index 0000000000..99817d71a4 --- /dev/null +++ b/changelog.d/11543.misc @@ -0,0 +1 @@ +Use HTTPStatus constants in place of literals in `synapse.http`. \ No newline at end of file diff --git a/changelog.d/11547.bugfix b/changelog.d/11547.bugfix new file mode 100644 index 0000000000..3950c4c8d3 --- /dev/null +++ b/changelog.d/11547.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.17.0 where a pusher created for an email with capital letters would fail to be created. diff --git a/changelog.d/11550.misc b/changelog.d/11550.misc new file mode 100644 index 0000000000..d5577e0b63 --- /dev/null +++ b/changelog.d/11550.misc @@ -0,0 +1 @@ +Fix an inaccurate and misleading comment in the `/sync` code. \ No newline at end of file diff --git a/changelog.d/11558.misc b/changelog.d/11558.misc new file mode 100644 index 0000000000..7c334f17e0 --- /dev/null +++ b/changelog.d/11558.misc @@ -0,0 +1 @@ +Stop populating unused database column `state_events.prev_state`. diff --git a/changelog.d/11560.misc b/changelog.d/11560.misc new file mode 100644 index 0000000000..eb968167f5 --- /dev/null +++ b/changelog.d/11560.misc @@ -0,0 +1 @@ +Minor efficiency improvements in event persistence. diff --git a/changelog.d/11565.misc b/changelog.d/11565.misc new file mode 100644 index 0000000000..ddcafd32cb --- /dev/null +++ b/changelog.d/11565.misc @@ -0,0 +1 @@ +Make `get_device` return `None` if the device doesn't exist rather than raising an exception. diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index b05af6d690..11f597b3ed 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -30,6 +30,7 @@ - [SSO Mapping Providers](sso_mapping_providers.md) - [Password Auth Providers](password_auth_providers.md) - [JSON Web Tokens](jwt.md) + - [Refresh Tokens](usage/configuration/user_authentication/refresh_tokens.md) - [Registration Captcha](CAPTCHA_SETUP.md) - [Application Services](application_services.md) - [Server Notices](server_notices.md) diff --git a/docs/usage/configuration/user_authentication/refresh_tokens.md b/docs/usage/configuration/user_authentication/refresh_tokens.md new file mode 100644 index 0000000000..23b3cddae0 --- /dev/null +++ b/docs/usage/configuration/user_authentication/refresh_tokens.md @@ -0,0 +1,139 @@ +# Refresh Tokens + +Synapse supports refresh tokens since version 1.49 (some earlier versions had support for an earlier, experimental draft of [MSC2918] which is not compatible). + + +[MSC2918]: https://github.com/matrix-org/matrix-doc/blob/main/proposals/2918-refreshtokens.md#msc2918-refresh-tokens + + +## Background and motivation + +Synapse users' sessions are identified by **access tokens**; access tokens are +issued to users on login. Each session gets a unique access token which identifies +it; the access token must be kept secret as it grants access to the user's account. + +Traditionally, these access tokens were eternally valid (at least until the user +explicitly chose to log out). + +In some cases, it may be desirable for these access tokens to expire so that the +potential damage caused by leaking an access token is reduced. +On the other hand, forcing a user to re-authenticate (log in again) often might +be too much of an inconvenience. + +**Refresh tokens** are a mechanism to avoid some of this inconvenience whilst +still getting most of the benefits of short access token lifetimes. +Refresh tokens are also a concept present in OAuth 2 — further reading is available +[here](https://datatracker.ietf.org/doc/html/rfc6749#section-1.5). + +When refresh tokens are in use, both an access token and a refresh token will be +issued to users on login. The access token will expire after a predetermined amount +of time, but otherwise works in the same way as before. When the access token is +close to expiring (or has expired), the user's client should present the homeserver +(Synapse) with the refresh token. + +The homeserver will then generate a new access token and refresh token for the user +and return them. The old refresh token is invalidated and can not be used again*. + +Finally, refresh tokens also make it possible for sessions to be logged out if they +are inactive for too long, before the session naturally ends; see the configuration +guide below. + + +*To prevent issues if clients lose connection half-way through refreshing a token, +the refresh token is only invalidated once the new access token has been used at +least once. For all intents and purposes, the above simplification is sufficient. + + +## Caveats + +There are some caveats: + +* If a third party gets both your access token and refresh token, they will be able to + continue to enjoy access to your session. + * This is still an improvement because you (the user) will notice when *your* + session expires and you're not able to use your refresh token. + That would be a giveaway that someone else has compromised your session. + You would be able to log in again and terminate that session. + Previously (with long-lived access tokens), a third party that has your access + token could go undetected for a very long time. +* Clients need to implement support for refresh tokens in order for them to be a + useful mechanism. + * It is up to homeserver administrators if they want to issue long-lived access + tokens to clients not implementing refresh tokens. + * For compatibility, it is likely that they should, at least until client support + is widespread. + * Users with clients that support refresh tokens will still benefit from the + added security; it's not possible to downgrade a session to using long-lived + access tokens so this effectively gives users the choice. + * In a closed environment where all users use known clients, this may not be + an issue as the homeserver administrator can know if the clients have refresh + token support. In that case, the non-refreshable access token lifetime + may be set to a short duration so that a similar level of security is provided. + + +## Configuration Guide + +The following configuration options, in the `registration` section, are related: + +* `session_lifetime`: maximum length of a session, even if it's refreshed. + In other words, the client must log in again after this time period. + In most cases, this can be unset (infinite) or set to a long time (years or months). +* `refreshable_access_token_lifetime`: lifetime of access tokens that are created + by clients supporting refresh tokens. + This should be short; a good value might be 5 minutes (`5m`). +* `nonrefreshable_access_token_lifetime`: lifetime of access tokens that are created + by clients which don't support refresh tokens. + Make this short if you want to effectively force use of refresh tokens. + Make this long if you don't want to inconvenience users of clients which don't + support refresh tokens (by forcing them to frequently re-authenticate using + login credentials). +* `refresh_token_lifetime`: lifetime of refresh tokens. + In other words, the client must refresh within this time period to maintain its session. + Unless you want to log inactive sessions out, it is often fine to use a long + value here or even leave it unset (infinite). + Beware that making it too short will inconvenience clients that do not connect + very often, including mobile clients and clients of infrequent users (by making + it more difficult for them to refresh in time, which may force them to need to + re-authenticate using login credentials). + +**Note:** All four options above only apply when tokens are created (by logging in or refreshing). +Changes to these settings do not apply retroactively. + + +### Using refresh token expiry to log out inactive sessions + +If you'd like to force sessions to be logged out upon inactivity, you can enable +refreshable access token expiry and refresh token expiry. + +This works because a client must refresh at least once within a period of +`refresh_token_lifetime` in order to maintain valid credentials to access the +account. + +(It's suggested that `refresh_token_lifetime` should be longer than +`refreshable_access_token_lifetime` and this section assumes that to be the case +for simplicity.) + +Note: this will only affect sessions using refresh tokens. You may wish to +set a short `nonrefreshable_access_token_lifetime` to prevent this being bypassed +by clients that do not support refresh tokens. + + +#### Choosing values that guarantee permitting some inactivity + +It may be desirable to permit some short periods of inactivity, for example to +accommodate brief outages in client connectivity. + +The following model aims to provide guidance for choosing `refresh_token_lifetime` +and `refreshable_access_token_lifetime` to satisfy requirements of the form: + +1. inactivity longer than `L` **MUST** cause the session to be logged out; and +2. inactivity shorter than `S` **MUST NOT** cause the session to be logged out. + +This model makes the weakest assumption that all active clients will refresh as +needed to maintain an active access token, but no sooner. +*In reality, clients may refresh more often than this model assumes, but the +above requirements will still hold.* + +To satisfy the above model, +* `refresh_token_lifetime` should be set to `L`; and +* `refreshable_access_token_lifetime` should be set to `L - S`. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index f7d29b4319..52c083a20b 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -253,5 +253,9 @@ class GuestAccess: FORBIDDEN: Final = "forbidden" +class ReceiptTypes: + READ: Final = "m.read" + + class ReadReceiptEventFields: MSC2285_HIDDEN: Final = "org.matrix.msc2285.hidden" diff --git a/synapse/config/key.py b/synapse/config/key.py index 035ee2416b..ee83c6c06b 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -16,12 +16,14 @@ import hashlib import logging import os -from typing import Any, Dict +from typing import Any, Dict, Iterator, List, Optional import attr import jsonschema from signedjson.key import ( NACL_ED25519, + SigningKey, + VerifyKey, decode_signing_key_base64, decode_verify_key_bytes, generate_signing_key, @@ -31,6 +33,7 @@ from signedjson.key import ( ) from unpaddedbase64 import decode_base64 +from synapse.types import JsonDict from synapse.util.stringutils import random_string, random_string_with_symbols from ._base import Config, ConfigError @@ -81,14 +84,13 @@ To suppress this warning and continue using 'matrix.org', admins should set logger = logging.getLogger(__name__) -@attr.s +@attr.s(slots=True, auto_attribs=True) class TrustedKeyServer: - # string: name of the server. - server_name = attr.ib() + # name of the server. + server_name: str - # dict[str,VerifyKey]|None: map from key id to key object, or None to disable - # signature verification. - verify_keys = attr.ib(default=None) + # map from key id to key object, or None to disable signature verification. + verify_keys: Optional[Dict[str, VerifyKey]] = None class KeyConfig(Config): @@ -279,15 +281,15 @@ class KeyConfig(Config): % locals() ) - def read_signing_keys(self, signing_key_path, name): + def read_signing_keys(self, signing_key_path: str, name: str) -> List[SigningKey]: """Read the signing keys in the given path. Args: - signing_key_path (str) - name (str): Associated config key name + signing_key_path + name: Associated config key name Returns: - list[SigningKey] + The signing keys read from the given path. """ signing_keys = self.read_file(signing_key_path, name) @@ -296,7 +298,9 @@ class KeyConfig(Config): except Exception as e: raise ConfigError("Error reading %s: %s" % (name, str(e))) - def read_old_signing_keys(self, old_signing_keys): + def read_old_signing_keys( + self, old_signing_keys: Optional[JsonDict] + ) -> Dict[str, VerifyKey]: if old_signing_keys is None: return {} keys = {} @@ -340,7 +344,7 @@ class KeyConfig(Config): write_signing_keys(signing_key_file, (key,)) -def _perspectives_to_key_servers(config): +def _perspectives_to_key_servers(config: JsonDict) -> Iterator[JsonDict]: """Convert old-style 'perspectives' configs into new-style 'trusted_key_servers' Returns an iterable of entries to add to trusted_key_servers. @@ -402,7 +406,9 @@ TRUSTED_KEY_SERVERS_SCHEMA = { } -def _parse_key_servers(key_servers, federation_verify_certificates): +def _parse_key_servers( + key_servers: List[Any], federation_verify_certificates: bool +) -> Iterator[TrustedKeyServer]: try: jsonschema.validate(key_servers, TRUSTED_KEY_SERVERS_SCHEMA) except jsonschema.ValidationError as e: @@ -444,7 +450,7 @@ def _parse_key_servers(key_servers, federation_verify_certificates): yield result -def _assert_keyserver_has_verify_keys(trusted_key_server): +def _assert_keyserver_has_verify_keys(trusted_key_server: TrustedKeyServer) -> None: if not trusted_key_server.verify_keys: raise ConfigError(INSECURE_NOTARY_ERROR) diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index 7ac82edb0e..1cc26e7578 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -22,10 +22,12 @@ from ._base import Config, ConfigError @attr.s class MetricsFlags: - known_servers = attr.ib(default=False, validator=attr.validators.instance_of(bool)) + known_servers: bool = attr.ib( + default=False, validator=attr.validators.instance_of(bool) + ) @classmethod - def all_off(cls): + def all_off(cls) -> "MetricsFlags": """ Instantiate the flags with all options set to off. """ diff --git a/synapse/config/server.py b/synapse/config/server.py index ba5b954263..1de2dea9b0 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -1257,7 +1257,7 @@ class ServerConfig(Config): help="Turn on the twisted telnet manhole service on the given port.", ) - def read_gc_intervals(self, durations) -> Optional[Tuple[float, float, float]]: + def read_gc_intervals(self, durations: Any) -> Optional[Tuple[float, float, float]]: """Reads the three durations for the GC min interval option, returning seconds.""" if durations is None: return None diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 4ca111618f..ffb316e4c0 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -132,7 +132,7 @@ class TlsConfig(Config): self.tls_certificate: Optional[crypto.X509] = None self.tls_private_key: Optional[crypto.PKey] = None - def read_certificate_from_disk(self): + def read_certificate_from_disk(self) -> None: """ Read the certificates and private key from disk. """ diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 84ef69df67..3da432c1df 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -454,23 +454,26 @@ class EventClientSerializer: return event_id = event.event_id + room_id = event.room_id # The bundled aggregations to include. aggregations = {} - annotations = await self.store.get_aggregation_groups_for_event(event_id) + annotations = await self.store.get_aggregation_groups_for_event( + event_id, room_id + ) if annotations.chunk: aggregations[RelationTypes.ANNOTATION] = annotations.to_dict() references = await self.store.get_relations_for_event( - event_id, RelationTypes.REFERENCE, direction="f" + event_id, room_id, RelationTypes.REFERENCE, direction="f" ) if references.chunk: aggregations[RelationTypes.REFERENCE] = references.to_dict() edit = None if event.type == EventTypes.Message: - edit = await self.store.get_applicable_edit(event_id) + edit = await self.store.get_applicable_edit(event_id, room_id) if edit: # If there is an edit replace the content, preserving existing @@ -503,7 +506,7 @@ class EventClientSerializer: ( thread_count, latest_thread_event, - ) = await self.store.get_thread_summary(event_id) + ) = await self.store.get_thread_summary(event_id, room_id) if latest_thread_event: aggregations[RelationTypes.THREAD] = { # Don't bundle aggregations as this could recurse forever. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 61607cf2ba..84724b207c 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -997,9 +997,7 @@ class AuthHandler: # really don't want is active access_tokens without a record of the # device, so we double-check it here. if device_id is not None: - try: - await self.store.get_device(user_id, device_id) - except StoreError: + if await self.store.get_device(user_id, device_id) is None: await self.store.delete_access_token(access_token) raise StoreError(400, "Login raced against device deletion") diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 82ee11e921..7665425232 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -106,10 +106,10 @@ class DeviceWorkerHandler: Raises: errors.NotFoundError: if the device was not found """ - try: - device = await self.store.get_device(user_id, device_id) - except errors.StoreError: - raise errors.NotFoundError + device = await self.store.get_device(user_id, device_id) + if device is None: + raise errors.NotFoundError() + ips = await self.store.get_last_client_ip_by_device(user_id, device_id) _update_device_from_client_ips(device, ips) @@ -602,6 +602,8 @@ class DeviceHandler(DeviceWorkerHandler): access_token, device_id ) old_device = await self.store.get_device(user_id, old_device_id) + if old_device is None: + raise errors.NotFoundError() await self.store.update_device(user_id, device_id, old_device["display_name"]) # can't call self.delete_device because that will clobber the # access token so call the storage layer directly diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index b2554bda04..14360b4e40 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -580,7 +580,9 @@ class E2eKeysHandler: log_kv( {"message": "Did not update one_time_keys", "reason": "no keys given"} ) - fallback_keys = keys.get("org.matrix.msc2732.fallback_keys", None) + fallback_keys = keys.get("fallback_keys") or keys.get( + "org.matrix.msc2732.fallback_keys" + ) if fallback_keys and isinstance(fallback_keys, dict): log_kv( { diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 87f671708c..38409fef38 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -496,6 +496,7 @@ class EventCreationHandler: require_consent: bool = True, outlier: bool = False, historical: bool = False, + allow_no_prev_events: bool = False, depth: Optional[int] = None, ) -> Tuple[EventBase, EventContext]: """ @@ -607,6 +608,7 @@ class EventCreationHandler: prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids, depth=depth, + allow_no_prev_events=allow_no_prev_events, ) # In an ideal world we wouldn't need the second part of this condition. However, @@ -882,6 +884,7 @@ class EventCreationHandler: prev_event_ids: Optional[List[str]] = None, auth_event_ids: Optional[List[str]] = None, depth: Optional[int] = None, + allow_no_prev_events: bool = False, ) -> Tuple[EventBase, EventContext]: """Create a new event for a local client @@ -912,6 +915,7 @@ class EventCreationHandler: full_state_ids_at_event = None if auth_event_ids is not None: # If auth events are provided, prev events must be also. + # prev_event_ids could be an empty array though. assert prev_event_ids is not None # Copy the full auth state before it stripped down @@ -943,14 +947,22 @@ class EventCreationHandler: else: prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id) - # we now ought to have some prev_events (unless it's a create event). - # - # do a quick sanity check here, rather than waiting until we've created the + # Do a quick sanity check here, rather than waiting until we've created the # event and then try to auth it (which fails with a somewhat confusing "No # create event in auth events") - assert ( - builder.type == EventTypes.Create or len(prev_event_ids) > 0 - ), "Attempting to create an event with no prev_events" + if allow_no_prev_events: + # We allow events with no `prev_events` but it better have some `auth_events` + assert ( + builder.type == EventTypes.Create + # Allow an event to have empty list of prev_event_ids + # only if it has auth_event_ids. + or auth_event_ids + ), "Attempting to create a non-m.room.create event with no prev_events or auth_event_ids" + else: + # we now ought to have some prev_events (unless it's a create event). + assert ( + builder.type == EventTypes.Create or prev_event_ids + ), "Attempting to create a non-m.room.create event with no prev_events" event = await builder.build( prev_event_ids=prev_event_ids, diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 4911a11535..5cb1ff749d 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -14,7 +14,7 @@ import logging from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple -from synapse.api.constants import ReadReceiptEventFields +from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes from synapse.appservice import ApplicationService from synapse.streams import EventSource from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id @@ -178,7 +178,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): for event_id in content.keys(): event_content = content.get(event_id, {}) - m_read = event_content.get("m.read", {}) + m_read = event_content.get(ReceiptTypes.READ, {}) # If m_read is missing copy over the original event_content as there is nothing to process here if not m_read: @@ -206,7 +206,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): # Set new users unless empty if len(new_users.keys()) > 0: - new_event["content"][event_id] = {"m.read": new_users} + new_event["content"][event_id] = {ReceiptTypes.READ: new_users} # Append new_event to visible_events unless empty if len(new_event["content"].keys()) > 0: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index a6dbff637f..447e3ce571 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -658,7 +658,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if block_invite: raise SynapseError(403, "Invites have been disabled on this server") - if prev_event_ids: + # An empty prev_events list is allowed as long as the auth_event_ids are present + if prev_event_ids is not None: return await self._local_membership_update( requester=requester, target=target, diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index f3039c3c3f..bcd10cbb30 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -28,7 +28,7 @@ from typing import ( import attr from prometheus_client import Counter -from synapse.api.constants import AccountDataTypes, EventTypes, Membership +from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS @@ -1046,7 +1046,7 @@ class SyncHandler: last_unread_event_id = await self.store.get_last_receipt_event_id_for_user( user_id=sync_config.user.to_string(), room_id=room_id, - receipt_type="m.read", + receipt_type=ReceiptTypes.READ, ) notifs = await self.store.get_unread_event_push_actions_by_room_for_user( @@ -1662,20 +1662,20 @@ class SyncHandler: ) -> _RoomChanges: """Determine the changes in rooms to report to the user. - Ideally, we want to report all events whose stream ordering `s` lies in the - range `since_token < s <= now_token`, where the two tokens are read from the - sync_result_builder. + This function is a first pass at generating the rooms part of the sync response. + It determines which rooms have changed during the sync period, and categorises + them into four buckets: "knock", "invite", "join" and "leave". - If there are too many events in that range to report, things get complicated. - In this situation we return a truncated list of the most recent events, and - indicate in the response that there is a "gap" of omitted events. Additionally: + 1. Finds all membership changes for the user in the sync period (from + `since_token` up to `now_token`). + 2. Uses those to place the room in one of the four categories above. + 3. Builds a `_RoomChanges` struct to record this, and return that struct. - - we include a "state_delta", to describe the changes in state over the gap, - - we include all membership events applying to the user making the request, - even those in the gap. - - See the spec for the rationale: - https://spec.matrix.org/v1.1/client-server-api/#syncing + For rooms classified as "knock", "invite" or "leave", we just need to report + a single membership event in the eventual /sync response. For "join" we need + to fetch additional non-membership events, e.g. messages in the room. That is + more complicated, so instead we report an intermediary `RoomSyncResultBuilder` + struct, and leave the additional work to `_generate_room_entry`. The sync_result_builder is not modified by this function. """ @@ -1686,16 +1686,6 @@ class SyncHandler: assert since_token - # The spec - # https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync - # notes that membership events need special consideration: - # - # > When a sync is limited, the server MUST return membership events for events - # > in the gap (between since and the start of the returned timeline), regardless - # > as to whether or not they are redundant. - # - # We fetch such events here, but we only seem to use them for categorising rooms - # as newly joined, newly left, invited or knocked. # TODO: we've already called this function and ran this query in # _have_rooms_changed. We could keep the results in memory to avoid a # second query, at the cost of more complicated source code. @@ -2009,6 +1999,23 @@ class SyncHandler: """Populates the `joined` and `archived` section of `sync_result_builder` based on the `room_builder`. + Ideally, we want to report all events whose stream ordering `s` lies in the + range `since_token < s <= now_token`, where the two tokens are read from the + sync_result_builder. + + If there are too many events in that range to report, things get complicated. + In this situation we return a truncated list of the most recent events, and + indicate in the response that there is a "gap" of omitted events. Lots of this + is handled in `_load_filtered_recents`, but some of is handled in this method. + + Additionally: + - we include a "state_delta", to describe the changes in state over the gap, + - we include all membership events applying to the user making the request, + even those in the gap. + + See the spec for the rationale: + https://spec.matrix.org/v1.1/client-server-api/#syncing + Args: sync_result_builder ignored_users: Set of users ignored by user. diff --git a/synapse/http/client.py b/synapse/http/client.py index b5a2d333a6..fbbeceabeb 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -14,6 +14,7 @@ # limitations under the License. import logging import urllib.parse +from http import HTTPStatus from io import BytesIO from typing import ( TYPE_CHECKING, @@ -280,7 +281,9 @@ class BlacklistingAgentWrapper(Agent): ip_address, self._ip_whitelist, self._ip_blacklist ): logger.info("Blocking access to %s due to blacklist" % (ip_address,)) - e = SynapseError(403, "IP address blocked by IP blacklist entry") + e = SynapseError( + HTTPStatus.FORBIDDEN, "IP address blocked by IP blacklist entry" + ) return defer.fail(Failure(e)) return self._agent.request( @@ -719,7 +722,9 @@ class SimpleHttpClient: if response.code > 299: logger.warning("Got %d when downloading %s" % (response.code, url)) - raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN) + raise SynapseError( + HTTPStatus.BAD_GATEWAY, "Got error %d" % (response.code,), Codes.UNKNOWN + ) # TODO: if our Content-Type is HTML or something, just read the first # N bytes into RAM rather than saving it all to disk only to read it @@ -731,12 +736,14 @@ class SimpleHttpClient: ) except BodyExceededMaxSize: raise SynapseError( - 502, + HTTPStatus.BAD_GATEWAY, "Requested file is too large > %r bytes" % (max_size,), Codes.TOO_LARGE, ) except Exception as e: - raise SynapseError(502, ("Failed to download remote body: %s" % e)) from e + raise SynapseError( + HTTPStatus.BAD_GATEWAY, ("Failed to download remote body: %s" % e) + ) from e return ( length, diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 203d723d41..deedde0b5b 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -19,6 +19,7 @@ import random import sys import typing import urllib.parse +from http import HTTPStatus from io import BytesIO, StringIO from typing import ( TYPE_CHECKING, @@ -1154,7 +1155,7 @@ class MatrixFederationHttpClient: request.destination, msg, ) - raise SynapseError(502, msg, Codes.TOO_LARGE) + raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE) except defer.TimeoutError as e: logger.warning( "{%s} [%s] Timed out reading response - %s %s", diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 6dd9b9ad03..e543cc6e01 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -14,6 +14,7 @@ """ This module contains base REST classes for constructing REST servlets. """ import logging +from http import HTTPStatus from typing import ( TYPE_CHECKING, Iterable, @@ -137,11 +138,15 @@ def parse_integer_from_args( return int(args[name_bytes][0]) except Exception: message = "Query parameter %r must be an integer" % (name,) - raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) else: if required: message = "Missing integer query parameter %r" % (name,) - raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM + ) else: return default @@ -246,11 +251,15 @@ def parse_boolean_from_args( message = ( "Boolean query parameter %r must be one of ['true', 'false']" ) % (name,) - raise SynapseError(400, message) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) else: if required: message = "Missing boolean query parameter %r" % (name,) - raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM + ) else: return default @@ -313,7 +322,7 @@ def parse_bytes_from_args( return args[name_bytes][0] elif required: message = "Missing string query parameter %s" % (name,) - raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) + raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM) return default @@ -407,14 +416,16 @@ def _parse_string_value( try: value_str = value.decode(encoding) except ValueError: - raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding)) + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Query parameter %r must be %s" % (name, encoding) + ) if allowed_values is not None and value_str not in allowed_values: message = "Query parameter %r must be one of [%s]" % ( name, ", ".join(repr(v) for v in allowed_values), ) - raise SynapseError(400, message) + raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM) else: return value_str @@ -510,7 +521,9 @@ def parse_strings_from_args( else: if required: message = "Missing string query parameter %r" % (name,) - raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM + ) return default @@ -638,7 +651,7 @@ def parse_json_value_from_request( try: content_bytes = request.content.read() # type: ignore except Exception: - raise SynapseError(400, "Error reading JSON content.") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Error reading JSON content.") if not content_bytes and allow_empty_body: return None @@ -647,7 +660,9 @@ def parse_json_value_from_request( content = json_decoder.decode(content_bytes.decode("utf-8")) except Exception as e: logger.warning("Unable to parse JSON: %s (%s)", e, content_bytes) - raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Content not JSON.", errcode=Codes.NOT_JSON + ) return content @@ -673,7 +688,7 @@ def parse_json_object_from_request( if not isinstance(content, dict): message = "Content must be a JSON object." - raise SynapseError(400, message, errcode=Codes.BAD_JSON) + raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.BAD_JSON) return content @@ -685,7 +700,9 @@ def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None: absent.append(k) if len(absent) > 0: - raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Missing params: %r" % absent, Codes.MISSING_PARAM + ) class RestServlet: @@ -758,10 +775,12 @@ class ResolveRoomIdMixin: resolved_room_id = room_id.to_string() else: raise SynapseError( - 400, "%s was not legal room ID or room alias" % (room_identifier,) + HTTPStatus.BAD_REQUEST, + "%s was not legal room ID or room alias" % (room_identifier,), ) if not resolved_room_id: raise SynapseError( - 400, "Unknown room ID or room alias %s" % room_identifier + HTTPStatus.BAD_REQUEST, + "Unknown room ID or room alias %s" % room_identifier, ) return resolved_room_id, remote_room_hosts diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 9c85200c0f..da641aca47 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Dict +from synapse.api.constants import ReceiptTypes from synapse.events import EventBase from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.storage import Storage @@ -23,7 +24,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - invites = await store.get_invited_rooms_for_local_user(user_id) joins = await store.get_rooms_for_user(user_id) - my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read") + my_receipts_by_room = await store.get_receipts_for_user(user_id, ReceiptTypes.READ) badge = len(invites) diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 26735447a6..7912311d24 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -27,6 +27,7 @@ from synapse.push.pusher import PusherFactory from synapse.replication.http.push import ReplicationRemovePusherRestServlet from synapse.types import JsonDict, RoomStreamToken from synapse.util.async_helpers import concurrently_execute +from synapse.util.threepids import canonicalise_email if TYPE_CHECKING: from synapse.server import HomeServer @@ -113,7 +114,9 @@ class PusherPool: """ if kind == "email": - email_owner = await self.store.get_user_id_by_threepid("email", pushkey) + email_owner = await self.store.get_user_id_by_threepid( + "email", canonicalise_email(pushkey) + ) if email_owner != user_id: raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index c499afd4be..701c609c12 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -108,7 +108,7 @@ class VersionServlet(RestServlet): class PurgeHistoryRestServlet(RestServlet): PATTERNS = admin_patterns( - "/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?" + "/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]*))?$" ) def __init__(self, hs: "HomeServer"): @@ -195,7 +195,7 @@ class PurgeHistoryRestServlet(RestServlet): class PurgeHistoryStatusRestServlet(RestServlet): - PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]+)") + PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]*)$") def __init__(self, hs: "HomeServer"): self.pagination_handler = hs.get_pagination_handler() diff --git a/synapse/rest/admin/background_updates.py b/synapse/rest/admin/background_updates.py index 479672d4d5..6ec00ce0b9 100644 --- a/synapse/rest/admin/background_updates.py +++ b/synapse/rest/admin/background_updates.py @@ -22,7 +22,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, ) from synapse.http.site import SynapseRequest -from synapse.rest.admin._base import admin_patterns, assert_user_is_admin +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.types import JsonDict if TYPE_CHECKING: @@ -41,8 +41,7 @@ class BackgroundUpdateEnabledRestServlet(RestServlet): self._data_stores = hs.get_datastores() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request) - await assert_user_is_admin(self._auth, requester.user) + await assert_requester_is_admin(self._auth, request) # We need to check that all configured databases have updates enabled. # (They *should* all be in sync.) @@ -51,8 +50,7 @@ class BackgroundUpdateEnabledRestServlet(RestServlet): return HTTPStatus.OK, {"enabled": enabled} async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request) - await assert_user_is_admin(self._auth, requester.user) + await assert_requester_is_admin(self._auth, request) body = parse_json_object_from_request(request) @@ -84,8 +82,7 @@ class BackgroundUpdateRestServlet(RestServlet): self._data_stores = hs.get_datastores() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request) - await assert_user_is_admin(self._auth, requester.user) + await assert_requester_is_admin(self._auth, request) # We need to check that all configured databases have updates enabled. # (They *should* all be in sync.) @@ -111,15 +108,14 @@ class BackgroundUpdateRestServlet(RestServlet): class BackgroundUpdateStartJobRestServlet(RestServlet): """Allows to start specific background updates""" - PATTERNS = admin_patterns("/background_updates/start_job") + PATTERNS = admin_patterns("/background_updates/start_job$") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() self._store = hs.get_datastore() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request) - await assert_user_is_admin(self._auth, requester.user) + await assert_requester_is_admin(self._auth, request) body = parse_json_object_from_request(request) assert_params_in_dict(body, ["job_name"]) diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index 2e5a6600d3..d9905ff560 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -42,10 +42,10 @@ class DeviceRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.store = hs.get_datastore() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str, device_id: str @@ -53,7 +53,7 @@ class DeviceRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) @@ -63,6 +63,8 @@ class DeviceRestServlet(RestServlet): device = await self.device_handler.get_device( target_user.to_string(), device_id ) + if device is None: + raise NotFoundError("No device found") return HTTPStatus.OK, device async def on_DELETE( @@ -71,7 +73,7 @@ class DeviceRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) @@ -87,7 +89,7 @@ class DeviceRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) @@ -109,14 +111,10 @@ class DevicesRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2") def __init__(self, hs: "HomeServer"): - """ - Args: - hs: server - """ - self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.store = hs.get_datastore() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str @@ -124,7 +122,7 @@ class DevicesRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) @@ -144,10 +142,10 @@ class DeleteDevicesRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.store = hs.get_datastore() + self.is_mine = hs.is_mine async def on_POST( self, request: SynapseRequest, user_id: str @@ -155,7 +153,7 @@ class DeleteDevicesRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py index 5ee8b11110..38477f8ead 100644 --- a/synapse/rest/admin/event_reports.py +++ b/synapse/rest/admin/event_reports.py @@ -52,7 +52,6 @@ class EventReportsRestServlet(RestServlet): PATTERNS = admin_patterns("/event_reports$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -115,7 +114,6 @@ class EventReportDetailRestServlet(RestServlet): PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py index 744687be35..50d88c9109 100644 --- a/synapse/rest/admin/federation.py +++ b/synapse/rest/admin/federation.py @@ -100,7 +100,7 @@ class DestinationsRestServlet(RestServlet): 200 OK with details of a destination if success otherwise an error. """ - PATTERNS = admin_patterns("/federation/destinations/(?P<destination>[^/]+)$") + PATTERNS = admin_patterns("/federation/destinations/(?P<destination>[^/]*)$") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py index a27110388f..cd697e180e 100644 --- a/synapse/rest/admin/groups.py +++ b/synapse/rest/admin/groups.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) class DeleteGroupAdminRestServlet(RestServlet): """Allows deleting of local groups""" - PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)") + PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)$") def __init__(self, hs: "HomeServer"): self.group_server = hs.get_groups_server_handler() diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 9e23e2d8fc..7236e4027f 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -17,7 +17,7 @@ import logging from http import HTTPStatus from typing import TYPE_CHECKING, Tuple -from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest @@ -41,9 +41,9 @@ class QuarantineMediaInRoom(RestServlet): """ PATTERNS = [ - *admin_patterns("/room/(?P<room_id>[^/]+)/media/quarantine$"), + *admin_patterns("/room/(?P<room_id>[^/]*)/media/quarantine$"), # This path kept around for legacy reasons - *admin_patterns("/quarantine_media/(?P<room_id>[^/]+)"), + *admin_patterns("/quarantine_media/(?P<room_id>[^/]*)$"), ] def __init__(self, hs: "HomeServer"): @@ -71,7 +71,7 @@ class QuarantineMediaByUser(RestServlet): this server. """ - PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine$") + PATTERNS = admin_patterns("/user/(?P<user_id>[^/]*)/media/quarantine$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -99,7 +99,7 @@ class QuarantineMediaByID(RestServlet): """ PATTERNS = admin_patterns( - "/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)" + "/media/quarantine/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$" ) def __init__(self, hs: "HomeServer"): @@ -128,7 +128,7 @@ class UnquarantineMediaByID(RestServlet): """ PATTERNS = admin_patterns( - "/media/unquarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)" + "/media/unquarantine/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$" ) def __init__(self, hs: "HomeServer"): @@ -138,8 +138,7 @@ class UnquarantineMediaByID(RestServlet): async def on_POST( self, request: SynapseRequest, server_name: str, media_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) logging.info( "Remove from quarantine local media by ID: %s/%s", server_name, media_id @@ -154,7 +153,7 @@ class UnquarantineMediaByID(RestServlet): class ProtectMediaByID(RestServlet): """Protect local media from being quarantined.""" - PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)") + PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -163,8 +162,7 @@ class ProtectMediaByID(RestServlet): async def on_POST( self, request: SynapseRequest, media_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) logging.info("Protecting local media by ID: %s", media_id) @@ -177,7 +175,7 @@ class ProtectMediaByID(RestServlet): class UnprotectMediaByID(RestServlet): """Unprotect local media from being quarantined.""" - PATTERNS = admin_patterns("/media/unprotect/(?P<media_id>[^/]+)") + PATTERNS = admin_patterns("/media/unprotect/(?P<media_id>[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -186,8 +184,7 @@ class UnprotectMediaByID(RestServlet): async def on_POST( self, request: SynapseRequest, media_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) logging.info("Unprotecting local media by ID: %s", media_id) @@ -200,7 +197,7 @@ class UnprotectMediaByID(RestServlet): class ListMediaInRoom(RestServlet): """Lists all of the media in a given room.""" - PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media$") + PATTERNS = admin_patterns("/room/(?P<room_id>[^/]*)/media$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -209,10 +206,7 @@ class ListMediaInRoom(RestServlet): async def on_GET( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - is_admin = await self.auth.is_server_admin(requester.user) - if not is_admin: - raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin") + await assert_requester_is_admin(self.auth, request) local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id) @@ -254,7 +248,7 @@ class PurgeMediaCacheRestServlet(RestServlet): class DeleteMediaByID(RestServlet): """Delete local media by a given ID. Removes it from this server.""" - PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)") + PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -286,7 +280,7 @@ class DeleteMediaByDateSize(RestServlet): timestamp and size. """ - PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete$") + PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/delete$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -353,7 +347,7 @@ class UserMediaRestServlet(RestServlet): media that exist given for this user """ - PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$") + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/media$") def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine @@ -403,16 +397,7 @@ class UserMediaRestServlet(RestServlet): request, "order_by", default=MediaSortOrder.CREATED_TS.value, - allowed_values=( - MediaSortOrder.MEDIA_ID.value, - MediaSortOrder.UPLOAD_NAME.value, - MediaSortOrder.CREATED_TS.value, - MediaSortOrder.LAST_ACCESS_TS.value, - MediaSortOrder.MEDIA_LENGTH.value, - MediaSortOrder.MEDIA_TYPE.value, - MediaSortOrder.QUARANTINED_BY.value, - MediaSortOrder.SAFE_FROM_QUARANTINE.value, - ), + allowed_values=[sort_order.value for sort_order in MediaSortOrder], ) direction = parse_string( request, "dir", default="f", allowed_values=("f", "b") @@ -470,16 +455,7 @@ class UserMediaRestServlet(RestServlet): request, "order_by", default=MediaSortOrder.CREATED_TS.value, - allowed_values=( - MediaSortOrder.MEDIA_ID.value, - MediaSortOrder.UPLOAD_NAME.value, - MediaSortOrder.CREATED_TS.value, - MediaSortOrder.LAST_ACCESS_TS.value, - MediaSortOrder.MEDIA_LENGTH.value, - MediaSortOrder.MEDIA_TYPE.value, - MediaSortOrder.QUARANTINED_BY.value, - MediaSortOrder.SAFE_FROM_QUARANTINE.value, - ), + allowed_values=[sort_order.value for sort_order in MediaSortOrder], ) direction = parse_string( request, "dir", default="f", allowed_values=("f", "b") diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py index 891b98c088..04948b6408 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py @@ -70,7 +70,6 @@ class ListRegistrationTokensRestServlet(RestServlet): PATTERNS = admin_patterns("/registration_tokens$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -109,7 +108,6 @@ class NewRegistrationTokenRestServlet(RestServlet): PATTERNS = admin_patterns("/registration_tokens/new$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -260,7 +258,6 @@ class RegistrationTokenRestServlet(RestServlet): PATTERNS = admin_patterns("/registration_tokens/(?P<token>[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.clock = hs.get_clock() self.auth = hs.get_auth() self.store = hs.get_datastore() diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 829e86675a..17c6df1cc8 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -61,7 +61,7 @@ class RoomRestV2Servlet(RestServlet): If 'purge' is true, it will remove all traces of a room from the database. """ - PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$", "v2") + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)$", "v2") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() @@ -123,7 +123,7 @@ class RoomRestV2Servlet(RestServlet): class DeleteRoomStatusByRoomIdRestServlet(RestServlet): """Get the status of the delete room background task.""" - PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete_status$", "v2") + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/delete_status$", "v2") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() @@ -160,7 +160,7 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet): class DeleteRoomStatusByDeleteIdRestServlet(RestServlet): """Get the status of the delete room background task.""" - PATTERNS = admin_patterns("/rooms/delete_status/(?P<delete_id>[^/]+)$", "v2") + PATTERNS = admin_patterns("/rooms/delete_status/(?P<delete_id>[^/]*)$", "v2") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() @@ -193,35 +193,17 @@ class ListRoomRestServlet(RestServlet): self.admin_handler = hs.get_admin_handler() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) # Extract query parameters start = parse_integer(request, "from", default=0) limit = parse_integer(request, "limit", default=100) - order_by = parse_string(request, "order_by", default=RoomSortOrder.NAME.value) - if order_by not in ( - RoomSortOrder.ALPHABETICAL.value, - RoomSortOrder.SIZE.value, - RoomSortOrder.NAME.value, - RoomSortOrder.CANONICAL_ALIAS.value, - RoomSortOrder.JOINED_MEMBERS.value, - RoomSortOrder.JOINED_LOCAL_MEMBERS.value, - RoomSortOrder.VERSION.value, - RoomSortOrder.CREATOR.value, - RoomSortOrder.ENCRYPTION.value, - RoomSortOrder.FEDERATABLE.value, - RoomSortOrder.PUBLIC.value, - RoomSortOrder.JOIN_RULES.value, - RoomSortOrder.GUEST_ACCESS.value, - RoomSortOrder.HISTORY_VISIBILITY.value, - RoomSortOrder.STATE_EVENTS.value, - ): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Unknown value for order_by: %s" % (order_by,), - errcode=Codes.INVALID_PARAM, - ) + order_by = parse_string( + request, + "order_by", + default=RoomSortOrder.NAME.value, + allowed_values=[sort_order.value for sort_order in RoomSortOrder], + ) search_term = parse_string(request, "search_term", encoding="utf-8") if search_term == "": @@ -292,10 +274,9 @@ class RoomRestServlet(RestServlet): TODO: Add on_POST to allow room creation without joining the room """ - PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$") + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.room_shutdown_handler = hs.get_room_shutdown_handler() @@ -397,10 +378,9 @@ class RoomMembersRestServlet(RestServlet): Get members list of a room. """ - PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members") + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/members$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -424,10 +404,9 @@ class RoomStateRestServlet(RestServlet): Get full state within a room. """ - PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/state") + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/state$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -436,8 +415,7 @@ class RoomStateRestServlet(RestServlet): async def on_GET( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) ret = await self.store.get_room(room_id) if not ret: @@ -454,14 +432,14 @@ class RoomStateRestServlet(RestServlet): class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): - PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)") + PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)$") def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.hs = hs self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() self.state_handler = hs.get_state_handler() + self.is_mine = hs.is_mine async def on_POST( self, request: SynapseRequest, room_identifier: str @@ -477,7 +455,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): assert_params_in_dict(content, ["user_id"]) target_user = UserID.from_string(content["user_id"]) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError( HTTPStatus.BAD_REQUEST, "This endpoint can only be used with local users", @@ -542,11 +520,10 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): } """ - PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin") + PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin$") def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.event_creation_handler = hs.get_event_creation_handler() @@ -688,19 +665,17 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities """ - PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities") + PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities$") def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() async def on_DELETE( self, request: SynapseRequest, room_identifier: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) room_id, _ = await self.resolve_room_id(room_identifier) @@ -710,8 +685,7 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): async def on_GET( self, request: SynapseRequest, room_identifier: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) room_id, _ = await self.resolve_room_id(room_identifier) @@ -793,7 +767,7 @@ class BlockRoomRestServlet(RestServlet): On GET: Get blocking status of room and user who has blocked this room. """ - PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/block$") + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/block$") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py index b295fb078b..15da9cd881 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py @@ -52,11 +52,11 @@ class SendServerNoticeServlet(RestServlet): """ def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.server_notices_manager = hs.get_server_notices_manager() self.admin_handler = hs.get_admin_handler() self.txns = HttpTransactionCache(hs) + self.is_mine = hs.is_mine def register(self, json_resource: HttpServer) -> None: PATTERN = "/send_server_notice" @@ -88,7 +88,7 @@ class SendServerNoticeServlet(RestServlet): ) target_user = UserID.from_string(body["user_id"]) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError( HTTPStatus.BAD_REQUEST, "Server notices can only be sent to local users" ) diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py index ca41fd45f2..7a6546372e 100644 --- a/synapse/rest/admin/statistics.py +++ b/synapse/rest/admin/statistics.py @@ -37,7 +37,6 @@ class UserMediaStatisticsRestServlet(RestServlet): PATTERNS = admin_patterns("/statistics/users/media$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -45,19 +44,16 @@ class UserMediaStatisticsRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) order_by = parse_string( - request, "order_by", default=UserSortOrder.USER_ID.value + request, + "order_by", + default=UserSortOrder.USER_ID.value, + allowed_values=( + UserSortOrder.MEDIA_LENGTH.value, + UserSortOrder.MEDIA_COUNT.value, + UserSortOrder.USER_ID.value, + UserSortOrder.DISPLAYNAME.value, + ), ) - if order_by not in ( - UserSortOrder.MEDIA_LENGTH.value, - UserSortOrder.MEDIA_COUNT.value, - UserSortOrder.USER_ID.value, - UserSortOrder.DISPLAYNAME.value, - ): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Unknown value for order_by: %s" % (order_by,), - errcode=Codes.INVALID_PARAM, - ) start = parse_integer(request, "from", default=0) if start < 0: diff --git a/synapse/rest/admin/username_available.py b/synapse/rest/admin/username_available.py index 2bf1472967..5353dc3682 100644 --- a/synapse/rest/admin/username_available.py +++ b/synapse/rest/admin/username_available.py @@ -37,7 +37,7 @@ class UsernameAvailableRestServlet(RestServlet): } """ - PATTERNS = admin_patterns("/username_available") + PATTERNS = admin_patterns("/username_available$") def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 2a60b602b1..db678da4cf 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -66,7 +66,6 @@ class UsersRestServletV2(RestServlet): """ def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() @@ -126,7 +125,7 @@ class UsersRestServletV2(RestServlet): class UserRestServletV2(RestServlet): - PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)$", "v2") + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$", "v2") """Get request to list user details. This needs user to have administrator access in Synapse. @@ -414,7 +413,7 @@ class UserRegisterServlet(RestServlet): nonce to the time it was generated, in int seconds. """ - PATTERNS = admin_patterns("/register") + PATTERNS = admin_patterns("/register$") NONCE_TIMEOUT = 60 def __init__(self, hs: "HomeServer"): @@ -561,9 +560,9 @@ class WhoisRestServlet(RestServlet): ] def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str @@ -575,7 +574,7 @@ class WhoisRestServlet(RestServlet): if target_user != auth_user: await assert_user_is_admin(self.auth, auth_user) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user") ret = await self.admin_handler.get_whois(target_user) @@ -584,7 +583,7 @@ class WhoisRestServlet(RestServlet): class DeactivateAccountRestServlet(RestServlet): - PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)") + PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)$") def __init__(self, hs: "HomeServer"): self._deactivate_account_handler = hs.get_deactivate_account_handler() @@ -630,7 +629,6 @@ class AccountValidityRenewServlet(RestServlet): PATTERNS = admin_patterns("/account_validity/validity$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.account_activity_handler = hs.get_account_validity_handler() self.auth = hs.get_auth() @@ -674,11 +672,10 @@ class ResetPasswordRestServlet(RestServlet): 200 OK with empty object if success otherwise an error. """ - PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)") + PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() - self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self._set_password_handler = hs.get_set_password_handler() @@ -718,12 +715,12 @@ class SearchUsersRestServlet(RestServlet): 200 OK with json object {list[dict[str, Any]], count} or empty object. """ - PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)") + PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, target_user_id: str @@ -740,7 +737,7 @@ class SearchUsersRestServlet(RestServlet): # if not is_admin and target_user != auth_user: # raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin") - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only users a local user") term = parse_string(request, "term", required=True) @@ -779,9 +776,9 @@ class UserAdminServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str @@ -790,7 +787,7 @@ class UserAdminServlet(RestServlet): target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be admins of this homeserver", @@ -813,7 +810,7 @@ class UserAdminServlet(RestServlet): assert_params_in_dict(body, ["admin"]) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be admins of this homeserver", @@ -834,7 +831,7 @@ class UserMembershipRestServlet(RestServlet): Get room list of an user. """ - PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$") + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/joined_rooms$") def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine @@ -909,10 +906,10 @@ class UserTokenRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/login$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() + self.is_mine_id = hs.is_mine_id async def on_POST( self, request: SynapseRequest, user_id: str @@ -921,7 +918,7 @@ class UserTokenRestServlet(RestServlet): await assert_user_is_admin(self.auth, requester.user) auth_user = requester.user - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be logged in as" ) @@ -975,19 +972,19 @@ class ShadowBanRestServlet(RestServlet): {} """ - PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban") + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine_id = hs.is_mine_id async def on_POST( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned" ) @@ -1001,7 +998,7 @@ class ShadowBanRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned" ) @@ -1027,19 +1024,19 @@ class RateLimitRestServlet(RestServlet): } """ - PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/override_ratelimit") + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/override_ratelimit$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine_id = hs.is_mine_id async def on_GET( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") if not await self.store.get_user_by_id(user_id): @@ -1068,7 +1065,7 @@ class RateLimitRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited" ) @@ -1113,7 +1110,7 @@ class RateLimitRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited" ) diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 8566dc5cb5..ad6fd6492b 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -17,6 +17,7 @@ import logging from typing import TYPE_CHECKING, Tuple from synapse.api import errors +from synapse.api.errors import NotFoundError from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -24,10 +25,9 @@ from synapse.http.servlet import ( parse_json_object_from_request, ) from synapse.http.site import SynapseRequest +from synapse.rest.client._base import client_patterns, interactive_auth_handler from synapse.types import JsonDict -from ._base import client_patterns, interactive_auth_handler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -116,6 +116,8 @@ class DeviceRestServlet(RestServlet): device = await self.device_handler.get_device( requester.user.to_string(), device_id ) + if device is None: + raise NotFoundError("No device found") return 200, device @interactive_auth_handler diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index d1d8a984c6..b12a332776 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -15,6 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple +from synapse.api.constants import ReceiptTypes from synapse.events.utils import format_event_for_client_v2_without_room_id from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string @@ -54,7 +55,7 @@ class NotificationsServlet(RestServlet): ) receipts_by_room = await self.store.get_receipts_for_user_with_orderings( - user_id, "m.read" + user_id, ReceiptTypes.READ ) notif_event_ids = [pa["event_id"] for pa in push_actions] diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index 43c04fac6f..f51be511d1 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple -from synapse.api.constants import ReadReceiptEventFields +from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes from synapse.api.errors import Codes, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -48,7 +48,7 @@ class ReadMarkerRestServlet(RestServlet): await self.presence_handler.bump_presence_active_time(requester.user) body = parse_json_object_from_request(request) - read_event_id = body.get("m.read", None) + read_event_id = body.get(ReceiptTypes.READ, None) hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False) if not isinstance(hidden, bool): @@ -62,7 +62,7 @@ class ReadMarkerRestServlet(RestServlet): if read_event_id: await self.receipts_handler.received_client_receipt( room_id, - "m.read", + ReceiptTypes.READ, user_id=requester.user.to_string(), event_id=read_event_id, hidden=hidden, diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 2b25b9aad6..b24ad2d1be 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -16,7 +16,7 @@ import logging import re from typing import TYPE_CHECKING, Tuple -from synapse.api.constants import ReadReceiptEventFields +from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes from synapse.api.errors import Codes, SynapseError from synapse.http import get_request_user_agent from synapse.http.server import HttpServer @@ -53,7 +53,7 @@ class ReceiptRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - if receipt_type != "m.read": + if receipt_type != ReceiptTypes.READ: raise SynapseError(400, "Receipt type must be 'm.read'") # Do not allow older SchildiChat and Element Android clients (prior to Element/1.[012].x) to send an empty body. diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index fc4e6921c5..ffa37ef06c 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -212,6 +212,7 @@ class RelationPaginationServlet(RestServlet): pagination_chunk = await self.store.get_relations_for_event( event_id=parent_id, + room_id=room_id, relation_type=relation_type, event_type=event_type, limit=limit, @@ -317,6 +318,7 @@ class RelationAggregationPaginationServlet(RestServlet): pagination_chunk = await self.store.get_aggregation_groups_for_event( event_id=parent_id, + room_id=room_id, event_type=event_type, limit=limit, from_token=from_token, @@ -383,7 +385,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet): # This checks that a) the event exists and b) the user is allowed to # view it. - await self.event_handler.get_event(requester.user, room_id, parent_id) + event = await self.event_handler.get_event(requester.user, room_id, parent_id) + if event is None: + raise SynapseError(404, "Unknown parent event.") if relation_type != RelationTypes.ANNOTATION: raise SynapseError(400, "Relation type must be 'annotation'") @@ -402,6 +406,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): result = await self.store.get_relations_for_event( event_id=parent_id, + room_id=room_id, relation_type=relation_type, event_type=event_type, aggregation_key=key, diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index f48e2e6ca2..60719331b6 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -187,7 +187,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): state_key: str, txn_id: Optional[str] = None, ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request, allow_guest=True) if txn_id: set_tag("txn_id", txn_id) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 88e4f5e063..dd90ffa123 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -293,6 +293,9 @@ class SyncRestServlet(RestServlet): response[ "org.matrix.msc2732.device_unused_fallback_key_types" ] = sync_result.device_unused_fallback_key_types + response[ + "device_unused_fallback_key_types" + ] = sync_result.device_unused_fallback_key_types if joined: response["rooms"][Membership.JOIN] = joined diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 0693d39006..5552dd3c5c 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -896,6 +896,9 @@ class DatabasePool: ) -> None: """Executes an INSERT query on the named table. + The input is given as a list of dicts, with one dict per row. + Generally simple_insert_many_values should be preferred for new code. + Args: table: string giving the table name values: dict of new column names and values for them @@ -909,6 +912,9 @@ class DatabasePool: ) -> None: """Executes an INSERT query on the named table. + The input is given as a list of dicts, with one dict per row. + Generally simple_insert_many_values_txn should be preferred for new code. + Args: txn: The transaction to use. table: string giving the table name @@ -933,23 +939,66 @@ class DatabasePool: if k != keys[0]: raise RuntimeError("All items must have the same keys") + return DatabasePool.simple_insert_many_values_txn(txn, table, keys[0], vals) + + async def simple_insert_many_values( + self, + table: str, + keys: Collection[str], + values: Iterable[Iterable[Any]], + desc: str, + ) -> None: + """Executes an INSERT query on the named table. + + The input is given as a list of rows, where each row is a list of values. + (Actually any iterable is fine.) + + Args: + table: string giving the table name + keys: list of column names + values: for each row, a list of values in the same order as `keys` + desc: description of the transaction, for logging and metrics + """ + await self.runInteraction( + desc, self.simple_insert_many_values_txn, table, keys, values + ) + + @staticmethod + def simple_insert_many_values_txn( + txn: LoggingTransaction, + table: str, + keys: Collection[str], + values: Iterable[Iterable[Any]], + ) -> None: + """Executes an INSERT query on the named table. + + The input is given as a list of rows, where each row is a list of values. + (Actually any iterable is fine.) + + Args: + txn: The transaction to use. + table: string giving the table name + keys: list of column names + values: for each row, a list of values in the same order as `keys` + """ + if isinstance(txn.database_engine, PostgresEngine): # We use `execute_values` as it can be a lot faster than `execute_batch`, # but it's only available on postgres. sql = "INSERT INTO %s (%s) VALUES ?" % ( table, - ", ".join(k for k in keys[0]), + ", ".join(k for k in keys), ) - txn.execute_values(sql, vals, fetch=False) + txn.execute_values(sql, values, fetch=False) else: sql = "INSERT INTO %s (%s) VALUES(%s)" % ( table, - ", ".join(k for k in keys[0]), - ", ".join("?" for _ in keys[0]), + ", ".join(k for k in keys), + ", ".join("?" for _ in keys), ) - txn.execute_batch(sql, vals) + txn.execute_batch(sql, values) async def simple_upsert( self, diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index afc516a978..54357343f2 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -101,7 +101,9 @@ class DeviceWorkerStore(SQLBaseStore): "count_devices_by_users", count_devices_by_users_txn, user_ids ) - async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: + async def get_device( + self, user_id: str, device_id: str + ) -> Optional[Dict[str, Any]]: """Retrieve a device. Only returns devices that are not marked as hidden. @@ -109,17 +111,15 @@ class DeviceWorkerStore(SQLBaseStore): user_id: The ID of the user which owns the device device_id: The ID of the device to retrieve Returns: - A dict containing the device information - Raises: - StoreError: if the device is not found - See also: - `get_device_opt` which returns None instead if the device is not found + A dict containing the device information, or `None` if the device does not + exist. """ return await self.db_pool.simple_select_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), desc="get_device", + allow_none=True, ) async def get_device_opt( diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 4e528612ea..5184e6bf85 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -19,6 +19,7 @@ from collections import OrderedDict from typing import ( TYPE_CHECKING, Any, + Collection, Dict, Generator, Iterable, @@ -1319,14 +1320,13 @@ class PersistEventsStore: return [ec for ec in events_and_contexts if ec[0] not in to_remove] - def _store_event_txn(self, txn, events_and_contexts): + def _store_event_txn( + self, + txn: LoggingTransaction, + events_and_contexts: Collection[Tuple[EventBase, EventContext]], + ) -> None: """Insert new events into the event, event_json, redaction and state_events tables. - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting """ if not events_and_contexts: @@ -1339,46 +1339,58 @@ class PersistEventsStore: d.pop("redacted_because", None) return d - self.db_pool.simple_insert_many_txn( + self.db_pool.simple_insert_many_values_txn( txn, table="event_json", - values=[ - { - "event_id": event.event_id, - "room_id": event.room_id, - "internal_metadata": json_encoder.encode( - event.internal_metadata.get_dict() - ), - "json": json_encoder.encode(event_dict(event)), - "format_version": event.format_version, - } + keys=("event_id", "room_id", "internal_metadata", "json", "format_version"), + values=( + ( + event.event_id, + event.room_id, + json_encoder.encode(event.internal_metadata.get_dict()), + json_encoder.encode(event_dict(event)), + event.format_version, + ) for event, _ in events_and_contexts - ], + ), ) - self.db_pool.simple_insert_many_txn( + self.db_pool.simple_insert_many_values_txn( txn, table="events", - values=[ - { - "instance_name": self._instance_name, - "stream_ordering": event.internal_metadata.stream_ordering, - "topological_ordering": event.depth, - "depth": event.depth, - "event_id": event.event_id, - "room_id": event.room_id, - "type": event.type, - "processed": True, - "outlier": event.internal_metadata.is_outlier(), - "origin_server_ts": int(event.origin_server_ts), - "received_ts": self._clock.time_msec(), - "sender": event.sender, - "contains_url": ( - "url" in event.content and isinstance(event.content["url"], str) - ), - } + keys=( + "instance_name", + "stream_ordering", + "topological_ordering", + "depth", + "event_id", + "room_id", + "type", + "processed", + "outlier", + "origin_server_ts", + "received_ts", + "sender", + "contains_url", + ), + values=( + ( + self._instance_name, + event.internal_metadata.stream_ordering, + event.depth, # topological_ordering + event.depth, # depth + event.event_id, + event.room_id, + event.type, + True, # processed + event.internal_metadata.is_outlier(), + int(event.origin_server_ts), + self._clock.time_msec(), + event.sender, + "url" in event.content and isinstance(event.content["url"], str), + ) for event, _ in events_and_contexts - ], + ), ) # If we're persisting an unredacted event we go and ensure @@ -1397,27 +1409,15 @@ class PersistEventsStore: ) txn.execute(sql + clause, [False] + args) - state_events_and_contexts = [ - ec for ec in events_and_contexts if ec[0].is_state() - ] - - state_values = [] - for event, _ in state_events_and_contexts: - vals = { - "event_id": event.event_id, - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, - } - - # TODO: How does this work with backfilling? - if hasattr(event, "replaces_state"): - vals["prev_state"] = event.replaces_state - - state_values.append(vals) - - self.db_pool.simple_insert_many_txn( - txn, table="state_events", values=state_values + self.db_pool.simple_insert_many_values_txn( + txn, + table="state_events", + keys=("event_id", "room_id", "type", "state_key"), + values=( + (event.event_id, event.room_id, event.type, event.state_key) + for event, _ in events_and_contexts + if event.is_state() + ), ) def _store_rejected_events_txn(self, txn, events_and_contexts): @@ -1780,10 +1780,14 @@ class PersistEventsStore: ) if rel_type == RelationTypes.REPLACE: - txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) + txn.call_after( + self.store.get_applicable_edit.invalidate, (parent_id, event.room_id) + ) if rel_type == RelationTypes.THREAD: - txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,)) + txn.call_after( + self.store.get_thread_summary.invalidate, (parent_id, event.room_id) + ) def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase): """Handles keeping track of insertion events and edges/connections. diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index b5284e4f67..3c98ef876f 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -18,6 +18,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_in_list_sql_clause from synapse.util.caches.descriptors import cached +from synapse.util.threepids import canonicalise_email if TYPE_CHECKING: from synapse.server import HomeServer @@ -103,7 +104,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): : self.hs.config.server.max_mau_value ]: user_id = await self.hs.get_datastore().get_user_id_by_threepid( - tp["medium"], tp["address"] + tp["medium"], canonicalise_email(tp["address"]) ) if user_id: users.append(user_id) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index c99f8aebdb..9c5625c8bb 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -14,14 +14,25 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, +) from twisted.internet import defer +from synapse.api.constants import ReceiptTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import JsonDict @@ -78,17 +89,13 @@ class ReceiptsWorkerStore(SQLBaseStore): "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() ) - def get_max_receipt_stream_id(self): - """Get the current max stream ID for receipts stream - - Returns: - int - """ + def get_max_receipt_stream_id(self) -> int: + """Get the current max stream ID for receipts stream""" return self._receipts_id_gen.get_current_token() @cached() - async def get_users_with_read_receipts_in_room(self, room_id): - receipts = await self.get_receipts_for_room(room_id, "m.read") + async def get_users_with_read_receipts_in_room(self, room_id: str) -> Set[str]: + receipts = await self.get_receipts_for_room(room_id, ReceiptTypes.READ) return {r["user_id"] for r in receipts} @cached(num_args=2) @@ -119,7 +126,9 @@ class ReceiptsWorkerStore(SQLBaseStore): ) @cached(num_args=2) - async def get_receipts_for_user(self, user_id, receipt_type): + async def get_receipts_for_user( + self, user_id: str, receipt_type: str + ) -> Dict[str, str]: rows = await self.db_pool.simple_select_list( table="receipts_linearized", keyvalues={"user_id": user_id, "receipt_type": receipt_type}, @@ -129,8 +138,10 @@ class ReceiptsWorkerStore(SQLBaseStore): return {row["room_id"]: row["event_id"] for row in rows} - async def get_receipts_for_user_with_orderings(self, user_id, receipt_type): - def f(txn): + async def get_receipts_for_user_with_orderings( + self, user_id: str, receipt_type: str + ) -> JsonDict: + def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]: sql = ( "SELECT rl.room_id, rl.event_id," " e.topological_ordering, e.stream_ordering" @@ -209,10 +220,10 @@ class ReceiptsWorkerStore(SQLBaseStore): @cached(num_args=3, tree=True) async def _get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None - ) -> List[dict]: + ) -> List[JsonDict]: """See get_linearized_receipts_for_room""" - def f(txn): + def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: if from_key: sql = ( "SELECT * FROM receipts_linearized WHERE" @@ -250,11 +261,13 @@ class ReceiptsWorkerStore(SQLBaseStore): list_name="room_ids", num_args=3, ) - async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): + async def _get_linearized_receipts_for_rooms( + self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None + ) -> Dict[str, List[JsonDict]]: if not room_ids: return {} - def f(txn): + def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: if from_key: sql = """ SELECT * FROM receipts_linearized WHERE @@ -323,7 +336,7 @@ class ReceiptsWorkerStore(SQLBaseStore): A dictionary of roomids to a list of receipts. """ - def f(txn): + def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: if from_key: sql = """ SELECT * FROM receipts_linearized WHERE @@ -379,7 +392,7 @@ class ReceiptsWorkerStore(SQLBaseStore): if last_id == current_id: return defer.succeed([]) - def _get_users_sent_receipts_between_txn(txn): + def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]: sql = """ SELECT DISTINCT user_id FROM receipts_linearized WHERE ? < stream_id AND stream_id <= ? @@ -419,7 +432,9 @@ class ReceiptsWorkerStore(SQLBaseStore): if last_id == current_id: return [], current_id, False - def get_all_updated_receipts_txn(txn): + def get_all_updated_receipts_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, list]], int, bool]: sql = """ SELECT stream_id, room_id, receipt_type, user_id, event_id, data FROM receipts_linearized @@ -446,8 +461,8 @@ class ReceiptsWorkerStore(SQLBaseStore): def _invalidate_get_users_with_receipts_in_room( self, room_id: str, receipt_type: str, user_id: str - ): - if receipt_type != "m.read": + ) -> None: + if receipt_type != ReceiptTypes.READ: return res = self.get_users_with_read_receipts_in_room.cache.get_immediate( @@ -461,7 +476,9 @@ class ReceiptsWorkerStore(SQLBaseStore): self.get_users_with_read_receipts_in_room.invalidate((room_id,)) - def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): + def invalidate_caches_for_receipt( + self, room_id: str, receipt_type: str, user_id: str + ) -> None: self.get_receipts_for_user.invalidate((user_id, receipt_type)) self._get_linearized_receipts_for_room.invalidate((room_id,)) self.get_last_receipt_event_id_for_user.invalidate( @@ -482,11 +499,18 @@ class ReceiptsWorkerStore(SQLBaseStore): return super().process_replication_rows(stream_name, instance_name, token, rows) def insert_linearized_receipt_txn( - self, txn, room_id, receipt_type, user_id, event_id, data, stream_id - ): + self, + txn: LoggingTransaction, + room_id: str, + receipt_type: str, + user_id: str, + event_id: str, + data: JsonDict, + stream_id: int, + ) -> Optional[int]: """Inserts a read-receipt into the database if it's newer than the current RR - Returns: int|None + Returns: None if the RR is older than the current RR otherwise, the rx timestamp of the event that the RR corresponds to (or 0 if the event is unknown) @@ -550,7 +574,7 @@ class ReceiptsWorkerStore(SQLBaseStore): lock=False, ) - if receipt_type == "m.read" and stream_ordering is not None: + if receipt_type == ReceiptTypes.READ and stream_ordering is not None: self._remove_old_push_actions_before_txn( txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering ) @@ -580,7 +604,7 @@ class ReceiptsWorkerStore(SQLBaseStore): else: # we need to points in graph -> linearized form. # TODO: Make this better. - def graph_to_linear(txn): + def graph_to_linear(txn: LoggingTransaction) -> str: clause, args = make_in_list_sql_clause( self.database_engine, "event_id", event_ids ) @@ -634,11 +658,16 @@ class ReceiptsWorkerStore(SQLBaseStore): return stream_id, max_persisted_id async def insert_graph_receipt( - self, room_id, receipt_type, user_id, event_ids, data - ): + self, + room_id: str, + receipt_type: str, + user_id: str, + event_ids: List[str], + data: JsonDict, + ) -> None: assert self._can_write_to_receipts - return await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "insert_graph_receipt", self.insert_graph_receipt_txn, room_id, @@ -649,8 +678,14 @@ class ReceiptsWorkerStore(SQLBaseStore): ) def insert_graph_receipt_txn( - self, txn, room_id, receipt_type, user_id, event_ids, data - ): + self, + txn: LoggingTransaction, + room_id: str, + receipt_type: str, + user_id: str, + event_ids: List[str], + data: JsonDict, + ) -> None: assert self._can_write_to_receipts txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index e1ddf06916..86c3425716 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -856,7 +856,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): Args: medium: threepid medium e.g. email - address: threepid address e.g. me@example.com + address: threepid address e.g. me@example.com. This must already be + in canonical form. Returns: The user ID or None if no user id/threepid mapping exists diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 0a43acda07..3368a8b084 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -37,6 +37,7 @@ class RelationsWorkerStore(SQLBaseStore): async def get_relations_for_event( self, event_id: str, + room_id: str, relation_type: Optional[str] = None, event_type: Optional[str] = None, aggregation_key: Optional[str] = None, @@ -49,6 +50,7 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id: Fetch events that relate to this event ID. + room_id: The room the event belongs to. relation_type: Only fetch events with this relation type, if given. event_type: Only fetch events with this event type, if given. aggregation_key: Only fetch events with this aggregation key, if given. @@ -63,8 +65,8 @@ class RelationsWorkerStore(SQLBaseStore): the form `{"event_id": "..."}`. """ - where_clause = ["relates_to_id = ?"] - where_args: List[Union[str, int]] = [event_id] + where_clause = ["relates_to_id = ?", "room_id = ?"] + where_args: List[Union[str, int]] = [event_id, room_id] if relation_type is not None: where_clause.append("relation_type = ?") @@ -199,6 +201,7 @@ class RelationsWorkerStore(SQLBaseStore): async def get_aggregation_groups_for_event( self, event_id: str, + room_id: str, event_type: Optional[str] = None, limit: int = 5, direction: str = "b", @@ -213,6 +216,7 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id: Fetch events that relate to this event ID. + room_id: The room the event belongs to. event_type: Only fetch events with this event type, if given. limit: Only fetch the `limit` groups. direction: Whether to fetch the highest count first (`"b"`) or @@ -225,8 +229,12 @@ class RelationsWorkerStore(SQLBaseStore): `type`, `key` and `count` fields. """ - where_clause = ["relates_to_id = ?", "relation_type = ?"] - where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION] + where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"] + where_args: List[Union[str, int]] = [ + event_id, + room_id, + RelationTypes.ANNOTATION, + ] if event_type: where_clause.append("type = ?") @@ -288,7 +296,9 @@ class RelationsWorkerStore(SQLBaseStore): ) @cached() - async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: + async def get_applicable_edit( + self, event_id: str, room_id: str + ) -> Optional[EventBase]: """Get the most recent edit (if any) that has happened for the given event. @@ -296,6 +306,7 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id: The original event ID + room_id: The original event's room ID Returns: The most recent edit, if any. @@ -317,13 +328,14 @@ class RelationsWorkerStore(SQLBaseStore): WHERE relates_to_id = ? AND relation_type = ? + AND edit.room_id = ? AND edit.type = 'm.room.message' ORDER by edit.origin_server_ts DESC, edit.event_id DESC LIMIT 1 """ def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]: - txn.execute(sql, (event_id, RelationTypes.REPLACE)) + txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id)) row = txn.fetchone() if row: return row[0] @@ -340,13 +352,14 @@ class RelationsWorkerStore(SQLBaseStore): @cached() async def get_thread_summary( - self, event_id: str + self, event_id: str, room_id: str ) -> Tuple[int, Optional[EventBase]]: """Get the number of threaded replies, the senders of those replies, and the latest reply (if any) for the given event. Args: - event_id: The original event ID + event_id: Summarize the thread related to this event ID. + room_id: The room the event belongs to. Returns: The number of items in the thread and the most recent response, if any. @@ -363,12 +376,13 @@ class RelationsWorkerStore(SQLBaseStore): INNER JOIN events USING (event_id) WHERE relates_to_id = ? + AND room_id = ? AND relation_type = ? ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT 1 """ - txn.execute(sql, (event_id, RelationTypes.THREAD)) + txn.execute(sql, (event_id, room_id, RelationTypes.THREAD)) row = txn.fetchone() if row is None: return 0, None @@ -378,11 +392,13 @@ class RelationsWorkerStore(SQLBaseStore): sql = """ SELECT COALESCE(COUNT(event_id), 0) FROM event_relations + INNER JOIN events USING (event_id) WHERE relates_to_id = ? + AND room_id = ? AND relation_type = ? """ - txn.execute(sql, (event_id, RelationTypes.THREAD)) + txn.execute(sql, (event_id, room_id, RelationTypes.THREAD)) count = txn.fetchone()[0] # type: ignore[index] return count, latest_event_id diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 50d08094d5..2a3d47185a 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 66 # remember to update the list below when updating +SCHEMA_VERSION = 67 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -50,6 +50,9 @@ Changes in SCHEMA_VERSION = 65: Changes in SCHEMA_VERSION = 66: - Queries on state_key columns are now disambiguated (ie, the codebase can handle the `events` table having a `state_key` column). + +Changes in SCHEMA_VERSION = 67: + - state_events.prev_state is no longer written to. """ diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index f0723892e4..ddcf3ee348 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -161,8 +161,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): def test_fallback_key(self): local_user = "@boris:" + self.hs.hostname device_id = "xyz" - fallback_key = {"alg1:k1": "key1"} - fallback_key2 = {"alg1:k2": "key2"} + fallback_key = {"alg1:k1": "fallback_key1"} + fallback_key2 = {"alg1:k2": "fallback_key2"} + fallback_key3 = {"alg1:k2": "fallback_key3"} otk = {"alg1:k2": "key2"} # we shouldn't have any unused fallback keys yet @@ -175,7 +176,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): self.handler.upload_keys_for_user( local_user, device_id, - {"org.matrix.msc2732.fallback_keys": fallback_key}, + {"fallback_keys": fallback_key}, ) ) @@ -220,7 +221,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): self.handler.upload_keys_for_user( local_user, device_id, - {"org.matrix.msc2732.fallback_keys": fallback_key}, + {"fallback_keys": fallback_key}, ) ) @@ -234,7 +235,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): self.handler.upload_keys_for_user( local_user, device_id, - {"org.matrix.msc2732.fallback_keys": fallback_key2}, + {"fallback_keys": fallback_key2}, ) ) @@ -271,6 +272,25 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}}, ) + # using the unstable prefix should also set the fallback key + self.get_success( + self.handler.upload_keys_for_user( + local_user, + device_id, + {"org.matrix.msc2732.fallback_keys": fallback_key3}, + ) + ) + + res = self.get_success( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + ) + ) + self.assertEqual( + res, + {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}}, + ) + def test_replace_master_key(self): """uploading a new signing key should make the old signing key unavailable""" local_user = "@boris:" + self.hs.hostname diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 8a8d369fac..5816295d8b 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -23,6 +23,7 @@ from synapse.types import create_requester from synapse.util.stringutils import random_string from tests import unittest +from tests.test_utils.event_injection import create_event logger = logging.getLogger(__name__) @@ -51,6 +52,24 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.requester = create_requester(self.user_id, access_token_id=self.token_id) + def _create_and_persist_member_event(self) -> Tuple[EventBase, EventContext]: + # Create a member event we can use as an auth_event + memberEvent, memberEventContext = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + type="m.room.member", + sender=self.requester.user.to_string(), + state_key=self.requester.user.to_string(), + content={"membership": "join"}, + ) + ) + self.get_success( + self.persist_event_storage.persist_event(memberEvent, memberEventContext) + ) + + return memberEvent, memberEventContext + def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]: """Create a new event with the given transaction ID. All events produced by this method will be considered duplicates. @@ -156,6 +175,90 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.assertEqual(len(events), 2) self.assertEqual(events[0].event_id, events[1].event_id) + def test_when_empty_prev_events_allowed_create_event_with_empty_prev_events(self): + """When we set allow_no_prev_events=True, should be able to create a + event without any prev_events (only auth_events). + """ + # Create a member event we can use as an auth_event + memberEvent, _ = self._create_and_persist_member_event() + + # Try to create the event with empty prev_events bit with some auth_events + event, _ = self.get_success( + self.handler.create_event( + self.requester, + { + "type": EventTypes.Message, + "room_id": self.room_id, + "sender": self.requester.user.to_string(), + "content": {"msgtype": "m.text", "body": random_string(5)}, + }, + # Empty prev_events is the key thing we're testing here + prev_event_ids=[], + # But with some auth_events + auth_event_ids=[memberEvent.event_id], + # Allow no prev_events! + allow_no_prev_events=True, + ) + ) + self.assertIsNotNone(event) + + def test_when_empty_prev_events_not_allowed_reject_event_with_empty_prev_events( + self, + ): + """When we set allow_no_prev_events=False, shouldn't be able to create a + event without any prev_events even if it has auth_events. Expect an + exception to be raised. + """ + # Create a member event we can use as an auth_event + memberEvent, _ = self._create_and_persist_member_event() + + # Try to create the event with empty prev_events but with some auth_events + self.get_failure( + self.handler.create_event( + self.requester, + { + "type": EventTypes.Message, + "room_id": self.room_id, + "sender": self.requester.user.to_string(), + "content": {"msgtype": "m.text", "body": random_string(5)}, + }, + # Empty prev_events is the key thing we're testing here + prev_event_ids=[], + # But with some auth_events + auth_event_ids=[memberEvent.event_id], + # We expect the test to fail because empty prev_events are not + # allowed here! + allow_no_prev_events=False, + ), + AssertionError, + ) + + def test_when_empty_prev_events_allowed_reject_event_with_empty_prev_events_and_auth_events( + self, + ): + """When we set allow_no_prev_events=True, should be able to create a + event without any prev_events or auth_events. Expect an exception to be + raised. + """ + # Try to create the event with empty prev_events and empty auth_events + self.get_failure( + self.handler.create_event( + self.requester, + { + "type": EventTypes.Message, + "room_id": self.room_id, + "sender": self.requester.user.to_string(), + "content": {"msgtype": "m.text", "body": random_string(5)}, + }, + prev_event_ids=[], + # The event should be rejected when there are no auth_events + auth_event_ids=[], + # Allow no prev_events! + allow_no_prev_events=True, + ), + AssertionError, + ) + class ServerAclValidationTestCase(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py index 5188499ef2..d1cd5b0751 100644 --- a/tests/rest/admin/test_federation.py +++ b/tests/rest/admin/test_federation.py @@ -95,7 +95,7 @@ class FederationTestCase(unittest.HomeserverTestCase): ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid search order channel = self.make_request( @@ -105,7 +105,7 @@ class FederationTestCase(unittest.HomeserverTestCase): ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid destination channel = self.make_request( diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 81e578fd26..3f727788ce 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -360,7 +360,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): channel.code, msg=channel.json_body, ) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Boolean query parameter 'keep_profiles' must be one of ['true', 'false']", channel.json_body["error"], diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 4fedd5fd08..eea675991c 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -608,7 +608,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid deactivated channel = self.make_request( @@ -618,7 +618,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # unkown order_by channel = self.make_request( @@ -628,7 +628,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid search order channel = self.make_request( @@ -638,7 +638,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_limit(self): """ @@ -1550,7 +1550,8 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Create user body = { "password": "abc123", - "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + # Note that the given email is not in canonical form. + "threepids": [{"medium": "email", "address": "Bob@bob.bob"}], } channel = self.make_request( @@ -2896,7 +2897,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid search order channel = self.make_request( @@ -2906,7 +2907,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative limit channel = self.make_request( diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 72bbc87b4a..27cb856b0a 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -85,7 +85,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", "auth/m.login.recaptcha/fallback/web?session=" + session ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) channel = self.make_request( "POST", @@ -104,7 +104,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): """Ensure that fallback auth via a captcha works.""" # Returns a 401 as per the spec channel = self.register( - 401, + HTTPStatus.UNAUTHORIZED, {"username": "user", "type": "m.login.password", "password": "bar"}, ) @@ -116,15 +116,17 @@ class FallbackAuthTests(unittest.HomeserverTestCase): ) # Complete the recaptcha step. - self.recaptcha(session, 200) + self.recaptcha(session, HTTPStatus.OK) # also complete the dummy auth - self.register(200, {"auth": {"session": session, "type": "m.login.dummy"}}) + self.register( + HTTPStatus.OK, {"auth": {"session": session, "type": "m.login.dummy"}} + ) # Now we should have fulfilled a complete auth flow, including # the recaptcha fallback step, we can then send a # request to the register API with the session in the authdict. - channel = self.register(200, {"auth": {"session": session}}) + channel = self.register(HTTPStatus.OK, {"auth": {"session": session}}) # We're given a registered user. self.assertEqual(channel.json_body["user_id"], "@user:test") @@ -137,7 +139,8 @@ class FallbackAuthTests(unittest.HomeserverTestCase): # will be used.) # Returns a 401 as per the spec channel = self.register( - 401, {"username": "user", "type": "m.login.password", "password": "bar"} + HTTPStatus.UNAUTHORIZED, + {"username": "user", "type": "m.login.password", "password": "bar"}, ) # Grab the session @@ -231,7 +234,9 @@ class UIAuthTests(unittest.HomeserverTestCase): """ # Attempt to delete this device. # Returns a 401 as per the spec - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) # Grab the session session = channel.json_body["session"] @@ -242,7 +247,7 @@ class UIAuthTests(unittest.HomeserverTestCase): self.delete_device( self.user_tok, self.device_id, - 200, + HTTPStatus.OK, { "auth": { "type": "m.login.password", @@ -260,14 +265,16 @@ class UIAuthTests(unittest.HomeserverTestCase): UIA - check that still works. """ - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) session = channel.json_body["session"] # Make another request providing the UI auth flow. self.delete_device( self.user_tok, self.device_id, - 200, + HTTPStatus.OK, { "auth": { "type": "m.login.password", @@ -293,7 +300,9 @@ class UIAuthTests(unittest.HomeserverTestCase): # Attempt to delete the first device. # Returns a 401 as per the spec - channel = self.delete_devices(401, {"devices": [self.device_id]}) + channel = self.delete_devices( + HTTPStatus.UNAUTHORIZED, {"devices": [self.device_id]} + ) # Grab the session session = channel.json_body["session"] @@ -303,7 +312,7 @@ class UIAuthTests(unittest.HomeserverTestCase): # Make another request providing the UI auth flow, but try to delete the # second device. self.delete_devices( - 200, + HTTPStatus.OK, { "devices": ["dev2"], "auth": { @@ -324,7 +333,9 @@ class UIAuthTests(unittest.HomeserverTestCase): # Attempt to delete the first device. # Returns a 401 as per the spec - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) # Grab the session session = channel.json_body["session"] @@ -338,7 +349,7 @@ class UIAuthTests(unittest.HomeserverTestCase): self.delete_device( self.user_tok, "dev2", - 403, + HTTPStatus.FORBIDDEN, { "auth": { "type": "m.login.password", @@ -361,13 +372,13 @@ class UIAuthTests(unittest.HomeserverTestCase): self.login("test", self.user_pass, "dev3") # Attempt to delete a device. This works since the user just logged in. - self.delete_device(self.user_tok, "dev2", 200) + self.delete_device(self.user_tok, "dev2", HTTPStatus.OK) # Move the clock forward past the validation timeout. self.reactor.advance(6) # Deleting another devices throws the user into UI auth. - channel = self.delete_device(self.user_tok, "dev3", 401) + channel = self.delete_device(self.user_tok, "dev3", HTTPStatus.UNAUTHORIZED) # Grab the session session = channel.json_body["session"] @@ -378,7 +389,7 @@ class UIAuthTests(unittest.HomeserverTestCase): self.delete_device( self.user_tok, "dev3", - 200, + HTTPStatus.OK, { "auth": { "type": "m.login.password", @@ -393,7 +404,7 @@ class UIAuthTests(unittest.HomeserverTestCase): # due to re-using the previous session. # # Note that *no auth* information is provided, not even a session iD! - self.delete_device(self.user_tok, self.device_id, 200) + self.delete_device(self.user_tok, self.device_id, HTTPStatus.OK) @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) @@ -413,7 +424,9 @@ class UIAuthTests(unittest.HomeserverTestCase): self.assertEqual(login_resp["user_id"], self.user) # initiate a UI Auth process by attempting to delete the device - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) # check that SSO is offered flows = channel.json_body["flows"] @@ -426,13 +439,13 @@ class UIAuthTests(unittest.HomeserverTestCase): ) # that should serve a confirmation page - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) # and now the delete request should succeed. self.delete_device( self.user_tok, self.device_id, - 200, + HTTPStatus.OK, body={"auth": {"session": session_id}}, ) @@ -445,13 +458,15 @@ class UIAuthTests(unittest.HomeserverTestCase): # now call the device deletion API: we should get the option to auth with SSO # and not password. - channel = self.delete_device(user_tok, device_id, 401) + channel = self.delete_device(user_tok, device_id, HTTPStatus.UNAUTHORIZED) flows = channel.json_body["flows"] self.assertEqual(flows, [{"stages": ["m.login.sso"]}]) def test_does_not_offer_sso_for_password_user(self): - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) flows = channel.json_body["flows"] self.assertEqual(flows, [{"stages": ["m.login.password"]}]) @@ -463,7 +478,9 @@ class UIAuthTests(unittest.HomeserverTestCase): login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) self.assertEqual(login_resp["user_id"], self.user) - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) flows = channel.json_body["flows"] # we have no particular expectations of ordering here @@ -480,7 +497,9 @@ class UIAuthTests(unittest.HomeserverTestCase): self.assertEqual(login_resp["user_id"], self.user) # start a UI Auth flow by attempting to delete a device - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) flows = channel.json_body["flows"] self.assertIn({"stages": ["m.login.sso"]}, flows) @@ -496,7 +515,10 @@ class UIAuthTests(unittest.HomeserverTestCase): # ... and the delete op should now fail with a 403 self.delete_device( - self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}} + self.user_tok, + self.device_id, + HTTPStatus.FORBIDDEN, + body={"auth": {"session": session_id}}, ) @@ -551,7 +573,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): login_without_refresh = self.make_request( "POST", "/_matrix/client/r0/login", body ) - self.assertEqual(login_without_refresh.code, 200, login_without_refresh.result) + self.assertEqual( + login_without_refresh.code, HTTPStatus.OK, login_without_refresh.result + ) self.assertNotIn("refresh_token", login_without_refresh.json_body) login_with_refresh = self.make_request( @@ -559,7 +583,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", {"refresh_token": True, **body}, ) - self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result) + self.assertEqual( + login_with_refresh.code, HTTPStatus.OK, login_with_refresh.result + ) self.assertIn("refresh_token", login_with_refresh.json_body) self.assertIn("expires_in_ms", login_with_refresh.json_body) @@ -577,7 +603,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): }, ) self.assertEqual( - register_without_refresh.code, 200, register_without_refresh.result + register_without_refresh.code, + HTTPStatus.OK, + register_without_refresh.result, ) self.assertNotIn("refresh_token", register_without_refresh.json_body) @@ -591,7 +619,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "refresh_token": True, }, ) - self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result) + self.assertEqual( + register_with_refresh.code, HTTPStatus.OK, register_with_refresh.result + ) self.assertIn("refresh_token", register_with_refresh.json_body) self.assertIn("expires_in_ms", register_with_refresh.json_body) @@ -610,14 +640,14 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response.code, 200, login_response.result) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) refresh_response = self.make_request( "POST", "/_matrix/client/v1/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) - self.assertEqual(refresh_response.code, 200, refresh_response.result) + self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) self.assertIn("access_token", refresh_response.json_body) self.assertIn("refresh_token", refresh_response.json_body) self.assertIn("expires_in_ms", refresh_response.json_body) @@ -648,7 +678,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response.code, 200, login_response.result) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) self.assertApproximates( login_response.json_body["expires_in_ms"], 60 * 1000, 100 ) @@ -658,7 +688,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/v1/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) - self.assertEqual(refresh_response.code, 200, refresh_response.result) + self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) self.assertApproximates( refresh_response.json_body["expires_in_ms"], 60 * 1000, 100 ) @@ -705,7 +735,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", {"refresh_token": True, **body}, ) - self.assertEqual(login_response1.code, 200, login_response1.result) + self.assertEqual(login_response1.code, HTTPStatus.OK, login_response1.result) self.assertApproximates( login_response1.json_body["expires_in_ms"], 60 * 1000, 100 ) @@ -716,7 +746,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response2.code, 200, login_response2.result) + self.assertEqual(login_response2.code, HTTPStatus.OK, login_response2.result) nonrefreshable_access_token = login_response2.json_body["access_token"] # Advance 59 seconds in the future (just shy of 1 minute, the time of expiry) @@ -818,7 +848,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response.code, 200, login_response.result) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) refresh_token = login_response.json_body["refresh_token"] # Advance shy of 2 minutes into the future @@ -826,7 +856,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # Refresh our session. The refresh token should still be valid right now. refresh_response = self.use_refresh_token(refresh_token) - self.assertEqual(refresh_response.code, 200, refresh_response.result) + self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) self.assertIn( "refresh_token", refresh_response.json_body, @@ -846,7 +876,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # This should fail because the refresh token's lifetime has also been # diminished as our session expired. refresh_response = self.use_refresh_token(refresh_token) - self.assertEqual(refresh_response.code, 403, refresh_response.result) + self.assertEqual( + refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result + ) def test_refresh_token_invalidation(self): """Refresh tokens are invalidated after first use of the next token. @@ -875,7 +907,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response.code, 200, login_response.result) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) # This first refresh should work properly first_refresh_response = self.make_request( @@ -884,7 +916,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( - first_refresh_response.code, 200, first_refresh_response.result + first_refresh_response.code, HTTPStatus.OK, first_refresh_response.result ) # This one as well, since the token in the first one was never used @@ -894,7 +926,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( - second_refresh_response.code, 200, second_refresh_response.result + second_refresh_response.code, HTTPStatus.OK, second_refresh_response.result ) # This one should not, since the token from the first refresh is not valid anymore @@ -904,7 +936,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": first_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( - third_refresh_response.code, 401, third_refresh_response.result + third_refresh_response.code, + HTTPStatus.UNAUTHORIZED, + third_refresh_response.result, ) # The associated access token should also be invalid @@ -913,7 +947,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/account/whoami", access_token=first_refresh_response.json_body["access_token"], ) - self.assertEqual(whoami_response.code, 401, whoami_response.result) + self.assertEqual( + whoami_response.code, HTTPStatus.UNAUTHORIZED, whoami_response.result + ) # But all other tokens should work (they will expire after some time) for access_token in [ @@ -923,7 +959,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): whoami_response = self.make_request( "GET", "/_matrix/client/r0/account/whoami", access_token=access_token ) - self.assertEqual(whoami_response.code, 200, whoami_response.result) + self.assertEqual( + whoami_response.code, HTTPStatus.OK, whoami_response.result + ) # Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail fourth_refresh_response = self.make_request( @@ -932,7 +970,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( - fourth_refresh_response.code, 403, fourth_refresh_response.result + fourth_refresh_response.code, + HTTPStatus.FORBIDDEN, + fourth_refresh_response.result, ) # But refreshing from the last valid refresh token still works @@ -942,5 +982,5 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": second_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( - fifth_refresh_response.code, 200, fifth_refresh_response.result + fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result ) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 397c12c2a6..55f4f0b1d0 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -16,6 +16,7 @@ import itertools import urllib.parse from typing import Dict, List, Optional, Tuple +from unittest.mock import patch from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin @@ -23,6 +24,8 @@ from synapse.rest.client import login, register, relations, room, sync from tests import unittest from tests.server import FakeChannel +from tests.test_utils import make_awaitable +from tests.test_utils.event_injection import inject_event class RelationsTestCase(unittest.HomeserverTestCase): @@ -651,6 +654,118 @@ class RelationsTestCase(unittest.HomeserverTestCase): }, ) + @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) + def test_ignore_invalid_room(self): + """Test that we ignore invalid relations over federation.""" + # Create another room and send a message in it. + room2 = self.helper.create_room_as(self.user_id, tok=self.user_token) + res = self.helper.send(room2, body="Hi!", tok=self.user_token) + parent_id = res["event_id"] + + # Disable the validation to pretend this came over federation. + with patch( + "synapse.handlers.message.EventCreationHandler._validate_event_relation", + new=lambda self, event: make_awaitable(None), + ): + # Generate a various relations from a different room. + self.get_success( + inject_event( + self.hs, + room_id=self.room, + type="m.reaction", + sender=self.user_id, + content={ + "m.relates_to": { + "rel_type": RelationTypes.ANNOTATION, + "event_id": parent_id, + "key": "A", + } + }, + ) + ) + + self.get_success( + inject_event( + self.hs, + room_id=self.room, + type="m.room.message", + sender=self.user_id, + content={ + "body": "foo", + "msgtype": "m.text", + "m.relates_to": { + "rel_type": RelationTypes.REFERENCE, + "event_id": parent_id, + }, + }, + ) + ) + + self.get_success( + inject_event( + self.hs, + room_id=self.room, + type="m.room.message", + sender=self.user_id, + content={ + "body": "foo", + "msgtype": "m.text", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": parent_id, + }, + }, + ) + ) + + self.get_success( + inject_event( + self.hs, + room_id=self.room, + type="m.room.message", + sender=self.user_id, + content={ + "body": "foo", + "msgtype": "m.text", + "new_content": { + "body": "new content", + "msgtype": "m.text", + }, + "m.relates_to": { + "rel_type": RelationTypes.REPLACE, + "event_id": parent_id, + }, + }, + ) + ) + + # They should be ignored when fetching relations. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{room2}/relations/{parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(channel.json_body["chunk"], []) + + # And when fetching aggregations. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{room2}/aggregations/{parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(channel.json_body["chunk"], []) + + # And for bundled aggregations. + channel = self.make_request( + "GET", + f"/rooms/{room2}/event/{parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertNotIn("m.relations", channel.json_body["unsigned"]) + def test_edit(self): """Test that a simple edit works.""" diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py new file mode 100644 index 0000000000..721454c187 --- /dev/null +++ b/tests/rest/client/test_room_batch.py @@ -0,0 +1,180 @@ +import logging +from typing import List, Tuple +from unittest.mock import Mock, patch + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import EventContentFields, EventTypes +from synapse.appservice import ApplicationService +from synapse.rest import admin +from synapse.rest.client import login, register, room, room_batch +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock + +from tests import unittest + +logger = logging.getLogger(__name__) + + +def _create_join_state_events_for_batch_send_request( + virtual_user_ids: List[str], + insert_time: int, +) -> List[JsonDict]: + return [ + { + "type": EventTypes.Member, + "sender": virtual_user_id, + "origin_server_ts": insert_time, + "content": { + "membership": "join", + "displayname": "display-name-for-%s" % (virtual_user_id,), + }, + "state_key": virtual_user_id, + } + for virtual_user_id in virtual_user_ids + ] + + +def _create_message_events_for_batch_send_request( + virtual_user_id: str, insert_time: int, count: int +) -> List[JsonDict]: + return [ + { + "type": EventTypes.Message, + "sender": virtual_user_id, + "origin_server_ts": insert_time, + "content": { + "msgtype": "m.text", + "body": "Historical %d" % (i), + EventContentFields.MSC2716_HISTORICAL: True, + }, + } + for i in range(count) + ] + + +class RoomBatchTestCase(unittest.HomeserverTestCase): + """Test importing batches of historical messages.""" + + servlets = [ + admin.register_servlets_for_client_rest_resource, + room_batch.register_servlets, + room.register_servlets, + register.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + + self.appservice = ApplicationService( + token="i_am_an_app_service", + hostname="test", + id="1234", + namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + # Note: this user does not have to match the regex above + sender="@as_main:test", + ) + + mock_load_appservices = Mock(return_value=[self.appservice]) + with patch( + "synapse.storage.databases.main.appservice.load_appservices", + mock_load_appservices, + ): + hs = self.setup_test_homeserver(config=config) + return hs + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.clock = clock + self.storage = hs.get_storage() + + self.virtual_user_id = self.register_appservice_user( + "as_user_potato", self.appservice.token + ) + + def _create_test_room(self) -> Tuple[str, str, str, str]: + room_id = self.helper.create_room_as( + self.appservice.sender, tok=self.appservice.token + ) + + res_a = self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "A", + }, + tok=self.appservice.token, + ) + event_id_a = res_a["event_id"] + + res_b = self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "B", + }, + tok=self.appservice.token, + ) + event_id_b = res_b["event_id"] + + res_c = self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "C", + }, + tok=self.appservice.token, + ) + event_id_c = res_c["event_id"] + + return room_id, event_id_a, event_id_b, event_id_c + + @unittest.override_config({"experimental_features": {"msc2716_enabled": True}}) + def test_same_state_groups_for_whole_historical_batch(self): + """Make sure that when using the `/batch_send` endpoint to import a + bunch of historical messages, it re-uses the same `state_group` across + the whole batch. This is an easy optimization to make sure we're getting + right because the state for the whole batch is contained in + `state_events_at_start` and can be shared across everything. + """ + + time_before_room = int(self.clock.time_msec()) + room_id, event_id_a, _, _ = self._create_test_room() + + channel = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2716/rooms/%s/batch_send?prev_event_id=%s" + % (room_id, event_id_a), + content={ + "events": _create_message_events_for_batch_send_request( + self.virtual_user_id, time_before_room, 3 + ), + "state_events_at_start": _create_join_state_events_for_batch_send_request( + [self.virtual_user_id], time_before_room + ), + }, + access_token=self.appservice.token, + ) + self.assertEqual(channel.code, 200, channel.result) + + # Get the historical event IDs that we just imported + historical_event_ids = channel.json_body["event_ids"] + self.assertEqual(len(historical_event_ids), 3) + + # Fetch the state_groups + state_group_map = self.get_success( + self.storage.state.get_state_groups_ids(room_id, historical_event_ids) + ) + + # We expect all of the historical events to be using the same state_group + # so there should only be a single state_group here! + self.assertEqual( + len(state_group_map.keys()), + 1, + "Expected a single state_group to be returned by saw state_groups=%s" + % (state_group_map.keys(),), + ) |