diff options
63 files changed, 552 insertions, 287 deletions
diff --git a/.github/workflows/release-artifacts.yml b/.github/workflows/release-artifacts.yml index 325c1f7d39..0beb418a07 100644 --- a/.github/workflows/release-artifacts.yml +++ b/.github/workflows/release-artifacts.yml @@ -12,6 +12,10 @@ on: # we do the full build on tags. tags: ["v*"] +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + permissions: contents: write diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index cef4439477..4e61824ee5 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -5,6 +5,10 @@ on: branches: ["develop", "release-*"] pull_request: +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: lint: runs-on: ubuntu-latest @@ -347,7 +351,12 @@ jobs: # a job which marks all the other jobs as complete, thus allowing PRs to be merged. tests-done: + if: ${{ always() }} needs: + - lint + - lint-crlf + - lint-newsfile + - lint-sdist - trial - trial-olddeps - sytest @@ -355,4 +364,16 @@ jobs: - complement runs-on: ubuntu-latest steps: - - run: "true" \ No newline at end of file + - name: Set build result + env: + NEEDS_CONTEXT: ${{ toJSON(needs) }} + # the `jq` incantation dumps out a series of "<job> <result>" lines + run: | + set -o pipefail + jq -r 'to_entries[] | [.key,.value.result] | join(" ")' \ + <<< $NEEDS_CONTEXT | + while read job result; do + if [ "$result" != "success" ]; then + echo "::set-failed ::Job $job returned $result" + fi + done diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a4e6688042..80ef6aa235 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -392,7 +392,7 @@ By now, you know the drill! # Notes for maintainers on merging PRs etc There are some notes for those with commit access to the project on how we -manage git [here](docs/dev/git.md). +manage git [here](docs/development/git.md). # Conclusion diff --git a/changelog.d/10283.feature b/changelog.d/10283.feature new file mode 100644 index 0000000000..99d633dbfb --- /dev/null +++ b/changelog.d/10283.feature @@ -0,0 +1 @@ +Initial support for MSC3244, Room version capabilities over the /capabilities API. \ No newline at end of file diff --git a/changelog.d/10426.feature b/changelog.d/10426.feature new file mode 100644 index 0000000000..9cca6dc456 --- /dev/null +++ b/changelog.d/10426.feature @@ -0,0 +1 @@ +Email notifications now state whether an invitation is to a room or a space. diff --git a/changelog.d/10429.misc b/changelog.d/10429.misc new file mode 100644 index 0000000000..ccb2217f64 --- /dev/null +++ b/changelog.d/10429.misc @@ -0,0 +1 @@ +Drop backwards-compatibility code that was required to support Ubuntu Xenial. diff --git a/changelog.d/10432.misc b/changelog.d/10432.misc new file mode 100644 index 0000000000..3a8cdf0ae0 --- /dev/null +++ b/changelog.d/10432.misc @@ -0,0 +1 @@ +Connect historical chunks together with chunk events instead of a content field (MSC2716). diff --git a/changelog.d/10437.misc b/changelog.d/10437.misc new file mode 100644 index 0000000000..a557578499 --- /dev/null +++ b/changelog.d/10437.misc @@ -0,0 +1 @@ +Improve servlet type hints. diff --git a/changelog.d/10438.misc b/changelog.d/10438.misc new file mode 100644 index 0000000000..a557578499 --- /dev/null +++ b/changelog.d/10438.misc @@ -0,0 +1 @@ +Improve servlet type hints. diff --git a/changelog.d/10442.misc b/changelog.d/10442.misc new file mode 100644 index 0000000000..b8d412d732 --- /dev/null +++ b/changelog.d/10442.misc @@ -0,0 +1 @@ +Replace usage of `or_ignore` in `simple_insert` with `simple_upsert` usage, to stop spamming postgres logs with spurious ERROR messages. diff --git a/changelog.d/10444.misc b/changelog.d/10444.misc new file mode 100644 index 0000000000..c012e89f4b --- /dev/null +++ b/changelog.d/10444.misc @@ -0,0 +1 @@ +Update the `tests-done` Github Actions status. diff --git a/changelog.d/10445.doc b/changelog.d/10445.doc new file mode 100644 index 0000000000..4c023ded7c --- /dev/null +++ b/changelog.d/10445.doc @@ -0,0 +1 @@ +Fix hierarchy of providers on the OpenID page. diff --git a/changelog.d/10446.misc b/changelog.d/10446.misc new file mode 100644 index 0000000000..a5a0ca80eb --- /dev/null +++ b/changelog.d/10446.misc @@ -0,0 +1 @@ +Update type annotations to work with forthcoming Twisted 21.7.0 release. diff --git a/changelog.d/10448.feature b/changelog.d/10448.feature new file mode 100644 index 0000000000..f6579e0ca8 --- /dev/null +++ b/changelog.d/10448.feature @@ -0,0 +1 @@ +Add `creation_ts` to list users admin API. \ No newline at end of file diff --git a/changelog.d/10451.misc b/changelog.d/10451.misc new file mode 100644 index 0000000000..e38f4b476d --- /dev/null +++ b/changelog.d/10451.misc @@ -0,0 +1 @@ +Cancel redundant GHA workflows when a new commit is pushed. diff --git a/changelog.d/10453.doc b/changelog.d/10453.doc new file mode 100644 index 0000000000..5d4db9bca2 --- /dev/null +++ b/changelog.d/10453.doc @@ -0,0 +1 @@ +Consolidate development documentation to `docs/development/`. diff --git a/changelog.d/10463.misc b/changelog.d/10463.misc new file mode 100644 index 0000000000..d7b4d2222e --- /dev/null +++ b/changelog.d/10463.misc @@ -0,0 +1 @@ +Disable `msc2716` Complement tests until Complement updates are merged. diff --git a/debian/build_virtualenv b/debian/build_virtualenv index 21caad90cc..68c8659953 100755 --- a/debian/build_virtualenv +++ b/debian/build_virtualenv @@ -33,13 +33,11 @@ esac # Use --builtin-venv to use the better `venv` module from CPython 3.4+ rather # than the 2/3 compatible `virtualenv`. -# Pin pip to 20.3.4 to fix breakage in 21.0 on py3.5 (xenial) - dh_virtualenv \ --install-suffix "matrix-synapse" \ --builtin-venv \ --python "$SNAKE" \ - --upgrade-pip-to="20.3.4" \ + --upgrade-pip \ --preinstall="lxml" \ --preinstall="mock" \ --extra-pip-arg="--no-cache-dir" \ diff --git a/debian/changelog b/debian/changelog index 2062c6caef..ce8e2105e7 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.39.0ubuntu1) UNRELEASED; urgency=medium + + * Drop backwards-compatibility code that was required to support Ubuntu Xenial. + + -- Richard van der Hoff <richard@matrix.org> Tue, 20 Jul 2021 00:10:03 +0100 + matrix-synapse-py3 (1.38.1) stable; urgency=medium * New synapse release 1.38.1. diff --git a/debian/compat b/debian/compat index ec635144f6..f599e28b8a 100644 --- a/debian/compat +++ b/debian/compat @@ -1 +1 @@ -9 +10 diff --git a/debian/control b/debian/control index 8167a901a4..763fabd6f6 100644 --- a/debian/control +++ b/debian/control @@ -3,11 +3,8 @@ Section: contrib/python Priority: extra Maintainer: Synapse Packaging team <packages@matrix.org> # keep this list in sync with the build dependencies in docker/Dockerfile-dhvirtualenv. -# TODO: Remove the dependency on dh-systemd after dropping support for Ubuntu xenial -# On all other supported releases, it's merely a transitional package which -# does nothing but depends on debhelper (> 9.20160709) Build-Depends: - debhelper (>= 9.20160709) | dh-systemd, + debhelper (>= 10), dh-virtualenv (>= 1.1), libsystemd-dev, libpq-dev, diff --git a/debian/rules b/debian/rules index c744060a57..b9d490adc9 100755 --- a/debian/rules +++ b/debian/rules @@ -51,7 +51,5 @@ override_dh_shlibdeps: override_dh_virtualenv: ./debian/build_virtualenv -# We are restricted to compat level 9 (because xenial), so have to -# enable the systemd bits manually. %: - dh $@ --with python-virtualenv --with systemd + dh $@ --with python-virtualenv diff --git a/docker/Dockerfile-dhvirtualenv b/docker/Dockerfile-dhvirtualenv index 0d74630370..017be8555e 100644 --- a/docker/Dockerfile-dhvirtualenv +++ b/docker/Dockerfile-dhvirtualenv @@ -15,6 +15,15 @@ ARG distro="" ### ### Stage 0: build a dh-virtualenv ### + +# This is only really needed on bionic and focal, since other distributions we +# care about have a recent version of dh-virtualenv by default. Unfortunately, +# it looks like focal is going to be with us for a while. +# +# (focal doesn't have a dh-virtualenv package at all. There is a PPA at +# https://launchpad.net/~jyrki-pulliainen/+archive/ubuntu/dh-virtualenv, but +# it's not obviously easier to use that than to build our own.) + FROM ${distro} as builder RUN apt-get update -qq -o Acquire::Languages=none @@ -27,7 +36,7 @@ RUN env DEBIAN_FRONTEND=noninteractive apt-get install \ wget # fetch and unpack the package -# TODO: Upgrade to 1.2.2 once xenial is dropped +# TODO: Upgrade to 1.2.2 once bionic is dropped (1.2.2 requires debhelper 12; bionic has only 11) RUN mkdir /dh-virtualenv RUN wget -q -O /dh-virtualenv.tar.gz https://github.com/spotify/dh-virtualenv/archive/ac6e1b1.tar.gz RUN tar -xv --strip-components=1 -C /dh-virtualenv -f /dh-virtualenv.tar.gz @@ -59,8 +68,6 @@ ENV LANG C.UTF-8 # # NB: keep this list in sync with the list of build-deps in debian/control # TODO: it would be nice to do that automatically. -# TODO: Remove the dh-systemd stanza after dropping support for Ubuntu xenial -# it's a transitional package on all other, more recent releases RUN apt-get update -qq -o Acquire::Languages=none \ && env DEBIAN_FRONTEND=noninteractive apt-get install \ -yqq --no-install-recommends -o Dpkg::Options::=--force-unsafe-io \ @@ -76,10 +83,7 @@ RUN apt-get update -qq -o Acquire::Languages=none \ python3-venv \ sqlite3 \ libpq-dev \ - xmlsec1 \ - && ( env DEBIAN_FRONTEND=noninteractive apt-get install \ - -yqq --no-install-recommends -o Dpkg::Options::=--force-unsafe-io \ - dh-systemd || true ) + xmlsec1 COPY --from=builder /dh-virtualenv_1.2~dev-1_all.deb / diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index db4ef1a44e..f1bde91420 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -67,7 +67,7 @@ # Development - [Contributing Guide](development/contributing_guide.md) - [Code Style](code_style.md) - - [Git Usage](dev/git.md) + - [Git Usage](development/git.md) - [Testing]() - [OpenTracing](opentracing.md) - [Database Schemas](development/database_schema.md) @@ -77,8 +77,8 @@ - [TCP Replication](tcp_replication.md) - [Internal Documentation](development/internal_documentation/README.md) - [Single Sign-On]() - - [SAML](dev/saml.md) - - [CAS](dev/cas.md) + - [SAML](development/saml.md) + - [CAS](development/cas.md) - [State Resolution]() - [The Auth Chain Difference Algorithm](auth_chain_difference_algorithm.md) - [Media Repository](media_repository.md) diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 4a65d0c3bc..160899754e 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -144,7 +144,8 @@ A response body like the following is returned: "deactivated": 0, "shadow_banned": 0, "displayname": "<User One>", - "avatar_url": null + "avatar_url": null, + "creation_ts": 1560432668000 }, { "name": "<user_id2>", "is_guest": 0, @@ -153,7 +154,8 @@ A response body like the following is returned: "deactivated": 0, "shadow_banned": 0, "displayname": "<User Two>", - "avatar_url": "<avatar_url>" + "avatar_url": "<avatar_url>", + "creation_ts": 1561550621000 } ], "next_token": "100", @@ -197,11 +199,12 @@ The following parameters should be set in the URL: - `shadow_banned` - Users are ordered by `shadow_banned` status. - `displayname` - Users are ordered alphabetically by `displayname`. - `avatar_url` - Users are ordered alphabetically by avatar URL. + - `creation_ts` - Users are ordered by when the users was created in ms. - `dir` - Direction of media order. Either `f` for forwards or `b` for backwards. Setting this value to `b` will reverse the above sort order. Defaults to `f`. -Caution. The database only has indexes on the columns `name` and `created_ts`. +Caution. The database only has indexes on the columns `name` and `creation_ts`. This means that if a different sort order is used (`is_guest`, `admin`, `user_type`, `deactivated`, `shadow_banned`, `avatar_url` or `displayname`), this can cause a large load on the database, especially for large environments. @@ -222,6 +225,7 @@ The following fields are returned in the JSON response body: - `shadow_banned` - bool - Status if that user has been marked as shadow banned. - `displayname` - string - The user's display name if they have set one. - `avatar_url` - string - The user's avatar URL if they have set one. + - `creation_ts` - integer - The user's creation timestamp in ms. - `next_token`: string representing a positive integer - Indication for pagination. See above. - `total` - integer - Total number of media. diff --git a/docs/dev/cas.md b/docs/development/cas.md index 592b2d8d4f..592b2d8d4f 100644 --- a/docs/dev/cas.md +++ b/docs/development/cas.md diff --git a/docs/dev/git.md b/docs/development/git.md index 87950f07b2..9b1ed54b65 100644 --- a/docs/dev/git.md +++ b/docs/development/git.md @@ -9,7 +9,7 @@ commits each of which contains a single change building on what came before. Here, by way of an arbitrary example, is the top of `git log --graph b2dba0607`: -<img src="git/clean.png" alt="clean git graph" width="500px"> +<img src="img/git/clean.png" alt="clean git graph" width="500px"> Note how the commit comment explains clearly what is changing and why. Also note the *absence* of merge commits, as well as the absence of commits called @@ -61,7 +61,7 @@ Ok, so that's what we'd like to achieve. How do we achieve it? The TL;DR is: when you come to merge a pull request, you *probably* want to “squash and merge”: -![squash and merge](git/squash.png). +![squash and merge](img/git/squash.png). (This applies whether you are merging your own PR, or that of another contributor.) @@ -105,7 +105,7 @@ complicated. Here's how we do it. Let's start with a picture: -![branching model](git/branches.jpg) +![branching model](img/git/branches.jpg) It looks complicated, but it's really not. There's one basic rule: *anyone* is free to merge from *any* more-stable branch to *any* less-stable branch at diff --git a/docs/dev/git/branches.jpg b/docs/development/img/git/branches.jpg index 715ecc8cd0..715ecc8cd0 100644 --- a/docs/dev/git/branches.jpg +++ b/docs/development/img/git/branches.jpg Binary files differdiff --git a/docs/dev/git/clean.png b/docs/development/img/git/clean.png index 3accd7ccef..3accd7ccef 100644 --- a/docs/dev/git/clean.png +++ b/docs/development/img/git/clean.png Binary files differdiff --git a/docs/dev/git/squash.png b/docs/development/img/git/squash.png index 234caca3e4..234caca3e4 100644 --- a/docs/dev/git/squash.png +++ b/docs/development/img/git/squash.png Binary files differdiff --git a/docs/dev/saml.md b/docs/development/saml.md index a9bfd2dc05..a9bfd2dc05 100644 --- a/docs/dev/saml.md +++ b/docs/development/saml.md diff --git a/docs/openid.md b/docs/openid.md index cfaafc5015..f685fd551a 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -410,7 +410,7 @@ oidc_providers: display_name_template: "{{ user.name }}" ``` -## Apple +### Apple Configuring "Sign in with Apple" (SiWA) requires an Apple Developer account. diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index aca32edc17..4df224be67 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -65,4 +65,4 @@ if [[ -n "$1" ]]; then fi # Run the tests! -go test -v -tags synapse_blacklist,msc2946,msc3083,msc2716,msc2403 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests +go test -v -tags synapse_blacklist,msc2946,msc3083,msc2403 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 8363c2bb0f..4caafc0ac9 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -120,6 +120,7 @@ class EventTypes: SpaceParent = "m.space.parent" MSC2716_INSERTION = "org.matrix.msc2716.insertion" + MSC2716_CHUNK = "org.matrix.msc2716.chunk" MSC2716_MARKER = "org.matrix.msc2716.marker" @@ -190,9 +191,10 @@ class EventContentFields: # Used on normal messages to indicate they were historically imported after the fact MSC2716_HISTORICAL = "org.matrix.msc2716.historical" - # For "insertion" events + # For "insertion" events to indicate what the next chunk ID should be in + # order to connect to it MSC2716_NEXT_CHUNK_ID = "org.matrix.msc2716.next_chunk_id" - # Used on normal message events to indicate where the chunk connects to + # Used on "chunk" events to indicate which insertion event it connects to MSC2716_CHUNK_ID = "org.matrix.msc2716.chunk_id" # For "marker" events MSC2716_MARKER_INSERTION = "org.matrix.msc2716.marker.insertion" diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index a20abc5a65..8dd33dcb83 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict +from typing import Callable, Dict, Optional import attr @@ -208,5 +208,39 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = { RoomVersions.MSC3083, RoomVersions.V7, ) - # Note that we do not include MSC2043 here unless it is enabled in the config. +} + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RoomVersionCapability: + """An object which describes the unique attributes of a room version.""" + + identifier: str # the identifier for this capability + preferred_version: Optional[RoomVersion] + support_check_lambda: Callable[[RoomVersion], bool] + + +MSC3244_CAPABILITIES = { + cap.identifier: { + "preferred": cap.preferred_version.identifier + if cap.preferred_version is not None + else None, + "support": [ + v.identifier + for v in KNOWN_ROOM_VERSIONS.values() + if cap.support_check_lambda(v) + ], + } + for cap in ( + RoomVersionCapability( + "knock", + RoomVersions.V7, + lambda room_version: room_version.msc2403_knocking, + ), + RoomVersionCapability( + "restricted", + None, + lambda room_version: room_version.msc3083_join_rules, + ), + ) } diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index bcecbfec03..8d8f166e9b 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -39,12 +39,13 @@ DEFAULT_SUBJECTS = { "messages_from_person_and_others": "[%(app)s] You have messages on %(app)s from %(person)s and others...", "invite_from_person": "[%(app)s] %(person)s has invited you to chat on %(app)s...", "invite_from_person_to_room": "[%(app)s] %(person)s has invited you to join the %(room)s room on %(app)s...", + "invite_from_person_to_space": "[%(app)s] %(person)s has invited you to join the %(space)s space on %(app)s...", "password_reset": "[%(server_name)s] Password reset", "email_validation": "[%(server_name)s] Validate your email", } -@attr.s +@attr.s(slots=True, frozen=True) class EmailSubjectConfig: message_from_person_in_room = attr.ib(type=str) message_from_person = attr.ib(type=str) @@ -54,6 +55,7 @@ class EmailSubjectConfig: messages_from_person_and_others = attr.ib(type=str) invite_from_person = attr.ib(type=str) invite_from_person_to_room = attr.ib(type=str) + invite_from_person_to_space = attr.ib(type=str) password_reset = attr.ib(type=str) email_validation = attr.ib(type=str) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index e25ccba9ac..040c4504d8 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -32,3 +32,6 @@ class ExperimentalConfig(Config): # MSC2716 (backfill existing history) self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False) + + # MSC3244 (room version capabilities) + self.msc3244_enabled: bool = experimental.get("msc3244_enabled", False) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 2974d4d0cc..5e059d6e09 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -984,7 +984,7 @@ class PublicRoomList(BaseFederationServlet): limit = parse_integer_from_args(query, "limit", 0) since_token = parse_string_from_args(query, "since", None) include_all_networks = parse_boolean_from_args( - query, "include_all_networks", False + query, "include_all_networks", default=False ) third_party_instance_id = parse_string_from_args( query, "third_party_instance_id", None @@ -1908,16 +1908,7 @@ class FederationSpaceSummaryServlet(BaseFederationServlet): suggested_only = parse_boolean_from_args(query, "suggested_only", default=False) max_rooms_per_space = parse_integer_from_args(query, "max_rooms_per_space") - exclude_rooms = [] - if b"exclude_rooms" in query: - try: - exclude_rooms = [ - room_id.decode("ascii") for room_id in query[b"exclude_rooms"] - ] - except Exception: - raise SynapseError( - 400, "Bad query parameter for exclude_rooms", Codes.INVALID_PARAM - ) + exclude_rooms = parse_strings_from_args(query, "exclude_rooms", default=[]) return 200, await self.handler.federation_space_summary( origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 950770201a..c16b7f10e6 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -27,7 +27,7 @@ from twisted.internet.interfaces import ( ) from twisted.web.client import URI, Agent, HTTPConnectionPool from twisted.web.http_headers import Headers -from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer +from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer, IResponse from synapse.crypto.context_factory import FederationPolicyForHTTPS from synapse.http.client import BlacklistingAgentWrapper @@ -116,7 +116,7 @@ class MatrixFederationAgent: uri: bytes, headers: Optional[Headers] = None, bodyProducer: Optional[IBodyProducer] = None, - ) -> Generator[defer.Deferred, Any, defer.Deferred]: + ) -> Generator[defer.Deferred, Any, IResponse]: """ Args: method: HTTP method: GET/POST/etc diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 04560fb589..732a1e6aeb 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -14,47 +14,86 @@ """ This module contains base REST classes for constructing REST servlets. """ import logging -from typing import Dict, Iterable, List, Optional, overload +from typing import Iterable, List, Mapping, Optional, Sequence, overload from typing_extensions import Literal from twisted.web.server import Request from synapse.api.errors import Codes, SynapseError +from synapse.types import JsonDict from synapse.util import json_decoder logger = logging.getLogger(__name__) -def parse_integer(request, name, default=None, required=False): +@overload +def parse_integer(request: Request, name: str, default: int) -> int: + ... + + +@overload +def parse_integer(request: Request, name: str, *, required: Literal[True]) -> int: + ... + + +@overload +def parse_integer( + request: Request, name: str, default: Optional[int] = None, required: bool = False +) -> Optional[int]: + ... + + +def parse_integer( + request: Request, name: str, default: Optional[int] = None, required: bool = False +) -> Optional[int]: """Parse an integer parameter from the request string Args: request: the twisted HTTP request. - name (bytes/unicode): the name of the query parameter. - default (int|None): value to use if the parameter is absent, defaults - to None. - required (bool): whether to raise a 400 SynapseError if the - parameter is absent, defaults to False. + name: the name of the query parameter. + default: value to use if the parameter is absent, defaults to None. + required: whether to raise a 400 SynapseError if the parameter is absent, + defaults to False. Returns: - int|None: An int value or the default. + An int value or the default. Raises: SynapseError: if the parameter is absent and required, or if the parameter is present and not an integer. """ - return parse_integer_from_args(request.args, name, default, required) + args: Mapping[bytes, Sequence[bytes]] = request.args # type: ignore + return parse_integer_from_args(args, name, default, required) + +def parse_integer_from_args( + args: Mapping[bytes, Sequence[bytes]], + name: str, + default: Optional[int] = None, + required: bool = False, +) -> Optional[int]: + """Parse an integer parameter from the request string + + Args: + args: A mapping of request args as bytes to a list of bytes (e.g. request.args). + name: the name of the query parameter. + default: value to use if the parameter is absent, defaults to None. + required: whether to raise a 400 SynapseError if the parameter is absent, + defaults to False. -def parse_integer_from_args(args, name, default=None, required=False): + Returns: + An int value or the default. - if not isinstance(name, bytes): - name = name.encode("ascii") + Raises: + SynapseError: if the parameter is absent and required, or if the + parameter is present and not an integer. + """ + name_bytes = name.encode("ascii") - if name in args: + if name_bytes in args: try: - return int(args[name][0]) + return int(args[name_bytes][0]) except Exception: message = "Query parameter %r must be an integer" % (name,) raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) @@ -66,36 +105,102 @@ def parse_integer_from_args(args, name, default=None, required=False): return default -def parse_boolean(request, name, default=None, required=False): +@overload +def parse_boolean(request: Request, name: str, default: bool) -> bool: + ... + + +@overload +def parse_boolean(request: Request, name: str, *, required: Literal[True]) -> bool: + ... + + +@overload +def parse_boolean( + request: Request, name: str, default: Optional[bool] = None, required: bool = False +) -> Optional[bool]: + ... + + +def parse_boolean( + request: Request, name: str, default: Optional[bool] = None, required: bool = False +) -> Optional[bool]: """Parse a boolean parameter from the request query string Args: request: the twisted HTTP request. - name (bytes/unicode): the name of the query parameter. - default (bool|None): value to use if the parameter is absent, defaults - to None. - required (bool): whether to raise a 400 SynapseError if the - parameter is absent, defaults to False. + name: the name of the query parameter. + default: value to use if the parameter is absent, defaults to None. + required: whether to raise a 400 SynapseError if the parameter is absent, + defaults to False. Returns: - bool|None: A bool value or the default. + A bool value or the default. Raises: SynapseError: if the parameter is absent and required, or if the parameter is present and not one of "true" or "false". """ + args: Mapping[bytes, Sequence[bytes]] = request.args # type: ignore + return parse_boolean_from_args(args, name, default, required) + + +@overload +def parse_boolean_from_args( + args: Mapping[bytes, Sequence[bytes]], + name: str, + default: bool, +) -> bool: + ... + - return parse_boolean_from_args(request.args, name, default, required) +@overload +def parse_boolean_from_args( + args: Mapping[bytes, Sequence[bytes]], + name: str, + *, + required: Literal[True], +) -> bool: + ... -def parse_boolean_from_args(args, name, default=None, required=False): +@overload +def parse_boolean_from_args( + args: Mapping[bytes, Sequence[bytes]], + name: str, + default: Optional[bool] = None, + required: bool = False, +) -> Optional[bool]: + ... - if not isinstance(name, bytes): - name = name.encode("ascii") - if name in args: +def parse_boolean_from_args( + args: Mapping[bytes, Sequence[bytes]], + name: str, + default: Optional[bool] = None, + required: bool = False, +) -> Optional[bool]: + """Parse a boolean parameter from the request query string + + Args: + args: A mapping of request args as bytes to a list of bytes (e.g. request.args). + name: the name of the query parameter. + default: value to use if the parameter is absent, defaults to None. + required: whether to raise a 400 SynapseError if the parameter is absent, + defaults to False. + + Returns: + A bool value or the default. + + Raises: + SynapseError: if the parameter is absent and required, or if the + parameter is present and not one of "true" or "false". + """ + name_bytes = name.encode("ascii") + + if name_bytes in args: try: - return {b"true": True, b"false": False}[args[name][0]] + return {b"true": True, b"false": False}[args[name_bytes][0]] except Exception: message = ( "Boolean query parameter %r must be one of ['true', 'false']" @@ -111,7 +216,7 @@ def parse_boolean_from_args(args, name, default=None, required=False): @overload def parse_bytes_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[bytes] = None, ) -> Optional[bytes]: @@ -120,7 +225,7 @@ def parse_bytes_from_args( @overload def parse_bytes_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Literal[None] = None, *, @@ -131,7 +236,7 @@ def parse_bytes_from_args( @overload def parse_bytes_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[bytes] = None, required: bool = False, @@ -140,7 +245,7 @@ def parse_bytes_from_args( def parse_bytes_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[bytes] = None, required: bool = False, @@ -172,6 +277,42 @@ def parse_bytes_from_args( return default +@overload +def parse_string( + request: Request, + name: str, + default: str, + *, + allowed_values: Optional[Iterable[str]] = None, + encoding: str = "ascii", +) -> str: + ... + + +@overload +def parse_string( + request: Request, + name: str, + *, + required: Literal[True], + allowed_values: Optional[Iterable[str]] = None, + encoding: str = "ascii", +) -> str: + ... + + +@overload +def parse_string( + request: Request, + name: str, + *, + required: bool = False, + allowed_values: Optional[Iterable[str]] = None, + encoding: str = "ascii", +) -> Optional[str]: + ... + + def parse_string( request: Request, name: str, @@ -179,7 +320,7 @@ def parse_string( required: bool = False, allowed_values: Optional[Iterable[str]] = None, encoding: str = "ascii", -): +) -> Optional[str]: """ Parse a string parameter from the request query string. @@ -205,7 +346,7 @@ def parse_string( parameter is present, must be one of a list of allowed values and is not one of those allowed values. """ - args: Dict[bytes, List[bytes]] = request.args # type: ignore + args: Mapping[bytes, Sequence[bytes]] = request.args # type: ignore return parse_string_from_args( args, name, @@ -239,9 +380,8 @@ def _parse_string_value( @overload def parse_strings_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, - default: Optional[List[str]] = None, *, allowed_values: Optional[Iterable[str]] = None, encoding: str = "ascii", @@ -251,9 +391,20 @@ def parse_strings_from_args( @overload def parse_strings_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], + name: str, + default: List[str], + *, + allowed_values: Optional[Iterable[str]] = None, + encoding: str = "ascii", +) -> List[str]: + ... + + +@overload +def parse_strings_from_args( + args: Mapping[bytes, Sequence[bytes]], name: str, - default: Optional[List[str]] = None, *, required: Literal[True], allowed_values: Optional[Iterable[str]] = None, @@ -264,7 +415,7 @@ def parse_strings_from_args( @overload def parse_strings_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[List[str]] = None, *, @@ -276,7 +427,7 @@ def parse_strings_from_args( def parse_strings_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[List[str]] = None, required: bool = False, @@ -325,7 +476,7 @@ def parse_strings_from_args( @overload def parse_string_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[str] = None, *, @@ -337,7 +488,7 @@ def parse_string_from_args( @overload def parse_string_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[str] = None, *, @@ -350,7 +501,7 @@ def parse_string_from_args( @overload def parse_string_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[str] = None, required: bool = False, @@ -361,7 +512,7 @@ def parse_string_from_args( def parse_string_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[str] = None, required: bool = False, @@ -409,13 +560,14 @@ def parse_string_from_args( return strings[0] -def parse_json_value_from_request(request, allow_empty_body=False): +def parse_json_value_from_request( + request: Request, allow_empty_body: bool = False +) -> Optional[JsonDict]: """Parse a JSON value from the body of a twisted HTTP request. Args: request: the twisted HTTP request. - allow_empty_body (bool): if True, an empty body will be accepted and - turned into None + allow_empty_body: if True, an empty body will be accepted and turned into None Returns: The JSON value. @@ -424,7 +576,7 @@ def parse_json_value_from_request(request, allow_empty_body=False): SynapseError if the request body couldn't be decoded as JSON. """ try: - content_bytes = request.content.read() + content_bytes = request.content.read() # type: ignore except Exception: raise SynapseError(400, "Error reading JSON content.") @@ -440,13 +592,15 @@ def parse_json_value_from_request(request, allow_empty_body=False): return content -def parse_json_object_from_request(request, allow_empty_body=False): +def parse_json_object_from_request( + request: Request, allow_empty_body: bool = False +) -> JsonDict: """Parse a JSON object from the body of a twisted HTTP request. Args: request: the twisted HTTP request. - allow_empty_body (bool): if True, an empty body will be accepted and - turned into an empty dict. + allow_empty_body: if True, an empty body will be accepted and turned into + an empty dict. Raises: SynapseError if the request body couldn't be decoded as JSON or @@ -457,14 +611,14 @@ def parse_json_object_from_request(request, allow_empty_body=False): if allow_empty_body and content is None: return {} - if type(content) != dict: + if not isinstance(content, dict): message = "Content must be a JSON object." raise SynapseError(400, message, errcode=Codes.BAD_JSON) return content -def assert_params_in_dict(body, required): +def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None: absent = [] for k in required: if k not in body: diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 18ac507802..02e5ddd2ef 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -25,7 +25,7 @@ See doc/log_contexts.rst for details on how this works. import inspect import logging import threading -import types +import typing import warnings from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union @@ -745,7 +745,7 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred: # by synchronous exceptions, so let's turn them into Failures. return defer.fail() - if isinstance(res, types.CoroutineType): + if isinstance(res, typing.Coroutine): res = defer.ensureDeferred(res) # At this point we should have a Deferred, if not then f was a synchronous diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 1259fc2d90..473812b8e2 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -484,7 +484,7 @@ class ModuleApi: @defer.inlineCallbacks def get_state_events_in_room( self, room_id: str, types: Iterable[Tuple[str, Optional[str]]] - ) -> Generator[defer.Deferred, Any, defer.Deferred]: + ) -> Generator[defer.Deferred, Any, Iterable[EventBase]]: """Gets current state events for the given room. (This is exposed for compatibility with the old SpamCheckerApi. We should diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 7be5fe1e9b..941fb238b7 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar import bleach import jinja2 -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, Membership, RoomTypes from synapse.api.errors import StoreError from synapse.config.emailconfig import EmailSubjectConfig from synapse.events import EventBase @@ -600,6 +600,22 @@ class Mailer: "app": self.app_name, } + # If the room is a space, it gets a slightly different topic. + create_event_id = room_state_ids.get(("m.room.create", "")) + if create_event_id: + create_event = await self.store.get_event( + create_event_id, allow_none=True + ) + if ( + create_event + and create_event.content.get("room_type") == RoomTypes.SPACE + ): + return self.email_subjects.invite_from_person_to_space % { + "person": inviter_name, + "space": room_name, + "app": self.app_name, + } + return self.email_subjects.invite_from_person_to_room % { "person": inviter_name, "room": room_name, diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 589e47fa47..eef76ab18a 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -62,6 +62,7 @@ class UsersRestServletV2(RestServlet): The parameter `name` can be used to filter by user id or display name. The parameter `guests` can be used to exclude guest users. The parameter `deactivated` can be used to include deactivated users. + The parameter `order_by` can be used to order the result. """ def __init__(self, hs: "HomeServer"): @@ -90,8 +91,8 @@ class UsersRestServletV2(RestServlet): errcode=Codes.INVALID_PARAM, ) - user_id = parse_string(request, "user_id", default=None) - name = parse_string(request, "name", default=None) + user_id = parse_string(request, "user_id") + name = parse_string(request, "name") guests = parse_boolean(request, "guests", default=True) deactivated = parse_boolean(request, "deactivated", default=False) @@ -108,6 +109,7 @@ class UsersRestServletV2(RestServlet): UserSortOrder.USER_TYPE.value, UserSortOrder.AVATAR_URL.value, UserSortOrder.SHADOW_BANNED.value, + UserSortOrder.CREATION_TS.value, ), ) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 31a1193cd3..25ba52c624 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -413,7 +413,7 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): assert_params_in_dict(body, ["state_events_at_start", "events"]) prev_events_from_query = parse_strings_from_args(request.args, "prev_event") - chunk_id_from_query = parse_string(request, "chunk_id", default=None) + chunk_id_from_query = parse_string(request, "chunk_id") if prev_events_from_query is None: raise SynapseError( @@ -553,9 +553,18 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): ] # Connect this current chunk to the insertion event from the previous chunk - last_event_in_chunk["content"][ - EventContentFields.MSC2716_CHUNK_ID - ] = chunk_id_to_connect_to + chunk_event = { + "type": EventTypes.MSC2716_CHUNK, + "sender": requester.user.to_string(), + "room_id": room_id, + "content": {EventContentFields.MSC2716_CHUNK_ID: chunk_id_to_connect_to}, + # Since the chunk event is put at the end of the chunk, + # where the newest-in-time event is, copy the origin_server_ts from + # the last event we're inserting + "origin_server_ts": last_event_in_chunk["origin_server_ts"], + } + # Add the chunk event to the end of the chunk (newest-in-time) + events_to_create.append(chunk_event) # Add an "insertion" event to the start of each chunk (next to the oldest-in-time # event in the chunk) so the next chunk can be connected to this one. @@ -567,7 +576,7 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): # the first event we're inserting origin_server_ts=events_to_create[0]["origin_server_ts"], ) - # Prepend the insertion event to the start of the chunk + # Prepend the insertion event to the start of the chunk (oldest-in-time) events_to_create = [insertion_event] + events_to_create event_ids = [] @@ -726,7 +735,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): self.auth = hs.get_auth() async def on_GET(self, request): - server = parse_string(request, "server", default=None) + server = parse_string(request, "server") try: await self.auth.get_user_by_req(request, allow_guest=True) @@ -745,8 +754,8 @@ class PublicRoomListRestServlet(TransactionRestServlet): if server: raise e - limit = parse_integer(request, "limit", 0) - since_token = parse_string(request, "since", None) + limit: Optional[int] = parse_integer(request, "limit", 0) + since_token = parse_string(request, "since") if limit == 0: # zero is a special value which corresponds to no limit. @@ -780,7 +789,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): async def on_POST(self, request): await self.auth.get_user_by_req(request, allow_guest=True) - server = parse_string(request, "server", default=None) + server = parse_string(request, "server") content = parse_json_object_from_request(request) limit: Optional[int] = int(content.get("limit", 100)) diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py index 6a24021484..88e3aac797 100644 --- a/synapse/rest/client/v2_alpha/capabilities.py +++ b/synapse/rest/client/v2_alpha/capabilities.py @@ -14,7 +14,7 @@ import logging from typing import TYPE_CHECKING, Tuple -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, MSC3244_CAPABILITIES from synapse.http.servlet import RestServlet from synapse.http.site import SynapseRequest from synapse.types import JsonDict @@ -55,6 +55,12 @@ class CapabilitiesRestServlet(RestServlet): "m.change_password": {"enabled": change_password}, } } + + if self.config.experimental.msc3244_enabled: + response["capabilities"]["m.room_versions"][ + "org.matrix.msc3244.room_capabilities" + ] = MSC3244_CAPABILITIES + return 200, response diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 33cf8de186..d0d9d30d40 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -194,7 +194,7 @@ class KeyChangesServlet(RestServlet): async def on_GET(self, request): requester = await self.auth.get_user_by_req(request, allow_guest=True) - from_token_string = parse_string(request, "from") + from_token_string = parse_string(request, "from", required=True) set_tag("from", from_token_string) # We want to enforce they do pass us one, but we ignore it and return diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index c7da6759db..0821cd285f 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -158,19 +158,21 @@ class RelationPaginationServlet(RestServlet): event = await self.event_handler.get_event(requester.user, room_id, parent_id) limit = parse_integer(request, "limit", default=5) - from_token = parse_string(request, "from") - to_token = parse_string(request, "to") + from_token_str = parse_string(request, "from") + to_token_str = parse_string(request, "to") if event.internal_metadata.is_redacted(): # If the event is redacted, return an empty list of relations pagination_chunk = PaginationChunk(chunk=[]) else: # Return the relations - if from_token: - from_token = RelationPaginationToken.from_string(from_token) + from_token = None + if from_token_str: + from_token = RelationPaginationToken.from_string(from_token_str) - if to_token: - to_token = RelationPaginationToken.from_string(to_token) + to_token = None + if to_token_str: + to_token = RelationPaginationToken.from_string(to_token_str) pagination_chunk = await self.store.get_relations_for_event( event_id=parent_id, @@ -256,19 +258,21 @@ class RelationAggregationPaginationServlet(RestServlet): raise SynapseError(400, "Relation type must be 'annotation'") limit = parse_integer(request, "limit", default=5) - from_token = parse_string(request, "from") - to_token = parse_string(request, "to") + from_token_str = parse_string(request, "from") + to_token_str = parse_string(request, "to") if event.internal_metadata.is_redacted(): # If the event is redacted, return an empty list of relations pagination_chunk = PaginationChunk(chunk=[]) else: # Return the relations - if from_token: - from_token = AggregationPaginationToken.from_string(from_token) + from_token = None + if from_token_str: + from_token = AggregationPaginationToken.from_string(from_token_str) - if to_token: - to_token = AggregationPaginationToken.from_string(to_token) + to_token = None + if to_token_str: + to_token = AggregationPaginationToken.from_string(to_token_str) pagination_chunk = await self.store.get_aggregation_groups_for_event( event_id=parent_id, @@ -336,14 +340,16 @@ class RelationAggregationGroupPaginationServlet(RestServlet): raise SynapseError(400, "Relation type must be 'annotation'") limit = parse_integer(request, "limit", default=5) - from_token = parse_string(request, "from") - to_token = parse_string(request, "to") + from_token_str = parse_string(request, "from") + to_token_str = parse_string(request, "to") - if from_token: - from_token = RelationPaginationToken.from_string(from_token) + from_token = None + if from_token_str: + from_token = RelationPaginationToken.from_string(from_token_str) - if to_token: - to_token = RelationPaginationToken.from_string(to_token) + to_token = None + if to_token_str: + to_token = RelationPaginationToken.from_string(to_token_str) result = await self.store.get_relations_for_event( event_id=parent_id, diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 32e8500795..e321668698 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -112,7 +112,7 @@ class SyncRestServlet(RestServlet): default="online", allowed_values=self.ALLOWED_PRESENCE, ) - filter_id = parse_string(request, "filter", default=None) + filter_id = parse_string(request, "filter") full_state = parse_boolean(request, "full_state", default=False) logger.debug( diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index 4282e2b228..11f7320832 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -112,7 +112,7 @@ class ConsentResource(DirectServeHtmlResource): request (twisted.web.http.Request): """ version = parse_string(request, "v", default=self._default_consent_version) - username = parse_string(request, "u", required=False, default="") + username = parse_string(request, "u", default="") userhmac = None has_consented = False public_version = username == "" diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 8e7fead3a2..172212ee3a 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -186,15 +186,11 @@ class PreviewUrlResource(DirectServeJsonResource): respond_with_json(request, 200, {}, send_cors=True) async def _async_render_GET(self, request: SynapseRequest) -> None: - # This will always be set by the time Twisted calls us. - assert request.args is not None - # XXX: if get_user_by_req fails, what should we do in an async render? requester = await self.auth.get_user_by_req(request) - url = parse_string(request, "url") - if b"ts" in request.args: - ts = parse_integer(request, "ts") - else: + url = parse_string(request, "url", required=True) + ts = parse_integer(request, "ts") + if ts is None: ts = self.clock.time_msec() # XXX: we could move this into _do_preview if we wanted. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index ccf9ac51ef..4d4643619f 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -832,31 +832,16 @@ class DatabasePool: self, table: str, values: Dict[str, Any], - or_ignore: bool = False, desc: str = "simple_insert", - ) -> bool: + ) -> None: """Executes an INSERT query on the named table. Args: table: string giving the table name values: dict of new column names and values for them - or_ignore: bool stating whether an exception should be raised - when a conflicting row already exists. If True, False will be - returned by the function instead desc: description of the transaction, for logging and metrics - - Returns: - Whether the row was inserted or not. Only useful when `or_ignore` is True """ - try: - await self.runInteraction(desc, self.simple_insert_txn, table, values) - except self.engine.module.IntegrityError: - # We have to do or_ignore flag at this layer, since we can't reuse - # a cursor after we receive an error from the db. - if not or_ignore: - raise - return False - return True + await self.runInteraction(desc, self.simple_insert_txn, table, values) @staticmethod def simple_insert_txn( @@ -930,7 +915,7 @@ class DatabasePool: insertion_values: Optional[Dict[str, Any]] = None, desc: str = "simple_upsert", lock: bool = True, - ) -> Optional[bool]: + ) -> bool: """ `lock` should generally be set to True (the default), but can be set @@ -951,8 +936,8 @@ class DatabasePool: desc: description of the transaction, for logging and metrics lock: True to lock the table when doing the upsert. Returns: - Native upserts always return None. Emulated upserts return True if a - new entry was created, False if an existing one was updated. + Returns True if a row was inserted or updated (i.e. if `values` is + not empty then this always returns True) """ insertion_values = insertion_values or {} @@ -995,7 +980,7 @@ class DatabasePool: values: Dict[str, Any], insertion_values: Optional[Dict[str, Any]] = None, lock: bool = True, - ) -> Optional[bool]: + ) -> bool: """ Pick the UPSERT method which works best on the platform. Either the native one (Pg9.5+, recent SQLites), or fall back to an emulated method. @@ -1008,16 +993,15 @@ class DatabasePool: insertion_values: additional key/values to use only when inserting lock: True to lock the table when doing the upsert. Returns: - Native upserts always return None. Emulated upserts return True if a - new entry was created, False if an existing one was updated. + Returns True if a row was inserted or updated (i.e. if `values` is + not empty then this always returns True) """ insertion_values = insertion_values or {} if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: - self.simple_upsert_txn_native_upsert( + return self.simple_upsert_txn_native_upsert( txn, table, keyvalues, values, insertion_values=insertion_values ) - return None else: return self.simple_upsert_txn_emulated( txn, @@ -1045,8 +1029,8 @@ class DatabasePool: insertion_values: additional key/values to use only when inserting lock: True to lock the table when doing the upsert. Returns: - Returns True if a new entry was created, False if an existing - one was updated. + Returns True if a row was inserted or updated (i.e. if `values` is + not empty then this always returns True) """ insertion_values = insertion_values or {} @@ -1086,8 +1070,7 @@ class DatabasePool: txn.execute(sql, sqlargs) if txn.rowcount > 0: - # successfully updated at least one row. - return False + return True # We didn't find any existing rows, so insert a new one allvalues: Dict[str, Any] = {} @@ -1111,15 +1094,19 @@ class DatabasePool: keyvalues: Dict[str, Any], values: Dict[str, Any], insertion_values: Optional[Dict[str, Any]] = None, - ) -> None: + ) -> bool: """ - Use the native UPSERT functionality in recent PostgreSQL versions. + Use the native UPSERT functionality in PostgreSQL. Args: table: The table to upsert into keyvalues: The unique key tables and their new values values: The nonunique columns and their new values insertion_values: additional key/values to use only when inserting + + Returns: + Returns True if a row was inserted or updated (i.e. if `values` is + not empty then this always returns True) """ allvalues: Dict[str, Any] = {} allvalues.update(keyvalues) @@ -1140,6 +1127,8 @@ class DatabasePool: ) txn.execute(sql, list(allvalues.values())) + return bool(txn.rowcount) + async def simple_upsert_many( self, table: str, diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index a3fddea042..8d9f07111d 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -249,7 +249,7 @@ class DataStore( name: Optional[str] = None, guests: bool = True, deactivated: bool = False, - order_by: UserSortOrder = UserSortOrder.USER_ID.value, + order_by: str = UserSortOrder.USER_ID.value, direction: str = "f", ) -> Tuple[List[JsonDict], int]: """Function to retrieve a paginated list of users from @@ -297,27 +297,22 @@ class DataStore( where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" - sql_base = """ + sql_base = f""" FROM users as u LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ? - {} - """.format( - where_clause - ) + {where_clause} + """ sql = "SELECT COUNT(*) as total_users " + sql_base txn.execute(sql, args) count = txn.fetchone()[0] - sql = """ - SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, displayname, avatar_url + sql = f""" + SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, + displayname, avatar_url, creation_ts * 1000 as creation_ts {sql_base} ORDER BY {order_by_column} {order}, u.name ASC LIMIT ? OFFSET ? - """.format( - sql_base=sql_base, - order_by_column=order_by_column, - order=order, - ) + """ args += [limit, start] txn.execute(sql, args) users = self.db_pool.cursor_to_dict(txn) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 18f07d96dc..3816a0ca53 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1078,16 +1078,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return False try: - inserted = await self.db_pool.simple_insert( + inserted = await self.db_pool.simple_upsert( "devices", - values={ + keyvalues={ "user_id": user_id, "device_id": device_id, + }, + values={}, + insertion_values={ "display_name": initial_device_display_name, "hidden": False, }, desc="store_device", - or_ignore=True, ) if not inserted: # if the device already exists, check if it's a real device, or @@ -1099,6 +1101,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) if hidden: raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN) + self.device_id_exists_cache.set(key, True) return inserted except StoreError: diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index fe25638289..d213b26703 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -297,17 +297,13 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): Args: txn (cursor): user_id (str): user to add/update - - Returns: - bool: True if a new entry was created, False if an - existing one was updated. """ # Am consciously deciding to lock the table on the basis that is ought # never be a big table and alternative approaches (batching multiple # upserts into a single txn) introduced a lot of extra complexity. # See https://github.com/matrix-org/synapse/issues/3854 for more - is_insert = self.db_pool.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="monthly_active_users", keyvalues={"user_id": user_id}, @@ -322,8 +318,6 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): txn, self.user_last_seen_monthly_active, (user_id,) ) - return is_insert - async def populate_monthly_active_users(self, user_id): """Checks on the state of monthly active user limits and optionally add the user to the monthly active tables diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 6ddafe5434..443e5f3315 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -363,7 +363,7 @@ class RoomWorkerStore(SQLBaseStore): self, start: int, limit: int, - order_by: RoomSortOrder, + order_by: str, reverse_order: bool, search_term: Optional[str], ) -> Tuple[List[Dict[str, Any]], int]: diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 59d67c255b..42edbcc057 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -75,6 +75,7 @@ class UserSortOrder(Enum): USER_TYPE = ordered alphabetically by `user_type` AVATAR_URL = ordered alphabetically by `avatar_url` SHADOW_BANNED = ordered by `shadow_banned` + CREATION_TS = ordered by `creation_ts` """ MEDIA_LENGTH = "media_length" @@ -88,6 +89,7 @@ class UserSortOrder(Enum): USER_TYPE = "user_type" AVATAR_URL = "avatar_url" SHADOW_BANNED = "shadow_banned" + CREATION_TS = "creation_ts" class StatsStore(StateDeltasStore): @@ -647,10 +649,10 @@ class StatsStore(StateDeltasStore): limit: int, from_ts: Optional[int] = None, until_ts: Optional[int] = None, - order_by: Optional[UserSortOrder] = UserSortOrder.USER_ID.value, + order_by: Optional[str] = UserSortOrder.USER_ID.value, direction: Optional[str] = "f", search_term: Optional[str] = None, - ) -> Tuple[List[JsonDict], Dict[str, int]]: + ) -> Tuple[List[JsonDict], int]: """Function to retrieve a paginated list of users and their uploaded local media (size and number). This will return a json list of users and the total number of users matching the filter criteria. diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index d211c423b2..7728d5f102 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -134,16 +134,18 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): response_dict: The response, to be encoded into JSON. """ - await self.db_pool.simple_insert( + await self.db_pool.simple_upsert( table="received_transactions", - values={ + keyvalues={ "transaction_id": transaction_id, "origin": origin, + }, + values={}, + insertion_values={ "response_code": code, "response_json": db_binary_type(encode_canonical_json(response_dict)), "ts": self._clock.time_msec(), }, - or_ignore=True, desc="set_received_txn_response", ) diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index a6bfb4902a..9d28d69ac7 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -377,7 +377,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): avatar_url = None def _update_profile_in_user_dir_txn(txn): - new_entry = self.db_pool.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="user_directory", keyvalues={"user_id": user_id}, @@ -388,8 +388,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): if isinstance(self.database_engine, PostgresEngine): # We weight the localpart most highly, then display name and finally # server name - if self.database_engine.can_native_upsert: - sql = """ + sql = """ INSERT INTO user_directory_search(user_id, vector) VALUES (?, setweight(to_tsvector('simple', ?), 'A') @@ -397,58 +396,15 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): || setweight(to_tsvector('simple', COALESCE(?, '')), 'B') ) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector """ - txn.execute( - sql, - ( - user_id, - get_localpart_from_id(user_id), - get_domain_from_id(user_id), - display_name, - ), - ) - else: - # TODO: Remove this code after we've bumped the minimum version - # of postgres to always support upserts, so we can get rid of - # `new_entry` usage - if new_entry is True: - sql = """ - INSERT INTO user_directory_search(user_id, vector) - VALUES (?, - setweight(to_tsvector('simple', ?), 'A') - || setweight(to_tsvector('simple', ?), 'D') - || setweight(to_tsvector('simple', COALESCE(?, '')), 'B') - ) - """ - txn.execute( - sql, - ( - user_id, - get_localpart_from_id(user_id), - get_domain_from_id(user_id), - display_name, - ), - ) - elif new_entry is False: - sql = """ - UPDATE user_directory_search - SET vector = setweight(to_tsvector('simple', ?), 'A') - || setweight(to_tsvector('simple', ?), 'D') - || setweight(to_tsvector('simple', COALESCE(?, '')), 'B') - WHERE user_id = ? - """ - txn.execute( - sql, - ( - get_localpart_from_id(user_id), - get_domain_from_id(user_id), - display_name, - user_id, - ), - ) - else: - raise RuntimeError( - "upsert returned None when 'can_native_upsert' is False" - ) + txn.execute( + sql, + ( + user_id, + get_localpart_from_id(user_id), + get_domain_from_id(user_id), + display_name, + ), + ) elif isinstance(self.database_engine, Sqlite3Engine): value = "%s %s" % (user_id, display_name) if display_name else user_id self.db_pool.simple_upsert_txn( diff --git a/synapse/streams/config.py b/synapse/streams/config.py index 13d300588b..cf4005984b 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -47,20 +47,22 @@ class PaginationConfig: ) -> "PaginationConfig": direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"]) - from_tok = parse_string(request, "from") - to_tok = parse_string(request, "to") + from_tok_str = parse_string(request, "from") + to_tok_str = parse_string(request, "to") try: - if from_tok == "END": + from_tok = None + if from_tok_str == "END": from_tok = None # For backwards compat. - elif from_tok: - from_tok = await StreamToken.from_string(store, from_tok) + elif from_tok_str: + from_tok = await StreamToken.from_string(store, from_tok_str) except Exception: raise SynapseError(400, "'from' parameter is invalid") try: - if to_tok: - to_tok = await StreamToken.from_string(store, to_tok) + to_tok = None + if to_tok_str: + to_tok = await StreamToken.from_string(store, to_tok_str) except Exception: raise SynapseError(400, "'to' parameter is invalid") diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 6fee0f95b6..7198fd293f 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -261,7 +261,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) self.assertEqual( - "Missing integer query parameter b'before_ts'", channel.json_body["error"] + "Missing integer query parameter 'before_ts'", channel.json_body["error"] ) def test_invalid_parameter(self): @@ -303,7 +303,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual( - "Boolean query parameter b'keep_profiles' must be one of ['true', 'false']", + "Boolean query parameter 'keep_profiles' must be one of ['true', 'false']", channel.json_body["error"], ) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 4fccce34fd..42f50c0921 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -473,7 +473,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): @@ -485,7 +485,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", self.url, access_token=other_user_token) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_all_users(self): @@ -497,11 +497,11 @@ class UsersListTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", self.url + "?deactivated=true", - b"{}", + {}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(3, len(channel.json_body["users"])) self.assertEqual(3, channel.json_body["total"]) @@ -532,7 +532,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): ) channel = self.make_request( "GET", - url.encode("ascii"), + url, access_token=self.admin_user_tok, ) self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) @@ -598,7 +598,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from @@ -608,7 +608,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid guests @@ -618,7 +618,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # invalid deactivated @@ -628,7 +628,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # unkown order_by @@ -648,7 +648,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) def test_limit(self): @@ -666,7 +666,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 5) self.assertEqual(channel.json_body["next_token"], "5") @@ -687,7 +687,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 15) self.assertNotIn("next_token", channel.json_body) @@ -708,7 +708,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(channel.json_body["next_token"], "15") self.assertEqual(len(channel.json_body["users"]), 10) @@ -731,7 +731,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -744,7 +744,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -757,7 +757,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 19) self.assertEqual(channel.json_body["next_token"], "19") @@ -771,7 +771,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 1) self.assertNotIn("next_token", channel.json_body) @@ -781,7 +781,10 @@ class UsersListTestCase(unittest.HomeserverTestCase): Testing order list with parameter `order_by` """ + # make sure that the users do not have the same timestamps + self.reactor.advance(10) user1 = self.register_user("user1", "pass1", admin=False, displayname="Name Z") + self.reactor.advance(10) user2 = self.register_user("user2", "pass2", admin=False, displayname="Name Y") # Modify user @@ -841,6 +844,11 @@ class UsersListTestCase(unittest.HomeserverTestCase): self._order_test([self.admin_user, user2, user1], "avatar_url", "f") self._order_test([user1, user2, self.admin_user], "avatar_url", "b") + # order by creation_ts + self._order_test([self.admin_user, user1, user2], "creation_ts") + self._order_test([self.admin_user, user1, user2], "creation_ts", "f") + self._order_test([user2, user1, self.admin_user], "creation_ts", "b") + def _order_test( self, expected_user_list: List[str], @@ -863,7 +871,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): url += "dir=%s" % (dir,) channel = self.make_request( "GET", - url.encode("ascii"), + url, access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -887,6 +895,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertIn("shadow_banned", u) self.assertIn("displayname", u) self.assertIn("avatar_url", u) + self.assertIn("creation_ts", u) def _create_users(self, number_users: int): """ diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py index 874052c61c..f80f48a455 100644 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ b/tests/rest/client/v2_alpha/test_capabilities.py @@ -102,3 +102,49 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertFalse(capabilities["m.change_password"]["enabled"]) + + def test_get_does_not_include_msc3244_fields_by_default(self): + localpart = "user" + password = "pass" + user = self.register_user(localpart, password) + access_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + user, device_id=None, valid_until_ms=None + ) + ) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertNotIn( + "org.matrix.msc3244.room_capabilities", capabilities["m.room_versions"] + ) + + @override_config({"experimental_features": {"msc3244_enabled": True}}) + def test_get_does_include_msc3244_fields_when_enabled(self): + localpart = "user" + password = "pass" + user = self.register_user(localpart, password) + access_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + user, device_id=None, valid_until_ms=None + ) + ) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + for details in capabilities["m.room_versions"][ + "org.matrix.msc3244.room_capabilities" + ].values(): + if details["preferred"] is not None: + self.assertTrue( + details["preferred"] in KNOWN_ROOM_VERSIONS, + str(details["preferred"]), + ) + + self.assertGreater(len(details["support"]), 0) + for room_version in details["support"]: + self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, str(room_version)) |