diff --git a/.ci/scripts/test_old_deps.sh b/.ci/scripts/test_old_deps.sh
index 54ec3c8b0d..b2859f7522 100755
--- a/.ci/scripts/test_old_deps.sh
+++ b/.ci/scripts/test_old_deps.sh
@@ -8,7 +8,9 @@ export DEBIAN_FRONTEND=noninteractive
set -ex
apt-get update
-apt-get install -y python3 python3-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox libjpeg-dev libwebp-dev
+apt-get install -y \
+ python3 python3-dev python3-pip python3-venv \
+ libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox libjpeg-dev libwebp-dev
export LANG="C.UTF-8"
diff --git a/.github/workflows/release-artifacts.yml b/.github/workflows/release-artifacts.yml
index eb294f1619..eee3633d50 100644
--- a/.github/workflows/release-artifacts.yml
+++ b/.github/workflows/release-artifacts.yml
@@ -7,7 +7,7 @@ on:
# of things breaking (but only build one set of debs)
pull_request:
push:
- branches: ["develop"]
+ branches: ["develop", "release-*"]
# we do the full build on tags.
tags: ["v*"]
@@ -91,17 +91,7 @@ jobs:
build-sdist:
name: "Build pypi distribution files"
- runs-on: ubuntu-latest
- steps:
- - uses: actions/checkout@v2
- - uses: actions/setup-python@v2
- - run: pip install wheel
- - run: |
- python setup.py sdist bdist_wheel
- - uses: actions/upload-artifact@v2
- with:
- name: python-dist
- path: dist/*
+ uses: "matrix-org/backend-meta/.github/workflows/packaging.yml@v1"
# if it's a tag, create a release and attach the artifacts to it
attach-assets:
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 75ac1304bf..bbf1033bdd 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -48,24 +48,10 @@ jobs:
env:
PULL_REQUEST_NUMBER: ${{ github.event.number }}
- lint-sdist:
- runs-on: ubuntu-latest
- steps:
- - uses: actions/checkout@v2
- - uses: actions/setup-python@v2
- with:
- python-version: "3.x"
- - run: pip install wheel
- - run: python setup.py sdist bdist_wheel
- - uses: actions/upload-artifact@v2
- with:
- name: Python Distributions
- path: dist/*
-
# Dummy step to gate other tests on without repeating the whole list
linting-done:
if: ${{ !cancelled() }} # Run this even if prior jobs were skipped
- needs: [lint, lint-crlf, lint-newsfile, lint-sdist]
+ needs: [lint, lint-crlf, lint-newsfile]
runs-on: ubuntu-latest
steps:
- run: "true"
@@ -397,7 +383,6 @@ jobs:
- lint
- lint-crlf
- lint-newsfile
- - lint-sdist
- trial
- trial-olddeps
- sytest
diff --git a/changelog.d/10870.misc b/changelog.d/10870.misc
new file mode 100644
index 0000000000..3af049b969
--- /dev/null
+++ b/changelog.d/10870.misc
@@ -0,0 +1 @@
+Deduplicate in-flight requests in `_get_state_for_groups`.
diff --git a/changelog.d/11835.feature b/changelog.d/11835.feature
new file mode 100644
index 0000000000..7cee39b08c
--- /dev/null
+++ b/changelog.d/11835.feature
@@ -0,0 +1 @@
+Make a `POST` to `/rooms/<room_id>/receipt/m.read/<event_id>` only trigger a push notification if the count of unread messages is different to the one in the last successfully sent push.
diff --git a/changelog.d/11972.misc b/changelog.d/11972.misc
new file mode 100644
index 0000000000..29c38bfd82
--- /dev/null
+++ b/changelog.d/11972.misc
@@ -0,0 +1 @@
+Add tests for device list changes between local users.
\ No newline at end of file
diff --git a/changelog.d/11974.misc b/changelog.d/11974.misc
new file mode 100644
index 0000000000..1debad2361
--- /dev/null
+++ b/changelog.d/11974.misc
@@ -0,0 +1 @@
+Optimise calculating device_list changes in `/sync`.
diff --git a/changelog.d/11984.misc b/changelog.d/11984.misc
new file mode 100644
index 0000000000..8e405b9226
--- /dev/null
+++ b/changelog.d/11984.misc
@@ -0,0 +1 @@
+Add missing type hints to storage classes.
\ No newline at end of file
diff --git a/changelog.d/11985.feature b/changelog.d/11985.feature
new file mode 100644
index 0000000000..120d888a49
--- /dev/null
+++ b/changelog.d/11985.feature
@@ -0,0 +1 @@
+Fetch images when previewing Twitter URLs. Contributed by @AndrewRyanChama.
diff --git a/changelog.d/11991.misc b/changelog.d/11991.misc
new file mode 100644
index 0000000000..34a3b3a6b9
--- /dev/null
+++ b/changelog.d/11991.misc
@@ -0,0 +1 @@
+Refactor the search code for improved readability.
diff --git a/changelog.d/11992.bugfix b/changelog.d/11992.bugfix
new file mode 100644
index 0000000000..f73c86bb25
--- /dev/null
+++ b/changelog.d/11992.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in Synapse v1.48.0 where an edit of the latest event in a thread would not be properly applied to the thread summary.
diff --git a/changelog.d/11994.misc b/changelog.d/11994.misc
new file mode 100644
index 0000000000..d64297dd78
--- /dev/null
+++ b/changelog.d/11994.misc
@@ -0,0 +1 @@
+Move common deduplication code down into `_auth_and_persist_outliers`.
diff --git a/changelog.d/11996.misc b/changelog.d/11996.misc
new file mode 100644
index 0000000000..6c675fd193
--- /dev/null
+++ b/changelog.d/11996.misc
@@ -0,0 +1 @@
+Limit concurrent joins from applications services.
\ No newline at end of file
diff --git a/changelog.d/11997.docker b/changelog.d/11997.docker
new file mode 100644
index 0000000000..1b3271457e
--- /dev/null
+++ b/changelog.d/11997.docker
@@ -0,0 +1 @@
+The docker image no longer automatically creates a temporary volume at `/data`. This is not expected to affect normal usage.
diff --git a/changelog.d/11999.bugfix b/changelog.d/11999.bugfix
new file mode 100644
index 0000000000..fd84095900
--- /dev/null
+++ b/changelog.d/11999.bugfix
@@ -0,0 +1 @@
+Fix long standing bug where `get_rooms_for_user` was not correctly invalidated for remote users when the server left a room.
diff --git a/changelog.d/12000.feature b/changelog.d/12000.feature
new file mode 100644
index 0000000000..246cc87f0b
--- /dev/null
+++ b/changelog.d/12000.feature
@@ -0,0 +1 @@
+Track cache invalidations in Prometheus metrics, as already happens for cache eviction based on size or time.
diff --git a/changelog.d/12003.doc b/changelog.d/12003.doc
new file mode 100644
index 0000000000..1ac8163559
--- /dev/null
+++ b/changelog.d/12003.doc
@@ -0,0 +1 @@
+Explain the meaning of spam checker callbacks' return values.
diff --git a/changelog.d/12004.doc b/changelog.d/12004.doc
new file mode 100644
index 0000000000..0b4baef210
--- /dev/null
+++ b/changelog.d/12004.doc
@@ -0,0 +1 @@
+Clarify information about external Identity Provider IDs.
diff --git a/changelog.d/12005.misc b/changelog.d/12005.misc
new file mode 100644
index 0000000000..45e21dbe59
--- /dev/null
+++ b/changelog.d/12005.misc
@@ -0,0 +1 @@
+Preparation for faster-room-join work: when parsing the `send_join` response, get the `m.room.create` event from `state`, not `auth_chain`.
diff --git a/changelog.d/12008.removal b/changelog.d/12008.removal
new file mode 100644
index 0000000000..57599d9ee9
--- /dev/null
+++ b/changelog.d/12008.removal
@@ -0,0 +1 @@
+Remove support for the legacy structured logging configuration (please see the the [upgrade notes](https://matrix-org.github.io/synapse/develop/upgrade#legacy-structured-logging-configuration-removal) if you are using `structured: true` in the Synapse configuration).
diff --git a/changelog.d/12009.feature b/changelog.d/12009.feature
new file mode 100644
index 0000000000..c8a531481e
--- /dev/null
+++ b/changelog.d/12009.feature
@@ -0,0 +1 @@
+Enable modules to set a custom display name when registering a user.
diff --git a/changelog.d/12011.misc b/changelog.d/12011.misc
new file mode 100644
index 0000000000..258b0e389f
--- /dev/null
+++ b/changelog.d/12011.misc
@@ -0,0 +1 @@
+Preparation for faster-room-join work: parse msc3706 fields in send_join response.
diff --git a/changelog.d/12013.misc b/changelog.d/12013.misc
new file mode 100644
index 0000000000..c0fca8dccb
--- /dev/null
+++ b/changelog.d/12013.misc
@@ -0,0 +1 @@
+Preparation for faster-room-join work: Support for calling `/federation/v1/state` on a remote server.
diff --git a/changelog.d/12015.misc b/changelog.d/12015.misc
new file mode 100644
index 0000000000..3aa32ab4cf
--- /dev/null
+++ b/changelog.d/12015.misc
@@ -0,0 +1 @@
+Configure `tox` to use `venv` rather than `virtualenv`.
diff --git a/changelog.d/12016.misc b/changelog.d/12016.misc
new file mode 100644
index 0000000000..8856ef46a9
--- /dev/null
+++ b/changelog.d/12016.misc
@@ -0,0 +1 @@
+Fix bug in `StateFilter.return_expanded()` and add some tests.
\ No newline at end of file
diff --git a/changelog.d/12018.removal b/changelog.d/12018.removal
new file mode 100644
index 0000000000..e940b62228
--- /dev/null
+++ b/changelog.d/12018.removal
@@ -0,0 +1 @@
+Drop support for [MSC3283](https://github.com/matrix-org/matrix-doc/pull/3283) unstable flags now that the stable flags are supported.
diff --git a/changelog.d/12019.misc b/changelog.d/12019.misc
new file mode 100644
index 0000000000..b2186320ea
--- /dev/null
+++ b/changelog.d/12019.misc
@@ -0,0 +1 @@
+Use Matrix v1.1 endpoints (`/_matrix/client/v3/auth/...`) in fallback auth HTML forms.
\ No newline at end of file
diff --git a/changelog.d/12020.feature b/changelog.d/12020.feature
new file mode 100644
index 0000000000..1ac9d2060e
--- /dev/null
+++ b/changelog.d/12020.feature
@@ -0,0 +1 @@
+Advertise Matrix 1.1 support on `/_matrix/client/versions`.
\ No newline at end of file
diff --git a/changelog.d/12021.feature b/changelog.d/12021.feature
new file mode 100644
index 0000000000..01378df8ca
--- /dev/null
+++ b/changelog.d/12021.feature
@@ -0,0 +1 @@
+Support only the stable identifier for [MSC3069](https://github.com/matrix-org/matrix-doc/pull/3069)'s `is_guest` on `/_matrix/client/v3/account/whoami`.
\ No newline at end of file
diff --git a/changelog.d/12022.feature b/changelog.d/12022.feature
new file mode 100644
index 0000000000..188fb12570
--- /dev/null
+++ b/changelog.d/12022.feature
@@ -0,0 +1 @@
+Advertise Matrix 1.2 support on `/_matrix/client/versions`.
\ No newline at end of file
diff --git a/changelog.d/12024.bugfix b/changelog.d/12024.bugfix
new file mode 100644
index 0000000000..59bcdb93a5
--- /dev/null
+++ b/changelog.d/12024.bugfix
@@ -0,0 +1 @@
+Fix 500 error with Postgres when looking backwards with the [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) `/timestamp_to_event?dir=b` endpoint.
diff --git a/changelog.d/12025.misc b/changelog.d/12025.misc
new file mode 100644
index 0000000000..d9475a7718
--- /dev/null
+++ b/changelog.d/12025.misc
@@ -0,0 +1 @@
+Update the `olddeps` CI job to use an old version of `markupsafe`.
diff --git a/changelog.d/12030.misc b/changelog.d/12030.misc
new file mode 100644
index 0000000000..607ee97ce6
--- /dev/null
+++ b/changelog.d/12030.misc
@@ -0,0 +1 @@
+Upgrade mypy to version 0.931.
diff --git a/changelog.d/12033.misc b/changelog.d/12033.misc
new file mode 100644
index 0000000000..3af049b969
--- /dev/null
+++ b/changelog.d/12033.misc
@@ -0,0 +1 @@
+Deduplicate in-flight requests in `_get_state_for_groups`.
diff --git a/changelog.d/12034.misc b/changelog.d/12034.misc
new file mode 100644
index 0000000000..8374a63220
--- /dev/null
+++ b/changelog.d/12034.misc
@@ -0,0 +1 @@
+Minor typing fixes.
diff --git a/changelog.d/12039.misc b/changelog.d/12039.misc
new file mode 100644
index 0000000000..45e21dbe59
--- /dev/null
+++ b/changelog.d/12039.misc
@@ -0,0 +1 @@
+Preparation for faster-room-join work: when parsing the `send_join` response, get the `m.room.create` event from `state`, not `auth_chain`.
diff --git a/changelog.d/12051.misc b/changelog.d/12051.misc
new file mode 100644
index 0000000000..9959191352
--- /dev/null
+++ b/changelog.d/12051.misc
@@ -0,0 +1 @@
+Tidy up GitHub Actions config which builds distributions for PyPI.
\ No newline at end of file
diff --git a/changelog.d/12052.misc b/changelog.d/12052.misc
new file mode 100644
index 0000000000..fbaff67e95
--- /dev/null
+++ b/changelog.d/12052.misc
@@ -0,0 +1 @@
+Move `isort` configuration to `pyproject.toml`.
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 306f75ae56..e4c1c19b86 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -98,8 +98,6 @@ COPY --from=builder /install /usr/local
COPY ./docker/start.py /start.py
COPY ./docker/conf /conf
-VOLUME ["/data"]
-
EXPOSE 8008/tcp 8009/tcp 8448/tcp
ENTRYPOINT ["/start.py"]
diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md
index 1bbe237080..4076fcab65 100644
--- a/docs/admin_api/user_admin_api.md
+++ b/docs/admin_api/user_admin_api.md
@@ -126,7 +126,8 @@ Body parameters:
[Sample Configuration File](../usage/configuration/homeserver_sample_config.html)
section `sso` and `oidc_providers`.
- `auth_provider` - string. ID of the external identity provider. Value of `idp_id`
- in homeserver configuration.
+ in the homeserver configuration. Note that no error is raised if the provided
+ value is not in the homeserver configuration.
- `external_id` - string, user ID in the external identity provider.
- `avatar_url` - string, optional, must be a
[MXC URI](https://matrix.org/docs/spec/client_server/r0.6.0#matrix-content-mxc-uris).
diff --git a/docs/modules/password_auth_provider_callbacks.md b/docs/modules/password_auth_provider_callbacks.md
index 88b59bb09e..ec810fd292 100644
--- a/docs/modules/password_auth_provider_callbacks.md
+++ b/docs/modules/password_auth_provider_callbacks.md
@@ -85,7 +85,7 @@ If the authentication is unsuccessful, the module must return `None`.
If multiple modules implement this callback, they will be considered in order. If a
callback returns `None`, Synapse falls through to the next one. The value of the first
callback that does not return `None` will be used. If this happens, Synapse will not call
-any of the subsequent implementations of this callback. If every callback return `None`,
+any of the subsequent implementations of this callback. If every callback returns `None`,
the authentication is denied.
### `on_logged_out`
@@ -162,10 +162,38 @@ return `None`.
If multiple modules implement this callback, they will be considered in order. If a
callback returns `None`, Synapse falls through to the next one. The value of the first
callback that does not return `None` will be used. If this happens, Synapse will not call
-any of the subsequent implementations of this callback. If every callback return `None`,
+any of the subsequent implementations of this callback. If every callback returns `None`,
the username provided by the user is used, if any (otherwise one is automatically
generated).
+### `get_displayname_for_registration`
+
+_First introduced in Synapse v1.54.0_
+
+```python
+async def get_displayname_for_registration(
+ uia_results: Dict[str, Any],
+ params: Dict[str, Any],
+) -> Optional[str]
+```
+
+Called when registering a new user. The module can return a display name to set for the
+user being registered by returning it as a string, or `None` if it doesn't wish to force a
+display name for this user.
+
+This callback is called once [User-Interactive Authentication](https://spec.matrix.org/latest/client-server-api/#user-interactive-authentication-api)
+has been completed by the user. It is not called when registering a user via SSO. It is
+passed two dictionaries, which include the information that the user has provided during
+the registration process. These dictionaries are identical to the ones passed to
+[`get_username_for_registration`](#get_username_for_registration), so refer to the
+documentation of this callback for more information about them.
+
+If multiple modules implement this callback, they will be considered in order. If a
+callback returns `None`, Synapse falls through to the next one. The value of the first
+callback that does not return `None` will be used. If this happens, Synapse will not call
+any of the subsequent implementations of this callback. If every callback returns `None`,
+the username will be used (e.g. `alice` if the user being registered is `@alice:example.com`).
+
## `is_3pid_allowed`
_First introduced in Synapse v1.53.0_
@@ -194,8 +222,7 @@ The example module below implements authentication checkers for two different lo
- Is checked by the method: `self.check_my_login`
- `m.login.password` (defined in [the spec](https://matrix.org/docs/spec/client_server/latest#password-based))
- Expects a `password` field to be sent to `/login`
- - Is checked by the method: `self.check_pass`
-
+ - Is checked by the method: `self.check_pass`
```python
from typing import Awaitable, Callable, Optional, Tuple
diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md
index 2eb9032f41..2b672b78f9 100644
--- a/docs/modules/spam_checker_callbacks.md
+++ b/docs/modules/spam_checker_callbacks.md
@@ -16,10 +16,12 @@ _First introduced in Synapse v1.37.0_
async def check_event_for_spam(event: "synapse.events.EventBase") -> Union[bool, str]
```
-Called when receiving an event from a client or via federation. The module can return
-either a `bool` to indicate whether the event must be rejected because of spam, or a `str`
-to indicate the event must be rejected because of spam and to give a rejection reason to
-forward to clients.
+Called when receiving an event from a client or via federation. The callback must return
+either:
+- an error message string, to indicate the event must be rejected because of spam and
+ give a rejection reason to forward to clients;
+- the boolean `True`, to indicate that the event is spammy, but not provide further details; or
+- the booelan `False`, to indicate that the event is not considered spammy.
If multiple modules implement this callback, they will be considered in order. If a
callback returns `False`, Synapse falls through to the next one. The value of the first
@@ -35,7 +37,10 @@ async def user_may_join_room(user: str, room: str, is_invited: bool) -> bool
```
Called when a user is trying to join a room. The module must return a `bool` to indicate
-whether the user can join the room. The user is represented by their Matrix user ID (e.g.
+whether the user can join the room. Return `False` to prevent the user from joining the
+room; otherwise return `True` to permit the joining.
+
+The user is represented by their Matrix user ID (e.g.
`@alice:example.com`) and the room is represented by its Matrix ID (e.g.
`!room:example.com`). The module is also given a boolean to indicate whether the user
currently has a pending invite in the room.
@@ -58,7 +63,8 @@ async def user_may_invite(inviter: str, invitee: str, room_id: str) -> bool
Called when processing an invitation. The module must return a `bool` indicating whether
the inviter can invite the invitee to the given room. Both inviter and invitee are
-represented by their Matrix user ID (e.g. `@alice:example.com`).
+represented by their Matrix user ID (e.g. `@alice:example.com`). Return `False` to prevent
+the invitation; otherwise return `True` to permit it.
If multiple modules implement this callback, they will be considered in order. If a
callback returns `True`, Synapse falls through to the next one. The value of the first
@@ -80,7 +86,8 @@ async def user_may_send_3pid_invite(
Called when processing an invitation using a third-party identifier (also called a 3PID,
e.g. an email address or a phone number). The module must return a `bool` indicating
-whether the inviter can invite the invitee to the given room.
+whether the inviter can invite the invitee to the given room. Return `False` to prevent
+the invitation; otherwise return `True` to permit it.
The inviter is represented by their Matrix user ID (e.g. `@alice:example.com`), and the
invitee is represented by its medium (e.g. "email") and its address
@@ -117,6 +124,7 @@ async def user_may_create_room(user: str) -> bool
Called when processing a room creation request. The module must return a `bool` indicating
whether the given user (represented by their Matrix user ID) is allowed to create a room.
+Return `False` to prevent room creation; otherwise return `True` to permit it.
If multiple modules implement this callback, they will be considered in order. If a
callback returns `True`, Synapse falls through to the next one. The value of the first
@@ -133,7 +141,8 @@ async def user_may_create_room_alias(user: str, room_alias: "synapse.types.RoomA
Called when trying to associate an alias with an existing room. The module must return a
`bool` indicating whether the given user (represented by their Matrix user ID) is allowed
-to set the given alias.
+to set the given alias. Return `False` to prevent the alias creation; otherwise return
+`True` to permit it.
If multiple modules implement this callback, they will be considered in order. If a
callback returns `True`, Synapse falls through to the next one. The value of the first
@@ -150,7 +159,8 @@ async def user_may_publish_room(user: str, room_id: str) -> bool
Called when trying to publish a room to the homeserver's public rooms directory. The
module must return a `bool` indicating whether the given user (represented by their
-Matrix user ID) is allowed to publish the given room.
+Matrix user ID) is allowed to publish the given room. Return `False` to prevent the
+room from being published; otherwise return `True` to permit its publication.
If multiple modules implement this callback, they will be considered in order. If a
callback returns `True`, Synapse falls through to the next one. The value of the first
@@ -166,8 +176,11 @@ async def check_username_for_spam(user_profile: Dict[str, str]) -> bool
```
Called when computing search results in the user directory. The module must return a
-`bool` indicating whether the given user profile can appear in search results. The profile
-is represented as a dictionary with the following keys:
+`bool` indicating whether the given user should be excluded from user directory
+searches. Return `True` to indicate that the user is spammy and exclude them from
+search results; otherwise return `False`.
+
+The profile is represented as a dictionary with the following keys:
* `user_id`: The Matrix ID for this user.
* `display_name`: The user's display name.
@@ -225,8 +238,9 @@ async def check_media_file_for_spam(
) -> bool
```
-Called when storing a local or remote file. The module must return a boolean indicating
-whether the given file can be stored in the homeserver's media store.
+Called when storing a local or remote file. The module must return a `bool` indicating
+whether the given file should be excluded from the homeserver's media store. Return
+`True` to prevent this file from being stored; otherwise return `False`.
If multiple modules implement this callback, they will be considered in order. If a
callback returns `False`, Synapse falls through to the next one. The value of the first
diff --git a/docs/structured_logging.md b/docs/structured_logging.md
index 14db85f587..805c867653 100644
--- a/docs/structured_logging.md
+++ b/docs/structured_logging.md
@@ -81,14 +81,12 @@ remote endpoint at 10.1.2.3:9999.
## Upgrading from legacy structured logging configuration
-Versions of Synapse prior to v1.23.0 included a custom structured logging
-configuration which is deprecated. It used a `structured: true` flag and
-configured `drains` instead of ``handlers`` and `formatters`.
-
-Synapse currently automatically converts the old configuration to the new
-configuration, but this will be removed in a future version of Synapse. The
-following reference can be used to update your configuration. Based on the drain
-`type`, we can pick a new handler:
+Versions of Synapse prior to v1.54.0 automatically converted the legacy
+structured logging configuration, which was deprecated in v1.23.0, to the standard
+library logging configuration.
+
+The following reference can be used to update your configuration. Based on the
+drain `type`, we can pick a new handler:
1. For a type of `console`, `console_json`, or `console_json_terse`: a handler
with a class of `logging.StreamHandler` and a `stream` of `ext://sys.stdout`
diff --git a/docs/upgrade.md b/docs/upgrade.md
index b722d3bb9d..f9be3ac6bc 100644
--- a/docs/upgrade.md
+++ b/docs/upgrade.md
@@ -85,6 +85,15 @@ process, for example:
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
```
+# Upgrading to v1.54.0
+
+## Legacy structured logging configuration removal
+
+This release removes support for the `structured: true` logging configuration
+which was deprecated in Synapse v1.23.0. If your logging configuration contains
+`structured: true` then it should be modified based on the
+[structured logging documentation](structured_logging.md).
+
# Upgrading to v1.53.0
## Dropping support for `webclient` listeners and non-HTTP(S) `web_client_location`
diff --git a/mypy.ini b/mypy.ini
index 63848d664c..610660b9b7 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -31,14 +31,11 @@ exclude = (?x)
|synapse/storage/databases/main/group_server.py
|synapse/storage/databases/main/metrics.py
|synapse/storage/databases/main/monthly_active_users.py
- |synapse/storage/databases/main/presence.py
- |synapse/storage/databases/main/purge_events.py
|synapse/storage/databases/main/push_rule.py
|synapse/storage/databases/main/receipts.py
|synapse/storage/databases/main/roommember.py
|synapse/storage/databases/main/search.py
|synapse/storage/databases/main/state.py
- |synapse/storage/databases/main/user_directory.py
|synapse/storage/schema/
|tests/api/test_auth.py
diff --git a/pyproject.toml b/pyproject.toml
index 963f149c6a..c9cd0cf6ec 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -54,3 +54,15 @@ exclude = '''
)/
)
'''
+
+[tool.isort]
+line_length = 88
+sections = ["FUTURE", "STDLIB", "THIRDPARTY", "TWISTED", "FIRSTPARTY", "TESTS", "LOCALFOLDER"]
+default_section = "THIRDPARTY"
+known_first_party = ["synapse"]
+known_tests = ["tests"]
+known_twisted = ["twisted", "OpenSSL"]
+multi_line_output = 3
+include_trailing_comma = true
+combine_as_imports = true
+
diff --git a/setup.cfg b/setup.cfg
index e5ceb7ed19..a0506572d9 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -19,14 +19,3 @@ ignore =
# E731: do not assign a lambda expression, use a def
# E501: Line too long (black enforces this for us)
ignore=W503,W504,E203,E731,E501
-
-[isort]
-line_length = 88
-sections=FUTURE,STDLIB,THIRDPARTY,TWISTED,FIRSTPARTY,TESTS,LOCALFOLDER
-default_section=THIRDPARTY
-known_first_party = synapse
-known_tests=tests
-known_twisted=twisted,OpenSSL
-multi_line_output=3
-include_trailing_comma=true
-combine_as_imports=true
diff --git a/setup.py b/setup.py
index d0511c767f..c80cb6f207 100755
--- a/setup.py
+++ b/setup.py
@@ -103,8 +103,8 @@ CONDITIONAL_REQUIREMENTS["lint"] = [
]
CONDITIONAL_REQUIREMENTS["mypy"] = [
- "mypy==0.910",
- "mypy-zope==0.3.2",
+ "mypy==0.931",
+ "mypy-zope==0.3.5",
"types-bleach>=4.1.0",
"types-jsonschema>=3.2.0",
"types-opentracing>=2.4.2",
diff --git a/stubs/sortedcontainers/sorteddict.pyi b/stubs/sortedcontainers/sorteddict.pyi
index 0eaef00498..344d55cce1 100644
--- a/stubs/sortedcontainers/sorteddict.pyi
+++ b/stubs/sortedcontainers/sorteddict.pyi
@@ -66,13 +66,18 @@ class SortedDict(Dict[_KT, _VT]):
def __copy__(self: _SD) -> _SD: ...
@classmethod
@overload
- def fromkeys(cls, seq: Iterable[_T_h]) -> SortedDict[_T_h, None]: ...
+ def fromkeys(
+ cls, seq: Iterable[_T_h], value: None = ...
+ ) -> SortedDict[_T_h, None]: ...
@classmethod
@overload
def fromkeys(cls, seq: Iterable[_T_h], value: _S) -> SortedDict[_T_h, _S]: ...
- def keys(self) -> SortedKeysView[_KT]: ...
- def items(self) -> SortedItemsView[_KT, _VT]: ...
- def values(self) -> SortedValuesView[_VT]: ...
+ # As of Python 3.10, `dict_{keys,items,values}` have an extra `mapping` attribute and so
+ # `Sorted{Keys,Items,Values}View` are no longer compatible with them.
+ # See https://github.com/python/typeshed/issues/6837
+ def keys(self) -> SortedKeysView[_KT]: ... # type: ignore[override]
+ def items(self) -> SortedItemsView[_KT, _VT]: ... # type: ignore[override]
+ def values(self) -> SortedValuesView[_VT]: ... # type: ignore[override]
@overload
def pop(self, key: _KT) -> _VT: ...
@overload
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 09d692d9a1..bcdeb9ee23 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -41,9 +41,6 @@ class ExperimentalConfig(Config):
# MSC3244 (room version capabilities)
self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True)
- # MSC3283 (set displayname, avatar_url and change 3pid capabilities)
- self.msc3283_enabled: bool = experimental.get("msc3283_enabled", False)
-
# MSC3266 (room summary api)
self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False)
@@ -64,3 +61,7 @@ class ExperimentalConfig(Config):
# MSC3706 (server-side support for partial state in /send_join responses)
self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False)
+
+ # experimental support for faster joins over federation (msc2775, msc3706)
+ # requires a target server with msc3706_enabled enabled.
+ self.faster_joins_enabled: bool = experimental.get("faster_joins", False)
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index b7145a44ae..cbbe221965 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -33,7 +33,6 @@ from twisted.logger import (
globalLogBeginner,
)
-from synapse.logging._structured import setup_structured_logging
from synapse.logging.context import LoggingContextFilter
from synapse.logging.filter import MetadataFilter
@@ -138,6 +137,12 @@ Support for the log_file configuration option and --log-file command-line option
removed in Synapse 1.3.0. You should instead set up a separate log configuration file.
"""
+STRUCTURED_ERROR = """\
+Support for the structured configuration option was removed in Synapse 1.54.0.
+You should instead use the standard logging configuration. See
+https://matrix-org.github.io/synapse/v1.54/structured_logging.html
+"""
+
class LoggingConfig(Config):
section = "logging"
@@ -292,10 +297,9 @@ def _load_logging_config(log_config_path: str) -> None:
if not log_config:
logging.warning("Loaded a blank logging config?")
- # If the old structured logging configuration is being used, convert it to
- # the new style configuration.
+ # If the old structured logging configuration is being used, raise an error.
if "structured" in log_config and log_config.get("structured"):
- log_config = setup_structured_logging(log_config)
+ raise ConfigError(STRUCTURED_ERROR)
logging.config.dictConfig(log_config)
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 243696b357..9386fa29dd 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -425,6 +425,33 @@ class EventClientSerializer:
return serialized_event
+ def _apply_edit(
+ self, orig_event: EventBase, serialized_event: JsonDict, edit: EventBase
+ ) -> None:
+ """Replace the content, preserving existing relations of the serialized event.
+
+ Args:
+ orig_event: The original event.
+ serialized_event: The original event, serialized. This is modified.
+ edit: The event which edits the above.
+ """
+
+ # Ensure we take copies of the edit content, otherwise we risk modifying
+ # the original event.
+ edit_content = edit.content.copy()
+
+ # Unfreeze the event content if necessary, so that we may modify it below
+ edit_content = unfreeze(edit_content)
+ serialized_event["content"] = edit_content.get("m.new_content", {})
+
+ # Check for existing relations
+ relates_to = orig_event.content.get("m.relates_to")
+ if relates_to:
+ # Keep the relations, ensuring we use a dict copy of the original
+ serialized_event["content"]["m.relates_to"] = relates_to.copy()
+ else:
+ serialized_event["content"].pop("m.relates_to", None)
+
def _inject_bundled_aggregations(
self,
event: EventBase,
@@ -450,26 +477,11 @@ class EventClientSerializer:
serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references
if aggregations.replace:
- # If there is an edit replace the content, preserving existing
- # relations.
+ # If there is an edit, apply it to the event.
edit = aggregations.replace
+ self._apply_edit(event, serialized_event, edit)
- # Ensure we take copies of the edit content, otherwise we risk modifying
- # the original event.
- edit_content = edit.content.copy()
-
- # Unfreeze the event content if necessary, so that we may modify it below
- edit_content = unfreeze(edit_content)
- serialized_event["content"] = edit_content.get("m.new_content", {})
-
- # Check for existing relations
- relates_to = event.content.get("m.relates_to")
- if relates_to:
- # Keep the relations, ensuring we use a dict copy of the original
- serialized_event["content"]["m.relates_to"] = relates_to.copy()
- else:
- serialized_event["content"].pop("m.relates_to", None)
-
+ # Include information about it in the relations dict.
serialized_aggregations[RelationTypes.REPLACE] = {
"event_id": edit.event_id,
"origin_server_ts": edit.origin_server_ts,
@@ -478,13 +490,22 @@ class EventClientSerializer:
# If this event is the start of a thread, include a summary of the replies.
if aggregations.thread:
+ thread = aggregations.thread
+
+ # Don't bundle aggregations as this could recurse forever.
+ serialized_latest_event = self.serialize_event(
+ thread.latest_event, time_now, bundle_aggregations=None
+ )
+ # Manually apply an edit, if one exists.
+ if thread.latest_edit:
+ self._apply_edit(
+ thread.latest_event, serialized_latest_event, thread.latest_edit
+ )
+
serialized_aggregations[RelationTypes.THREAD] = {
- # Don't bundle aggregations as this could recurse forever.
- "latest_event": self.serialize_event(
- aggregations.thread.latest_event, time_now, bundle_aggregations=None
- ),
- "count": aggregations.thread.count,
- "current_user_participated": aggregations.thread.current_user_participated,
+ "latest_event": serialized_latest_event,
+ "count": thread.count,
+ "current_user_participated": thread.current_user_participated,
}
# Include the bundled aggregations in the event.
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 896168c05c..fab6da3c08 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -47,6 +47,11 @@ class FederationBase:
) -> EventBase:
"""Checks that event is correctly signed by the sending server.
+ Also checks the content hash, and redacts the event if there is a mismatch.
+
+ Also runs the event through the spam checker; if it fails, redacts the event
+ and flags it as soft-failed.
+
Args:
room_version: The room version of the PDU
pdu: the event to be checked
@@ -55,7 +60,10 @@ class FederationBase:
* the original event if the checks pass
* a redacted version of the event (if the signature
matched but the hash did not)
- * throws a SynapseError if the signature check failed."""
+
+ Raises:
+ SynapseError if the signature check failed.
+ """
try:
await _check_sigs_on_pdu(self.keyring, room_version, pdu)
except SynapseError as e:
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 74f17aa4da..c2997997da 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -1,4 +1,4 @@
-# Copyright 2015-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2015-2022 The Matrix.org Foundation C.I.C.
# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -89,6 +89,12 @@ class SendJoinResult:
state: List[EventBase]
auth_chain: List[EventBase]
+ # True if 'state' elides non-critical membership events
+ partial_state: bool
+
+ # if 'partial_state' is set, a list of the servers in the room (otherwise empty)
+ servers_in_room: List[str]
+
class FederationClient(FederationBase):
def __init__(self, hs: "HomeServer"):
@@ -413,26 +419,90 @@ class FederationClient(FederationBase):
return state_event_ids, auth_event_ids
+ async def get_room_state(
+ self,
+ destination: str,
+ room_id: str,
+ event_id: str,
+ room_version: RoomVersion,
+ ) -> Tuple[List[EventBase], List[EventBase]]:
+ """Calls the /state endpoint to fetch the state at a particular point
+ in the room.
+
+ Any invalid events (those with incorrect or unverifiable signatures or hashes)
+ are filtered out from the response, and any duplicate events are removed.
+
+ (Size limits and other event-format checks are *not* performed.)
+
+ Note that the result is not ordered, so callers must be careful to process
+ the events in an order that handles dependencies.
+
+ Returns:
+ a tuple of (state events, auth events)
+ """
+ result = await self.transport_layer.get_room_state(
+ room_version,
+ destination,
+ room_id,
+ event_id,
+ )
+ state_events = result.state
+ auth_events = result.auth_events
+
+ # we may as well filter out any duplicates from the response, to save
+ # processing them multiple times. (In particular, events may be present in
+ # `auth_events` as well as `state`, which is redundant).
+ #
+ # We don't rely on the sort order of the events, so we can just stick them
+ # in a dict.
+ state_event_map = {event.event_id: event for event in state_events}
+ auth_event_map = {
+ event.event_id: event
+ for event in auth_events
+ if event.event_id not in state_event_map
+ }
+
+ logger.info(
+ "Processing from /state: %d state events, %d auth events",
+ len(state_event_map),
+ len(auth_event_map),
+ )
+
+ valid_auth_events = await self._check_sigs_and_hash_and_fetch(
+ destination, auth_event_map.values(), room_version
+ )
+
+ valid_state_events = await self._check_sigs_and_hash_and_fetch(
+ destination, state_event_map.values(), room_version
+ )
+
+ return valid_state_events, valid_auth_events
+
async def _check_sigs_and_hash_and_fetch(
self,
origin: str,
pdus: Collection[EventBase],
room_version: RoomVersion,
) -> List[EventBase]:
- """Takes a list of PDUs and checks the signatures and hashes of each
- one. If a PDU fails its signature check then we check if we have it in
- the database and if not then request if from the originating server of
- that PDU.
+ """Checks the signatures and hashes of a list of events.
+
+ If a PDU fails its signature check then we check if we have it in
+ the database, and if not then request it from the sender's server (if that
+ is different from `origin`). If that still fails, the event is omitted from
+ the returned list.
If a PDU fails its content hash check then it is redacted.
- The given list of PDUs are not modified, instead the function returns
+ Also runs each event through the spam checker; if it fails, redacts the event
+ and flags it as soft-failed.
+
+ The given list of PDUs are not modified; instead the function returns
a new list.
Args:
- origin
- pdu
- room_version
+ origin: The server that sent us these events
+ pdus: The events to be checked
+ room_version: the version of the room these events are in
Returns:
A list of PDUs that have valid signatures and hashes.
@@ -463,11 +533,16 @@ class FederationClient(FederationBase):
origin: str,
room_version: RoomVersion,
) -> Optional[EventBase]:
- """Takes a PDU and checks its signatures and hashes. If the PDU fails
- its signature check then we check if we have it in the database and if
- not then request if from the originating server of that PDU.
+ """Takes a PDU and checks its signatures and hashes.
+
+ If the PDU fails its signature check then we check if we have it in the
+ database; if not, we then request it from sender's server (if that is not the
+ same as `origin`). If that still fails, we return None.
- If then PDU fails its content hash check then it is redacted.
+ If the PDU fails its content hash check, it is redacted.
+
+ Also runs the event through the spam checker; if it fails, redacts the event
+ and flags it as soft-failed.
Args:
origin
@@ -864,23 +939,32 @@ class FederationClient(FederationBase):
for s in signed_state:
s.internal_metadata = copy.deepcopy(s.internal_metadata)
- # double-check that the same create event has ended up in the auth chain
+ # double-check that the auth chain doesn't include a different create event
auth_chain_create_events = [
e.event_id
for e in signed_auth
if (e.type, e.state_key) == (EventTypes.Create, "")
]
- if auth_chain_create_events != [create_event.event_id]:
+ if auth_chain_create_events and auth_chain_create_events != [
+ create_event.event_id
+ ]:
raise InvalidResponseError(
"Unexpected create event(s) in auth chain: %s"
% (auth_chain_create_events,)
)
+ if response.partial_state and not response.servers_in_room:
+ raise InvalidResponseError(
+ "partial_state was set, but no servers were listed in the room"
+ )
+
return SendJoinResult(
event=event,
state=signed_state,
auth_chain=signed_auth,
origin=destination,
+ partial_state=response.partial_state,
+ servers_in_room=response.servers_in_room or [],
)
# MSC3083 defines additional error codes for room joins.
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 8152e80b88..c3132f7319 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -381,7 +381,9 @@ class PerDestinationQueue:
)
)
- if self._last_successful_stream_ordering is None:
+ last_successful_stream_ordering = self._last_successful_stream_ordering
+
+ if last_successful_stream_ordering is None:
# if it's still None, then this means we don't have the information
# in our database  we haven't successfully sent a PDU to this server
# (at least since the introduction of the feature tracking
@@ -394,8 +396,7 @@ class PerDestinationQueue:
# get at most 50 catchup room/PDUs
while True:
event_ids = await self._store.get_catch_up_room_event_ids(
- self._destination,
- self._last_successful_stream_ordering,
+ self._destination, last_successful_stream_ordering
)
if not event_ids:
@@ -403,7 +404,7 @@ class PerDestinationQueue:
# of a race condition, so we check that no new events have been
# skipped due to us being in catch-up mode
- if self._catchup_last_skipped > self._last_successful_stream_ordering:
+ if self._catchup_last_skipped > last_successful_stream_ordering:
# another event has been skipped because we were in catch-up mode
continue
@@ -470,7 +471,7 @@ class PerDestinationQueue:
# offline
if (
p.internal_metadata.stream_ordering
- < self._last_successful_stream_ordering
+ < last_successful_stream_ordering
):
continue
@@ -513,12 +514,11 @@ class PerDestinationQueue:
# from the *original* PDU, rather than the PDU(s) we actually
# send. This is because we use it to mark our position in the
# queue of missed PDUs to process.
- self._last_successful_stream_ordering = (
- pdu.internal_metadata.stream_ordering
- )
+ last_successful_stream_ordering = pdu.internal_metadata.stream_ordering
+ self._last_successful_stream_ordering = last_successful_stream_ordering
await self._store.set_destination_last_successful_stream_ordering(
- self._destination, self._last_successful_stream_ordering
+ self._destination, last_successful_stream_ordering
)
def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]:
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 8782586cd6..7e510e224a 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -1,4 +1,4 @@
-# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2014-2022 The Matrix.org Foundation C.I.C.
# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -60,17 +60,17 @@ class TransportLayerClient:
def __init__(self, hs):
self.server_name = hs.hostname
self.client = hs.get_federation_http_client()
+ self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled
async def get_room_state_ids(
self, destination: str, room_id: str, event_id: str
) -> JsonDict:
- """Requests all state for a given room from the given server at the
- given event. Returns the state's event_id's
+ """Requests the IDs of all state for a given room at the given event.
Args:
destination: The host name of the remote homeserver we want
to get the state from.
- context: The name of the context we want the state of
+ room_id: the room we want the state of
event_id: The event we want the context at.
Returns:
@@ -86,6 +86,29 @@ class TransportLayerClient:
try_trailing_slash_on_400=True,
)
+ async def get_room_state(
+ self, room_version: RoomVersion, destination: str, room_id: str, event_id: str
+ ) -> "StateRequestResponse":
+ """Requests the full state for a given room at the given event.
+
+ Args:
+ room_version: the version of the room (required to build the event objects)
+ destination: The host name of the remote homeserver we want
+ to get the state from.
+ room_id: the room we want the state of
+ event_id: The event we want the context at.
+
+ Returns:
+ Results in a dict received from the remote homeserver.
+ """
+ path = _create_v1_path("/state/%s", room_id)
+ return await self.client.get_json(
+ destination,
+ path=path,
+ args={"event_id": event_id},
+ parser=_StateParser(room_version),
+ )
+
async def get_event(
self, destination: str, event_id: str, timeout: Optional[int] = None
) -> JsonDict:
@@ -336,10 +359,15 @@ class TransportLayerClient:
content: JsonDict,
) -> "SendJoinResponse":
path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
+ query_params: Dict[str, str] = {}
+ if self._faster_joins_enabled:
+ # lazy-load state on join
+ query_params["org.matrix.msc3706.partial_state"] = "true"
return await self.client.put_json(
destination=destination,
path=path,
+ args=query_params,
data=content,
parser=SendJoinParser(room_version, v1_api=False),
max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN,
@@ -1271,6 +1299,20 @@ class SendJoinResponse:
# "event" is not included in the response.
event: Optional[EventBase] = None
+ # The room state is incomplete
+ partial_state: bool = False
+
+ # List of servers in the room
+ servers_in_room: Optional[List[str]] = None
+
+
+@attr.s(slots=True, auto_attribs=True)
+class StateRequestResponse:
+ """The parsed response of a `/state` request."""
+
+ auth_events: List[EventBase]
+ state: List[EventBase]
+
@ijson.coroutine
def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]:
@@ -1297,6 +1339,32 @@ def _event_list_parser(
events.append(event)
+@ijson.coroutine
+def _partial_state_parser(response: SendJoinResponse) -> Generator[None, Any, None]:
+ """Helper function for use with `ijson.items_coro`
+
+ Parses the partial_state field in send_join responses
+ """
+ while True:
+ val = yield
+ if not isinstance(val, bool):
+ raise TypeError("partial_state must be a boolean")
+ response.partial_state = val
+
+
+@ijson.coroutine
+def _servers_in_room_parser(response: SendJoinResponse) -> Generator[None, Any, None]:
+ """Helper function for use with `ijson.items_coro`
+
+ Parses the servers_in_room field in send_join responses
+ """
+ while True:
+ val = yield
+ if not isinstance(val, list) or any(not isinstance(x, str) for x in val):
+ raise TypeError("servers_in_room must be a list of strings")
+ response.servers_in_room = val
+
+
class SendJoinParser(ByteParser[SendJoinResponse]):
"""A parser for the response to `/send_join` requests.
@@ -1308,44 +1376,62 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
CONTENT_TYPE = "application/json"
def __init__(self, room_version: RoomVersion, v1_api: bool):
- self._response = SendJoinResponse([], [], {})
+ self._response = SendJoinResponse([], [], event_dict={})
self._room_version = room_version
+ self._coros = []
# The V1 API has the shape of `[200, {...}]`, which we handle by
# prefixing with `item.*`.
prefix = "item." if v1_api else ""
- self._coro_state = ijson.items_coro(
- _event_list_parser(room_version, self._response.state),
- prefix + "state.item",
- use_float=True,
- )
- self._coro_auth = ijson.items_coro(
- _event_list_parser(room_version, self._response.auth_events),
- prefix + "auth_chain.item",
- use_float=True,
- )
- # TODO Remove the unstable prefix when servers have updated.
- #
- # By re-using the same event dictionary this will cause the parsing of
- # org.matrix.msc3083.v2.event and event to stomp over each other.
- # Generally this should be fine.
- self._coro_unstable_event = ijson.kvitems_coro(
- _event_parser(self._response.event_dict),
- prefix + "org.matrix.msc3083.v2.event",
- use_float=True,
- )
- self._coro_event = ijson.kvitems_coro(
- _event_parser(self._response.event_dict),
- prefix + "event",
- use_float=True,
- )
+ self._coros = [
+ ijson.items_coro(
+ _event_list_parser(room_version, self._response.state),
+ prefix + "state.item",
+ use_float=True,
+ ),
+ ijson.items_coro(
+ _event_list_parser(room_version, self._response.auth_events),
+ prefix + "auth_chain.item",
+ use_float=True,
+ ),
+ # TODO Remove the unstable prefix when servers have updated.
+ #
+ # By re-using the same event dictionary this will cause the parsing of
+ # org.matrix.msc3083.v2.event and event to stomp over each other.
+ # Generally this should be fine.
+ ijson.kvitems_coro(
+ _event_parser(self._response.event_dict),
+ prefix + "org.matrix.msc3083.v2.event",
+ use_float=True,
+ ),
+ ijson.kvitems_coro(
+ _event_parser(self._response.event_dict),
+ prefix + "event",
+ use_float=True,
+ ),
+ ]
+
+ if not v1_api:
+ self._coros.append(
+ ijson.items_coro(
+ _partial_state_parser(self._response),
+ "org.matrix.msc3706.partial_state",
+ use_float="True",
+ )
+ )
+
+ self._coros.append(
+ ijson.items_coro(
+ _servers_in_room_parser(self._response),
+ "org.matrix.msc3706.servers_in_room",
+ use_float="True",
+ )
+ )
def write(self, data: bytes) -> int:
- self._coro_state.send(data)
- self._coro_auth.send(data)
- self._coro_unstable_event.send(data)
- self._coro_event.send(data)
+ for c in self._coros:
+ c.send(data)
return len(data)
@@ -1355,3 +1441,37 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
self._response.event_dict, self._room_version
)
return self._response
+
+
+class _StateParser(ByteParser[StateRequestResponse]):
+ """A parser for the response to `/state` requests.
+
+ Args:
+ room_version: The version of the room.
+ """
+
+ CONTENT_TYPE = "application/json"
+
+ def __init__(self, room_version: RoomVersion):
+ self._response = StateRequestResponse([], [])
+ self._room_version = room_version
+ self._coros = [
+ ijson.items_coro(
+ _event_list_parser(room_version, self._response.state),
+ "pdus.item",
+ use_float=True,
+ ),
+ ijson.items_coro(
+ _event_list_parser(room_version, self._response.auth_events),
+ "auth_chain.item",
+ use_float=True,
+ ),
+ ]
+
+ def write(self, data: bytes) -> int:
+ for c in self._coros:
+ c.send(data)
+ return len(data)
+
+ def finish(self) -> StateRequestResponse:
+ return self._response
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 6959d1aa7e..572f54b1e3 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -2064,6 +2064,10 @@ GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
[JsonDict, JsonDict],
Awaitable[Optional[str]],
]
+GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[
+ [JsonDict, JsonDict],
+ Awaitable[Optional[str]],
+]
IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
@@ -2080,6 +2084,9 @@ class PasswordAuthProvider:
self.get_username_for_registration_callbacks: List[
GET_USERNAME_FOR_REGISTRATION_CALLBACK
] = []
+ self.get_displayname_for_registration_callbacks: List[
+ GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
+ ] = []
self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
# Mapping from login type to login parameters
@@ -2099,6 +2106,9 @@ class PasswordAuthProvider:
get_username_for_registration: Optional[
GET_USERNAME_FOR_REGISTRATION_CALLBACK
] = None,
+ get_displayname_for_registration: Optional[
+ GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
+ ] = None,
) -> None:
# Register check_3pid_auth callback
if check_3pid_auth is not None:
@@ -2148,6 +2158,11 @@ class PasswordAuthProvider:
get_username_for_registration,
)
+ if get_displayname_for_registration is not None:
+ self.get_displayname_for_registration_callbacks.append(
+ get_displayname_for_registration,
+ )
+
if is_3pid_allowed is not None:
self.is_3pid_allowed_callbacks.append(is_3pid_allowed)
@@ -2350,6 +2365,49 @@ class PasswordAuthProvider:
return None
+ async def get_displayname_for_registration(
+ self,
+ uia_results: JsonDict,
+ params: JsonDict,
+ ) -> Optional[str]:
+ """Defines the display name to use when registering the user, using the
+ credentials and parameters provided during the UIA flow.
+
+ Stops at the first callback that returns a tuple containing at least one string.
+
+ Args:
+ uia_results: The credentials provided during the UIA flow.
+ params: The parameters provided by the registration request.
+
+ Returns:
+ A tuple which first element is the display name, and the second is an MXC URL
+ to the user's avatar.
+ """
+ for callback in self.get_displayname_for_registration_callbacks:
+ try:
+ res = await callback(uia_results, params)
+
+ if isinstance(res, str):
+ return res
+ elif res is not None:
+ # mypy complains that this line is unreachable because it assumes the
+ # data returned by the module fits the expected type. We just want
+ # to make sure this is the case.
+ logger.warning( # type: ignore[unreachable]
+ "Ignoring non-string value returned by"
+ " get_displayname_for_registration callback %s: %s",
+ callback,
+ res,
+ )
+ except Exception as e:
+ logger.error(
+ "Module raised an exception in get_displayname_for_registration: %s",
+ e,
+ )
+ raise SynapseError(code=500, msg="Internal Server Error")
+
+ return None
+
async def is_3pid_allowed(
self,
medium: str,
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index c0f642005f..c8356f233d 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -516,7 +516,7 @@ class FederationHandler:
await self.store.upsert_room_on_join(
room_id=room_id,
room_version=room_version_obj,
- auth_events=auth_chain,
+ state_events=state,
)
max_stream_id = await self._federation_event_handler.process_remote_join(
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 9edc7369d6..7683246bef 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -419,10 +419,8 @@ class FederationEventHandler:
Raises:
SynapseError if the response is in some way invalid.
"""
- event_map = {e.event_id: e for e in itertools.chain(auth_events, state)}
-
create_event = None
- for e in auth_events:
+ for e in state:
if (e.type, e.state_key) == (EventTypes.Create, ""):
create_event = e
break
@@ -439,11 +437,6 @@ class FederationEventHandler:
if room_version.identifier != room_version_id:
raise SynapseError(400, "Room version mismatch")
- # filter out any events we have already seen
- seen_remotes = await self._store.have_seen_events(room_id, event_map.keys())
- for s in seen_remotes:
- event_map.pop(s, None)
-
# persist the auth chain and state events.
#
# any invalid events here will be marked as rejected, and we'll carry on.
@@ -455,7 +448,9 @@ class FederationEventHandler:
# signatures right now doesn't mean that we will *never* be able to, so it
# is premature to reject them.
#
- await self._auth_and_persist_outliers(room_id, event_map.values())
+ await self._auth_and_persist_outliers(
+ room_id, itertools.chain(auth_events, state)
+ )
# and now persist the join event itself.
logger.info("Peristing join-via-remote %s", event)
@@ -1245,6 +1240,16 @@ class FederationEventHandler:
"""
event_map = {event.event_id: event for event in events}
+ # filter out any events we have already seen. This might happen because
+ # the events were eagerly pushed to us (eg, during a room join), or because
+ # another thread has raced against us since we decided to request the event.
+ #
+ # This is just an optimisation, so it doesn't need to be watertight - the event
+ # persister does another round of deduplication.
+ seen_remotes = await self._store.have_seen_events(room_id, event_map.keys())
+ for s in seen_remotes:
+ event_map.pop(s, None)
+
# XXX: it might be possible to kick this process off in parallel with fetching
# the events.
while event_map:
@@ -1717,31 +1722,22 @@ class FederationEventHandler:
event_id: the event for which we are lacking auth events
"""
try:
- remote_event_map = {
- e.event_id: e
- for e in await self._federation_client.get_event_auth(
- destination, room_id, event_id
- )
- }
+ remote_events = await self._federation_client.get_event_auth(
+ destination, room_id, event_id
+ )
+
except RequestSendFailed as e1:
# The other side isn't around or doesn't implement the
# endpoint, so lets just bail out.
logger.info("Failed to get event auth from remote: %s", e1)
return
- logger.info("/event_auth returned %i events", len(remote_event_map))
+ logger.info("/event_auth returned %i events", len(remote_events))
# `event` may be returned, but we should not yet process it.
- remote_event_map.pop(event_id, None)
-
- # nor should we reprocess any events we have already seen.
- seen_remotes = await self._store.have_seen_events(
- room_id, remote_event_map.keys()
- )
- for s in seen_remotes:
- remote_event_map.pop(s, None)
+ remote_auth_events = (e for e in remote_events if e.event_id != event_id)
- await self._auth_and_persist_outliers(room_id, remote_event_map.values())
+ await self._auth_and_persist_outliers(room_id, remote_auth_events)
async def _update_context_for_auth_events(
self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 9267e586a8..4d0da84287 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -550,10 +550,11 @@ class EventCreationHandler:
if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
room_version_id = event_dict["content"]["room_version"]
- room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id)
- if not room_version_obj:
+ maybe_room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id)
+ if not maybe_room_version_obj:
# this can happen if support is withdrawn for a room version
raise UnsupportedRoomVersionError(room_version_id)
+ room_version_obj = maybe_room_version_obj
else:
try:
room_version_obj = await self.store.get_room_version(
@@ -1145,12 +1146,13 @@ class EventCreationHandler:
room_version_id = event.content.get(
"room_version", RoomVersions.V1.identifier
)
- room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id)
- if not room_version_obj:
+ maybe_room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id)
+ if not maybe_room_version_obj:
raise UnsupportedRoomVersionError(
"Attempt to create a room with unsupported room version %s"
% (room_version_id,)
)
+ room_version_obj = maybe_room_version_obj
else:
room_version_obj = await self.store.get_room_version(event.room_id)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 067c43ae47..b223b72623 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -204,25 +204,27 @@ class BasePresenceHandler(abc.ABC):
Returns:
dict: `user_id` -> `UserPresenceState`
"""
- states = {
- user_id: self.user_to_current_state.get(user_id, None)
- for user_id in user_ids
- }
+ states = {}
+ missing = []
+ for user_id in user_ids:
+ state = self.user_to_current_state.get(user_id, None)
+ if state:
+ states[user_id] = state
+ else:
+ missing.append(user_id)
- missing = [user_id for user_id, state in states.items() if not state]
if missing:
# There are things not in our in memory cache. Lets pull them out of
# the database.
res = await self.store.get_presence_for_users(missing)
states.update(res)
- missing = [user_id for user_id, state in states.items() if not state]
- if missing:
- new = {
- user_id: UserPresenceState.default(user_id) for user_id in missing
- }
- states.update(new)
- self.user_to_current_state.update(new)
+ for user_id in missing:
+ # if user has no state in database, create the state
+ if not res.get(user_id, None):
+ new_state = UserPresenceState.default(user_id)
+ states[user_id] = new_state
+ self.user_to_current_state[user_id] = new_state
return states
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index a719d5eef3..80320d2c07 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -320,12 +320,12 @@ class RegistrationHandler:
if fail_count > 10:
raise SynapseError(500, "Unable to find a suitable guest user ID")
- localpart = await self.store.generate_user_id()
- user = UserID(localpart, self.hs.hostname)
+ generated_localpart = await self.store.generate_user_id()
+ user = UserID(generated_localpart, self.hs.hostname)
user_id = user.to_string()
self.check_user_id_not_appservice_exclusive(user_id)
if generate_display_name:
- default_display_name = localpart
+ default_display_name = generated_localpart
try:
await self.register_with_store(
user_id=user_id,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index bf1a47efb0..b2adc0f48b 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -82,6 +82,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.event_auth_handler = hs.get_event_auth_handler()
self.member_linearizer: Linearizer = Linearizer(name="member")
+ self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter")
self.clock = hs.get_clock()
self.spam_checker = hs.get_spam_checker()
@@ -500,25 +501,32 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
key = (room_id,)
- with (await self.member_linearizer.queue(key)):
- result = await self.update_membership_locked(
- requester,
- target,
- room_id,
- action,
- txn_id=txn_id,
- remote_room_hosts=remote_room_hosts,
- third_party_signed=third_party_signed,
- ratelimit=ratelimit,
- content=content,
- new_room=new_room,
- require_consent=require_consent,
- outlier=outlier,
- historical=historical,
- allow_no_prev_events=allow_no_prev_events,
- prev_event_ids=prev_event_ids,
- auth_event_ids=auth_event_ids,
- )
+ as_id = object()
+ if requester.app_service:
+ as_id = requester.app_service.id
+
+ # We first linearise by the application service (to try to limit concurrent joins
+ # by application services), and then by room ID.
+ with (await self.member_as_limiter.queue(as_id)):
+ with (await self.member_linearizer.queue(key)):
+ result = await self.update_membership_locked(
+ requester,
+ target,
+ room_id,
+ action,
+ txn_id=txn_id,
+ remote_room_hosts=remote_room_hosts,
+ third_party_signed=third_party_signed,
+ ratelimit=ratelimit,
+ content=content,
+ new_room=new_room,
+ require_consent=require_consent,
+ outlier=outlier,
+ historical=historical,
+ allow_no_prev_events=allow_no_prev_events,
+ prev_event_ids=prev_event_ids,
+ auth_event_ids=auth_event_ids,
+ )
return result
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 41cb809078..0e0e58de02 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -14,8 +14,9 @@
import itertools
import logging
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional
+from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+import attr
from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.constants import EventTypes, Membership
@@ -32,6 +33,20 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _SearchResult:
+ # The count of results.
+ count: int
+ # A mapping of event ID to the rank of that event.
+ rank_map: Dict[str, int]
+ # A list of the resulting events.
+ allowed_events: List[EventBase]
+ # A map of room ID to results.
+ room_groups: Dict[str, JsonDict]
+ # A set of event IDs to highlight.
+ highlights: Set[str]
+
+
class SearchHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@@ -100,7 +115,7 @@ class SearchHandler:
"""Performs a full text search for a user.
Args:
- user
+ user: The user performing the search.
content: Search parameters
batch: The next_batch parameter. Used for pagination.
@@ -156,6 +171,8 @@ class SearchHandler:
# Include context around each event?
event_context = room_cat.get("event_context", None)
+ before_limit = after_limit = None
+ include_profile = False
# Group results together? May allow clients to paginate within a
# group
@@ -182,6 +199,73 @@ class SearchHandler:
% (set(group_keys) - {"room_id", "sender"},),
)
+ return await self._search(
+ user,
+ batch_group,
+ batch_group_key,
+ batch_token,
+ search_term,
+ keys,
+ filter_dict,
+ order_by,
+ include_state,
+ group_keys,
+ event_context,
+ before_limit,
+ after_limit,
+ include_profile,
+ )
+
+ async def _search(
+ self,
+ user: UserID,
+ batch_group: Optional[str],
+ batch_group_key: Optional[str],
+ batch_token: Optional[str],
+ search_term: str,
+ keys: List[str],
+ filter_dict: JsonDict,
+ order_by: str,
+ include_state: bool,
+ group_keys: List[str],
+ event_context: Optional[bool],
+ before_limit: Optional[int],
+ after_limit: Optional[int],
+ include_profile: bool,
+ ) -> JsonDict:
+ """Performs a full text search for a user.
+
+ Args:
+ user: The user performing the search.
+ batch_group: Pagination information.
+ batch_group_key: Pagination information.
+ batch_token: Pagination information.
+ search_term: Search term to search for
+ keys: List of keys to search in, currently supports
+ "content.body", "content.name", "content.topic"
+ filter_dict: The JSON to build a filter out of.
+ order_by: How to order the results. Valid values ore "rank" and "recent".
+ include_state: True if the state of the room at each result should
+ be included.
+ group_keys: A list of ways to group the results. Valid values are
+ "room_id" and "sender".
+ event_context: True to include contextual events around results.
+ before_limit:
+ The number of events before a result to include as context.
+
+ Only used if event_context is True.
+ after_limit:
+ The number of events after a result to include as context.
+
+ Only used if event_context is True.
+ include_profile: True if historical profile information should be
+ included in the event context.
+
+ Only used if event_context is True.
+
+ Returns:
+ dict to be returned to the client with results of search
+ """
search_filter = Filter(self.hs, filter_dict)
# TODO: Search through left rooms too
@@ -216,278 +300,399 @@ class SearchHandler:
}
}
- rank_map = {} # event_id -> rank of event
- allowed_events = []
- # Holds result of grouping by room, if applicable
- room_groups: Dict[str, JsonDict] = {}
- # Holds result of grouping by sender, if applicable
- sender_group: Dict[str, JsonDict] = {}
+ sender_group: Optional[Dict[str, JsonDict]]
- # Holds the next_batch for the entire result set if one of those exists
- global_next_batch = None
-
- highlights = set()
+ if order_by == "rank":
+ search_result, sender_group = await self._search_by_rank(
+ user, room_ids, search_term, keys, search_filter
+ )
+ # Unused return values for rank search.
+ global_next_batch = None
+ elif order_by == "recent":
+ search_result, global_next_batch = await self._search_by_recent(
+ user,
+ room_ids,
+ search_term,
+ keys,
+ search_filter,
+ batch_group,
+ batch_group_key,
+ batch_token,
+ )
+ # Unused return values for recent search.
+ sender_group = None
+ else:
+ # We should never get here due to the guard earlier.
+ raise NotImplementedError()
- count = None
+ logger.info("Found %d events to return", len(search_result.allowed_events))
- if order_by == "rank":
- search_result = await self.store.search_msgs(room_ids, search_term, keys)
+ # If client has asked for "context" for each event (i.e. some surrounding
+ # events and state), fetch that
+ if event_context is not None:
+ # Note that before and after limit must be set in this case.
+ assert before_limit is not None
+ assert after_limit is not None
+
+ contexts = await self._calculate_event_contexts(
+ user,
+ search_result.allowed_events,
+ before_limit,
+ after_limit,
+ include_profile,
+ )
+ else:
+ contexts = {}
- count = search_result["count"]
+ # TODO: Add a limit
- if search_result["highlights"]:
- highlights.update(search_result["highlights"])
+ state_results = {}
+ if include_state:
+ for room_id in {e.room_id for e in search_result.allowed_events}:
+ state = await self.state_handler.get_current_state(room_id)
+ state_results[room_id] = list(state.values())
- results = search_result["results"]
+ aggregations = None
+ if self._msc3666_enabled:
+ aggregations = await self.store.get_bundled_aggregations(
+ # Generate an iterable of EventBase for all the events that will be
+ # returned, including contextual events.
+ itertools.chain(
+ # The events_before and events_after for each context.
+ itertools.chain.from_iterable(
+ itertools.chain(context["events_before"], context["events_after"]) # type: ignore[arg-type]
+ for context in contexts.values()
+ ),
+ # The returned events.
+ search_result.allowed_events,
+ ),
+ user.to_string(),
+ )
- rank_map.update({r["event"].event_id: r["rank"] for r in results})
+ # We're now about to serialize the events. We should not make any
+ # blocking calls after this. Otherwise, the 'age' will be wrong.
- filtered_events = await search_filter.filter([r["event"] for r in results])
+ time_now = self.clock.time_msec()
- events = await filter_events_for_client(
- self.storage, user.to_string(), filtered_events
+ for context in contexts.values():
+ context["events_before"] = self._event_serializer.serialize_events(
+ context["events_before"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type]
+ )
+ context["events_after"] = self._event_serializer.serialize_events(
+ context["events_after"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type]
)
- events.sort(key=lambda e: -rank_map[e.event_id])
- allowed_events = events[: search_filter.limit]
+ results = [
+ {
+ "rank": search_result.rank_map[e.event_id],
+ "result": self._event_serializer.serialize_event(
+ e, time_now, bundle_aggregations=aggregations
+ ),
+ "context": contexts.get(e.event_id, {}),
+ }
+ for e in search_result.allowed_events
+ ]
- for e in allowed_events:
- rm = room_groups.setdefault(
- e.room_id, {"results": [], "order": rank_map[e.event_id]}
- )
- rm["results"].append(e.event_id)
+ rooms_cat_res: JsonDict = {
+ "results": results,
+ "count": search_result.count,
+ "highlights": list(search_result.highlights),
+ }
- s = sender_group.setdefault(
- e.sender, {"results": [], "order": rank_map[e.event_id]}
- )
- s["results"].append(e.event_id)
+ if state_results:
+ rooms_cat_res["state"] = {
+ room_id: self._event_serializer.serialize_events(state_events, time_now)
+ for room_id, state_events in state_results.items()
+ }
- elif order_by == "recent":
- room_events: List[EventBase] = []
- i = 0
-
- pagination_token = batch_token
-
- # We keep looping and we keep filtering until we reach the limit
- # or we run out of things.
- # But only go around 5 times since otherwise synapse will be sad.
- while len(room_events) < search_filter.limit and i < 5:
- i += 1
- search_result = await self.store.search_rooms(
- room_ids,
- search_term,
- keys,
- search_filter.limit * 2,
- pagination_token=pagination_token,
- )
+ if search_result.room_groups and "room_id" in group_keys:
+ rooms_cat_res.setdefault("groups", {})[
+ "room_id"
+ ] = search_result.room_groups
- if search_result["highlights"]:
- highlights.update(search_result["highlights"])
+ if sender_group and "sender" in group_keys:
+ rooms_cat_res.setdefault("groups", {})["sender"] = sender_group
- count = search_result["count"]
+ if global_next_batch:
+ rooms_cat_res["next_batch"] = global_next_batch
- results = search_result["results"]
+ return {"search_categories": {"room_events": rooms_cat_res}}
- results_map = {r["event"].event_id: r for r in results}
+ async def _search_by_rank(
+ self,
+ user: UserID,
+ room_ids: Collection[str],
+ search_term: str,
+ keys: Iterable[str],
+ search_filter: Filter,
+ ) -> Tuple[_SearchResult, Dict[str, JsonDict]]:
+ """
+ Performs a full text search for a user ordering by rank.
- rank_map.update({r["event"].event_id: r["rank"] for r in results})
+ Args:
+ user: The user performing the search.
+ room_ids: List of room ids to search in
+ search_term: Search term to search for
+ keys: List of keys to search in, currently supports
+ "content.body", "content.name", "content.topic"
+ search_filter: The event filter to use.
- filtered_events = await search_filter.filter(
- [r["event"] for r in results]
- )
+ Returns:
+ A tuple of:
+ The search results.
+ A map of sender ID to results.
+ """
+ rank_map = {} # event_id -> rank of event
+ # Holds result of grouping by room, if applicable
+ room_groups: Dict[str, JsonDict] = {}
+ # Holds result of grouping by sender, if applicable
+ sender_group: Dict[str, JsonDict] = {}
- events = await filter_events_for_client(
- self.storage, user.to_string(), filtered_events
- )
+ search_result = await self.store.search_msgs(room_ids, search_term, keys)
- room_events.extend(events)
- room_events = room_events[: search_filter.limit]
+ if search_result["highlights"]:
+ highlights = search_result["highlights"]
+ else:
+ highlights = set()
- if len(results) < search_filter.limit * 2:
- pagination_token = None
- break
- else:
- pagination_token = results[-1]["pagination_token"]
-
- for event in room_events:
- group = room_groups.setdefault(event.room_id, {"results": []})
- group["results"].append(event.event_id)
-
- if room_events and len(room_events) >= search_filter.limit:
- last_event_id = room_events[-1].event_id
- pagination_token = results_map[last_event_id]["pagination_token"]
-
- # We want to respect the given batch group and group keys so
- # that if people blindly use the top level `next_batch` token
- # it returns more from the same group (if applicable) rather
- # than reverting to searching all results again.
- if batch_group and batch_group_key:
- global_next_batch = encode_base64(
- (
- "%s\n%s\n%s"
- % (batch_group, batch_group_key, pagination_token)
- ).encode("ascii")
- )
- else:
- global_next_batch = encode_base64(
- ("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii")
- )
+ results = search_result["results"]
- for room_id, group in room_groups.items():
- group["next_batch"] = encode_base64(
- ("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode(
- "ascii"
- )
- )
+ # event_id -> rank of event
+ rank_map = {r["event"].event_id: r["rank"] for r in results}
- allowed_events.extend(room_events)
+ filtered_events = await search_filter.filter([r["event"] for r in results])
- else:
- # We should never get here due to the guard earlier.
- raise NotImplementedError()
+ events = await filter_events_for_client(
+ self.storage, user.to_string(), filtered_events
+ )
- logger.info("Found %d events to return", len(allowed_events))
+ events.sort(key=lambda e: -rank_map[e.event_id])
+ allowed_events = events[: search_filter.limit]
- # If client has asked for "context" for each event (i.e. some surrounding
- # events and state), fetch that
- if event_context is not None:
- now_token = self.hs.get_event_sources().get_current_token()
+ for e in allowed_events:
+ rm = room_groups.setdefault(
+ e.room_id, {"results": [], "order": rank_map[e.event_id]}
+ )
+ rm["results"].append(e.event_id)
- contexts = {}
- for event in allowed_events:
- res = await self.store.get_events_around(
- event.room_id, event.event_id, before_limit, after_limit
- )
+ s = sender_group.setdefault(
+ e.sender, {"results": [], "order": rank_map[e.event_id]}
+ )
+ s["results"].append(e.event_id)
+
+ return (
+ _SearchResult(
+ search_result["count"],
+ rank_map,
+ allowed_events,
+ room_groups,
+ highlights,
+ ),
+ sender_group,
+ )
- logger.info(
- "Context for search returned %d and %d events",
- len(res.events_before),
- len(res.events_after),
- )
+ async def _search_by_recent(
+ self,
+ user: UserID,
+ room_ids: Collection[str],
+ search_term: str,
+ keys: Iterable[str],
+ search_filter: Filter,
+ batch_group: Optional[str],
+ batch_group_key: Optional[str],
+ batch_token: Optional[str],
+ ) -> Tuple[_SearchResult, Optional[str]]:
+ """
+ Performs a full text search for a user ordering by recent.
- events_before = await filter_events_for_client(
- self.storage, user.to_string(), res.events_before
- )
+ Args:
+ user: The user performing the search.
+ room_ids: List of room ids to search in
+ search_term: Search term to search for
+ keys: List of keys to search in, currently supports
+ "content.body", "content.name", "content.topic"
+ search_filter: The event filter to use.
+ batch_group: Pagination information.
+ batch_group_key: Pagination information.
+ batch_token: Pagination information.
- events_after = await filter_events_for_client(
- self.storage, user.to_string(), res.events_after
- )
+ Returns:
+ A tuple of:
+ The search results.
+ Optionally, a pagination token.
+ """
+ rank_map = {} # event_id -> rank of event
+ # Holds result of grouping by room, if applicable
+ room_groups: Dict[str, JsonDict] = {}
- context = {
- "events_before": events_before,
- "events_after": events_after,
- "start": await now_token.copy_and_replace(
- "room_key", res.start
- ).to_string(self.store),
- "end": await now_token.copy_and_replace(
- "room_key", res.end
- ).to_string(self.store),
- }
+ # Holds the next_batch for the entire result set if one of those exists
+ global_next_batch = None
- if include_profile:
- senders = {
- ev.sender
- for ev in itertools.chain(events_before, [event], events_after)
- }
+ highlights = set()
- if events_after:
- last_event_id = events_after[-1].event_id
- else:
- last_event_id = event.event_id
+ room_events: List[EventBase] = []
+ i = 0
+
+ pagination_token = batch_token
+
+ # We keep looping and we keep filtering until we reach the limit
+ # or we run out of things.
+ # But only go around 5 times since otherwise synapse will be sad.
+ while len(room_events) < search_filter.limit and i < 5:
+ i += 1
+ search_result = await self.store.search_rooms(
+ room_ids,
+ search_term,
+ keys,
+ search_filter.limit * 2,
+ pagination_token=pagination_token,
+ )
- state_filter = StateFilter.from_types(
- [(EventTypes.Member, sender) for sender in senders]
- )
+ if search_result["highlights"]:
+ highlights.update(search_result["highlights"])
+
+ count = search_result["count"]
+
+ results = search_result["results"]
+
+ results_map = {r["event"].event_id: r for r in results}
+
+ rank_map.update({r["event"].event_id: r["rank"] for r in results})
+
+ filtered_events = await search_filter.filter([r["event"] for r in results])
- state = await self.state_store.get_state_for_event(
- last_event_id, state_filter
+ events = await filter_events_for_client(
+ self.storage, user.to_string(), filtered_events
+ )
+
+ room_events.extend(events)
+ room_events = room_events[: search_filter.limit]
+
+ if len(results) < search_filter.limit * 2:
+ break
+ else:
+ pagination_token = results[-1]["pagination_token"]
+
+ for event in room_events:
+ group = room_groups.setdefault(event.room_id, {"results": []})
+ group["results"].append(event.event_id)
+
+ if room_events and len(room_events) >= search_filter.limit:
+ last_event_id = room_events[-1].event_id
+ pagination_token = results_map[last_event_id]["pagination_token"]
+
+ # We want to respect the given batch group and group keys so
+ # that if people blindly use the top level `next_batch` token
+ # it returns more from the same group (if applicable) rather
+ # than reverting to searching all results again.
+ if batch_group and batch_group_key:
+ global_next_batch = encode_base64(
+ (
+ "%s\n%s\n%s" % (batch_group, batch_group_key, pagination_token)
+ ).encode("ascii")
+ )
+ else:
+ global_next_batch = encode_base64(
+ ("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii")
+ )
+
+ for room_id, group in room_groups.items():
+ group["next_batch"] = encode_base64(
+ ("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode(
+ "ascii"
)
+ )
- context["profile_info"] = {
- s.state_key: {
- "displayname": s.content.get("displayname", None),
- "avatar_url": s.content.get("avatar_url", None),
- }
- for s in state.values()
- if s.type == EventTypes.Member and s.state_key in senders
- }
+ return (
+ _SearchResult(count, rank_map, room_events, room_groups, highlights),
+ global_next_batch,
+ )
- contexts[event.event_id] = context
- else:
- contexts = {}
+ async def _calculate_event_contexts(
+ self,
+ user: UserID,
+ allowed_events: List[EventBase],
+ before_limit: int,
+ after_limit: int,
+ include_profile: bool,
+ ) -> Dict[str, JsonDict]:
+ """
+ Calculates the contextual events for any search results.
- # TODO: Add a limit
+ Args:
+ user: The user performing the search.
+ allowed_events: The search results.
+ before_limit:
+ The number of events before a result to include as context.
+ after_limit:
+ The number of events after a result to include as context.
+ include_profile: True if historical profile information should be
+ included in the event context.
- time_now = self.clock.time_msec()
+ Returns:
+ A map of event ID to contextual information.
+ """
+ now_token = self.hs.get_event_sources().get_current_token()
- aggregations = None
- if self._msc3666_enabled:
- aggregations = await self.store.get_bundled_aggregations(
- # Generate an iterable of EventBase for all the events that will be
- # returned, including contextual events.
- itertools.chain(
- # The events_before and events_after for each context.
- itertools.chain.from_iterable(
- itertools.chain(context["events_before"], context["events_after"]) # type: ignore[arg-type]
- for context in contexts.values()
- ),
- # The returned events.
- allowed_events,
- ),
- user.to_string(),
+ contexts = {}
+ for event in allowed_events:
+ res = await self.store.get_events_around(
+ event.room_id, event.event_id, before_limit, after_limit
)
- for context in contexts.values():
- context["events_before"] = self._event_serializer.serialize_events(
- context["events_before"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type]
+ logger.info(
+ "Context for search returned %d and %d events",
+ len(res.events_before),
+ len(res.events_after),
)
- context["events_after"] = self._event_serializer.serialize_events(
- context["events_after"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type]
+
+ events_before = await filter_events_for_client(
+ self.storage, user.to_string(), res.events_before
)
- state_results = {}
- if include_state:
- for room_id in {e.room_id for e in allowed_events}:
- state = await self.state_handler.get_current_state(room_id)
- state_results[room_id] = list(state.values())
+ events_after = await filter_events_for_client(
+ self.storage, user.to_string(), res.events_after
+ )
- # We're now about to serialize the events. We should not make any
- # blocking calls after this. Otherwise the 'age' will be wrong
+ context: JsonDict = {
+ "events_before": events_before,
+ "events_after": events_after,
+ "start": await now_token.copy_and_replace(
+ "room_key", res.start
+ ).to_string(self.store),
+ "end": await now_token.copy_and_replace("room_key", res.end).to_string(
+ self.store
+ ),
+ }
- results = []
- for e in allowed_events:
- results.append(
- {
- "rank": rank_map[e.event_id],
- "result": self._event_serializer.serialize_event(
- e, time_now, bundle_aggregations=aggregations
- ),
- "context": contexts.get(e.event_id, {}),
+ if include_profile:
+ senders = {
+ ev.sender
+ for ev in itertools.chain(events_before, [event], events_after)
}
- )
- rooms_cat_res = {
- "results": results,
- "count": count,
- "highlights": list(highlights),
- }
+ if events_after:
+ last_event_id = events_after[-1].event_id
+ else:
+ last_event_id = event.event_id
- if state_results:
- s = {}
- for room_id, state_events in state_results.items():
- s[room_id] = self._event_serializer.serialize_events(
- state_events, time_now
+ state_filter = StateFilter.from_types(
+ [(EventTypes.Member, sender) for sender in senders]
)
- rooms_cat_res["state"] = s
-
- if room_groups and "room_id" in group_keys:
- rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups
+ state = await self.state_store.get_state_for_event(
+ last_event_id, state_filter
+ )
- if sender_group and "sender" in group_keys:
- rooms_cat_res.setdefault("groups", {})["sender"] = sender_group
+ context["profile_info"] = {
+ s.state_key: {
+ "displayname": s.content.get("displayname", None),
+ "avatar_url": s.content.get("avatar_url", None),
+ }
+ for s in state.values()
+ if s.type == EventTypes.Member and s.state_key in senders
+ }
- if global_next_batch:
- rooms_cat_res["next_batch"] = global_next_batch
+ contexts[event.event_id] = context
- return {"search_categories": {"room_events": rooms_cat_res}}
+ return contexts
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index aa9a76f8a9..e6050cbce6 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1289,23 +1289,54 @@ class SyncHandler:
# room with by looking at all users that have left a room plus users
# that were in a room we've left.
- users_who_share_room = await self.store.get_users_who_share_room_with_user(
- user_id
- )
-
- # Always tell the user about their own devices. We check as the user
- # ID is almost certainly already included (unless they're not in any
- # rooms) and taking a copy of the set is relatively expensive.
- if user_id not in users_who_share_room:
- users_who_share_room = set(users_who_share_room)
- users_who_share_room.add(user_id)
+ users_that_have_changed = set()
- tracked_users = users_who_share_room
+ joined_rooms = sync_result_builder.joined_room_ids
- # Step 1a, check for changes in devices of users we share a room with
- users_that_have_changed = await self.store.get_users_whose_devices_changed(
- since_token.device_list_key, tracked_users
+ # Step 1a, check for changes in devices of users we share a room
+ # with
+ #
+ # We do this in two different ways depending on what we have cached.
+ # If we already have a list of all the user that have changed since
+ # the last sync then it's likely more efficient to compare the rooms
+ # they're in with the rooms the syncing user is in.
+ #
+ # If we don't have that info cached then we get all the users that
+ # share a room with our user and check if those users have changed.
+ changed_users = self.store.get_cached_device_list_changes(
+ since_token.device_list_key
)
+ if changed_users is not None:
+ result = await self.store.get_rooms_for_users_with_stream_ordering(
+ changed_users
+ )
+
+ for changed_user_id, entries in result.items():
+ # Check if the changed user shares any rooms with the user,
+ # or if the changed user is the syncing user (as we always
+ # want to include device list updates of their own devices).
+ if user_id == changed_user_id or any(
+ e.room_id in joined_rooms for e in entries
+ ):
+ users_that_have_changed.add(changed_user_id)
+ else:
+ users_who_share_room = (
+ await self.store.get_users_who_share_room_with_user(user_id)
+ )
+
+ # Always tell the user about their own devices. We check as the user
+ # ID is almost certainly already included (unless they're not in any
+ # rooms) and taking a copy of the set is relatively expensive.
+ if user_id not in users_who_share_room:
+ users_who_share_room = set(users_who_share_room)
+ users_who_share_room.add(user_id)
+
+ tracked_users = users_who_share_room
+ users_that_have_changed = (
+ await self.store.get_users_whose_devices_changed(
+ since_token.device_list_key, tracked_users
+ )
+ )
# Step 1b, check for newly joined rooms
for room_id in newly_joined_rooms:
@@ -1329,7 +1360,14 @@ class SyncHandler:
newly_left_users.update(left_users)
# Remove any users that we still share a room with.
- newly_left_users -= users_who_share_room
+ left_users_rooms = (
+ await self.store.get_rooms_for_users_with_stream_ordering(
+ newly_left_users
+ )
+ )
+ for user_id, entries in left_users_rooms.items():
+ if any(e.room_id in joined_rooms for e in entries):
+ newly_left_users.discard(user_id)
return DeviceLists(changed=users_that_have_changed, left=newly_left_users)
else:
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index c5f8fcbb2a..e7656fbb9f 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -958,6 +958,7 @@ class MatrixFederationHttpClient:
)
return body
+ @overload
async def get_json(
self,
destination: str,
@@ -967,7 +968,38 @@ class MatrixFederationHttpClient:
timeout: Optional[int] = None,
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
+ parser: Literal[None] = None,
+ max_response_size: Optional[int] = None,
) -> Union[JsonDict, list]:
+ ...
+
+ @overload
+ async def get_json(
+ self,
+ destination: str,
+ path: str,
+ args: Optional[QueryArgs] = ...,
+ retry_on_dns_fail: bool = ...,
+ timeout: Optional[int] = ...,
+ ignore_backoff: bool = ...,
+ try_trailing_slash_on_400: bool = ...,
+ parser: ByteParser[T] = ...,
+ max_response_size: Optional[int] = ...,
+ ) -> T:
+ ...
+
+ async def get_json(
+ self,
+ destination: str,
+ path: str,
+ args: Optional[QueryArgs] = None,
+ retry_on_dns_fail: bool = True,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ try_trailing_slash_on_400: bool = False,
+ parser: Optional[ByteParser] = None,
+ max_response_size: Optional[int] = None,
+ ):
"""GETs some json from the given host homeserver and path
Args:
@@ -992,6 +1024,13 @@ class MatrixFederationHttpClient:
try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
response we should try appending a trailing slash to the end of
the request. Workaround for #3622 in Synapse <= v0.99.3.
+
+ parser: The parser to use to decode the response. Defaults to
+ parsing as JSON.
+
+ max_response_size: The maximum size to read from the response. If None,
+ uses the default.
+
Returns:
Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
@@ -1026,8 +1065,17 @@ class MatrixFederationHttpClient:
else:
_sec_timeout = self.default_timeout
+ if parser is None:
+ parser = JsonParser()
+
body = await _handle_response(
- self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
+ self.reactor,
+ _sec_timeout,
+ request,
+ response,
+ start_ms,
+ parser=parser,
+ max_response_size=max_response_size,
)
return body
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
deleted file mode 100644
index b9933a1528..0000000000
--- a/synapse/logging/_structured.py
+++ /dev/null
@@ -1,163 +0,0 @@
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import os.path
-from typing import Any, Dict, Generator, Optional, Tuple
-
-from constantly import NamedConstant, Names
-
-from synapse.config._base import ConfigError
-
-
-class DrainType(Names):
- CONSOLE = NamedConstant()
- CONSOLE_JSON = NamedConstant()
- CONSOLE_JSON_TERSE = NamedConstant()
- FILE = NamedConstant()
- FILE_JSON = NamedConstant()
- NETWORK_JSON_TERSE = NamedConstant()
-
-
-DEFAULT_LOGGERS = {"synapse": {"level": "info"}}
-
-
-def parse_drain_configs(
- drains: dict,
-) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
- """
- Parse the drain configurations.
-
- Args:
- drains (dict): A list of drain configurations.
-
- Yields:
- dict instances representing a logging handler.
-
- Raises:
- ConfigError: If any of the drain configuration items are invalid.
- """
-
- for name, config in drains.items():
- if "type" not in config:
- raise ConfigError("Logging drains require a 'type' key.")
-
- try:
- logging_type = DrainType.lookupByName(config["type"].upper())
- except ValueError:
- raise ConfigError(
- "%s is not a known logging drain type." % (config["type"],)
- )
-
- # Either use the default formatter or the tersejson one.
- if logging_type in (
- DrainType.CONSOLE_JSON,
- DrainType.FILE_JSON,
- ):
- formatter: Optional[str] = "json"
- elif logging_type in (
- DrainType.CONSOLE_JSON_TERSE,
- DrainType.NETWORK_JSON_TERSE,
- ):
- formatter = "tersejson"
- else:
- # A formatter of None implies using the default formatter.
- formatter = None
-
- if logging_type in [
- DrainType.CONSOLE,
- DrainType.CONSOLE_JSON,
- DrainType.CONSOLE_JSON_TERSE,
- ]:
- location = config.get("location")
- if location is None or location not in ["stdout", "stderr"]:
- raise ConfigError(
- (
- "The %s drain needs the 'location' key set to "
- "either 'stdout' or 'stderr'."
- )
- % (logging_type,)
- )
-
- yield name, {
- "class": "logging.StreamHandler",
- "formatter": formatter,
- "stream": "ext://sys." + location,
- }
-
- elif logging_type in [DrainType.FILE, DrainType.FILE_JSON]:
- if "location" not in config:
- raise ConfigError(
- "The %s drain needs the 'location' key set." % (logging_type,)
- )
-
- location = config.get("location")
- if os.path.abspath(location) != location:
- raise ConfigError(
- "File paths need to be absolute, '%s' is a relative path"
- % (location,)
- )
-
- yield name, {
- "class": "logging.FileHandler",
- "formatter": formatter,
- "filename": location,
- }
-
- elif logging_type in [DrainType.NETWORK_JSON_TERSE]:
- host = config.get("host")
- port = config.get("port")
- maximum_buffer = config.get("maximum_buffer", 1000)
-
- yield name, {
- "class": "synapse.logging.RemoteHandler",
- "formatter": formatter,
- "host": host,
- "port": port,
- "maximum_buffer": maximum_buffer,
- }
-
- else:
- raise ConfigError(
- "The %s drain type is currently not implemented."
- % (config["type"].upper(),)
- )
-
-
-def setup_structured_logging(
- log_config: dict,
-) -> dict:
- """
- Convert a legacy structured logging configuration (from Synapse < v1.23.0)
- to one compatible with the new standard library handlers.
- """
- if "drains" not in log_config:
- raise ConfigError("The logging configuration requires a list of drains.")
-
- new_config = {
- "version": 1,
- "formatters": {
- "json": {"class": "synapse.logging.JsonFormatter"},
- "tersejson": {"class": "synapse.logging.TerseJsonFormatter"},
- },
- "handlers": {},
- "loggers": log_config.get("loggers", DEFAULT_LOGGERS),
- "root": {"handlers": []},
- }
-
- for handler_name, handler in parse_drain_configs(log_config["drains"]):
- new_config["handlers"][handler_name] = handler
-
- # Add each handler to the root logger.
- new_config["root"]["handlers"].append(handler_name)
-
- return new_config
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index d4fca36923..07020bfb8d 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -70,6 +70,7 @@ from synapse.handlers.account_validity import (
from synapse.handlers.auth import (
CHECK_3PID_AUTH_CALLBACK,
CHECK_AUTH_CALLBACK,
+ GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK,
GET_USERNAME_FOR_REGISTRATION_CALLBACK,
IS_3PID_ALLOWED_CALLBACK,
ON_LOGGED_OUT_CALLBACK,
@@ -317,6 +318,9 @@ class ModuleApi:
get_username_for_registration: Optional[
GET_USERNAME_FOR_REGISTRATION_CALLBACK
] = None,
+ get_displayname_for_registration: Optional[
+ GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
+ ] = None,
) -> None:
"""Registers callbacks for password auth provider capabilities.
@@ -328,6 +332,7 @@ class ModuleApi:
is_3pid_allowed=is_3pid_allowed,
auth_checkers=auth_checkers,
get_username_for_registration=get_username_for_registration,
+ get_displayname_for_registration=get_displayname_for_registration,
)
def register_background_update_controller_callbacks(
@@ -648,7 +653,11 @@ class ModuleApi:
Added in Synapse v1.9.0.
Args:
- auth_provider: identifier for the remote auth provider
+ auth_provider: identifier for the remote auth provider, see `sso` and
+ `oidc_providers` in the homeserver configuration.
+
+ Note that no error is raised if the provided value is not in the
+ homeserver configuration.
external_id: id on that system
user_id: complete mxid that it is mapped to
"""
diff --git a/synapse/notifier.py b/synapse/notifier.py
index e0fad2da66..753dd6b6a5 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -138,7 +138,7 @@ class _NotifierUserStream:
self.current_token = self.current_token.copy_and_advance(stream_key, stream_id)
self.last_notified_token = self.current_token
self.last_notified_ms = time_now_ms
- noify_deferred = self.notify_deferred
+ notify_deferred = self.notify_deferred
log_kv(
{
@@ -153,7 +153,7 @@ class _NotifierUserStream:
with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred())
- noify_deferred.callback(self.current_token)
+ notify_deferred.callback(self.current_token)
def remove(self, notifier: "Notifier") -> None:
"""Remove this listener from all the indexes in the Notifier
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 910b05c0da..832eaa34e9 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -130,7 +130,9 @@ def make_base_prepend_rules(
return rules
-BASE_APPEND_CONTENT_RULES = [
+# We have to annotate these types, otherwise mypy infers them as
+# `List[Dict[str, Sequence[Collection[str]]]]`.
+BASE_APPEND_CONTENT_RULES: List[Dict[str, Any]] = [
{
"rule_id": "global/content/.m.rule.contains_user_name",
"conditions": [
@@ -149,7 +151,7 @@ BASE_APPEND_CONTENT_RULES = [
]
-BASE_PREPEND_OVERRIDE_RULES = [
+BASE_PREPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
{
"rule_id": "global/override/.m.rule.master",
"enabled": False,
@@ -159,7 +161,7 @@ BASE_PREPEND_OVERRIDE_RULES = [
]
-BASE_APPEND_OVERRIDE_RULES = [
+BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
{
"rule_id": "global/override/.m.rule.suppress_notices",
"conditions": [
@@ -278,7 +280,7 @@ BASE_APPEND_OVERRIDE_RULES = [
]
-BASE_APPEND_UNDERRIDE_RULES = [
+BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
{
"rule_id": "global/underride/.m.rule.call",
"conditions": [
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 96559081d0..52c7ff3572 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -109,6 +109,7 @@ class HttpPusher(Pusher):
self.data_minus_url = {}
self.data_minus_url.update(self.data)
del self.data_minus_url["url"]
+ self.badge_count_last_call: Optional[int] = None
def on_started(self, should_check_for_notifs: bool) -> None:
"""Called when this pusher has been started.
@@ -136,7 +137,9 @@ class HttpPusher(Pusher):
self.user_id,
group_by_room=self._group_unread_count_by_room,
)
- await self._send_badge(badge)
+ if self.badge_count_last_call is None or self.badge_count_last_call != badge:
+ self.badge_count_last_call = badge
+ await self._send_badge(badge)
def on_timer(self) -> None:
self._start_processing()
@@ -322,7 +325,7 @@ class HttpPusher(Pusher):
# This was checked in the __init__, but mypy doesn't seem to know that.
assert self.data is not None
if self.data.get("format") == "event_id_only":
- d = {
+ d: Dict[str, Any] = {
"notification": {
"event_id": event.event_id,
"room_id": event.room_id,
@@ -402,6 +405,8 @@ class HttpPusher(Pusher):
rejected = []
if "rejected" in resp:
rejected = resp["rejected"]
+ else:
+ self.badge_count_last_call = badge
return rejected
async def _send_badge(self, badge: int) -> None:
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 86162e0f2c..f43fbb5842 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -87,7 +87,8 @@ REQUIREMENTS = [
# We enforce that we have a `cryptography` version that bundles an `openssl`
# with the latest security patches.
"cryptography>=3.4.7",
- "ijson>=3.1",
+ # ijson 3.1.4 fixes a bug with "." in property names
+ "ijson>=3.1.4",
"matrix-common~=1.1.0",
]
diff --git a/synapse/res/providers.json b/synapse/res/providers.json
index f1838f9559..7b9958e454 100644
--- a/synapse/res/providers.json
+++ b/synapse/res/providers.json
@@ -5,8 +5,6 @@
"endpoints": [
{
"schemes": [
- "https://twitter.com/*/status/*",
- "https://*.twitter.com/*/status/*",
"https://twitter.com/*/moments/*",
"https://*.twitter.com/*/moments/*"
],
@@ -14,4 +12,4 @@
}
]
}
-]
\ No newline at end of file
+]
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index cfa2aee76d..efe299e698 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -883,7 +883,9 @@ class WhoamiRestServlet(RestServlet):
response = {
"user_id": requester.user.to_string(),
# MSC: https://github.com/matrix-org/matrix-doc/pull/3069
+ # Entered spec in Matrix 1.2
"org.matrix.msc3069.is_guest": bool(requester.is_guest),
+ "is_guest": bool(requester.is_guest),
}
# Appservices and similar accounts do not have device IDs
diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py
index 9c15a04338..e0b2b80e5b 100644
--- a/synapse/rest/client/auth.py
+++ b/synapse/rest/client/auth.py
@@ -62,7 +62,7 @@ class AuthRestServlet(RestServlet):
if stagetype == LoginType.RECAPTCHA:
html = self.recaptcha_template.render(
session=session,
- myurl="%s/r0/auth/%s/fallback/web"
+ myurl="%s/v3/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
sitekey=self.hs.config.captcha.recaptcha_public_key,
)
@@ -74,7 +74,7 @@ class AuthRestServlet(RestServlet):
self.hs.config.server.public_baseurl,
self.hs.config.consent.user_consent_version,
),
- myurl="%s/r0/auth/%s/fallback/web"
+ myurl="%s/v3/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
)
@@ -118,7 +118,7 @@ class AuthRestServlet(RestServlet):
# Authentication failed, let user try again
html = self.recaptcha_template.render(
session=session,
- myurl="%s/r0/auth/%s/fallback/web"
+ myurl="%s/v3/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
sitekey=self.hs.config.captcha.recaptcha_public_key,
error=e.msg,
@@ -143,7 +143,7 @@ class AuthRestServlet(RestServlet):
self.hs.config.server.public_baseurl,
self.hs.config.consent.user_consent_version,
),
- myurl="%s/r0/auth/%s/fallback/web"
+ myurl="%s/v3/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
error=e.msg,
)
diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py
index 6682da077a..e05c926b6f 100644
--- a/synapse/rest/client/capabilities.py
+++ b/synapse/rest/client/capabilities.py
@@ -72,20 +72,6 @@ class CapabilitiesRestServlet(RestServlet):
"org.matrix.msc3244.room_capabilities"
] = MSC3244_CAPABILITIES
- # Must be removed in later versions.
- # Is only included for migration.
- # Also the parts in `synapse/config/experimental.py`.
- if self.config.experimental.msc3283_enabled:
- response["capabilities"]["org.matrix.msc3283.set_displayname"] = {
- "enabled": self.config.registration.enable_set_displayname
- }
- response["capabilities"]["org.matrix.msc3283.set_avatar_url"] = {
- "enabled": self.config.registration.enable_set_avatar_url
- }
- response["capabilities"]["org.matrix.msc3283.3pid_changes"] = {
- "enabled": self.config.registration.enable_3pid_changes
- }
-
if self.config.experimental.msc3440_enabled:
response["capabilities"]["io.element.thread"] = {"enabled": True}
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index c965e2bda2..b8a5135e02 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -694,11 +694,18 @@ class RegisterRestServlet(RestServlet):
session_id
)
+ display_name = await (
+ self.password_auth_provider.get_displayname_for_registration(
+ auth_result, params
+ )
+ )
+
registered_user_id = await self.registration_handler.register_user(
localpart=desired_username,
password_hash=password_hash,
guest_access_token=guest_access_token,
threepid=threepid,
+ default_display_name=display_name,
address=client_addr,
user_agent_ips=entries,
)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 2290c57c12..00f29344a8 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -73,6 +73,8 @@ class VersionsRestServlet(RestServlet):
"r0.5.0",
"r0.6.0",
"r0.6.1",
+ "v1.1",
+ "v1.2",
],
# as per MSC1497:
"unstable_features": {
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 8d3d1e54dc..c08b60d10a 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -402,7 +402,15 @@ class PreviewUrlResource(DirectServeJsonResource):
url,
output_stream=output_stream,
max_size=self.max_spider_size,
- headers={"Accept-Language": self.url_preview_accept_language},
+ headers={
+ b"Accept-Language": self.url_preview_accept_language,
+ # Use a custom user agent for the preview because some sites will only return
+ # Open Graph metadata to crawler user agents. Omit the Synapse version
+ # string to avoid leaking information.
+ b"User-Agent": [
+ "Synapse (bot; +https://github.com/matrix-org/synapse)"
+ ],
+ },
is_allowed_content_type=_is_previewable,
)
except SynapseError:
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 8d845fe951..3b3a089b76 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -670,6 +670,16 @@ class DeviceWorkerStore(SQLBaseStore):
device["device_id"]: db_to_json(device["content"]) for device in devices
}
+ def get_cached_device_list_changes(
+ self,
+ from_key: int,
+ ) -> Optional[Set[str]]:
+ """Get set of users whose devices have changed since `from_key`, or None
+ if that information is not in our cache.
+ """
+
+ return self._device_list_stream_cache.get_all_entities_changed(from_key)
+
async def get_users_whose_devices_changed(
self, from_key: int, user_ids: Iterable[str]
) -> Set[str]:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 5246fccad5..a1d7a9b413 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -975,6 +975,17 @@ class PersistEventsStore:
to_delete = delta_state.to_delete
to_insert = delta_state.to_insert
+ # Figure out the changes of membership to invalidate the
+ # `get_rooms_for_user` cache.
+ # We find out which membership events we may have deleted
+ # and which we have added, then we invalidate the caches for all
+ # those users.
+ members_changed = {
+ state_key
+ for ev_type, state_key in itertools.chain(to_delete, to_insert)
+ if ev_type == EventTypes.Member
+ }
+
if delta_state.no_longer_in_room:
# Server is no longer in the room so we delete the room from
# current_state_events, being careful we've already updated the
@@ -993,6 +1004,11 @@ class PersistEventsStore:
"""
txn.execute(sql, (stream_id, self._instance_name, room_id))
+ # We also want to invalidate the membership caches for users
+ # that were in the room.
+ users_in_room = self.store.get_users_in_room_txn(txn, room_id)
+ members_changed.update(users_in_room)
+
self.db_pool.simple_delete_txn(
txn,
table="current_state_events",
@@ -1102,17 +1118,6 @@ class PersistEventsStore:
# Invalidate the various caches
- # Figure out the changes of membership to invalidate the
- # `get_rooms_for_user` cache.
- # We find out which membership events we may have deleted
- # and which we have added, then we invalidate the caches for all
- # those users.
- members_changed = {
- state_key
- for ev_type, state_key in itertools.chain(to_delete, to_insert)
- if ev_type == EventTypes.Member
- }
-
for member in members_changed:
txn.call_after(
self.store.get_rooms_for_user_with_stream_ordering.invalidate,
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 8d4287045a..2a255d1031 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -408,7 +408,7 @@ class EventsWorkerStore(SQLBaseStore):
include the previous states content in the unsigned field.
allow_rejected: If True, return rejected events. Otherwise,
- omits rejeted events from the response.
+ omits rejected events from the response.
Returns:
A mapping from event_id to event.
@@ -1854,7 +1854,7 @@ class EventsWorkerStore(SQLBaseStore):
forward_edge_query = """
SELECT 1 FROM event_edges
/* Check to make sure the event referencing our event in question is not rejected */
- LEFT JOIN rejections ON event_edges.event_id == rejections.event_id
+ LEFT JOIN rejections ON event_edges.event_id = rejections.event_id
WHERE
event_edges.room_id = ?
AND event_edges.prev_event_id = ?
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 4f05811a77..d3c4611686 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -12,15 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast
from synapse.api.presence import PresenceState, UserPresenceState
from synapse.replication.tcp.streams import PresenceStream
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import (
+ AbstractStreamIdGenerator,
+ MultiWriterIdGenerator,
+ StreamIdGenerator,
+)
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.iterutils import batch_iter
@@ -35,7 +43,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore):
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
- ):
+ ) -> None:
super().__init__(database, db_conn, hs)
# Used by `PresenceStore._get_active_presence()`
@@ -54,11 +62,14 @@ class PresenceStore(PresenceBackgroundUpdateStore):
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
- ):
+ ) -> None:
super().__init__(database, db_conn, hs)
+ self._instance_name = hs.get_instance_name()
+ self._presence_id_gen: AbstractStreamIdGenerator
+
self._can_persist_presence = (
- hs.get_instance_name() in hs.config.worker.writers.presence
+ self._instance_name in hs.config.worker.writers.presence
)
if isinstance(database.engine, PostgresEngine):
@@ -109,7 +120,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
return stream_orderings[-1], self._presence_id_gen.get_current_token()
- def _update_presence_txn(self, txn, stream_orderings, presence_states):
+ def _update_presence_txn(
+ self, txn: LoggingTransaction, stream_orderings, presence_states
+ ) -> None:
for stream_id, state in zip(stream_orderings, presence_states):
txn.call_after(
self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
@@ -183,19 +196,23 @@ class PresenceStore(PresenceBackgroundUpdateStore):
if last_id == current_id:
return [], current_id, False
- def get_all_presence_updates_txn(txn):
+ def get_all_presence_updates_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, list]], int, bool]:
sql = """
SELECT stream_id, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts,
- status_msg,
- currently_active
+ status_msg, currently_active
FROM presence_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
- updates = [(row[0], row[1:]) for row in txn]
+ updates = cast(
+ List[Tuple[int, list]],
+ [(row[0], row[1:]) for row in txn],
+ )
upper_bound = current_id
limited = False
@@ -210,7 +227,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
)
@cached()
- def _get_presence_for_user(self, user_id):
+ def _get_presence_for_user(self, user_id: str) -> None:
raise NotImplementedError()
@cachedList(
@@ -218,7 +235,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
list_name="user_ids",
num_args=1,
)
- async def get_presence_for_users(self, user_ids):
+ async def get_presence_for_users(
+ self, user_ids: Iterable[str]
+ ) -> Dict[str, UserPresenceState]:
rows = await self.db_pool.simple_select_many_batch(
table="presence_stream",
column="user_id",
@@ -257,7 +276,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
True if the user should have full presence sent to them, False otherwise.
"""
- def _should_user_receive_full_presence_with_token_txn(txn):
+ def _should_user_receive_full_presence_with_token_txn(
+ txn: LoggingTransaction,
+ ) -> bool:
sql = """
SELECT 1 FROM users_to_send_full_presence_to
WHERE user_id = ?
@@ -271,7 +292,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
_should_user_receive_full_presence_with_token_txn,
)
- async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]):
+ async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]) -> None:
"""Adds to the list of users who should receive a full snapshot of presence
upon their next sync.
@@ -353,10 +374,10 @@ class PresenceStore(PresenceBackgroundUpdateStore):
return users_to_state
- def get_current_presence_token(self):
+ def get_current_presence_token(self) -> int:
return self._presence_id_gen.get_current_token()
- def _get_active_presence(self, db_conn: Connection):
+ def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]:
"""Fetch non-offline presence from the database so that we can register
the appropriate time outs.
"""
@@ -379,12 +400,12 @@ class PresenceStore(PresenceBackgroundUpdateStore):
return [UserPresenceState(**row) for row in rows]
- def take_presence_startup_info(self):
+ def take_presence_startup_info(self) -> List[UserPresenceState]:
active_on_startup = self._presence_on_startup
- self._presence_on_startup = None
+ self._presence_on_startup = []
return active_on_startup
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(self, stream_name, instance_name, token, rows) -> None:
if stream_name == PresenceStream.NAME:
self._presence_id_gen.advance(instance_name, token)
for row in rows:
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index e87a8fb85d..2e3818e432 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -13,9 +13,10 @@
# limitations under the License.
import logging
-from typing import Any, List, Set, Tuple
+from typing import Any, List, Set, Tuple, cast
from synapse.api.errors import SynapseError
+from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.storage.databases.main.state import StateGroupWorkerStore
from synapse.types import RoomStreamToken
@@ -55,7 +56,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
)
def _purge_history_txn(
- self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ token: RoomStreamToken,
+ delete_local_events: bool,
) -> Set[int]:
# Tables that should be pruned:
# event_auth
@@ -273,7 +278,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
""",
(room_id,),
)
- (min_depth,) = txn.fetchone()
+ (min_depth,) = cast(Tuple[int], txn.fetchone())
logger.info("[purge] updating room_depth to %d", min_depth)
@@ -318,7 +323,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"purge_room", self._purge_room_txn, room_id
)
- def _purge_room_txn(self, txn, room_id: str) -> List[int]:
+ def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]:
# First we fetch all the state groups that should be deleted, before
# we delete that information.
txn.execute(
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index aac94fa464..17110bb033 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -622,10 +622,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
) -> None:
"""Record a mapping from an external user id to a mxid
+ See notes in _record_user_external_id_txn about what constitutes valid data.
+
Args:
auth_provider: identifier for the remote auth provider
external_id: id on that system
user_id: complete mxid that it is mapped to
+
Raises:
ExternalIDReuseException if the new external_id could not be mapped.
"""
@@ -648,6 +651,21 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
external_id: str,
user_id: str,
) -> None:
+ """
+ Record a mapping from an external user id to a mxid.
+
+ Note that the auth provider IDs (and the external IDs) are not validated
+ against configured IdPs as Synapse does not know its relationship to
+ external systems. For example, it might be useful to pre-configure users
+ before enabling a new IdP or an IdP might be temporarily offline, but
+ still valid.
+
+ Args:
+ txn: The database transaction.
+ auth_provider: identifier for the remote auth provider
+ external_id: id on that system
+ user_id: complete mxid that it is mapped to
+ """
self.db_pool.simple_insert_txn(
txn,
@@ -687,10 +705,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""Replace mappings from external user ids to a mxid in a single transaction.
All mappings are deleted and the new ones are created.
+ See notes in _record_user_external_id_txn about what constitutes valid data.
+
Args:
record_external_ids:
List with tuple of auth_provider and external_id to record
user_id: complete mxid that it is mapped to
+
Raises:
ExternalIDReuseException if the new external_id could not be mapped.
"""
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index e2c27e594b..36aa1092f6 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -53,8 +53,13 @@ logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _ThreadAggregation:
+ # The latest event in the thread.
latest_event: EventBase
+ # The latest edit to the latest event in the thread.
+ latest_edit: Optional[EventBase]
+ # The total number of events in the thread.
count: int
+ # True if the current user has sent an event to the thread.
current_user_participated: bool
@@ -461,8 +466,8 @@ class RelationsWorkerStore(SQLBaseStore):
@cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
async def _get_thread_summaries(
self, event_ids: Collection[str]
- ) -> Dict[str, Optional[Tuple[int, EventBase]]]:
- """Get the number of threaded replies and the latest reply (if any) for the given event.
+ ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
+ """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
Args:
event_ids: Summarize the thread related to this event ID.
@@ -471,8 +476,10 @@ class RelationsWorkerStore(SQLBaseStore):
A map of the thread summary each event. A missing event implies there
are no threaded replies.
- Each summary includes the number of items in the thread and the most
- recent response.
+ Each summary is a tuple of:
+ The number of events in the thread.
+ The most recent event in the thread.
+ The most recent edit to the most recent event in the thread, if applicable.
"""
def _get_thread_summaries_txn(
@@ -482,7 +489,7 @@ class RelationsWorkerStore(SQLBaseStore):
# TODO Should this only allow m.room.message events.
if isinstance(self.database_engine, PostgresEngine):
# The `DISTINCT ON` clause will pick the *first* row it encounters,
- # so ordering by topologica ordering + stream ordering desc will
+ # so ordering by topological ordering + stream ordering desc will
# ensure we get the latest event in the thread.
sql = """
SELECT DISTINCT ON (parent.event_id) parent.event_id, child.event_id FROM events AS child
@@ -558,6 +565,9 @@ class RelationsWorkerStore(SQLBaseStore):
latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]
+ # Check to see if any of those events are edited.
+ latest_edits = await self._get_applicable_edits(latest_event_ids.values())
+
# Map to the event IDs to the thread summary.
#
# There might not be a summary due to there not being a thread or
@@ -568,7 +578,8 @@ class RelationsWorkerStore(SQLBaseStore):
summary = None
if latest_event:
- summary = (counts[parent_event_id], latest_event)
+ latest_edit = latest_edits.get(latest_event_id)
+ summary = (counts[parent_event_id], latest_event, latest_edit)
summaries[parent_event_id] = summary
return summaries
@@ -828,11 +839,12 @@ class RelationsWorkerStore(SQLBaseStore):
)
for event_id, summary in summaries.items():
if summary:
- thread_count, latest_thread_event = summary
+ thread_count, latest_thread_event, edit = summary
results.setdefault(
event_id, BundledAggregations()
).thread = _ThreadAggregation(
latest_event=latest_thread_event,
+ latest_edit=edit,
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 95167116c9..0416df64ce 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1498,7 +1498,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
async def upsert_room_on_join(
- self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase]
+ self, room_id: str, room_version: RoomVersion, state_events: List[EventBase]
) -> None:
"""Ensure that the room is stored in the table
@@ -1511,7 +1511,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
has_auth_chain_index = await self.has_auth_chain_index(room_id)
create_event = None
- for e in auth_events:
+ for e in state_events:
if (e.type, e.state_key) == (EventTypes.Create, ""):
create_event = e
break
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 4489732fda..e48ec5f495 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -504,6 +504,68 @@ class RoomMemberWorkerStore(EventsWorkerStore):
for room_id, instance, stream_id in txn
)
+ @cachedList(
+ cached_method_name="get_rooms_for_user_with_stream_ordering",
+ list_name="user_ids",
+ )
+ async def get_rooms_for_users_with_stream_ordering(
+ self, user_ids: Collection[str]
+ ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
+ """A batched version of `get_rooms_for_user_with_stream_ordering`.
+
+ Returns:
+ Map from user_id to set of rooms that is currently in.
+ """
+ return await self.db_pool.runInteraction(
+ "get_rooms_for_users_with_stream_ordering",
+ self._get_rooms_for_users_with_stream_ordering_txn,
+ user_ids,
+ )
+
+ def _get_rooms_for_users_with_stream_ordering_txn(
+ self, txn, user_ids: Collection[str]
+ ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
+
+ clause, args = make_in_list_sql_clause(
+ self.database_engine,
+ "c.state_key",
+ user_ids,
+ )
+
+ if self._current_state_events_membership_up_to_date:
+ sql = f"""
+ SELECT c.state_key, room_id, e.instance_name, e.stream_ordering
+ FROM current_state_events AS c
+ INNER JOIN events AS e USING (room_id, event_id)
+ WHERE
+ c.type = 'm.room.member'
+ AND c.membership = ?
+ AND {clause}
+ """
+ else:
+ sql = f"""
+ SELECT c.state_key, room_id, e.instance_name, e.stream_ordering
+ FROM current_state_events AS c
+ INNER JOIN room_memberships AS m USING (room_id, event_id)
+ INNER JOIN events AS e USING (room_id, event_id)
+ WHERE
+ c.type = 'm.room.member'
+ AND m.membership = ?
+ AND {clause}
+ """
+
+ txn.execute(sql, [Membership.JOIN] + args)
+
+ result = {user_id: set() for user_id in user_ids}
+ for user_id, room_id, instance, stream_id in txn:
+ result[user_id].add(
+ GetRoomsForUserWithStreamOrdering(
+ room_id, PersistedEventPosition(instance, stream_id)
+ )
+ )
+
+ return {user_id: frozenset(v) for user_id, v in result.items()}
+
async def get_users_server_still_shares_room_with(
self, user_ids: Collection[str]
) -> Set[str]:
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 2d085a5764..acea300ed3 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -28,6 +28,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -381,17 +382,19 @@ class SearchStore(SearchBackgroundUpdateStore):
):
super().__init__(database, db_conn, hs)
- async def search_msgs(self, room_ids, search_term, keys):
+ async def search_msgs(
+ self, room_ids: Collection[str], search_term: str, keys: Iterable[str]
+ ) -> JsonDict:
"""Performs a full text search over events with given keys.
Args:
- room_ids (list): List of room ids to search in
- search_term (str): Search term to search for
- keys (list): List of keys to search in, currently supports
+ room_ids: List of room ids to search in
+ search_term: Search term to search for
+ keys: List of keys to search in, currently supports
"content.body", "content.name", "content.topic"
Returns:
- list of dicts
+ Dictionary of results
"""
clauses = []
@@ -499,10 +502,10 @@ class SearchStore(SearchBackgroundUpdateStore):
self,
room_ids: Collection[str],
search_term: str,
- keys: List[str],
+ keys: Iterable[str],
limit,
pagination_token: Optional[str] = None,
- ) -> List[dict]:
+ ) -> JsonDict:
"""Performs a full text search over events with given keys.
Args:
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index f7c778bdf2..e7fddd2426 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -58,7 +58,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
- ):
+ ) -> None:
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
@@ -234,10 +234,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
processed_event_count = 0
for room_id, event_count in rooms_to_work_on:
- is_in_room = await self.is_host_joined(room_id, self.server_name)
+ is_in_room = await self.is_host_joined(room_id, self.server_name) # type: ignore[attr-defined]
if is_in_room:
- users_with_profile = await self.get_users_in_room_with_profiles(room_id)
+ users_with_profile = await self.get_users_in_room_with_profiles(room_id) # type: ignore[attr-defined]
# Throw away users excluded from the directory.
users_with_profile = {
user_id: profile
@@ -368,7 +368,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
for user_id in users_to_work_on:
if await self.should_include_local_user_in_dir(user_id):
- profile = await self.get_profileinfo(get_localpart_from_id(user_id))
+ profile = await self.get_profileinfo(get_localpart_from_id(user_id)) # type: ignore[attr-defined]
await self.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url
)
@@ -397,7 +397,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# technically it could be DM-able. In the future, this could potentially
# be configurable per-appservice whether the appservice sender can be
# contacted.
- if self.get_app_service_by_user_id(user) is not None:
+ if self.get_app_service_by_user_id(user) is not None: # type: ignore[attr-defined]
return False
# We're opting to exclude appservice users (anyone matching the user
@@ -405,17 +405,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# they could be DM-able. In the future, this could potentially
# be configurable per-appservice whether the appservice users can be
# contacted.
- if self.get_if_app_services_interested_in_user(user):
+ if self.get_if_app_services_interested_in_user(user): # type: ignore[attr-defined]
# TODO we might want to make this configurable for each app service
return False
# Support users are for diagnostics and should not appear in the user directory.
- if await self.is_support_user(user):
+ if await self.is_support_user(user): # type: ignore[attr-defined]
return False
# Deactivated users aren't contactable, so should not appear in the user directory.
try:
- if await self.get_user_deactivated_status(user):
+ if await self.get_user_deactivated_status(user): # type: ignore[attr-defined]
return False
except StoreError:
# No such user in the users table. No need to do this when calling
@@ -433,20 +433,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
(EventTypes.RoomHistoryVisibility, ""),
)
- current_state_ids = await self.get_filtered_current_state_ids(
+ current_state_ids = await self.get_filtered_current_state_ids( # type: ignore[attr-defined]
room_id, StateFilter.from_types(types_to_filter)
)
join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
if join_rules_id:
- join_rule_ev = await self.get_event(join_rules_id, allow_none=True)
+ join_rule_ev = await self.get_event(join_rules_id, allow_none=True) # type: ignore[attr-defined]
if join_rule_ev:
if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
return True
hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
if hist_vis_id:
- hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)
+ hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) # type: ignore[attr-defined]
if hist_vis_ev:
if (
hist_vis_ev.content.get("history_visibility")
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 7614d76ac6..3af69a2076 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -13,11 +13,23 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ Iterable,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+)
import attr
+from twisted.internet import defer
+
from synapse.api.constants import EventTypes
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@@ -29,6 +41,12 @@ from synapse.storage.state import StateFilter
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import MutableStateMap, StateKey, StateMap
+from synapse.util import unwrapFirstError
+from synapse.util.async_helpers import (
+ AbstractObservableDeferred,
+ ObservableDeferred,
+ yieldable_gather_results,
+)
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
@@ -37,7 +55,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-
MAX_STATE_DELTA_HOPS = 100
@@ -106,6 +123,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
500000,
)
+ # Current ongoing get_state_for_groups in-flight requests
+ # {group ID -> {StateFilter -> ObservableDeferred}}
+ self._state_group_inflight_requests: Dict[
+ int, Dict[StateFilter, AbstractObservableDeferred[StateMap[str]]]
+ ] = {}
+
def get_max_state_group_txn(txn: Cursor) -> int:
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
return txn.fetchone()[0] # type: ignore
@@ -157,7 +180,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
async def _get_state_groups_from_groups(
- self, groups: List[int], state_filter: StateFilter
+ self, groups: Sequence[int], state_filter: StateFilter
) -> Dict[int, StateMap[str]]:
"""Returns the state groups for a given set of groups from the
database, filtering on types of state events.
@@ -228,6 +251,150 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_filter.filter_state(state_dict_ids), not missing_types
+ def _get_state_for_group_gather_inflight_requests(
+ self, group: int, state_filter_left_over: StateFilter
+ ) -> Tuple[Sequence[AbstractObservableDeferred[StateMap[str]]], StateFilter]:
+ """
+ Attempts to gather in-flight requests and re-use them to retrieve state
+ for the given state group, filtered with the given state filter.
+
+ Used as part of _get_state_for_group_using_inflight_cache.
+
+ Returns:
+ Tuple of two values:
+ A sequence of ObservableDeferreds to observe
+ A StateFilter representing what else needs to be requested to fulfill the request
+ """
+
+ inflight_requests = self._state_group_inflight_requests.get(group)
+ if inflight_requests is None:
+ # no requests for this group, need to retrieve it all ourselves
+ return (), state_filter_left_over
+
+ # The list of ongoing requests which will help narrow the current request.
+ reusable_requests = []
+ for (request_state_filter, request_deferred) in inflight_requests.items():
+ new_state_filter_left_over = state_filter_left_over.approx_difference(
+ request_state_filter
+ )
+ if new_state_filter_left_over == state_filter_left_over:
+ # Reusing this request would not gain us anything, so don't bother.
+ continue
+
+ reusable_requests.append(request_deferred)
+ state_filter_left_over = new_state_filter_left_over
+ if state_filter_left_over == StateFilter.none():
+ # we have managed to collect enough of the in-flight requests
+ # to cover our StateFilter and give us the state we need.
+ break
+
+ return reusable_requests, state_filter_left_over
+
+ async def _get_state_for_group_fire_request(
+ self, group: int, state_filter: StateFilter
+ ) -> StateMap[str]:
+ """
+ Fires off a request to get the state at a state group,
+ potentially filtering by type and/or state key.
+
+ This request will be tracked in the in-flight request cache and automatically
+ removed when it is finished.
+
+ Used as part of _get_state_for_group_using_inflight_cache.
+
+ Args:
+ group: ID of the state group for which we want to get state
+ state_filter: the state filter used to fetch state from the database
+ """
+ cache_sequence_nm = self._state_group_cache.sequence
+ cache_sequence_m = self._state_group_members_cache.sequence
+
+ # Help the cache hit ratio by expanding the filter a bit
+ db_state_filter = state_filter.return_expanded()
+
+ async def _the_request() -> StateMap[str]:
+ group_to_state_dict = await self._get_state_groups_from_groups(
+ (group,), state_filter=db_state_filter
+ )
+
+ # Now let's update the caches
+ self._insert_into_cache(
+ group_to_state_dict,
+ db_state_filter,
+ cache_seq_num_members=cache_sequence_m,
+ cache_seq_num_non_members=cache_sequence_nm,
+ )
+
+ # Remove ourselves from the in-flight cache
+ group_request_dict = self._state_group_inflight_requests[group]
+ del group_request_dict[db_state_filter]
+ if not group_request_dict:
+ # If there are no more requests in-flight for this group,
+ # clean up the cache by removing the empty dictionary
+ del self._state_group_inflight_requests[group]
+
+ return group_to_state_dict[group]
+
+ # We don't immediately await the result, so must use run_in_background
+ # But we DO await the result before the current log context (request)
+ # finishes, so don't need to run it as a background process.
+ request_deferred = run_in_background(_the_request)
+ observable_deferred = ObservableDeferred(request_deferred, consumeErrors=True)
+
+ # Insert the ObservableDeferred into the cache
+ group_request_dict = self._state_group_inflight_requests.setdefault(group, {})
+ group_request_dict[db_state_filter] = observable_deferred
+
+ return await make_deferred_yieldable(observable_deferred.observe())
+
+ async def _get_state_for_group_using_inflight_cache(
+ self, group: int, state_filter: StateFilter
+ ) -> MutableStateMap[str]:
+ """
+ Gets the state at a state group, potentially filtering by type and/or
+ state key.
+
+ 1. Calls _get_state_for_group_gather_inflight_requests to gather any
+ ongoing requests which might overlap with the current request.
+ 2. Fires a new request, using _get_state_for_group_fire_request,
+ for any state which cannot be gathered from ongoing requests.
+
+ Args:
+ group: ID of the state group for which we want to get state
+ state_filter: the state filter used to fetch state from the database
+ Returns:
+ state map
+ """
+
+ # first, figure out whether we can re-use any in-flight requests
+ # (and if so, what would be left over)
+ (
+ reusable_requests,
+ state_filter_left_over,
+ ) = self._get_state_for_group_gather_inflight_requests(group, state_filter)
+
+ if state_filter_left_over != StateFilter.none():
+ # Fetch remaining state
+ remaining = await self._get_state_for_group_fire_request(
+ group, state_filter_left_over
+ )
+ assembled_state: MutableStateMap[str] = dict(remaining)
+ else:
+ assembled_state = {}
+
+ gathered = await make_deferred_yieldable(
+ defer.gatherResults(
+ (r.observe() for r in reusable_requests), consumeErrors=True
+ )
+ ).addErrback(unwrapFirstError)
+
+ # assemble our result.
+ for result_piece in gathered:
+ assembled_state.update(result_piece)
+
+ # Filter out any state that may be more than what we asked for.
+ return state_filter.filter_state(assembled_state)
+
async def _get_state_for_groups(
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Dict[int, MutableStateMap[str]]:
@@ -269,31 +436,17 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
if not incomplete_groups:
return state
- cache_sequence_nm = self._state_group_cache.sequence
- cache_sequence_m = self._state_group_members_cache.sequence
-
- # Help the cache hit ratio by expanding the filter a bit
- db_state_filter = state_filter.return_expanded()
-
- group_to_state_dict = await self._get_state_groups_from_groups(
- list(incomplete_groups), state_filter=db_state_filter
- )
+ async def get_from_cache(group: int, state_filter: StateFilter) -> None:
+ state[group] = await self._get_state_for_group_using_inflight_cache(
+ group, state_filter
+ )
- # Now lets update the caches
- self._insert_into_cache(
- group_to_state_dict,
- db_state_filter,
- cache_seq_num_members=cache_sequence_m,
- cache_seq_num_non_members=cache_sequence_nm,
+ await yieldable_gather_results(
+ get_from_cache,
+ incomplete_groups,
+ state_filter,
)
- # And finally update the result dict, by filtering out any extra
- # stuff we pulled out of the database.
- for group, group_state_dict in group_to_state_dict.items():
- # We just replace any existing entries, as we will have loaded
- # everything we need from the database anyway.
- state[group] = state_filter.filter_state(group_state_dict)
-
return state
def _get_state_for_groups_using_cache(
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 913448f0f9..e79ecf64a0 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -204,13 +204,16 @@ class StateFilter:
if get_all_members:
# We want to return everything.
return StateFilter.all()
- else:
+ elif EventTypes.Member in self.types:
# We want to return all non-members, but only particular
# memberships
return StateFilter(
types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}),
include_others=True,
)
+ else:
+ # We want to return all non-members
+ return _ALL_NON_MEMBER_STATE_FILTER
def make_sql_filter_clause(self) -> Tuple[str, List[str]]:
"""Converts the filter to an SQL clause.
@@ -528,6 +531,9 @@ class StateFilter:
_ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True)
+_ALL_NON_MEMBER_STATE_FILTER = StateFilter(
+ types=frozendict({EventTypes.Member: frozenset()}), include_others=True
+)
_NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False)
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 21591d0bfd..4ec2a713cf 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -37,14 +37,16 @@ class _EventSourcesInner:
account_data: AccountDataEventSource
def get_sources(self) -> Iterator[Tuple[str, EventSource]]:
- for attribute in _EventSourcesInner.__attrs_attrs__: # type: ignore[attr-defined]
+ for attribute in attr.fields(_EventSourcesInner):
yield attribute.name, getattr(self, attribute.name)
class EventSources:
def __init__(self, hs: "HomeServer"):
self.sources = _EventSourcesInner(
- *(attribute.type(hs) for attribute in _EventSourcesInner.__attrs_attrs__) # type: ignore[attr-defined]
+ # mypy thinks attribute.type is `Optional`, but we know it's never `None` here since
+ # all the attributes of `_EventSourcesInner` are annotated.
+ *(attribute.type(hs) for attribute in attr.fields(_EventSourcesInner)) # type: ignore[misc]
)
self.store = hs.get_datastore()
diff --git a/synapse/types.py b/synapse/types.py
index f89fb216a6..53be3583a0 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -51,7 +51,7 @@ from synapse.util.stringutils import parse_and_validate_server_name
if TYPE_CHECKING:
from synapse.appservice.api import ApplicationService
- from synapse.storage.databases.main import DataStore
+ from synapse.storage.databases.main import DataStore, PurgeEventsStore
# Define a state map type from type/state_key to T (usually an event ID or
# event)
@@ -485,7 +485,7 @@ class RoomStreamToken:
)
@classmethod
- async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken":
+ async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken":
try:
if string[0] == "s":
return cls(topological=None, stream=int(string[1:]))
@@ -502,7 +502,7 @@ class RoomStreamToken:
instance_id = int(key)
pos = int(value)
- instance_name = await store.get_name_from_instance_id(instance_id)
+ instance_name = await store.get_name_from_instance_id(instance_id) # type: ignore[attr-defined]
instance_map[instance_name] = pos
return cls(
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 15debd6c46..1cbc180eda 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -56,6 +56,7 @@ response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["n
class EvictionReason(Enum):
size = auto()
time = auto()
+ invalidation = auto()
@attr.s(slots=True, auto_attribs=True)
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 67ee4c693b..c6a5d0dfc0 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -133,6 +133,11 @@ class ExpiringCache(Generic[KT, VT]):
raise KeyError(key)
return default
+ if self.iterable:
+ self.metrics.inc_evictions(EvictionReason.invalidation, len(value.value))
+ else:
+ self.metrics.inc_evictions(EvictionReason.invalidation)
+
return value.value
def __contains__(self, key: KT) -> bool:
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 7548b38548..45ff0de638 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -560,8 +560,10 @@ class LruCache(Generic[KT, VT]):
def cache_pop(key: KT, default: Optional[T] = None) -> Union[None, T, VT]:
node = cache.get(key, None)
if node:
- delete_node(node)
+ evicted_len = delete_node(node)
cache.pop(node.key, None)
+ if metrics:
+ metrics.inc_evictions(EvictionReason.invalidation, evicted_len)
return node.value
else:
return default
diff --git a/synapse/util/daemonize.py b/synapse/util/daemonize.py
index de04f34e4e..031880ec39 100644
--- a/synapse/util/daemonize.py
+++ b/synapse/util/daemonize.py
@@ -20,7 +20,7 @@ import os
import signal
import sys
from types import FrameType, TracebackType
-from typing import NoReturn, Type
+from typing import NoReturn, Optional, Type
def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -> None:
@@ -100,7 +100,9 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
# also catch any other uncaught exceptions before we get that far.)
def excepthook(
- type_: Type[BaseException], value: BaseException, traceback: TracebackType
+ type_: Type[BaseException],
+ value: BaseException,
+ traceback: Optional[TracebackType],
) -> None:
logger.critical("Unhanded exception", exc_info=(type_, value, traceback))
@@ -123,7 +125,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
sys.exit(1)
# write a log line on SIGTERM.
- def sigterm(signum: signal.Signals, frame: FrameType) -> NoReturn:
+ def sigterm(signum: int, frame: Optional[FrameType]) -> NoReturn:
logger.warning("Caught signal %s. Stopping daemon." % signum)
sys.exit(0)
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index 1f18654d47..6d4b0b7c5a 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -14,7 +14,7 @@
import functools
import sys
-from typing import Any, Callable, Generator, List, TypeVar
+from typing import Any, Callable, Generator, List, TypeVar, cast
from twisted.internet import defer
from twisted.internet.defer import Deferred
@@ -174,7 +174,9 @@ def _check_yield_points(
)
)
changes.append(err)
- return getattr(e, "value", None)
+ # The `StopIteration` or `_DefGen_Return` contains the return value from the
+ # generator.
+ return cast(T, e.value)
frame = gen.gi_frame
diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py
new file mode 100644
index 0000000000..ec8864dafe
--- /dev/null
+++ b/tests/federation/test_federation_client.py
@@ -0,0 +1,149 @@
+# Copyright 2022 Matrix.org Federation C.I.C
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from unittest import mock
+
+import twisted.web.client
+from twisted.internet import defer
+from twisted.internet.protocol import Protocol
+from twisted.python.failure import Failure
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.room_versions import RoomVersions
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
+
+from tests.unittest import FederatingHomeserverTestCase
+
+
+class FederationClientTest(FederatingHomeserverTestCase):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
+ super().prepare(reactor, clock, homeserver)
+
+ # mock out the Agent used by the federation client, which is easier than
+ # catching the HTTPS connection and do the TLS stuff.
+ self._mock_agent = mock.create_autospec(twisted.web.client.Agent, spec_set=True)
+ homeserver.get_federation_http_client().agent = self._mock_agent
+
+ def test_get_room_state(self):
+ creator = f"@creator:{self.OTHER_SERVER_NAME}"
+ test_room_id = "!room_id"
+
+ # mock up some events to use in the response.
+ # In real life, these would have things in `prev_events` and `auth_events`, but that's
+ # a bit annoying to mock up, and the code under test doesn't care, so we don't bother.
+ create_event_dict = self.add_hashes_and_signatures(
+ {
+ "room_id": test_room_id,
+ "type": "m.room.create",
+ "state_key": "",
+ "sender": creator,
+ "content": {"creator": creator},
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": 500,
+ }
+ )
+ member_event_dict = self.add_hashes_and_signatures(
+ {
+ "room_id": test_room_id,
+ "type": "m.room.member",
+ "sender": creator,
+ "state_key": creator,
+ "content": {"membership": "join"},
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": 600,
+ }
+ )
+ pl_event_dict = self.add_hashes_and_signatures(
+ {
+ "room_id": test_room_id,
+ "type": "m.room.power_levels",
+ "sender": creator,
+ "state_key": "",
+ "content": {},
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": 700,
+ }
+ )
+
+ # mock up the response, and have the agent return it
+ self._mock_agent.request.return_value = defer.succeed(
+ _mock_response(
+ {
+ "pdus": [
+ create_event_dict,
+ member_event_dict,
+ pl_event_dict,
+ ],
+ "auth_chain": [
+ create_event_dict,
+ member_event_dict,
+ ],
+ }
+ )
+ )
+
+ # now fire off the request
+ state_resp, auth_resp = self.get_success(
+ self.hs.get_federation_client().get_room_state(
+ "yet_another_server",
+ test_room_id,
+ "event_id",
+ RoomVersions.V9,
+ )
+ )
+
+ # check the right call got made to the agent
+ self._mock_agent.request.assert_called_once_with(
+ b"GET",
+ b"matrix://yet_another_server/_matrix/federation/v1/state/%21room_id?event_id=event_id",
+ headers=mock.ANY,
+ bodyProducer=None,
+ )
+
+ # ... and that the response is correct.
+
+ # the auth_resp should be empty because all the events are also in state
+ self.assertEqual(auth_resp, [])
+
+ # all of the events should be returned in state_resp, though not necessarily
+ # in the same order. We just check the type on the assumption that if the type
+ # is right, so is the rest of the event.
+ self.assertCountEqual(
+ [e.type for e in state_resp],
+ ["m.room.create", "m.room.member", "m.room.power_levels"],
+ )
+
+
+def _mock_response(resp: JsonDict):
+ body = json.dumps(resp).encode("utf-8")
+
+ def deliver_body(p: Protocol):
+ p.dataReceived(body)
+ p.connectionLost(Failure(twisted.web.client.ResponseDone()))
+
+ response = mock.Mock(
+ code=200,
+ phrase=b"OK",
+ headers=twisted.web.client.Headers({"content-Type": ["application/json"]}),
+ length=len(body),
+ deliverBody=deliver_body,
+ )
+ mock.seal(response)
+ return response
diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py
index a7031a55f2..c2320ce133 100644
--- a/tests/federation/transport/test_client.py
+++ b/tests/federation/transport/test_client.py
@@ -62,3 +62,35 @@ class SendJoinParserTestCase(TestCase):
self.assertEqual(len(parsed_response.state), 1, parsed_response)
self.assertEqual(parsed_response.event_dict, {}, parsed_response)
self.assertIsNone(parsed_response.event, parsed_response)
+ self.assertFalse(parsed_response.partial_state, parsed_response)
+ self.assertEqual(parsed_response.servers_in_room, None, parsed_response)
+
+ def test_partial_state(self) -> None:
+ """Check that the partial_state flag is correctly parsed"""
+ parser = SendJoinParser(RoomVersions.V1, False)
+ response = {
+ "org.matrix.msc3706.partial_state": True,
+ }
+
+ serialised_response = json.dumps(response).encode()
+
+ # Send data to the parser
+ parser.write(serialised_response)
+
+ # Retrieve and check the parsed SendJoinResponse
+ parsed_response = parser.finish()
+ self.assertTrue(parsed_response.partial_state)
+
+ def test_servers_in_room(self) -> None:
+ """Check that the servers_in_room field is correctly parsed"""
+ parser = SendJoinParser(RoomVersions.V1, False)
+ response = {"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]}
+
+ serialised_response = json.dumps(response).encode()
+
+ # Send data to the parser
+ parser.write(serialised_response)
+
+ # Retrieve and check the parsed SendJoinResponse
+ parsed_response = parser.finish()
+ self.assertEqual(parsed_response.servers_in_room, ["hs1", "hs2"])
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 4740dd0a65..49d832de81 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -84,7 +84,7 @@ class CustomAuthProvider:
def __init__(self, config, api: ModuleApi):
api.register_password_auth_provider_callbacks(
- auth_checkers={("test.login_type", ("test_field",)): self.check_auth},
+ auth_checkers={("test.login_type", ("test_field",)): self.check_auth}
)
def check_auth(self, *args):
@@ -122,7 +122,7 @@ class PasswordCustomAuthProvider:
auth_checkers={
("test.login_type", ("test_field",)): self.check_auth,
("m.login.password", ("password",)): self.check_auth,
- },
+ }
)
pass
@@ -163,6 +163,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
account.register_servlets,
]
+ CALLBACK_USERNAME = "get_username_for_registration"
+ CALLBACK_DISPLAYNAME = "get_displayname_for_registration"
+
def setUp(self):
# we use a global mock device, so make sure we are starting with a clean slate
mock_password_provider.reset_mock()
@@ -754,7 +757,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"""Tests that the get_username_for_registration callback can define the username
of a user when registering.
"""
- self._setup_get_username_for_registration()
+ self._setup_get_name_for_registration(
+ callback_name=self.CALLBACK_USERNAME,
+ )
username = "rin"
channel = self.make_request(
@@ -777,30 +782,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"""Tests that the get_username_for_registration callback is only called at the
end of the UIA flow.
"""
- m = self._setup_get_username_for_registration()
-
- # Initiate the UIA flow.
- username = "rin"
- channel = self.make_request(
- "POST",
- "register",
- {"username": username, "type": "m.login.password", "password": "bar"},
+ m = self._setup_get_name_for_registration(
+ callback_name=self.CALLBACK_USERNAME,
)
- self.assertEqual(channel.code, 401)
- self.assertIn("session", channel.json_body)
- # Check that the callback hasn't been called yet.
- m.assert_not_called()
+ username = "rin"
+ res = self._do_uia_assert_mock_not_called(username, m)
- # Finish the UIA flow.
- session = channel.json_body["session"]
- channel = self.make_request(
- "POST",
- "register",
- {"auth": {"session": session, "type": LoginType.DUMMY}},
- )
- self.assertEqual(channel.code, 200, channel.json_body)
- mxid = channel.json_body["user_id"]
+ mxid = res["user_id"]
self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
# Check that the callback has been called.
@@ -817,6 +806,56 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self._test_3pid_allowed("rin", False)
self._test_3pid_allowed("kitay", True)
+ def test_displayname(self):
+ """Tests that the get_displayname_for_registration callback can define the
+ display name of a user when registering.
+ """
+ self._setup_get_name_for_registration(
+ callback_name=self.CALLBACK_DISPLAYNAME,
+ )
+
+ username = "rin"
+ channel = self.make_request(
+ "POST",
+ "/register",
+ {
+ "username": username,
+ "password": "bar",
+ "auth": {"type": LoginType.DUMMY},
+ },
+ )
+ self.assertEqual(channel.code, 200)
+
+ # Our callback takes the username and appends "-foo" to it, check that's what we
+ # have.
+ user_id = UserID.from_string(channel.json_body["user_id"])
+ display_name = self.get_success(
+ self.hs.get_profile_handler().get_displayname(user_id)
+ )
+
+ self.assertEqual(display_name, username + "-foo")
+
+ def test_displayname_uia(self):
+ """Tests that the get_displayname_for_registration callback is only called at the
+ end of the UIA flow.
+ """
+ m = self._setup_get_name_for_registration(
+ callback_name=self.CALLBACK_DISPLAYNAME,
+ )
+
+ username = "rin"
+ res = self._do_uia_assert_mock_not_called(username, m)
+
+ user_id = UserID.from_string(res["user_id"])
+ display_name = self.get_success(
+ self.hs.get_profile_handler().get_displayname(user_id)
+ )
+
+ self.assertEqual(display_name, username + "-foo")
+
+ # Check that the callback has been called.
+ m.assert_called_once()
+
def _test_3pid_allowed(self, username: str, registration: bool):
"""Tests that the "is_3pid_allowed" module callback is called correctly, using
either /register or /account URLs depending on the arguments.
@@ -877,23 +916,47 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
m.assert_called_once_with("email", "bar@test.com", registration)
- def _setup_get_username_for_registration(self) -> Mock:
- """Registers a get_username_for_registration callback that appends "-foo" to the
- username the client is trying to register.
+ def _setup_get_name_for_registration(self, callback_name: str) -> Mock:
+ """Registers either a get_username_for_registration callback or a
+ get_displayname_for_registration callback that appends "-foo" to the username the
+ client is trying to register.
"""
- async def get_username_for_registration(uia_results, params):
+ async def callback(uia_results, params):
self.assertIn(LoginType.DUMMY, uia_results)
username = params["username"]
return username + "-foo"
- m = Mock(side_effect=get_username_for_registration)
+ m = Mock(side_effect=callback)
password_auth_provider = self.hs.get_password_auth_provider()
- password_auth_provider.get_username_for_registration_callbacks.append(m)
+ getattr(password_auth_provider, callback_name + "_callbacks").append(m)
return m
+ def _do_uia_assert_mock_not_called(self, username: str, m: Mock) -> JsonDict:
+ # Initiate the UIA flow.
+ channel = self.make_request(
+ "POST",
+ "register",
+ {"username": username, "type": "m.login.password", "password": "bar"},
+ )
+ self.assertEqual(channel.code, 401)
+ self.assertIn("session", channel.json_body)
+
+ # Check that the callback hasn't been called yet.
+ m.assert_not_called()
+
+ # Finish the UIA flow.
+ session = channel.json_body["session"]
+ channel = self.make_request(
+ "POST",
+ "register",
+ {"auth": {"session": session, "type": LoginType.DUMMY}},
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ return channel.json_body
+
def _get_login_flows(self) -> JsonDict:
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index c068d329a9..e1e3fb97c5 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -571,9 +571,7 @@ class HTTPPusherTests(HomeserverTestCase):
# Carry out our option-value specific test
#
# This push should still only contain an unread count of 1 (for 1 unread room)
- self.assertEqual(
- self.push_attempts[5][2]["notification"]["counts"]["unread"], 1
- )
+ self._check_push_attempt(6, 1)
@override_config({"push": {"group_unread_count_by_room": False}})
def test_push_unread_count_message_count(self):
@@ -585,11 +583,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Carry out our option-value specific test
#
- # We're counting every unread message, so there should now be 4 since the
+ # We're counting every unread message, so there should now be 3 since the
# last read receipt
- self.assertEqual(
- self.push_attempts[5][2]["notification"]["counts"]["unread"], 4
- )
+ self._check_push_attempt(6, 3)
def _test_push_unread_count(self):
"""
@@ -597,8 +593,9 @@ class HTTPPusherTests(HomeserverTestCase):
Note that:
* Sending messages will cause push notifications to go out to relevant users
- * Sending a read receipt will cause a "badge update" notification to go out to
- the user that sent the receipt
+ * Sending a read receipt will cause the HTTP pusher to check whether the unread
+ count has changed since the last push notification. If so, a "badge update"
+ notification goes out to the user that sent the receipt
"""
# Register the user who gets notified
user_id = self.register_user("user", "pass")
@@ -642,24 +639,74 @@ class HTTPPusherTests(HomeserverTestCase):
# position in the room. We'll set the read position to this event in a moment
first_message_event_id = response["event_id"]
- # Advance time a bit (so the pusher will register something has happened) and
- # make the push succeed
- self.push_attempts[0][0].callback({})
+ expected_push_attempts = 1
+ self._check_push_attempt(expected_push_attempts, 0)
+
+ self._send_read_request(access_token, first_message_event_id, room_id)
+
+ # Unread count has not changed. Therefore, ensure that read request does not
+ # trigger a push notification.
+ self.assertEqual(len(self.push_attempts), 1)
+
+ # Send another message
+ response2 = self.helper.send(
+ room_id, body="How's the weather today?", tok=other_access_token
+ )
+ second_message_event_id = response2["event_id"]
+
+ expected_push_attempts += 1
+
+ self._check_push_attempt(expected_push_attempts, 1)
+
+ self._send_read_request(access_token, second_message_event_id, room_id)
+ expected_push_attempts += 1
+
+ self._check_push_attempt(expected_push_attempts, 0)
+
+ # If we're grouping by room, sending more messages shouldn't increase the
+ # unread count, as they're all being sent in the same room. Otherwise, it
+ # should. Therefore, the last call to _check_push_attempt is done in the
+ # caller method.
+ self.helper.send(room_id, body="Hello?", tok=other_access_token)
+ expected_push_attempts += 1
+
+ self._advance_time_and_make_push_succeed(expected_push_attempts)
+
+ self.helper.send(room_id, body="Hello??", tok=other_access_token)
+ expected_push_attempts += 1
+
+ self._advance_time_and_make_push_succeed(expected_push_attempts)
+
+ self.helper.send(room_id, body="HELLO???", tok=other_access_token)
+
+ def _advance_time_and_make_push_succeed(self, expected_push_attempts):
self.pump()
+ self.push_attempts[expected_push_attempts - 1][0].callback({})
+ def _check_push_attempt(
+ self, expected_push_attempts: int, expected_unread_count_last_push: int
+ ) -> None:
+ """
+ Makes sure that the last expected push attempt succeeds and checks whether
+ it contains the expected unread count.
+ """
+ self._advance_time_and_make_push_succeed(expected_push_attempts)
# Check our push made it
- self.assertEqual(len(self.push_attempts), 1)
+ self.assertEqual(len(self.push_attempts), expected_push_attempts)
+ _, push_url, push_body = self.push_attempts[expected_push_attempts - 1]
self.assertEqual(
- self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+ push_url,
+ "http://example.com/_matrix/push/v1/notify",
)
-
# Check that the unread count for the room is 0
#
# The unread count is zero as the user has no read receipt in the room yet
self.assertEqual(
- self.push_attempts[0][2]["notification"]["counts"]["unread"], 0
+ push_body["notification"]["counts"]["unread"],
+ expected_unread_count_last_push,
)
+ def _send_read_request(self, access_token, message_event_id, room_id):
# Now set the user's read receipt position to the first event
#
# This will actually trigger a new notification to be sent out so that
@@ -667,56 +714,8 @@ class HTTPPusherTests(HomeserverTestCase):
# count goes down
channel = self.make_request(
"POST",
- "/rooms/%s/receipt/m.read/%s" % (room_id, first_message_event_id),
+ "/rooms/%s/receipt/m.read/%s" % (room_id, message_event_id),
{},
access_token=access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
-
- # Advance time and make the push succeed
- self.push_attempts[1][0].callback({})
- self.pump()
-
- # Unread count is still zero as we've read the only message in the room
- self.assertEqual(len(self.push_attempts), 2)
- self.assertEqual(
- self.push_attempts[1][2]["notification"]["counts"]["unread"], 0
- )
-
- # Send another message
- self.helper.send(
- room_id, body="How's the weather today?", tok=other_access_token
- )
-
- # Advance time and make the push succeed
- self.push_attempts[2][0].callback({})
- self.pump()
-
- # This push should contain an unread count of 1 as there's now been one
- # message since our last read receipt
- self.assertEqual(len(self.push_attempts), 3)
- self.assertEqual(
- self.push_attempts[2][2]["notification"]["counts"]["unread"], 1
- )
-
- # Since we're grouping by room, sending more messages shouldn't increase the
- # unread count, as they're all being sent in the same room
- self.helper.send(room_id, body="Hello?", tok=other_access_token)
-
- # Advance time and make the push succeed
- self.pump()
- self.push_attempts[3][0].callback({})
-
- self.helper.send(room_id, body="Hello??", tok=other_access_token)
-
- # Advance time and make the push succeed
- self.pump()
- self.push_attempts[4][0].callback({})
-
- self.helper.send(room_id, body="HELLO???", tok=other_access_token)
-
- # Advance time and make the push succeed
- self.pump()
- self.push_attempts[5][0].callback({})
-
- self.assertEqual(len(self.push_attempts), 6)
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 89d85b0a17..51146c471d 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -486,8 +486,9 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
{
"user_id": user_id,
"device_id": device_id,
- # Unstable until MSC3069 enters spec
+ # MSC3069 entered spec in Matrix 1.2 but maintained compatibility
"org.matrix.msc3069.is_guest": False,
+ "is_guest": False,
},
)
@@ -505,8 +506,9 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
{
"user_id": user_id,
"device_id": device_id,
- # Unstable until MSC3069 enters spec
+ # MSC3069 entered spec in Matrix 1.2 but maintained compatibility
"org.matrix.msc3069.is_guest": True,
+ "is_guest": True,
},
)
@@ -528,8 +530,9 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
whoami,
{
"user_id": user_id,
- # Unstable until MSC3069 enters spec
+ # MSC3069 entered spec in Matrix 1.2 but maintained compatibility
"org.matrix.msc3069.is_guest": False,
+ "is_guest": False,
},
)
self.assertFalse(hasattr(whoami, "device_id"))
diff --git a/tests/rest/client/test_device_lists.py b/tests/rest/client/test_device_lists.py
new file mode 100644
index 0000000000..16070cf027
--- /dev/null
+++ b/tests/rest/client/test_device_lists.py
@@ -0,0 +1,155 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.rest import admin, devices, room, sync
+from synapse.rest.client import account, login, register
+
+from tests import unittest
+
+
+class DeviceListsTestCase(unittest.HomeserverTestCase):
+ """Tests regarding device list changes."""
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ register.register_servlets,
+ account.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ devices.register_servlets,
+ ]
+
+ def test_receiving_local_device_list_changes(self):
+ """Tests that a local users that share a room receive each other's device list
+ changes.
+ """
+ # Register two users
+ test_device_id = "TESTDEVICE"
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ bob_user_id = self.register_user("bob", "ponyponypony")
+ bob_access_token = self.login(bob_user_id, "ponyponypony")
+
+ # Create a room for them to coexist peacefully in
+ new_room_id = self.helper.create_room_as(
+ alice_user_id, is_public=True, tok=alice_access_token
+ )
+ self.assertIsNotNone(new_room_id)
+
+ # Have Bob join the room
+ self.helper.invite(
+ new_room_id, alice_user_id, bob_user_id, tok=alice_access_token
+ )
+ self.helper.join(new_room_id, bob_user_id, tok=bob_access_token)
+
+ # Now have Bob initiate an initial sync (in order to get a since token)
+ channel = self.make_request(
+ "GET",
+ "/sync",
+ access_token=bob_access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ next_batch_token = channel.json_body["next_batch"]
+
+ # ...and then an incremental sync. This should block until the sync stream is woken up,
+ # which we hope will happen as a result of Alice updating their device list.
+ bob_sync_channel = self.make_request(
+ "GET",
+ f"/sync?since={next_batch_token}&timeout=30000",
+ access_token=bob_access_token,
+ # Start the request, then continue on.
+ await_result=False,
+ )
+
+ # Have alice update their device list
+ channel = self.make_request(
+ "PUT",
+ f"/devices/{test_device_id}",
+ {
+ "display_name": "New Device Name",
+ },
+ access_token=alice_access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check that bob's incremental sync contains the updated device list.
+ # If not, the client would only receive the device list update on the
+ # *next* sync.
+ bob_sync_channel.await_result()
+ self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
+
+ changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
+ "changed", []
+ )
+ self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body)
+
+ def test_not_receiving_local_device_list_changes(self):
+ """Tests a local users DO NOT receive device updates from each other if they do not
+ share a room.
+ """
+ # Register two users
+ test_device_id = "TESTDEVICE"
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ bob_user_id = self.register_user("bob", "ponyponypony")
+ bob_access_token = self.login(bob_user_id, "ponyponypony")
+
+ # These users do not share a room. They are lonely.
+
+ # Have Bob initiate an initial sync (in order to get a since token)
+ channel = self.make_request(
+ "GET",
+ "/sync",
+ access_token=bob_access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ next_batch_token = channel.json_body["next_batch"]
+
+ # ...and then an incremental sync. This should block until the sync stream is woken up,
+ # which we hope will happen as a result of Alice updating their device list.
+ bob_sync_channel = self.make_request(
+ "GET",
+ f"/sync?since={next_batch_token}&timeout=1000",
+ access_token=bob_access_token,
+ # Start the request, then continue on.
+ await_result=False,
+ )
+
+ # Have alice update their device list
+ channel = self.make_request(
+ "PUT",
+ f"/devices/{test_device_id}",
+ {
+ "display_name": "New Device Name",
+ },
+ access_token=alice_access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check that bob's incremental sync does not contain the updated device list.
+ bob_sync_channel.await_result()
+ self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
+
+ changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
+ "changed", []
+ )
+ self.assertNotIn(
+ alice_user_id, changed_device_lists, bob_sync_channel.json_body
+ )
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index de80aca037..dfd9ffcb93 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -1123,6 +1123,48 @@ class RelationsTestCase(unittest.HomeserverTestCase):
{"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
)
+ @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
+ def test_edit_thread(self):
+ """Test that editing a thread works."""
+
+ # Create a thread and edit the last event.
+ channel = self._send_relation(
+ RelationTypes.THREAD,
+ "m.room.message",
+ content={"msgtype": "m.text", "body": "A threaded reply!"},
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ threaded_event_id = channel.json_body["event_id"]
+
+ new_body = {"msgtype": "m.text", "body": "I've been edited!"}
+ channel = self._send_relation(
+ RelationTypes.REPLACE,
+ "m.room.message",
+ content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
+ parent_id=threaded_event_id,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # Fetch the thread root, to get the bundled aggregation for the thread.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/event/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # We expect that the edit message appears in the thread summary in the
+ # unsigned relations section.
+ relations_dict = channel.json_body["unsigned"].get("m.relations")
+ self.assertIn(RelationTypes.THREAD, relations_dict)
+
+ thread_summary = relations_dict[RelationTypes.THREAD]
+ self.assertIn("latest_event", thread_summary)
+ latest_event_in_thread = thread_summary["latest_event"]
+ self.assertEquals(
+ latest_event_in_thread["content"]["body"], "I've been edited!"
+ )
+
def test_edit_edit(self):
"""Test that an edit cannot be edited."""
new_body = {"msgtype": "m.text", "body": "Initial edit"}
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 1c0cb0cf4f..2b3fdadffa 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -106,9 +106,13 @@ class RestHelper:
default room version.
tok: The access token to use in the request.
expect_code: The expected HTTP response code.
+ extra_content: Extra keys to include in the body of the /createRoom request.
+ Note that if is_public is set, the "visibility" key will be overridden.
+ If room_version is set, the "room_version" key will be overridden.
+ custom_headers: HTTP headers to include in the request.
Returns:
- The ID of the newly created room.
+ The ID of the newly created room, or None if the request failed.
"""
temp_id = self.auth_user_id
self.auth_user_id = room_creator
diff --git a/tests/storage/databases/test_state_store.py b/tests/storage/databases/test_state_store.py
new file mode 100644
index 0000000000..3a4a4a3a29
--- /dev/null
+++ b/tests/storage/databases/test_state_store.py
@@ -0,0 +1,283 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import typing
+from typing import Dict, List, Sequence, Tuple
+from unittest.mock import patch
+
+from twisted.internet.defer import Deferred, ensureDeferred
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.constants import EventTypes
+from synapse.storage.state import StateFilter
+from synapse.types import StateMap
+from synapse.util import Clock
+
+from tests.unittest import HomeserverTestCase
+
+if typing.TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+# StateFilter for ALL non-m.room.member state events
+ALL_NON_MEMBERS_STATE_FILTER = StateFilter.freeze(
+ types={EventTypes.Member: set()},
+ include_others=True,
+)
+
+FAKE_STATE = {
+ (EventTypes.Member, "@alice:test"): "join",
+ (EventTypes.Member, "@bob:test"): "leave",
+ (EventTypes.Member, "@charlie:test"): "invite",
+ ("test.type", "a"): "AAA",
+ ("test.type", "b"): "BBB",
+ ("other.event.type", "state.key"): "123",
+}
+
+
+class StateGroupInflightCachingTestCase(HomeserverTestCase):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: "HomeServer"
+ ) -> None:
+ self.state_storage = homeserver.get_storage().state
+ self.state_datastore = homeserver.get_datastores().state
+ # Patch out the `_get_state_groups_from_groups`.
+ # This is useful because it lets us pretend we have a slow database.
+ get_state_groups_patch = patch.object(
+ self.state_datastore,
+ "_get_state_groups_from_groups",
+ self._fake_get_state_groups_from_groups,
+ )
+ get_state_groups_patch.start()
+
+ self.addCleanup(get_state_groups_patch.stop)
+ self.get_state_group_calls: List[
+ Tuple[Tuple[int, ...], StateFilter, Deferred[Dict[int, StateMap[str]]]]
+ ] = []
+
+ def _fake_get_state_groups_from_groups(
+ self, groups: Sequence[int], state_filter: StateFilter
+ ) -> "Deferred[Dict[int, StateMap[str]]]":
+ d: Deferred[Dict[int, StateMap[str]]] = Deferred()
+ self.get_state_group_calls.append((tuple(groups), state_filter, d))
+ return d
+
+ def _complete_request_fake(
+ self,
+ groups: Tuple[int, ...],
+ state_filter: StateFilter,
+ d: "Deferred[Dict[int, StateMap[str]]]",
+ ) -> None:
+ """
+ Assemble a fake database response and complete the database request.
+ """
+
+ # Return a filtered copy of the fake state
+ d.callback({group: state_filter.filter_state(FAKE_STATE) for group in groups})
+
+ def test_duplicate_requests_deduplicated(self) -> None:
+ """
+ Tests that duplicate requests for state are deduplicated.
+
+ This test:
+ - requests some state (state group 42, 'all' state filter)
+ - requests it again, before the first request finishes
+ - checks to see that only one database query was made
+ - completes the database query
+ - checks that both requests see the same retrieved state
+ """
+ req1 = ensureDeferred(
+ self.state_datastore._get_state_for_group_using_inflight_cache(
+ 42, StateFilter.all()
+ )
+ )
+ self.pump(by=0.1)
+
+ # This should have gone to the database
+ self.assertEqual(len(self.get_state_group_calls), 1)
+ self.assertFalse(req1.called)
+
+ req2 = ensureDeferred(
+ self.state_datastore._get_state_for_group_using_inflight_cache(
+ 42, StateFilter.all()
+ )
+ )
+ self.pump(by=0.1)
+
+ # No more calls should have gone to the database
+ self.assertEqual(len(self.get_state_group_calls), 1)
+ self.assertFalse(req1.called)
+ self.assertFalse(req2.called)
+
+ groups, sf, d = self.get_state_group_calls[0]
+ self.assertEqual(groups, (42,))
+ self.assertEqual(sf, StateFilter.all())
+
+ # Now we can complete the request
+ self._complete_request_fake(groups, sf, d)
+
+ self.assertEqual(self.get_success(req1), FAKE_STATE)
+ self.assertEqual(self.get_success(req2), FAKE_STATE)
+
+ def test_smaller_request_deduplicated(self) -> None:
+ """
+ Tests that duplicate requests for state are deduplicated.
+
+ This test:
+ - requests some state (state group 42, 'all' state filter)
+ - requests a subset of that state, before the first request finishes
+ - checks to see that only one database query was made
+ - completes the database query
+ - checks that both requests see the correct retrieved state
+ """
+ req1 = ensureDeferred(
+ self.state_datastore._get_state_for_group_using_inflight_cache(
+ 42, StateFilter.from_types((("test.type", None),))
+ )
+ )
+ self.pump(by=0.1)
+
+ # This should have gone to the database
+ self.assertEqual(len(self.get_state_group_calls), 1)
+ self.assertFalse(req1.called)
+
+ req2 = ensureDeferred(
+ self.state_datastore._get_state_for_group_using_inflight_cache(
+ 42, StateFilter.from_types((("test.type", "b"),))
+ )
+ )
+ self.pump(by=0.1)
+
+ # No more calls should have gone to the database, because the second
+ # request was already in the in-flight cache!
+ self.assertEqual(len(self.get_state_group_calls), 1)
+ self.assertFalse(req1.called)
+ self.assertFalse(req2.called)
+
+ groups, sf, d = self.get_state_group_calls[0]
+ self.assertEqual(groups, (42,))
+ # The state filter is expanded internally for increased cache hit rate,
+ # so we the database sees a wider state filter than requested.
+ self.assertEqual(sf, ALL_NON_MEMBERS_STATE_FILTER)
+
+ # Now we can complete the request
+ self._complete_request_fake(groups, sf, d)
+
+ self.assertEqual(
+ self.get_success(req1),
+ {("test.type", "a"): "AAA", ("test.type", "b"): "BBB"},
+ )
+ self.assertEqual(self.get_success(req2), {("test.type", "b"): "BBB"})
+
+ def test_partially_overlapping_request_deduplicated(self) -> None:
+ """
+ Tests that partially-overlapping requests are partially deduplicated.
+
+ This test:
+ - requests a single type of wildcard state
+ (This is internally expanded to be all non-member state)
+ - requests the entire state in parallel
+ - checks to see that two database queries were made, but that the second
+ one is only for member state.
+ - completes the database queries
+ - checks that both requests have the correct result.
+ """
+
+ req1 = ensureDeferred(
+ self.state_datastore._get_state_for_group_using_inflight_cache(
+ 42, StateFilter.from_types((("test.type", None),))
+ )
+ )
+ self.pump(by=0.1)
+
+ # This should have gone to the database
+ self.assertEqual(len(self.get_state_group_calls), 1)
+ self.assertFalse(req1.called)
+
+ req2 = ensureDeferred(
+ self.state_datastore._get_state_for_group_using_inflight_cache(
+ 42, StateFilter.all()
+ )
+ )
+ self.pump(by=0.1)
+
+ # Because it only partially overlaps, this also went to the database
+ self.assertEqual(len(self.get_state_group_calls), 2)
+ self.assertFalse(req1.called)
+ self.assertFalse(req2.called)
+
+ # First request:
+ groups, sf, d = self.get_state_group_calls[0]
+ self.assertEqual(groups, (42,))
+ # The state filter is expanded internally for increased cache hit rate,
+ # so we the database sees a wider state filter than requested.
+ self.assertEqual(sf, ALL_NON_MEMBERS_STATE_FILTER)
+ self._complete_request_fake(groups, sf, d)
+
+ # Second request:
+ groups, sf, d = self.get_state_group_calls[1]
+ self.assertEqual(groups, (42,))
+ # The state filter is narrowed to only request membership state, because
+ # the remainder of the state is already being queried in the first request!
+ self.assertEqual(
+ sf, StateFilter.freeze({EventTypes.Member: None}, include_others=False)
+ )
+ self._complete_request_fake(groups, sf, d)
+
+ # Check the results are correct
+ self.assertEqual(
+ self.get_success(req1),
+ {("test.type", "a"): "AAA", ("test.type", "b"): "BBB"},
+ )
+ self.assertEqual(self.get_success(req2), FAKE_STATE)
+
+ def test_in_flight_requests_stop_being_in_flight(self) -> None:
+ """
+ Tests that in-flight request deduplication doesn't somehow 'hold on'
+ to completed requests: once they're done, they're taken out of the
+ in-flight cache.
+ """
+ req1 = ensureDeferred(
+ self.state_datastore._get_state_for_group_using_inflight_cache(
+ 42, StateFilter.all()
+ )
+ )
+ self.pump(by=0.1)
+
+ # This should have gone to the database
+ self.assertEqual(len(self.get_state_group_calls), 1)
+ self.assertFalse(req1.called)
+
+ # Complete the request right away.
+ self._complete_request_fake(*self.get_state_group_calls[0])
+ self.assertTrue(req1.called)
+
+ # Send off another request
+ req2 = ensureDeferred(
+ self.state_datastore._get_state_for_group_using_inflight_cache(
+ 42, StateFilter.all()
+ )
+ )
+ self.pump(by=0.1)
+
+ # It should have gone to the database again, because the previous request
+ # isn't in-flight and therefore isn't available for deduplication.
+ self.assertEqual(len(self.get_state_group_calls), 2)
+ self.assertFalse(req2.called)
+
+ # Complete the request right away.
+ self._complete_request_fake(*self.get_state_group_calls[1])
+ self.assertTrue(req2.called)
+ groups, sf, d = self.get_state_group_calls[0]
+
+ self.assertEqual(self.get_success(req1), FAKE_STATE)
+ self.assertEqual(self.get_success(req2), FAKE_STATE)
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index f462a8b1c7..a8639d8f82 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -329,3 +329,110 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check the new extremity is just the new remote event.
self.assert_extremities([local_message_event_id, remote_event_2.event_id])
+
+
+class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.state = self.hs.get_state_handler()
+ self.persistence = self.hs.get_storage().persistence
+ self.store = self.hs.get_datastore()
+
+ def test_remote_user_rooms_cache_invalidated(self):
+ """Test that if the server leaves a room the `get_rooms_for_user` cache
+ is invalidated for remote users.
+ """
+
+ # Set up a room with a local and remote user in it.
+ user_id = self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ room_id = self.helper.create_room_as(
+ "user", room_version=RoomVersions.V6.identifier, tok=token
+ )
+
+ body = self.helper.send(room_id, body="Test", tok=token)
+ local_message_event_id = body["event_id"]
+
+ # Fudge a join event for a remote user.
+ remote_user = "@user:other"
+ remote_event_1 = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "state_key": remote_user,
+ "content": {"membership": Membership.JOIN},
+ "room_id": room_id,
+ "sender": remote_user,
+ "depth": 5,
+ "prev_events": [local_message_event_id],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ context = self.get_success(self.state.compute_event_context(remote_event_1))
+ self.get_success(self.persistence.persist_event(remote_event_1, context))
+
+ # Call `get_rooms_for_user` to add the remote user to the cache
+ rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
+ self.assertEqual(set(rooms), {room_id})
+
+ # Now we have the local server leave the room, and check that calling
+ # `get_user_in_room` for the remote user no longer includes the room.
+ self.helper.leave(room_id, user_id, tok=token)
+
+ rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
+ self.assertEqual(set(rooms), set())
+
+ def test_room_remote_user_cache_invalidated(self):
+ """Test that if the server leaves a room the `get_users_in_room` cache
+ is invalidated for remote users.
+ """
+
+ # Set up a room with a local and remote user in it.
+ user_id = self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ room_id = self.helper.create_room_as(
+ "user", room_version=RoomVersions.V6.identifier, tok=token
+ )
+
+ body = self.helper.send(room_id, body="Test", tok=token)
+ local_message_event_id = body["event_id"]
+
+ # Fudge a join event for a remote user.
+ remote_user = "@user:other"
+ remote_event_1 = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "state_key": remote_user,
+ "content": {"membership": Membership.JOIN},
+ "room_id": room_id,
+ "sender": remote_user,
+ "depth": 5,
+ "prev_events": [local_message_event_id],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ RoomVersions.V6,
+ )
+
+ context = self.get_success(self.state.compute_event_context(remote_event_1))
+ self.get_success(self.persistence.persist_event(remote_event_1, context))
+
+ # Call `get_users_in_room` to add the remote user to the cache
+ users = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertEqual(set(users), {user_id, remote_user})
+
+ # Now we have the local server leave the room, and check that calling
+ # `get_user_in_room` for the remote user no longer includes the room.
+ self.helper.leave(room_id, user_id, tok=token)
+
+ users = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertEqual(users, [])
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 70d52b088c..28c767ecfd 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -992,3 +992,112 @@ class StateFilterDifferenceTestCase(TestCase):
StateFilter.none(),
StateFilter.all(),
)
+
+
+class StateFilterTestCase(TestCase):
+ def test_return_expanded(self):
+ """
+ Tests the behaviour of the return_expanded() function that expands
+ StateFilters to include more state types (for the sake of cache hit rate).
+ """
+
+ self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all())
+
+ self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none())
+
+ # Concrete-only state filters stay the same
+ # (Case: mixed filter)
+ self.assertEqual(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:test", "@alicia:test"},
+ "some.other.state.type": {""},
+ },
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:test", "@alicia:test"},
+ "some.other.state.type": {""},
+ },
+ include_others=False,
+ ),
+ )
+
+ # Concrete-only state filters stay the same
+ # (Case: non-member-only filter)
+ self.assertEqual(
+ StateFilter.freeze(
+ {"some.other.state.type": {""}}, include_others=False
+ ).return_expanded(),
+ StateFilter.freeze({"some.other.state.type": {""}}, include_others=False),
+ )
+
+ # Concrete-only state filters stay the same
+ # (Case: member-only filter)
+ self.assertEqual(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:test", "@alicia:test"},
+ },
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:test", "@alicia:test"},
+ },
+ include_others=False,
+ ),
+ )
+
+ # Wildcard member-only state filters stay the same
+ self.assertEqual(
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ )
+
+ # If there is a wildcard in the non-member portion of the filter,
+ # it's expanded to include ALL non-member events.
+ # (Case: mixed filter)
+ self.assertEqual(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:test", "@alicia:test"},
+ "some.other.state.type": None,
+ },
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze(
+ {EventTypes.Member: {"@wombat:test", "@alicia:test"}},
+ include_others=True,
+ ),
+ )
+
+ # If there is a wildcard in the non-member portion of the filter,
+ # it's expanded to include ALL non-member events.
+ # (Case: non-member-only filter)
+ self.assertEqual(
+ StateFilter.freeze(
+ {
+ "some.other.state.type": None,
+ },
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
+ )
+ self.assertEqual(
+ StateFilter.freeze(
+ {
+ "some.other.state.type": None,
+ "yet.another.state.type": {"wombat"},
+ },
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
+ )
diff --git a/tests/unittest.py b/tests/unittest.py
index a71892cb9d..7983c1e8b8 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -51,7 +51,10 @@ from twisted.web.server import Request
from synapse import events
from synapse.api.constants import EventTypes, Membership
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.config.homeserver import HomeServerConfig
+from synapse.config.server import DEFAULT_ROOM_VERSION
+from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.federation.transport.server import TransportLayerServer
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite
@@ -839,6 +842,24 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
client_ip=client_ip,
)
+ def add_hashes_and_signatures(
+ self,
+ event_dict: JsonDict,
+ room_version: RoomVersion = KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
+ ) -> JsonDict:
+ """Adds hashes and signatures to the given event dict
+
+ Returns:
+ The modified event dict, for convenience
+ """
+ add_hashes_and_signatures(
+ room_version,
+ event_dict,
+ signature_name=self.OTHER_SERVER_NAME,
+ signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
+ )
+ return event_dict
+
def _auth_header_for_request(
origin: str,
diff --git a/tox.ini b/tox.ini
index 32679e9106..436ecf7552 100644
--- a/tox.ini
+++ b/tox.ini
@@ -4,6 +4,11 @@ envlist = packaging, py37, py38, py39, py310, check_codestyle, check_isort
# we require tox>=2.3.2 for the fix to https://github.com/tox-dev/tox/issues/208
minversion = 2.3.2
+# the tox-venv plugin makes tox use python's built-in `venv` module rather than
+# the legacy `virtualenv` tool. `virtualenv` embeds its own `pip`, `setuptools`,
+# etc, and ends up being rather unreliable.
+requires = tox-venv
+
[base]
deps =
python-subunit
@@ -119,6 +124,9 @@ usedevelop = false
deps =
Automat == 0.8.0
lxml
+ # markupsafe 2.1 introduced a change that breaks Jinja 2.x. Since we depend on
+ # Jinja >= 2.9, it means this test suite will fail if markupsafe >= 2.1 is installed.
+ markupsafe < 2.1
{[base]deps}
commands =
@@ -158,7 +166,7 @@ commands =
[testenv:check_isort]
extras = lint
-commands = isort -c --df --sp setup.cfg {[base]lint_targets}
+commands = isort -c --df {[base]lint_targets}
[testenv:check-newsfragment]
skip_install = true
|