diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 8d7e8cafd9..21c9ee7823 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -374,7 +374,7 @@ jobs:
working-directory: complement/dockerfiles
# Run Complement
- - run: go test -v -tags synapse_blacklist,msc2403,msc2946,msc3083 ./tests/...
+ - run: go test -v -tags synapse_blacklist,msc2403 ./tests/...
env:
COMPLEMENT_BASE_IMAGE: complement-synapse:latest
working-directory: complement
diff --git a/changelog.d/11029.misc b/changelog.d/11029.misc
new file mode 100644
index 0000000000..111de5fc7a
--- /dev/null
+++ b/changelog.d/11029.misc
@@ -0,0 +1 @@
+Improve type annotations in `synapse.module_api`.
\ No newline at end of file
diff --git a/changelog.d/11220.bugfix b/changelog.d/11220.bugfix
new file mode 100644
index 0000000000..8baae28d5b
--- /dev/null
+++ b/changelog.d/11220.bugfix
@@ -0,0 +1 @@
+Fix using MSC2716 batch sending in combination with event persistence workers. Contributed by @tulir at Beeper.
diff --git a/changelog.d/11306.feature b/changelog.d/11306.feature
new file mode 100644
index 0000000000..aba3292015
--- /dev/null
+++ b/changelog.d/11306.feature
@@ -0,0 +1 @@
+Add plugin support for controlling database background updates.
diff --git a/changelog.d/11329.feature b/changelog.d/11329.feature
new file mode 100644
index 0000000000..7e0efb3b00
--- /dev/null
+++ b/changelog.d/11329.feature
@@ -0,0 +1 @@
+Support the stable API endpoints for [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946): the room `/hierarchy` endpoint.
diff --git a/changelog.d/11356.misc b/changelog.d/11356.misc
new file mode 100644
index 0000000000..01ce6a306c
--- /dev/null
+++ b/changelog.d/11356.misc
@@ -0,0 +1 @@
+Add `Final` annotation to string constants in `synapse.api.constants` so that they get typed as `Literal`s.
diff --git a/changelog.d/11376.bugfix b/changelog.d/11376.bugfix
new file mode 100644
index 0000000000..639e48b59b
--- /dev/null
+++ b/changelog.d/11376.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug where all requests that read events from the database could get stuck as a result of losing the database connection, for real this time. Also fix a race condition introduced in the previous insufficient fix in 1.47.0.
diff --git a/changelog.d/11409.misc b/changelog.d/11409.misc
new file mode 100644
index 0000000000..f9e8ae9e3a
--- /dev/null
+++ b/changelog.d/11409.misc
@@ -0,0 +1 @@
+Improve internal types in push code.
diff --git a/changelog.d/11411.misc b/changelog.d/11411.misc
new file mode 100644
index 0000000000..86594a332d
--- /dev/null
+++ b/changelog.d/11411.misc
@@ -0,0 +1 @@
+Add type hints to storage classes.
diff --git a/changelog.d/11413.bugfix b/changelog.d/11413.bugfix
new file mode 100644
index 0000000000..44111d8152
--- /dev/null
+++ b/changelog.d/11413.bugfix
@@ -0,0 +1 @@
+The `/send_join` response now includes the stable `event` field instead of the unstable field from [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083).
diff --git a/changelog.d/11415.doc b/changelog.d/11415.doc
new file mode 100644
index 0000000000..e405531867
--- /dev/null
+++ b/changelog.d/11415.doc
@@ -0,0 +1 @@
+Update the media repository documentation.
diff --git a/changelog.d/11417.misc b/changelog.d/11417.misc
new file mode 100644
index 0000000000..88dc4722da
--- /dev/null
+++ b/changelog.d/11417.misc
@@ -0,0 +1 @@
+Refactor `backfilled` into specific behavior function arguments (`_persist_events_and_state_updates` and downstream calls).
diff --git a/changelog.d/11425.feature b/changelog.d/11425.feature
new file mode 100644
index 0000000000..806dd5d91c
--- /dev/null
+++ b/changelog.d/11425.feature
@@ -0,0 +1 @@
+Support expiry of refresh tokens and expiry of the overall session when refresh tokens are in use.
\ No newline at end of file
diff --git a/changelog.d/11428.misc b/changelog.d/11428.misc
new file mode 100644
index 0000000000..2f814fa5fb
--- /dev/null
+++ b/changelog.d/11428.misc
@@ -0,0 +1 @@
+Add type annotations to some of the configuration surrounding refresh tokens.
\ No newline at end of file
diff --git a/changelog.d/11429.docker b/changelog.d/11429.docker
new file mode 100644
index 0000000000..81db719ed6
--- /dev/null
+++ b/changelog.d/11429.docker
@@ -0,0 +1 @@
+Update `Dockerfile-workers` to healthcheck all workers in container.
diff --git a/changelog.d/11430.misc b/changelog.d/11430.misc
new file mode 100644
index 0000000000..28f06f4c4e
--- /dev/null
+++ b/changelog.d/11430.misc
@@ -0,0 +1 @@
+Update [MSC2918 refresh token](https://github.com/matrix-org/matrix-doc/blob/main/proposals/2918-refreshtokens.md#msc2918-refresh-tokens) support to confirm with the latest revision: accept the `refresh_tokens` parameter in the request body rather than in the URL parameters.
\ No newline at end of file
diff --git a/changelog.d/11439.bugfix b/changelog.d/11439.bugfix
new file mode 100644
index 0000000000..fc6bc82b36
--- /dev/null
+++ b/changelog.d/11439.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in 1.47.0 where `send_join` could fail due to an outdated `ijson` version.
diff --git a/changelog.d/11440.bugfix b/changelog.d/11440.bugfix
new file mode 100644
index 0000000000..02ce2e428f
--- /dev/null
+++ b/changelog.d/11440.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in Synapse 1.36 which could cause problems fetching event-signing keys from trusted key servers.
diff --git a/changelog.d/11441.bugfix b/changelog.d/11441.bugfix
new file mode 100644
index 0000000000..1baef41d70
--- /dev/null
+++ b/changelog.d/11441.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in 1.47.0 where `send_join` could fail due to an outdated `ijson` version.
\ No newline at end of file
diff --git a/changelog.d/11452.misc b/changelog.d/11452.misc
new file mode 100644
index 0000000000..7c83f62e3f
--- /dev/null
+++ b/changelog.d/11452.misc
@@ -0,0 +1 @@
+Convert status codes to `HTTPStatus` in `synapse.rest.admin`.
\ No newline at end of file
diff --git a/changelog.d/11455.misc b/changelog.d/11455.misc
new file mode 100644
index 0000000000..7c83f62e3f
--- /dev/null
+++ b/changelog.d/11455.misc
@@ -0,0 +1 @@
+Convert status codes to `HTTPStatus` in `synapse.rest.admin`.
\ No newline at end of file
diff --git a/changelog.d/11459.feature b/changelog.d/11459.feature
new file mode 100644
index 0000000000..4cb97dc1d0
--- /dev/null
+++ b/changelog.d/11459.feature
@@ -0,0 +1 @@
+`synctl stop` will now wait for Synapse to exit before returning.
diff --git a/changelog.d/11460.misc b/changelog.d/11460.misc
new file mode 100644
index 0000000000..fc6bc82b36
--- /dev/null
+++ b/changelog.d/11460.misc
@@ -0,0 +1 @@
+Fix a bug introduced in 1.47.0 where `send_join` could fail due to an outdated `ijson` version.
diff --git a/docker/Dockerfile-workers b/docker/Dockerfile-workers
index 969cf97286..46f2e17382 100644
--- a/docker/Dockerfile-workers
+++ b/docker/Dockerfile-workers
@@ -21,3 +21,6 @@ VOLUME ["/data"]
# files to run the desired worker configuration. Will start supervisord.
COPY ./docker/configure_workers_and_start.py /configure_workers_and_start.py
ENTRYPOINT ["/configure_workers_and_start.py"]
+
+HEALTHCHECK --start-period=5s --interval=15s --timeout=5s \
+ CMD /bin/sh /healthcheck.sh
diff --git a/docker/conf-workers/healthcheck.sh.j2 b/docker/conf-workers/healthcheck.sh.j2
new file mode 100644
index 0000000000..79c621f89c
--- /dev/null
+++ b/docker/conf-workers/healthcheck.sh.j2
@@ -0,0 +1,6 @@
+#!/bin/sh
+# This healthcheck script is designed to return OK when every
+# host involved returns OK
+{%- for healthcheck_url in healthcheck_urls %}
+curl -fSs {{ healthcheck_url }} || exit 1
+{%- endfor %}
diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py
index f4ac1c22a4..adbb551cee 100755
--- a/docker/configure_workers_and_start.py
+++ b/docker/configure_workers_and_start.py
@@ -474,10 +474,16 @@ def generate_worker_files(environ, config_path: str, data_dir: str):
# Determine the load-balancing upstreams to configure
nginx_upstream_config = ""
+
+ # At the same time, prepare a list of internal endpoints to healthcheck
+ # starting with the main process which exists even if no workers do.
+ healthcheck_urls = ["http://localhost:8080/health"]
+
for upstream_worker_type, upstream_worker_ports in nginx_upstreams.items():
body = ""
for port in upstream_worker_ports:
body += " server localhost:%d;\n" % (port,)
+ healthcheck_urls.append("http://localhost:%d/health" % (port,))
# Add to the list of configured upstreams
nginx_upstream_config += NGINX_UPSTREAM_CONFIG_BLOCK.format(
@@ -510,6 +516,13 @@ def generate_worker_files(environ, config_path: str, data_dir: str):
worker_config=supervisord_config,
)
+ # healthcheck config
+ convert(
+ "/conf/healthcheck.sh.j2",
+ "/healthcheck.sh",
+ healthcheck_urls=healthcheck_urls,
+ )
+
# Ensure the logging directory exists
log_dir = data_dir + "/logs"
if not os.path.exists(log_dir):
diff --git a/docs/media_repository.md b/docs/media_repository.md
index 99ee8f1ef7..ba17f8a856 100644
--- a/docs/media_repository.md
+++ b/docs/media_repository.md
@@ -2,29 +2,80 @@
*Synapse implementation-specific details for the media repository*
-The media repository is where attachments and avatar photos are stored.
-It stores attachment content and thumbnails for media uploaded by local users.
-It caches attachment content and thumbnails for media uploaded by remote users.
+The media repository
+ * stores avatars, attachments and their thumbnails for media uploaded by local
+ users.
+ * caches avatars, attachments and their thumbnails for media uploaded by remote
+ users.
+ * caches resources and thumbnails used for
+ [URL previews](development/url_previews.md).
-## Storage
+All media in Matrix can be identified by a unique
+[MXC URI](https://spec.matrix.org/latest/client-server-api/#matrix-content-mxc-uris),
+consisting of a server name and media ID:
+```
+mxc://<server-name>/<media-id>
+```
-Each item of media is assigned a `media_id` when it is uploaded.
-The `media_id` is a randomly chosen, URL safe 24 character string.
+## Local Media
+Synapse generates 24 character media IDs for content uploaded by local users.
+These media IDs consist of upper and lowercase letters and are case-sensitive.
+Other homeserver implementations may generate media IDs differently.
-Metadata such as the MIME type, upload time and length are stored in the
-sqlite3 database indexed by `media_id`.
+Local media is recorded in the `local_media_repository` table, which includes
+metadata such as MIME types, upload times and file sizes.
+Note that this table is shared by the URL cache, which has a different media ID
+scheme.
-Content is stored on the filesystem under a `"local_content"` directory.
+### Paths
+A file with media ID `aabbcccccccccccccccccccc` and its `128x96` `image/jpeg`
+thumbnail, created by scaling, would be stored at:
+```
+local_content/aa/bb/cccccccccccccccccccc
+local_thumbnails/aa/bb/cccccccccccccccccccc/128-96-image-jpeg-scale
+```
-Thumbnails are stored under a `"local_thumbnails"` directory.
+## Remote Media
+When media from a remote homeserver is requested from Synapse, it is assigned
+a local `filesystem_id`, with the same format as locally-generated media IDs,
+as described above.
-The item with `media_id` `"aabbccccccccdddddddddddd"` is stored under
-`"local_content/aa/bb/ccccccccdddddddddddd"`. Its thumbnail with width
-`128` and height `96` and type `"image/jpeg"` is stored under
-`"local_thumbnails/aa/bb/ccccccccdddddddddddd/128-96-image-jpeg"`
+A record of remote media is stored in the `remote_media_cache` table, which
+can be used to map remote MXC URIs (server names and media IDs) to local
+`filesystem_id`s.
-Remote content is cached under `"remote_content"` directory. Each item of
-remote content is assigned a local `"filesystem_id"` to ensure that the
-directory structure `"remote_content/server_name/aa/bb/ccccccccdddddddddddd"`
-is appropriate. Thumbnails for remote content are stored under
-`"remote_thumbnail/server_name/..."`
+### Paths
+A file from `matrix.org` with `filesystem_id` `aabbcccccccccccccccccccc` and its
+`128x96` `image/jpeg` thumbnail, created by scaling, would be stored at:
+```
+remote_content/matrix.org/aa/bb/cccccccccccccccccccc
+remote_thumbnail/matrix.org/aa/bb/cccccccccccccccccccc/128-96-image-jpeg-scale
+```
+Older thumbnails may omit the thumbnailing method:
+```
+remote_thumbnail/matrix.org/aa/bb/cccccccccccccccccccc/128-96-image-jpeg
+```
+
+Note that `remote_thumbnail/` does not have an `s`.
+
+## URL Previews
+See [URL Previews](development/url_previews.md) for documentation on the URL preview
+process.
+
+When generating previews for URLs, Synapse may download and cache various
+resources, including images. These resources are assigned temporary media IDs
+of the form `yyyy-mm-dd_aaaaaaaaaaaaaaaa`, where `yyyy-mm-dd` is the current
+date and `aaaaaaaaaaaaaaaa` is a random sequence of 16 case-sensitive letters.
+
+The metadata for these cached resources is stored in the
+`local_media_repository` and `local_media_repository_url_cache` tables.
+
+Resources for URL previews are deleted after a few days.
+
+### Paths
+The file with media ID `yyyy-mm-dd_aaaaaaaaaaaaaaaa` and its `128x96`
+`image/jpeg` thumbnail, created by scaling, would be stored at:
+```
+url_cache/yyyy-mm-dd/aaaaaaaaaaaaaaaa
+url_cache_thumbnails/yyyy-mm-dd/aaaaaaaaaaaaaaaa/128-96-image-jpeg-scale
+```
diff --git a/docs/modules/background_update_controller_callbacks.md b/docs/modules/background_update_controller_callbacks.md
new file mode 100644
index 0000000000..b3e7c259f4
--- /dev/null
+++ b/docs/modules/background_update_controller_callbacks.md
@@ -0,0 +1,71 @@
+# Background update controller callbacks
+
+Background update controller callbacks allow module developers to control (e.g. rate-limit)
+how database background updates are run. A database background update is an operation
+Synapse runs on its database in the background after it starts. It's usually used to run
+database operations that would take too long if they were run at the same time as schema
+updates (which are run on startup) and delay Synapse's startup too much: populating a
+table with a big amount of data, adding an index on a big table, deleting superfluous data,
+etc.
+
+Background update controller callbacks can be registered using the module API's
+`register_background_update_controller_callbacks` method. Only the first module (in order
+of appearance in Synapse's configuration file) calling this method can register background
+update controller callbacks, subsequent calls are ignored.
+
+The available background update controller callbacks are:
+
+### `on_update`
+
+_First introduced in Synapse v1.49.0_
+
+```python
+def on_update(update_name: str, database_name: str, one_shot: bool) -> AsyncContextManager[int]
+```
+
+Called when about to do an iteration of a background update. The module is given the name
+of the update, the name of the database, and a flag to indicate whether the background
+update will happen in one go and may take a long time (e.g. creating indices). If this last
+argument is set to `False`, the update will be run in batches.
+
+The module must return an async context manager. It will be entered before Synapse runs a
+background update; this should return the desired duration of the iteration, in
+milliseconds.
+
+The context manager will be exited when the iteration completes. Note that the duration
+returned by the context manager is a target, and an iteration may take substantially longer
+or shorter. If the `one_shot` flag is set to `True`, the duration returned is ignored.
+
+__Note__: Unlike most module callbacks in Synapse, this one is _synchronous_. This is
+because asynchronous operations are expected to be run by the async context manager.
+
+This callback is required when registering any other background update controller callback.
+
+### `default_batch_size`
+
+_First introduced in Synapse v1.49.0_
+
+```python
+async def default_batch_size(update_name: str, database_name: str) -> int
+```
+
+Called before the first iteration of a background update, with the name of the update and
+of the database. The module must return the number of elements to process in this first
+iteration.
+
+If this callback is not defined, Synapse will use a default value of 100.
+
+### `min_batch_size`
+
+_First introduced in Synapse v1.49.0_
+
+```python
+async def min_batch_size(update_name: str, database_name: str) -> int
+```
+
+Called before running a new batch for a background update, with the name of the update and
+of the database. The module must return an integer representing the minimum number of
+elements to process in this iteration. This number must be at least 1, and is used to
+ensure that progress is always made.
+
+If this callback is not defined, Synapse will use a default value of 100.
diff --git a/docs/modules/writing_a_module.md b/docs/modules/writing_a_module.md
index 7764e06692..e7c0ffad58 100644
--- a/docs/modules/writing_a_module.md
+++ b/docs/modules/writing_a_module.md
@@ -71,15 +71,15 @@ Modules **must** register their web resources in their `__init__` method.
## Registering a callback
Modules can use Synapse's module API to register callbacks. Callbacks are functions that
-Synapse will call when performing specific actions. Callbacks must be asynchronous, and
-are split in categories. A single module may implement callbacks from multiple categories,
-and is under no obligation to implement all callbacks from the categories it registers
-callbacks for.
+Synapse will call when performing specific actions. Callbacks must be asynchronous (unless
+specified otherwise), and are split in categories. A single module may implement callbacks
+from multiple categories, and is under no obligation to implement all callbacks from the
+categories it registers callbacks for.
Modules can register callbacks using one of the module API's `register_[...]_callbacks`
methods. The callback functions are passed to these methods as keyword arguments, with
-the callback name as the argument name and the function as its value. This is demonstrated
-in the example below. A `register_[...]_callbacks` method exists for each category.
+the callback name as the argument name and the function as its value. A
+`register_[...]_callbacks` method exists for each category.
Callbacks for each category can be found on their respective page of the
[Synapse documentation website](https://matrix-org.github.io/synapse).
\ No newline at end of file
diff --git a/docs/templates.md b/docs/templates.md
index a240f58b54..2b66e9d862 100644
--- a/docs/templates.md
+++ b/docs/templates.md
@@ -71,7 +71,12 @@ Below are the templates Synapse will look for when generating the content of an
* `sender_avatar_url`: the avatar URL (as a `mxc://` URL) for the event's
sender
* `sender_hash`: a hash of the user ID of the sender
+ * `msgtype`: the type of the message
+ * `body_text_html`: html representation of the message
+ * `body_text_plain`: plaintext representation of the message
+ * `image_url`: mxc url of an image, when "msgtype" is "m.image"
* `link`: a `matrix.to` link to the room
+ * `avator_url`: url to the room's avator
* `reason`: information on the event that triggered the email to be sent. It's an
object with the following attributes:
* `room_id`: the ID of the room the event was sent in
diff --git a/docs/workers.md b/docs/workers.md
index 17c8bfeef6..fd83e2ddeb 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -210,7 +210,7 @@ expressions:
^/_matrix/federation/v1/get_groups_publicised$
^/_matrix/key/v2/query
^/_matrix/federation/unstable/org.matrix.msc2946/spaces/
- ^/_matrix/federation/unstable/org.matrix.msc2946/hierarchy/
+ ^/_matrix/federation/(v1|unstable/org.matrix.msc2946)/hierarchy/
# Inbound federation transaction request
^/_matrix/federation/v1/send/
@@ -223,7 +223,7 @@ expressions:
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/members$
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/state$
^/_matrix/client/unstable/org.matrix.msc2946/rooms/.*/spaces$
- ^/_matrix/client/unstable/org.matrix.msc2946/rooms/.*/hierarchy$
+ ^/_matrix/client/(v1|unstable/org.matrix.msc2946)/rooms/.*/hierarchy$
^/_matrix/client/unstable/im.nheko.summary/rooms/.*/summary$
^/_matrix/client/(api/v1|r0|v3|unstable)/account/3pid$
^/_matrix/client/(api/v1|r0|v3|unstable)/devices$
diff --git a/mypy.ini b/mypy.ini
index bc4f59154d..51056a8f64 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -33,7 +33,6 @@ exclude = (?x)
|synapse/storage/databases/main/event_federation.py
|synapse/storage/databases/main/event_push_actions.py
|synapse/storage/databases/main/events_bg_updates.py
- |synapse/storage/databases/main/events_worker.py
|synapse/storage/databases/main/group_server.py
|synapse/storage/databases/main/metrics.py
|synapse/storage/databases/main/monthly_active_users.py
@@ -166,6 +165,9 @@ disallow_untyped_defs = True
[mypy-synapse.metrics.*]
disallow_untyped_defs = True
+[mypy-synapse.module_api.*]
+disallow_untyped_defs = True
+
[mypy-synapse.push.*]
disallow_untyped_defs = True
@@ -184,6 +186,9 @@ disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.directory]
disallow_untyped_defs = True
+[mypy-synapse.storage.databases.main.events_worker]
+disallow_untyped_defs = True
+
[mypy-synapse.storage.databases.main.room_batch]
disallow_untyped_defs = True
@@ -220,6 +225,10 @@ disallow_untyped_defs = True
[mypy-tests.rest.client.test_directory]
disallow_untyped_defs = True
+[mypy-tests.federation.transport.test_client]
+disallow_untyped_defs = True
+
+
;; Dependencies without annotations
;; Before ignoring a module, check to see if type stubs are available.
;; The `typeshed` project maintains stubs here:
diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh
index 29568eded8..53295b58fc 100755
--- a/scripts-dev/complement.sh
+++ b/scripts-dev/complement.sh
@@ -65,4 +65,4 @@ if [[ -n "$1" ]]; then
fi
# Run the tests!
-go test -v -tags synapse_blacklist,msc2946,msc3083,msc2403 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/...
+go test -v -tags synapse_blacklist,msc2403 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/...
diff --git a/setup.py b/setup.py
index 0ce8beb004..ad99b3bd2c 100755
--- a/setup.py
+++ b/setup.py
@@ -119,7 +119,9 @@ CONDITIONAL_REQUIREMENTS["mypy"] = [
# Tests assume that all optional dependencies are installed.
#
# parameterized_class decorator was introduced in parameterized 0.7.0
-CONDITIONAL_REQUIREMENTS["test"] = ["parameterized>=0.7.0"]
+#
+# We use `mock` library as that backports `AsyncMock` to Python 3.6
+CONDITIONAL_REQUIREMENTS["test"] = ["parameterized>=0.7.0", "mock>=4.0.0"]
CONDITIONAL_REQUIREMENTS["dev"] = (
CONDITIONAL_REQUIREMENTS["lint"]
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index a33ac34161..f7d29b4319 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -17,6 +17,8 @@
"""Contains constants from the specification."""
+from typing_extensions import Final
+
# the max size of a (canonical-json-encoded) event
MAX_PDU_SIZE = 65536
@@ -39,125 +41,125 @@ class Membership:
"""Represents the membership states of a user in a room."""
- INVITE = "invite"
- JOIN = "join"
- KNOCK = "knock"
- LEAVE = "leave"
- BAN = "ban"
- LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN)
+ INVITE: Final = "invite"
+ JOIN: Final = "join"
+ KNOCK: Final = "knock"
+ LEAVE: Final = "leave"
+ BAN: Final = "ban"
+ LIST: Final = (INVITE, JOIN, KNOCK, LEAVE, BAN)
class PresenceState:
"""Represents the presence state of a user."""
- OFFLINE = "offline"
- UNAVAILABLE = "unavailable"
- ONLINE = "online"
- BUSY = "org.matrix.msc3026.busy"
+ OFFLINE: Final = "offline"
+ UNAVAILABLE: Final = "unavailable"
+ ONLINE: Final = "online"
+ BUSY: Final = "org.matrix.msc3026.busy"
class JoinRules:
- PUBLIC = "public"
- KNOCK = "knock"
- INVITE = "invite"
- PRIVATE = "private"
+ PUBLIC: Final = "public"
+ KNOCK: Final = "knock"
+ INVITE: Final = "invite"
+ PRIVATE: Final = "private"
# As defined for MSC3083.
- RESTRICTED = "restricted"
+ RESTRICTED: Final = "restricted"
class RestrictedJoinRuleTypes:
"""Understood types for the allow rules in restricted join rules."""
- ROOM_MEMBERSHIP = "m.room_membership"
+ ROOM_MEMBERSHIP: Final = "m.room_membership"
class LoginType:
- PASSWORD = "m.login.password"
- EMAIL_IDENTITY = "m.login.email.identity"
- MSISDN = "m.login.msisdn"
- RECAPTCHA = "m.login.recaptcha"
- TERMS = "m.login.terms"
- SSO = "m.login.sso"
- DUMMY = "m.login.dummy"
- REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token"
+ PASSWORD: Final = "m.login.password"
+ EMAIL_IDENTITY: Final = "m.login.email.identity"
+ MSISDN: Final = "m.login.msisdn"
+ RECAPTCHA: Final = "m.login.recaptcha"
+ TERMS: Final = "m.login.terms"
+ SSO: Final = "m.login.sso"
+ DUMMY: Final = "m.login.dummy"
+ REGISTRATION_TOKEN: Final = "org.matrix.msc3231.login.registration_token"
# This is used in the `type` parameter for /register when called by
# an appservice to register a new user.
-APP_SERVICE_REGISTRATION_TYPE = "m.login.application_service"
+APP_SERVICE_REGISTRATION_TYPE: Final = "m.login.application_service"
class EventTypes:
- Member = "m.room.member"
- Create = "m.room.create"
- Tombstone = "m.room.tombstone"
- JoinRules = "m.room.join_rules"
- PowerLevels = "m.room.power_levels"
- Aliases = "m.room.aliases"
- Redaction = "m.room.redaction"
- ThirdPartyInvite = "m.room.third_party_invite"
- RelatedGroups = "m.room.related_groups"
-
- RoomHistoryVisibility = "m.room.history_visibility"
- CanonicalAlias = "m.room.canonical_alias"
- Encrypted = "m.room.encrypted"
- RoomAvatar = "m.room.avatar"
- RoomEncryption = "m.room.encryption"
- GuestAccess = "m.room.guest_access"
+ Member: Final = "m.room.member"
+ Create: Final = "m.room.create"
+ Tombstone: Final = "m.room.tombstone"
+ JoinRules: Final = "m.room.join_rules"
+ PowerLevels: Final = "m.room.power_levels"
+ Aliases: Final = "m.room.aliases"
+ Redaction: Final = "m.room.redaction"
+ ThirdPartyInvite: Final = "m.room.third_party_invite"
+ RelatedGroups: Final = "m.room.related_groups"
+
+ RoomHistoryVisibility: Final = "m.room.history_visibility"
+ CanonicalAlias: Final = "m.room.canonical_alias"
+ Encrypted: Final = "m.room.encrypted"
+ RoomAvatar: Final = "m.room.avatar"
+ RoomEncryption: Final = "m.room.encryption"
+ GuestAccess: Final = "m.room.guest_access"
# These are used for validation
- Message = "m.room.message"
- Topic = "m.room.topic"
- Name = "m.room.name"
+ Message: Final = "m.room.message"
+ Topic: Final = "m.room.topic"
+ Name: Final = "m.room.name"
- ServerACL = "m.room.server_acl"
- Pinned = "m.room.pinned_events"
+ ServerACL: Final = "m.room.server_acl"
+ Pinned: Final = "m.room.pinned_events"
- Retention = "m.room.retention"
+ Retention: Final = "m.room.retention"
- Dummy = "org.matrix.dummy_event"
+ Dummy: Final = "org.matrix.dummy_event"
- SpaceChild = "m.space.child"
- SpaceParent = "m.space.parent"
+ SpaceChild: Final = "m.space.child"
+ SpaceParent: Final = "m.space.parent"
- MSC2716_INSERTION = "org.matrix.msc2716.insertion"
- MSC2716_BATCH = "org.matrix.msc2716.batch"
- MSC2716_MARKER = "org.matrix.msc2716.marker"
+ MSC2716_INSERTION: Final = "org.matrix.msc2716.insertion"
+ MSC2716_BATCH: Final = "org.matrix.msc2716.batch"
+ MSC2716_MARKER: Final = "org.matrix.msc2716.marker"
class ToDeviceEventTypes:
- RoomKeyRequest = "m.room_key_request"
+ RoomKeyRequest: Final = "m.room_key_request"
class DeviceKeyAlgorithms:
"""Spec'd algorithms for the generation of per-device keys"""
- ED25519 = "ed25519"
- CURVE25519 = "curve25519"
- SIGNED_CURVE25519 = "signed_curve25519"
+ ED25519: Final = "ed25519"
+ CURVE25519: Final = "curve25519"
+ SIGNED_CURVE25519: Final = "signed_curve25519"
class EduTypes:
- Presence = "m.presence"
+ Presence: Final = "m.presence"
class RejectedReason:
- AUTH_ERROR = "auth_error"
+ AUTH_ERROR: Final = "auth_error"
class RoomCreationPreset:
- PRIVATE_CHAT = "private_chat"
- PUBLIC_CHAT = "public_chat"
- TRUSTED_PRIVATE_CHAT = "trusted_private_chat"
+ PRIVATE_CHAT: Final = "private_chat"
+ PUBLIC_CHAT: Final = "public_chat"
+ TRUSTED_PRIVATE_CHAT: Final = "trusted_private_chat"
class ThirdPartyEntityKind:
- USER = "user"
- LOCATION = "location"
+ USER: Final = "user"
+ LOCATION: Final = "location"
-ServerNoticeMsgType = "m.server_notice"
-ServerNoticeLimitReached = "m.server_notice.usage_limit_reached"
+ServerNoticeMsgType: Final = "m.server_notice"
+ServerNoticeLimitReached: Final = "m.server_notice.usage_limit_reached"
class UserTypes:
@@ -165,91 +167,91 @@ class UserTypes:
'admin' and 'guest' users should also be UserTypes. Normal users are type None
"""
- SUPPORT = "support"
- BOT = "bot"
- ALL_USER_TYPES = (SUPPORT, BOT)
+ SUPPORT: Final = "support"
+ BOT: Final = "bot"
+ ALL_USER_TYPES: Final = (SUPPORT, BOT)
class RelationTypes:
"""The types of relations known to this server."""
- ANNOTATION = "m.annotation"
- REPLACE = "m.replace"
- REFERENCE = "m.reference"
- THREAD = "io.element.thread"
+ ANNOTATION: Final = "m.annotation"
+ REPLACE: Final = "m.replace"
+ REFERENCE: Final = "m.reference"
+ THREAD: Final = "io.element.thread"
class LimitBlockingTypes:
"""Reasons that a server may be blocked"""
- MONTHLY_ACTIVE_USER = "monthly_active_user"
- HS_DISABLED = "hs_disabled"
+ MONTHLY_ACTIVE_USER: Final = "monthly_active_user"
+ HS_DISABLED: Final = "hs_disabled"
class EventContentFields:
"""Fields found in events' content, regardless of type."""
# Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326
- LABELS = "org.matrix.labels"
+ LABELS: Final = "org.matrix.labels"
# Timestamp to delete the event after
# cf https://github.com/matrix-org/matrix-doc/pull/2228
- SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after"
+ SELF_DESTRUCT_AFTER: Final = "org.matrix.self_destruct_after"
# cf https://github.com/matrix-org/matrix-doc/pull/1772
- ROOM_TYPE = "type"
+ ROOM_TYPE: Final = "type"
# Whether a room can federate.
- FEDERATE = "m.federate"
+ FEDERATE: Final = "m.federate"
# The creator of the room, as used in `m.room.create` events.
- ROOM_CREATOR = "creator"
+ ROOM_CREATOR: Final = "creator"
# Used in m.room.guest_access events.
- GUEST_ACCESS = "guest_access"
+ GUEST_ACCESS: Final = "guest_access"
# Used on normal messages to indicate they were historically imported after the fact
- MSC2716_HISTORICAL = "org.matrix.msc2716.historical"
+ MSC2716_HISTORICAL: Final = "org.matrix.msc2716.historical"
# For "insertion" events to indicate what the next batch ID should be in
# order to connect to it
- MSC2716_NEXT_BATCH_ID = "org.matrix.msc2716.next_batch_id"
+ MSC2716_NEXT_BATCH_ID: Final = "org.matrix.msc2716.next_batch_id"
# Used on "batch" events to indicate which insertion event it connects to
- MSC2716_BATCH_ID = "org.matrix.msc2716.batch_id"
+ MSC2716_BATCH_ID: Final = "org.matrix.msc2716.batch_id"
# For "marker" events
- MSC2716_MARKER_INSERTION = "org.matrix.msc2716.marker.insertion"
+ MSC2716_MARKER_INSERTION: Final = "org.matrix.msc2716.marker.insertion"
# The authorising user for joining a restricted room.
- AUTHORISING_USER = "join_authorised_via_users_server"
+ AUTHORISING_USER: Final = "join_authorised_via_users_server"
class RoomTypes:
"""Understood values of the room_type field of m.room.create events."""
- SPACE = "m.space"
+ SPACE: Final = "m.space"
class RoomEncryptionAlgorithms:
- MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2"
- DEFAULT = MEGOLM_V1_AES_SHA2
+ MEGOLM_V1_AES_SHA2: Final = "m.megolm.v1.aes-sha2"
+ DEFAULT: Final = MEGOLM_V1_AES_SHA2
class AccountDataTypes:
- DIRECT = "m.direct"
- IGNORED_USER_LIST = "m.ignored_user_list"
+ DIRECT: Final = "m.direct"
+ IGNORED_USER_LIST: Final = "m.ignored_user_list"
class HistoryVisibility:
- INVITED = "invited"
- JOINED = "joined"
- SHARED = "shared"
- WORLD_READABLE = "world_readable"
+ INVITED: Final = "invited"
+ JOINED: Final = "joined"
+ SHARED: Final = "shared"
+ WORLD_READABLE: Final = "world_readable"
class GuestAccess:
- CAN_JOIN = "can_join"
+ CAN_JOIN: Final = "can_join"
# anything that is not "can_join" is considered "forbidden", but for completeness:
- FORBIDDEN = "forbidden"
+ FORBIDDEN: Final = "forbidden"
class ReadReceiptEventFields:
- MSC2285_HIDDEN = "org.matrix.msc2285.hidden"
+ MSC2285_HIDDEN: Final = "org.matrix.msc2285.hidden"
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 502cc8e8d1..b4bed5bf40 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -113,6 +113,7 @@ from synapse.storage.databases.main.monthly_active_users import (
)
from synapse.storage.databases.main.presence import PresenceStore
from synapse.storage.databases.main.room import RoomWorkerStore
+from synapse.storage.databases.main.room_batch import RoomBatchStore
from synapse.storage.databases.main.search import SearchStore
from synapse.storage.databases.main.session import SessionStore
from synapse.storage.databases.main.stats import StatsStore
@@ -240,6 +241,7 @@ class GenericWorkerSlavedStore(
SlavedEventStore,
SlavedKeyStore,
RoomWorkerStore,
+ RoomBatchStore,
DirectoryStore,
SlavedApplicationServiceStore,
SlavedRegistrationStore,
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 7e09530ad2..52541faab2 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -194,6 +194,7 @@ class SynapseHomeServer(HomeServer):
{
"/_matrix/client/api/v1": client_resource,
"/_matrix/client/r0": client_resource,
+ "/_matrix/client/v1": client_resource,
"/_matrix/client/v3": client_resource,
"/_matrix/client/unstable": client_resource,
"/_matrix/client/v2_alpha": client_resource,
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 61e569d412..1ddad7cb70 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
from synapse.api.constants import RoomCreationPreset
from synapse.config._base import Config, ConfigError
@@ -113,32 +114,24 @@ class RegistrationConfig(Config):
self.session_lifetime = session_lifetime
# The `refreshable_access_token_lifetime` applies for tokens that can be renewed
- # using a refresh token, as per MSC2918. If it is `None`, the refresh
- # token mechanism is disabled.
- #
- # Since it is incompatible with the `session_lifetime` mechanism, it is set to
- # `None` by default if a `session_lifetime` is set.
+ # using a refresh token, as per MSC2918.
+ # If it is `None`, the refresh token mechanism is disabled.
refreshable_access_token_lifetime = config.get(
"refreshable_access_token_lifetime",
- "5m" if session_lifetime is None else None,
+ "5m",
)
if refreshable_access_token_lifetime is not None:
refreshable_access_token_lifetime = self.parse_duration(
refreshable_access_token_lifetime
)
- self.refreshable_access_token_lifetime = refreshable_access_token_lifetime
-
- if (
- session_lifetime is not None
- and refreshable_access_token_lifetime is not None
- ):
- raise ConfigError(
- "The refresh token mechanism is incompatible with the "
- "`session_lifetime` option. Consider disabling the "
- "`session_lifetime` option or disabling the refresh token "
- "mechanism by removing the `refreshable_access_token_lifetime` "
- "option."
- )
+ self.refreshable_access_token_lifetime: Optional[
+ int
+ ] = refreshable_access_token_lifetime
+
+ refresh_token_lifetime = config.get("refresh_token_lifetime")
+ if refresh_token_lifetime is not None:
+ refresh_token_lifetime = self.parse_duration(refresh_token_lifetime)
+ self.refresh_token_lifetime: Optional[int] = refresh_token_lifetime
# The fallback template used for authenticating using a registration token
self.registration_token_template = self.read_template("registration_token.html")
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 4cda439ad9..993b04099e 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -667,21 +667,25 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
perspective_name,
)
+ request: JsonDict = {}
+ for queue_value in keys_to_fetch:
+ # there may be multiple requests for each server, so we have to merge
+ # them intelligently.
+ request_for_server = {
+ key_id: {
+ "minimum_valid_until_ts": queue_value.minimum_valid_until_ts,
+ }
+ for key_id in queue_value.key_ids
+ }
+ request.setdefault(queue_value.server_name, {}).update(request_for_server)
+
+ logger.debug("Request to notary server %s: %s", perspective_name, request)
+
try:
query_response = await self.client.post_json(
destination=perspective_name,
path="/_matrix/key/v2/query",
- data={
- "server_keys": {
- queue_value.server_name: {
- key_id: {
- "minimum_valid_until_ts": queue_value.minimum_valid_until_ts,
- }
- for key_id in queue_value.key_ids
- }
- for queue_value in keys_to_fetch
- }
- },
+ data={"server_keys": request},
)
except (NotRetryingDestination, RequestSendFailed) as e:
# these both have str() representations which we can't really improve upon
@@ -689,6 +693,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
except HttpResponseException as e:
raise KeyLookupError("Remote server returned an error: %s" % (e,))
+ logger.debug(
+ "Response from notary server %s: %s", perspective_name, query_response
+ )
+
keys: Dict[str, Dict[str, FetchKeyResult]] = {}
added_keys: List[Tuple[str, str, FetchKeyResult]] = []
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index d7527008c4..f251402ed8 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -322,6 +322,11 @@ class _AsyncEventContextImpl(EventContext):
attributes by loading from the database.
"""
if self.state_group is None:
+ # No state group means the event is an outlier. Usually the state_ids dicts are also
+ # pre-set to empty dicts, but they get reset when the context is serialized, so set
+ # them to empty dicts again here.
+ self._current_state_ids = {}
+ self._prev_state_ids = {}
return
current_state_ids = await self._storage.state.get_state_ids_for_group(
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 3b85b135e0..bc3f96c1fc 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -1395,11 +1395,28 @@ class FederationClient(FederationBase):
async def send_request(
destination: str,
) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]:
- res = await self.transport_layer.get_room_hierarchy(
- destination=destination,
- room_id=room_id,
- suggested_only=suggested_only,
- )
+ try:
+ res = await self.transport_layer.get_room_hierarchy(
+ destination=destination,
+ room_id=room_id,
+ suggested_only=suggested_only,
+ )
+ except HttpResponseException as e:
+ # If an error is received that is due to an unrecognised endpoint,
+ # fallback to the unstable endpoint. Otherwise consider it a
+ # legitmate error and raise.
+ if not self._is_unknown_endpoint(e):
+ raise
+
+ logger.debug(
+ "Couldn't fetch room hierarchy with the v1 API, falling back to the unstable API"
+ )
+
+ res = await self.transport_layer.get_room_hierarchy_unstable(
+ destination=destination,
+ room_id=room_id,
+ suggested_only=suggested_only,
+ )
room = res.get("room")
if not isinstance(room, dict):
@@ -1449,6 +1466,10 @@ class FederationClient(FederationBase):
if e.code != 502:
raise
+ logger.debug(
+ "Couldn't fetch room hierarchy, falling back to the spaces API"
+ )
+
# Fallback to the old federation API and translate the results if
# no servers implement the new API.
#
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 9a8758e9a6..8fbc75aa65 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -613,8 +613,11 @@ class FederationServer(FederationBase):
state = await self.store.get_events(state_ids)
time_now = self._clock.time_msec()
+ event_json = event.get_pdu_json()
return {
- "org.matrix.msc3083.v2.event": event.get_pdu_json(),
+ # TODO Remove the unstable prefix when servers have updated.
+ "org.matrix.msc3083.v2.event": event_json,
+ "event": event_json,
"state": [p.get_pdu_json(time_now) for p in state.values()],
"auth_chain": [p.get_pdu_json(time_now) for p in auth_chain],
}
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 10b5aa5af8..fe29bcfd4b 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -1192,10 +1192,24 @@ class TransportLayerClient:
)
async def get_room_hierarchy(
- self,
- destination: str,
- room_id: str,
- suggested_only: bool,
+ self, destination: str, room_id: str, suggested_only: bool
+ ) -> JsonDict:
+ """
+ Args:
+ destination: The remote server
+ room_id: The room ID to ask about.
+ suggested_only: if True, only suggested rooms will be returned
+ """
+ path = _create_v1_path("/hierarchy/%s", room_id)
+
+ return await self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"suggested_only": "true" if suggested_only else "false"},
+ )
+
+ async def get_room_hierarchy_unstable(
+ self, destination: str, room_id: str, suggested_only: bool
) -> JsonDict:
"""
Args:
@@ -1317,15 +1331,26 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
prefix + "auth_chain.item",
use_float=True,
)
- self._coro_event = ijson.kvitems_coro(
+ # TODO Remove the unstable prefix when servers have updated.
+ #
+ # By re-using the same event dictionary this will cause the parsing of
+ # org.matrix.msc3083.v2.event and event to stomp over each other.
+ # Generally this should be fine.
+ self._coro_unstable_event = ijson.kvitems_coro(
_event_parser(self._response.event_dict),
prefix + "org.matrix.msc3083.v2.event",
use_float=True,
)
+ self._coro_event = ijson.kvitems_coro(
+ _event_parser(self._response.event_dict),
+ prefix + "event",
+ use_float=True,
+ )
def write(self, data: bytes) -> int:
self._coro_state.send(data)
self._coro_auth.send(data)
+ self._coro_unstable_event.send(data)
self._coro_event.send(data)
return len(data)
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index 2fdf6cc99e..66e915228c 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -611,7 +611,6 @@ class FederationSpaceSummaryServlet(BaseFederationServlet):
class FederationRoomHierarchyServlet(BaseFederationServlet):
- PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946"
PATH = "/hierarchy/(?P<room_id>[^/]*)"
def __init__(
@@ -637,6 +636,10 @@ class FederationRoomHierarchyServlet(BaseFederationServlet):
)
+class FederationRoomHierarchyUnstableServlet(FederationRoomHierarchyServlet):
+ PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946"
+
+
class RoomComplexityServlet(BaseFederationServlet):
"""
Indicates to other servers how complex (and therefore likely
@@ -701,6 +704,7 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
RoomComplexityServlet,
FederationSpaceSummaryServlet,
FederationRoomHierarchyServlet,
+ FederationRoomHierarchyUnstableServlet,
FederationV1SendKnockServlet,
FederationMakeKnockServlet,
)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 4b66a9862f..4d9c4e5834 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -18,6 +18,7 @@ import time
import unicodedata
import urllib.parse
from binascii import crc32
+from http import HTTPStatus
from typing import (
TYPE_CHECKING,
Any,
@@ -756,53 +757,109 @@ class AuthHandler:
async def refresh_token(
self,
refresh_token: str,
- valid_until_ms: Optional[int],
- ) -> Tuple[str, str]:
+ access_token_valid_until_ms: Optional[int],
+ refresh_token_valid_until_ms: Optional[int],
+ ) -> Tuple[str, str, Optional[int]]:
"""
Consumes a refresh token and generate both a new access token and a new refresh token from it.
The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token.
+ The lifetime of both the access token and refresh token will be capped so that they
+ do not exceed the session's ultimate expiry time, if applicable.
+
Args:
refresh_token: The token to consume.
- valid_until_ms: The expiration timestamp of the new access token.
-
+ access_token_valid_until_ms: The expiration timestamp of the new access token.
+ None if the access token does not expire.
+ refresh_token_valid_until_ms: The expiration timestamp of the new refresh token.
+ None if the refresh token does not expire.
Returns:
- A tuple containing the new access token and refresh token
+ A tuple containing:
+ - the new access token
+ - the new refresh token
+ - the actual expiry time of the access token, which may be earlier than
+ `access_token_valid_until_ms`.
"""
# Verify the token signature first before looking up the token
if not self._verify_refresh_token(refresh_token):
- raise SynapseError(401, "invalid refresh token", Codes.UNKNOWN_TOKEN)
+ raise SynapseError(
+ HTTPStatus.UNAUTHORIZED, "invalid refresh token", Codes.UNKNOWN_TOKEN
+ )
existing_token = await self.store.lookup_refresh_token(refresh_token)
if existing_token is None:
- raise SynapseError(401, "refresh token does not exist", Codes.UNKNOWN_TOKEN)
+ raise SynapseError(
+ HTTPStatus.UNAUTHORIZED,
+ "refresh token does not exist",
+ Codes.UNKNOWN_TOKEN,
+ )
if (
existing_token.has_next_access_token_been_used
or existing_token.has_next_refresh_token_been_refreshed
):
raise SynapseError(
- 403, "refresh token isn't valid anymore", Codes.FORBIDDEN
+ HTTPStatus.FORBIDDEN,
+ "refresh token isn't valid anymore",
+ Codes.FORBIDDEN,
+ )
+
+ now_ms = self._clock.time_msec()
+
+ if existing_token.expiry_ts is not None and existing_token.expiry_ts < now_ms:
+
+ raise SynapseError(
+ HTTPStatus.FORBIDDEN,
+ "The supplied refresh token has expired",
+ Codes.FORBIDDEN,
)
+ if existing_token.ultimate_session_expiry_ts is not None:
+ # This session has a bounded lifetime, even across refreshes.
+
+ if access_token_valid_until_ms is not None:
+ access_token_valid_until_ms = min(
+ access_token_valid_until_ms,
+ existing_token.ultimate_session_expiry_ts,
+ )
+ else:
+ access_token_valid_until_ms = existing_token.ultimate_session_expiry_ts
+
+ if refresh_token_valid_until_ms is not None:
+ refresh_token_valid_until_ms = min(
+ refresh_token_valid_until_ms,
+ existing_token.ultimate_session_expiry_ts,
+ )
+ else:
+ refresh_token_valid_until_ms = existing_token.ultimate_session_expiry_ts
+ if existing_token.ultimate_session_expiry_ts < now_ms:
+ raise SynapseError(
+ HTTPStatus.FORBIDDEN,
+ "The session has expired and can no longer be refreshed",
+ Codes.FORBIDDEN,
+ )
+
(
new_refresh_token,
new_refresh_token_id,
) = await self.create_refresh_token_for_user_id(
- user_id=existing_token.user_id, device_id=existing_token.device_id
+ user_id=existing_token.user_id,
+ device_id=existing_token.device_id,
+ expiry_ts=refresh_token_valid_until_ms,
+ ultimate_session_expiry_ts=existing_token.ultimate_session_expiry_ts,
)
access_token = await self.create_access_token_for_user_id(
user_id=existing_token.user_id,
device_id=existing_token.device_id,
- valid_until_ms=valid_until_ms,
+ valid_until_ms=access_token_valid_until_ms,
refresh_token_id=new_refresh_token_id,
)
await self.store.replace_refresh_token(
existing_token.token_id, new_refresh_token_id
)
- return access_token, new_refresh_token
+ return access_token, new_refresh_token, access_token_valid_until_ms
def _verify_refresh_token(self, token: str) -> bool:
"""
@@ -836,6 +893,8 @@ class AuthHandler:
self,
user_id: str,
device_id: str,
+ expiry_ts: Optional[int],
+ ultimate_session_expiry_ts: Optional[int],
) -> Tuple[str, int]:
"""
Creates a new refresh token for the user with the given user ID.
@@ -843,6 +902,13 @@ class AuthHandler:
Args:
user_id: canonical user ID
device_id: the device ID to associate with the token.
+ expiry_ts (milliseconds since the epoch): Time after which the
+ refresh token cannot be used.
+ If None, the refresh token never expires until it has been used.
+ ultimate_session_expiry_ts (milliseconds since the epoch):
+ Time at which the session will end and can not be extended any
+ further.
+ If None, the session can be refreshed indefinitely.
Returns:
The newly created refresh token and its ID in the database
@@ -852,6 +918,8 @@ class AuthHandler:
user_id=user_id,
token=refresh_token,
device_id=device_id,
+ expiry_ts=expiry_ts,
+ ultimate_session_expiry_ts=ultimate_session_expiry_ts,
)
return refresh_token, refresh_token_id
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 448a36108e..24ca11b924 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -119,6 +119,7 @@ class RegistrationHandler:
self.refreshable_access_token_lifetime = (
hs.config.registration.refreshable_access_token_lifetime
)
+ self.refresh_token_lifetime = hs.config.registration.refresh_token_lifetime
init_counters_for_auth_provider("")
@@ -793,13 +794,13 @@ class RegistrationHandler:
class and RegisterDeviceReplicationServlet.
"""
assert not self.hs.config.worker.worker_app
- valid_until_ms = None
+ access_token_expiry = None
if self.session_lifetime is not None:
if is_guest:
raise Exception(
"session_lifetime is not currently implemented for guest access"
)
- valid_until_ms = self.clock.time_msec() + self.session_lifetime
+ access_token_expiry = self.clock.time_msec() + self.session_lifetime
refresh_token = None
refresh_token_id = None
@@ -808,25 +809,57 @@ class RegistrationHandler:
user_id, device_id, initial_display_name
)
if is_guest:
- assert valid_until_ms is None
+ assert access_token_expiry is None
access_token = self.macaroon_gen.generate_guest_access_token(user_id)
else:
if should_issue_refresh_token:
+ # A refreshable access token lifetime must be configured
+ # since we're told to issue a refresh token (the caller checks
+ # that this value is set before setting this flag).
+ assert self.refreshable_access_token_lifetime is not None
+
+ now_ms = self.clock.time_msec()
+
+ # Set the expiry time of the refreshable access token
+ access_token_expiry = now_ms + self.refreshable_access_token_lifetime
+
+ # Set the refresh token expiry time (if configured)
+ refresh_token_expiry = None
+ if self.refresh_token_lifetime is not None:
+ refresh_token_expiry = now_ms + self.refresh_token_lifetime
+
+ # Set an ultimate session expiry time (if configured)
+ ultimate_session_expiry_ts = None
+ if self.session_lifetime is not None:
+ ultimate_session_expiry_ts = now_ms + self.session_lifetime
+
+ # Also ensure that the issued tokens don't outlive the
+ # session.
+ # (It would be weird to configure a homeserver with a shorter
+ # session lifetime than token lifetime, but may as well handle
+ # it.)
+ access_token_expiry = min(
+ access_token_expiry, ultimate_session_expiry_ts
+ )
+ if refresh_token_expiry is not None:
+ refresh_token_expiry = min(
+ refresh_token_expiry, ultimate_session_expiry_ts
+ )
+
(
refresh_token,
refresh_token_id,
) = await self._auth_handler.create_refresh_token_for_user_id(
user_id,
device_id=registered_device_id,
- )
- valid_until_ms = (
- self.clock.time_msec() + self.refreshable_access_token_lifetime
+ expiry_ts=refresh_token_expiry,
+ ultimate_session_expiry_ts=ultimate_session_expiry_ts,
)
access_token = await self._auth_handler.create_access_token_for_user_id(
user_id,
device_id=registered_device_id,
- valid_until_ms=valid_until_ms,
+ valid_until_ms=access_token_expiry,
is_appservice_ghost=is_appservice_ghost,
refresh_token_id=refresh_token_id,
)
@@ -834,7 +867,7 @@ class RegistrationHandler:
return {
"device_id": registered_device_id,
"access_token": access_token,
- "valid_until_ms": valid_until_ms,
+ "valid_until_ms": access_token_expiry,
"refresh_token": refresh_token,
}
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index 8181cc0b52..b2cfe537df 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -36,8 +36,9 @@ from synapse.api.errors import (
SynapseError,
UnsupportedRoomVersionError,
)
+from synapse.api.ratelimiting import Ratelimiter
from synapse.events import EventBase
-from synapse.types import JsonDict
+from synapse.types import JsonDict, Requester
from synapse.util.caches.response_cache import ResponseCache
if TYPE_CHECKING:
@@ -93,6 +94,9 @@ class RoomSummaryHandler:
self._event_serializer = hs.get_event_client_serializer()
self._server_name = hs.hostname
self._federation_client = hs.get_federation_client()
+ self._ratelimiter = Ratelimiter(
+ store=self._store, clock=hs.get_clock(), rate_hz=5, burst_count=10
+ )
# If a user tries to fetch the same page multiple times in quick succession,
# only process the first attempt and return its result to subsequent requests.
@@ -249,7 +253,7 @@ class RoomSummaryHandler:
async def get_room_hierarchy(
self,
- requester: str,
+ requester: Requester,
requested_room_id: str,
suggested_only: bool = False,
max_depth: Optional[int] = None,
@@ -276,6 +280,8 @@ class RoomSummaryHandler:
Returns:
The JSON hierarchy dictionary.
"""
+ await self._ratelimiter.ratelimit(requester)
+
# If a user tries to fetch the same page multiple times in quick succession,
# only process the first attempt and return its result to subsequent requests.
#
@@ -283,7 +289,7 @@ class RoomSummaryHandler:
# to process multiple requests for the same page will result in errors.
return await self._pagination_response_cache.wrap(
(
- requester,
+ requester.user.to_string(),
requested_room_id,
suggested_only,
max_depth,
@@ -291,7 +297,7 @@ class RoomSummaryHandler:
from_token,
),
self._get_room_hierarchy,
- requester,
+ requester.user.to_string(),
requested_room_id,
suggested_only,
max_depth,
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 96d7a8f2a9..a8154168be 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -24,6 +24,7 @@ from typing import (
List,
Optional,
Tuple,
+ TypeVar,
Union,
)
@@ -81,10 +82,19 @@ from synapse.http.server import (
)
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest
-from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.logging.context import (
+ defer_to_thread,
+ make_deferred_yieldable,
+ run_in_background,
+)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.client.login import LoginResponse
from synapse.storage import DataStore
+from synapse.storage.background_updates import (
+ DEFAULT_BATCH_SIZE_CALLBACK,
+ MIN_BATCH_SIZE_CALLBACK,
+ ON_UPDATE_CALLBACK,
+)
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.roommember import ProfileInfo
from synapse.storage.state import StateFilter
@@ -104,6 +114,9 @@ if TYPE_CHECKING:
from synapse.app.generic_worker import GenericWorkerSlavedStore
from synapse.server import HomeServer
+
+T = TypeVar("T")
+
"""
This package defines the 'stable' API which can be used by extension modules which
are loaded into Synapse.
@@ -307,7 +320,25 @@ class ModuleApi:
auth_checkers=auth_checkers,
)
- def register_web_resource(self, path: str, resource: Resource):
+ def register_background_update_controller_callbacks(
+ self,
+ on_update: ON_UPDATE_CALLBACK,
+ default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
+ min_batch_size: Optional[MIN_BATCH_SIZE_CALLBACK] = None,
+ ) -> None:
+ """Registers background update controller callbacks.
+
+ Added in Synapse v1.49.0.
+ """
+
+ for db in self._hs.get_datastores().databases:
+ db.updates.register_update_controller_callbacks(
+ on_update=on_update,
+ default_batch_size=default_batch_size,
+ min_batch_size=min_batch_size,
+ )
+
+ def register_web_resource(self, path: str, resource: Resource) -> None:
"""Registers a web resource to be served at the given path.
This function should be called during initialisation of the module.
@@ -432,7 +463,7 @@ class ModuleApi:
username: provided user id
Returns:
- str: qualified @user:id
+ qualified @user:id
"""
if username.startswith("@"):
return username
@@ -468,7 +499,7 @@ class ModuleApi:
"""
return await self._store.user_get_threepids(user_id)
- def check_user_exists(self, user_id: str):
+ def check_user_exists(self, user_id: str) -> "defer.Deferred[Optional[str]]":
"""Check if user exists.
Added in Synapse v0.25.0.
@@ -477,13 +508,18 @@ class ModuleApi:
user_id: Complete @user:id
Returns:
- Deferred[str|None]: Canonical (case-corrected) user_id, or None
+ Canonical (case-corrected) user_id, or None
if the user is not registered.
"""
return defer.ensureDeferred(self._auth_handler.check_user_exists(user_id))
@defer.inlineCallbacks
- def register(self, localpart, displayname=None, emails: Optional[List[str]] = None):
+ def register(
+ self,
+ localpart: str,
+ displayname: Optional[str] = None,
+ emails: Optional[List[str]] = None,
+ ) -> Generator["defer.Deferred[Any]", Any, Tuple[str, str]]:
"""Registers a new user with given localpart and optional displayname, emails.
Also returns an access token for the new user.
@@ -495,12 +531,12 @@ class ModuleApi:
Added in Synapse v0.25.0.
Args:
- localpart (str): The localpart of the new user.
- displayname (str|None): The displayname of the new user.
- emails (List[str]): Emails to bind to the new user.
+ localpart: The localpart of the new user.
+ displayname: The displayname of the new user.
+ emails: Emails to bind to the new user.
Returns:
- Deferred[tuple[str, str]]: a 2-tuple of (user_id, access_token)
+ a 2-tuple of (user_id, access_token)
"""
logger.warning(
"Using deprecated ModuleApi.register which creates a dummy user device."
@@ -510,23 +546,26 @@ class ModuleApi:
return user_id, access_token
def register_user(
- self, localpart, displayname=None, emails: Optional[List[str]] = None
- ):
+ self,
+ localpart: str,
+ displayname: Optional[str] = None,
+ emails: Optional[List[str]] = None,
+ ) -> "defer.Deferred[str]":
"""Registers a new user with given localpart and optional displayname, emails.
Added in Synapse v1.2.0.
Args:
- localpart (str): The localpart of the new user.
- displayname (str|None): The displayname of the new user.
- emails (List[str]): Emails to bind to the new user.
+ localpart: The localpart of the new user.
+ displayname: The displayname of the new user.
+ emails: Emails to bind to the new user.
Raises:
SynapseError if there is an error performing the registration. Check the
'errcode' property for more information on the reason for failure
Returns:
- defer.Deferred[str]: user_id
+ user_id
"""
return defer.ensureDeferred(
self._hs.get_registration_handler().register_user(
@@ -536,20 +575,25 @@ class ModuleApi:
)
)
- def register_device(self, user_id, device_id=None, initial_display_name=None):
+ def register_device(
+ self,
+ user_id: str,
+ device_id: Optional[str] = None,
+ initial_display_name: Optional[str] = None,
+ ) -> "defer.Deferred[Tuple[str, str, Optional[int], Optional[str]]]":
"""Register a device for a user and generate an access token.
Added in Synapse v1.2.0.
Args:
- user_id (str): full canonical @user:id
- device_id (str|None): The device ID to check, or None to generate
+ user_id: full canonical @user:id
+ device_id: The device ID to check, or None to generate
a new one.
- initial_display_name (str|None): An optional display name for the
+ initial_display_name: An optional display name for the
device.
Returns:
- defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
+ Tuple of device ID, access token, access token expiration time and refresh token
"""
return defer.ensureDeferred(
self._hs.get_registration_handler().register_device(
@@ -603,7 +647,9 @@ class ModuleApi:
)
@defer.inlineCallbacks
- def invalidate_access_token(self, access_token):
+ def invalidate_access_token(
+ self, access_token: str
+ ) -> Generator["defer.Deferred[Any]", Any, None]:
"""Invalidate an access token for a user
Added in Synapse v0.25.0.
@@ -635,14 +681,20 @@ class ModuleApi:
self._auth_handler.delete_access_token(access_token)
)
- def run_db_interaction(self, desc, func, *args, **kwargs):
+ def run_db_interaction(
+ self,
+ desc: str,
+ func: Callable[..., T],
+ *args: Any,
+ **kwargs: Any,
+ ) -> "defer.Deferred[T]":
"""Run a function with a database connection
Added in Synapse v0.25.0.
Args:
- desc (str): description for the transaction, for metrics etc
- func (func): function to be run. Passed a database cursor object
+ desc: description for the transaction, for metrics etc
+ func: function to be run. Passed a database cursor object
as well as *args and **kwargs
*args: positional args to be passed to func
**kwargs: named args to be passed to func
@@ -656,7 +708,7 @@ class ModuleApi:
def complete_sso_login(
self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
- ):
+ ) -> None:
"""Complete a SSO login by redirecting the user to a page to confirm whether they
want their access token sent to `client_redirect_url`, or redirect them to that
URL with a token directly if the URL matches with one of the whitelisted clients.
@@ -686,7 +738,7 @@ class ModuleApi:
client_redirect_url: str,
new_user: bool = False,
auth_provider_id: str = "<unknown>",
- ):
+ ) -> None:
"""Complete a SSO login by redirecting the user to a page to confirm whether they
want their access token sent to `client_redirect_url`, or redirect them to that
URL with a token directly if the URL matches with one of the whitelisted clients.
@@ -925,11 +977,11 @@ class ModuleApi:
self,
f: Callable,
msec: float,
- *args,
+ *args: object,
desc: Optional[str] = None,
run_on_all_instances: bool = False,
- **kwargs,
- ):
+ **kwargs: object,
+ ) -> None:
"""Wraps a function as a background process and calls it repeatedly.
NOTE: Will only run on the instance that is configured to run
@@ -970,13 +1022,18 @@ class ModuleApi:
f,
)
+ async def sleep(self, seconds: float) -> None:
+ """Sleeps for the given number of seconds."""
+
+ await self._clock.sleep(seconds)
+
async def send_mail(
self,
recipient: str,
subject: str,
html: str,
text: str,
- ):
+ ) -> None:
"""Send an email on behalf of the homeserver.
Added in Synapse v1.39.0.
@@ -1124,6 +1181,26 @@ class ModuleApi:
return {key: state_events[event_id] for key, event_id in state_ids.items()}
+ async def defer_to_thread(
+ self,
+ f: Callable[..., T],
+ *args: Any,
+ **kwargs: Any,
+ ) -> T:
+ """Runs the given function in a separate thread from Synapse's thread pool.
+
+ Added in Synapse v1.49.0.
+
+ Args:
+ f: The function to run.
+ args: The function's arguments.
+ kwargs: The function's keyword arguments.
+
+ Returns:
+ The return value of the function once ran in a thread.
+ """
+ return await defer_to_thread(self._hs.get_reactor(), f, *args, **kwargs)
+
class PublicRoomListManager:
"""Contains methods for adding to, removing from and querying whether a room
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index cf5abdfbda..4f13c0418a 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -21,6 +21,8 @@ from twisted.internet.interfaces import IDelayedCall
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import Pusher, PusherConfig, PusherConfigException, ThrottleParams
from synapse.push.mailer import Mailer
+from synapse.push.push_types import EmailReason
+from synapse.storage.databases.main.event_push_actions import EmailPushAction
from synapse.util.threepids import validate_email
if TYPE_CHECKING:
@@ -190,7 +192,7 @@ class EmailPusher(Pusher):
# we then consider all previously outstanding notifications
# to be delivered.
- reason = {
+ reason: EmailReason = {
"room_id": push_action["room_id"],
"now": self.clock.time_msec(),
"received_at": received_at,
@@ -275,7 +277,7 @@ class EmailPusher(Pusher):
return may_send_at
async def sent_notif_update_throttle(
- self, room_id: str, notified_push_action: dict
+ self, room_id: str, notified_push_action: EmailPushAction
) -> None:
# We have sent a notification, so update the throttle accordingly.
# If the event that triggered the notif happened more than
@@ -315,7 +317,9 @@ class EmailPusher(Pusher):
self.pusher_id, room_id, self.throttle_params[room_id]
)
- async def send_notification(self, push_actions: List[dict], reason: dict) -> None:
+ async def send_notification(
+ self, push_actions: List[EmailPushAction], reason: EmailReason
+ ) -> None:
logger.info("Sending notif email for user %r", self.user_id)
await self.mailer.send_notification_mail(
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index dbf4ad7f97..3fa603ccb7 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -26,6 +26,7 @@ from synapse.events import EventBase
from synapse.logging import opentracing
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import Pusher, PusherConfig, PusherConfigException
+from synapse.storage.databases.main.event_push_actions import HttpPushAction
from . import push_rule_evaluator, push_tools
@@ -273,7 +274,7 @@ class HttpPusher(Pusher):
)
break
- async def _process_one(self, push_action: dict) -> bool:
+ async def _process_one(self, push_action: HttpPushAction) -> bool:
if "notify" not in push_action["actions"]:
return True
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index ce299ba3da..ba4f866487 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -14,7 +14,7 @@
import logging
import urllib.parse
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, TypeVar
import bleach
import jinja2
@@ -28,6 +28,14 @@ from synapse.push.presentable_names import (
descriptor_from_member_events,
name_from_member_event,
)
+from synapse.push.push_types import (
+ EmailReason,
+ MessageVars,
+ NotifVars,
+ RoomVars,
+ TemplateVars,
+)
+from synapse.storage.databases.main.event_push_actions import EmailPushAction
from synapse.storage.state import StateFilter
from synapse.types import StateMap, UserID
from synapse.util.async_helpers import concurrently_execute
@@ -135,7 +143,7 @@ class Mailer:
% urllib.parse.urlencode(params)
)
- template_vars = {"link": link}
+ template_vars: TemplateVars = {"link": link}
await self.send_email(
email_address,
@@ -165,7 +173,7 @@ class Mailer:
% urllib.parse.urlencode(params)
)
- template_vars = {"link": link}
+ template_vars: TemplateVars = {"link": link}
await self.send_email(
email_address,
@@ -196,7 +204,7 @@ class Mailer:
% urllib.parse.urlencode(params)
)
- template_vars = {"link": link}
+ template_vars: TemplateVars = {"link": link}
await self.send_email(
email_address,
@@ -210,8 +218,8 @@ class Mailer:
app_id: str,
user_id: str,
email_address: str,
- push_actions: Iterable[Dict[str, Any]],
- reason: Dict[str, Any],
+ push_actions: Iterable[EmailPushAction],
+ reason: EmailReason,
) -> None:
"""
Send email regarding a user's room notifications
@@ -230,7 +238,7 @@ class Mailer:
[pa["event_id"] for pa in push_actions]
)
- notifs_by_room: Dict[str, List[Dict[str, Any]]] = {}
+ notifs_by_room: Dict[str, List[EmailPushAction]] = {}
for pa in push_actions:
notifs_by_room.setdefault(pa["room_id"], []).append(pa)
@@ -258,7 +266,7 @@ class Mailer:
# actually sort our so-called rooms_in_order list, most recent room first
rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1]["received_ts"] or 0))
- rooms: List[Dict[str, Any]] = []
+ rooms: List[RoomVars] = []
for r in rooms_in_order:
roomvars = await self._get_room_vars(
@@ -289,7 +297,7 @@ class Mailer:
notifs_by_room, state_by_room, notif_events, reason
)
- template_vars = {
+ template_vars: TemplateVars = {
"user_display_name": user_display_name,
"unsubscribe_link": self._make_unsubscribe_link(
user_id, app_id, email_address
@@ -302,10 +310,10 @@ class Mailer:
await self.send_email(email_address, summary_text, template_vars)
async def send_email(
- self, email_address: str, subject: str, extra_template_vars: Dict[str, Any]
+ self, email_address: str, subject: str, extra_template_vars: TemplateVars
) -> None:
"""Send an email with the given information and template text"""
- template_vars = {
+ template_vars: TemplateVars = {
"app_name": self.app_name,
"server_name": self.hs.config.server.server_name,
}
@@ -327,10 +335,10 @@ class Mailer:
self,
room_id: str,
user_id: str,
- notifs: Iterable[Dict[str, Any]],
+ notifs: Iterable[EmailPushAction],
notif_events: Dict[str, EventBase],
room_state_ids: StateMap[str],
- ) -> Dict[str, Any]:
+ ) -> RoomVars:
"""
Generate the variables for notifications on a per-room basis.
@@ -356,7 +364,7 @@ class Mailer:
room_name = await calculate_room_name(self.store, room_state_ids, user_id)
- room_vars: Dict[str, Any] = {
+ room_vars: RoomVars = {
"title": room_name,
"hash": string_ordinal_total(room_id), # See sender avatar hash
"notifs": [],
@@ -417,11 +425,11 @@ class Mailer:
async def _get_notif_vars(
self,
- notif: Dict[str, Any],
+ notif: EmailPushAction,
user_id: str,
notif_event: EventBase,
room_state_ids: StateMap[str],
- ) -> Dict[str, Any]:
+ ) -> NotifVars:
"""
Generate the variables for a single notification.
@@ -442,7 +450,7 @@ class Mailer:
after_limit=CONTEXT_AFTER,
)
- ret = {
+ ret: NotifVars = {
"link": self._make_notif_link(notif),
"ts": notif["received_ts"],
"messages": [],
@@ -461,8 +469,8 @@ class Mailer:
return ret
async def _get_message_vars(
- self, notif: Dict[str, Any], event: EventBase, room_state_ids: StateMap[str]
- ) -> Optional[Dict[str, Any]]:
+ self, notif: EmailPushAction, event: EventBase, room_state_ids: StateMap[str]
+ ) -> Optional[MessageVars]:
"""
Generate the variables for a single event, if possible.
@@ -494,7 +502,9 @@ class Mailer:
if sender_state_event:
sender_name = name_from_member_event(sender_state_event)
- sender_avatar_url = sender_state_event.content.get("avatar_url")
+ sender_avatar_url: Optional[str] = sender_state_event.content.get(
+ "avatar_url"
+ )
else:
# No state could be found, fallback to the MXID.
sender_name = event.sender
@@ -504,7 +514,7 @@ class Mailer:
# sender_hash % the number of default images to choose from
sender_hash = string_ordinal_total(event.sender)
- ret = {
+ ret: MessageVars = {
"event_type": event.type,
"is_historical": event.event_id != notif["event_id"],
"id": event.event_id,
@@ -519,6 +529,8 @@ class Mailer:
return ret
msgtype = event.content.get("msgtype")
+ if not isinstance(msgtype, str):
+ msgtype = None
ret["msgtype"] = msgtype
@@ -533,7 +545,7 @@ class Mailer:
return ret
def _add_text_message_vars(
- self, messagevars: Dict[str, Any], event: EventBase
+ self, messagevars: MessageVars, event: EventBase
) -> None:
"""
Potentially add a sanitised message body to the message variables.
@@ -543,8 +555,8 @@ class Mailer:
event: The event under consideration.
"""
msgformat = event.content.get("format")
-
- messagevars["format"] = msgformat
+ if not isinstance(msgformat, str):
+ msgformat = None
formatted_body = event.content.get("formatted_body")
body = event.content.get("body")
@@ -555,7 +567,7 @@ class Mailer:
messagevars["body_text_html"] = safe_text(body)
def _add_image_message_vars(
- self, messagevars: Dict[str, Any], event: EventBase
+ self, messagevars: MessageVars, event: EventBase
) -> None:
"""
Potentially add an image URL to the message variables.
@@ -570,7 +582,7 @@ class Mailer:
async def _make_summary_text_single_room(
self,
room_id: str,
- notifs: List[Dict[str, Any]],
+ notifs: List[EmailPushAction],
room_state_ids: StateMap[str],
notif_events: Dict[str, EventBase],
user_id: str,
@@ -685,10 +697,10 @@ class Mailer:
async def _make_summary_text(
self,
- notifs_by_room: Dict[str, List[Dict[str, Any]]],
+ notifs_by_room: Dict[str, List[EmailPushAction]],
room_state_ids: Dict[str, StateMap[str]],
notif_events: Dict[str, EventBase],
- reason: Dict[str, Any],
+ reason: EmailReason,
) -> str:
"""
Make a summary text for the email when multiple rooms have notifications.
@@ -718,7 +730,7 @@ class Mailer:
async def _make_summary_text_from_member_events(
self,
room_id: str,
- notifs: List[Dict[str, Any]],
+ notifs: List[EmailPushAction],
room_state_ids: StateMap[str],
notif_events: Dict[str, EventBase],
) -> str:
@@ -805,7 +817,7 @@ class Mailer:
base_url = "https://matrix.to/#"
return "%s/%s" % (base_url, room_id)
- def _make_notif_link(self, notif: Dict[str, str]) -> str:
+ def _make_notif_link(self, notif: EmailPushAction) -> str:
"""
Generate a link to open an event in the web client.
diff --git a/synapse/push/push_types.py b/synapse/push/push_types.py
new file mode 100644
index 0000000000..8d16ab62ce
--- /dev/null
+++ b/synapse/push/push_types.py
@@ -0,0 +1,136 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import List, Optional
+
+from typing_extensions import TypedDict
+
+
+class EmailReason(TypedDict, total=False):
+ """
+ Information on the event that triggered the email to be sent
+
+ room_id: the ID of the room the event was sent in
+ now: timestamp in ms when the email is being sent out
+ room_name: a human-readable name for the room the event was sent in
+ received_at: the time in milliseconds at which the event was received
+ delay_before_mail_ms: the amount of time in milliseconds Synapse always waits
+ before ever emailing about a notification (to give the user a chance to respond
+ to other push or notice the window)
+ last_sent_ts: the time in milliseconds at which a notification was last sent
+ for an event in this room
+ throttle_ms: the minimum amount of time in milliseconds between two
+ notifications can be sent for this room
+ """
+
+ room_id: str
+ now: int
+ room_name: Optional[str]
+ received_at: int
+ delay_before_mail_ms: int
+ last_sent_ts: int
+ throttle_ms: int
+
+
+class MessageVars(TypedDict, total=False):
+ """
+ Details about a specific message to include in a notification
+
+ event_type: the type of the event
+ is_historical: a boolean, which is `False` if the message is the one
+ that triggered the notification, `True` otherwise
+ id: the ID of the event
+ ts: the time in milliseconds at which the event was sent
+ sender_name: the display name for the event's sender
+ sender_avatar_url: the avatar URL (as a `mxc://` URL) for the event's
+ sender
+ sender_hash: a hash of the user ID of the sender
+ msgtype: the type of the message
+ body_text_html: html representation of the message
+ body_text_plain: plaintext representation of the message
+ image_url: mxc url of an image, when "msgtype" is "m.image"
+ """
+
+ event_type: str
+ is_historical: bool
+ id: str
+ ts: int
+ sender_name: str
+ sender_avatar_url: Optional[str]
+ sender_hash: int
+ msgtype: Optional[str]
+ body_text_html: str
+ body_text_plain: str
+ image_url: str
+
+
+class NotifVars(TypedDict):
+ """
+ Details about an event we are about to include in a notification
+
+ link: a `matrix.to` link to the event
+ ts: the time in milliseconds at which the event was received
+ messages: a list of messages containing one message before the event, the
+ message in the event, and one message after the event.
+ """
+
+ link: str
+ ts: Optional[int]
+ messages: List[MessageVars]
+
+
+class RoomVars(TypedDict):
+ """
+ Represents a room containing events to include in the email.
+
+ title: a human-readable name for the room
+ hash: a hash of the ID of the room
+ invite: a boolean, which is `True` if the room is an invite the user hasn't
+ accepted yet, `False` otherwise
+ notifs: a list of events, or an empty list if `invite` is `True`.
+ link: a `matrix.to` link to the room
+ avator_url: url to the room's avator
+ """
+
+ title: Optional[str]
+ hash: int
+ invite: bool
+ notifs: List[NotifVars]
+ link: str
+ avatar_url: Optional[str]
+
+
+class TemplateVars(TypedDict, total=False):
+ """
+ Generic structure for passing to the email sender, can hold all the fields used in email templates.
+
+ app_name: name of the app/service this homeserver is associated with
+ server_name: name of our own homeserver
+ link: a link to include into the email to be sent
+ user_display_name: the display name for the user receiving the notification
+ unsubscribe_link: the link users can click to unsubscribe from email notifications
+ summary_text: a summary of the notification(s). The text used can be customised
+ by configuring the various settings in the `email.subjects` section of the
+ configuration file.
+ rooms: a list of rooms containing events to include in the email
+ reason: information on the event that triggered the email to be sent
+ """
+
+ app_name: str
+ server_name: str
+ link: str
+ user_display_name: str
+ unsubscribe_link: str
+ summary_text: str
+ rooms: List[RoomVars]
+ reason: EmailReason
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 154e5b7028..7d26954244 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -86,7 +86,7 @@ REQUIREMENTS = [
# We enforce that we have a `cryptography` version that bundles an `openssl`
# with the latest security patches.
"cryptography>=3.4.7",
- "ijson>=3.0",
+ "ijson>=3.1",
]
CONDITIONAL_REQUIREMENTS = {
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index 8c1bf9227a..fa132d10b4 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -14,10 +14,18 @@
from typing import List, Optional, Tuple
from synapse.storage.database import LoggingDatabaseConnection
-from synapse.storage.util.id_generators import _load_current_id
+from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id
-class SlavedIdTracker:
+class SlavedIdTracker(AbstractStreamIdTracker):
+ """Tracks the "current" stream ID of a stream with a single writer.
+
+ See `AbstractStreamIdTracker` for more details.
+
+ Note that this class does not work correctly when there are multiple
+ writers.
+ """
+
def __init__(
self,
db_conn: LoggingDatabaseConnection,
@@ -36,17 +44,7 @@ class SlavedIdTracker:
self._current = (max if self.step > 0 else min)(self._current, new_id)
def get_current_token(self) -> int:
- """
-
- Returns:
- int
- """
return self._current
def get_current_token_for_writer(self, instance_name: str) -> int:
- """Returns the position of the given writer.
-
- For streams with single writers this is equivalent to
- `get_current_token`.
- """
return self.get_current_token()
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 4d5f862862..7541e21de9 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
@@ -25,9 +24,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- # We assert this for the benefit of mypy
- assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
-
if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(instance_name, token)
for row in rows:
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index a030e9299e..a390cfcb74 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -14,7 +14,7 @@
# limitations under the License.
import heapq
from collections.abc import Iterable
-from typing import TYPE_CHECKING, List, Optional, Tuple, Type
+from typing import TYPE_CHECKING, Optional, Tuple, Type
import attr
@@ -157,7 +157,7 @@ class EventsStream(Stream):
# now we fetch up to that many rows from the events table
- event_rows: List[Tuple] = await self._store.get_all_new_forward_event_rows(
+ event_rows = await self._store.get_all_new_forward_event_rows(
instance_name, from_token, current_token, target_row_count
)
@@ -191,7 +191,7 @@ class EventsStream(Stream):
# finally, fetch the ex-outliers rows. We assume there are few enough of these
# not to bother with the limit.
- ex_outliers_rows: List[Tuple] = await self._store.get_ex_outlier_stream_rows(
+ ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
instance_name, from_token, upper_limit
)
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index ee4a5e481b..c51a029bf3 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -17,6 +17,7 @@
import logging
import platform
+from http import HTTPStatus
from typing import TYPE_CHECKING, Optional, Tuple
import synapse
@@ -98,7 +99,7 @@ class VersionServlet(RestServlet):
}
def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- return 200, self.res
+ return HTTPStatus.OK, self.res
class PurgeHistoryRestServlet(RestServlet):
@@ -130,7 +131,7 @@ class PurgeHistoryRestServlet(RestServlet):
event = await self.store.get_event(event_id)
if event.room_id != room_id:
- raise SynapseError(400, "Event is for wrong room.")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Event is for wrong room.")
# RoomStreamToken expects [int] not Optional[int]
assert event.internal_metadata.stream_ordering is not None
@@ -144,7 +145,9 @@ class PurgeHistoryRestServlet(RestServlet):
ts = body["purge_up_to_ts"]
if not isinstance(ts, int):
raise SynapseError(
- 400, "purge_up_to_ts must be an int", errcode=Codes.BAD_JSON
+ HTTPStatus.BAD_REQUEST,
+ "purge_up_to_ts must be an int",
+ errcode=Codes.BAD_JSON,
)
stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts)
@@ -160,7 +163,9 @@ class PurgeHistoryRestServlet(RestServlet):
stream_ordering,
)
raise SynapseError(
- 404, "there is no event to be purged", errcode=Codes.NOT_FOUND
+ HTTPStatus.NOT_FOUND,
+ "there is no event to be purged",
+ errcode=Codes.NOT_FOUND,
)
(stream, topo, _event_id) = r
token = "t%d-%d" % (topo, stream)
@@ -173,7 +178,7 @@ class PurgeHistoryRestServlet(RestServlet):
)
else:
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"must specify purge_up_to_event_id or purge_up_to_ts",
errcode=Codes.BAD_JSON,
)
@@ -182,7 +187,7 @@ class PurgeHistoryRestServlet(RestServlet):
room_id, token, delete_local_events=delete_local_events
)
- return 200, {"purge_id": purge_id}
+ return HTTPStatus.OK, {"purge_id": purge_id}
class PurgeHistoryStatusRestServlet(RestServlet):
@@ -201,7 +206,7 @@ class PurgeHistoryStatusRestServlet(RestServlet):
if purge_status is None:
raise NotFoundError("purge id '%s' not found" % purge_id)
- return 200, purge_status.asdict()
+ return HTTPStatus.OK, purge_status.asdict()
########################################################################################
diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
index d9a2f6ca15..399b205aaf 100644
--- a/synapse/rest/admin/_base.py
+++ b/synapse/rest/admin/_base.py
@@ -13,6 +13,7 @@
# limitations under the License.
import re
+from http import HTTPStatus
from typing import Iterable, Pattern
from synapse.api.auth import Auth
@@ -62,4 +63,4 @@ async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None:
"""
is_admin = await auth.is_server_admin(user_id)
if not is_admin:
- raise AuthError(403, "You are not a server admin")
+ raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index 80fbf32f17..2e5a6600d3 100644
--- a/synapse/rest/admin/devices.py
+++ b/synapse/rest/admin/devices.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import NotFoundError, SynapseError
@@ -53,7 +54,7 @@ class DeviceRestServlet(RestServlet):
target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Can only lookup local users")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string())
if u is None:
@@ -62,7 +63,7 @@ class DeviceRestServlet(RestServlet):
device = await self.device_handler.get_device(
target_user.to_string(), device_id
)
- return 200, device
+ return HTTPStatus.OK, device
async def on_DELETE(
self, request: SynapseRequest, user_id: str, device_id: str
@@ -71,14 +72,14 @@ class DeviceRestServlet(RestServlet):
target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Can only lookup local users")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string())
if u is None:
raise NotFoundError("Unknown user")
await self.device_handler.delete_device(target_user.to_string(), device_id)
- return 200, {}
+ return HTTPStatus.OK, {}
async def on_PUT(
self, request: SynapseRequest, user_id: str, device_id: str
@@ -87,7 +88,7 @@ class DeviceRestServlet(RestServlet):
target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Can only lookup local users")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string())
if u is None:
@@ -97,7 +98,7 @@ class DeviceRestServlet(RestServlet):
await self.device_handler.update_device(
target_user.to_string(), device_id, body
)
- return 200, {}
+ return HTTPStatus.OK, {}
class DevicesRestServlet(RestServlet):
@@ -124,14 +125,14 @@ class DevicesRestServlet(RestServlet):
target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Can only lookup local users")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string())
if u is None:
raise NotFoundError("Unknown user")
devices = await self.device_handler.get_devices_by_user(target_user.to_string())
- return 200, {"devices": devices, "total": len(devices)}
+ return HTTPStatus.OK, {"devices": devices, "total": len(devices)}
class DeleteDevicesRestServlet(RestServlet):
@@ -155,7 +156,7 @@ class DeleteDevicesRestServlet(RestServlet):
target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Can only lookup local users")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string())
if u is None:
@@ -167,4 +168,4 @@ class DeleteDevicesRestServlet(RestServlet):
await self.device_handler.delete_devices(
target_user.to_string(), body["devices"]
)
- return 200, {}
+ return HTTPStatus.OK, {}
diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py
index bbfcaf723b..5ee8b11110 100644
--- a/synapse/rest/admin/event_reports.py
+++ b/synapse/rest/admin/event_reports.py
@@ -13,6 +13,7 @@
# limitations under the License.
import logging
+from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -66,21 +67,23 @@ class EventReportsRestServlet(RestServlet):
if start < 0:
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"The start parameter must be a positive integer.",
errcode=Codes.INVALID_PARAM,
)
if limit < 0:
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"The limit parameter must be a positive integer.",
errcode=Codes.INVALID_PARAM,
)
if direction not in ("f", "b"):
raise SynapseError(
- 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM
+ HTTPStatus.BAD_REQUEST,
+ "Unknown direction: %s" % (direction,),
+ errcode=Codes.INVALID_PARAM,
)
event_reports, total = await self.store.get_event_reports_paginate(
@@ -90,7 +93,7 @@ class EventReportsRestServlet(RestServlet):
if (start + limit) < total:
ret["next_token"] = start + len(event_reports)
- return 200, ret
+ return HTTPStatus.OK, ret
class EventReportDetailRestServlet(RestServlet):
@@ -127,13 +130,17 @@ class EventReportDetailRestServlet(RestServlet):
try:
resolved_report_id = int(report_id)
except ValueError:
- raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
+ )
if resolved_report_id < 0:
- raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
+ )
ret = await self.store.get_event_report(resolved_report_id)
if not ret:
raise NotFoundError("Event report not found")
- return 200, ret
+ return HTTPStatus.OK, ret
diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py
index 68a3ba3cb7..a27110388f 100644
--- a/synapse/rest/admin/groups.py
+++ b/synapse/rest/admin/groups.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import SynapseError
@@ -43,7 +44,7 @@ class DeleteGroupAdminRestServlet(RestServlet):
await assert_user_is_admin(self.auth, requester.user)
if not self.is_mine_id(group_id):
- raise SynapseError(400, "Can only delete local groups")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local groups")
await self.group_server.delete_group(group_id, requester.user.to_string())
- return 200, {}
+ return HTTPStatus.OK, {}
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 30a687d234..9e23e2d8fc 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
@@ -62,7 +63,7 @@ class QuarantineMediaInRoom(RestServlet):
room_id, requester.user.to_string()
)
- return 200, {"num_quarantined": num_quarantined}
+ return HTTPStatus.OK, {"num_quarantined": num_quarantined}
class QuarantineMediaByUser(RestServlet):
@@ -89,7 +90,7 @@ class QuarantineMediaByUser(RestServlet):
user_id, requester.user.to_string()
)
- return 200, {"num_quarantined": num_quarantined}
+ return HTTPStatus.OK, {"num_quarantined": num_quarantined}
class QuarantineMediaByID(RestServlet):
@@ -118,7 +119,7 @@ class QuarantineMediaByID(RestServlet):
server_name, media_id, requester.user.to_string()
)
- return 200, {}
+ return HTTPStatus.OK, {}
class UnquarantineMediaByID(RestServlet):
@@ -147,7 +148,7 @@ class UnquarantineMediaByID(RestServlet):
# Remove from quarantine this media id
await self.store.quarantine_media_by_id(server_name, media_id, None)
- return 200, {}
+ return HTTPStatus.OK, {}
class ProtectMediaByID(RestServlet):
@@ -170,7 +171,7 @@ class ProtectMediaByID(RestServlet):
# Protect this media id
await self.store.mark_local_media_as_safe(media_id, safe=True)
- return 200, {}
+ return HTTPStatus.OK, {}
class UnprotectMediaByID(RestServlet):
@@ -193,7 +194,7 @@ class UnprotectMediaByID(RestServlet):
# Unprotect this media id
await self.store.mark_local_media_as_safe(media_id, safe=False)
- return 200, {}
+ return HTTPStatus.OK, {}
class ListMediaInRoom(RestServlet):
@@ -211,11 +212,11 @@ class ListMediaInRoom(RestServlet):
requester = await self.auth.get_user_by_req(request)
is_admin = await self.auth.is_server_admin(requester.user)
if not is_admin:
- raise AuthError(403, "You are not a server admin")
+ raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id)
- return 200, {"local": local_mxcs, "remote": remote_mxcs}
+ return HTTPStatus.OK, {"local": local_mxcs, "remote": remote_mxcs}
class PurgeMediaCacheRestServlet(RestServlet):
@@ -233,13 +234,13 @@ class PurgeMediaCacheRestServlet(RestServlet):
if before_ts < 0:
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"Query parameter before_ts must be a positive integer.",
errcode=Codes.INVALID_PARAM,
)
elif before_ts < 30000000000: # Dec 1970 in milliseconds, Aug 2920 in seconds
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"Query parameter before_ts you provided is from the year 1970. "
+ "Double check that you are providing a timestamp in milliseconds.",
errcode=Codes.INVALID_PARAM,
@@ -247,7 +248,7 @@ class PurgeMediaCacheRestServlet(RestServlet):
ret = await self.media_repository.delete_old_remote_media(before_ts)
- return 200, ret
+ return HTTPStatus.OK, ret
class DeleteMediaByID(RestServlet):
@@ -267,7 +268,7 @@ class DeleteMediaByID(RestServlet):
await assert_requester_is_admin(self.auth, request)
if self.server_name != server_name:
- raise SynapseError(400, "Can only delete local media")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media")
if await self.store.get_local_media(media_id) is None:
raise NotFoundError("Unknown media")
@@ -277,7 +278,7 @@ class DeleteMediaByID(RestServlet):
deleted_media, total = await self.media_repository.delete_local_media_ids(
[media_id]
)
- return 200, {"deleted_media": deleted_media, "total": total}
+ return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total}
class DeleteMediaByDateSize(RestServlet):
@@ -304,26 +305,26 @@ class DeleteMediaByDateSize(RestServlet):
if before_ts < 0:
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"Query parameter before_ts must be a positive integer.",
errcode=Codes.INVALID_PARAM,
)
elif before_ts < 30000000000: # Dec 1970 in milliseconds, Aug 2920 in seconds
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"Query parameter before_ts you provided is from the year 1970. "
+ "Double check that you are providing a timestamp in milliseconds.",
errcode=Codes.INVALID_PARAM,
)
if size_gt < 0:
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"Query parameter size_gt must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
if self.server_name != server_name:
- raise SynapseError(400, "Can only delete local media")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media")
logging.info(
"Deleting local media by timestamp: %s, size larger than: %s, keep profile media: %s"
@@ -333,7 +334,7 @@ class DeleteMediaByDateSize(RestServlet):
deleted_media, total = await self.media_repository.delete_old_local_media(
before_ts, size_gt, keep_profiles
)
- return 200, {"deleted_media": deleted_media, "total": total}
+ return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total}
class UserMediaRestServlet(RestServlet):
@@ -369,7 +370,7 @@ class UserMediaRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
if not self.is_mine(UserID.from_string(user_id)):
- raise SynapseError(400, "Can only look up local users")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
user = await self.store.get_user_by_id(user_id)
if user is None:
@@ -380,14 +381,14 @@ class UserMediaRestServlet(RestServlet):
if start < 0:
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"Query parameter from must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
if limit < 0:
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"Query parameter limit must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
@@ -425,7 +426,7 @@ class UserMediaRestServlet(RestServlet):
if (start + limit) < total:
ret["next_token"] = start + len(media)
- return 200, ret
+ return HTTPStatus.OK, ret
async def on_DELETE(
self, request: SynapseRequest, user_id: str
@@ -436,7 +437,7 @@ class UserMediaRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
if not self.is_mine(UserID.from_string(user_id)):
- raise SynapseError(400, "Can only look up local users")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
user = await self.store.get_user_by_id(user_id)
if user is None:
@@ -447,14 +448,14 @@ class UserMediaRestServlet(RestServlet):
if start < 0:
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"Query parameter from must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
if limit < 0:
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"Query parameter limit must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
@@ -492,7 +493,7 @@ class UserMediaRestServlet(RestServlet):
([row["media_id"] for row in media])
)
- return 200, {"deleted_media": deleted_media, "total": total}
+ return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total}
def register_servlets_for_media_repo(hs: "HomeServer", http_server: HttpServer) -> None:
diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py
index aba48f6e7b..891b98c088 100644
--- a/synapse/rest/admin/registration_tokens.py
+++ b/synapse/rest/admin/registration_tokens.py
@@ -14,6 +14,7 @@
import logging
import string
+from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -77,7 +78,7 @@ class ListRegistrationTokensRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
valid = parse_boolean(request, "valid")
token_list = await self.store.get_registration_tokens(valid)
- return 200, {"registration_tokens": token_list}
+ return HTTPStatus.OK, {"registration_tokens": token_list}
class NewRegistrationTokenRestServlet(RestServlet):
@@ -123,16 +124,20 @@ class NewRegistrationTokenRestServlet(RestServlet):
if "token" in body:
token = body["token"]
if not isinstance(token, str):
- raise SynapseError(400, "token must be a string", Codes.INVALID_PARAM)
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "token must be a string",
+ Codes.INVALID_PARAM,
+ )
if not (0 < len(token) <= 64):
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"token must not be empty and must not be longer than 64 characters",
Codes.INVALID_PARAM,
)
if not set(token).issubset(self.allowed_chars_set):
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"token must consist only of characters matched by the regex [A-Za-z0-9-_]",
Codes.INVALID_PARAM,
)
@@ -142,11 +147,13 @@ class NewRegistrationTokenRestServlet(RestServlet):
length = body.get("length", 16)
if not isinstance(length, int):
raise SynapseError(
- 400, "length must be an integer", Codes.INVALID_PARAM
+ HTTPStatus.BAD_REQUEST,
+ "length must be an integer",
+ Codes.INVALID_PARAM,
)
if not (0 < length <= 64):
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"length must be greater than zero and not greater than 64",
Codes.INVALID_PARAM,
)
@@ -162,7 +169,7 @@ class NewRegistrationTokenRestServlet(RestServlet):
or (isinstance(uses_allowed, int) and uses_allowed >= 0)
):
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"uses_allowed must be a non-negative integer or null",
Codes.INVALID_PARAM,
)
@@ -170,11 +177,15 @@ class NewRegistrationTokenRestServlet(RestServlet):
expiry_time = body.get("expiry_time", None)
if not isinstance(expiry_time, (int, type(None))):
raise SynapseError(
- 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
+ HTTPStatus.BAD_REQUEST,
+ "expiry_time must be an integer or null",
+ Codes.INVALID_PARAM,
)
if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
raise SynapseError(
- 400, "expiry_time must not be in the past", Codes.INVALID_PARAM
+ HTTPStatus.BAD_REQUEST,
+ "expiry_time must not be in the past",
+ Codes.INVALID_PARAM,
)
created = await self.store.create_registration_token(
@@ -182,7 +193,9 @@ class NewRegistrationTokenRestServlet(RestServlet):
)
if not created:
raise SynapseError(
- 400, f"Token already exists: {token}", Codes.INVALID_PARAM
+ HTTPStatus.BAD_REQUEST,
+ f"Token already exists: {token}",
+ Codes.INVALID_PARAM,
)
resp = {
@@ -192,7 +205,7 @@ class NewRegistrationTokenRestServlet(RestServlet):
"completed": 0,
"expiry_time": expiry_time,
}
- return 200, resp
+ return HTTPStatus.OK, resp
class RegistrationTokenRestServlet(RestServlet):
@@ -261,7 +274,7 @@ class RegistrationTokenRestServlet(RestServlet):
if token_info is None:
raise NotFoundError(f"No such registration token: {token}")
- return 200, token_info
+ return HTTPStatus.OK, token_info
async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
"""Update a registration token."""
@@ -277,7 +290,7 @@ class RegistrationTokenRestServlet(RestServlet):
or (isinstance(uses_allowed, int) and uses_allowed >= 0)
):
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"uses_allowed must be a non-negative integer or null",
Codes.INVALID_PARAM,
)
@@ -287,11 +300,15 @@ class RegistrationTokenRestServlet(RestServlet):
expiry_time = body["expiry_time"]
if not isinstance(expiry_time, (int, type(None))):
raise SynapseError(
- 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
+ HTTPStatus.BAD_REQUEST,
+ "expiry_time must be an integer or null",
+ Codes.INVALID_PARAM,
)
if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
raise SynapseError(
- 400, "expiry_time must not be in the past", Codes.INVALID_PARAM
+ HTTPStatus.BAD_REQUEST,
+ "expiry_time must not be in the past",
+ Codes.INVALID_PARAM,
)
new_attributes["expiry_time"] = expiry_time
@@ -307,7 +324,7 @@ class RegistrationTokenRestServlet(RestServlet):
if token_info is None:
raise NotFoundError(f"No such registration token: {token}")
- return 200, token_info
+ return HTTPStatus.OK, token_info
async def on_DELETE(
self, request: SynapseRequest, token: str
@@ -316,6 +333,6 @@ class RegistrationTokenRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
if await self.store.delete_registration_token(token):
- return 200, {}
+ return HTTPStatus.OK, {}
raise NotFoundError(f"No such registration token: {token}")
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index a89dda1ba5..6bbc5510f0 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -102,7 +102,9 @@ class RoomRestV2Servlet(RestServlet):
)
if not RoomID.is_valid(room_id):
- raise SynapseError(400, "%s is not a legal room ID" % (room_id,))
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,)
+ )
if not await self._store.get_room(room_id):
raise NotFoundError("Unknown room id %s" % (room_id,))
@@ -118,7 +120,7 @@ class RoomRestV2Servlet(RestServlet):
force_purge=force_purge,
)
- return 200, {"delete_id": delete_id}
+ return HTTPStatus.OK, {"delete_id": delete_id}
class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
@@ -137,7 +139,9 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
await assert_requester_is_admin(self._auth, request)
if not RoomID.is_valid(room_id):
- raise SynapseError(400, "%s is not a legal room ID" % (room_id,))
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,)
+ )
delete_ids = self._pagination_handler.get_delete_ids_by_room(room_id)
if delete_ids is None:
@@ -153,7 +157,7 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
**delete.asdict(),
}
]
- return 200, {"results": cast(JsonDict, response)}
+ return HTTPStatus.OK, {"results": cast(JsonDict, response)}
class DeleteRoomStatusByDeleteIdRestServlet(RestServlet):
@@ -175,7 +179,7 @@ class DeleteRoomStatusByDeleteIdRestServlet(RestServlet):
if delete_status is None:
raise NotFoundError("delete id '%s' not found" % delete_id)
- return 200, cast(JsonDict, delete_status.asdict())
+ return HTTPStatus.OK, cast(JsonDict, delete_status.asdict())
class ListRoomRestServlet(RestServlet):
@@ -217,7 +221,7 @@ class ListRoomRestServlet(RestServlet):
RoomSortOrder.STATE_EVENTS.value,
):
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"Unknown value for order_by: %s" % (order_by,),
errcode=Codes.INVALID_PARAM,
)
@@ -225,7 +229,7 @@ class ListRoomRestServlet(RestServlet):
search_term = parse_string(request, "search_term", encoding="utf-8")
if search_term == "":
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"search_term cannot be an empty string",
errcode=Codes.INVALID_PARAM,
)
@@ -233,7 +237,9 @@ class ListRoomRestServlet(RestServlet):
direction = parse_string(request, "dir", default="f")
if direction not in ("f", "b"):
raise SynapseError(
- 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM
+ HTTPStatus.BAD_REQUEST,
+ "Unknown direction: %s" % (direction,),
+ errcode=Codes.INVALID_PARAM,
)
reverse_order = True if direction == "b" else False
@@ -265,7 +271,7 @@ class ListRoomRestServlet(RestServlet):
else:
response["prev_batch"] = 0
- return 200, response
+ return HTTPStatus.OK, response
class RoomRestServlet(RestServlet):
@@ -310,7 +316,7 @@ class RoomRestServlet(RestServlet):
members = await self.store.get_users_in_room(room_id)
ret["joined_local_devices"] = await self.store.count_devices_by_users(members)
- return 200, ret
+ return HTTPStatus.OK, ret
async def on_DELETE(
self, request: SynapseRequest, room_id: str
@@ -386,7 +392,7 @@ class RoomRestServlet(RestServlet):
# See https://github.com/python/mypy/issues/4976#issuecomment-579883622
# for some discussion on why this is necessary. Either way,
# `ret` is an opaque dictionary blob as far as the rest of the app cares.
- return 200, cast(JsonDict, ret)
+ return HTTPStatus.OK, cast(JsonDict, ret)
class RoomMembersRestServlet(RestServlet):
@@ -413,7 +419,7 @@ class RoomMembersRestServlet(RestServlet):
members = await self.store.get_users_in_room(room_id)
ret = {"members": members, "total": len(members)}
- return 200, ret
+ return HTTPStatus.OK, ret
class RoomStateRestServlet(RestServlet):
@@ -452,7 +458,7 @@ class RoomStateRestServlet(RestServlet):
)
ret = {"state": room_state}
- return 200, ret
+ return HTTPStatus.OK, ret
class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
@@ -481,7 +487,10 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
target_user = UserID.from_string(content["user_id"])
if not self.hs.is_mine(target_user):
- raise SynapseError(400, "This endpoint can only be used with local users")
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "This endpoint can only be used with local users",
+ )
if not await self.admin_handler.get_user(target_user):
raise NotFoundError("User not found")
@@ -527,7 +536,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
ratelimit=False,
)
- return 200, {"room_id": room_id}
+ return HTTPStatus.OK, {"room_id": room_id}
class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
@@ -568,7 +577,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
# Figure out which local users currently have power in the room, if any.
room_state = await self.state_handler.get_current_state(room_id)
if not room_state:
- raise SynapseError(400, "Server not in room")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Server not in room")
create_event = room_state[(EventTypes.Create, "")]
power_levels = room_state.get((EventTypes.PowerLevels, ""))
@@ -582,7 +591,9 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
admin_users.sort(key=lambda user: user_power[user])
if not admin_users:
- raise SynapseError(400, "No local admin user in room")
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "No local admin user in room"
+ )
admin_user_id = None
@@ -599,7 +610,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
if not admin_user_id:
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"No local admin user in room",
)
@@ -610,7 +621,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
admin_user_id = create_event.sender
if not self.is_mine_id(admin_user_id):
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"No local admin user in room",
)
@@ -639,7 +650,8 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
except AuthError:
# The admin user we found turned out not to have enough power.
raise SynapseError(
- 400, "No local admin user in room with power to update power levels."
+ HTTPStatus.BAD_REQUEST,
+ "No local admin user in room with power to update power levels.",
)
# Now we check if the user we're granting admin rights to is already in
@@ -653,7 +665,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
)
if is_joined:
- return 200, {}
+ return HTTPStatus.OK, {}
join_rules = room_state.get((EventTypes.JoinRules, ""))
is_public = False
@@ -661,7 +673,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC
if is_public:
- return 200, {}
+ return HTTPStatus.OK, {}
await self.room_member_handler.update_membership(
fake_requester,
@@ -670,7 +682,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
action=Membership.INVITE,
)
- return 200, {}
+ return HTTPStatus.OK, {}
class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
@@ -702,7 +714,7 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
room_id, _ = await self.resolve_room_id(room_identifier)
deleted_count = await self.store.delete_forward_extremities_for_room(room_id)
- return 200, {"deleted": deleted_count}
+ return HTTPStatus.OK, {"deleted": deleted_count}
async def on_GET(
self, request: SynapseRequest, room_identifier: str
@@ -713,7 +725,7 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
room_id, _ = await self.resolve_room_id(room_identifier)
extremities = await self.store.get_forward_extremities_for_room(room_id)
- return 200, {"count": len(extremities), "results": extremities}
+ return HTTPStatus.OK, {"count": len(extremities), "results": extremities}
class RoomEventContextServlet(RestServlet):
@@ -762,7 +774,9 @@ class RoomEventContextServlet(RestServlet):
)
if not results:
- raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
+ raise SynapseError(
+ HTTPStatus.NOT_FOUND, "Event not found.", errcode=Codes.NOT_FOUND
+ )
time_now = self.clock.time_msec()
results["events_before"] = await self._event_serializer.serialize_events(
@@ -781,7 +795,7 @@ class RoomEventContextServlet(RestServlet):
bundle_relations=False,
)
- return 200, results
+ return HTTPStatus.OK, results
class BlockRoomRestServlet(RestServlet):
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index 19f84f33f2..b295fb078b 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from http import HTTPStatus
from typing import TYPE_CHECKING, Awaitable, Optional, Tuple
from synapse.api.constants import EventTypes
@@ -82,11 +83,15 @@ class SendServerNoticeServlet(RestServlet):
# but worker processes still need to initialise SendServerNoticeServlet (as it is part of the
# admin api).
if not self.server_notices_manager.is_enabled():
- raise SynapseError(400, "Server notices are not enabled on this server")
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "Server notices are not enabled on this server"
+ )
target_user = UserID.from_string(body["user_id"])
if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Server notices can only be sent to local users")
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "Server notices can only be sent to local users"
+ )
if not await self.admin_handler.get_user(target_user):
raise NotFoundError("User not found")
@@ -99,7 +104,7 @@ class SendServerNoticeServlet(RestServlet):
txn_id=txn_id,
)
- return 200, {"event_id": event.event_id}
+ return HTTPStatus.OK, {"event_id": event.event_id}
def on_PUT(
self, request: SynapseRequest, txn_id: str
diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py
index 948de94ccd..ca41fd45f2 100644
--- a/synapse/rest/admin/statistics.py
+++ b/synapse/rest/admin/statistics.py
@@ -13,6 +13,7 @@
# limitations under the License.
import logging
+from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, SynapseError
@@ -53,7 +54,7 @@ class UserMediaStatisticsRestServlet(RestServlet):
UserSortOrder.DISPLAYNAME.value,
):
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"Unknown value for order_by: %s" % (order_by,),
errcode=Codes.INVALID_PARAM,
)
@@ -61,7 +62,7 @@ class UserMediaStatisticsRestServlet(RestServlet):
start = parse_integer(request, "from", default=0)
if start < 0:
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"Query parameter from must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
@@ -69,7 +70,7 @@ class UserMediaStatisticsRestServlet(RestServlet):
limit = parse_integer(request, "limit", default=100)
if limit < 0:
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"Query parameter limit must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
@@ -77,7 +78,7 @@ class UserMediaStatisticsRestServlet(RestServlet):
from_ts = parse_integer(request, "from_ts", default=0)
if from_ts < 0:
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"Query parameter from_ts must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
@@ -86,13 +87,13 @@ class UserMediaStatisticsRestServlet(RestServlet):
if until_ts is not None:
if until_ts < 0:
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"Query parameter until_ts must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
if until_ts <= from_ts:
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"Query parameter until_ts must be greater than from_ts.",
errcode=Codes.INVALID_PARAM,
)
@@ -100,7 +101,7 @@ class UserMediaStatisticsRestServlet(RestServlet):
search_term = parse_string(request, "search_term")
if search_term == "":
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"Query parameter search_term cannot be an empty string.",
errcode=Codes.INVALID_PARAM,
)
@@ -108,7 +109,9 @@ class UserMediaStatisticsRestServlet(RestServlet):
direction = parse_string(request, "dir", default="f")
if direction not in ("f", "b"):
raise SynapseError(
- 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM
+ HTTPStatus.BAD_REQUEST,
+ "Unknown direction: %s" % (direction,),
+ errcode=Codes.INVALID_PARAM,
)
users_media, total = await self.store.get_users_media_usage_paginate(
@@ -118,4 +121,4 @@ class UserMediaStatisticsRestServlet(RestServlet):
if (start + limit) < total:
ret["next_token"] = start + len(users_media)
- return 200, ret
+ return HTTPStatus.OK, ret
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index ccd9a2a175..2a60b602b1 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -79,14 +79,14 @@ class UsersRestServletV2(RestServlet):
if start < 0:
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"Query parameter from must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
if limit < 0:
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"Query parameter limit must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
@@ -122,7 +122,7 @@ class UsersRestServletV2(RestServlet):
if (start + limit) < total:
ret["next_token"] = str(start + len(users))
- return 200, ret
+ return HTTPStatus.OK, ret
class UserRestServletV2(RestServlet):
@@ -172,14 +172,14 @@ class UserRestServletV2(RestServlet):
target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Can only look up local users")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
ret = await self.admin_handler.get_user(target_user)
if not ret:
raise NotFoundError("User not found")
- return 200, ret
+ return HTTPStatus.OK, ret
async def on_PUT(
self, request: SynapseRequest, user_id: str
@@ -191,7 +191,10 @@ class UserRestServletV2(RestServlet):
body = parse_json_object_from_request(request)
if not self.hs.is_mine(target_user):
- raise SynapseError(400, "This endpoint can only be used with local users")
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "This endpoint can only be used with local users",
+ )
user = await self.admin_handler.get_user(target_user)
user_id = target_user.to_string()
@@ -210,7 +213,7 @@ class UserRestServletV2(RestServlet):
user_type = body.get("user_type", None)
if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
- raise SynapseError(400, "Invalid user type")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid user type")
set_admin_to = body.get("admin", False)
if not isinstance(set_admin_to, bool):
@@ -223,11 +226,13 @@ class UserRestServletV2(RestServlet):
password = body.get("password", None)
if password is not None:
if not isinstance(password, str) or len(password) > 512:
- raise SynapseError(400, "Invalid password")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password")
deactivate = body.get("deactivated", False)
if not isinstance(deactivate, bool):
- raise SynapseError(400, "'deactivated' parameter is not of type boolean")
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean"
+ )
# convert List[Dict[str, str]] into List[Tuple[str, str]]
if external_ids is not None:
@@ -282,7 +287,9 @@ class UserRestServletV2(RestServlet):
user_id,
)
except ExternalIDReuseException:
- raise SynapseError(409, "External id is already in use.")
+ raise SynapseError(
+ HTTPStatus.CONFLICT, "External id is already in use."
+ )
if "avatar_url" in body and isinstance(body["avatar_url"], str):
await self.profile_handler.set_avatar_url(
@@ -293,7 +300,9 @@ class UserRestServletV2(RestServlet):
if set_admin_to != user["admin"]:
auth_user = requester.user
if target_user == auth_user and not set_admin_to:
- raise SynapseError(400, "You may not demote yourself.")
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "You may not demote yourself."
+ )
await self.store.set_server_admin(target_user, set_admin_to)
@@ -319,7 +328,8 @@ class UserRestServletV2(RestServlet):
and self.auth_handler.can_change_password()
):
raise SynapseError(
- 400, "Must provide a password to re-activate an account."
+ HTTPStatus.BAD_REQUEST,
+ "Must provide a password to re-activate an account.",
)
await self.deactivate_account_handler.activate_account(
@@ -332,7 +342,7 @@ class UserRestServletV2(RestServlet):
user = await self.admin_handler.get_user(target_user)
assert user is not None
- return 200, user
+ return HTTPStatus.OK, user
else: # create user
displayname = body.get("displayname", None)
@@ -381,7 +391,9 @@ class UserRestServletV2(RestServlet):
user_id,
)
except ExternalIDReuseException:
- raise SynapseError(409, "External id is already in use.")
+ raise SynapseError(
+ HTTPStatus.CONFLICT, "External id is already in use."
+ )
if "avatar_url" in body and isinstance(body["avatar_url"], str):
await self.profile_handler.set_avatar_url(
@@ -429,51 +441,61 @@ class UserRegisterServlet(RestServlet):
nonce = secrets.token_hex(64)
self.nonces[nonce] = int(self.reactor.seconds())
- return 200, {"nonce": nonce}
+ return HTTPStatus.OK, {"nonce": nonce}
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
self._clear_old_nonces()
if not self.hs.config.registration.registration_shared_secret:
- raise SynapseError(400, "Shared secret registration is not enabled")
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "Shared secret registration is not enabled"
+ )
body = parse_json_object_from_request(request)
if "nonce" not in body:
- raise SynapseError(400, "nonce must be specified", errcode=Codes.BAD_JSON)
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "nonce must be specified",
+ errcode=Codes.BAD_JSON,
+ )
nonce = body["nonce"]
if nonce not in self.nonces:
- raise SynapseError(400, "unrecognised nonce")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "unrecognised nonce")
# Delete the nonce, so it can't be reused, even if it's invalid
del self.nonces[nonce]
if "username" not in body:
raise SynapseError(
- 400, "username must be specified", errcode=Codes.BAD_JSON
+ HTTPStatus.BAD_REQUEST,
+ "username must be specified",
+ errcode=Codes.BAD_JSON,
)
else:
if not isinstance(body["username"], str) or len(body["username"]) > 512:
- raise SynapseError(400, "Invalid username")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid username")
username = body["username"].encode("utf-8")
if b"\x00" in username:
- raise SynapseError(400, "Invalid username")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid username")
if "password" not in body:
raise SynapseError(
- 400, "password must be specified", errcode=Codes.BAD_JSON
+ HTTPStatus.BAD_REQUEST,
+ "password must be specified",
+ errcode=Codes.BAD_JSON,
)
else:
password = body["password"]
if not isinstance(password, str) or len(password) > 512:
- raise SynapseError(400, "Invalid password")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password")
password_bytes = password.encode("utf-8")
if b"\x00" in password_bytes:
- raise SynapseError(400, "Invalid password")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password")
password_hash = await self.auth_handler.hash(password)
@@ -482,10 +504,12 @@ class UserRegisterServlet(RestServlet):
displayname = body.get("displayname", None)
if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
- raise SynapseError(400, "Invalid user type")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid user type")
if "mac" not in body:
- raise SynapseError(400, "mac must be specified", errcode=Codes.BAD_JSON)
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "mac must be specified", errcode=Codes.BAD_JSON
+ )
got_mac = body["mac"]
@@ -507,7 +531,7 @@ class UserRegisterServlet(RestServlet):
want_mac = want_mac_builder.hexdigest()
if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")):
- raise SynapseError(403, "HMAC incorrect")
+ raise SynapseError(HTTPStatus.FORBIDDEN, "HMAC incorrect")
# Reuse the parts of RegisterRestServlet to reduce code duplication
from synapse.rest.client.register import RegisterRestServlet
@@ -524,7 +548,7 @@ class UserRegisterServlet(RestServlet):
)
result = await register._create_registration_details(user_id, body)
- return 200, result
+ return HTTPStatus.OK, result
class WhoisRestServlet(RestServlet):
@@ -552,11 +576,11 @@ class WhoisRestServlet(RestServlet):
await assert_user_is_admin(self.auth, auth_user)
if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Can only whois a local user")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user")
ret = await self.admin_handler.get_whois(target_user)
- return 200, ret
+ return HTTPStatus.OK, ret
class DeactivateAccountRestServlet(RestServlet):
@@ -575,7 +599,9 @@ class DeactivateAccountRestServlet(RestServlet):
await assert_user_is_admin(self.auth, requester.user)
if not self.is_mine(UserID.from_string(target_user_id)):
- raise SynapseError(400, "Can only deactivate local users")
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "Can only deactivate local users"
+ )
if not await self.store.get_user_by_id(target_user_id):
raise NotFoundError("User not found")
@@ -597,7 +623,7 @@ class DeactivateAccountRestServlet(RestServlet):
else:
id_server_unbind_result = "no-support"
- return 200, {"id_server_unbind_result": id_server_unbind_result}
+ return HTTPStatus.OK, {"id_server_unbind_result": id_server_unbind_result}
class AccountValidityRenewServlet(RestServlet):
@@ -620,7 +646,7 @@ class AccountValidityRenewServlet(RestServlet):
if "user_id" not in body:
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"Missing property 'user_id' in the request body",
)
@@ -631,7 +657,7 @@ class AccountValidityRenewServlet(RestServlet):
)
res = {"expiration_ts": expiration_ts}
- return 200, res
+ return HTTPStatus.OK, res
class ResetPasswordRestServlet(RestServlet):
@@ -678,7 +704,7 @@ class ResetPasswordRestServlet(RestServlet):
await self._set_password_handler.set_password(
target_user_id, new_password_hash, logout_devices, requester
)
- return 200, {}
+ return HTTPStatus.OK, {}
class SearchUsersRestServlet(RestServlet):
@@ -712,16 +738,16 @@ class SearchUsersRestServlet(RestServlet):
# To allow all users to get the users list
# if not is_admin and target_user != auth_user:
- # raise AuthError(403, "You are not a server admin")
+ # raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Can only users a local user")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only users a local user")
term = parse_string(request, "term", required=True)
logger.info("term: %s ", term)
ret = await self.store.search_users(term)
- return 200, ret
+ return HTTPStatus.OK, ret
class UserAdminServlet(RestServlet):
@@ -765,11 +791,14 @@ class UserAdminServlet(RestServlet):
target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Only local users can be admins of this homeserver")
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "Only local users can be admins of this homeserver",
+ )
is_admin = await self.store.is_server_admin(target_user)
- return 200, {"admin": is_admin}
+ return HTTPStatus.OK, {"admin": is_admin}
async def on_PUT(
self, request: SynapseRequest, user_id: str
@@ -785,16 +814,19 @@ class UserAdminServlet(RestServlet):
assert_params_in_dict(body, ["admin"])
if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Only local users can be admins of this homeserver")
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "Only local users can be admins of this homeserver",
+ )
set_admin_to = bool(body["admin"])
if target_user == auth_user and not set_admin_to:
- raise SynapseError(400, "You may not demote yourself.")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "You may not demote yourself.")
await self.store.set_server_admin(target_user, set_admin_to)
- return 200, {}
+ return HTTPStatus.OK, {}
class UserMembershipRestServlet(RestServlet):
@@ -816,7 +848,7 @@ class UserMembershipRestServlet(RestServlet):
room_ids = await self.store.get_rooms_for_user(user_id)
ret = {"joined_rooms": list(room_ids), "total": len(room_ids)}
- return 200, ret
+ return HTTPStatus.OK, ret
class PushersRestServlet(RestServlet):
@@ -845,7 +877,7 @@ class PushersRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
if not self.is_mine(UserID.from_string(user_id)):
- raise SynapseError(400, "Can only look up local users")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
if not await self.store.get_user_by_id(user_id):
raise NotFoundError("User not found")
@@ -854,7 +886,10 @@ class PushersRestServlet(RestServlet):
filtered_pushers = [p.as_dict() for p in pushers]
- return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)}
+ return HTTPStatus.OK, {
+ "pushers": filtered_pushers,
+ "total": len(filtered_pushers),
+ }
class UserTokenRestServlet(RestServlet):
@@ -887,16 +922,22 @@ class UserTokenRestServlet(RestServlet):
auth_user = requester.user
if not self.hs.is_mine_id(user_id):
- raise SynapseError(400, "Only local users can be logged in as")
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "Only local users can be logged in as"
+ )
body = parse_json_object_from_request(request, allow_empty_body=True)
valid_until_ms = body.get("valid_until_ms")
if valid_until_ms and not isinstance(valid_until_ms, int):
- raise SynapseError(400, "'valid_until_ms' parameter must be an int")
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "'valid_until_ms' parameter must be an int"
+ )
if auth_user.to_string() == user_id:
- raise SynapseError(400, "Cannot use admin API to login as self")
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "Cannot use admin API to login as self"
+ )
token = await self.auth_handler.create_access_token_for_user_id(
user_id=auth_user.to_string(),
@@ -905,7 +946,7 @@ class UserTokenRestServlet(RestServlet):
puppets_user_id=user_id,
)
- return 200, {"access_token": token}
+ return HTTPStatus.OK, {"access_token": token}
class ShadowBanRestServlet(RestServlet):
@@ -947,11 +988,13 @@ class ShadowBanRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id):
- raise SynapseError(400, "Only local users can be shadow-banned")
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned"
+ )
await self.store.set_shadow_banned(UserID.from_string(user_id), True)
- return 200, {}
+ return HTTPStatus.OK, {}
async def on_DELETE(
self, request: SynapseRequest, user_id: str
@@ -959,11 +1002,13 @@ class ShadowBanRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id):
- raise SynapseError(400, "Only local users can be shadow-banned")
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned"
+ )
await self.store.set_shadow_banned(UserID.from_string(user_id), False)
- return 200, {}
+ return HTTPStatus.OK, {}
class RateLimitRestServlet(RestServlet):
@@ -995,7 +1040,7 @@ class RateLimitRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id):
- raise SynapseError(400, "Can only look up local users")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
if not await self.store.get_user_by_id(user_id):
raise NotFoundError("User not found")
@@ -1016,7 +1061,7 @@ class RateLimitRestServlet(RestServlet):
else:
ret = {}
- return 200, ret
+ return HTTPStatus.OK, ret
async def on_POST(
self, request: SynapseRequest, user_id: str
@@ -1024,7 +1069,9 @@ class RateLimitRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id):
- raise SynapseError(400, "Only local users can be ratelimited")
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited"
+ )
if not await self.store.get_user_by_id(user_id):
raise NotFoundError("User not found")
@@ -1036,14 +1083,14 @@ class RateLimitRestServlet(RestServlet):
if not isinstance(messages_per_second, int) or messages_per_second < 0:
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"%r parameter must be a positive int" % (messages_per_second,),
errcode=Codes.INVALID_PARAM,
)
if not isinstance(burst_count, int) or burst_count < 0:
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"%r parameter must be a positive int" % (burst_count,),
errcode=Codes.INVALID_PARAM,
)
@@ -1059,7 +1106,7 @@ class RateLimitRestServlet(RestServlet):
"burst_count": ratelimit.burst_count,
}
- return 200, ret
+ return HTTPStatus.OK, ret
async def on_DELETE(
self, request: SynapseRequest, user_id: str
@@ -1067,11 +1114,13 @@ class RateLimitRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id):
- raise SynapseError(400, "Only local users can be ratelimited")
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited"
+ )
if not await self.store.get_user_by_id(user_id):
raise NotFoundError("User not found")
await self.store.delete_ratelimit_for_user(user_id)
- return 200, {}
+ return HTTPStatus.OK, {}
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index 67e03dca04..09f378f919 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -14,7 +14,17 @@
import logging
import re
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ Dict,
+ List,
+ Optional,
+ Tuple,
+ Union,
+)
from typing_extensions import TypedDict
@@ -28,7 +38,6 @@ from synapse.http.server import HttpServer, finish_request
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
- parse_boolean,
parse_bytes_from_args,
parse_json_object_from_request,
parse_string,
@@ -155,11 +164,14 @@ class LoginRestServlet(RestServlet):
login_submission = parse_json_object_from_request(request)
if self._msc2918_enabled:
- # Check if this login should also issue a refresh token, as per
- # MSC2918
- should_issue_refresh_token = parse_boolean(
- request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False
+ # Check if this login should also issue a refresh token, as per MSC2918
+ should_issue_refresh_token = login_submission.get(
+ "org.matrix.msc2918.refresh_token", False
)
+ if not isinstance(should_issue_refresh_token, bool):
+ raise SynapseError(
+ 400, "`org.matrix.msc2918.refresh_token` should be true or false."
+ )
else:
should_issue_refresh_token = False
@@ -458,6 +470,7 @@ class RefreshTokenServlet(RestServlet):
self.refreshable_access_token_lifetime = (
hs.config.registration.refreshable_access_token_lifetime
)
+ self.refresh_token_lifetime = hs.config.registration.refresh_token_lifetime
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
refresh_submission = parse_json_object_from_request(request)
@@ -467,22 +480,33 @@ class RefreshTokenServlet(RestServlet):
if not isinstance(token, str):
raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM)
- valid_until_ms = (
- self._clock.time_msec() + self.refreshable_access_token_lifetime
- )
- access_token, refresh_token = await self._auth_handler.refresh_token(
- token, valid_until_ms
- )
- expires_in_ms = valid_until_ms - self._clock.time_msec()
- return (
- 200,
- {
- "access_token": access_token,
- "refresh_token": refresh_token,
- "expires_in_ms": expires_in_ms,
- },
+ now = self._clock.time_msec()
+ access_valid_until_ms = None
+ if self.refreshable_access_token_lifetime is not None:
+ access_valid_until_ms = now + self.refreshable_access_token_lifetime
+ refresh_valid_until_ms = None
+ if self.refresh_token_lifetime is not None:
+ refresh_valid_until_ms = now + self.refresh_token_lifetime
+
+ (
+ access_token,
+ refresh_token,
+ actual_access_token_expiry,
+ ) = await self._auth_handler.refresh_token(
+ token, access_valid_until_ms, refresh_valid_until_ms
)
+ response: Dict[str, Union[str, int]] = {
+ "access_token": access_token,
+ "refresh_token": refresh_token,
+ }
+
+ # expires_in_ms is only present if the token expires
+ if actual_access_token_expiry is not None:
+ response["expires_in_ms"] = actual_access_token_expiry - now
+
+ return 200, response
+
class SsoRedirectServlet(RestServlet):
PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index d2b11e39d9..11fd6cd24d 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -41,7 +41,6 @@ from synapse.http.server import HttpServer, finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
- parse_boolean,
parse_json_object_from_request,
parse_string,
)
@@ -449,9 +448,13 @@ class RegisterRestServlet(RestServlet):
if self._msc2918_enabled:
# Check if this registration should also issue a refresh token, as
# per MSC2918
- should_issue_refresh_token = parse_boolean(
- request, name="org.matrix.msc2918.refresh_token", default=False
+ should_issue_refresh_token = body.get(
+ "org.matrix.msc2918.refresh_token", False
)
+ if not isinstance(should_issue_refresh_token, bool):
+ raise SynapseError(
+ 400, "`org.matrix.msc2918.refresh_token` should be true or false."
+ )
else:
should_issue_refresh_token = False
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 955d4e8641..73d0f7c950 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -1138,12 +1138,12 @@ class RoomSpaceSummaryRestServlet(RestServlet):
class RoomHierarchyRestServlet(RestServlet):
- PATTERNS = (
+ PATTERNS = [
re.compile(
- "^/_matrix/client/unstable/org.matrix.msc2946"
+ "^/_matrix/client/(v1|unstable/org.matrix.msc2946)"
"/rooms/(?P<room_id>[^/]*)/hierarchy$"
),
- )
+ ]
def __init__(self, hs: "HomeServer"):
super().__init__()
@@ -1168,7 +1168,7 @@ class RoomHierarchyRestServlet(RestServlet):
)
return 200, await self._room_summary_handler.get_room_hierarchy(
- requester.user.to_string(),
+ requester,
room_id,
suggested_only=parse_boolean(request, "suggested_only", default=False),
max_depth=max_depth,
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 1605411b00..446204dbe5 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -764,7 +764,7 @@ class StateResolutionStore:
store: "DataStore"
def get_events(
- self, event_ids: Iterable[str], allow_rejected: bool = False
+ self, event_ids: Collection[str], allow_rejected: bool = False
) -> Awaitable[Dict[str, EventBase]]:
"""Get events from the database
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 6edadea550..499a328201 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -17,6 +17,7 @@ import logging
from typing import (
Awaitable,
Callable,
+ Collection,
Dict,
Iterable,
List,
@@ -44,7 +45,7 @@ async def resolve_events_with_store(
room_version: RoomVersion,
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
- state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
+ state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]],
) -> StateMap[str]:
"""
Args:
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 0623da9aa1..3056e64ff5 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -21,7 +21,7 @@ from synapse.storage.database import LoggingTransaction # noqa: F401
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
from synapse.storage.database import DatabasePool
from synapse.storage.types import Connection
-from synapse.types import StreamToken, get_domain_from_id
+from synapse.types import get_domain_from_id
from synapse.util import json_decoder
if TYPE_CHECKING:
@@ -48,7 +48,7 @@ class SQLBaseStore(metaclass=ABCMeta):
self,
stream_name: str,
instance_name: str,
- token: StreamToken,
+ token: int,
rows: Iterable[Any],
) -> None:
pass
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index bc8364400d..d64910aded 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,12 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional
+from typing import (
+ TYPE_CHECKING,
+ AsyncContextManager,
+ Awaitable,
+ Callable,
+ Dict,
+ Iterable,
+ Optional,
+)
+
+import attr
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.types import Connection
from synapse.types import JsonDict
-from synapse.util import json_encoder
+from synapse.util import Clock, json_encoder
from . import engines
@@ -28,6 +38,45 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+ON_UPDATE_CALLBACK = Callable[[str, str, bool], AsyncContextManager[int]]
+DEFAULT_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
+MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _BackgroundUpdateHandler:
+ """A handler for a given background update.
+
+ Attributes:
+ callback: The function to call to make progress on the background
+ update.
+ oneshot: Wether the update is likely to happen all in one go, ignoring
+ the supplied target duration, e.g. index creation. This is used by
+ the update controller to help correctly schedule the update.
+ """
+
+ callback: Callable[[JsonDict, int], Awaitable[int]]
+ oneshot: bool = False
+
+
+class _BackgroundUpdateContextManager:
+ BACKGROUND_UPDATE_INTERVAL_MS = 1000
+ BACKGROUND_UPDATE_DURATION_MS = 100
+
+ def __init__(self, sleep: bool, clock: Clock):
+ self._sleep = sleep
+ self._clock = clock
+
+ async def __aenter__(self) -> int:
+ if self._sleep:
+ await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000)
+
+ return self.BACKGROUND_UPDATE_DURATION_MS
+
+ async def __aexit__(self, *exc) -> None:
+ pass
+
+
class BackgroundUpdatePerformance:
"""Tracks the how long a background update is taking to update its items"""
@@ -84,20 +133,22 @@ class BackgroundUpdater:
MINIMUM_BACKGROUND_BATCH_SIZE = 1
DEFAULT_BACKGROUND_BATCH_SIZE = 100
- BACKGROUND_UPDATE_INTERVAL_MS = 1000
- BACKGROUND_UPDATE_DURATION_MS = 100
def __init__(self, hs: "HomeServer", database: "DatabasePool"):
self._clock = hs.get_clock()
self.db_pool = database
+ self._database_name = database.name()
+
# if a background update is currently running, its name.
self._current_background_update: Optional[str] = None
+ self._on_update_callback: Optional[ON_UPDATE_CALLBACK] = None
+ self._default_batch_size_callback: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None
+ self._min_batch_size_callback: Optional[MIN_BATCH_SIZE_CALLBACK] = None
+
self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {}
- self._background_update_handlers: Dict[
- str, Callable[[JsonDict, int], Awaitable[int]]
- ] = {}
+ self._background_update_handlers: Dict[str, _BackgroundUpdateHandler] = {}
self._all_done = False
# Whether we're currently running updates
@@ -107,6 +158,83 @@ class BackgroundUpdater:
# enable/disable background updates via the admin API.
self.enabled = True
+ def register_update_controller_callbacks(
+ self,
+ on_update: ON_UPDATE_CALLBACK,
+ default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
+ min_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
+ ) -> None:
+ """Register callbacks from a module for each hook."""
+ if self._on_update_callback is not None:
+ logger.warning(
+ "More than one module tried to register callbacks for controlling"
+ " background updates. Only the callbacks registered by the first module"
+ " (in order of appearance in Synapse's configuration file) that tried to"
+ " do so will be called."
+ )
+
+ return
+
+ self._on_update_callback = on_update
+
+ if default_batch_size is not None:
+ self._default_batch_size_callback = default_batch_size
+
+ if min_batch_size is not None:
+ self._min_batch_size_callback = min_batch_size
+
+ def _get_context_manager_for_update(
+ self,
+ sleep: bool,
+ update_name: str,
+ database_name: str,
+ oneshot: bool,
+ ) -> AsyncContextManager[int]:
+ """Get a context manager to run a background update with.
+
+ If a module has registered a `update_handler` callback, use the context manager
+ it returns.
+
+ Otherwise, returns a context manager that will return a default value, optionally
+ sleeping if needed.
+
+ Args:
+ sleep: Whether we can sleep between updates.
+ update_name: The name of the update.
+ database_name: The name of the database the update is being run on.
+ oneshot: Whether the update will complete all in one go, e.g. index creation.
+ In such cases the returned target duration is ignored.
+
+ Returns:
+ The target duration in milliseconds that the background update should run for.
+
+ Note: this is a *target*, and an iteration may take substantially longer or
+ shorter.
+ """
+ if self._on_update_callback is not None:
+ return self._on_update_callback(update_name, database_name, oneshot)
+
+ return _BackgroundUpdateContextManager(sleep, self._clock)
+
+ async def _default_batch_size(self, update_name: str, database_name: str) -> int:
+ """The batch size to use for the first iteration of a new background
+ update.
+ """
+ if self._default_batch_size_callback is not None:
+ return await self._default_batch_size_callback(update_name, database_name)
+
+ return self.DEFAULT_BACKGROUND_BATCH_SIZE
+
+ async def _min_batch_size(self, update_name: str, database_name: str) -> int:
+ """A lower bound on the batch size of a new background update.
+
+ Used to ensure that progress is always made. Must be greater than 0.
+ """
+ if self._min_batch_size_callback is not None:
+ return await self._min_batch_size_callback(update_name, database_name)
+
+ return self.MINIMUM_BACKGROUND_BATCH_SIZE
+
def get_current_update(self) -> Optional[BackgroundUpdatePerformance]:
"""Returns the current background update, if any."""
@@ -135,13 +263,8 @@ class BackgroundUpdater:
try:
logger.info("Starting background schema updates")
while self.enabled:
- if sleep:
- await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
-
try:
- result = await self.do_next_background_update(
- self.BACKGROUND_UPDATE_DURATION_MS
- )
+ result = await self.do_next_background_update(sleep)
except Exception:
logger.exception("Error doing update")
else:
@@ -203,13 +326,15 @@ class BackgroundUpdater:
return not update_exists
- async def do_next_background_update(self, desired_duration_ms: float) -> bool:
+ async def do_next_background_update(self, sleep: bool = True) -> bool:
"""Does some amount of work on the next queued background update
Returns once some amount of work is done.
Args:
- desired_duration_ms: How long we want to spend updating.
+ sleep: Whether to limit how quickly we run background updates or
+ not.
+
Returns:
True if we have finished running all the background updates, otherwise False
"""
@@ -252,7 +377,19 @@ class BackgroundUpdater:
self._current_background_update = upd["update_name"]
- await self._do_background_update(desired_duration_ms)
+ # We have a background update to run, otherwise we would have returned
+ # early.
+ assert self._current_background_update is not None
+ update_info = self._background_update_handlers[self._current_background_update]
+
+ async with self._get_context_manager_for_update(
+ sleep=sleep,
+ update_name=self._current_background_update,
+ database_name=self._database_name,
+ oneshot=update_info.oneshot,
+ ) as desired_duration_ms:
+ await self._do_background_update(desired_duration_ms)
+
return False
async def _do_background_update(self, desired_duration_ms: float) -> int:
@@ -260,7 +397,7 @@ class BackgroundUpdater:
update_name = self._current_background_update
logger.info("Starting update batch on background update '%s'", update_name)
- update_handler = self._background_update_handlers[update_name]
+ update_handler = self._background_update_handlers[update_name].callback
performance = self._background_update_performance.get(update_name)
@@ -273,9 +410,14 @@ class BackgroundUpdater:
if items_per_ms is not None:
batch_size = int(desired_duration_ms * items_per_ms)
# Clamp the batch size so that we always make progress
- batch_size = max(batch_size, self.MINIMUM_BACKGROUND_BATCH_SIZE)
+ batch_size = max(
+ batch_size,
+ await self._min_batch_size(update_name, self._database_name),
+ )
else:
- batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
+ batch_size = await self._default_batch_size(
+ update_name, self._database_name
+ )
progress_json = await self.db_pool.simple_select_one_onecol(
"background_updates",
@@ -294,6 +436,8 @@ class BackgroundUpdater:
duration_ms = time_stop - time_start
+ performance.update(items_updated, duration_ms)
+
logger.info(
"Running background update %r. Processed %r items in %rms."
" (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
@@ -306,8 +450,6 @@ class BackgroundUpdater:
batch_size,
)
- performance.update(items_updated, duration_ms)
-
return len(self._background_update_performance)
def register_background_update_handler(
@@ -331,7 +473,9 @@ class BackgroundUpdater:
update_name: The name of the update that this code handles.
update_handler: The function that does the update.
"""
- self._background_update_handlers[update_name] = update_handler
+ self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
+ update_handler
+ )
def register_noop_background_update(self, update_name: str) -> None:
"""Register a noop handler for a background update.
@@ -453,7 +597,9 @@ class BackgroundUpdater:
await self._end_background_update(update_name)
return 1
- self.register_background_update_handler(update_name, updater)
+ self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
+ updater, oneshot=True
+ )
async def _end_background_update(self, update_name: str) -> None:
"""Removes a completed background update task from the queue.
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index d957e770dc..3efdd0c920 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -16,6 +16,7 @@ import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import attr
+from typing_extensions import TypedDict
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -37,6 +38,20 @@ DEFAULT_HIGHLIGHT_ACTION = [
]
+class BasePushAction(TypedDict):
+ event_id: str
+ actions: List[Union[dict, str]]
+
+
+class HttpPushAction(BasePushAction):
+ room_id: str
+ stream_ordering: int
+
+
+class EmailPushAction(HttpPushAction):
+ received_ts: Optional[int]
+
+
def _serialize_action(actions, is_highlight):
"""Custom serializer for actions. This allows us to "compress" common actions.
@@ -221,7 +236,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
min_stream_ordering: int,
max_stream_ordering: int,
limit: int = 20,
- ) -> List[dict]:
+ ) -> List[HttpPushAction]:
"""Get a list of the most recent unread push actions for a given user,
within the given stream ordering range. Called by the httppusher.
@@ -326,7 +341,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
min_stream_ordering: int,
max_stream_ordering: int,
limit: int = 20,
- ) -> List[dict]:
+ ) -> List[EmailPushAction]:
"""Get a list of the most recent unread push actions for a given user,
within the given stream ordering range. Called by the emailpusher
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 06832221ad..4171b904eb 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -15,7 +15,7 @@
# limitations under the License.
import itertools
import logging
-from collections import OrderedDict, namedtuple
+from collections import OrderedDict
from typing import (
TYPE_CHECKING,
Any,
@@ -41,9 +41,10 @@ from synapse.events.snapshot import EventContext # noqa: F401
from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.databases.main.events_worker import EventCacheEntry
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import StateMap, get_domain_from_id
from synapse.util import json_encoder
@@ -64,9 +65,6 @@ event_counter = Counter(
)
-_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
-
-
@attr.s(slots=True)
class DeltaState:
"""Deltas to use to update the `current_state_events` table.
@@ -108,23 +106,30 @@ class PersistEventsStore:
self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
- # Ideally we'd move these ID gens here, unfortunately some other ID
- # generators are chained off them so doing so is a bit of a PITA.
- self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
- self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen
-
# This should only exist on instances that are configured to write
assert (
hs.get_instance_name() in hs.config.worker.writers.events
), "Can only instantiate EventsStore on master"
+ # Since we have been configured to write, we ought to have id generators,
+ # rather than id trackers.
+ assert isinstance(self.store._backfill_id_gen, AbstractStreamIdGenerator)
+ assert isinstance(self.store._stream_id_gen, AbstractStreamIdGenerator)
+
+ # Ideally we'd move these ID gens here, unfortunately some other ID
+ # generators are chained off them so doing so is a bit of a PITA.
+ self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
+ self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
+
async def _persist_events_and_state_updates(
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
+ *,
current_state_for_room: Dict[str, StateMap[str]],
state_delta_for_room: Dict[str, DeltaState],
new_forward_extremeties: Dict[str, List[str]],
- backfilled: bool = False,
+ use_negative_stream_ordering: bool = False,
+ inhibit_local_membership_updates: bool = False,
) -> None:
"""Persist a set of events alongside updates to the current state and
forward extremities tables.
@@ -137,7 +142,14 @@ class PersistEventsStore:
room state
new_forward_extremities: Map from room_id to list of event IDs
that are the new forward extremities of the room.
- backfilled
+ use_negative_stream_ordering: Whether to start stream_ordering on
+ the negative side and decrement. This should be set as True
+ for backfilled events because backfilled events get a negative
+ stream ordering so they don't come down incremental `/sync`.
+ inhibit_local_membership_updates: Stop the local_current_membership
+ from being updated by these events. This should be set to True
+ for backfilled events because backfilled events in the past do
+ not affect the current local state.
Returns:
Resolves when the events have been persisted
@@ -159,7 +171,7 @@ class PersistEventsStore:
#
# Note: Multiple instances of this function cannot be in flight at
# the same time for the same room.
- if backfilled:
+ if use_negative_stream_ordering:
stream_ordering_manager = self._backfill_id_gen.get_next_mult(
len(events_and_contexts)
)
@@ -176,13 +188,13 @@ class PersistEventsStore:
"persist_events",
self._persist_events_txn,
events_and_contexts=events_and_contexts,
- backfilled=backfilled,
+ inhibit_local_membership_updates=inhibit_local_membership_updates,
state_delta_for_room=state_delta_for_room,
new_forward_extremeties=new_forward_extremeties,
)
persist_event_counter.inc(len(events_and_contexts))
- if not backfilled:
+ if stream < 0:
# backfilled events have negative stream orderings, so we don't
# want to set the event_persisted_position to that.
synapse.metrics.event_persisted_position.set(
@@ -316,8 +328,9 @@ class PersistEventsStore:
def _persist_events_txn(
self,
txn: LoggingTransaction,
+ *,
events_and_contexts: List[Tuple[EventBase, EventContext]],
- backfilled: bool,
+ inhibit_local_membership_updates: bool = False,
state_delta_for_room: Optional[Dict[str, DeltaState]] = None,
new_forward_extremeties: Optional[Dict[str, List[str]]] = None,
):
@@ -330,7 +343,10 @@ class PersistEventsStore:
Args:
txn
events_and_contexts: events to persist
- backfilled: True if the events were backfilled
+ inhibit_local_membership_updates: Stop the local_current_membership
+ from being updated by these events. This should be set to True
+ for backfilled events because backfilled events in the past do
+ not affect the current local state.
delete_existing True to purge existing table rows for the events
from the database. This is useful when retrying due to
IntegrityError.
@@ -363,9 +379,7 @@ class PersistEventsStore:
events_and_contexts
)
- self._update_room_depths_txn(
- txn, events_and_contexts=events_and_contexts, backfilled=backfilled
- )
+ self._update_room_depths_txn(txn, events_and_contexts=events_and_contexts)
# _update_outliers_txn filters out any events which have already been
# persisted, and returns the filtered list.
@@ -398,7 +412,7 @@ class PersistEventsStore:
txn,
events_and_contexts=events_and_contexts,
all_events_and_contexts=all_events_and_contexts,
- backfilled=backfilled,
+ inhibit_local_membership_updates=inhibit_local_membership_updates,
)
# We call this last as it assumes we've inserted the events into
@@ -1200,7 +1214,6 @@ class PersistEventsStore:
self,
txn,
events_and_contexts: List[Tuple[EventBase, EventContext]],
- backfilled: bool,
):
"""Update min_depth for each room
@@ -1208,13 +1221,18 @@ class PersistEventsStore:
txn (twisted.enterprise.adbapi.Connection): db connection
events_and_contexts (list[(EventBase, EventContext)]): events
we are persisting
- backfilled (bool): True if the events were backfilled
"""
depth_updates: Dict[str, int] = {}
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
- if not backfilled:
+ # Then update the `stream_ordering` position to mark the latest
+ # event as the front of the room. This should not be done for
+ # backfilled events because backfilled events have negative
+ # stream_ordering and happened in the past so we know that we don't
+ # need to update the stream_ordering tip/front for the room.
+ assert event.internal_metadata.stream_ordering is not None
+ if event.internal_metadata.stream_ordering >= 0:
txn.call_after(
self.store._events_stream_cache.entity_has_changed,
event.room_id,
@@ -1427,7 +1445,12 @@ class PersistEventsStore:
return [ec for ec in events_and_contexts if ec[0] not in to_remove]
def _update_metadata_tables_txn(
- self, txn, events_and_contexts, all_events_and_contexts, backfilled
+ self,
+ txn,
+ *,
+ events_and_contexts,
+ all_events_and_contexts,
+ inhibit_local_membership_updates: bool = False,
):
"""Update all the miscellaneous tables for new events
@@ -1439,7 +1462,10 @@ class PersistEventsStore:
events that we were going to persist. This includes events
we've already persisted, etc, that wouldn't appear in
events_and_context.
- backfilled (bool): True if the events were backfilled
+ inhibit_local_membership_updates: Stop the local_current_membership
+ from being updated by these events. This should be set to True
+ for backfilled events because backfilled events in the past do
+ not affect the current local state.
"""
# Insert all the push actions into the event_push_actions table.
@@ -1513,7 +1539,7 @@ class PersistEventsStore:
for event, _ in events_and_contexts
if event.type == EventTypes.Member
],
- backfilled=backfilled,
+ inhibit_local_membership_updates=inhibit_local_membership_updates,
)
# Insert event_reference_hashes table.
@@ -1553,11 +1579,13 @@ class PersistEventsStore:
for row in rows:
event = ev_map[row["event_id"]]
if not row["rejects"] and not row["redacts"]:
- to_prefill.append(_EventCacheEntry(event=event, redacted_event=None))
+ to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
def prefill():
for cache_entry in to_prefill:
- self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
+ self.store._get_event_cache.set(
+ (cache_entry.event.event_id,), cache_entry
+ )
txn.call_after(prefill)
@@ -1638,8 +1666,19 @@ class PersistEventsStore:
txn, table="event_reference_hashes", values=vals
)
- def _store_room_members_txn(self, txn, events, backfilled):
- """Store a room member in the database."""
+ def _store_room_members_txn(
+ self, txn, events, *, inhibit_local_membership_updates: bool = False
+ ):
+ """
+ Store a room member in the database.
+ Args:
+ txn: The transaction to use.
+ events: List of events to store.
+ inhibit_local_membership_updates: Stop the local_current_membership
+ from being updated by these events. This should be set to True
+ for backfilled events because backfilled events in the past do
+ not affect the current local state.
+ """
def non_null_str_or_none(val: Any) -> Optional[str]:
return val if isinstance(val, str) and "\u0000" not in val else None
@@ -1682,7 +1721,7 @@ class PersistEventsStore:
# band membership", like a remote invite or a rejection of a remote invite.
if (
self.is_mine_id(event.state_key)
- and not backfilled
+ and not inhibit_local_membership_updates
and event.internal_metadata.is_outlier()
and event.internal_metadata.is_out_of_band_membership()
):
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c6bf316d5b..4cefc0a07e 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -15,14 +15,18 @@
import logging
import threading
from typing import (
+ TYPE_CHECKING,
+ Any,
Collection,
Container,
Dict,
Iterable,
List,
+ NoReturn,
Optional,
Set,
Tuple,
+ cast,
overload,
)
@@ -38,6 +42,7 @@ from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
+ RoomVersion,
RoomVersions,
)
from synapse.events import EventBase, make_event_from_dict
@@ -56,10 +61,18 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.engines import PostgresEngine
-from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import (
+ AbstractStreamIdTracker,
+ MultiWriterIdGenerator,
+ StreamIdGenerator,
+)
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError
@@ -69,10 +82,13 @@ from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
-# These values are used in the `enqueus_event` and `_do_fetch` methods to
+# These values are used in the `enqueue_event` and `_fetch_loop` methods to
# control how we batch/bulk fetch events from the database.
# The values are plucked out of thing air to make initial sync run faster
# on jki.re
@@ -89,7 +105,7 @@ event_fetch_ongoing_gauge = Gauge(
@attr.s(slots=True, auto_attribs=True)
-class _EventCacheEntry:
+class EventCacheEntry:
event: EventBase
redacted_event: Optional[EventBase]
@@ -129,7 +145,7 @@ class _EventRow:
json: str
internal_metadata: str
format_version: Optional[int]
- room_version_id: Optional[int]
+ room_version_id: Optional[str]
rejected_reason: Optional[str]
redactions: List[str]
outlier: bool
@@ -153,9 +169,16 @@ class EventsWorkerStore(SQLBaseStore):
# options controlling this.
USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
+ self._stream_id_gen: AbstractStreamIdTracker
+ self._backfill_id_gen: AbstractStreamIdTracker
if isinstance(database.engine, PostgresEngine):
# If we're using Postgres than we can use `MultiWriterIdGenerator`
# regardless of whether this process writes to the streams or not.
@@ -214,7 +237,7 @@ class EventsWorkerStore(SQLBaseStore):
5 * 60 * 1000,
)
- self._get_event_cache = LruCache(
+ self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache(
cache_name="*getEvent*",
max_size=hs.config.caches.event_cache_size,
)
@@ -223,19 +246,21 @@ class EventsWorkerStore(SQLBaseStore):
# ID to cache entry. Note that the returned dict may not have the
# requested event in it if the event isn't in the DB.
self._current_event_fetches: Dict[
- str, ObservableDeferred[Dict[str, _EventCacheEntry]]
+ str, ObservableDeferred[Dict[str, EventCacheEntry]]
] = {}
self._event_fetch_lock = threading.Condition()
- self._event_fetch_list = []
+ self._event_fetch_list: List[
+ Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]
+ ] = []
self._event_fetch_ongoing = 0
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
# We define this sequence here so that it can be referenced from both
# the DataStore and PersistEventStore.
- def get_chain_id_txn(txn):
+ def get_chain_id_txn(txn: Cursor) -> int:
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
- return txn.fetchone()[0]
+ return cast(Tuple[int], txn.fetchone())[0]
self.event_chain_id_gen = build_sequence_generator(
db_conn,
@@ -246,7 +271,13 @@ class EventsWorkerStore(SQLBaseStore):
id_column="chain_id",
)
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(
+ self,
+ stream_name: str,
+ instance_name: str,
+ token: int,
+ rows: Iterable[Any],
+ ) -> None:
if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(instance_name, token)
elif stream_name == BackfillStream.NAME:
@@ -280,10 +311,10 @@ class EventsWorkerStore(SQLBaseStore):
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
- get_prev_content: bool = False,
- allow_rejected: bool = False,
- allow_none: Literal[False] = False,
- check_room_id: Optional[str] = None,
+ get_prev_content: bool = ...,
+ allow_rejected: bool = ...,
+ allow_none: Literal[False] = ...,
+ check_room_id: Optional[str] = ...,
) -> EventBase:
...
@@ -292,10 +323,10 @@ class EventsWorkerStore(SQLBaseStore):
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
- get_prev_content: bool = False,
- allow_rejected: bool = False,
- allow_none: Literal[True] = False,
- check_room_id: Optional[str] = None,
+ get_prev_content: bool = ...,
+ allow_rejected: bool = ...,
+ allow_none: Literal[True] = ...,
+ check_room_id: Optional[str] = ...,
) -> Optional[EventBase]:
...
@@ -357,7 +388,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_events(
self,
- event_ids: Iterable[str],
+ event_ids: Collection[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
@@ -544,7 +575,7 @@ class EventsWorkerStore(SQLBaseStore):
async def _get_events_from_cache_or_db(
self, event_ids: Iterable[str], allow_rejected: bool = False
- ) -> Dict[str, _EventCacheEntry]:
+ ) -> Dict[str, EventCacheEntry]:
"""Fetch a bunch of events from the cache or the database.
If events are pulled from the database, they will be cached for future lookups.
@@ -578,7 +609,7 @@ class EventsWorkerStore(SQLBaseStore):
# same dict into itself N times).
already_fetching_ids: Set[str] = set()
already_fetching_deferreds: Set[
- ObservableDeferred[Dict[str, _EventCacheEntry]]
+ ObservableDeferred[Dict[str, EventCacheEntry]]
] = set()
for event_id in missing_events_ids:
@@ -601,8 +632,8 @@ class EventsWorkerStore(SQLBaseStore):
# function returning more events than requested, but that can happen
# already due to `_get_events_from_db`).
fetching_deferred: ObservableDeferred[
- Dict[str, _EventCacheEntry]
- ] = ObservableDeferred(defer.Deferred())
+ Dict[str, EventCacheEntry]
+ ] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
for event_id in missing_events_ids:
self._current_event_fetches[event_id] = fetching_deferred
@@ -658,12 +689,12 @@ class EventsWorkerStore(SQLBaseStore):
return event_entry_map
- def _invalidate_get_event_cache(self, event_id):
+ def _invalidate_get_event_cache(self, event_id: str) -> None:
self._get_event_cache.invalidate((event_id,))
def _get_events_from_cache(
self, events: Iterable[str], update_metrics: bool = True
- ) -> Dict[str, _EventCacheEntry]:
+ ) -> Dict[str, EventCacheEntry]:
"""Fetch events from the caches.
May return rejected events.
@@ -736,38 +767,123 @@ class EventsWorkerStore(SQLBaseStore):
for e in state_to_include.values()
]
- def _do_fetch(self, conn: Connection) -> None:
+ def _maybe_start_fetch_thread(self) -> None:
+ """Starts an event fetch thread if we are not yet at the maximum number."""
+ with self._event_fetch_lock:
+ if (
+ self._event_fetch_list
+ and self._event_fetch_ongoing < EVENT_QUEUE_THREADS
+ ):
+ self._event_fetch_ongoing += 1
+ event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+ # `_event_fetch_ongoing` is decremented in `_fetch_thread`.
+ should_start = True
+ else:
+ should_start = False
+
+ if should_start:
+ run_as_background_process("fetch_events", self._fetch_thread)
+
+ async def _fetch_thread(self) -> None:
+ """Services requests for events from `_event_fetch_list`."""
+ exc = None
+ try:
+ await self.db_pool.runWithConnection(self._fetch_loop)
+ except BaseException as e:
+ exc = e
+ raise
+ finally:
+ should_restart = False
+ event_fetches_to_fail = []
+ with self._event_fetch_lock:
+ self._event_fetch_ongoing -= 1
+ event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+
+ # There may still be work remaining in `_event_fetch_list` if we
+ # failed, or it was added in between us deciding to exit and
+ # decrementing `_event_fetch_ongoing`.
+ if self._event_fetch_list:
+ if exc is None:
+ # We decided to exit, but then some more work was added
+ # before `_event_fetch_ongoing` was decremented.
+ # If a new event fetch thread was not started, we should
+ # restart ourselves since the remaining event fetch threads
+ # may take a while to get around to the new work.
+ #
+ # Unfortunately it is not possible to tell whether a new
+ # event fetch thread was started, so we restart
+ # unconditionally. If we are unlucky, we will end up with
+ # an idle fetch thread, but it will time out after
+ # `EVENT_QUEUE_ITERATIONS * EVENT_QUEUE_TIMEOUT_S` seconds
+ # in any case.
+ #
+ # Note that multiple fetch threads may run down this path at
+ # the same time.
+ should_restart = True
+ elif isinstance(exc, Exception):
+ if self._event_fetch_ongoing == 0:
+ # We were the last remaining fetcher and failed.
+ # Fail any outstanding fetches since no one else will
+ # handle them.
+ event_fetches_to_fail = self._event_fetch_list
+ self._event_fetch_list = []
+ else:
+ # We weren't the last remaining fetcher, so another
+ # fetcher will pick up the work. This will either happen
+ # after their existing work, however long that takes,
+ # or after at most `EVENT_QUEUE_TIMEOUT_S` seconds if
+ # they are idle.
+ pass
+ else:
+ # The exception is a `SystemExit`, `KeyboardInterrupt` or
+ # `GeneratorExit`. Don't try to do anything clever here.
+ pass
+
+ if should_restart:
+ # We exited cleanly but noticed more work.
+ self._maybe_start_fetch_thread()
+
+ if event_fetches_to_fail:
+ # We were the last remaining fetcher and failed.
+ # Fail any outstanding fetches since no one else will handle them.
+ assert exc is not None
+ with PreserveLoggingContext():
+ for _, deferred in event_fetches_to_fail:
+ deferred.errback(exc)
+
+ def _fetch_loop(self, conn: LoggingDatabaseConnection) -> None:
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
- try:
- i = 0
- while True:
- with self._event_fetch_lock:
- event_list = self._event_fetch_list
- self._event_fetch_list = []
-
- if not event_list:
- single_threaded = self.database_engine.single_threaded
- if (
- not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
- or single_threaded
- or i > EVENT_QUEUE_ITERATIONS
- ):
- break
- else:
- self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
- i += 1
- continue
- i = 0
+ i = 0
+ while True:
+ with self._event_fetch_lock:
+ event_list = self._event_fetch_list
+ self._event_fetch_list = []
+
+ if not event_list:
+ # There are no requests waiting. If we haven't yet reached the
+ # maximum iteration limit, wait for some more requests to turn up.
+ # Otherwise, bail out.
+ single_threaded = self.database_engine.single_threaded
+ if (
+ not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
+ or single_threaded
+ or i > EVENT_QUEUE_ITERATIONS
+ ):
+ return
+
+ self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+ i += 1
+ continue
+ i = 0
- self._fetch_event_list(conn, event_list)
- finally:
- self._event_fetch_ongoing -= 1
- event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+ self._fetch_event_list(conn, event_list)
def _fetch_event_list(
- self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
+ self,
+ conn: LoggingDatabaseConnection,
+ event_list: List[Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]],
) -> None:
"""Handle a load of requests from the _event_fetch_list queue
@@ -794,7 +910,7 @@ class EventsWorkerStore(SQLBaseStore):
)
# We only want to resolve deferreds from the main thread
- def fire():
+ def fire() -> None:
for _, d in event_list:
d.callback(row_dict)
@@ -804,18 +920,16 @@ class EventsWorkerStore(SQLBaseStore):
logger.exception("do_fetch")
# We only want to resolve deferreds from the main thread
- def fire(evs, exc):
- for _, d in evs:
- if not d.called:
- with PreserveLoggingContext():
- d.errback(exc)
+ def fire_errback(exc: Exception) -> None:
+ for _, d in event_list:
+ d.errback(exc)
with PreserveLoggingContext():
- self.hs.get_reactor().callFromThread(fire, event_list, e)
+ self.hs.get_reactor().callFromThread(fire_errback, e)
async def _get_events_from_db(
- self, event_ids: Iterable[str]
- ) -> Dict[str, _EventCacheEntry]:
+ self, event_ids: Collection[str]
+ ) -> Dict[str, EventCacheEntry]:
"""Fetch a bunch of events from the database.
May return rejected events.
@@ -831,29 +945,29 @@ class EventsWorkerStore(SQLBaseStore):
map from event id to result. May return extra events which
weren't asked for.
"""
- fetched_events = {}
+ fetched_event_ids: Set[str] = set()
+ fetched_events: Dict[str, _EventRow] = {}
events_to_fetch = event_ids
while events_to_fetch:
row_map = await self._enqueue_events(events_to_fetch)
# we need to recursively fetch any redactions of those events
- redaction_ids = set()
+ redaction_ids: Set[str] = set()
for event_id in events_to_fetch:
row = row_map.get(event_id)
- fetched_events[event_id] = row
+ fetched_event_ids.add(event_id)
if row:
+ fetched_events[event_id] = row
redaction_ids.update(row.redactions)
- events_to_fetch = redaction_ids.difference(fetched_events.keys())
+ events_to_fetch = redaction_ids.difference(fetched_event_ids)
if events_to_fetch:
logger.debug("Also fetching redaction events %s", events_to_fetch)
# build a map from event_id to EventBase
- event_map = {}
+ event_map: Dict[str, EventBase] = {}
for event_id, row in fetched_events.items():
- if not row:
- continue
assert row.event_id == event_id
rejected_reason = row.rejected_reason
@@ -881,6 +995,7 @@ class EventsWorkerStore(SQLBaseStore):
room_version_id = row.room_version_id
+ room_version: Optional[RoomVersion]
if not room_version_id:
# this should only happen for out-of-band membership events which
# arrived before #6983 landed. For all other events, we should have
@@ -951,14 +1066,14 @@ class EventsWorkerStore(SQLBaseStore):
# finally, we can decide whether each one needs redacting, and build
# the cache entries.
- result_map = {}
+ result_map: Dict[str, EventCacheEntry] = {}
for event_id, original_ev in event_map.items():
redactions = fetched_events[event_id].redactions
redacted_event = self._maybe_redact_event_row(
original_ev, redactions, event_map
)
- cache_entry = _EventCacheEntry(
+ cache_entry = EventCacheEntry(
event=original_ev, redacted_event=redacted_event
)
@@ -967,7 +1082,7 @@ class EventsWorkerStore(SQLBaseStore):
return result_map
- async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
+ async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]:
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
@@ -980,23 +1095,12 @@ class EventsWorkerStore(SQLBaseStore):
that weren't requested.
"""
- events_d = defer.Deferred()
+ events_d: "defer.Deferred[Dict[str, _EventRow]]" = defer.Deferred()
with self._event_fetch_lock:
self._event_fetch_list.append((events, events_d))
-
self._event_fetch_lock.notify()
- if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
- self._event_fetch_ongoing += 1
- event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
- should_start = True
- else:
- should_start = False
-
- if should_start:
- run_as_background_process(
- "fetch_events", self.db_pool.runWithConnection, self._do_fetch
- )
+ self._maybe_start_fetch_thread()
logger.debug("Loading %d events: %s", len(events), events)
with PreserveLoggingContext():
@@ -1146,7 +1250,7 @@ class EventsWorkerStore(SQLBaseStore):
# no valid redaction found for this event
return None
- async def have_events_in_timeline(self, event_ids):
+ async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]:
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
@@ -1175,7 +1279,7 @@ class EventsWorkerStore(SQLBaseStore):
event_ids: events we are looking for
Returns:
- set[str]: The events we have already seen.
+ The set of events we have already seen.
"""
res = await self._have_seen_events_dict(
(room_id, event_id) for event_id in event_ids
@@ -1198,7 +1302,9 @@ class EventsWorkerStore(SQLBaseStore):
}
results = {x: True for x in cache_results}
- def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]):
+ def have_seen_events_txn(
+ txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...]
+ ) -> None:
# we deliberately do *not* query the database for room_id, to make the
# query an index-only lookup on `events_event_id_key`.
#
@@ -1224,12 +1330,14 @@ class EventsWorkerStore(SQLBaseStore):
return results
@cached(max_entries=100000, tree=True)
- async def have_seen_event(self, room_id: str, event_id: str):
+ async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn:
# this only exists for the benefit of the @cachedList descriptor on
# _have_seen_events_dict
raise NotImplementedError()
- def _get_current_state_event_counts_txn(self, txn, room_id):
+ def _get_current_state_event_counts_txn(
+ self, txn: LoggingTransaction, room_id: str
+ ) -> int:
"""
See get_current_state_event_counts.
"""
@@ -1254,7 +1362,7 @@ class EventsWorkerStore(SQLBaseStore):
room_id,
)
- async def get_room_complexity(self, room_id):
+ async def get_room_complexity(self, room_id: str) -> Dict[str, float]:
"""
Get a rough approximation of the complexity of the room. This is used by
remote servers to decide whether they wish to join the room or not.
@@ -1262,10 +1370,10 @@ class EventsWorkerStore(SQLBaseStore):
more resources.
Args:
- room_id (str)
+ room_id: The room ID to query.
Returns:
- dict[str:int] of complexity version to complexity.
+ dict[str:float] of complexity version to complexity.
"""
state_events = await self.get_current_state_event_counts(room_id)
@@ -1275,13 +1383,13 @@ class EventsWorkerStore(SQLBaseStore):
return {"v1": complexity_v1}
- def get_current_events_token(self):
+ def get_current_events_token(self) -> int:
"""The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token()
async def get_all_new_forward_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> List[Tuple]:
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
"""Returns new events, for the Events replication stream
Args:
@@ -1295,7 +1403,9 @@ class EventsWorkerStore(SQLBaseStore):
EventsStreamRow.
"""
- def get_all_new_forward_event_rows(txn):
+ def get_all_new_forward_event_rows(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
@@ -1311,7 +1421,9 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (last_id, current_id, instance_name, limit))
- return txn.fetchall()
+ return cast(
+ List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+ )
return await self.db_pool.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
@@ -1319,7 +1431,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_ex_outlier_stream_rows(
self, instance_name: str, last_id: int, current_id: int
- ) -> List[Tuple]:
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
"""Returns de-outliered events, for the Events replication stream
Args:
@@ -1332,7 +1444,9 @@ class EventsWorkerStore(SQLBaseStore):
EventsStreamRow.
"""
- def get_ex_outlier_stream_rows_txn(txn):
+ def get_ex_outlier_stream_rows_txn(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
@@ -1350,7 +1464,9 @@ class EventsWorkerStore(SQLBaseStore):
)
txn.execute(sql, (last_id, current_id, instance_name))
- return txn.fetchall()
+ return cast(
+ List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+ )
return await self.db_pool.runInteraction(
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
@@ -1358,7 +1474,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_all_new_backfill_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> Tuple[List[Tuple[int, list]], int, bool]:
+ ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
"""Get updates for backfill replication stream, including all new
backfilled events and events that have gone from being outliers to not.
@@ -1386,7 +1502,9 @@ class EventsWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
- def get_all_new_backfill_event_rows(txn):
+ def get_all_new_backfill_event_rows(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
@@ -1400,7 +1518,15 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (-last_id, -current_id, instance_name, limit))
- new_event_updates = [(row[0], row[1:]) for row in txn]
+ new_event_updates: List[
+ Tuple[int, Tuple[str, str, str, str, str, str]]
+ ] = []
+ row: Tuple[int, str, str, str, str, str, str]
+ # Type safety: iterating over `txn` yields `Tuple`, i.e.
+ # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
+ # variadic tuple to a fixed length tuple and flags it up as an error.
+ for row in txn: # type: ignore[assignment]
+ new_event_updates.append((row[0], row[1:]))
limited = False
if len(new_event_updates) == limit:
@@ -1423,7 +1549,11 @@ class EventsWorkerStore(SQLBaseStore):
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_id, -upper_bound, instance_name))
- new_event_updates.extend((row[0], row[1:]) for row in txn)
+ # Type safety: iterating over `txn` yields `Tuple`, i.e.
+ # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
+ # variadic tuple to a fixed length tuple and flags it up as an error.
+ for row in txn: # type: ignore[assignment]
+ new_event_updates.append((row[0], row[1:]))
if len(new_event_updates) >= limit:
upper_bound = new_event_updates[-1][0]
@@ -1437,7 +1567,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_all_updated_current_state_deltas(
self, instance_name: str, from_token: int, to_token: int, target_row_count: int
- ) -> Tuple[List[Tuple], int, bool]:
+ ) -> Tuple[List[Tuple[int, str, str, str, str]], int, bool]:
"""Fetch updates from current_state_delta_stream
Args:
@@ -1457,7 +1587,9 @@ class EventsWorkerStore(SQLBaseStore):
* `limited` is whether there are more updates to fetch.
"""
- def get_all_updated_current_state_deltas_txn(txn):
+ def get_all_updated_current_state_deltas_txn(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[int, str, str, str, str]]:
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
@@ -1466,21 +1598,23 @@ class EventsWorkerStore(SQLBaseStore):
ORDER BY stream_id ASC LIMIT ?
"""
txn.execute(sql, (from_token, to_token, instance_name, target_row_count))
- return txn.fetchall()
+ return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
- def get_deltas_for_stream_id_txn(txn, stream_id):
+ def get_deltas_for_stream_id_txn(
+ txn: LoggingTransaction, stream_id: int
+ ) -> List[Tuple[int, str, str, str, str]]:
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE stream_id = ?
"""
txn.execute(sql, [stream_id])
- return txn.fetchall()
+ return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
# we need to make sure that, for every stream id in the results, we get *all*
# the rows with that stream id.
- rows: List[Tuple] = await self.db_pool.runInteraction(
+ rows: List[Tuple[int, str, str, str, str]] = await self.db_pool.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
)
@@ -1509,14 +1643,14 @@ class EventsWorkerStore(SQLBaseStore):
return rows, to_token, True
- async def is_event_after(self, event_id1, event_id2):
+ async def is_event_after(self, event_id1: str, event_id2: str) -> bool:
"""Returns True if event_id1 is after event_id2 in the stream"""
to_1, so_1 = await self.get_event_ordering(event_id1)
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
@cached(max_entries=5000)
- async def get_event_ordering(self, event_id):
+ async def get_event_ordering(self, event_id: str) -> Tuple[int, int]:
res = await self.db_pool.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
@@ -1539,7 +1673,9 @@ class EventsWorkerStore(SQLBaseStore):
None otherwise.
"""
- def get_next_event_to_expire_txn(txn):
+ def get_next_event_to_expire_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[Tuple[str, int]]:
txn.execute(
"""
SELECT event_id, expiry_ts FROM event_expiry
@@ -1547,7 +1683,7 @@ class EventsWorkerStore(SQLBaseStore):
"""
)
- return txn.fetchone()
+ return cast(Optional[Tuple[str, int]], txn.fetchone())
return await self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
@@ -1611,10 +1747,10 @@ class EventsWorkerStore(SQLBaseStore):
return mapping
@wrap_as_background_process("_cleanup_old_transaction_ids")
- async def _cleanup_old_transaction_ids(self):
+ async def _cleanup_old_transaction_ids(self) -> None:
"""Cleans out transaction id mappings older than 24hrs."""
- def _cleanup_old_transaction_ids_txn(txn):
+ def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
sql = """
DELETE FROM event_txn_id
WHERE inserted_ts < ?
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index fa782023d4..3b63267395 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -28,7 +28,10 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.util.id_generators import (
+ AbstractStreamIdTracker,
+ StreamIdGenerator,
+)
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -82,9 +85,9 @@ class PushRulesWorkerStore(
super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
- self._push_rules_stream_id_gen: Union[
- StreamIdGenerator, SlavedIdTracker
- ] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id")
+ self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+ db_conn, "push_rules_stream", "stream_id"
+ )
else:
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id"
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 0e8c168667..e1ddf06916 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -106,6 +106,15 @@ class RefreshTokenLookupResult:
has_next_access_token_been_used: bool
"""True if the next access token was already used at least once."""
+ expiry_ts: Optional[int]
+ """The time at which the refresh token expires and can not be used.
+ If None, the refresh token doesn't expire."""
+
+ ultimate_session_expiry_ts: Optional[int]
+ """The time at which the session comes to an end and can no longer be
+ refreshed.
+ If None, the session can be refreshed indefinitely."""
+
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(
@@ -1626,8 +1635,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
rt.user_id,
rt.device_id,
rt.next_token_id,
- (nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed,
- at.used has_next_access_token_been_used
+ (nrt.next_token_id IS NOT NULL) AS has_next_refresh_token_been_refreshed,
+ at.used AS has_next_access_token_been_used,
+ rt.expiry_ts,
+ rt.ultimate_session_expiry_ts
FROM refresh_tokens rt
LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id
LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id
@@ -1648,6 +1659,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
has_next_refresh_token_been_refreshed=row[4],
# This column is nullable, ensure it's a boolean
has_next_access_token_been_used=(row[5] or False),
+ expiry_ts=row[6],
+ ultimate_session_expiry_ts=row[7],
)
return await self.db_pool.runInteraction(
@@ -1915,6 +1928,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
user_id: str,
token: str,
device_id: Optional[str],
+ expiry_ts: Optional[int],
+ ultimate_session_expiry_ts: Optional[int],
) -> int:
"""Adds a refresh token for the given user.
@@ -1922,6 +1937,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
user_id: The user ID.
token: The new access token to add.
device_id: ID of the device to associate with the refresh token.
+ expiry_ts (milliseconds since the epoch): Time after which the
+ refresh token cannot be used.
+ If None, the refresh token never expires until it has been used.
+ ultimate_session_expiry_ts (milliseconds since the epoch):
+ Time at which the session will end and can not be extended any
+ further.
+ If None, the session can be refreshed indefinitely.
Raises:
StoreError if there was a problem adding this.
Returns:
@@ -1937,6 +1959,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"device_id": device_id,
"token": token,
"next_token_id": None,
+ "expiry_ts": expiry_ts,
+ "ultimate_session_expiry_ts": ultimate_session_expiry_ts,
},
desc="add_refresh_token_to_user",
)
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 402f134d89..428d66a617 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -583,7 +583,8 @@ class EventsPersistenceStorage:
current_state_for_room=current_state_for_room,
state_delta_for_room=state_delta_for_room,
new_forward_extremeties=new_forward_extremeties,
- backfilled=backfilled,
+ use_negative_stream_ordering=backfilled,
+ inhibit_local_membership_updates=backfilled,
)
await self._handle_potentially_left_users(potentially_left_users)
diff --git a/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql b/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql
new file mode 100644
index 0000000000..bdc491c817
--- /dev/null
+++ b/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql
@@ -0,0 +1,28 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+ALTER TABLE refresh_tokens
+ -- We add an expiry_ts column (in milliseconds since the Epoch) to refresh tokens.
+ -- They may not be used after they have expired.
+ -- If null, then the refresh token's lifetime is unlimited.
+ ADD COLUMN expiry_ts BIGINT DEFAULT NULL;
+
+ALTER TABLE refresh_tokens
+ -- We also add an ultimate session expiry time (in milliseconds since the Epoch).
+ -- No matter how much the access and refresh tokens are refreshed, they cannot
+ -- be extended past this time.
+ -- If null, then the session length is unlimited.
+ ADD COLUMN ultimate_session_expiry_ts BIGINT DEFAULT NULL;
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index ac56bc9a05..4ff3013908 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -89,31 +89,77 @@ def _load_current_id(
return (max if step > 0 else min)(current_id, step)
-class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
- @abc.abstractmethod
- def get_next(self) -> AsyncContextManager[int]:
- raise NotImplementedError()
+class AbstractStreamIdTracker(metaclass=abc.ABCMeta):
+ """Tracks the "current" stream ID of a stream that may have multiple writers.
+
+ Stream IDs are monotonically increasing or decreasing integers representing write
+ transactions. The "current" stream ID is the stream ID such that all transactions
+ with equal or smaller stream IDs have completed. Since transactions may complete out
+ of order, this is not the same as the stream ID of the last completed transaction.
+
+ Completed transactions include both committed transactions and transactions that
+ have been rolled back.
+ """
@abc.abstractmethod
- def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
+ def advance(self, instance_name: str, new_id: int) -> None:
+ """Advance the position of the named writer to the given ID, if greater
+ than existing entry.
+ """
raise NotImplementedError()
@abc.abstractmethod
def get_current_token(self) -> int:
+ """Returns the maximum stream id such that all stream ids less than or
+ equal to it have been successfully persisted.
+
+ Returns:
+ The maximum stream id.
+ """
raise NotImplementedError()
@abc.abstractmethod
def get_current_token_for_writer(self, instance_name: str) -> int:
+ """Returns the position of the given writer.
+
+ For streams with single writers this is equivalent to `get_current_token`.
+ """
+ raise NotImplementedError()
+
+
+class AbstractStreamIdGenerator(AbstractStreamIdTracker):
+ """Generates stream IDs for a stream that may have multiple writers.
+
+ Each stream ID represents a write transaction, whose completion is tracked
+ so that the "current" stream ID of the stream can be determined.
+
+ See `AbstractStreamIdTracker` for more details.
+ """
+
+ @abc.abstractmethod
+ def get_next(self) -> AsyncContextManager[int]:
+ """
+ Usage:
+ async with stream_id_gen.get_next() as stream_id:
+ # ... persist event ...
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
+ """
+ Usage:
+ async with stream_id_gen.get_next(n) as stream_ids:
+ # ... persist events ...
+ """
raise NotImplementedError()
class StreamIdGenerator(AbstractStreamIdGenerator):
- """Used to generate new stream ids when persisting events while keeping
- track of which transactions have been completed.
+ """Generates and tracks stream IDs for a stream with a single writer.
- This allows us to get the "current" stream id, i.e. the stream id such that
- all ids less than or equal to it have completed. This handles the fact that
- persistence of events can complete out of order.
+ This class must only be used when the current Synapse process is the sole
+ writer for a stream.
Args:
db_conn(connection): A database connection to use to fetch the
@@ -157,12 +203,12 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
# The key and values are the same, but we never look at the values.
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
+ def advance(self, instance_name: str, new_id: int) -> None:
+ # `StreamIdGenerator` should only be used when there is a single writer,
+ # so replication should never happen.
+ raise Exception("Replication is not supported by StreamIdGenerator")
+
def get_next(self) -> AsyncContextManager[int]:
- """
- Usage:
- async with stream_id_gen.get_next() as stream_id:
- # ... persist event ...
- """
with self._lock:
self._current += self._step
next_id = self._current
@@ -180,11 +226,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return _AsyncCtxManagerWrapper(manager())
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
- """
- Usage:
- async with stream_id_gen.get_next(n) as stream_ids:
- # ... persist events ...
- """
with self._lock:
next_ids = range(
self._current + self._step,
@@ -208,12 +249,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return _AsyncCtxManagerWrapper(manager())
def get_current_token(self) -> int:
- """Returns the maximum stream id such that all stream ids less than or
- equal to it have been successfully persisted.
-
- Returns:
- The maximum stream id.
- """
with self._lock:
if self._unfinished_ids:
return next(iter(self._unfinished_ids)) - self._step
@@ -221,16 +256,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return self._current
def get_current_token_for_writer(self, instance_name: str) -> int:
- """Returns the position of the given writer.
-
- For streams with single writers this is equivalent to
- `get_current_token`.
- """
return self.get_current_token()
class MultiWriterIdGenerator(AbstractStreamIdGenerator):
- """An ID generator that tracks a stream that can have multiple writers.
+ """Generates and tracks stream IDs for a stream with multiple writers.
Uses a Postgres sequence to coordinate ID assignment, but positions of other
writers will only get updated when `advance` is called (by replication).
@@ -475,12 +505,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
return stream_ids
def get_next(self) -> AsyncContextManager[int]:
- """
- Usage:
- async with stream_id_gen.get_next() as stream_id:
- # ... persist event ...
- """
-
# If we have a list of instances that are allowed to write to this
# stream, make sure we're in it.
if self._writers and self._instance_name not in self._writers:
@@ -492,12 +516,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
- """
- Usage:
- async with stream_id_gen.get_next_mult(5) as stream_ids:
- # ... persist events ...
- """
-
# If we have a list of instances that are allowed to write to this
# stream, make sure we're in it.
if self._writers and self._instance_name not in self._writers:
@@ -597,15 +615,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self._add_persisted_position(next_id)
def get_current_token(self) -> int:
- """Returns the maximum stream id such that all stream ids less than or
- equal to it have been successfully persisted.
- """
-
return self.get_persisted_upto_position()
def get_current_token_for_writer(self, instance_name: str) -> int:
- """Returns the position of the given writer."""
-
# If we don't have an entry for the given instance name, we assume it's a
# new writer.
#
@@ -631,10 +643,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
}
def advance(self, instance_name: str, new_id: int) -> None:
- """Advance the position of the named writer to the given ID, if greater
- than existing entry.
- """
-
new_id *= self._return_factor
with self._lock:
diff --git a/synctl b/synctl
index 90559ded62..08709f21ab 100755
--- a/synctl
+++ b/synctl
@@ -24,7 +24,7 @@ import signal
import subprocess
import sys
import time
-from typing import Iterable
+from typing import Iterable, Optional
import yaml
@@ -109,15 +109,14 @@ def start(pidfile: str, app: str, config_files: Iterable[str], daemonize: bool)
return False
-def stop(pidfile: str, app: str) -> bool:
+def stop(pidfile: str, app: str) -> Optional[int]:
"""Attempts to kill a synapse worker from the pidfile.
Args:
pidfile: path to file containing worker's pid
app: name of the worker's appservice
Returns:
- True if the process stopped successfully
- False if process was already stopped or an error occured
+ process id, or None if the process was not running
"""
if os.path.exists(pidfile):
@@ -125,7 +124,7 @@ def stop(pidfile: str, app: str) -> bool:
try:
os.kill(pid, signal.SIGTERM)
write("stopped %s" % (app,), colour=GREEN)
- return True
+ return pid
except OSError as err:
if err.errno == errno.ESRCH:
write("%s not running" % (app,), colour=YELLOW)
@@ -133,14 +132,13 @@ def stop(pidfile: str, app: str) -> bool:
abort("Cannot stop %s: Operation not permitted" % (app,))
else:
abort("Cannot stop %s: Unknown error" % (app,))
- return False
else:
write(
"No running worker of %s found (from %s)\nThe process might be managed by another controller (e.g. systemd)"
% (app, pidfile),
colour=YELLOW,
)
- return False
+ return None
Worker = collections.namedtuple(
@@ -288,32 +286,23 @@ def main():
action = options.action
if action == "stop" or action == "restart":
- has_stopped = True
+ running_pids = []
for worker in workers:
- if not stop(worker.pidfile, worker.app):
- # A worker could not be stopped.
- has_stopped = False
+ pid = stop(worker.pidfile, worker.app)
+ if pid is not None:
+ running_pids.append(pid)
if start_stop_synapse:
- if not stop(pidfile, MAIN_PROCESS):
- has_stopped = False
- if not has_stopped and action == "stop":
- sys.exit(1)
+ pid = stop(pidfile, MAIN_PROCESS)
+ if pid is not None:
+ running_pids.append(pid)
- # Wait for synapse to actually shutdown before starting it again
- if action == "restart":
- running_pids = []
- if start_stop_synapse and os.path.exists(pidfile):
- running_pids.append(int(open(pidfile).read()))
- for worker in workers:
- if os.path.exists(worker.pidfile):
- running_pids.append(int(open(worker.pidfile).read()))
if len(running_pids) > 0:
- write("Waiting for process to exit before restarting...")
+ write("Waiting for processes to exit...")
for running_pid in running_pids:
while pid_running(running_pid):
time.sleep(0.2)
- write("All processes exited; now restarting...")
+ write("All processes exited")
if action == "start" or action == "restart":
error = False
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 4d1e154578..17a9fb63a1 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -22,6 +22,7 @@ import signedjson.sign
from nacl.signing import SigningKey
from signedjson.key import encode_verify_key_base64, get_verify_key
+from twisted.internet import defer
from twisted.internet.defer import Deferred, ensureDeferred
from synapse.api.errors import SynapseError
@@ -577,6 +578,76 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)
+ def test_get_multiple_keys_from_perspectives(self):
+ """Check that we can correctly request multiple keys for the same server"""
+
+ fetcher = PerspectivesKeyFetcher(self.hs)
+
+ SERVER_NAME = "server2"
+
+ testkey1 = signedjson.key.generate_signing_key("ver1")
+ testverifykey1 = signedjson.key.get_verify_key(testkey1)
+ testverifykey1_id = "ed25519:ver1"
+
+ testkey2 = signedjson.key.generate_signing_key("ver2")
+ testverifykey2 = signedjson.key.get_verify_key(testkey2)
+ testverifykey2_id = "ed25519:ver2"
+
+ VALID_UNTIL_TS = 200 * 1000
+
+ response1 = self.build_perspectives_response(
+ SERVER_NAME,
+ testkey1,
+ VALID_UNTIL_TS,
+ )
+ response2 = self.build_perspectives_response(
+ SERVER_NAME,
+ testkey2,
+ VALID_UNTIL_TS,
+ )
+
+ async def post_json(destination, path, data, **kwargs):
+ self.assertEqual(destination, self.mock_perspective_server.server_name)
+ self.assertEqual(path, "/_matrix/key/v2/query")
+
+ # check that the request is for the expected keys
+ q = data["server_keys"]
+
+ self.assertEqual(
+ list(q[SERVER_NAME].keys()), [testverifykey1_id, testverifykey2_id]
+ )
+ return {"server_keys": [response1, response2]}
+
+ self.http_client.post_json.side_effect = post_json
+
+ # fire off two separate requests; they should get merged together into a
+ # single HTTP hit.
+ request1_d = defer.ensureDeferred(
+ fetcher.get_keys(SERVER_NAME, [testverifykey1_id], 0)
+ )
+ request2_d = defer.ensureDeferred(
+ fetcher.get_keys(SERVER_NAME, [testverifykey2_id], 0)
+ )
+
+ keys1 = self.get_success(request1_d)
+ self.assertIn(testverifykey1_id, keys1)
+ k = keys1[testverifykey1_id]
+ self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
+ self.assertEqual(k.verify_key, testverifykey1)
+ self.assertEqual(k.verify_key.alg, "ed25519")
+ self.assertEqual(k.verify_key.version, "ver1")
+
+ keys2 = self.get_success(request2_d)
+ self.assertIn(testverifykey2_id, keys2)
+ k = keys2[testverifykey2_id]
+ self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
+ self.assertEqual(k.verify_key, testverifykey2)
+ self.assertEqual(k.verify_key.alg, "ed25519")
+ self.assertEqual(k.verify_key.version, "ver2")
+
+ # finally, ensure that only one request was sent
+ self.assertEqual(self.http_client.post_json.call_count, 1)
+
def test_get_perspectives_own_key(self):
"""Check that we can get the perspectives server's own keys
diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py
new file mode 100644
index 0000000000..a7031a55f2
--- /dev/null
+++ b/tests/federation/transport/test_client.py
@@ -0,0 +1,64 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+from synapse.api.room_versions import RoomVersions
+from synapse.federation.transport.client import SendJoinParser
+
+from tests.unittest import TestCase
+
+
+class SendJoinParserTestCase(TestCase):
+ def test_two_writes(self) -> None:
+ """Test that the parser can sensibly deserialise an input given in two slices."""
+ parser = SendJoinParser(RoomVersions.V1, True)
+ parent_event = {
+ "content": {
+ "see_room_version_spec": "The event format changes depending on the room version."
+ },
+ "event_id": "$authparent",
+ "room_id": "!somewhere:example.org",
+ "type": "m.room.minimal_pdu",
+ }
+ state = {
+ "content": {
+ "see_room_version_spec": "The event format changes depending on the room version."
+ },
+ "event_id": "$DoNotThinkAboutTheEvent",
+ "room_id": "!somewhere:example.org",
+ "type": "m.room.minimal_pdu",
+ }
+ response = [
+ 200,
+ {
+ "auth_chain": [parent_event],
+ "origin": "matrix.org",
+ "state": [state],
+ },
+ ]
+ serialised_response = json.dumps(response).encode()
+
+ # Send data to the parser
+ parser.write(serialised_response[:100])
+ parser.write(serialised_response[100:])
+
+ # Retrieve the parsed SendJoinResponse
+ parsed_response = parser.finish()
+
+ # Sanity check the parsing gave us sensible data.
+ self.assertEqual(len(parsed_response.auth_events), 1, parsed_response)
+ self.assertEqual(len(parsed_response.state), 1, parsed_response)
+ self.assertEqual(parsed_response.event_dict, {}, parsed_response)
+ self.assertIsNone(parsed_response.event, parsed_response)
diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index 7b95844b55..e5a6a6c747 100644
--- a/tests/handlers/test_room_summary.py
+++ b/tests/handlers/test_room_summary.py
@@ -32,7 +32,7 @@ from synapse.handlers.room_summary import _child_events_comparison_key, _RoomEnt
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, UserID, create_requester
from tests import unittest
@@ -249,7 +249,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
self._assert_rooms(result, expected)
result = self.get_success(
- self.handler.get_room_hierarchy(self.user, self.space)
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
)
self._assert_hierarchy(result, expected)
@@ -263,7 +263,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
expected = [(self.space, [self.room]), (self.room, ())]
self._assert_rooms(result, expected)
- result = self.get_success(self.handler.get_room_hierarchy(user2, self.space))
+ result = self.get_success(
+ self.handler.get_room_hierarchy(create_requester(user2), self.space)
+ )
self._assert_hierarchy(result, expected)
# If the space is made invite-only, it should no longer be viewable.
@@ -274,7 +276,10 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
tok=self.token,
)
self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError)
- self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError)
+ self.get_failure(
+ self.handler.get_room_hierarchy(create_requester(user2), self.space),
+ AuthError,
+ )
# If the space is made world-readable it should return a result.
self.helper.send_state(
@@ -286,7 +291,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
result = self.get_success(self.handler.get_space_summary(user2, self.space))
self._assert_rooms(result, expected)
- result = self.get_success(self.handler.get_room_hierarchy(user2, self.space))
+ result = self.get_success(
+ self.handler.get_room_hierarchy(create_requester(user2), self.space)
+ )
self._assert_hierarchy(result, expected)
# Make it not world-readable again and confirm it results in an error.
@@ -297,7 +304,10 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
tok=self.token,
)
self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError)
- self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError)
+ self.get_failure(
+ self.handler.get_room_hierarchy(create_requester(user2), self.space),
+ AuthError,
+ )
# Join the space and results should be returned.
self.helper.invite(self.space, targ=user2, tok=self.token)
@@ -305,7 +315,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
result = self.get_success(self.handler.get_space_summary(user2, self.space))
self._assert_rooms(result, expected)
- result = self.get_success(self.handler.get_room_hierarchy(user2, self.space))
+ result = self.get_success(
+ self.handler.get_room_hierarchy(create_requester(user2), self.space)
+ )
self._assert_hierarchy(result, expected)
# Attempting to view an unknown room returns the same error.
@@ -314,7 +326,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
AuthError,
)
self.get_failure(
- self.handler.get_room_hierarchy(user2, "#not-a-space:" + self.hs.hostname),
+ self.handler.get_room_hierarchy(
+ create_requester(user2), "#not-a-space:" + self.hs.hostname
+ ),
AuthError,
)
@@ -322,10 +336,10 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
"""In-flight room hierarchy requests are deduplicated."""
# Run two `get_room_hierarchy` calls up until they block.
deferred1 = ensureDeferred(
- self.handler.get_room_hierarchy(self.user, self.space)
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
)
deferred2 = ensureDeferred(
- self.handler.get_room_hierarchy(self.user, self.space)
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
)
# Complete the two calls.
@@ -340,7 +354,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
# A subsequent `get_room_hierarchy` call should not reuse the result.
result3 = self.get_success(
- self.handler.get_room_hierarchy(self.user, self.space)
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
)
self._assert_hierarchy(result3, expected)
self.assertIsNot(result1, result3)
@@ -359,9 +373,11 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
# Run two `get_room_hierarchy` calls for different users up until they block.
deferred1 = ensureDeferred(
- self.handler.get_room_hierarchy(self.user, self.space)
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ )
+ deferred2 = ensureDeferred(
+ self.handler.get_room_hierarchy(create_requester(user2), self.space)
)
- deferred2 = ensureDeferred(self.handler.get_room_hierarchy(user2, self.space))
# Complete the two calls.
result1 = self.get_success(deferred1)
@@ -465,7 +481,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
]
self._assert_rooms(result, expected)
- result = self.get_success(self.handler.get_room_hierarchy(user2, self.space))
+ result = self.get_success(
+ self.handler.get_room_hierarchy(create_requester(user2), self.space)
+ )
self._assert_hierarchy(result, expected)
def test_complex_space(self):
@@ -507,7 +525,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
self._assert_rooms(result, expected)
result = self.get_success(
- self.handler.get_room_hierarchy(self.user, self.space)
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
)
self._assert_hierarchy(result, expected)
@@ -522,7 +540,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
room_ids.append(self.room)
result = self.get_success(
- self.handler.get_room_hierarchy(self.user, self.space, limit=7)
+ self.handler.get_room_hierarchy(
+ create_requester(self.user), self.space, limit=7
+ )
)
# The result should have the space and all of the links, plus some of the
# rooms and a pagination token.
@@ -534,7 +554,10 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
# Check the next page.
result = self.get_success(
self.handler.get_room_hierarchy(
- self.user, self.space, limit=5, from_token=result["next_batch"]
+ create_requester(self.user),
+ self.space,
+ limit=5,
+ from_token=result["next_batch"],
)
)
# The result should have the space and the room in it, along with a link
@@ -554,20 +577,22 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
room_ids.append(self.room)
result = self.get_success(
- self.handler.get_room_hierarchy(self.user, self.space, limit=7)
+ self.handler.get_room_hierarchy(
+ create_requester(self.user), self.space, limit=7
+ )
)
self.assertIn("next_batch", result)
# Changing the room ID, suggested-only, or max-depth causes an error.
self.get_failure(
self.handler.get_room_hierarchy(
- self.user, self.room, from_token=result["next_batch"]
+ create_requester(self.user), self.room, from_token=result["next_batch"]
),
SynapseError,
)
self.get_failure(
self.handler.get_room_hierarchy(
- self.user,
+ create_requester(self.user),
self.space,
suggested_only=True,
from_token=result["next_batch"],
@@ -576,14 +601,19 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self.get_failure(
self.handler.get_room_hierarchy(
- self.user, self.space, max_depth=0, from_token=result["next_batch"]
+ create_requester(self.user),
+ self.space,
+ max_depth=0,
+ from_token=result["next_batch"],
),
SynapseError,
)
# An invalid token is ignored.
self.get_failure(
- self.handler.get_room_hierarchy(self.user, self.space, from_token="foo"),
+ self.handler.get_room_hierarchy(
+ create_requester(self.user), self.space, from_token="foo"
+ ),
SynapseError,
)
@@ -609,14 +639,18 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
# Test just the space itself.
result = self.get_success(
- self.handler.get_room_hierarchy(self.user, self.space, max_depth=0)
+ self.handler.get_room_hierarchy(
+ create_requester(self.user), self.space, max_depth=0
+ )
)
expected: List[Tuple[str, Iterable[str]]] = [(spaces[0], [rooms[0], spaces[1]])]
self._assert_hierarchy(result, expected)
# A single additional layer.
result = self.get_success(
- self.handler.get_room_hierarchy(self.user, self.space, max_depth=1)
+ self.handler.get_room_hierarchy(
+ create_requester(self.user), self.space, max_depth=1
+ )
)
expected += [
(rooms[0], ()),
@@ -626,7 +660,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
# A few layers.
result = self.get_success(
- self.handler.get_room_hierarchy(self.user, self.space, max_depth=3)
+ self.handler.get_room_hierarchy(
+ create_requester(self.user), self.space, max_depth=3
+ )
)
expected += [
(rooms[1], ()),
@@ -657,7 +693,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
self._assert_rooms(result, expected)
result = self.get_success(
- self.handler.get_room_hierarchy(self.user, self.space)
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
)
self._assert_hierarchy(result, expected)
@@ -739,7 +775,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
new=summarize_remote_room_hierarchy,
):
result = self.get_success(
- self.handler.get_room_hierarchy(self.user, self.space)
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
)
self._assert_hierarchy(result, expected)
@@ -906,7 +942,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
new=summarize_remote_room_hierarchy,
):
result = self.get_success(
- self.handler.get_room_hierarchy(self.user, self.space)
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
)
self._assert_hierarchy(result, expected)
@@ -964,7 +1000,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
new=summarize_remote_room_hierarchy,
):
result = self.get_success(
- self.handler.get_room_hierarchy(self.user, self.space)
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
)
self._assert_hierarchy(result, expected)
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 90f800e564..f8cba7b645 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -128,6 +128,7 @@ class EmailPusherTests(HomeserverTestCase):
)
self.auth_handler = hs.get_auth_handler()
+ self.store = hs.get_datastore()
def test_need_validated_email(self):
"""Test that we can only add an email pusher if the user has validated
@@ -408,13 +409,7 @@ class EmailPusherTests(HomeserverTestCase):
self.hs.get_datastore().db_pool.updates._all_done = False
# Now let's actually drive the updates to completion
- while not self.get_success(
- self.hs.get_datastore().db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.hs.get_datastore().db_pool.updates.do_next_background_update(100),
- by=0.1,
- )
+ self.wait_for_background_updates()
# Check that all pushers with unlinked addresses were deleted
pushers = self.get_success(
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 0a6e4795ee..596ba5a0c9 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -17,6 +17,7 @@ from unittest.mock import patch
from synapse.api.room_versions import RoomVersion
from synapse.rest import admin
from synapse.rest.client import login, room, sync
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request
@@ -193,7 +194,10 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
#
# Worker2's event stream position will not advance until we call
# __aexit__ again.
- actx = worker_hs2.get_datastore()._stream_id_gen.get_next()
+ worker_store2 = worker_hs2.get_datastore()
+ assert isinstance(worker_store2._stream_id_gen, MultiWriterIdGenerator)
+
+ actx = worker_store2._stream_id_gen.get_next()
self.get_success(actx.__aenter__())
response = self.helper.send(room_id1, body="Hi!", tok=self.other_access_token)
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index af849bd471..3adadcb46b 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import os
import urllib.parse
+from http import HTTPStatus
from unittest.mock import Mock
from twisted.internet.defer import Deferred
@@ -41,7 +41,7 @@ class VersionTestCase(unittest.HomeserverTestCase):
def test_version_string(self):
channel = self.make_request("GET", self.url, shorthand=False)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(
{"server_version", "python_version"}, set(channel.json_body.keys())
)
@@ -70,11 +70,11 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
content={"localpart": "test"},
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
group_id = channel.json_body["group_id"]
- self._check_group(group_id, expect_code=200)
+ self._check_group(group_id, expect_code=HTTPStatus.OK)
# Invite/join another user
@@ -82,13 +82,13 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={}
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
url = "/groups/%s/self/accept_invite" % (group_id,)
channel = self.make_request(
"PUT", url.encode("ascii"), access_token=self.other_user_token, content={}
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Check other user knows they're in the group
self.assertIn(group_id, self._get_groups_user_is_in(self.admin_user_tok))
@@ -103,10 +103,10 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
content={"localpart": "test"},
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
- # Check group returns 404
- self._check_group(group_id, expect_code=404)
+ # Check group returns HTTPStatus.NOT_FOUND
+ self._check_group(group_id, expect_code=HTTPStatus.NOT_FOUND)
# Check users don't think they're in the group
self.assertNotIn(group_id, self._get_groups_user_is_in(self.admin_user_tok))
@@ -122,15 +122,13 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
- self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
- )
+ self.assertEqual(expect_code, channel.code, msg=channel.json_body)
def _get_groups_user_is_in(self, access_token):
"""Returns the list of groups the user is in (given their access token)"""
channel = self.make_request("GET", b"/joined_groups", access_token=access_token)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
return channel.json_body["groups"]
@@ -210,10 +208,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Should be quarantined
self.assertEqual(
- 404,
- int(channel.code),
+ HTTPStatus.NOT_FOUND,
+ channel.code,
msg=(
- "Expected to receive a 404 on accessing quarantined media: %s"
+ "Expected to receive a HTTPStatus.NOT_FOUND on accessing quarantined media: %s"
% server_and_media_id
),
)
@@ -232,8 +230,8 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Expect a forbidden error
self.assertEqual(
- 403,
- int(channel.result["code"]),
+ HTTPStatus.FORBIDDEN,
+ channel.code,
msg="Expected forbidden on quarantining media as a non-admin",
)
@@ -247,8 +245,8 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Expect a forbidden error
self.assertEqual(
- 403,
- int(channel.result["code"]),
+ HTTPStatus.FORBIDDEN,
+ channel.code,
msg="Expected forbidden on quarantining media as a non-admin",
)
@@ -279,7 +277,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
)
# Should be successful
- self.assertEqual(200, int(channel.code), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code)
# Quarantine the media
url = "/_synapse/admin/v1/media/quarantine/%s/%s" % (
@@ -292,7 +290,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
access_token=admin_user_tok,
)
self.pump(1.0)
- self.assertEqual(200, int(channel.code), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Attempt to access the media
self._ensure_quarantined(admin_user_tok, server_name_and_media_id)
@@ -348,11 +346,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
access_token=admin_user_tok,
)
self.pump(1.0)
- self.assertEqual(200, int(channel.code), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(
- json.loads(channel.result["body"].decode("utf-8")),
- {"num_quarantined": 2},
- "Expected 2 quarantined items",
+ channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items"
)
# Convert mxc URLs to server/media_id strings
@@ -396,11 +392,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
access_token=admin_user_tok,
)
self.pump(1.0)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(
- json.loads(channel.result["body"].decode("utf-8")),
- {"num_quarantined": 2},
- "Expected 2 quarantined items",
+ channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items"
)
# Attempt to access each piece of media
@@ -432,7 +426,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),)
channel = self.make_request("POST", url, access_token=admin_user_tok)
self.pump(1.0)
- self.assertEqual(200, int(channel.code), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Quarantine all media by this user
url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
@@ -444,11 +438,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
access_token=admin_user_tok,
)
self.pump(1.0)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(
- json.loads(channel.result["body"].decode("utf-8")),
- {"num_quarantined": 1},
- "Expected 1 quarantined item",
+ channel.json_body, {"num_quarantined": 1}, "Expected 1 quarantined item"
)
# Attempt to access each piece of media, the first should fail, the
@@ -467,10 +459,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Shouldn't be quarantined
self.assertEqual(
- 200,
- int(channel.code),
+ HTTPStatus.OK,
+ channel.code,
msg=(
- "Expected to receive a 200 on accessing not-quarantined media: %s"
+ "Expected to receive a HTTPStatus.OK on accessing not-quarantined media: %s"
% server_and_media_id_2
),
)
@@ -499,7 +491,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase):
def test_purge_history(self):
"""
Simple test of purge history API.
- Test only that is is possible to call, get status 200 and purge_id.
+ Test only that is is possible to call, get status HTTPStatus.OK and purge_id.
"""
channel = self.make_request(
@@ -509,7 +501,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertIn("purge_id", channel.json_body)
purge_id = channel.json_body["purge_id"]
@@ -520,5 +512,5 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("complete", channel.json_body["status"])
diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py
index cd5c60b65c..a5423af652 100644
--- a/tests/rest/admin/test_background_updates.py
+++ b/tests/rest/admin/test_background_updates.py
@@ -46,7 +46,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
)
def test_requester_is_no_admin(self, method: str, url: str):
"""
- If the user is not a server admin, an error 403 is returned.
+ If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
"""
self.register_user("user", "pass", admin=False)
@@ -135,7 +135,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
self._register_bg_update()
self.store.db_pool.updates.start_doing_background_updates()
- self.reactor.pump([1.0, 1.0])
+ self.reactor.pump([1.0, 1.0, 1.0])
channel = self.make_request(
"GET",
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index a3679be205..baff057c56 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -13,6 +13,7 @@
# limitations under the License.
import urllib.parse
+from http import HTTPStatus
from parameterized import parameterized
@@ -53,7 +54,11 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request(method, self.url, b"{}")
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.UNAUTHORIZED,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "PUT", "DELETE"])
@@ -67,13 +72,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.FORBIDDEN,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "PUT", "DELETE"])
def test_user_does_not_exist(self, method: str):
"""
- Tests that a lookup for a user that does not exist returns a 404
+ Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
"""
url = (
"/_synapse/admin/v2/users/@unknown_person:test/devices/%s"
@@ -86,13 +95,13 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@parameterized.expand(["GET", "PUT", "DELETE"])
def test_user_is_not_local(self, method: str):
"""
- Tests that a lookup for a user that is not a local returns a 400
+ Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
"""
url = (
"/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices/%s"
@@ -105,12 +114,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
def test_unknown_device(self):
"""
- Tests that a lookup for a device that does not exist returns either 404 or 200.
+ Tests that a lookup for a device that does not exist returns either HTTPStatus.NOT_FOUND or HTTPStatus.OK.
"""
url = "/_synapse/admin/v2/users/%s/devices/unknown_device" % urllib.parse.quote(
self.other_user
@@ -122,7 +131,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
channel = self.make_request(
@@ -131,7 +140,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
channel = self.make_request(
"DELETE",
@@ -139,8 +148,8 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- # Delete unknown device returns status 200
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ # Delete unknown device returns status HTTPStatus.OK
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
def test_update_device_too_long_display_name(self):
"""
@@ -167,7 +176,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
content=update,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"])
# Ensure the display name was not updated.
@@ -177,12 +186,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("new display", channel.json_body["display_name"])
def test_update_no_display_name(self):
"""
- Tests that a update for a device without JSON returns a 200
+ Tests that a update for a device without JSON returns a HTTPStatus.OK
"""
# Set iniital display name.
update = {"display_name": "new display"}
@@ -198,7 +207,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Ensure the display name was not updated.
channel = self.make_request(
@@ -207,7 +216,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("new display", channel.json_body["display_name"])
def test_update_display_name(self):
@@ -222,7 +231,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
content={"display_name": "new displayname"},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Check new display_name
channel = self.make_request(
@@ -231,7 +240,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("new displayname", channel.json_body["display_name"])
def test_get_device(self):
@@ -244,7 +253,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
# Check that all fields are available
self.assertIn("user_id", channel.json_body)
@@ -269,7 +278,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Ensure that the number of devices is decreased
res = self.get_success(self.handler.get_devices_by_user(self.other_user))
@@ -299,7 +308,11 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.UNAUTHORIZED,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self):
@@ -314,12 +327,16 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.FORBIDDEN,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self):
"""
- Tests that a lookup for a user that does not exist returns a 404
+ Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
"""
url = "/_synapse/admin/v2/users/@unknown_person:test/devices"
channel = self.make_request(
@@ -328,12 +345,12 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_user_is_not_local(self):
"""
- Tests that a lookup for a user that is not a local returns a 400
+ Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
"""
url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices"
@@ -343,7 +360,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
def test_user_has_no_devices(self):
@@ -359,7 +376,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["devices"]))
@@ -379,7 +396,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(number_devices, channel.json_body["total"])
self.assertEqual(number_devices, len(channel.json_body["devices"]))
self.assertEqual(self.other_user, channel.json_body["devices"][0]["user_id"])
@@ -417,7 +434,11 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.UNAUTHORIZED,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self):
@@ -432,12 +453,16 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.FORBIDDEN,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self):
"""
- Tests that a lookup for a user that does not exist returns a 404
+ Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
"""
url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices"
channel = self.make_request(
@@ -446,12 +471,12 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_user_is_not_local(self):
"""
- Tests that a lookup for a user that is not a local returns a 400
+ Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
"""
url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices"
@@ -461,12 +486,12 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
def test_unknown_devices(self):
"""
- Tests that a remove of a device that does not exist returns 200.
+ Tests that a remove of a device that does not exist returns HTTPStatus.OK.
"""
channel = self.make_request(
"POST",
@@ -475,8 +500,8 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
content={"devices": ["unknown_device1", "unknown_device2"]},
)
- # Delete unknown devices returns status 200
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ # Delete unknown devices returns status HTTPStatus.OK
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
def test_delete_devices(self):
"""
@@ -505,7 +530,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
content={"devices": device_ids},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
res = self.get_success(self.handler.get_devices_by_user(self.other_user))
self.assertEqual(0, len(res))
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index e9ef89731f..a9c46ec62d 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
+from http import HTTPStatus
import synapse.rest.admin
from synapse.api.errors import Codes
@@ -76,12 +76,16 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.UNAUTHORIZED,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self):
"""
- If the user is not a server admin, an error 403 is returned.
+ If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
"""
channel = self.make_request(
@@ -90,7 +94,11 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.FORBIDDEN,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_default_success(self):
@@ -104,7 +112,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20)
self.assertNotIn("next_token", channel.json_body)
@@ -121,7 +129,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 5)
self.assertEqual(channel.json_body["next_token"], 5)
@@ -138,7 +146,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 15)
self.assertNotIn("next_token", channel.json_body)
@@ -155,7 +163,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(channel.json_body["next_token"], 15)
self.assertEqual(len(channel.json_body["event_reports"]), 10)
@@ -172,7 +180,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 10)
self.assertEqual(len(channel.json_body["event_reports"]), 10)
self.assertNotIn("next_token", channel.json_body)
@@ -192,7 +200,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 10)
self.assertEqual(len(channel.json_body["event_reports"]), 10)
self.assertNotIn("next_token", channel.json_body)
@@ -212,7 +220,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 5)
self.assertEqual(len(channel.json_body["event_reports"]), 5)
self.assertNotIn("next_token", channel.json_body)
@@ -234,7 +242,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20)
report = 1
@@ -252,7 +260,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20)
report = 1
@@ -265,7 +273,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
def test_invalid_search_order(self):
"""
- Testing that a invalid search order returns a 400
+ Testing that a invalid search order returns a HTTPStatus.BAD_REQUEST
"""
channel = self.make_request(
@@ -274,13 +282,17 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual("Unknown direction: bar", channel.json_body["error"])
def test_limit_is_negative(self):
"""
- Testing that a negative limit parameter returns a 400
+ Testing that a negative limit parameter returns a HTTPStatus.BAD_REQUEST
"""
channel = self.make_request(
@@ -289,12 +301,16 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_from_is_negative(self):
"""
- Testing that a negative from parameter returns a 400
+ Testing that a negative from parameter returns a HTTPStatus.BAD_REQUEST
"""
channel = self.make_request(
@@ -303,7 +319,11 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_next_token(self):
@@ -319,7 +339,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20)
self.assertNotIn("next_token", channel.json_body)
@@ -332,7 +352,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20)
self.assertNotIn("next_token", channel.json_body)
@@ -345,7 +365,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 19)
self.assertEqual(channel.json_body["next_token"], 19)
@@ -359,7 +379,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 1)
self.assertNotIn("next_token", channel.json_body)
@@ -372,10 +392,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
"rooms/%s/report/%s" % (room_id, event_id),
- json.dumps({"score": -100, "reason": "this makes me sad"}),
+ {"score": -100, "reason": "this makes me sad"},
access_token=user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
def _create_event_and_report_without_parameters(self, room_id, user_tok):
"""Create and report an event, but omit reason and score"""
@@ -385,10 +405,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
"rooms/%s/report/%s" % (room_id, event_id),
- json.dumps({}),
+ {},
access_token=user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
def _check_fields(self, content):
"""Checks that all attributes are present in an event report"""
@@ -439,12 +459,16 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.UNAUTHORIZED,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self):
"""
- If the user is not a server admin, an error 403 is returned.
+ If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
"""
channel = self.make_request(
@@ -453,7 +477,11 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.FORBIDDEN,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_default_success(self):
@@ -467,12 +495,12 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self._check_fields(channel.json_body)
def test_invalid_report_id(self):
"""
- Testing that an invalid `report_id` returns a 400.
+ Testing that an invalid `report_id` returns a HTTPStatus.BAD_REQUEST.
"""
# `report_id` is negative
@@ -482,7 +510,11 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"The report_id parameter must be a string representing a positive integer.",
@@ -496,7 +528,11 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"The report_id parameter must be a string representing a positive integer.",
@@ -510,7 +546,11 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"The report_id parameter must be a string representing a positive integer.",
@@ -519,7 +559,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
def test_report_id_not_found(self):
"""
- Testing that a not existing `report_id` returns a 404.
+ Testing that a not existing `report_id` returns a HTTPStatus.NOT_FOUND.
"""
channel = self.make_request(
@@ -528,7 +568,11 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.NOT_FOUND,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
self.assertEqual("Event report not found", channel.json_body["error"])
@@ -540,10 +584,10 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
"rooms/%s/report/%s" % (room_id, event_id),
- json.dumps({"score": -100, "reason": "this makes me sad"}),
+ {"score": -100, "reason": "this makes me sad"},
access_token=user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
def _check_fields(self, content):
"""Checks that all attributes are present in a event report"""
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index db0e78c039..6618279dd1 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import os
+from http import HTTPStatus
from parameterized import parameterized
@@ -56,7 +56,11 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
channel = self.make_request("DELETE", url, b"{}")
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.UNAUTHORIZED,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self):
@@ -74,12 +78,16 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.FORBIDDEN,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_media_does_not_exist(self):
"""
- Tests that a lookup for a media that does not exist returns a 404
+ Tests that a lookup for a media that does not exist returns a HTTPStatus.NOT_FOUND
"""
url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
@@ -89,12 +97,12 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_media_is_not_local(self):
"""
- Tests that a lookup for a media that is not a local returns a 400
+ Tests that a lookup for a media that is not a local returns a HTTPStatus.BAD_REQUEST
"""
url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345")
@@ -104,7 +112,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("Can only delete local media", channel.json_body["error"])
def test_delete_media(self):
@@ -117,7 +125,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
# Upload some media into the room
response = self.helper.upload_media(
- upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200
+ upload_resource,
+ SMALL_PNG,
+ tok=self.admin_user_tok,
+ expect_code=HTTPStatus.OK,
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@@ -137,10 +148,11 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
# Should be successful
self.assertEqual(
- 200,
+ HTTPStatus.OK,
channel.code,
msg=(
- "Expected to receive a 200 on accessing media: %s" % server_and_media_id
+ "Expected to receive a HTTPStatus.OK on accessing media: %s"
+ % server_and_media_id
),
)
@@ -157,7 +169,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
media_id,
@@ -174,10 +186,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
self.assertEqual(
- 404,
+ HTTPStatus.NOT_FOUND,
channel.code,
msg=(
- "Expected to receive a 404 on accessing deleted media: %s"
+ "Expected to receive a HTTPStatus.NOT_FOUND on accessing deleted media: %s"
% server_and_media_id
),
)
@@ -216,7 +228,11 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.UNAUTHORIZED,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self):
@@ -232,12 +248,16 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.FORBIDDEN,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_media_is_not_local(self):
"""
- Tests that a lookup for media that is not local returns a 400
+ Tests that a lookup for media that is not local returns a HTTPStatus.BAD_REQUEST
"""
url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain"
@@ -247,7 +267,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("Can only delete local media", channel.json_body["error"])
def test_missing_parameter(self):
@@ -260,7 +280,11 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Missing integer query parameter 'before_ts'", channel.json_body["error"]
@@ -276,7 +300,11 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter before_ts must be a positive integer.",
@@ -289,7 +317,11 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter before_ts you provided is from the year 1970. "
@@ -303,7 +335,11 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter size_gt must be a string representing a positive integer.",
@@ -316,7 +352,11 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual(
"Boolean query parameter 'keep_profiles' must be one of ['true', 'false']",
@@ -345,7 +385,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
media_id,
@@ -370,7 +410,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self._access_media(server_and_media_id)
@@ -382,7 +422,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
server_and_media_id.split("/")[1],
@@ -406,7 +446,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&size_gt=67",
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self._access_media(server_and_media_id)
@@ -417,7 +457,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&size_gt=66",
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
server_and_media_id.split("/")[1],
@@ -439,10 +479,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
"/profile/%s/avatar_url" % (self.admin_user,),
- content=json.dumps({"avatar_url": "mxc://%s" % (server_and_media_id,)}),
+ content={"avatar_url": "mxc://%s" % (server_and_media_id,)},
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
now_ms = self.clock.time_msec()
channel = self.make_request(
@@ -450,7 +490,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self._access_media(server_and_media_id)
@@ -461,7 +501,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
server_and_media_id.split("/")[1],
@@ -484,10 +524,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
"/rooms/%s/state/m.room.avatar" % (room_id,),
- content=json.dumps({"url": "mxc://%s" % (server_and_media_id,)}),
+ content={"url": "mxc://%s" % (server_and_media_id,)},
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
now_ms = self.clock.time_msec()
channel = self.make_request(
@@ -495,7 +535,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self._access_media(server_and_media_id)
@@ -506,7 +546,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
server_and_media_id.split("/")[1],
@@ -523,7 +563,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
# Upload some media into the room
response = self.helper.upload_media(
- upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200
+ upload_resource,
+ SMALL_PNG,
+ tok=self.admin_user_tok,
+ expect_code=HTTPStatus.OK,
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@@ -554,10 +597,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
if expect_success:
self.assertEqual(
- 200,
+ HTTPStatus.OK,
channel.code,
msg=(
- "Expected to receive a 200 on accessing media: %s"
+ "Expected to receive a HTTPStatus.OK on accessing media: %s"
% server_and_media_id
),
)
@@ -565,10 +608,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.assertTrue(os.path.exists(local_path))
else:
self.assertEqual(
- 404,
+ HTTPStatus.NOT_FOUND,
channel.code,
msg=(
- "Expected to receive a 404 on accessing deleted media: %s"
+ "Expected to receive a HTTPStatus.NOT_FOUND on accessing deleted media: %s"
% (server_and_media_id)
),
)
@@ -597,7 +640,10 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
# Upload some media into the room
response = self.helper.upload_media(
- upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200
+ upload_resource,
+ SMALL_PNG,
+ tok=self.admin_user_tok,
+ expect_code=HTTPStatus.OK,
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@@ -617,7 +663,11 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
b"{}",
)
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.UNAUTHORIZED,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["quarantine", "unquarantine"])
@@ -634,7 +684,11 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.FORBIDDEN,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_quarantine_media(self):
@@ -652,7 +706,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
@@ -665,7 +719,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
@@ -690,7 +744,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
# verify that is not in quarantine
@@ -718,7 +772,10 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
# Upload some media into the room
response = self.helper.upload_media(
- upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200
+ upload_resource,
+ SMALL_PNG,
+ tok=self.admin_user_tok,
+ expect_code=HTTPStatus.OK,
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@@ -734,7 +791,11 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url % (action, self.media_id), b"{}")
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.UNAUTHORIZED,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["protect", "unprotect"])
@@ -751,7 +812,11 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.FORBIDDEN,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_protect_media(self):
@@ -769,7 +834,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
@@ -782,7 +847,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
@@ -816,7 +881,11 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.UNAUTHORIZED,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_not_admin(self):
@@ -832,7 +901,11 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.FORBIDDEN,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self):
@@ -845,7 +918,11 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter before_ts must be a positive integer.",
@@ -858,7 +935,11 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter before_ts you provided is from the year 1970. "
diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py
index 9bac423ae0..63087955f2 100644
--- a/tests/rest/admin/test_registration_tokens.py
+++ b/tests/rest/admin/test_registration_tokens.py
@@ -14,6 +14,7 @@
import random
import string
+from http import HTTPStatus
import synapse.rest.admin
from synapse.api.errors import Codes
@@ -63,7 +64,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
def test_create_no_auth(self):
"""Try to create a token without authentication."""
channel = self.make_request("POST", self.url + "/new", {})
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.UNAUTHORIZED,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_create_requester_not_admin(self):
@@ -74,7 +79,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{},
access_token=self.other_user_tok,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.FORBIDDEN,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_create_using_defaults(self):
@@ -86,7 +95,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(len(channel.json_body["token"]), 16)
self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"])
@@ -110,7 +119,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["token"], token)
self.assertEqual(channel.json_body["uses_allowed"], 1)
self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"])
@@ -131,7 +140,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(len(channel.json_body["token"]), 16)
self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"])
@@ -149,7 +158,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_token_invalid_chars(self):
@@ -165,7 +178,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_token_already_exists(self):
@@ -180,7 +197,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
data,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel1.result["code"]), msg=channel1.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel1.code, msg=channel1.json_body)
channel2 = self.make_request(
"POST",
@@ -188,7 +205,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
data,
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel2.result["code"]), msg=channel2.result["body"])
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel2.code, msg=channel2.json_body)
self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_unable_to_generate_token(self):
@@ -220,7 +237,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": 1},
access_token=self.admin_user_tok,
)
- self.assertEqual(500, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(500, channel.code, msg=channel.json_body)
def test_create_uses_allowed(self):
"""Check you can only create a token with good values for uses_allowed."""
@@ -231,7 +248,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 0},
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["uses_allowed"], 0)
# Should fail with negative integer
@@ -241,7 +258,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": -5},
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with float
@@ -251,7 +272,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 1.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_expiry_time(self):
@@ -263,7 +288,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": self.clock.time_msec() - 10000},
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with float
@@ -273,7 +302,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": self.clock.time_msec() + 1000000.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_length(self):
@@ -285,7 +318,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": 64},
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(len(channel.json_body["token"]), 64)
# Should fail with 0
@@ -295,7 +328,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": 0},
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with a negative integer
@@ -305,7 +342,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": -5},
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with a float
@@ -315,7 +356,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": 8.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with 65
@@ -325,7 +370,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": 65},
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# UPDATING
@@ -337,7 +386,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
self.url + "/1234", # Token doesn't exist but that doesn't matter
{},
)
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.UNAUTHORIZED,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_update_requester_not_admin(self):
@@ -348,7 +401,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{},
access_token=self.other_user_tok,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.FORBIDDEN,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_update_non_existent(self):
@@ -360,7 +417,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.NOT_FOUND,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_update_uses_allowed(self):
@@ -375,7 +436,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 1},
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["uses_allowed"], 1)
self.assertIsNone(channel.json_body["expiry_time"])
@@ -386,7 +447,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 0},
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["uses_allowed"], 0)
self.assertIsNone(channel.json_body["expiry_time"])
@@ -397,7 +458,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": None},
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"])
@@ -408,7 +469,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 1.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with a negative integer
@@ -418,7 +483,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": -5},
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_update_expiry_time(self):
@@ -434,7 +503,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": new_expiry_time},
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
self.assertIsNone(channel.json_body["uses_allowed"])
@@ -445,7 +514,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": None},
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertIsNone(channel.json_body["expiry_time"])
self.assertIsNone(channel.json_body["uses_allowed"])
@@ -457,7 +526,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": past_time},
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail a float
@@ -467,7 +540,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": new_expiry_time + 0.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_update_both(self):
@@ -488,7 +565,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["uses_allowed"], 1)
self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
@@ -509,7 +586,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# DELETING
@@ -521,7 +602,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
self.url + "/1234", # Token doesn't exist but that doesn't matter
{},
)
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.UNAUTHORIZED,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_delete_requester_not_admin(self):
@@ -532,7 +617,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{},
access_token=self.other_user_tok,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.FORBIDDEN,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_delete_non_existent(self):
@@ -544,7 +633,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.NOT_FOUND,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_delete(self):
@@ -559,7 +652,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# GETTING ONE
@@ -570,7 +663,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
self.url + "/1234", # Token doesn't exist but that doesn't matter
{},
)
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.UNAUTHORIZED,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_get_requester_not_admin(self):
@@ -581,7 +678,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{},
access_token=self.other_user_tok,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.FORBIDDEN,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_get_non_existent(self):
@@ -593,7 +694,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.NOT_FOUND,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_get(self):
@@ -608,7 +713,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["token"], token)
self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"])
@@ -620,7 +725,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
def test_list_no_auth(self):
"""Try to list tokens without authentication."""
channel = self.make_request("GET", self.url, {})
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.UNAUTHORIZED,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_list_requester_not_admin(self):
@@ -631,7 +740,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{},
access_token=self.other_user_tok,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.FORBIDDEN,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_list_all(self):
@@ -646,7 +759,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(len(channel.json_body["registration_tokens"]), 1)
token_info = channel.json_body["registration_tokens"][0]
self.assertEqual(token_info["token"], token)
@@ -664,7 +777,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
def _test_list_query_parameter(self, valid: str):
"""Helper used to test both valid=true and valid=false."""
@@ -696,7 +813,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(len(channel.json_body["registration_tokens"]), 2)
token_info_1 = channel.json_body["registration_tokens"][0]
token_info_2 = channel.json_body["registration_tokens"][1]
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 07077aff78..56b7a438b6 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -66,7 +66,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
def test_requester_is_no_admin(self):
"""
- If the user is not a server admin, an error 403 is returned.
+ If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
"""
channel = self.make_request(
@@ -76,12 +76,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(403, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_room_does_not_exist(self):
"""
- Check that unknown rooms/server return error 404.
+ Check that unknown rooms/server return error HTTPStatus.NOT_FOUND.
"""
url = "/_synapse/admin/v1/rooms/%s" % "!unknown:test"
@@ -92,12 +92,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_room_is_not_valid(self):
"""
- Check that invalid room names, return an error 400.
+ Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
"""
url = "/_synapse/admin/v1/rooms/%s" % "invalidroom"
@@ -108,7 +108,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(
"invalidroom is not a legal room ID",
channel.json_body["error"],
@@ -127,7 +127,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertIn("new_room_id", channel.json_body)
self.assertIn("kicked_users", channel.json_body)
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -146,7 +146,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(
"User must be our own: @not:exist.bla",
channel.json_body["error"],
@@ -165,7 +165,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_purge_is_not_bool(self):
@@ -181,7 +181,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_purge_room_and_block(self):
@@ -207,7 +207,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -240,7 +240,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -274,7 +274,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -305,9 +305,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
)
# The room is now blocked.
- self.assertEqual(
- HTTPStatus.OK, int(channel.result["code"]), msg=channel.result["body"]
- )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self._is_blocked(room_id)
def test_shutdown_room_consent(self):
@@ -327,7 +325,10 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
# Assert that the user is getting consent error
self.helper.send(
- self.room_id, body="foo", tok=self.other_user_tok, expect_code=403
+ self.room_id,
+ body="foo",
+ tok=self.other_user_tok,
+ expect_code=HTTPStatus.FORBIDDEN,
)
# Test that room is not purged
@@ -345,7 +346,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("new_room_id", channel.json_body)
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -374,7 +375,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
json.dumps({"history_visibility": "world_readable"}),
access_token=self.other_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Test that room is not purged
with self.assertRaises(AssertionError):
@@ -391,7 +392,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("new_room_id", channel.json_body)
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -406,7 +407,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self._has_no_members(self.room_id)
# Assert we can no longer peek into the room
- self._assert_peek(self.room_id, expect_code=403)
+ self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN)
def _is_blocked(self, room_id, expect=True):
"""Assert that the room is blocked or not"""
@@ -502,7 +503,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
)
def test_requester_is_no_admin(self, method: str, url: str):
"""
- If the user is not a server admin, an error 403 is returned.
+ If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
"""
channel = self.make_request(
@@ -524,7 +525,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
)
def test_room_does_not_exist(self, method: str, url: str):
"""
- Check that unknown rooms/server return error 404.
+ Check that unknown rooms/server return error HTTPStatus.NOT_FOUND.
"""
channel = self.make_request(
@@ -545,7 +546,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
)
def test_room_is_not_valid(self, method: str, url: str):
"""
- Check that invalid room names, return an error 400.
+ Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
"""
channel = self.make_request(
@@ -854,7 +855,10 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
# Assert that the user is getting consent error
self.helper.send(
- self.room_id, body="foo", tok=self.other_user_tok, expect_code=403
+ self.room_id,
+ body="foo",
+ tok=self.other_user_tok,
+ expect_code=HTTPStatus.FORBIDDEN,
)
# Test that room is not purged
@@ -951,7 +955,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self._has_no_members(self.room_id)
# Assert we can no longer peek into the room
- self._assert_peek(self.room_id, expect_code=403)
+ self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN)
def _is_blocked(self, room_id: str, expect: bool = True) -> None:
"""Assert that the room is blocked or not"""
@@ -1094,7 +1098,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
)
# Check request completed successfully
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Check that response json body contains a "rooms" key
self.assertTrue(
@@ -1178,7 +1182,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertTrue("rooms" in channel.json_body)
for r in channel.json_body["rooms"]:
@@ -1218,7 +1222,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
def test_correct_room_attributes(self):
"""Test the correct attributes for a room are returned"""
@@ -1241,7 +1245,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
{"room_id": room_id},
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Set this new alias as the canonical alias for this room
self.helper.send_state(
@@ -1273,7 +1277,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Check that rooms were returned
self.assertTrue("rooms" in channel.json_body)
@@ -1328,7 +1332,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Check that rooms were returned
self.assertTrue("rooms" in channel.json_body)
@@ -1467,7 +1471,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
def _search_test(
expected_room_id: Optional[str],
search_term: str,
- expected_http_code: int = 200,
+ expected_http_code: int = HTTPStatus.OK,
):
"""Search for a room and check that the returned room's id is a match
@@ -1485,7 +1489,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
- if expected_http_code != 200:
+ if expected_http_code != HTTPStatus.OK:
return
# Check that rooms were returned
@@ -1528,7 +1532,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
_search_test(None, "foo")
_search_test(None, "bar")
- _search_test(None, "", expected_http_code=400)
+ _search_test(None, "", expected_http_code=HTTPStatus.BAD_REQUEST)
# Test that the whole room id returns the room
_search_test(room_id_1, room_id_1)
@@ -1565,7 +1569,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(room_id, channel.json_body.get("rooms")[0].get("room_id"))
self.assertEqual("ж", channel.json_body.get("rooms")[0].get("name"))
@@ -1598,7 +1602,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertIn("room_id", channel.json_body)
self.assertIn("name", channel.json_body)
@@ -1630,7 +1634,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["joined_local_devices"])
# Have another user join the room
@@ -1644,7 +1648,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(2, channel.json_body["joined_local_devices"])
# leave room
@@ -1656,7 +1660,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["joined_local_devices"])
def test_room_members(self):
@@ -1687,7 +1691,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertCountEqual(
["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"]
@@ -1700,7 +1704,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertCountEqual(
["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"]
@@ -1718,7 +1722,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertIn("state", channel.json_body)
# testing that the state events match is painful and not done here. We assume that
# the create_room already does the right thing, so no need to verify that we got
@@ -1733,7 +1737,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
{"room_id": room_id},
access_token=admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Set this new alias as the canonical alias for this room
self.helper.send_state(
@@ -1776,7 +1780,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
def test_requester_is_no_admin(self):
"""
- If the user is not a server admin, an error 403 is returned.
+ If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
"""
body = json.dumps({"user_id": self.second_user_id})
@@ -1787,7 +1791,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.second_tok,
)
- self.assertEqual(403, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self):
@@ -1803,12 +1807,12 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
def test_local_user_does_not_exist(self):
"""
- Tests that a lookup for a user that does not exist returns a 404
+ Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
"""
body = json.dumps({"user_id": "@unknown:test"})
@@ -1819,7 +1823,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_remote_user(self):
@@ -1835,7 +1839,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(
"This endpoint can only be used with local users",
channel.json_body["error"],
@@ -1843,7 +1847,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
def test_room_does_not_exist(self):
"""
- Check that unknown rooms/server return error 404.
+ Check that unknown rooms/server return error HTTPStatus.NOT_FOUND.
"""
body = json.dumps({"user_id": self.second_user_id})
url = "/_synapse/admin/v1/join/!unknown:test"
@@ -1855,12 +1859,12 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
self.assertEqual("No known servers", channel.json_body["error"])
def test_room_is_not_valid(self):
"""
- Check that invalid room names, return an error 400.
+ Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
"""
body = json.dumps({"user_id": self.second_user_id})
url = "/_synapse/admin/v1/join/invalidroom"
@@ -1872,7 +1876,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(
"invalidroom was not legal room ID or room alias",
channel.json_body["error"],
@@ -1891,7 +1895,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(self.public_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room
@@ -1901,7 +1905,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.second_tok,
)
- self.assertEquals(200, channel.code, msg=channel.json_body)
+ self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
def test_join_private_room_if_not_member(self):
@@ -1922,7 +1926,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(403, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_join_private_room_if_member(self):
@@ -1950,7 +1954,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.admin_user_tok,
)
- self.assertEquals(200, channel.code, msg=channel.json_body)
+ self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
# Join user to room.
@@ -1964,7 +1968,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room
@@ -1974,7 +1978,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.second_tok,
)
- self.assertEquals(200, channel.code, msg=channel.json_body)
+ self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
def test_join_private_room_if_owner(self):
@@ -1995,7 +1999,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room
@@ -2005,7 +2009,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.second_tok,
)
- self.assertEquals(200, channel.code, msg=channel.json_body)
+ self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
def test_context_as_non_admin(self):
@@ -2039,7 +2043,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
% (room_id, events[midway]["event_id"]),
access_token=tok,
)
- self.assertEquals(403, channel.code, msg=channel.json_body)
+ self.assertEquals(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_context_as_admin(self):
@@ -2069,7 +2073,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
% (room_id, events[midway]["event_id"]),
access_token=self.admin_user_tok,
)
- self.assertEquals(200, channel.code, msg=channel.json_body)
+ self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEquals(
channel.json_body["event"]["event_id"], events[midway]["event_id"]
)
@@ -2128,7 +2132,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Now we test that we can join the room and ban a user.
self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok)
@@ -2155,7 +2159,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Now we test that we can join the room (we should have received an
# invite) and can ban a user.
@@ -2181,7 +2185,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Now we test that we can join the room and ban a user.
self.helper.join(room_id, self.second_user_id, tok=self.second_tok)
@@ -2215,11 +2219,11 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- # We expect this to fail with a 400 as there are no room admins.
+ # We expect this to fail with a HTTPStatus.BAD_REQUEST as there are no room admins.
#
# (Note we assert the error message to ensure that it's not denied for
# some other reason)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(
channel.json_body["error"],
"No local admin user in room with power to update power levels.",
@@ -2249,7 +2253,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
@parameterized.expand([("PUT",), ("GET",)])
def test_requester_is_no_admin(self, method: str):
- """If the user is not a server admin, an error 403 is returned."""
+ """If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned."""
channel = self.make_request(
method,
@@ -2263,7 +2267,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
@parameterized.expand([("PUT",), ("GET",)])
def test_room_is_not_valid(self, method: str):
- """Check that invalid room names, return an error 400."""
+ """Check that invalid room names, return an error HTTPStatus.BAD_REQUEST."""
channel = self.make_request(
method,
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index fbceba3254..0b9da4c732 100644
--- a/tests/rest/admin/test_server_notice.py
+++ b/tests/rest/admin/test_server_notice.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from http import HTTPStatus
from typing import List
import synapse.rest.admin
@@ -52,7 +53,11 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"""Try to send a server notice without authentication."""
channel = self.make_request("POST", self.url)
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.UNAUTHORIZED,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self):
@@ -63,12 +68,16 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.FORBIDDEN,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@override_config({"server_notices": {"system_mxid_localpart": "notices"}})
def test_user_does_not_exist(self):
- """Tests that a lookup for a user that does not exist returns a 404"""
+ """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND"""
channel = self.make_request(
"POST",
self.url,
@@ -76,13 +85,13 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
content={"user_id": "@unknown_person:test", "content": ""},
)
- self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@override_config({"server_notices": {"system_mxid_localpart": "notices"}})
def test_user_is_not_local(self):
"""
- Tests that a lookup for a user that is not a local returns a 400
+ Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
"""
channel = self.make_request(
"POST",
@@ -94,7 +103,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(
"Server notices can only be sent to local users", channel.json_body["error"]
)
@@ -110,7 +119,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"])
# no content
@@ -121,7 +130,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
content={"user_id": self.other_user},
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
# no body
@@ -132,7 +141,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
content={"user_id": self.other_user, "content": ""},
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual("'body' not in content", channel.json_body["error"])
@@ -144,7 +153,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
content={"user_id": self.other_user, "content": {"body": ""}},
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual("'msgtype' not in content", channel.json_body["error"])
@@ -160,7 +169,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual(
"Server notices are not enabled on this server", channel.json_body["error"]
@@ -185,7 +194,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg one"},
},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@@ -216,7 +225,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg two"},
},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# user has no new invites or memberships
self._check_invite_and_join_status(self.other_user, 0, 1)
@@ -250,7 +259,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg one"},
},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@@ -293,7 +302,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg two"},
},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@@ -333,7 +342,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg one"},
},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@@ -382,7 +391,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg two"},
},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@@ -440,7 +449,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET", "/_matrix/client/r0/sync", access_token=token
)
- self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.code, HTTPStatus.OK)
# Get the messages
room = channel.json_body["rooms"]["join"][room_id]
diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index ece89a65ac..43d8ca032b 100644
--- a/tests/rest/admin/test_statistics.py
+++ b/tests/rest/admin/test_statistics.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
+from http import HTTPStatus
from typing import Any, Dict, List, Optional
import synapse.rest.admin
@@ -47,21 +47,29 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.UNAUTHORIZED,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self):
"""
- If the user is not a server admin, an error 403 is returned.
+ If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
"""
channel = self.make_request(
"GET",
self.url,
- json.dumps({}),
+ {},
access_token=self.other_user_tok,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.FORBIDDEN,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self):
@@ -75,7 +83,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
@@ -85,7 +97,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative limit
@@ -95,7 +111,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from_ts
@@ -105,7 +125,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative until_ts
@@ -115,7 +139,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# until_ts smaller from_ts
@@ -125,7 +153,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# empty search term
@@ -135,7 +167,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid search order
@@ -145,7 +181,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_limit(self):
@@ -160,7 +200,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 10)
self.assertEqual(len(channel.json_body["users"]), 5)
self.assertEqual(channel.json_body["next_token"], 5)
@@ -178,7 +218,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["users"]), 15)
self.assertNotIn("next_token", channel.json_body)
@@ -196,7 +236,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(channel.json_body["next_token"], 15)
self.assertEqual(len(channel.json_body["users"]), 10)
@@ -218,7 +258,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body)
@@ -231,7 +271,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body)
@@ -244,7 +284,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 19)
self.assertEqual(channel.json_body["next_token"], 19)
@@ -257,7 +297,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 1)
self.assertNotIn("next_token", channel.json_body)
@@ -274,7 +314,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["users"]))
@@ -371,7 +411,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
# filter media starting at `ts1` after creating first media
@@ -381,7 +421,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?from_ts=%s" % (ts1,),
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 0)
self._create_media(self.other_user_tok, 3)
@@ -396,7 +436,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?from_ts=%s&until_ts=%s" % (ts1, ts2),
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
# filter media until `ts2` and earlier
@@ -405,7 +445,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?until_ts=%s" % (ts2,),
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["users"][0]["media_count"], 6)
def test_search_term(self):
@@ -417,7 +457,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
# filter user 1 and 10-19 by `user_id`
@@ -426,7 +466,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?search_term=foo_user_1",
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 11)
# filter on this user in `displayname`
@@ -435,7 +475,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?search_term=bar_user_10",
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["users"][0]["displayname"], "bar_user_10")
self.assertEqual(channel.json_body["total"], 1)
@@ -445,7 +485,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?search_term=foobar",
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 0)
def _create_users_with_media(self, number_users: int, media_per_user: int):
@@ -471,7 +511,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
for _ in range(number_media):
# Upload some media into the room
self.helper.upload_media(
- upload_resource, SMALL_PNG, tok=user_token, expect_code=200
+ upload_resource, SMALL_PNG, tok=user_token, expect_code=HTTPStatus.OK
)
def _check_fields(self, content: List[Dict[str, Any]]):
@@ -505,7 +545,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], len(expected_user_list))
returned_order = [row["user_id"] for row in channel.json_body["users"]]
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 5011e54563..03aa689ace 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -17,6 +17,7 @@ import hmac
import os
import urllib.parse
from binascii import unhexlify
+from http import HTTPStatus
from typing import List, Optional
from unittest.mock import Mock, patch
@@ -74,7 +75,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(
"Shared secret registration is not enabled", channel.json_body["error"]
)
@@ -106,7 +107,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
body = {"nonce": nonce}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("username must be specified", channel.json_body["error"])
# 61 seconds
@@ -114,7 +115,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, body)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("unrecognised nonce", channel.json_body["error"])
def test_register_incorrect_nonce(self):
@@ -137,7 +138,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(403, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual("HMAC incorrect", channel.json_body["error"])
def test_register_correct_nonce(self):
@@ -164,7 +165,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"])
def test_nonce_reuse(self):
@@ -187,13 +188,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"])
# Now, try and reuse it
channel = self.make_request("POST", self.url, body)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("unrecognised nonce", channel.json_body["error"])
def test_missing_parts(self):
@@ -214,7 +215,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must be an empty body present
channel = self.make_request("POST", self.url, {})
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("nonce must be specified", channel.json_body["error"])
#
@@ -224,28 +225,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must be present
channel = self.make_request("POST", self.url, {"nonce": nonce()})
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("username must be specified", channel.json_body["error"])
# Must be a string
body = {"nonce": nonce(), "username": 1234}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"])
# Must not have null bytes
body = {"nonce": nonce(), "username": "abcd\u0000"}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"])
# Must not have null bytes
body = {"nonce": nonce(), "username": "a" * 1000}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"])
#
@@ -256,28 +257,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
body = {"nonce": nonce(), "username": "a"}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("password must be specified", channel.json_body["error"])
# Must be a string
body = {"nonce": nonce(), "username": "a", "password": 1234}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"])
# Must not have null bytes
body = {"nonce": nonce(), "username": "a", "password": "abcd\u0000"}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"])
# Super long
body = {"nonce": nonce(), "username": "a", "password": "A" * 1000}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"])
#
@@ -293,7 +294,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("Invalid user type", channel.json_body["error"])
def test_displayname(self):
@@ -318,11 +319,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, body)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@bob1:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob1:test/displayname")
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("bob1", channel.json_body["displayname"])
# displayname is None
@@ -342,11 +343,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@bob2:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob2:test/displayname")
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("bob2", channel.json_body["displayname"])
# displayname is empty
@@ -366,11 +367,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@bob3:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob3:test/displayname")
- self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
# set displayname
channel = self.make_request("GET", self.url)
@@ -389,11 +390,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@bob4:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob4:test/displayname")
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("Bob's Name", channel.json_body["displayname"])
@override_config(
@@ -437,7 +438,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"])
@@ -461,7 +462,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(401, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self):
@@ -473,7 +474,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", self.url, access_token=other_user_token)
- self.assertEqual(403, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_all_users(self):
@@ -489,7 +490,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(3, len(channel.json_body["users"]))
self.assertEqual(3, channel.json_body["total"])
@@ -503,7 +504,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
expected_user_id: Optional[str],
search_term: str,
search_field: Optional[str] = "name",
- expected_http_code: Optional[int] = 200,
+ expected_http_code: Optional[int] = HTTPStatus.OK,
):
"""Search for a user and check that the returned user's id is a match
@@ -525,7 +526,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
- if expected_http_code != 200:
+ if expected_http_code != HTTPStatus.OK:
return
# Check that users were returned
@@ -586,7 +587,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
@@ -596,7 +597,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid guests
@@ -606,7 +607,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# invalid deactivated
@@ -616,7 +617,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# unkown order_by
@@ -626,7 +627,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# invalid search order
@@ -636,7 +637,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
def test_limit(self):
@@ -654,7 +655,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 5)
self.assertEqual(channel.json_body["next_token"], "5")
@@ -675,7 +676,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 15)
self.assertNotIn("next_token", channel.json_body)
@@ -696,7 +697,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(channel.json_body["next_token"], "15")
self.assertEqual(len(channel.json_body["users"]), 10)
@@ -719,7 +720,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body)
@@ -732,7 +733,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body)
@@ -745,7 +746,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 19)
self.assertEqual(channel.json_body["next_token"], "19")
@@ -759,7 +760,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 1)
self.assertNotIn("next_token", channel.json_body)
@@ -862,7 +863,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
url,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], len(expected_user_list))
returned_order = [row["name"] for row in channel.json_body["users"]]
@@ -936,7 +937,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(401, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_not_admin(self):
@@ -947,7 +948,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", url, access_token=self.other_user_token)
- self.assertEqual(403, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
channel = self.make_request(
@@ -957,12 +958,12 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
content=b"{}",
)
- self.assertEqual(403, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
def test_user_does_not_exist(self):
"""
- Tests that deactivation for a user that does not exist returns a 404
+ Tests that deactivation for a user that does not exist returns a HTTPStatus.NOT_FOUND
"""
channel = self.make_request(
@@ -971,7 +972,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_erase_is_not_bool(self):
@@ -986,18 +987,18 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_user_is_not_local(self):
"""
- Tests that deactivation for a user that is not a local returns a 400
+ Tests that deactivation for a user that is not a local returns a HTTPStatus.BAD_REQUEST
"""
url = "/_synapse/admin/v1/deactivate/@unknown_person:unknown_domain"
channel = self.make_request("POST", url, access_token=self.admin_user_tok)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("Can only deactivate local users", channel.json_body["error"])
def test_deactivate_user_erase_true(self):
@@ -1012,7 +1013,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
@@ -1027,7 +1028,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
content={"erase": True},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Get user
channel = self.make_request(
@@ -1036,7 +1037,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"]))
@@ -1057,7 +1058,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
@@ -1072,7 +1073,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
content={"erase": False},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Get user
channel = self.make_request(
@@ -1081,7 +1082,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"]))
@@ -1111,7 +1112,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
@@ -1126,7 +1127,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
content={"erase": True},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Get user
channel = self.make_request(
@@ -1135,7 +1136,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"]))
@@ -1195,7 +1196,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(403, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
channel = self.make_request(
@@ -1205,12 +1206,12 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=b"{}",
)
- self.assertEqual(403, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
def test_user_does_not_exist(self):
"""
- Tests that a lookup for a user that does not exist returns a 404
+ Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
"""
channel = self.make_request(
@@ -1219,7 +1220,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"])
def test_invalid_parameter(self):
@@ -1234,7 +1235,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"admin": "not_bool"},
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
# deactivated not bool
@@ -1244,7 +1245,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": "not_bool"},
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# password not str
@@ -1254,7 +1255,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"password": True},
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# password not length
@@ -1264,7 +1265,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"password": "x" * 513},
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# user_type not valid
@@ -1274,7 +1275,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"user_type": "new type"},
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# external_ids not valid
@@ -1286,7 +1287,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"external_ids": {"auth_provider": "prov", "wrong_external_id": "id"}
},
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
channel = self.make_request(
@@ -1295,7 +1296,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"external_ids": {"external_id": "id"}},
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
# threepids not valid
@@ -1305,7 +1306,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"threepids": {"medium": "email", "wrong_address": "id"}},
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
channel = self.make_request(
@@ -1314,7 +1315,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"threepids": {"address": "value"}},
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
def test_get_user(self):
@@ -1327,7 +1328,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("User", channel.json_body["displayname"])
self._check_fields(channel.json_body)
@@ -1370,7 +1371,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1433,7 +1434,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1461,9 +1462,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# before limit of monthly active users is reached
channel = self.make_request("GET", "/sync", access_token=self.admin_user_tok)
- if channel.code != 200:
+ if channel.code != HTTPStatus.OK:
raise HttpResponseException(
- channel.code, channel.result["reason"], channel.result["body"]
+ channel.code, channel.result["reason"], channel.json_body
)
# Set monthly active users to the limit
@@ -1625,7 +1626,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "hahaha"},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self._check_fields(channel.json_body)
def test_set_displayname(self):
@@ -1641,7 +1642,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"displayname": "foobar"},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("foobar", channel.json_body["displayname"])
@@ -1652,7 +1653,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("foobar", channel.json_body["displayname"])
@@ -1674,7 +1675,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"]))
# result does not always have the same sort order, therefore it becomes sorted
@@ -1700,7 +1701,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1716,7 +1717,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1732,7 +1733,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"threepids": []},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(0, len(channel.json_body["threepids"]))
self._check_fields(channel.json_body)
@@ -1759,7 +1760,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1778,7 +1779,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1800,7 +1801,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
# other user has this two threepids
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"]))
# result does not always have the same sort order, therefore it becomes sorted
@@ -1819,7 +1820,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
url_first_user,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"])
self.assertEqual(0, len(channel.json_body["threepids"]))
self._check_fields(channel.json_body)
@@ -1848,7 +1849,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["external_ids"]))
# result does not always have the same sort order, therefore it becomes sorted
@@ -1880,7 +1881,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -1899,7 +1900,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -1918,7 +1919,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"external_ids": []},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(0, len(channel.json_body["external_ids"]))
@@ -1947,7 +1948,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -1973,7 +1974,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -2005,7 +2006,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
# must fail
- self.assertEqual(409, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.CONFLICT, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual("External id is already in use.", channel.json_body["error"])
@@ -2016,7 +2017,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -2034,7 +2035,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -2065,7 +2066,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
@@ -2080,7 +2081,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"deactivated": True},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
self.assertIsNone(channel.json_body["password_hash"])
@@ -2096,7 +2097,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
self.assertIsNone(channel.json_body["password_hash"])
@@ -2123,7 +2124,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"deactivated": True},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
@@ -2139,7 +2140,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"displayname": "Foobar"},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
self.assertEqual("Foobar", channel.json_body["displayname"])
@@ -2163,7 +2164,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False},
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
# Reactivate the user.
channel = self.make_request(
@@ -2172,7 +2173,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False, "password": "foo"},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self.assertIsNotNone(channel.json_body["password_hash"])
@@ -2194,7 +2195,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False, "password": "foo"},
)
- self.assertEqual(403, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Reactivate the user without a password.
@@ -2204,7 +2205,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self.assertIsNone(channel.json_body["password_hash"])
@@ -2226,7 +2227,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False, "password": "foo"},
)
- self.assertEqual(403, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Reactivate the user without a password.
@@ -2236,7 +2237,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self.assertIsNone(channel.json_body["password_hash"])
@@ -2255,7 +2256,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"admin": True},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["admin"])
@@ -2266,7 +2267,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["admin"])
@@ -2283,7 +2284,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"user_type": UserTypes.SUPPORT},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"])
@@ -2294,7 +2295,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"])
@@ -2306,7 +2307,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"user_type": None},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertIsNone(channel.json_body["user_type"])
@@ -2317,7 +2318,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertIsNone(channel.json_body["user_type"])
@@ -2347,7 +2348,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"])
self.assertEqual(0, channel.json_body["deactivated"])
@@ -2360,7 +2361,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "abc123", "deactivated": "false"},
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
# Check user is not deactivated
channel = self.make_request(
@@ -2369,7 +2370,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"])
@@ -2394,7 +2395,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": True},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["deactivated"])
self.assertIsNone(channel.json_body["password_hash"])
self._is_erased(user_id, False)
@@ -2445,7 +2446,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(401, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self):
@@ -2460,7 +2461,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(403, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self):
@@ -2474,7 +2475,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["joined_rooms"]))
@@ -2490,7 +2491,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["joined_rooms"]))
@@ -2506,7 +2507,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["joined_rooms"]))
@@ -2527,7 +2528,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(number_rooms, channel.json_body["total"])
self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
@@ -2574,7 +2575,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual([local_and_remote_room_id], channel.json_body["joined_rooms"])
@@ -2603,7 +2604,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(401, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self):
@@ -2618,12 +2619,12 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(403, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self):
"""
- Tests that a lookup for a user that does not exist returns a 404
+ Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
"""
url = "/_synapse/admin/v1/users/@unknown_person:test/pushers"
channel = self.make_request(
@@ -2632,12 +2633,12 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_user_is_not_local(self):
"""
- Tests that a lookup for a user that is not a local returns a 400
+ Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
"""
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers"
@@ -2647,7 +2648,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("Can only look up local users", channel.json_body["error"])
def test_get_pushers(self):
@@ -2662,7 +2663,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
# Register the pusher
@@ -2693,7 +2694,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
for p in channel.json_body["pushers"]:
@@ -2732,7 +2733,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
"""Try to list media of an user without authentication."""
channel = self.make_request(method, self.url, {})
- self.assertEqual(401, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "DELETE"])
@@ -2746,12 +2747,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(403, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "DELETE"])
def test_user_does_not_exist(self, method: str):
- """Tests that a lookup for a user that does not exist returns a 404"""
+ """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND"""
url = "/_synapse/admin/v1/users/@unknown_person:test/media"
channel = self.make_request(
method,
@@ -2759,12 +2760,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@parameterized.expand(["GET", "DELETE"])
def test_user_is_not_local(self, method: str):
- """Tests that a lookup for a user that is not a local returns a 400"""
+ """Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST"""
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media"
channel = self.make_request(
@@ -2773,7 +2774,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("Can only look up local users", channel.json_body["error"])
def test_limit_GET(self):
@@ -2789,7 +2790,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 5)
self.assertEqual(channel.json_body["next_token"], 5)
@@ -2808,7 +2809,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 5)
self.assertEqual(len(channel.json_body["deleted_media"]), 5)
@@ -2825,7 +2826,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 15)
self.assertNotIn("next_token", channel.json_body)
@@ -2844,7 +2845,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 15)
self.assertEqual(len(channel.json_body["deleted_media"]), 15)
@@ -2861,7 +2862,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(channel.json_body["next_token"], 15)
self.assertEqual(len(channel.json_body["media"]), 10)
@@ -2880,7 +2881,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 10)
self.assertEqual(len(channel.json_body["deleted_media"]), 10)
@@ -2894,7 +2895,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# invalid search order
@@ -2904,7 +2905,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# negative limit
@@ -2914,7 +2915,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
@@ -2924,7 +2925,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_next_token(self):
@@ -2947,7 +2948,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), number_media)
self.assertNotIn("next_token", channel.json_body)
@@ -2960,7 +2961,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), number_media)
self.assertNotIn("next_token", channel.json_body)
@@ -2973,7 +2974,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 19)
self.assertEqual(channel.json_body["next_token"], 19)
@@ -2987,7 +2988,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 1)
self.assertNotIn("next_token", channel.json_body)
@@ -3004,7 +3005,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["media"]))
@@ -3019,7 +3020,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["deleted_media"]))
@@ -3036,7 +3037,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(number_media, channel.json_body["total"])
self.assertEqual(number_media, len(channel.json_body["media"]))
self.assertNotIn("next_token", channel.json_body)
@@ -3062,7 +3063,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(number_media, channel.json_body["total"])
self.assertEqual(number_media, len(channel.json_body["deleted_media"]))
self.assertCountEqual(channel.json_body["deleted_media"], media_ids)
@@ -3207,7 +3208,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
# Upload some media into the room
response = self.helper.upload_media(
- upload_resource, image_data, user_token, filename, expect_code=200
+ upload_resource, image_data, user_token, filename, expect_code=HTTPStatus.OK
)
# Extract media ID from the response
@@ -3225,10 +3226,10 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
- 200,
+ HTTPStatus.OK,
channel.code,
msg=(
- f"Expected to receive a 200 on accessing media: {server_and_media_id}"
+ f"Expected to receive a HTTPStatus.OK on accessing media: {server_and_media_id}"
),
)
@@ -3274,7 +3275,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
url,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], len(expected_media_list))
returned_order = [row["media_id"] for row in channel.json_body["media"]]
@@ -3310,14 +3311,14 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST", self.url, b"{}", access_token=self.admin_user_tok
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
return channel.json_body["access_token"]
def test_no_auth(self):
"""Try to login as a user without authentication."""
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(401, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_not_admin(self):
@@ -3326,7 +3327,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
"POST", self.url, b"{}", access_token=self.other_user_tok
)
- self.assertEqual(403, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
def test_send_event(self):
"""Test that sending event as a user works."""
@@ -3351,7 +3352,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# We should only see the one device (from the login in `prepare`)
self.assertEqual(len(channel.json_body["devices"]), 1)
@@ -3363,21 +3364,21 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Test that we can successfully make a request
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Logout with the puppet token
channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# The puppet token should no longer work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(401, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
# .. but the real user's tokens should still work
channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
def test_user_logout_all(self):
"""Tests that the target user calling `/logout/all` does *not* expire
@@ -3388,23 +3389,23 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Test that we can successfully make a request
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Logout all with the real user token
channel = self.make_request(
"POST", "logout/all", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# The puppet token should still work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# .. but the real user's tokens shouldn't
channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(401, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
def test_admin_logout_all(self):
"""Tests that the admin user calling `/logout/all` does expire the
@@ -3415,23 +3416,23 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Test that we can successfully make a request
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Logout all with the admin user token
channel = self.make_request(
"POST", "logout/all", b"{}", access_token=self.admin_user_tok
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# The puppet token should no longer work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(401, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
# .. but the real user's tokens should still work
channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
@unittest.override_config(
{
@@ -3459,7 +3460,10 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Now unaccept it and check that we can't send an event
self.get_success(self.store.user_set_consent_version(self.other_user, "0.0"))
self.helper.send_event(
- room_id, "com.example.test", tok=self.other_user_tok, expect_code=403
+ room_id,
+ "com.example.test",
+ tok=self.other_user_tok,
+ expect_code=HTTPStatus.FORBIDDEN,
)
# Login in as the user
@@ -3477,7 +3481,10 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Trying to join as the other user should fail due to reaching MAU limit.
self.helper.join(
- room_id, user=self.other_user, tok=self.other_user_tok, expect_code=403
+ room_id,
+ user=self.other_user,
+ tok=self.other_user_tok,
+ expect_code=HTTPStatus.FORBIDDEN,
)
# Logging in as the other user and joining a room should work, even
@@ -3512,7 +3519,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
Try to get information of an user without authentication.
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(401, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_not_admin(self):
@@ -3527,12 +3534,12 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.url,
access_token=other_user2_token,
)
- self.assertEqual(403, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_is_not_local(self):
"""
- Tests that a lookup for a user that is not a local returns a 400
+ Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
"""
url = self.url_prefix % "@unknown_person:unknown_domain"
@@ -3541,7 +3548,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
url,
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("Can only whois a local user", channel.json_body["error"])
def test_get_whois_admin(self):
@@ -3553,7 +3560,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
self.assertIn("devices", channel.json_body)
@@ -3568,7 +3575,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.url,
access_token=other_user_token,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
self.assertIn("devices", channel.json_body)
@@ -3598,7 +3605,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
Try to get information of an user without authentication.
"""
channel = self.make_request(method, self.url)
- self.assertEqual(401, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["POST", "DELETE"])
@@ -3609,18 +3616,18 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
other_user_token = self.login("user", "pass")
channel = self.make_request(method, self.url, access_token=other_user_token)
- self.assertEqual(403, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand(["POST", "DELETE"])
def test_user_is_not_local(self, method: str):
"""
- Tests that shadow-banning for a user that is not a local returns a 400
+ Tests that shadow-banning for a user that is not a local returns a HTTPStatus.BAD_REQUEST
"""
url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
channel = self.make_request(method, url, access_token=self.admin_user_tok)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
def test_success(self):
"""
@@ -3632,7 +3639,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
self.assertFalse(result.shadow_banned)
channel = self.make_request("POST", self.url, access_token=self.admin_user_tok)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual({}, channel.json_body)
# Ensure the user is shadow-banned (and the cache was cleared).
@@ -3643,7 +3650,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE", self.url, access_token=self.admin_user_tok
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual({}, channel.json_body)
# Ensure the user is no longer shadow-banned (and the cache was cleared).
@@ -3677,7 +3684,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request(method, self.url, b"{}")
- self.assertEqual(401, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "POST", "DELETE"])
@@ -3693,13 +3700,13 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(403, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "POST", "DELETE"])
def test_user_does_not_exist(self, method: str):
"""
- Tests that a lookup for a user that does not exist returns a 404
+ Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
"""
url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit"
@@ -3709,7 +3716,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@parameterized.expand(
@@ -3721,7 +3728,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
)
def test_user_is_not_local(self, method: str, error_msg: str):
"""
- Tests that a lookup for a user that is not a local returns a 400
+ Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
"""
url = (
"/_synapse/admin/v1/users/@unknown_person:unknown_domain/override_ratelimit"
@@ -3733,7 +3740,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(error_msg, channel.json_body["error"])
def test_invalid_parameter(self):
@@ -3748,7 +3755,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"messages_per_second": "string"},
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# messages_per_second is negative
@@ -3759,7 +3766,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"messages_per_second": -1},
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# burst_count is a string
@@ -3770,7 +3777,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"burst_count": "string"},
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# burst_count is negative
@@ -3781,7 +3788,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"burst_count": -1},
)
- self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_return_zero_when_null(self):
@@ -3806,7 +3813,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["messages_per_second"])
self.assertEqual(0, channel.json_body["burst_count"])
@@ -3820,7 +3827,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertNotIn("messages_per_second", channel.json_body)
self.assertNotIn("burst_count", channel.json_body)
@@ -3831,7 +3838,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"messages_per_second": 10, "burst_count": 11},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(10, channel.json_body["messages_per_second"])
self.assertEqual(11, channel.json_body["burst_count"])
@@ -3842,7 +3849,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"messages_per_second": 20, "burst_count": 21},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(20, channel.json_body["messages_per_second"])
self.assertEqual(21, channel.json_body["burst_count"])
@@ -3852,7 +3859,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(20, channel.json_body["messages_per_second"])
self.assertEqual(21, channel.json_body["burst_count"])
@@ -3862,7 +3869,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertNotIn("messages_per_second", channel.json_body)
self.assertNotIn("burst_count", channel.json_body)
@@ -3872,6 +3879,6 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertNotIn("messages_per_second", channel.json_body)
self.assertNotIn("burst_count", channel.json_body)
diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py
index 4e1c49c28b..7978626e71 100644
--- a/tests/rest/admin/test_username_available.py
+++ b/tests/rest/admin/test_username_available.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from http import HTTPStatus
+
import synapse.rest.admin
from synapse.api.errors import Codes, SynapseError
from synapse.rest.client import login
@@ -33,30 +35,38 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
async def check_username(username):
if username == "allowed":
return True
- raise SynapseError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "User ID already taken.",
+ errcode=Codes.USER_IN_USE,
+ )
handler = self.hs.get_registration_handler()
handler.check_username = check_username
def test_username_available(self):
"""
- The endpoint should return a 200 response if the username does not exist
+ The endpoint should return a HTTPStatus.OK response if the username does not exist
"""
url = "%s?username=%s" % (self.url, "allowed")
channel = self.make_request("GET", url, None, self.admin_user_tok)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["available"])
def test_username_unavailable(self):
"""
- The endpoint should return a 200 response if the username does not exist
+ The endpoint should return a HTTPStatus.OK response if the username does not exist
"""
url = "%s?username=%s" % (self.url, "disallowed")
channel = self.make_request("GET", url, None, self.admin_user_tok)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST,
+ channel.code,
+ msg=channel.json_body,
+ )
self.assertEqual(channel.json_body["errcode"], "M_USER_IN_USE")
self.assertEqual(channel.json_body["error"], "User ID already taken.")
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 8552671431..d8a94f4c12 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from http import HTTPStatus
from typing import Optional, Union
from twisted.internet.defer import succeed
@@ -513,12 +514,26 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
self.user_pass = "pass"
self.user = self.register_user("test", self.user_pass)
+ def use_refresh_token(self, refresh_token: str) -> FakeChannel:
+ """
+ Helper that makes a request to use a refresh token.
+ """
+ return self.make_request(
+ "POST",
+ "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
+ {"refresh_token": refresh_token},
+ )
+
def test_login_issue_refresh_token(self):
"""
A login response should include a refresh_token only if asked.
"""
# Test login
- body = {"type": "m.login.password", "user": "test", "password": self.user_pass}
+ body = {
+ "type": "m.login.password",
+ "user": "test",
+ "password": self.user_pass,
+ }
login_without_refresh = self.make_request(
"POST", "/_matrix/client/r0/login", body
@@ -528,8 +543,8 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
login_with_refresh = self.make_request(
"POST",
- "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true",
- body,
+ "/_matrix/client/r0/login",
+ {"org.matrix.msc2918.refresh_token": True, **body},
)
self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result)
self.assertIn("refresh_token", login_with_refresh.json_body)
@@ -555,11 +570,12 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
register_with_refresh = self.make_request(
"POST",
- "/_matrix/client/r0/register?org.matrix.msc2918.refresh_token=true",
+ "/_matrix/client/r0/register",
{
"username": "test3",
"password": self.user_pass,
"auth": {"type": LoginType.DUMMY},
+ "org.matrix.msc2918.refresh_token": True,
},
)
self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result)
@@ -570,10 +586,15 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
"""
A refresh token can be used to issue a new access token.
"""
- body = {"type": "m.login.password", "user": "test", "password": self.user_pass}
+ body = {
+ "type": "m.login.password",
+ "user": "test",
+ "password": self.user_pass,
+ "org.matrix.msc2918.refresh_token": True,
+ }
login_response = self.make_request(
"POST",
- "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true",
+ "/_matrix/client/r0/login",
body,
)
self.assertEqual(login_response.code, 200, login_response.result)
@@ -599,14 +620,19 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
)
@override_config({"refreshable_access_token_lifetime": "1m"})
- def test_refresh_token_expiration(self):
+ def test_refreshable_access_token_expiration(self):
"""
The access token should have some time as specified in the config.
"""
- body = {"type": "m.login.password", "user": "test", "password": self.user_pass}
+ body = {
+ "type": "m.login.password",
+ "user": "test",
+ "password": self.user_pass,
+ "org.matrix.msc2918.refresh_token": True,
+ }
login_response = self.make_request(
"POST",
- "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true",
+ "/_matrix/client/r0/login",
body,
)
self.assertEqual(login_response.code, 200, login_response.result)
@@ -623,6 +649,128 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
self.assertApproximates(
refresh_response.json_body["expires_in_ms"], 60 * 1000, 100
)
+ access_token = refresh_response.json_body["access_token"]
+
+ # Advance 59 seconds in the future (just shy of 1 minute, the time of expiry)
+ self.reactor.advance(59.0)
+ # Check that our token is valid
+ self.assertEqual(
+ self.make_request(
+ "GET", "/_matrix/client/v3/account/whoami", access_token=access_token
+ ).code,
+ HTTPStatus.OK,
+ )
+
+ # Advance 2 more seconds (just past the time of expiry)
+ self.reactor.advance(2.0)
+ # Check that our token is invalid
+ self.assertEqual(
+ self.make_request(
+ "GET", "/_matrix/client/v3/account/whoami", access_token=access_token
+ ).code,
+ HTTPStatus.UNAUTHORIZED,
+ )
+
+ @override_config(
+ {"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"}
+ )
+ def test_refresh_token_expiry(self):
+ """
+ The refresh token can be configured to have a limited lifetime.
+ When that lifetime has ended, the refresh token can no longer be used to
+ refresh the session.
+ """
+
+ body = {
+ "type": "m.login.password",
+ "user": "test",
+ "password": self.user_pass,
+ "org.matrix.msc2918.refresh_token": True,
+ }
+ login_response = self.make_request(
+ "POST",
+ "/_matrix/client/r0/login",
+ body,
+ )
+ self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result)
+ refresh_token1 = login_response.json_body["refresh_token"]
+
+ # Advance 119 seconds in the future (just shy of 2 minutes)
+ self.reactor.advance(119.0)
+
+ # Refresh our session. The refresh token should still JUST be valid right now.
+ # By doing so, we get a new access token and a new refresh token.
+ refresh_response = self.use_refresh_token(refresh_token1)
+ self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result)
+ self.assertIn(
+ "refresh_token",
+ refresh_response.json_body,
+ "No new refresh token returned after refresh.",
+ )
+ refresh_token2 = refresh_response.json_body["refresh_token"]
+
+ # Advance 121 seconds in the future (just a bit more than 2 minutes)
+ self.reactor.advance(121.0)
+
+ # Try to refresh our session, but instead notice that the refresh token is
+ # not valid (it just expired).
+ refresh_response = self.use_refresh_token(refresh_token2)
+ self.assertEqual(
+ refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result
+ )
+
+ @override_config(
+ {
+ "refreshable_access_token_lifetime": "2m",
+ "refresh_token_lifetime": "2m",
+ "session_lifetime": "3m",
+ }
+ )
+ def test_ultimate_session_expiry(self):
+ """
+ The session can be configured to have an ultimate, limited lifetime.
+ """
+
+ body = {
+ "type": "m.login.password",
+ "user": "test",
+ "password": self.user_pass,
+ "org.matrix.msc2918.refresh_token": True,
+ }
+ login_response = self.make_request(
+ "POST",
+ "/_matrix/client/r0/login",
+ body,
+ )
+ self.assertEqual(login_response.code, 200, login_response.result)
+ refresh_token = login_response.json_body["refresh_token"]
+
+ # Advance shy of 2 minutes into the future
+ self.reactor.advance(119.0)
+
+ # Refresh our session. The refresh token should still be valid right now.
+ refresh_response = self.use_refresh_token(refresh_token)
+ self.assertEqual(refresh_response.code, 200, refresh_response.result)
+ self.assertIn(
+ "refresh_token",
+ refresh_response.json_body,
+ "No new refresh token returned after refresh.",
+ )
+ # Notice that our access token lifetime has been diminished to match the
+ # session lifetime.
+ # 3 minutes - 119 seconds = 61 seconds.
+ self.assertEqual(refresh_response.json_body["expires_in_ms"], 61_000)
+ refresh_token = refresh_response.json_body["refresh_token"]
+
+ # Advance 61 seconds into the future. Our session should have expired
+ # now, because we've had our 3 minutes.
+ self.reactor.advance(61.0)
+
+ # Try to issue a new, refreshed, access token.
+ # This should fail because the refresh token's lifetime has also been
+ # diminished as our session expired.
+ refresh_response = self.use_refresh_token(refresh_token)
+ self.assertEqual(refresh_response.code, 403, refresh_response.result)
def test_refresh_token_invalidation(self):
"""Refresh tokens are invalidated after first use of the next token.
@@ -640,10 +788,15 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|-> fourth_refresh (fails)
"""
- body = {"type": "m.login.password", "user": "test", "password": self.user_pass}
+ body = {
+ "type": "m.login.password",
+ "user": "test",
+ "password": self.user_pass,
+ "org.matrix.msc2918.refresh_token": True,
+ }
login_response = self.make_request(
"POST",
- "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true",
+ "/_matrix/client/r0/login",
body,
)
self.assertEqual(login_response.code, 200, login_response.result)
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index a649e8c618..5ae491ff5a 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -12,11 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
+from contextlib import contextmanager
+from typing import Generator
+from twisted.enterprise.adbapi import ConnectionPool
+from twisted.internet.defer import ensureDeferred
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.room_versions import EventFormatVersions, RoomVersions
from synapse.logging.context import LoggingContext
from synapse.rest import admin
from synapse.rest.client import login, room
-from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.server import HomeServer
+from synapse.storage.databases.main.events_worker import (
+ EVENT_QUEUE_THREADS,
+ EventsWorkerStore,
+)
+from synapse.storage.types import Connection
+from synapse.util import Clock
from synapse.util.async_helpers import yieldable_gather_results
from tests import unittest
@@ -144,3 +157,127 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
# We should have fetched the event from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
+
+
+class DatabaseOutageTestCase(unittest.HomeserverTestCase):
+ """Test event fetching during a database outage."""
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ self.store: EventsWorkerStore = hs.get_datastore()
+
+ self.room_id = f"!room:{hs.hostname}"
+ self.event_ids = [f"event{i}" for i in range(20)]
+
+ self._populate_events()
+
+ def _populate_events(self) -> None:
+ """Ensure that there are test events in the database.
+
+ When testing with the in-memory SQLite database, all the events are lost during
+ the simulated outage.
+
+ To ensure consistency between `room_id`s and `event_id`s before and after the
+ outage, rows are built and inserted manually.
+
+ Upserts are used to handle the non-SQLite case where events are not lost.
+ """
+ self.get_success(
+ self.store.db_pool.simple_upsert(
+ "rooms",
+ {"room_id": self.room_id},
+ {"room_version": RoomVersions.V4.identifier},
+ )
+ )
+
+ self.event_ids = [f"event{i}" for i in range(20)]
+ for idx, event_id in enumerate(self.event_ids):
+ self.get_success(
+ self.store.db_pool.simple_upsert(
+ "events",
+ {"event_id": event_id},
+ {
+ "event_id": event_id,
+ "room_id": self.room_id,
+ "topological_ordering": idx,
+ "stream_ordering": idx,
+ "type": "test",
+ "processed": True,
+ "outlier": False,
+ },
+ )
+ )
+ self.get_success(
+ self.store.db_pool.simple_upsert(
+ "event_json",
+ {"event_id": event_id},
+ {
+ "room_id": self.room_id,
+ "json": json.dumps({"type": "test", "room_id": self.room_id}),
+ "internal_metadata": "{}",
+ "format_version": EventFormatVersions.V3,
+ },
+ )
+ )
+
+ @contextmanager
+ def _outage(self) -> Generator[None, None, None]:
+ """Simulate a database outage.
+
+ Returns:
+ A context manager. While the context is active, any attempts to connect to
+ the database will fail.
+ """
+ connection_pool = self.store.db_pool._db_pool
+
+ # Close all connections and shut down the database `ThreadPool`.
+ connection_pool.close()
+
+ # Restart the database `ThreadPool`.
+ connection_pool.start()
+
+ original_connection_factory = connection_pool.connectionFactory
+
+ def connection_factory(_pool: ConnectionPool) -> Connection:
+ raise Exception("Could not connect to the database.")
+
+ connection_pool.connectionFactory = connection_factory # type: ignore[assignment]
+ try:
+ yield
+ finally:
+ connection_pool.connectionFactory = original_connection_factory
+
+ # If the in-memory SQLite database is being used, all the events are gone.
+ # Restore the test data.
+ self._populate_events()
+
+ def test_failure(self) -> None:
+ """Test that event fetches do not get stuck during a database outage."""
+ with self._outage():
+ failure = self.get_failure(
+ self.store.get_event(self.event_ids[0]), Exception
+ )
+ self.assertEqual(str(failure.value), "Could not connect to the database.")
+
+ def test_recovery(self) -> None:
+ """Test that event fetchers recover after a database outage."""
+ with self._outage():
+ # Kick off a bunch of event fetches but do not pump the reactor
+ event_deferreds = []
+ for event_id in self.event_ids:
+ event_deferreds.append(ensureDeferred(self.store.get_event(event_id)))
+
+ # We should have maxed out on event fetcher threads
+ self.assertEqual(self.store._event_fetch_ongoing, EVENT_QUEUE_THREADS)
+
+ # All the event fetchers will fail
+ self.pump()
+ self.assertEqual(self.store._event_fetch_ongoing, 0)
+
+ for event_deferred in event_deferreds:
+ failure = self.get_failure(event_deferred, Exception)
+ self.assertEqual(
+ str(failure.value), "Could not connect to the database."
+ )
+
+ # This next event fetch should succeed
+ self.get_success(self.store.get_event(self.event_ids[0]))
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index a5f5ebad41..216d816d56 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -1,8 +1,11 @@
-from unittest.mock import Mock
+from mock import Mock
+
+from twisted.internet.defer import Deferred, ensureDeferred
from synapse.storage.background_updates import BackgroundUpdater
from tests import unittest
+from tests.test_utils import make_awaitable
class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
@@ -20,10 +23,10 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
def test_do_background_update(self):
# the time we claim it takes to update one item when running the update
- duration_ms = 4200
+ duration_ms = 10
# the target runtime for each bg update
- target_background_update_duration_ms = 5000000
+ target_background_update_duration_ms = 100
store = self.hs.get_datastore()
self.get_success(
@@ -48,10 +51,8 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
self.update_handler.side_effect = update
self.update_handler.reset_mock()
res = self.get_success(
- self.updates.do_next_background_update(
- target_background_update_duration_ms
- ),
- by=0.1,
+ self.updates.do_next_background_update(False),
+ by=0.01,
)
self.assertFalse(res)
@@ -74,16 +75,93 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
self.update_handler.side_effect = update
self.update_handler.reset_mock()
- result = self.get_success(
- self.updates.do_next_background_update(target_background_update_duration_ms)
- )
+ result = self.get_success(self.updates.do_next_background_update(False))
self.assertFalse(result)
self.update_handler.assert_called_once()
# third step: we don't expect to be called any more
self.update_handler.reset_mock()
- result = self.get_success(
- self.updates.do_next_background_update(target_background_update_duration_ms)
- )
+ result = self.get_success(self.updates.do_next_background_update(False))
self.assertTrue(result)
self.assertFalse(self.update_handler.called)
+
+
+class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, homeserver):
+ self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates
+ # the base test class should have run the real bg updates for us
+ self.assertTrue(
+ self.get_success(self.updates.has_completed_background_updates())
+ )
+
+ self.update_deferred = Deferred()
+ self.update_handler = Mock(return_value=self.update_deferred)
+ self.updates.register_background_update_handler(
+ "test_update", self.update_handler
+ )
+
+ # Mock out the AsyncContextManager
+ self._update_ctx_manager = Mock(spec=["__aenter__", "__aexit__"])
+ self._update_ctx_manager.__aenter__ = Mock(
+ return_value=make_awaitable(None),
+ )
+ self._update_ctx_manager.__aexit__ = Mock(return_value=make_awaitable(None))
+
+ # Mock out the `update_handler` callback
+ self._on_update = Mock(return_value=self._update_ctx_manager)
+
+ # Define a default batch size value that's not the same as the internal default
+ # value (100).
+ self._default_batch_size = 500
+
+ # Register the callbacks with more mocks
+ self.hs.get_module_api().register_background_update_controller_callbacks(
+ on_update=self._on_update,
+ min_batch_size=Mock(return_value=make_awaitable(self._default_batch_size)),
+ default_batch_size=Mock(
+ return_value=make_awaitable(self._default_batch_size),
+ ),
+ )
+
+ def test_controller(self):
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db_pool.simple_insert(
+ "background_updates",
+ values={"update_name": "test_update", "progress_json": "{}"},
+ )
+ )
+
+ # Set the return value for the context manager.
+ enter_defer = Deferred()
+ self._update_ctx_manager.__aenter__ = Mock(return_value=enter_defer)
+
+ # Start the background update.
+ do_update_d = ensureDeferred(self.updates.do_next_background_update(True))
+
+ self.pump()
+
+ # `run_update` should have been called, but the update handler won't be
+ # called until the `enter_defer` (returned by `__aenter__`) is resolved.
+ self._on_update.assert_called_once_with(
+ "test_update",
+ "master",
+ False,
+ )
+ self.assertFalse(do_update_d.called)
+ self.assertFalse(self.update_deferred.called)
+
+ # Resolving the `enter_defer` should call the update handler, which then
+ # blocks.
+ enter_defer.callback(100)
+ self.pump()
+ self.update_handler.assert_called_once_with({}, self._default_batch_size)
+ self.assertFalse(self.update_deferred.called)
+ self._update_ctx_manager.__aexit__.assert_not_called()
+
+ # Resolving the update handler deferred should cause the
+ # `do_next_background_update` to finish and return
+ self.update_deferred.callback(100)
+ self.pump()
+ self._update_ctx_manager.__aexit__.assert_called()
+ self.get_success(do_update_d)
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index b31c5eb5ec..7b7f6c349e 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -664,7 +664,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
):
iterations += 1
self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(False), by=0.1
)
# Ensure that we did actually take multiple iterations to process the
@@ -723,7 +723,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
):
iterations += 1
self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(False), by=0.1
)
# Ensure that we did actually take multiple iterations to process the
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 37cf7bb232..7f5b28aed8 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -23,6 +23,7 @@ from synapse.rest import admin
from synapse.rest.client import login, register, room
from synapse.server import HomeServer
from synapse.storage import DataStore
+from synapse.storage.background_updates import _BackgroundUpdateHandler
from synapse.storage.roommember import ProfileInfo
from synapse.util import Clock
@@ -391,7 +392,9 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
with mock.patch.dict(
self.store.db_pool.updates._background_update_handlers,
- populate_user_directory_process_users=mocked_process_users,
+ populate_user_directory_process_users=_BackgroundUpdateHandler(
+ mocked_process_users,
+ ),
):
self._purge_and_rebuild_user_dir()
diff --git a/tests/unittest.py b/tests/unittest.py
index 165aafc574..eea0903f05 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -331,17 +331,16 @@ class HomeserverTestCase(TestCase):
time.sleep(0.01)
def wait_for_background_updates(self) -> None:
- """
- Block until all background database updates have completed.
+ """Block until all background database updates have completed.
- Note that callers must ensure that's a store property created on the
+ Note that callers must ensure there's a store property created on the
testcase.
"""
while not self.get_success(
self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(False), by=0.1
)
def make_homeserver(self, reactor, clock):
@@ -500,8 +499,7 @@ class HomeserverTestCase(TestCase):
async def run_bg_updates():
with LoggingContext("run_bg_updates"):
- while not await stor.db_pool.updates.has_completed_background_updates():
- await stor.db_pool.updates.do_next_background_update(1)
+ self.get_success(stor.db_pool.updates.run_background_updates(False))
hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
stor = hs.get_datastore()
|