summary refs log tree commit diff
diff options
context:
space:
mode:
-rwxr-xr-x.ci/scripts/test_old_deps.sh4
-rw-r--r--.github/workflows/release-artifacts.yml14
-rw-r--r--.github/workflows/tests.yml17
-rw-r--r--changelog.d/10870.misc1
-rw-r--r--changelog.d/11835.feature1
-rw-r--r--changelog.d/11972.misc1
-rw-r--r--changelog.d/11974.misc1
-rw-r--r--changelog.d/11984.misc1
-rw-r--r--changelog.d/11985.feature1
-rw-r--r--changelog.d/11991.misc1
-rw-r--r--changelog.d/11992.bugfix1
-rw-r--r--changelog.d/11994.misc1
-rw-r--r--changelog.d/11996.misc1
-rw-r--r--changelog.d/11997.docker1
-rw-r--r--changelog.d/11999.bugfix1
-rw-r--r--changelog.d/12000.feature1
-rw-r--r--changelog.d/12003.doc1
-rw-r--r--changelog.d/12004.doc1
-rw-r--r--changelog.d/12005.misc1
-rw-r--r--changelog.d/12008.removal1
-rw-r--r--changelog.d/12009.feature1
-rw-r--r--changelog.d/12011.misc1
-rw-r--r--changelog.d/12013.misc1
-rw-r--r--changelog.d/12015.misc1
-rw-r--r--changelog.d/12016.misc1
-rw-r--r--changelog.d/12018.removal1
-rw-r--r--changelog.d/12019.misc1
-rw-r--r--changelog.d/12020.feature1
-rw-r--r--changelog.d/12021.feature1
-rw-r--r--changelog.d/12022.feature1
-rw-r--r--changelog.d/12024.bugfix1
-rw-r--r--changelog.d/12025.misc1
-rw-r--r--changelog.d/12030.misc1
-rw-r--r--changelog.d/12033.misc1
-rw-r--r--changelog.d/12034.misc1
-rw-r--r--changelog.d/12039.misc1
-rw-r--r--changelog.d/12051.misc1
-rw-r--r--changelog.d/12052.misc1
-rw-r--r--docker/Dockerfile2
-rw-r--r--docs/admin_api/user_admin_api.md3
-rw-r--r--docs/modules/password_auth_provider_callbacks.md35
-rw-r--r--docs/modules/spam_checker_callbacks.md40
-rw-r--r--docs/structured_logging.md14
-rw-r--r--docs/upgrade.md9
-rw-r--r--mypy.ini3
-rw-r--r--pyproject.toml12
-rw-r--r--setup.cfg11
-rwxr-xr-xsetup.py4
-rw-r--r--stubs/sortedcontainers/sorteddict.pyi13
-rw-r--r--synapse/config/experimental.py7
-rw-r--r--synapse/config/logger.py12
-rw-r--r--synapse/events/utils.py69
-rw-r--r--synapse/federation/federation_base.py10
-rw-r--r--synapse/federation/federation_client.py114
-rw-r--r--synapse/federation/sender/per_destination_queue.py18
-rw-r--r--synapse/federation/transport/client.py188
-rw-r--r--synapse/handlers/auth.py58
-rw-r--r--synapse/handlers/federation.py2
-rw-r--r--synapse/handlers/federation_event.py46
-rw-r--r--synapse/handlers/message.py10
-rw-r--r--synapse/handlers/presence.py26
-rw-r--r--synapse/handlers/register.py6
-rw-r--r--synapse/handlers/room_member.py46
-rw-r--r--synapse/handlers/search.py643
-rw-r--r--synapse/handlers/sync.py68
-rw-r--r--synapse/http/matrixfederationclient.py50
-rw-r--r--synapse/logging/_structured.py163
-rw-r--r--synapse/module_api/__init__.py11
-rw-r--r--synapse/notifier.py4
-rw-r--r--synapse/push/baserules.py10
-rw-r--r--synapse/push/httppusher.py9
-rw-r--r--synapse/python_dependencies.py3
-rw-r--r--synapse/res/providers.json4
-rw-r--r--synapse/rest/client/account.py2
-rw-r--r--synapse/rest/client/auth.py8
-rw-r--r--synapse/rest/client/capabilities.py14
-rw-r--r--synapse/rest/client/register.py7
-rw-r--r--synapse/rest/client/versions.py2
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py10
-rw-r--r--synapse/storage/databases/main/devices.py10
-rw-r--r--synapse/storage/databases/main/events.py27
-rw-r--r--synapse/storage/databases/main/events_worker.py4
-rw-r--r--synapse/storage/databases/main/presence.py61
-rw-r--r--synapse/storage/databases/main/purge_events.py13
-rw-r--r--synapse/storage/databases/main/registration.py21
-rw-r--r--synapse/storage/databases/main/relations.py26
-rw-r--r--synapse/storage/databases/main/room.py4
-rw-r--r--synapse/storage/databases/main/roommember.py62
-rw-r--r--synapse/storage/databases/main/search.py17
-rw-r--r--synapse/storage/databases/main/user_directory.py22
-rw-r--r--synapse/storage/databases/state/store.py203
-rw-r--r--synapse/storage/state.py8
-rw-r--r--synapse/streams/events.py6
-rw-r--r--synapse/types.py6
-rw-r--r--synapse/util/caches/__init__.py1
-rw-r--r--synapse/util/caches/expiringcache.py5
-rw-r--r--synapse/util/caches/lrucache.py4
-rw-r--r--synapse/util/daemonize.py8
-rw-r--r--synapse/util/patch_inline_callbacks.py6
-rw-r--r--tests/federation/test_federation_client.py149
-rw-r--r--tests/federation/transport/test_client.py32
-rw-r--r--tests/handlers/test_password_providers.py123
-rw-r--r--tests/push/test_http.py129
-rw-r--r--tests/rest/client/test_account.py9
-rw-r--r--tests/rest/client/test_device_lists.py155
-rw-r--r--tests/rest/client/test_relations.py42
-rw-r--r--tests/rest/client/utils.py6
-rw-r--r--tests/storage/databases/test_state_store.py283
-rw-r--r--tests/storage/test_events.py107
-rw-r--r--tests/storage/test_state.py109
-rw-r--r--tests/unittest.py21
-rw-r--r--tox.ini10
112 files changed, 2663 insertions, 862 deletions
diff --git a/.ci/scripts/test_old_deps.sh b/.ci/scripts/test_old_deps.sh
index 54ec3c8b0d..b2859f7522 100755
--- a/.ci/scripts/test_old_deps.sh
+++ b/.ci/scripts/test_old_deps.sh
@@ -8,7 +8,9 @@ export DEBIAN_FRONTEND=noninteractive
 set -ex
 
 apt-get update
-apt-get install -y python3 python3-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox libjpeg-dev libwebp-dev
+apt-get install -y \
+        python3 python3-dev python3-pip python3-venv \
+        libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox libjpeg-dev libwebp-dev
 
 export LANG="C.UTF-8"
 
diff --git a/.github/workflows/release-artifacts.yml b/.github/workflows/release-artifacts.yml
index eb294f1619..eee3633d50 100644
--- a/.github/workflows/release-artifacts.yml
+++ b/.github/workflows/release-artifacts.yml
@@ -7,7 +7,7 @@ on:
   # of things breaking (but only build one set of debs)
   pull_request:
   push:
-    branches: ["develop"]
+    branches: ["develop", "release-*"]
 
     # we do the full build on tags.
     tags: ["v*"]
@@ -91,17 +91,7 @@ jobs:
 
   build-sdist:
     name: "Build pypi distribution files"
-    runs-on: ubuntu-latest
-    steps:
-      - uses: actions/checkout@v2
-      - uses: actions/setup-python@v2
-      - run: pip install wheel
-      - run: |
-          python setup.py sdist bdist_wheel
-      - uses: actions/upload-artifact@v2
-        with:
-          name: python-dist
-          path: dist/*
+    uses: "matrix-org/backend-meta/.github/workflows/packaging.yml@v1"
 
   # if it's a tag, create a release and attach the artifacts to it
   attach-assets:
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 75ac1304bf..bbf1033bdd 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -48,24 +48,10 @@ jobs:
         env:
           PULL_REQUEST_NUMBER: ${{ github.event.number }}
 
-  lint-sdist:
-    runs-on: ubuntu-latest
-    steps:
-      - uses: actions/checkout@v2
-      - uses: actions/setup-python@v2
-        with:
-          python-version: "3.x"
-      - run: pip install wheel
-      - run: python setup.py sdist bdist_wheel
-      - uses: actions/upload-artifact@v2
-        with:
-          name: Python Distributions
-          path: dist/*
-
   # Dummy step to gate other tests on without repeating the whole list
   linting-done:
     if: ${{ !cancelled() }} # Run this even if prior jobs were skipped
-    needs: [lint, lint-crlf, lint-newsfile, lint-sdist]
+    needs: [lint, lint-crlf, lint-newsfile]
     runs-on: ubuntu-latest
     steps:
       - run: "true"
@@ -397,7 +383,6 @@ jobs:
       - lint
       - lint-crlf
       - lint-newsfile
-      - lint-sdist
       - trial
       - trial-olddeps
       - sytest
diff --git a/changelog.d/10870.misc b/changelog.d/10870.misc
new file mode 100644
index 0000000000..3af049b969
--- /dev/null
+++ b/changelog.d/10870.misc
@@ -0,0 +1 @@
+Deduplicate in-flight requests in `_get_state_for_groups`.
diff --git a/changelog.d/11835.feature b/changelog.d/11835.feature
new file mode 100644
index 0000000000..7cee39b08c
--- /dev/null
+++ b/changelog.d/11835.feature
@@ -0,0 +1 @@
+Make a `POST` to `/rooms/<room_id>/receipt/m.read/<event_id>` only trigger a push notification if the count of unread messages is different to the one in the last successfully sent push.
diff --git a/changelog.d/11972.misc b/changelog.d/11972.misc
new file mode 100644
index 0000000000..29c38bfd82
--- /dev/null
+++ b/changelog.d/11972.misc
@@ -0,0 +1 @@
+Add tests for device list changes between local users.
\ No newline at end of file
diff --git a/changelog.d/11974.misc b/changelog.d/11974.misc
new file mode 100644
index 0000000000..1debad2361
--- /dev/null
+++ b/changelog.d/11974.misc
@@ -0,0 +1 @@
+Optimise calculating device_list changes in `/sync`.
diff --git a/changelog.d/11984.misc b/changelog.d/11984.misc
new file mode 100644
index 0000000000..8e405b9226
--- /dev/null
+++ b/changelog.d/11984.misc
@@ -0,0 +1 @@
+Add missing type hints to storage classes.
\ No newline at end of file
diff --git a/changelog.d/11985.feature b/changelog.d/11985.feature
new file mode 100644
index 0000000000..120d888a49
--- /dev/null
+++ b/changelog.d/11985.feature
@@ -0,0 +1 @@
+Fetch images when previewing Twitter URLs. Contributed by @AndrewRyanChama.
diff --git a/changelog.d/11991.misc b/changelog.d/11991.misc
new file mode 100644
index 0000000000..34a3b3a6b9
--- /dev/null
+++ b/changelog.d/11991.misc
@@ -0,0 +1 @@
+Refactor the search code for improved readability.
diff --git a/changelog.d/11992.bugfix b/changelog.d/11992.bugfix
new file mode 100644
index 0000000000..f73c86bb25
--- /dev/null
+++ b/changelog.d/11992.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in Synapse v1.48.0 where an edit of the latest event in a thread would not be properly applied to the thread summary.
diff --git a/changelog.d/11994.misc b/changelog.d/11994.misc
new file mode 100644
index 0000000000..d64297dd78
--- /dev/null
+++ b/changelog.d/11994.misc
@@ -0,0 +1 @@
+Move common deduplication code down into `_auth_and_persist_outliers`.
diff --git a/changelog.d/11996.misc b/changelog.d/11996.misc
new file mode 100644
index 0000000000..6c675fd193
--- /dev/null
+++ b/changelog.d/11996.misc
@@ -0,0 +1 @@
+Limit concurrent joins from applications services.
\ No newline at end of file
diff --git a/changelog.d/11997.docker b/changelog.d/11997.docker
new file mode 100644
index 0000000000..1b3271457e
--- /dev/null
+++ b/changelog.d/11997.docker
@@ -0,0 +1 @@
+The docker image no longer automatically creates a temporary volume at `/data`. This is not expected to affect normal usage.
diff --git a/changelog.d/11999.bugfix b/changelog.d/11999.bugfix
new file mode 100644
index 0000000000..fd84095900
--- /dev/null
+++ b/changelog.d/11999.bugfix
@@ -0,0 +1 @@
+Fix long standing bug where `get_rooms_for_user` was not correctly invalidated for remote users when the server left a room.
diff --git a/changelog.d/12000.feature b/changelog.d/12000.feature
new file mode 100644
index 0000000000..246cc87f0b
--- /dev/null
+++ b/changelog.d/12000.feature
@@ -0,0 +1 @@
+Track cache invalidations in Prometheus metrics, as already happens for cache eviction based on size or time.
diff --git a/changelog.d/12003.doc b/changelog.d/12003.doc
new file mode 100644
index 0000000000..1ac8163559
--- /dev/null
+++ b/changelog.d/12003.doc
@@ -0,0 +1 @@
+Explain the meaning of spam checker callbacks' return values.
diff --git a/changelog.d/12004.doc b/changelog.d/12004.doc
new file mode 100644
index 0000000000..0b4baef210
--- /dev/null
+++ b/changelog.d/12004.doc
@@ -0,0 +1 @@
+Clarify information about external Identity Provider IDs.
diff --git a/changelog.d/12005.misc b/changelog.d/12005.misc
new file mode 100644
index 0000000000..45e21dbe59
--- /dev/null
+++ b/changelog.d/12005.misc
@@ -0,0 +1 @@
+Preparation for faster-room-join work: when parsing the `send_join` response, get the `m.room.create` event from `state`, not `auth_chain`.
diff --git a/changelog.d/12008.removal b/changelog.d/12008.removal
new file mode 100644
index 0000000000..57599d9ee9
--- /dev/null
+++ b/changelog.d/12008.removal
@@ -0,0 +1 @@
+Remove support for the legacy structured logging configuration (please see the the [upgrade notes](https://matrix-org.github.io/synapse/develop/upgrade#legacy-structured-logging-configuration-removal) if you are using `structured: true` in the Synapse configuration).
diff --git a/changelog.d/12009.feature b/changelog.d/12009.feature
new file mode 100644
index 0000000000..c8a531481e
--- /dev/null
+++ b/changelog.d/12009.feature
@@ -0,0 +1 @@
+Enable modules to set a custom display name when registering a user.
diff --git a/changelog.d/12011.misc b/changelog.d/12011.misc
new file mode 100644
index 0000000000..258b0e389f
--- /dev/null
+++ b/changelog.d/12011.misc
@@ -0,0 +1 @@
+Preparation for faster-room-join work: parse msc3706 fields in send_join response.
diff --git a/changelog.d/12013.misc b/changelog.d/12013.misc
new file mode 100644
index 0000000000..c0fca8dccb
--- /dev/null
+++ b/changelog.d/12013.misc
@@ -0,0 +1 @@
+Preparation for faster-room-join work: Support for calling `/federation/v1/state` on a remote server.
diff --git a/changelog.d/12015.misc b/changelog.d/12015.misc
new file mode 100644
index 0000000000..3aa32ab4cf
--- /dev/null
+++ b/changelog.d/12015.misc
@@ -0,0 +1 @@
+Configure `tox` to use `venv` rather than `virtualenv`.
diff --git a/changelog.d/12016.misc b/changelog.d/12016.misc
new file mode 100644
index 0000000000..8856ef46a9
--- /dev/null
+++ b/changelog.d/12016.misc
@@ -0,0 +1 @@
+Fix bug in `StateFilter.return_expanded()` and add some tests.
\ No newline at end of file
diff --git a/changelog.d/12018.removal b/changelog.d/12018.removal
new file mode 100644
index 0000000000..e940b62228
--- /dev/null
+++ b/changelog.d/12018.removal
@@ -0,0 +1 @@
+Drop support for [MSC3283](https://github.com/matrix-org/matrix-doc/pull/3283) unstable flags now that the stable flags are supported.
diff --git a/changelog.d/12019.misc b/changelog.d/12019.misc
new file mode 100644
index 0000000000..b2186320ea
--- /dev/null
+++ b/changelog.d/12019.misc
@@ -0,0 +1 @@
+Use Matrix v1.1 endpoints (`/_matrix/client/v3/auth/...`) in fallback auth HTML forms.
\ No newline at end of file
diff --git a/changelog.d/12020.feature b/changelog.d/12020.feature
new file mode 100644
index 0000000000..1ac9d2060e
--- /dev/null
+++ b/changelog.d/12020.feature
@@ -0,0 +1 @@
+Advertise Matrix 1.1 support on `/_matrix/client/versions`.
\ No newline at end of file
diff --git a/changelog.d/12021.feature b/changelog.d/12021.feature
new file mode 100644
index 0000000000..01378df8ca
--- /dev/null
+++ b/changelog.d/12021.feature
@@ -0,0 +1 @@
+Support only the stable identifier for [MSC3069](https://github.com/matrix-org/matrix-doc/pull/3069)'s `is_guest` on `/_matrix/client/v3/account/whoami`.
\ No newline at end of file
diff --git a/changelog.d/12022.feature b/changelog.d/12022.feature
new file mode 100644
index 0000000000..188fb12570
--- /dev/null
+++ b/changelog.d/12022.feature
@@ -0,0 +1 @@
+Advertise Matrix 1.2 support on `/_matrix/client/versions`.
\ No newline at end of file
diff --git a/changelog.d/12024.bugfix b/changelog.d/12024.bugfix
new file mode 100644
index 0000000000..59bcdb93a5
--- /dev/null
+++ b/changelog.d/12024.bugfix
@@ -0,0 +1 @@
+Fix 500 error with Postgres when looking backwards with the [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) `/timestamp_to_event?dir=b` endpoint.
diff --git a/changelog.d/12025.misc b/changelog.d/12025.misc
new file mode 100644
index 0000000000..d9475a7718
--- /dev/null
+++ b/changelog.d/12025.misc
@@ -0,0 +1 @@
+Update the `olddeps` CI job to use an old version of `markupsafe`.
diff --git a/changelog.d/12030.misc b/changelog.d/12030.misc
new file mode 100644
index 0000000000..607ee97ce6
--- /dev/null
+++ b/changelog.d/12030.misc
@@ -0,0 +1 @@
+Upgrade mypy to version 0.931.
diff --git a/changelog.d/12033.misc b/changelog.d/12033.misc
new file mode 100644
index 0000000000..3af049b969
--- /dev/null
+++ b/changelog.d/12033.misc
@@ -0,0 +1 @@
+Deduplicate in-flight requests in `_get_state_for_groups`.
diff --git a/changelog.d/12034.misc b/changelog.d/12034.misc
new file mode 100644
index 0000000000..8374a63220
--- /dev/null
+++ b/changelog.d/12034.misc
@@ -0,0 +1 @@
+Minor typing fixes.
diff --git a/changelog.d/12039.misc b/changelog.d/12039.misc
new file mode 100644
index 0000000000..45e21dbe59
--- /dev/null
+++ b/changelog.d/12039.misc
@@ -0,0 +1 @@
+Preparation for faster-room-join work: when parsing the `send_join` response, get the `m.room.create` event from `state`, not `auth_chain`.
diff --git a/changelog.d/12051.misc b/changelog.d/12051.misc
new file mode 100644
index 0000000000..9959191352
--- /dev/null
+++ b/changelog.d/12051.misc
@@ -0,0 +1 @@
+Tidy up GitHub Actions config which builds distributions for PyPI.
\ No newline at end of file
diff --git a/changelog.d/12052.misc b/changelog.d/12052.misc
new file mode 100644
index 0000000000..fbaff67e95
--- /dev/null
+++ b/changelog.d/12052.misc
@@ -0,0 +1 @@
+Move `isort` configuration to `pyproject.toml`.
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 306f75ae56..e4c1c19b86 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -98,8 +98,6 @@ COPY --from=builder /install /usr/local
 COPY ./docker/start.py /start.py
 COPY ./docker/conf /conf
 
-VOLUME ["/data"]
-
 EXPOSE 8008/tcp 8009/tcp 8448/tcp
 
 ENTRYPOINT ["/start.py"]
diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md
index 1bbe237080..4076fcab65 100644
--- a/docs/admin_api/user_admin_api.md
+++ b/docs/admin_api/user_admin_api.md
@@ -126,7 +126,8 @@ Body parameters:
   [Sample Configuration File](../usage/configuration/homeserver_sample_config.html)
   section `sso` and `oidc_providers`.
   - `auth_provider` - string. ID of the external identity provider. Value of `idp_id`
-    in homeserver configuration.
+    in the homeserver configuration. Note that no error is raised if the provided
+    value is not in the homeserver configuration.
   - `external_id` - string, user ID in the external identity provider.
 - `avatar_url` - string, optional, must be a
   [MXC URI](https://matrix.org/docs/spec/client_server/r0.6.0#matrix-content-mxc-uris).
diff --git a/docs/modules/password_auth_provider_callbacks.md b/docs/modules/password_auth_provider_callbacks.md
index 88b59bb09e..ec810fd292 100644
--- a/docs/modules/password_auth_provider_callbacks.md
+++ b/docs/modules/password_auth_provider_callbacks.md
@@ -85,7 +85,7 @@ If the authentication is unsuccessful, the module must return `None`.
 If multiple modules implement this callback, they will be considered in order. If a
 callback returns `None`, Synapse falls through to the next one. The value of the first
 callback that does not return `None` will be used. If this happens, Synapse will not call
-any of the subsequent implementations of this callback. If every callback return `None`,
+any of the subsequent implementations of this callback. If every callback returns `None`,
 the authentication is denied.
 
 ### `on_logged_out`
@@ -162,10 +162,38 @@ return `None`.
 If multiple modules implement this callback, they will be considered in order. If a
 callback returns `None`, Synapse falls through to the next one. The value of the first
 callback that does not return `None` will be used. If this happens, Synapse will not call
-any of the subsequent implementations of this callback. If every callback return `None`,
+any of the subsequent implementations of this callback. If every callback returns `None`,
 the username provided by the user is used, if any (otherwise one is automatically
 generated).
 
+### `get_displayname_for_registration`
+
+_First introduced in Synapse v1.54.0_
+
+```python
+async def get_displayname_for_registration(
+    uia_results: Dict[str, Any],
+    params: Dict[str, Any],
+) -> Optional[str]
+```
+
+Called when registering a new user. The module can return a display name to set for the
+user being registered by returning it as a string, or `None` if it doesn't wish to force a
+display name for this user.
+
+This callback is called once [User-Interactive Authentication](https://spec.matrix.org/latest/client-server-api/#user-interactive-authentication-api)
+has been completed by the user. It is not called when registering a user via SSO. It is
+passed two dictionaries, which include the information that the user has provided during
+the registration process. These dictionaries are identical to the ones passed to
+[`get_username_for_registration`](#get_username_for_registration), so refer to the
+documentation of this callback for more information about them.
+
+If multiple modules implement this callback, they will be considered in order. If a
+callback returns `None`, Synapse falls through to the next one. The value of the first
+callback that does not return `None` will be used. If this happens, Synapse will not call
+any of the subsequent implementations of this callback. If every callback returns `None`,
+the username will be used (e.g. `alice` if the user being registered is `@alice:example.com`).
+
 ## `is_3pid_allowed`
 
 _First introduced in Synapse v1.53.0_
@@ -194,8 +222,7 @@ The example module below implements authentication checkers for two different lo
     - Is checked by the method: `self.check_my_login`
 - `m.login.password` (defined in [the spec](https://matrix.org/docs/spec/client_server/latest#password-based))
     - Expects a `password` field to be sent to `/login`
-    - Is checked by the method: `self.check_pass` 
-
+    - Is checked by the method: `self.check_pass`
 
 ```python
 from typing import Awaitable, Callable, Optional, Tuple
diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md
index 2eb9032f41..2b672b78f9 100644
--- a/docs/modules/spam_checker_callbacks.md
+++ b/docs/modules/spam_checker_callbacks.md
@@ -16,10 +16,12 @@ _First introduced in Synapse v1.37.0_
 async def check_event_for_spam(event: "synapse.events.EventBase") -> Union[bool, str]
 ```
 
-Called when receiving an event from a client or via federation. The module can return
-either a `bool` to indicate whether the event must be rejected because of spam, or a `str`
-to indicate the event must be rejected because of spam and to give a rejection reason to
-forward to clients.
+Called when receiving an event from a client or via federation. The callback must return
+either:
+- an error message string, to indicate the event must be rejected because of spam and 
+  give a rejection reason to forward to clients;
+- the boolean `True`, to indicate that the event is spammy, but not provide further details; or
+- the booelan `False`, to indicate that the event is not considered spammy.
 
 If multiple modules implement this callback, they will be considered in order. If a
 callback returns `False`, Synapse falls through to the next one. The value of the first
@@ -35,7 +37,10 @@ async def user_may_join_room(user: str, room: str, is_invited: bool) -> bool
 ```
 
 Called when a user is trying to join a room. The module must return a `bool` to indicate
-whether the user can join the room. The user is represented by their Matrix user ID (e.g.
+whether the user can join the room. Return `False` to prevent the user from joining the
+room; otherwise return `True` to permit the joining.
+
+The user is represented by their Matrix user ID (e.g.
 `@alice:example.com`) and the room is represented by its Matrix ID (e.g.
 `!room:example.com`). The module is also given a boolean to indicate whether the user
 currently has a pending invite in the room.
@@ -58,7 +63,8 @@ async def user_may_invite(inviter: str, invitee: str, room_id: str) -> bool
 
 Called when processing an invitation. The module must return a `bool` indicating whether
 the inviter can invite the invitee to the given room. Both inviter and invitee are
-represented by their Matrix user ID (e.g. `@alice:example.com`).
+represented by their Matrix user ID (e.g. `@alice:example.com`). Return `False` to prevent
+the invitation; otherwise return `True` to permit it.
 
 If multiple modules implement this callback, they will be considered in order. If a
 callback returns `True`, Synapse falls through to the next one. The value of the first
@@ -80,7 +86,8 @@ async def user_may_send_3pid_invite(
 
 Called when processing an invitation using a third-party identifier (also called a 3PID,
 e.g. an email address or a phone number). The module must return a `bool` indicating
-whether the inviter can invite the invitee to the given room.
+whether the inviter can invite the invitee to the given room. Return `False` to prevent
+the invitation; otherwise return `True` to permit it.
 
 The inviter is represented by their Matrix user ID (e.g. `@alice:example.com`), and the
 invitee is represented by its medium (e.g. "email") and its address
@@ -117,6 +124,7 @@ async def user_may_create_room(user: str) -> bool
 
 Called when processing a room creation request. The module must return a `bool` indicating
 whether the given user (represented by their Matrix user ID) is allowed to create a room.
+Return `False` to prevent room creation; otherwise return `True` to permit it.
 
 If multiple modules implement this callback, they will be considered in order. If a
 callback returns `True`, Synapse falls through to the next one. The value of the first
@@ -133,7 +141,8 @@ async def user_may_create_room_alias(user: str, room_alias: "synapse.types.RoomA
 
 Called when trying to associate an alias with an existing room. The module must return a
 `bool` indicating whether the given user (represented by their Matrix user ID) is allowed
-to set the given alias.
+to set the given alias. Return `False` to prevent the alias creation; otherwise return 
+`True` to permit it.
 
 If multiple modules implement this callback, they will be considered in order. If a
 callback returns `True`, Synapse falls through to the next one. The value of the first
@@ -150,7 +159,8 @@ async def user_may_publish_room(user: str, room_id: str) -> bool
 
 Called when trying to publish a room to the homeserver's public rooms directory. The
 module must return a `bool` indicating whether the given user (represented by their
-Matrix user ID) is allowed to publish the given room.
+Matrix user ID) is allowed to publish the given room. Return `False` to prevent the
+room from being published; otherwise return `True` to permit its publication.
 
 If multiple modules implement this callback, they will be considered in order. If a
 callback returns `True`, Synapse falls through to the next one. The value of the first
@@ -166,8 +176,11 @@ async def check_username_for_spam(user_profile: Dict[str, str]) -> bool
 ```
 
 Called when computing search results in the user directory. The module must return a
-`bool` indicating whether the given user profile can appear in search results. The profile
-is represented as a dictionary with the following keys:
+`bool` indicating whether the given user should be excluded from user directory 
+searches. Return `True` to indicate that the user is spammy and exclude them from 
+search results; otherwise return `False`.
+
+The profile is represented as a dictionary with the following keys:
 
 * `user_id`: The Matrix ID for this user.
 * `display_name`: The user's display name.
@@ -225,8 +238,9 @@ async def check_media_file_for_spam(
 ) -> bool
 ```
 
-Called when storing a local or remote file. The module must return a boolean indicating
-whether the given file can be stored in the homeserver's media store.
+Called when storing a local or remote file. The module must return a `bool` indicating
+whether the given file should be excluded from the homeserver's media store. Return
+`True` to prevent this file from being stored; otherwise return `False`.
 
 If multiple modules implement this callback, they will be considered in order. If a
 callback returns `False`, Synapse falls through to the next one. The value of the first
diff --git a/docs/structured_logging.md b/docs/structured_logging.md
index 14db85f587..805c867653 100644
--- a/docs/structured_logging.md
+++ b/docs/structured_logging.md
@@ -81,14 +81,12 @@ remote endpoint at 10.1.2.3:9999.
 
 ## Upgrading from legacy structured logging configuration
 
-Versions of Synapse prior to v1.23.0 included a custom structured logging
-configuration which is deprecated. It used a `structured: true` flag and
-configured `drains` instead of ``handlers`` and `formatters`.
-
-Synapse currently automatically converts the old configuration to the new
-configuration, but this will be removed in a future version of Synapse. The
-following reference can be used to update your configuration. Based on the drain
-`type`, we can pick a new handler:
+Versions of Synapse prior to v1.54.0 automatically converted the legacy
+structured logging configuration, which was deprecated in v1.23.0, to the standard
+library logging configuration.
+
+The following reference can be used to update your configuration. Based on the
+drain `type`, we can pick a new handler:
 
 1. For a type of `console`, `console_json`, or `console_json_terse`: a handler
    with a class of `logging.StreamHandler` and a `stream` of `ext://sys.stdout`
diff --git a/docs/upgrade.md b/docs/upgrade.md
index b722d3bb9d..f9be3ac6bc 100644
--- a/docs/upgrade.md
+++ b/docs/upgrade.md
@@ -85,6 +85,15 @@ process, for example:
     dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
     ```
 
+# Upgrading to v1.54.0
+
+## Legacy structured logging configuration removal
+
+This release removes support for the `structured: true` logging configuration
+which was deprecated in Synapse v1.23.0. If your logging configuration contains
+`structured: true` then it should be modified based on the
+[structured logging documentation](structured_logging.md).
+
 # Upgrading to v1.53.0
 
 ## Dropping support for `webclient` listeners and non-HTTP(S) `web_client_location`
diff --git a/mypy.ini b/mypy.ini
index 63848d664c..610660b9b7 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -31,14 +31,11 @@ exclude = (?x)
    |synapse/storage/databases/main/group_server.py
    |synapse/storage/databases/main/metrics.py
    |synapse/storage/databases/main/monthly_active_users.py
-   |synapse/storage/databases/main/presence.py
-   |synapse/storage/databases/main/purge_events.py
    |synapse/storage/databases/main/push_rule.py
    |synapse/storage/databases/main/receipts.py
    |synapse/storage/databases/main/roommember.py
    |synapse/storage/databases/main/search.py
    |synapse/storage/databases/main/state.py
-   |synapse/storage/databases/main/user_directory.py
    |synapse/storage/schema/
 
    |tests/api/test_auth.py
diff --git a/pyproject.toml b/pyproject.toml
index 963f149c6a..c9cd0cf6ec 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -54,3 +54,15 @@ exclude = '''
   )/
 )
 '''
+
+[tool.isort]
+line_length = 88
+sections = ["FUTURE", "STDLIB", "THIRDPARTY", "TWISTED", "FIRSTPARTY", "TESTS", "LOCALFOLDER"]
+default_section = "THIRDPARTY"
+known_first_party = ["synapse"]
+known_tests = ["tests"]
+known_twisted = ["twisted", "OpenSSL"]
+multi_line_output = 3
+include_trailing_comma = true
+combine_as_imports = true
+
diff --git a/setup.cfg b/setup.cfg
index e5ceb7ed19..a0506572d9 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -19,14 +19,3 @@ ignore =
 #  E731: do not assign a lambda expression, use a def
 #  E501: Line too long (black enforces this for us)
 ignore=W503,W504,E203,E731,E501
-
-[isort]
-line_length = 88
-sections=FUTURE,STDLIB,THIRDPARTY,TWISTED,FIRSTPARTY,TESTS,LOCALFOLDER
-default_section=THIRDPARTY
-known_first_party = synapse
-known_tests=tests
-known_twisted=twisted,OpenSSL
-multi_line_output=3
-include_trailing_comma=true
-combine_as_imports=true
diff --git a/setup.py b/setup.py
index d0511c767f..c80cb6f207 100755
--- a/setup.py
+++ b/setup.py
@@ -103,8 +103,8 @@ CONDITIONAL_REQUIREMENTS["lint"] = [
 ]
 
 CONDITIONAL_REQUIREMENTS["mypy"] = [
-    "mypy==0.910",
-    "mypy-zope==0.3.2",
+    "mypy==0.931",
+    "mypy-zope==0.3.5",
     "types-bleach>=4.1.0",
     "types-jsonschema>=3.2.0",
     "types-opentracing>=2.4.2",
diff --git a/stubs/sortedcontainers/sorteddict.pyi b/stubs/sortedcontainers/sorteddict.pyi
index 0eaef00498..344d55cce1 100644
--- a/stubs/sortedcontainers/sorteddict.pyi
+++ b/stubs/sortedcontainers/sorteddict.pyi
@@ -66,13 +66,18 @@ class SortedDict(Dict[_KT, _VT]):
     def __copy__(self: _SD) -> _SD: ...
     @classmethod
     @overload
-    def fromkeys(cls, seq: Iterable[_T_h]) -> SortedDict[_T_h, None]: ...
+    def fromkeys(
+        cls, seq: Iterable[_T_h], value: None = ...
+    ) -> SortedDict[_T_h, None]: ...
     @classmethod
     @overload
     def fromkeys(cls, seq: Iterable[_T_h], value: _S) -> SortedDict[_T_h, _S]: ...
-    def keys(self) -> SortedKeysView[_KT]: ...
-    def items(self) -> SortedItemsView[_KT, _VT]: ...
-    def values(self) -> SortedValuesView[_VT]: ...
+    # As of Python 3.10, `dict_{keys,items,values}` have an extra `mapping` attribute and so
+    # `Sorted{Keys,Items,Values}View` are no longer compatible with them.
+    # See https://github.com/python/typeshed/issues/6837
+    def keys(self) -> SortedKeysView[_KT]: ...  # type: ignore[override]
+    def items(self) -> SortedItemsView[_KT, _VT]: ...  # type: ignore[override]
+    def values(self) -> SortedValuesView[_VT]: ...  # type: ignore[override]
     @overload
     def pop(self, key: _KT) -> _VT: ...
     @overload
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 09d692d9a1..bcdeb9ee23 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -41,9 +41,6 @@ class ExperimentalConfig(Config):
         # MSC3244 (room version capabilities)
         self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True)
 
-        # MSC3283 (set displayname, avatar_url and change 3pid capabilities)
-        self.msc3283_enabled: bool = experimental.get("msc3283_enabled", False)
-
         # MSC3266 (room summary api)
         self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False)
 
@@ -64,3 +61,7 @@ class ExperimentalConfig(Config):
 
         # MSC3706 (server-side support for partial state in /send_join responses)
         self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False)
+
+        # experimental support for faster joins over federation (msc2775, msc3706)
+        # requires a target server with msc3706_enabled enabled.
+        self.faster_joins_enabled: bool = experimental.get("faster_joins", False)
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index b7145a44ae..cbbe221965 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -33,7 +33,6 @@ from twisted.logger import (
     globalLogBeginner,
 )
 
-from synapse.logging._structured import setup_structured_logging
 from synapse.logging.context import LoggingContextFilter
 from synapse.logging.filter import MetadataFilter
 
@@ -138,6 +137,12 @@ Support for the log_file configuration option and --log-file command-line option
 removed in Synapse 1.3.0. You should instead set up a separate log configuration file.
 """
 
+STRUCTURED_ERROR = """\
+Support for the structured configuration option was removed in Synapse 1.54.0.
+You should instead use the standard logging configuration. See
+https://matrix-org.github.io/synapse/v1.54/structured_logging.html
+"""
+
 
 class LoggingConfig(Config):
     section = "logging"
@@ -292,10 +297,9 @@ def _load_logging_config(log_config_path: str) -> None:
     if not log_config:
         logging.warning("Loaded a blank logging config?")
 
-    # If the old structured logging configuration is being used, convert it to
-    # the new style configuration.
+    # If the old structured logging configuration is being used, raise an error.
     if "structured" in log_config and log_config.get("structured"):
-        log_config = setup_structured_logging(log_config)
+        raise ConfigError(STRUCTURED_ERROR)
 
     logging.config.dictConfig(log_config)
 
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 243696b357..9386fa29dd 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -425,6 +425,33 @@ class EventClientSerializer:
 
         return serialized_event
 
+    def _apply_edit(
+        self, orig_event: EventBase, serialized_event: JsonDict, edit: EventBase
+    ) -> None:
+        """Replace the content, preserving existing relations of the serialized event.
+
+        Args:
+            orig_event: The original event.
+            serialized_event: The original event, serialized. This is modified.
+            edit: The event which edits the above.
+        """
+
+        # Ensure we take copies of the edit content, otherwise we risk modifying
+        # the original event.
+        edit_content = edit.content.copy()
+
+        # Unfreeze the event content if necessary, so that we may modify it below
+        edit_content = unfreeze(edit_content)
+        serialized_event["content"] = edit_content.get("m.new_content", {})
+
+        # Check for existing relations
+        relates_to = orig_event.content.get("m.relates_to")
+        if relates_to:
+            # Keep the relations, ensuring we use a dict copy of the original
+            serialized_event["content"]["m.relates_to"] = relates_to.copy()
+        else:
+            serialized_event["content"].pop("m.relates_to", None)
+
     def _inject_bundled_aggregations(
         self,
         event: EventBase,
@@ -450,26 +477,11 @@ class EventClientSerializer:
             serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references
 
         if aggregations.replace:
-            # If there is an edit replace the content, preserving existing
-            # relations.
+            # If there is an edit, apply it to the event.
             edit = aggregations.replace
+            self._apply_edit(event, serialized_event, edit)
 
-            # Ensure we take copies of the edit content, otherwise we risk modifying
-            # the original event.
-            edit_content = edit.content.copy()
-
-            # Unfreeze the event content if necessary, so that we may modify it below
-            edit_content = unfreeze(edit_content)
-            serialized_event["content"] = edit_content.get("m.new_content", {})
-
-            # Check for existing relations
-            relates_to = event.content.get("m.relates_to")
-            if relates_to:
-                # Keep the relations, ensuring we use a dict copy of the original
-                serialized_event["content"]["m.relates_to"] = relates_to.copy()
-            else:
-                serialized_event["content"].pop("m.relates_to", None)
-
+            # Include information about it in the relations dict.
             serialized_aggregations[RelationTypes.REPLACE] = {
                 "event_id": edit.event_id,
                 "origin_server_ts": edit.origin_server_ts,
@@ -478,13 +490,22 @@ class EventClientSerializer:
 
         # If this event is the start of a thread, include a summary of the replies.
         if aggregations.thread:
+            thread = aggregations.thread
+
+            # Don't bundle aggregations as this could recurse forever.
+            serialized_latest_event = self.serialize_event(
+                thread.latest_event, time_now, bundle_aggregations=None
+            )
+            # Manually apply an edit, if one exists.
+            if thread.latest_edit:
+                self._apply_edit(
+                    thread.latest_event, serialized_latest_event, thread.latest_edit
+                )
+
             serialized_aggregations[RelationTypes.THREAD] = {
-                # Don't bundle aggregations as this could recurse forever.
-                "latest_event": self.serialize_event(
-                    aggregations.thread.latest_event, time_now, bundle_aggregations=None
-                ),
-                "count": aggregations.thread.count,
-                "current_user_participated": aggregations.thread.current_user_participated,
+                "latest_event": serialized_latest_event,
+                "count": thread.count,
+                "current_user_participated": thread.current_user_participated,
             }
 
         # Include the bundled aggregations in the event.
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 896168c05c..fab6da3c08 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -47,6 +47,11 @@ class FederationBase:
     ) -> EventBase:
         """Checks that event is correctly signed by the sending server.
 
+        Also checks the content hash, and redacts the event if there is a mismatch.
+
+        Also runs the event through the spam checker; if it fails, redacts the event
+        and flags it as soft-failed.
+
         Args:
             room_version: The room version of the PDU
             pdu: the event to be checked
@@ -55,7 +60,10 @@ class FederationBase:
               * the original event if the checks pass
               * a redacted version of the event (if the signature
                 matched but the hash did not)
-              * throws a SynapseError if the signature check failed."""
+
+        Raises:
+              SynapseError if the signature check failed.
+        """
         try:
             await _check_sigs_on_pdu(self.keyring, room_version, pdu)
         except SynapseError as e:
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 74f17aa4da..c2997997da 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -1,4 +1,4 @@
-# Copyright 2015-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2015-2022 The Matrix.org Foundation C.I.C.
 # Copyright 2020 Sorunome
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -89,6 +89,12 @@ class SendJoinResult:
     state: List[EventBase]
     auth_chain: List[EventBase]
 
+    # True if 'state' elides non-critical membership events
+    partial_state: bool
+
+    # if 'partial_state' is set, a list of the servers in the room (otherwise empty)
+    servers_in_room: List[str]
+
 
 class FederationClient(FederationBase):
     def __init__(self, hs: "HomeServer"):
@@ -413,26 +419,90 @@ class FederationClient(FederationBase):
 
         return state_event_ids, auth_event_ids
 
+    async def get_room_state(
+        self,
+        destination: str,
+        room_id: str,
+        event_id: str,
+        room_version: RoomVersion,
+    ) -> Tuple[List[EventBase], List[EventBase]]:
+        """Calls the /state endpoint to fetch the state at a particular point
+        in the room.
+
+        Any invalid events (those with incorrect or unverifiable signatures or hashes)
+        are filtered out from the response, and any duplicate events are removed.
+
+        (Size limits and other event-format checks are *not* performed.)
+
+        Note that the result is not ordered, so callers must be careful to process
+        the events in an order that handles dependencies.
+
+        Returns:
+            a tuple of (state events, auth events)
+        """
+        result = await self.transport_layer.get_room_state(
+            room_version,
+            destination,
+            room_id,
+            event_id,
+        )
+        state_events = result.state
+        auth_events = result.auth_events
+
+        # we may as well filter out any duplicates from the response, to save
+        # processing them multiple times. (In particular, events may be present in
+        # `auth_events` as well as `state`, which is redundant).
+        #
+        # We don't rely on the sort order of the events, so we can just stick them
+        # in a dict.
+        state_event_map = {event.event_id: event for event in state_events}
+        auth_event_map = {
+            event.event_id: event
+            for event in auth_events
+            if event.event_id not in state_event_map
+        }
+
+        logger.info(
+            "Processing from /state: %d state events, %d auth events",
+            len(state_event_map),
+            len(auth_event_map),
+        )
+
+        valid_auth_events = await self._check_sigs_and_hash_and_fetch(
+            destination, auth_event_map.values(), room_version
+        )
+
+        valid_state_events = await self._check_sigs_and_hash_and_fetch(
+            destination, state_event_map.values(), room_version
+        )
+
+        return valid_state_events, valid_auth_events
+
     async def _check_sigs_and_hash_and_fetch(
         self,
         origin: str,
         pdus: Collection[EventBase],
         room_version: RoomVersion,
     ) -> List[EventBase]:
-        """Takes a list of PDUs and checks the signatures and hashes of each
-        one. If a PDU fails its signature check then we check if we have it in
-        the database and if not then request if from the originating server of
-        that PDU.
+        """Checks the signatures and hashes of a list of events.
+
+        If a PDU fails its signature check then we check if we have it in
+        the database, and if not then request it from the sender's server (if that
+        is different from `origin`). If that still fails, the event is omitted from
+        the returned list.
 
         If a PDU fails its content hash check then it is redacted.
 
-        The given list of PDUs are not modified, instead the function returns
+        Also runs each event through the spam checker; if it fails, redacts the event
+        and flags it as soft-failed.
+
+        The given list of PDUs are not modified; instead the function returns
         a new list.
 
         Args:
-            origin
-            pdu
-            room_version
+            origin: The server that sent us these events
+            pdus: The events to be checked
+            room_version: the version of the room these events are in
 
         Returns:
             A list of PDUs that have valid signatures and hashes.
@@ -463,11 +533,16 @@ class FederationClient(FederationBase):
         origin: str,
         room_version: RoomVersion,
     ) -> Optional[EventBase]:
-        """Takes a PDU and checks its signatures and hashes. If the PDU fails
-        its signature check then we check if we have it in the database and if
-        not then request if from the originating server of that PDU.
+        """Takes a PDU and checks its signatures and hashes.
+
+        If the PDU fails its signature check then we check if we have it in the
+        database; if not, we then request it from sender's server (if that is not the
+        same as `origin`). If that still fails, we return None.
 
-        If then PDU fails its content hash check then it is redacted.
+        If the PDU fails its content hash check, it is redacted.
+
+        Also runs the event through the spam checker; if it fails, redacts the event
+        and flags it as soft-failed.
 
         Args:
             origin
@@ -864,23 +939,32 @@ class FederationClient(FederationBase):
             for s in signed_state:
                 s.internal_metadata = copy.deepcopy(s.internal_metadata)
 
-            # double-check that the same create event has ended up in the auth chain
+            # double-check that the auth chain doesn't include a different create event
             auth_chain_create_events = [
                 e.event_id
                 for e in signed_auth
                 if (e.type, e.state_key) == (EventTypes.Create, "")
             ]
-            if auth_chain_create_events != [create_event.event_id]:
+            if auth_chain_create_events and auth_chain_create_events != [
+                create_event.event_id
+            ]:
                 raise InvalidResponseError(
                     "Unexpected create event(s) in auth chain: %s"
                     % (auth_chain_create_events,)
                 )
 
+            if response.partial_state and not response.servers_in_room:
+                raise InvalidResponseError(
+                    "partial_state was set, but no servers were listed in the room"
+                )
+
             return SendJoinResult(
                 event=event,
                 state=signed_state,
                 auth_chain=signed_auth,
                 origin=destination,
+                partial_state=response.partial_state,
+                servers_in_room=response.servers_in_room or [],
             )
 
         # MSC3083 defines additional error codes for room joins.
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 8152e80b88..c3132f7319 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -381,7 +381,9 @@ class PerDestinationQueue:
                 )
             )
 
-        if self._last_successful_stream_ordering is None:
+        last_successful_stream_ordering = self._last_successful_stream_ordering
+
+        if last_successful_stream_ordering is None:
             # if it's still None, then this means we don't have the information
             # in our database ­ we haven't successfully sent a PDU to this server
             # (at least since the introduction of the feature tracking
@@ -394,8 +396,7 @@ class PerDestinationQueue:
         # get at most 50 catchup room/PDUs
         while True:
             event_ids = await self._store.get_catch_up_room_event_ids(
-                self._destination,
-                self._last_successful_stream_ordering,
+                self._destination, last_successful_stream_ordering
             )
 
             if not event_ids:
@@ -403,7 +404,7 @@ class PerDestinationQueue:
                 # of a race condition, so we check that no new events have been
                 # skipped due to us being in catch-up mode
 
-                if self._catchup_last_skipped > self._last_successful_stream_ordering:
+                if self._catchup_last_skipped > last_successful_stream_ordering:
                     # another event has been skipped because we were in catch-up mode
                     continue
 
@@ -470,7 +471,7 @@ class PerDestinationQueue:
                         # offline
                         if (
                             p.internal_metadata.stream_ordering
-                            < self._last_successful_stream_ordering
+                            < last_successful_stream_ordering
                         ):
                             continue
 
@@ -513,12 +514,11 @@ class PerDestinationQueue:
                 # from the *original* PDU, rather than the PDU(s) we actually
                 # send. This is because we use it to mark our position in the
                 # queue of missed PDUs to process.
-                self._last_successful_stream_ordering = (
-                    pdu.internal_metadata.stream_ordering
-                )
+                last_successful_stream_ordering = pdu.internal_metadata.stream_ordering
 
+                self._last_successful_stream_ordering = last_successful_stream_ordering
                 await self._store.set_destination_last_successful_stream_ordering(
-                    self._destination, self._last_successful_stream_ordering
+                    self._destination, last_successful_stream_ordering
                 )
 
     def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]:
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 8782586cd6..7e510e224a 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -1,4 +1,4 @@
-# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2014-2022 The Matrix.org Foundation C.I.C.
 # Copyright 2020 Sorunome
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -60,17 +60,17 @@ class TransportLayerClient:
     def __init__(self, hs):
         self.server_name = hs.hostname
         self.client = hs.get_federation_http_client()
+        self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled
 
     async def get_room_state_ids(
         self, destination: str, room_id: str, event_id: str
     ) -> JsonDict:
-        """Requests all state for a given room from the given server at the
-        given event. Returns the state's event_id's
+        """Requests the IDs of all state for a given room at the given event.
 
         Args:
             destination: The host name of the remote homeserver we want
                 to get the state from.
-            context: The name of the context we want the state of
+            room_id: the room we want the state of
             event_id: The event we want the context at.
 
         Returns:
@@ -86,6 +86,29 @@ class TransportLayerClient:
             try_trailing_slash_on_400=True,
         )
 
+    async def get_room_state(
+        self, room_version: RoomVersion, destination: str, room_id: str, event_id: str
+    ) -> "StateRequestResponse":
+        """Requests the full state for a given room at the given event.
+
+        Args:
+            room_version: the version of the room (required to build the event objects)
+            destination: The host name of the remote homeserver we want
+                to get the state from.
+            room_id: the room we want the state of
+            event_id: The event we want the context at.
+
+        Returns:
+            Results in a dict received from the remote homeserver.
+        """
+        path = _create_v1_path("/state/%s", room_id)
+        return await self.client.get_json(
+            destination,
+            path=path,
+            args={"event_id": event_id},
+            parser=_StateParser(room_version),
+        )
+
     async def get_event(
         self, destination: str, event_id: str, timeout: Optional[int] = None
     ) -> JsonDict:
@@ -336,10 +359,15 @@ class TransportLayerClient:
         content: JsonDict,
     ) -> "SendJoinResponse":
         path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
+        query_params: Dict[str, str] = {}
+        if self._faster_joins_enabled:
+            # lazy-load state on join
+            query_params["org.matrix.msc3706.partial_state"] = "true"
 
         return await self.client.put_json(
             destination=destination,
             path=path,
+            args=query_params,
             data=content,
             parser=SendJoinParser(room_version, v1_api=False),
             max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN,
@@ -1271,6 +1299,20 @@ class SendJoinResponse:
     # "event" is not included in the response.
     event: Optional[EventBase] = None
 
+    # The room state is incomplete
+    partial_state: bool = False
+
+    # List of servers in the room
+    servers_in_room: Optional[List[str]] = None
+
+
+@attr.s(slots=True, auto_attribs=True)
+class StateRequestResponse:
+    """The parsed response of a `/state` request."""
+
+    auth_events: List[EventBase]
+    state: List[EventBase]
+
 
 @ijson.coroutine
 def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]:
@@ -1297,6 +1339,32 @@ def _event_list_parser(
         events.append(event)
 
 
+@ijson.coroutine
+def _partial_state_parser(response: SendJoinResponse) -> Generator[None, Any, None]:
+    """Helper function for use with `ijson.items_coro`
+
+    Parses the partial_state field in send_join responses
+    """
+    while True:
+        val = yield
+        if not isinstance(val, bool):
+            raise TypeError("partial_state must be a boolean")
+        response.partial_state = val
+
+
+@ijson.coroutine
+def _servers_in_room_parser(response: SendJoinResponse) -> Generator[None, Any, None]:
+    """Helper function for use with `ijson.items_coro`
+
+    Parses the servers_in_room field in send_join responses
+    """
+    while True:
+        val = yield
+        if not isinstance(val, list) or any(not isinstance(x, str) for x in val):
+            raise TypeError("servers_in_room must be a list of strings")
+        response.servers_in_room = val
+
+
 class SendJoinParser(ByteParser[SendJoinResponse]):
     """A parser for the response to `/send_join` requests.
 
@@ -1308,44 +1376,62 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
     CONTENT_TYPE = "application/json"
 
     def __init__(self, room_version: RoomVersion, v1_api: bool):
-        self._response = SendJoinResponse([], [], {})
+        self._response = SendJoinResponse([], [], event_dict={})
         self._room_version = room_version
+        self._coros = []
 
         # The V1 API has the shape of `[200, {...}]`, which we handle by
         # prefixing with `item.*`.
         prefix = "item." if v1_api else ""
 
-        self._coro_state = ijson.items_coro(
-            _event_list_parser(room_version, self._response.state),
-            prefix + "state.item",
-            use_float=True,
-        )
-        self._coro_auth = ijson.items_coro(
-            _event_list_parser(room_version, self._response.auth_events),
-            prefix + "auth_chain.item",
-            use_float=True,
-        )
-        # TODO Remove the unstable prefix when servers have updated.
-        #
-        # By re-using the same event dictionary this will cause the parsing of
-        # org.matrix.msc3083.v2.event and event to stomp over each other.
-        # Generally this should be fine.
-        self._coro_unstable_event = ijson.kvitems_coro(
-            _event_parser(self._response.event_dict),
-            prefix + "org.matrix.msc3083.v2.event",
-            use_float=True,
-        )
-        self._coro_event = ijson.kvitems_coro(
-            _event_parser(self._response.event_dict),
-            prefix + "event",
-            use_float=True,
-        )
+        self._coros = [
+            ijson.items_coro(
+                _event_list_parser(room_version, self._response.state),
+                prefix + "state.item",
+                use_float=True,
+            ),
+            ijson.items_coro(
+                _event_list_parser(room_version, self._response.auth_events),
+                prefix + "auth_chain.item",
+                use_float=True,
+            ),
+            # TODO Remove the unstable prefix when servers have updated.
+            #
+            # By re-using the same event dictionary this will cause the parsing of
+            # org.matrix.msc3083.v2.event and event to stomp over each other.
+            # Generally this should be fine.
+            ijson.kvitems_coro(
+                _event_parser(self._response.event_dict),
+                prefix + "org.matrix.msc3083.v2.event",
+                use_float=True,
+            ),
+            ijson.kvitems_coro(
+                _event_parser(self._response.event_dict),
+                prefix + "event",
+                use_float=True,
+            ),
+        ]
+
+        if not v1_api:
+            self._coros.append(
+                ijson.items_coro(
+                    _partial_state_parser(self._response),
+                    "org.matrix.msc3706.partial_state",
+                    use_float="True",
+                )
+            )
+
+            self._coros.append(
+                ijson.items_coro(
+                    _servers_in_room_parser(self._response),
+                    "org.matrix.msc3706.servers_in_room",
+                    use_float="True",
+                )
+            )
 
     def write(self, data: bytes) -> int:
-        self._coro_state.send(data)
-        self._coro_auth.send(data)
-        self._coro_unstable_event.send(data)
-        self._coro_event.send(data)
+        for c in self._coros:
+            c.send(data)
 
         return len(data)
 
@@ -1355,3 +1441,37 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
                 self._response.event_dict, self._room_version
             )
         return self._response
+
+
+class _StateParser(ByteParser[StateRequestResponse]):
+    """A parser for the response to `/state` requests.
+
+    Args:
+        room_version: The version of the room.
+    """
+
+    CONTENT_TYPE = "application/json"
+
+    def __init__(self, room_version: RoomVersion):
+        self._response = StateRequestResponse([], [])
+        self._room_version = room_version
+        self._coros = [
+            ijson.items_coro(
+                _event_list_parser(room_version, self._response.state),
+                "pdus.item",
+                use_float=True,
+            ),
+            ijson.items_coro(
+                _event_list_parser(room_version, self._response.auth_events),
+                "auth_chain.item",
+                use_float=True,
+            ),
+        ]
+
+    def write(self, data: bytes) -> int:
+        for c in self._coros:
+            c.send(data)
+        return len(data)
+
+    def finish(self) -> StateRequestResponse:
+        return self._response
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 6959d1aa7e..572f54b1e3 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -2064,6 +2064,10 @@ GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
     [JsonDict, JsonDict],
     Awaitable[Optional[str]],
 ]
+GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[
+    [JsonDict, JsonDict],
+    Awaitable[Optional[str]],
+]
 IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
 
 
@@ -2080,6 +2084,9 @@ class PasswordAuthProvider:
         self.get_username_for_registration_callbacks: List[
             GET_USERNAME_FOR_REGISTRATION_CALLBACK
         ] = []
+        self.get_displayname_for_registration_callbacks: List[
+            GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
+        ] = []
         self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
 
         # Mapping from login type to login parameters
@@ -2099,6 +2106,9 @@ class PasswordAuthProvider:
         get_username_for_registration: Optional[
             GET_USERNAME_FOR_REGISTRATION_CALLBACK
         ] = None,
+        get_displayname_for_registration: Optional[
+            GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
+        ] = None,
     ) -> None:
         # Register check_3pid_auth callback
         if check_3pid_auth is not None:
@@ -2148,6 +2158,11 @@ class PasswordAuthProvider:
                 get_username_for_registration,
             )
 
+        if get_displayname_for_registration is not None:
+            self.get_displayname_for_registration_callbacks.append(
+                get_displayname_for_registration,
+            )
+
         if is_3pid_allowed is not None:
             self.is_3pid_allowed_callbacks.append(is_3pid_allowed)
 
@@ -2350,6 +2365,49 @@ class PasswordAuthProvider:
 
         return None
 
+    async def get_displayname_for_registration(
+        self,
+        uia_results: JsonDict,
+        params: JsonDict,
+    ) -> Optional[str]:
+        """Defines the display name to use when registering the user, using the
+        credentials and parameters provided during the UIA flow.
+
+        Stops at the first callback that returns a tuple containing at least one string.
+
+        Args:
+            uia_results: The credentials provided during the UIA flow.
+            params: The parameters provided by the registration request.
+
+        Returns:
+            A tuple which first element is the display name, and the second is an MXC URL
+            to the user's avatar.
+        """
+        for callback in self.get_displayname_for_registration_callbacks:
+            try:
+                res = await callback(uia_results, params)
+
+                if isinstance(res, str):
+                    return res
+                elif res is not None:
+                    # mypy complains that this line is unreachable because it assumes the
+                    # data returned by the module fits the expected type. We just want
+                    # to make sure this is the case.
+                    logger.warning(  # type: ignore[unreachable]
+                        "Ignoring non-string value returned by"
+                        " get_displayname_for_registration callback %s: %s",
+                        callback,
+                        res,
+                    )
+            except Exception as e:
+                logger.error(
+                    "Module raised an exception in get_displayname_for_registration: %s",
+                    e,
+                )
+                raise SynapseError(code=500, msg="Internal Server Error")
+
+        return None
+
     async def is_3pid_allowed(
         self,
         medium: str,
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index c0f642005f..c8356f233d 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -516,7 +516,7 @@ class FederationHandler:
             await self.store.upsert_room_on_join(
                 room_id=room_id,
                 room_version=room_version_obj,
-                auth_events=auth_chain,
+                state_events=state,
             )
 
             max_stream_id = await self._federation_event_handler.process_remote_join(
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 9edc7369d6..7683246bef 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -419,10 +419,8 @@ class FederationEventHandler:
         Raises:
             SynapseError if the response is in some way invalid.
         """
-        event_map = {e.event_id: e for e in itertools.chain(auth_events, state)}
-
         create_event = None
-        for e in auth_events:
+        for e in state:
             if (e.type, e.state_key) == (EventTypes.Create, ""):
                 create_event = e
                 break
@@ -439,11 +437,6 @@ class FederationEventHandler:
         if room_version.identifier != room_version_id:
             raise SynapseError(400, "Room version mismatch")
 
-        # filter out any events we have already seen
-        seen_remotes = await self._store.have_seen_events(room_id, event_map.keys())
-        for s in seen_remotes:
-            event_map.pop(s, None)
-
         # persist the auth chain and state events.
         #
         # any invalid events here will be marked as rejected, and we'll carry on.
@@ -455,7 +448,9 @@ class FederationEventHandler:
         # signatures right now doesn't mean that we will *never* be able to, so it
         # is premature to reject them.
         #
-        await self._auth_and_persist_outliers(room_id, event_map.values())
+        await self._auth_and_persist_outliers(
+            room_id, itertools.chain(auth_events, state)
+        )
 
         # and now persist the join event itself.
         logger.info("Peristing join-via-remote %s", event)
@@ -1245,6 +1240,16 @@ class FederationEventHandler:
         """
         event_map = {event.event_id: event for event in events}
 
+        # filter out any events we have already seen. This might happen because
+        # the events were eagerly pushed to us (eg, during a room join), or because
+        # another thread has raced against us since we decided to request the event.
+        #
+        # This is just an optimisation, so it doesn't need to be watertight - the event
+        # persister does another round of deduplication.
+        seen_remotes = await self._store.have_seen_events(room_id, event_map.keys())
+        for s in seen_remotes:
+            event_map.pop(s, None)
+
         # XXX: it might be possible to kick this process off in parallel with fetching
         # the events.
         while event_map:
@@ -1717,31 +1722,22 @@ class FederationEventHandler:
             event_id: the event for which we are lacking auth events
         """
         try:
-            remote_event_map = {
-                e.event_id: e
-                for e in await self._federation_client.get_event_auth(
-                    destination, room_id, event_id
-                )
-            }
+            remote_events = await self._federation_client.get_event_auth(
+                destination, room_id, event_id
+            )
+
         except RequestSendFailed as e1:
             # The other side isn't around or doesn't implement the
             # endpoint, so lets just bail out.
             logger.info("Failed to get event auth from remote: %s", e1)
             return
 
-        logger.info("/event_auth returned %i events", len(remote_event_map))
+        logger.info("/event_auth returned %i events", len(remote_events))
 
         # `event` may be returned, but we should not yet process it.
-        remote_event_map.pop(event_id, None)
-
-        # nor should we reprocess any events we have already seen.
-        seen_remotes = await self._store.have_seen_events(
-            room_id, remote_event_map.keys()
-        )
-        for s in seen_remotes:
-            remote_event_map.pop(s, None)
+        remote_auth_events = (e for e in remote_events if e.event_id != event_id)
 
-        await self._auth_and_persist_outliers(room_id, remote_event_map.values())
+        await self._auth_and_persist_outliers(room_id, remote_auth_events)
 
     async def _update_context_for_auth_events(
         self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 9267e586a8..4d0da84287 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -550,10 +550,11 @@ class EventCreationHandler:
 
         if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
             room_version_id = event_dict["content"]["room_version"]
-            room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id)
-            if not room_version_obj:
+            maybe_room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id)
+            if not maybe_room_version_obj:
                 # this can happen if support is withdrawn for a room version
                 raise UnsupportedRoomVersionError(room_version_id)
+            room_version_obj = maybe_room_version_obj
         else:
             try:
                 room_version_obj = await self.store.get_room_version(
@@ -1145,12 +1146,13 @@ class EventCreationHandler:
             room_version_id = event.content.get(
                 "room_version", RoomVersions.V1.identifier
             )
-            room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id)
-            if not room_version_obj:
+            maybe_room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id)
+            if not maybe_room_version_obj:
                 raise UnsupportedRoomVersionError(
                     "Attempt to create a room with unsupported room version %s"
                     % (room_version_id,)
                 )
+            room_version_obj = maybe_room_version_obj
         else:
             room_version_obj = await self.store.get_room_version(event.room_id)
 
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 067c43ae47..b223b72623 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -204,25 +204,27 @@ class BasePresenceHandler(abc.ABC):
         Returns:
             dict: `user_id` -> `UserPresenceState`
         """
-        states = {
-            user_id: self.user_to_current_state.get(user_id, None)
-            for user_id in user_ids
-        }
+        states = {}
+        missing = []
+        for user_id in user_ids:
+            state = self.user_to_current_state.get(user_id, None)
+            if state:
+                states[user_id] = state
+            else:
+                missing.append(user_id)
 
-        missing = [user_id for user_id, state in states.items() if not state]
         if missing:
             # There are things not in our in memory cache. Lets pull them out of
             # the database.
             res = await self.store.get_presence_for_users(missing)
             states.update(res)
 
-            missing = [user_id for user_id, state in states.items() if not state]
-            if missing:
-                new = {
-                    user_id: UserPresenceState.default(user_id) for user_id in missing
-                }
-                states.update(new)
-                self.user_to_current_state.update(new)
+            for user_id in missing:
+                # if user has no state in database, create the state
+                if not res.get(user_id, None):
+                    new_state = UserPresenceState.default(user_id)
+                    states[user_id] = new_state
+                    self.user_to_current_state[user_id] = new_state
 
         return states
 
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index a719d5eef3..80320d2c07 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -320,12 +320,12 @@ class RegistrationHandler:
                 if fail_count > 10:
                     raise SynapseError(500, "Unable to find a suitable guest user ID")
 
-                localpart = await self.store.generate_user_id()
-                user = UserID(localpart, self.hs.hostname)
+                generated_localpart = await self.store.generate_user_id()
+                user = UserID(generated_localpart, self.hs.hostname)
                 user_id = user.to_string()
                 self.check_user_id_not_appservice_exclusive(user_id)
                 if generate_display_name:
-                    default_display_name = localpart
+                    default_display_name = generated_localpart
                 try:
                     await self.register_with_store(
                         user_id=user_id,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index bf1a47efb0..b2adc0f48b 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -82,6 +82,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         self.event_auth_handler = hs.get_event_auth_handler()
 
         self.member_linearizer: Linearizer = Linearizer(name="member")
+        self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter")
 
         self.clock = hs.get_clock()
         self.spam_checker = hs.get_spam_checker()
@@ -500,25 +501,32 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         key = (room_id,)
 
-        with (await self.member_linearizer.queue(key)):
-            result = await self.update_membership_locked(
-                requester,
-                target,
-                room_id,
-                action,
-                txn_id=txn_id,
-                remote_room_hosts=remote_room_hosts,
-                third_party_signed=third_party_signed,
-                ratelimit=ratelimit,
-                content=content,
-                new_room=new_room,
-                require_consent=require_consent,
-                outlier=outlier,
-                historical=historical,
-                allow_no_prev_events=allow_no_prev_events,
-                prev_event_ids=prev_event_ids,
-                auth_event_ids=auth_event_ids,
-            )
+        as_id = object()
+        if requester.app_service:
+            as_id = requester.app_service.id
+
+        # We first linearise by the application service (to try to limit concurrent joins
+        # by application services), and then by room ID.
+        with (await self.member_as_limiter.queue(as_id)):
+            with (await self.member_linearizer.queue(key)):
+                result = await self.update_membership_locked(
+                    requester,
+                    target,
+                    room_id,
+                    action,
+                    txn_id=txn_id,
+                    remote_room_hosts=remote_room_hosts,
+                    third_party_signed=third_party_signed,
+                    ratelimit=ratelimit,
+                    content=content,
+                    new_room=new_room,
+                    require_consent=require_consent,
+                    outlier=outlier,
+                    historical=historical,
+                    allow_no_prev_events=allow_no_prev_events,
+                    prev_event_ids=prev_event_ids,
+                    auth_event_ids=auth_event_ids,
+                )
 
         return result
 
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 41cb809078..0e0e58de02 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -14,8 +14,9 @@
 
 import itertools
 import logging
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional
+from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
 
+import attr
 from unpaddedbase64 import decode_base64, encode_base64
 
 from synapse.api.constants import EventTypes, Membership
@@ -32,6 +33,20 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _SearchResult:
+    # The count of results.
+    count: int
+    # A mapping of event ID to the rank of that event.
+    rank_map: Dict[str, int]
+    # A list of the resulting events.
+    allowed_events: List[EventBase]
+    # A map of room ID to results.
+    room_groups: Dict[str, JsonDict]
+    # A set of event IDs to highlight.
+    highlights: Set[str]
+
+
 class SearchHandler:
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
@@ -100,7 +115,7 @@ class SearchHandler:
         """Performs a full text search for a user.
 
         Args:
-            user
+            user: The user performing the search.
             content: Search parameters
             batch: The next_batch parameter. Used for pagination.
 
@@ -156,6 +171,8 @@ class SearchHandler:
 
             # Include context around each event?
             event_context = room_cat.get("event_context", None)
+            before_limit = after_limit = None
+            include_profile = False
 
             # Group results together? May allow clients to paginate within a
             # group
@@ -182,6 +199,73 @@ class SearchHandler:
                 % (set(group_keys) - {"room_id", "sender"},),
             )
 
+        return await self._search(
+            user,
+            batch_group,
+            batch_group_key,
+            batch_token,
+            search_term,
+            keys,
+            filter_dict,
+            order_by,
+            include_state,
+            group_keys,
+            event_context,
+            before_limit,
+            after_limit,
+            include_profile,
+        )
+
+    async def _search(
+        self,
+        user: UserID,
+        batch_group: Optional[str],
+        batch_group_key: Optional[str],
+        batch_token: Optional[str],
+        search_term: str,
+        keys: List[str],
+        filter_dict: JsonDict,
+        order_by: str,
+        include_state: bool,
+        group_keys: List[str],
+        event_context: Optional[bool],
+        before_limit: Optional[int],
+        after_limit: Optional[int],
+        include_profile: bool,
+    ) -> JsonDict:
+        """Performs a full text search for a user.
+
+        Args:
+            user: The user performing the search.
+            batch_group: Pagination information.
+            batch_group_key: Pagination information.
+            batch_token: Pagination information.
+            search_term: Search term to search for
+            keys: List of keys to search in, currently supports
+                "content.body", "content.name", "content.topic"
+            filter_dict: The JSON to build a filter out of.
+            order_by: How to order the results. Valid values ore "rank" and "recent".
+            include_state: True if the state of the room at each result should
+                be included.
+            group_keys: A list of ways to group the results. Valid values are
+                "room_id" and "sender".
+            event_context: True to include contextual events around results.
+            before_limit:
+                The number of events before a result to include as context.
+
+                Only used if event_context is True.
+            after_limit:
+                The number of events after a result to include as context.
+
+                Only used if event_context is True.
+            include_profile: True if historical profile information should be
+                included in the event context.
+
+                Only used if event_context is True.
+
+        Returns:
+            dict to be returned to the client with results of search
+        """
         search_filter = Filter(self.hs, filter_dict)
 
         # TODO: Search through left rooms too
@@ -216,278 +300,399 @@ class SearchHandler:
                 }
             }
 
-        rank_map = {}  # event_id -> rank of event
-        allowed_events = []
-        # Holds result of grouping by room, if applicable
-        room_groups: Dict[str, JsonDict] = {}
-        # Holds result of grouping by sender, if applicable
-        sender_group: Dict[str, JsonDict] = {}
+        sender_group: Optional[Dict[str, JsonDict]]
 
-        # Holds the next_batch for the entire result set if one of those exists
-        global_next_batch = None
-
-        highlights = set()
+        if order_by == "rank":
+            search_result, sender_group = await self._search_by_rank(
+                user, room_ids, search_term, keys, search_filter
+            )
+            # Unused return values for rank search.
+            global_next_batch = None
+        elif order_by == "recent":
+            search_result, global_next_batch = await self._search_by_recent(
+                user,
+                room_ids,
+                search_term,
+                keys,
+                search_filter,
+                batch_group,
+                batch_group_key,
+                batch_token,
+            )
+            # Unused return values for recent search.
+            sender_group = None
+        else:
+            # We should never get here due to the guard earlier.
+            raise NotImplementedError()
 
-        count = None
+        logger.info("Found %d events to return", len(search_result.allowed_events))
 
-        if order_by == "rank":
-            search_result = await self.store.search_msgs(room_ids, search_term, keys)
+        # If client has asked for "context" for each event (i.e. some surrounding
+        # events and state), fetch that
+        if event_context is not None:
+            # Note that before and after limit must be set in this case.
+            assert before_limit is not None
+            assert after_limit is not None
+
+            contexts = await self._calculate_event_contexts(
+                user,
+                search_result.allowed_events,
+                before_limit,
+                after_limit,
+                include_profile,
+            )
+        else:
+            contexts = {}
 
-            count = search_result["count"]
+        # TODO: Add a limit
 
-            if search_result["highlights"]:
-                highlights.update(search_result["highlights"])
+        state_results = {}
+        if include_state:
+            for room_id in {e.room_id for e in search_result.allowed_events}:
+                state = await self.state_handler.get_current_state(room_id)
+                state_results[room_id] = list(state.values())
 
-            results = search_result["results"]
+        aggregations = None
+        if self._msc3666_enabled:
+            aggregations = await self.store.get_bundled_aggregations(
+                # Generate an iterable of EventBase for all the events that will be
+                # returned, including contextual events.
+                itertools.chain(
+                    # The events_before and events_after for each context.
+                    itertools.chain.from_iterable(
+                        itertools.chain(context["events_before"], context["events_after"])  # type: ignore[arg-type]
+                        for context in contexts.values()
+                    ),
+                    # The returned events.
+                    search_result.allowed_events,
+                ),
+                user.to_string(),
+            )
 
-            rank_map.update({r["event"].event_id: r["rank"] for r in results})
+        # We're now about to serialize the events. We should not make any
+        # blocking calls after this. Otherwise, the 'age' will be wrong.
 
-            filtered_events = await search_filter.filter([r["event"] for r in results])
+        time_now = self.clock.time_msec()
 
-            events = await filter_events_for_client(
-                self.storage, user.to_string(), filtered_events
+        for context in contexts.values():
+            context["events_before"] = self._event_serializer.serialize_events(
+                context["events_before"], time_now, bundle_aggregations=aggregations  # type: ignore[arg-type]
+            )
+            context["events_after"] = self._event_serializer.serialize_events(
+                context["events_after"], time_now, bundle_aggregations=aggregations  # type: ignore[arg-type]
             )
 
-            events.sort(key=lambda e: -rank_map[e.event_id])
-            allowed_events = events[: search_filter.limit]
+        results = [
+            {
+                "rank": search_result.rank_map[e.event_id],
+                "result": self._event_serializer.serialize_event(
+                    e, time_now, bundle_aggregations=aggregations
+                ),
+                "context": contexts.get(e.event_id, {}),
+            }
+            for e in search_result.allowed_events
+        ]
 
-            for e in allowed_events:
-                rm = room_groups.setdefault(
-                    e.room_id, {"results": [], "order": rank_map[e.event_id]}
-                )
-                rm["results"].append(e.event_id)
+        rooms_cat_res: JsonDict = {
+            "results": results,
+            "count": search_result.count,
+            "highlights": list(search_result.highlights),
+        }
 
-                s = sender_group.setdefault(
-                    e.sender, {"results": [], "order": rank_map[e.event_id]}
-                )
-                s["results"].append(e.event_id)
+        if state_results:
+            rooms_cat_res["state"] = {
+                room_id: self._event_serializer.serialize_events(state_events, time_now)
+                for room_id, state_events in state_results.items()
+            }
 
-        elif order_by == "recent":
-            room_events: List[EventBase] = []
-            i = 0
-
-            pagination_token = batch_token
-
-            # We keep looping and we keep filtering until we reach the limit
-            # or we run out of things.
-            # But only go around 5 times since otherwise synapse will be sad.
-            while len(room_events) < search_filter.limit and i < 5:
-                i += 1
-                search_result = await self.store.search_rooms(
-                    room_ids,
-                    search_term,
-                    keys,
-                    search_filter.limit * 2,
-                    pagination_token=pagination_token,
-                )
+        if search_result.room_groups and "room_id" in group_keys:
+            rooms_cat_res.setdefault("groups", {})[
+                "room_id"
+            ] = search_result.room_groups
 
-                if search_result["highlights"]:
-                    highlights.update(search_result["highlights"])
+        if sender_group and "sender" in group_keys:
+            rooms_cat_res.setdefault("groups", {})["sender"] = sender_group
 
-                count = search_result["count"]
+        if global_next_batch:
+            rooms_cat_res["next_batch"] = global_next_batch
 
-                results = search_result["results"]
+        return {"search_categories": {"room_events": rooms_cat_res}}
 
-                results_map = {r["event"].event_id: r for r in results}
+    async def _search_by_rank(
+        self,
+        user: UserID,
+        room_ids: Collection[str],
+        search_term: str,
+        keys: Iterable[str],
+        search_filter: Filter,
+    ) -> Tuple[_SearchResult, Dict[str, JsonDict]]:
+        """
+        Performs a full text search for a user ordering by rank.
 
-                rank_map.update({r["event"].event_id: r["rank"] for r in results})
+        Args:
+            user: The user performing the search.
+            room_ids: List of room ids to search in
+            search_term: Search term to search for
+            keys: List of keys to search in, currently supports
+                "content.body", "content.name", "content.topic"
+            search_filter: The event filter to use.
 
-                filtered_events = await search_filter.filter(
-                    [r["event"] for r in results]
-                )
+        Returns:
+            A tuple of:
+                The search results.
+                A map of sender ID to results.
+        """
+        rank_map = {}  # event_id -> rank of event
+        # Holds result of grouping by room, if applicable
+        room_groups: Dict[str, JsonDict] = {}
+        # Holds result of grouping by sender, if applicable
+        sender_group: Dict[str, JsonDict] = {}
 
-                events = await filter_events_for_client(
-                    self.storage, user.to_string(), filtered_events
-                )
+        search_result = await self.store.search_msgs(room_ids, search_term, keys)
 
-                room_events.extend(events)
-                room_events = room_events[: search_filter.limit]
+        if search_result["highlights"]:
+            highlights = search_result["highlights"]
+        else:
+            highlights = set()
 
-                if len(results) < search_filter.limit * 2:
-                    pagination_token = None
-                    break
-                else:
-                    pagination_token = results[-1]["pagination_token"]
-
-            for event in room_events:
-                group = room_groups.setdefault(event.room_id, {"results": []})
-                group["results"].append(event.event_id)
-
-            if room_events and len(room_events) >= search_filter.limit:
-                last_event_id = room_events[-1].event_id
-                pagination_token = results_map[last_event_id]["pagination_token"]
-
-                # We want to respect the given batch group and group keys so
-                # that if people blindly use the top level `next_batch` token
-                # it returns more from the same group (if applicable) rather
-                # than reverting to searching all results again.
-                if batch_group and batch_group_key:
-                    global_next_batch = encode_base64(
-                        (
-                            "%s\n%s\n%s"
-                            % (batch_group, batch_group_key, pagination_token)
-                        ).encode("ascii")
-                    )
-                else:
-                    global_next_batch = encode_base64(
-                        ("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii")
-                    )
+        results = search_result["results"]
 
-                for room_id, group in room_groups.items():
-                    group["next_batch"] = encode_base64(
-                        ("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode(
-                            "ascii"
-                        )
-                    )
+        # event_id -> rank of event
+        rank_map = {r["event"].event_id: r["rank"] for r in results}
 
-            allowed_events.extend(room_events)
+        filtered_events = await search_filter.filter([r["event"] for r in results])
 
-        else:
-            # We should never get here due to the guard earlier.
-            raise NotImplementedError()
+        events = await filter_events_for_client(
+            self.storage, user.to_string(), filtered_events
+        )
 
-        logger.info("Found %d events to return", len(allowed_events))
+        events.sort(key=lambda e: -rank_map[e.event_id])
+        allowed_events = events[: search_filter.limit]
 
-        # If client has asked for "context" for each event (i.e. some surrounding
-        # events and state), fetch that
-        if event_context is not None:
-            now_token = self.hs.get_event_sources().get_current_token()
+        for e in allowed_events:
+            rm = room_groups.setdefault(
+                e.room_id, {"results": [], "order": rank_map[e.event_id]}
+            )
+            rm["results"].append(e.event_id)
 
-            contexts = {}
-            for event in allowed_events:
-                res = await self.store.get_events_around(
-                    event.room_id, event.event_id, before_limit, after_limit
-                )
+            s = sender_group.setdefault(
+                e.sender, {"results": [], "order": rank_map[e.event_id]}
+            )
+            s["results"].append(e.event_id)
+
+        return (
+            _SearchResult(
+                search_result["count"],
+                rank_map,
+                allowed_events,
+                room_groups,
+                highlights,
+            ),
+            sender_group,
+        )
 
-                logger.info(
-                    "Context for search returned %d and %d events",
-                    len(res.events_before),
-                    len(res.events_after),
-                )
+    async def _search_by_recent(
+        self,
+        user: UserID,
+        room_ids: Collection[str],
+        search_term: str,
+        keys: Iterable[str],
+        search_filter: Filter,
+        batch_group: Optional[str],
+        batch_group_key: Optional[str],
+        batch_token: Optional[str],
+    ) -> Tuple[_SearchResult, Optional[str]]:
+        """
+        Performs a full text search for a user ordering by recent.
 
-                events_before = await filter_events_for_client(
-                    self.storage, user.to_string(), res.events_before
-                )
+        Args:
+            user: The user performing the search.
+            room_ids: List of room ids to search in
+            search_term: Search term to search for
+            keys: List of keys to search in, currently supports
+                "content.body", "content.name", "content.topic"
+            search_filter: The event filter to use.
+            batch_group: Pagination information.
+            batch_group_key: Pagination information.
+            batch_token: Pagination information.
 
-                events_after = await filter_events_for_client(
-                    self.storage, user.to_string(), res.events_after
-                )
+        Returns:
+            A tuple of:
+                The search results.
+                Optionally, a pagination token.
+        """
+        rank_map = {}  # event_id -> rank of event
+        # Holds result of grouping by room, if applicable
+        room_groups: Dict[str, JsonDict] = {}
 
-                context = {
-                    "events_before": events_before,
-                    "events_after": events_after,
-                    "start": await now_token.copy_and_replace(
-                        "room_key", res.start
-                    ).to_string(self.store),
-                    "end": await now_token.copy_and_replace(
-                        "room_key", res.end
-                    ).to_string(self.store),
-                }
+        # Holds the next_batch for the entire result set if one of those exists
+        global_next_batch = None
 
-                if include_profile:
-                    senders = {
-                        ev.sender
-                        for ev in itertools.chain(events_before, [event], events_after)
-                    }
+        highlights = set()
 
-                    if events_after:
-                        last_event_id = events_after[-1].event_id
-                    else:
-                        last_event_id = event.event_id
+        room_events: List[EventBase] = []
+        i = 0
+
+        pagination_token = batch_token
+
+        # We keep looping and we keep filtering until we reach the limit
+        # or we run out of things.
+        # But only go around 5 times since otherwise synapse will be sad.
+        while len(room_events) < search_filter.limit and i < 5:
+            i += 1
+            search_result = await self.store.search_rooms(
+                room_ids,
+                search_term,
+                keys,
+                search_filter.limit * 2,
+                pagination_token=pagination_token,
+            )
 
-                    state_filter = StateFilter.from_types(
-                        [(EventTypes.Member, sender) for sender in senders]
-                    )
+            if search_result["highlights"]:
+                highlights.update(search_result["highlights"])
+
+            count = search_result["count"]
+
+            results = search_result["results"]
+
+            results_map = {r["event"].event_id: r for r in results}
+
+            rank_map.update({r["event"].event_id: r["rank"] for r in results})
+
+            filtered_events = await search_filter.filter([r["event"] for r in results])
 
-                    state = await self.state_store.get_state_for_event(
-                        last_event_id, state_filter
+            events = await filter_events_for_client(
+                self.storage, user.to_string(), filtered_events
+            )
+
+            room_events.extend(events)
+            room_events = room_events[: search_filter.limit]
+
+            if len(results) < search_filter.limit * 2:
+                break
+            else:
+                pagination_token = results[-1]["pagination_token"]
+
+        for event in room_events:
+            group = room_groups.setdefault(event.room_id, {"results": []})
+            group["results"].append(event.event_id)
+
+        if room_events and len(room_events) >= search_filter.limit:
+            last_event_id = room_events[-1].event_id
+            pagination_token = results_map[last_event_id]["pagination_token"]
+
+            # We want to respect the given batch group and group keys so
+            # that if people blindly use the top level `next_batch` token
+            # it returns more from the same group (if applicable) rather
+            # than reverting to searching all results again.
+            if batch_group and batch_group_key:
+                global_next_batch = encode_base64(
+                    (
+                        "%s\n%s\n%s" % (batch_group, batch_group_key, pagination_token)
+                    ).encode("ascii")
+                )
+            else:
+                global_next_batch = encode_base64(
+                    ("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii")
+                )
+
+            for room_id, group in room_groups.items():
+                group["next_batch"] = encode_base64(
+                    ("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode(
+                        "ascii"
                     )
+                )
 
-                    context["profile_info"] = {
-                        s.state_key: {
-                            "displayname": s.content.get("displayname", None),
-                            "avatar_url": s.content.get("avatar_url", None),
-                        }
-                        for s in state.values()
-                        if s.type == EventTypes.Member and s.state_key in senders
-                    }
+        return (
+            _SearchResult(count, rank_map, room_events, room_groups, highlights),
+            global_next_batch,
+        )
 
-                contexts[event.event_id] = context
-        else:
-            contexts = {}
+    async def _calculate_event_contexts(
+        self,
+        user: UserID,
+        allowed_events: List[EventBase],
+        before_limit: int,
+        after_limit: int,
+        include_profile: bool,
+    ) -> Dict[str, JsonDict]:
+        """
+        Calculates the contextual events for any search results.
 
-        # TODO: Add a limit
+        Args:
+            user: The user performing the search.
+            allowed_events: The search results.
+            before_limit:
+                The number of events before a result to include as context.
+            after_limit:
+                The number of events after a result to include as context.
+            include_profile: True if historical profile information should be
+                included in the event context.
 
-        time_now = self.clock.time_msec()
+        Returns:
+            A map of event ID to contextual information.
+        """
+        now_token = self.hs.get_event_sources().get_current_token()
 
-        aggregations = None
-        if self._msc3666_enabled:
-            aggregations = await self.store.get_bundled_aggregations(
-                # Generate an iterable of EventBase for all the events that will be
-                # returned, including contextual events.
-                itertools.chain(
-                    # The events_before and events_after for each context.
-                    itertools.chain.from_iterable(
-                        itertools.chain(context["events_before"], context["events_after"])  # type: ignore[arg-type]
-                        for context in contexts.values()
-                    ),
-                    # The returned events.
-                    allowed_events,
-                ),
-                user.to_string(),
+        contexts = {}
+        for event in allowed_events:
+            res = await self.store.get_events_around(
+                event.room_id, event.event_id, before_limit, after_limit
             )
 
-        for context in contexts.values():
-            context["events_before"] = self._event_serializer.serialize_events(
-                context["events_before"], time_now, bundle_aggregations=aggregations  # type: ignore[arg-type]
+            logger.info(
+                "Context for search returned %d and %d events",
+                len(res.events_before),
+                len(res.events_after),
             )
-            context["events_after"] = self._event_serializer.serialize_events(
-                context["events_after"], time_now, bundle_aggregations=aggregations  # type: ignore[arg-type]
+
+            events_before = await filter_events_for_client(
+                self.storage, user.to_string(), res.events_before
             )
 
-        state_results = {}
-        if include_state:
-            for room_id in {e.room_id for e in allowed_events}:
-                state = await self.state_handler.get_current_state(room_id)
-                state_results[room_id] = list(state.values())
+            events_after = await filter_events_for_client(
+                self.storage, user.to_string(), res.events_after
+            )
 
-        # We're now about to serialize the events. We should not make any
-        # blocking calls after this. Otherwise the 'age' will be wrong
+            context: JsonDict = {
+                "events_before": events_before,
+                "events_after": events_after,
+                "start": await now_token.copy_and_replace(
+                    "room_key", res.start
+                ).to_string(self.store),
+                "end": await now_token.copy_and_replace("room_key", res.end).to_string(
+                    self.store
+                ),
+            }
 
-        results = []
-        for e in allowed_events:
-            results.append(
-                {
-                    "rank": rank_map[e.event_id],
-                    "result": self._event_serializer.serialize_event(
-                        e, time_now, bundle_aggregations=aggregations
-                    ),
-                    "context": contexts.get(e.event_id, {}),
+            if include_profile:
+                senders = {
+                    ev.sender
+                    for ev in itertools.chain(events_before, [event], events_after)
                 }
-            )
 
-        rooms_cat_res = {
-            "results": results,
-            "count": count,
-            "highlights": list(highlights),
-        }
+                if events_after:
+                    last_event_id = events_after[-1].event_id
+                else:
+                    last_event_id = event.event_id
 
-        if state_results:
-            s = {}
-            for room_id, state_events in state_results.items():
-                s[room_id] = self._event_serializer.serialize_events(
-                    state_events, time_now
+                state_filter = StateFilter.from_types(
+                    [(EventTypes.Member, sender) for sender in senders]
                 )
 
-            rooms_cat_res["state"] = s
-
-        if room_groups and "room_id" in group_keys:
-            rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups
+                state = await self.state_store.get_state_for_event(
+                    last_event_id, state_filter
+                )
 
-        if sender_group and "sender" in group_keys:
-            rooms_cat_res.setdefault("groups", {})["sender"] = sender_group
+                context["profile_info"] = {
+                    s.state_key: {
+                        "displayname": s.content.get("displayname", None),
+                        "avatar_url": s.content.get("avatar_url", None),
+                    }
+                    for s in state.values()
+                    if s.type == EventTypes.Member and s.state_key in senders
+                }
 
-        if global_next_batch:
-            rooms_cat_res["next_batch"] = global_next_batch
+            contexts[event.event_id] = context
 
-        return {"search_categories": {"room_events": rooms_cat_res}}
+        return contexts
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index aa9a76f8a9..e6050cbce6 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1289,23 +1289,54 @@ class SyncHandler:
             # room with by looking at all users that have left a room plus users
             # that were in a room we've left.
 
-            users_who_share_room = await self.store.get_users_who_share_room_with_user(
-                user_id
-            )
-
-            # Always tell the user about their own devices. We check as the user
-            # ID is almost certainly already included (unless they're not in any
-            # rooms) and taking a copy of the set is relatively expensive.
-            if user_id not in users_who_share_room:
-                users_who_share_room = set(users_who_share_room)
-                users_who_share_room.add(user_id)
+            users_that_have_changed = set()
 
-            tracked_users = users_who_share_room
+            joined_rooms = sync_result_builder.joined_room_ids
 
-            # Step 1a, check for changes in devices of users we share a room with
-            users_that_have_changed = await self.store.get_users_whose_devices_changed(
-                since_token.device_list_key, tracked_users
+            # Step 1a, check for changes in devices of users we share a room
+            # with
+            #
+            # We do this in two different ways depending on what we have cached.
+            # If we already have a list of all the user that have changed since
+            # the last sync then it's likely more efficient to compare the rooms
+            # they're in with the rooms the syncing user is in.
+            #
+            # If we don't have that info cached then we get all the users that
+            # share a room with our user and check if those users have changed.
+            changed_users = self.store.get_cached_device_list_changes(
+                since_token.device_list_key
             )
+            if changed_users is not None:
+                result = await self.store.get_rooms_for_users_with_stream_ordering(
+                    changed_users
+                )
+
+                for changed_user_id, entries in result.items():
+                    # Check if the changed user shares any rooms with the user,
+                    # or if the changed user is the syncing user (as we always
+                    # want to include device list updates of their own devices).
+                    if user_id == changed_user_id or any(
+                        e.room_id in joined_rooms for e in entries
+                    ):
+                        users_that_have_changed.add(changed_user_id)
+            else:
+                users_who_share_room = (
+                    await self.store.get_users_who_share_room_with_user(user_id)
+                )
+
+                # Always tell the user about their own devices. We check as the user
+                # ID is almost certainly already included (unless they're not in any
+                # rooms) and taking a copy of the set is relatively expensive.
+                if user_id not in users_who_share_room:
+                    users_who_share_room = set(users_who_share_room)
+                    users_who_share_room.add(user_id)
+
+                tracked_users = users_who_share_room
+                users_that_have_changed = (
+                    await self.store.get_users_whose_devices_changed(
+                        since_token.device_list_key, tracked_users
+                    )
+                )
 
             # Step 1b, check for newly joined rooms
             for room_id in newly_joined_rooms:
@@ -1329,7 +1360,14 @@ class SyncHandler:
                 newly_left_users.update(left_users)
 
             # Remove any users that we still share a room with.
-            newly_left_users -= users_who_share_room
+            left_users_rooms = (
+                await self.store.get_rooms_for_users_with_stream_ordering(
+                    newly_left_users
+                )
+            )
+            for user_id, entries in left_users_rooms.items():
+                if any(e.room_id in joined_rooms for e in entries):
+                    newly_left_users.discard(user_id)
 
             return DeviceLists(changed=users_that_have_changed, left=newly_left_users)
         else:
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index c5f8fcbb2a..e7656fbb9f 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -958,6 +958,7 @@ class MatrixFederationHttpClient:
         )
         return body
 
+    @overload
     async def get_json(
         self,
         destination: str,
@@ -967,7 +968,38 @@ class MatrixFederationHttpClient:
         timeout: Optional[int] = None,
         ignore_backoff: bool = False,
         try_trailing_slash_on_400: bool = False,
+        parser: Literal[None] = None,
+        max_response_size: Optional[int] = None,
     ) -> Union[JsonDict, list]:
+        ...
+
+    @overload
+    async def get_json(
+        self,
+        destination: str,
+        path: str,
+        args: Optional[QueryArgs] = ...,
+        retry_on_dns_fail: bool = ...,
+        timeout: Optional[int] = ...,
+        ignore_backoff: bool = ...,
+        try_trailing_slash_on_400: bool = ...,
+        parser: ByteParser[T] = ...,
+        max_response_size: Optional[int] = ...,
+    ) -> T:
+        ...
+
+    async def get_json(
+        self,
+        destination: str,
+        path: str,
+        args: Optional[QueryArgs] = None,
+        retry_on_dns_fail: bool = True,
+        timeout: Optional[int] = None,
+        ignore_backoff: bool = False,
+        try_trailing_slash_on_400: bool = False,
+        parser: Optional[ByteParser] = None,
+        max_response_size: Optional[int] = None,
+    ):
         """GETs some json from the given host homeserver and path
 
         Args:
@@ -992,6 +1024,13 @@ class MatrixFederationHttpClient:
             try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
                 response we should try appending a trailing slash to the end of
                 the request. Workaround for #3622 in Synapse <= v0.99.3.
+
+            parser: The parser to use to decode the response. Defaults to
+                parsing as JSON.
+
+            max_response_size: The maximum size to read from the response. If None,
+                uses the default.
+
         Returns:
             Succeeds when we get a 2xx HTTP response. The
             result will be the decoded JSON body.
@@ -1026,8 +1065,17 @@ class MatrixFederationHttpClient:
         else:
             _sec_timeout = self.default_timeout
 
+        if parser is None:
+            parser = JsonParser()
+
         body = await _handle_response(
-            self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
+            self.reactor,
+            _sec_timeout,
+            request,
+            response,
+            start_ms,
+            parser=parser,
+            max_response_size=max_response_size,
         )
 
         return body
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
deleted file mode 100644
index b9933a1528..0000000000
--- a/synapse/logging/_structured.py
+++ /dev/null
@@ -1,163 +0,0 @@
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import os.path
-from typing import Any, Dict, Generator, Optional, Tuple
-
-from constantly import NamedConstant, Names
-
-from synapse.config._base import ConfigError
-
-
-class DrainType(Names):
-    CONSOLE = NamedConstant()
-    CONSOLE_JSON = NamedConstant()
-    CONSOLE_JSON_TERSE = NamedConstant()
-    FILE = NamedConstant()
-    FILE_JSON = NamedConstant()
-    NETWORK_JSON_TERSE = NamedConstant()
-
-
-DEFAULT_LOGGERS = {"synapse": {"level": "info"}}
-
-
-def parse_drain_configs(
-    drains: dict,
-) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
-    """
-    Parse the drain configurations.
-
-    Args:
-        drains (dict): A list of drain configurations.
-
-    Yields:
-        dict instances representing a logging handler.
-
-    Raises:
-        ConfigError: If any of the drain configuration items are invalid.
-    """
-
-    for name, config in drains.items():
-        if "type" not in config:
-            raise ConfigError("Logging drains require a 'type' key.")
-
-        try:
-            logging_type = DrainType.lookupByName(config["type"].upper())
-        except ValueError:
-            raise ConfigError(
-                "%s is not a known logging drain type." % (config["type"],)
-            )
-
-        # Either use the default formatter or the tersejson one.
-        if logging_type in (
-            DrainType.CONSOLE_JSON,
-            DrainType.FILE_JSON,
-        ):
-            formatter: Optional[str] = "json"
-        elif logging_type in (
-            DrainType.CONSOLE_JSON_TERSE,
-            DrainType.NETWORK_JSON_TERSE,
-        ):
-            formatter = "tersejson"
-        else:
-            # A formatter of None implies using the default formatter.
-            formatter = None
-
-        if logging_type in [
-            DrainType.CONSOLE,
-            DrainType.CONSOLE_JSON,
-            DrainType.CONSOLE_JSON_TERSE,
-        ]:
-            location = config.get("location")
-            if location is None or location not in ["stdout", "stderr"]:
-                raise ConfigError(
-                    (
-                        "The %s drain needs the 'location' key set to "
-                        "either 'stdout' or 'stderr'."
-                    )
-                    % (logging_type,)
-                )
-
-            yield name, {
-                "class": "logging.StreamHandler",
-                "formatter": formatter,
-                "stream": "ext://sys." + location,
-            }
-
-        elif logging_type in [DrainType.FILE, DrainType.FILE_JSON]:
-            if "location" not in config:
-                raise ConfigError(
-                    "The %s drain needs the 'location' key set." % (logging_type,)
-                )
-
-            location = config.get("location")
-            if os.path.abspath(location) != location:
-                raise ConfigError(
-                    "File paths need to be absolute, '%s' is a relative path"
-                    % (location,)
-                )
-
-            yield name, {
-                "class": "logging.FileHandler",
-                "formatter": formatter,
-                "filename": location,
-            }
-
-        elif logging_type in [DrainType.NETWORK_JSON_TERSE]:
-            host = config.get("host")
-            port = config.get("port")
-            maximum_buffer = config.get("maximum_buffer", 1000)
-
-            yield name, {
-                "class": "synapse.logging.RemoteHandler",
-                "formatter": formatter,
-                "host": host,
-                "port": port,
-                "maximum_buffer": maximum_buffer,
-            }
-
-        else:
-            raise ConfigError(
-                "The %s drain type is currently not implemented."
-                % (config["type"].upper(),)
-            )
-
-
-def setup_structured_logging(
-    log_config: dict,
-) -> dict:
-    """
-    Convert a legacy structured logging configuration (from Synapse < v1.23.0)
-    to one compatible with the new standard library handlers.
-    """
-    if "drains" not in log_config:
-        raise ConfigError("The logging configuration requires a list of drains.")
-
-    new_config = {
-        "version": 1,
-        "formatters": {
-            "json": {"class": "synapse.logging.JsonFormatter"},
-            "tersejson": {"class": "synapse.logging.TerseJsonFormatter"},
-        },
-        "handlers": {},
-        "loggers": log_config.get("loggers", DEFAULT_LOGGERS),
-        "root": {"handlers": []},
-    }
-
-    for handler_name, handler in parse_drain_configs(log_config["drains"]):
-        new_config["handlers"][handler_name] = handler
-
-        # Add each handler to the root logger.
-        new_config["root"]["handlers"].append(handler_name)
-
-    return new_config
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index d4fca36923..07020bfb8d 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -70,6 +70,7 @@ from synapse.handlers.account_validity import (
 from synapse.handlers.auth import (
     CHECK_3PID_AUTH_CALLBACK,
     CHECK_AUTH_CALLBACK,
+    GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK,
     GET_USERNAME_FOR_REGISTRATION_CALLBACK,
     IS_3PID_ALLOWED_CALLBACK,
     ON_LOGGED_OUT_CALLBACK,
@@ -317,6 +318,9 @@ class ModuleApi:
         get_username_for_registration: Optional[
             GET_USERNAME_FOR_REGISTRATION_CALLBACK
         ] = None,
+        get_displayname_for_registration: Optional[
+            GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
+        ] = None,
     ) -> None:
         """Registers callbacks for password auth provider capabilities.
 
@@ -328,6 +332,7 @@ class ModuleApi:
             is_3pid_allowed=is_3pid_allowed,
             auth_checkers=auth_checkers,
             get_username_for_registration=get_username_for_registration,
+            get_displayname_for_registration=get_displayname_for_registration,
         )
 
     def register_background_update_controller_callbacks(
@@ -648,7 +653,11 @@ class ModuleApi:
         Added in Synapse v1.9.0.
 
         Args:
-            auth_provider: identifier for the remote auth provider
+            auth_provider: identifier for the remote auth provider, see `sso` and
+                `oidc_providers` in the homeserver configuration.
+
+                Note that no error is raised if the provided value is not in the
+                homeserver configuration.
             external_id: id on that system
             user_id: complete mxid that it is mapped to
         """
diff --git a/synapse/notifier.py b/synapse/notifier.py
index e0fad2da66..753dd6b6a5 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -138,7 +138,7 @@ class _NotifierUserStream:
         self.current_token = self.current_token.copy_and_advance(stream_key, stream_id)
         self.last_notified_token = self.current_token
         self.last_notified_ms = time_now_ms
-        noify_deferred = self.notify_deferred
+        notify_deferred = self.notify_deferred
 
         log_kv(
             {
@@ -153,7 +153,7 @@ class _NotifierUserStream:
 
         with PreserveLoggingContext():
             self.notify_deferred = ObservableDeferred(defer.Deferred())
-            noify_deferred.callback(self.current_token)
+            notify_deferred.callback(self.current_token)
 
     def remove(self, notifier: "Notifier") -> None:
         """Remove this listener from all the indexes in the Notifier
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 910b05c0da..832eaa34e9 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -130,7 +130,9 @@ def make_base_prepend_rules(
     return rules
 
 
-BASE_APPEND_CONTENT_RULES = [
+# We have to annotate these types, otherwise mypy infers them as
+# `List[Dict[str, Sequence[Collection[str]]]]`.
+BASE_APPEND_CONTENT_RULES: List[Dict[str, Any]] = [
     {
         "rule_id": "global/content/.m.rule.contains_user_name",
         "conditions": [
@@ -149,7 +151,7 @@ BASE_APPEND_CONTENT_RULES = [
 ]
 
 
-BASE_PREPEND_OVERRIDE_RULES = [
+BASE_PREPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
     {
         "rule_id": "global/override/.m.rule.master",
         "enabled": False,
@@ -159,7 +161,7 @@ BASE_PREPEND_OVERRIDE_RULES = [
 ]
 
 
-BASE_APPEND_OVERRIDE_RULES = [
+BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
     {
         "rule_id": "global/override/.m.rule.suppress_notices",
         "conditions": [
@@ -278,7 +280,7 @@ BASE_APPEND_OVERRIDE_RULES = [
 ]
 
 
-BASE_APPEND_UNDERRIDE_RULES = [
+BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
     {
         "rule_id": "global/underride/.m.rule.call",
         "conditions": [
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 96559081d0..52c7ff3572 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -109,6 +109,7 @@ class HttpPusher(Pusher):
         self.data_minus_url = {}
         self.data_minus_url.update(self.data)
         del self.data_minus_url["url"]
+        self.badge_count_last_call: Optional[int] = None
 
     def on_started(self, should_check_for_notifs: bool) -> None:
         """Called when this pusher has been started.
@@ -136,7 +137,9 @@ class HttpPusher(Pusher):
             self.user_id,
             group_by_room=self._group_unread_count_by_room,
         )
-        await self._send_badge(badge)
+        if self.badge_count_last_call is None or self.badge_count_last_call != badge:
+            self.badge_count_last_call = badge
+            await self._send_badge(badge)
 
     def on_timer(self) -> None:
         self._start_processing()
@@ -322,7 +325,7 @@ class HttpPusher(Pusher):
         # This was checked in the __init__, but mypy doesn't seem to know that.
         assert self.data is not None
         if self.data.get("format") == "event_id_only":
-            d = {
+            d: Dict[str, Any] = {
                 "notification": {
                     "event_id": event.event_id,
                     "room_id": event.room_id,
@@ -402,6 +405,8 @@ class HttpPusher(Pusher):
         rejected = []
         if "rejected" in resp:
             rejected = resp["rejected"]
+        else:
+            self.badge_count_last_call = badge
         return rejected
 
     async def _send_badge(self, badge: int) -> None:
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 86162e0f2c..f43fbb5842 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -87,7 +87,8 @@ REQUIREMENTS = [
     # We enforce that we have a `cryptography` version that bundles an `openssl`
     # with the latest security patches.
     "cryptography>=3.4.7",
-    "ijson>=3.1",
+    # ijson 3.1.4 fixes a bug with "." in property names
+    "ijson>=3.1.4",
     "matrix-common~=1.1.0",
 ]
 
diff --git a/synapse/res/providers.json b/synapse/res/providers.json
index f1838f9559..7b9958e454 100644
--- a/synapse/res/providers.json
+++ b/synapse/res/providers.json
@@ -5,8 +5,6 @@
         "endpoints": [
             {
                 "schemes": [
-                    "https://twitter.com/*/status/*",
-                    "https://*.twitter.com/*/status/*",
                     "https://twitter.com/*/moments/*",
                     "https://*.twitter.com/*/moments/*"
                 ],
@@ -14,4 +12,4 @@
             }
         ]
     }
-]
\ No newline at end of file
+]
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index cfa2aee76d..efe299e698 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -883,7 +883,9 @@ class WhoamiRestServlet(RestServlet):
         response = {
             "user_id": requester.user.to_string(),
             # MSC: https://github.com/matrix-org/matrix-doc/pull/3069
+            # Entered spec in Matrix 1.2
             "org.matrix.msc3069.is_guest": bool(requester.is_guest),
+            "is_guest": bool(requester.is_guest),
         }
 
         # Appservices and similar accounts do not have device IDs
diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py
index 9c15a04338..e0b2b80e5b 100644
--- a/synapse/rest/client/auth.py
+++ b/synapse/rest/client/auth.py
@@ -62,7 +62,7 @@ class AuthRestServlet(RestServlet):
         if stagetype == LoginType.RECAPTCHA:
             html = self.recaptcha_template.render(
                 session=session,
-                myurl="%s/r0/auth/%s/fallback/web"
+                myurl="%s/v3/auth/%s/fallback/web"
                 % (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
                 sitekey=self.hs.config.captcha.recaptcha_public_key,
             )
@@ -74,7 +74,7 @@ class AuthRestServlet(RestServlet):
                     self.hs.config.server.public_baseurl,
                     self.hs.config.consent.user_consent_version,
                 ),
-                myurl="%s/r0/auth/%s/fallback/web"
+                myurl="%s/v3/auth/%s/fallback/web"
                 % (CLIENT_API_PREFIX, LoginType.TERMS),
             )
 
@@ -118,7 +118,7 @@ class AuthRestServlet(RestServlet):
                 # Authentication failed, let user try again
                 html = self.recaptcha_template.render(
                     session=session,
-                    myurl="%s/r0/auth/%s/fallback/web"
+                    myurl="%s/v3/auth/%s/fallback/web"
                     % (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
                     sitekey=self.hs.config.captcha.recaptcha_public_key,
                     error=e.msg,
@@ -143,7 +143,7 @@ class AuthRestServlet(RestServlet):
                         self.hs.config.server.public_baseurl,
                         self.hs.config.consent.user_consent_version,
                     ),
-                    myurl="%s/r0/auth/%s/fallback/web"
+                    myurl="%s/v3/auth/%s/fallback/web"
                     % (CLIENT_API_PREFIX, LoginType.TERMS),
                     error=e.msg,
                 )
diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py
index 6682da077a..e05c926b6f 100644
--- a/synapse/rest/client/capabilities.py
+++ b/synapse/rest/client/capabilities.py
@@ -72,20 +72,6 @@ class CapabilitiesRestServlet(RestServlet):
                 "org.matrix.msc3244.room_capabilities"
             ] = MSC3244_CAPABILITIES
 
-        # Must be removed in later versions.
-        # Is only included for migration.
-        # Also the parts in `synapse/config/experimental.py`.
-        if self.config.experimental.msc3283_enabled:
-            response["capabilities"]["org.matrix.msc3283.set_displayname"] = {
-                "enabled": self.config.registration.enable_set_displayname
-            }
-            response["capabilities"]["org.matrix.msc3283.set_avatar_url"] = {
-                "enabled": self.config.registration.enable_set_avatar_url
-            }
-            response["capabilities"]["org.matrix.msc3283.3pid_changes"] = {
-                "enabled": self.config.registration.enable_3pid_changes
-            }
-
         if self.config.experimental.msc3440_enabled:
             response["capabilities"]["io.element.thread"] = {"enabled": True}
 
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index c965e2bda2..b8a5135e02 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -694,11 +694,18 @@ class RegisterRestServlet(RestServlet):
                 session_id
             )
 
+            display_name = await (
+                self.password_auth_provider.get_displayname_for_registration(
+                    auth_result, params
+                )
+            )
+
             registered_user_id = await self.registration_handler.register_user(
                 localpart=desired_username,
                 password_hash=password_hash,
                 guest_access_token=guest_access_token,
                 threepid=threepid,
+                default_display_name=display_name,
                 address=client_addr,
                 user_agent_ips=entries,
             )
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 2290c57c12..00f29344a8 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -73,6 +73,8 @@ class VersionsRestServlet(RestServlet):
                     "r0.5.0",
                     "r0.6.0",
                     "r0.6.1",
+                    "v1.1",
+                    "v1.2",
                 ],
                 # as per MSC1497:
                 "unstable_features": {
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 8d3d1e54dc..c08b60d10a 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -402,7 +402,15 @@ class PreviewUrlResource(DirectServeJsonResource):
                 url,
                 output_stream=output_stream,
                 max_size=self.max_spider_size,
-                headers={"Accept-Language": self.url_preview_accept_language},
+                headers={
+                    b"Accept-Language": self.url_preview_accept_language,
+                    # Use a custom user agent for the preview because some sites will only return
+                    # Open Graph metadata to crawler user agents. Omit the Synapse version
+                    # string to avoid leaking information.
+                    b"User-Agent": [
+                        "Synapse (bot; +https://github.com/matrix-org/synapse)"
+                    ],
+                },
                 is_allowed_content_type=_is_previewable,
             )
         except SynapseError:
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 8d845fe951..3b3a089b76 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -670,6 +670,16 @@ class DeviceWorkerStore(SQLBaseStore):
             device["device_id"]: db_to_json(device["content"]) for device in devices
         }
 
+    def get_cached_device_list_changes(
+        self,
+        from_key: int,
+    ) -> Optional[Set[str]]:
+        """Get set of users whose devices have changed since `from_key`, or None
+        if that information is not in our cache.
+        """
+
+        return self._device_list_stream_cache.get_all_entities_changed(from_key)
+
     async def get_users_whose_devices_changed(
         self, from_key: int, user_ids: Iterable[str]
     ) -> Set[str]:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 5246fccad5..a1d7a9b413 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -975,6 +975,17 @@ class PersistEventsStore:
             to_delete = delta_state.to_delete
             to_insert = delta_state.to_insert
 
+            # Figure out the changes of membership to invalidate the
+            # `get_rooms_for_user` cache.
+            # We find out which membership events we may have deleted
+            # and which we have added, then we invalidate the caches for all
+            # those users.
+            members_changed = {
+                state_key
+                for ev_type, state_key in itertools.chain(to_delete, to_insert)
+                if ev_type == EventTypes.Member
+            }
+
             if delta_state.no_longer_in_room:
                 # Server is no longer in the room so we delete the room from
                 # current_state_events, being careful we've already updated the
@@ -993,6 +1004,11 @@ class PersistEventsStore:
                 """
                 txn.execute(sql, (stream_id, self._instance_name, room_id))
 
+                # We also want to invalidate the membership caches for users
+                # that were in the room.
+                users_in_room = self.store.get_users_in_room_txn(txn, room_id)
+                members_changed.update(users_in_room)
+
                 self.db_pool.simple_delete_txn(
                     txn,
                     table="current_state_events",
@@ -1102,17 +1118,6 @@ class PersistEventsStore:
 
             # Invalidate the various caches
 
-            # Figure out the changes of membership to invalidate the
-            # `get_rooms_for_user` cache.
-            # We find out which membership events we may have deleted
-            # and which we have added, then we invalidate the caches for all
-            # those users.
-            members_changed = {
-                state_key
-                for ev_type, state_key in itertools.chain(to_delete, to_insert)
-                if ev_type == EventTypes.Member
-            }
-
             for member in members_changed:
                 txn.call_after(
                     self.store.get_rooms_for_user_with_stream_ordering.invalidate,
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 8d4287045a..2a255d1031 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -408,7 +408,7 @@ class EventsWorkerStore(SQLBaseStore):
                 include the previous states content in the unsigned field.
 
             allow_rejected: If True, return rejected events. Otherwise,
-                omits rejeted events from the response.
+                omits rejected events from the response.
 
         Returns:
             A mapping from event_id to event.
@@ -1854,7 +1854,7 @@ class EventsWorkerStore(SQLBaseStore):
             forward_edge_query = """
                 SELECT 1 FROM event_edges
                 /* Check to make sure the event referencing our event in question is not rejected */
-                LEFT JOIN rejections ON event_edges.event_id == rejections.event_id
+                LEFT JOIN rejections ON event_edges.event_id = rejections.event_id
                 WHERE
                     event_edges.room_id = ?
                     AND event_edges.prev_event_id = ?
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 4f05811a77..d3c4611686 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -12,15 +12,23 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast
 
 from synapse.api.presence import PresenceState, UserPresenceState
 from synapse.replication.tcp.streams import PresenceStream
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import (
+    AbstractStreamIdGenerator,
+    MultiWriterIdGenerator,
+    StreamIdGenerator,
+)
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.iterutils import batch_iter
@@ -35,7 +43,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore):
         database: DatabasePool,
         db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
-    ):
+    ) -> None:
         super().__init__(database, db_conn, hs)
 
         # Used by `PresenceStore._get_active_presence()`
@@ -54,11 +62,14 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         database: DatabasePool,
         db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
-    ):
+    ) -> None:
         super().__init__(database, db_conn, hs)
 
+        self._instance_name = hs.get_instance_name()
+        self._presence_id_gen: AbstractStreamIdGenerator
+
         self._can_persist_presence = (
-            hs.get_instance_name() in hs.config.worker.writers.presence
+            self._instance_name in hs.config.worker.writers.presence
         )
 
         if isinstance(database.engine, PostgresEngine):
@@ -109,7 +120,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
 
         return stream_orderings[-1], self._presence_id_gen.get_current_token()
 
-    def _update_presence_txn(self, txn, stream_orderings, presence_states):
+    def _update_presence_txn(
+        self, txn: LoggingTransaction, stream_orderings, presence_states
+    ) -> None:
         for stream_id, state in zip(stream_orderings, presence_states):
             txn.call_after(
                 self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
@@ -183,19 +196,23 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def get_all_presence_updates_txn(txn):
+        def get_all_presence_updates_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, list]], int, bool]:
             sql = """
                 SELECT stream_id, user_id, state, last_active_ts,
                     last_federation_update_ts, last_user_sync_ts,
-                    status_msg,
-                currently_active
+                    status_msg, currently_active
                 FROM presence_stream
                 WHERE ? < stream_id AND stream_id <= ?
                 ORDER BY stream_id ASC
                 LIMIT ?
             """
             txn.execute(sql, (last_id, current_id, limit))
-            updates = [(row[0], row[1:]) for row in txn]
+            updates = cast(
+                List[Tuple[int, list]],
+                [(row[0], row[1:]) for row in txn],
+            )
 
             upper_bound = current_id
             limited = False
@@ -210,7 +227,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         )
 
     @cached()
-    def _get_presence_for_user(self, user_id):
+    def _get_presence_for_user(self, user_id: str) -> None:
         raise NotImplementedError()
 
     @cachedList(
@@ -218,7 +235,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         list_name="user_ids",
         num_args=1,
     )
-    async def get_presence_for_users(self, user_ids):
+    async def get_presence_for_users(
+        self, user_ids: Iterable[str]
+    ) -> Dict[str, UserPresenceState]:
         rows = await self.db_pool.simple_select_many_batch(
             table="presence_stream",
             column="user_id",
@@ -257,7 +276,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
             True if the user should have full presence sent to them, False otherwise.
         """
 
-        def _should_user_receive_full_presence_with_token_txn(txn):
+        def _should_user_receive_full_presence_with_token_txn(
+            txn: LoggingTransaction,
+        ) -> bool:
             sql = """
                 SELECT 1 FROM users_to_send_full_presence_to
                 WHERE user_id = ?
@@ -271,7 +292,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
             _should_user_receive_full_presence_with_token_txn,
         )
 
-    async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]):
+    async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]) -> None:
         """Adds to the list of users who should receive a full snapshot of presence
         upon their next sync.
 
@@ -353,10 +374,10 @@ class PresenceStore(PresenceBackgroundUpdateStore):
 
         return users_to_state
 
-    def get_current_presence_token(self):
+    def get_current_presence_token(self) -> int:
         return self._presence_id_gen.get_current_token()
 
-    def _get_active_presence(self, db_conn: Connection):
+    def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]:
         """Fetch non-offline presence from the database so that we can register
         the appropriate time outs.
         """
@@ -379,12 +400,12 @@ class PresenceStore(PresenceBackgroundUpdateStore):
 
         return [UserPresenceState(**row) for row in rows]
 
-    def take_presence_startup_info(self):
+    def take_presence_startup_info(self) -> List[UserPresenceState]:
         active_on_startup = self._presence_on_startup
-        self._presence_on_startup = None
+        self._presence_on_startup = []
         return active_on_startup
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows) -> None:
         if stream_name == PresenceStream.NAME:
             self._presence_id_gen.advance(instance_name, token)
             for row in rows:
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index e87a8fb85d..2e3818e432 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -13,9 +13,10 @@
 # limitations under the License.
 
 import logging
-from typing import Any, List, Set, Tuple
+from typing import Any, List, Set, Tuple, cast
 
 from synapse.api.errors import SynapseError
+from synapse.storage.database import LoggingTransaction
 from synapse.storage.databases.main import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.state import StateGroupWorkerStore
 from synapse.types import RoomStreamToken
@@ -55,7 +56,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         )
 
     def _purge_history_txn(
-        self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        token: RoomStreamToken,
+        delete_local_events: bool,
     ) -> Set[int]:
         # Tables that should be pruned:
         #     event_auth
@@ -273,7 +278,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         """,
             (room_id,),
         )
-        (min_depth,) = txn.fetchone()
+        (min_depth,) = cast(Tuple[int], txn.fetchone())
 
         logger.info("[purge] updating room_depth to %d", min_depth)
 
@@ -318,7 +323,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             "purge_room", self._purge_room_txn, room_id
         )
 
-    def _purge_room_txn(self, txn, room_id: str) -> List[int]:
+    def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]:
         # First we fetch all the state groups that should be deleted, before
         # we delete that information.
         txn.execute(
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index aac94fa464..17110bb033 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -622,10 +622,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
     ) -> None:
         """Record a mapping from an external user id to a mxid
 
+        See notes in _record_user_external_id_txn about what constitutes valid data.
+
         Args:
             auth_provider: identifier for the remote auth provider
             external_id: id on that system
             user_id: complete mxid that it is mapped to
+
         Raises:
             ExternalIDReuseException if the new external_id could not be mapped.
         """
@@ -648,6 +651,21 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         external_id: str,
         user_id: str,
     ) -> None:
+        """
+        Record a mapping from an external user id to a mxid.
+
+        Note that the auth provider IDs (and the external IDs) are not validated
+        against configured IdPs as Synapse does not know its relationship to
+        external systems. For example, it might be useful to pre-configure users
+        before enabling a new IdP or an IdP might be temporarily offline, but
+        still valid.
+
+        Args:
+            txn: The database transaction.
+            auth_provider: identifier for the remote auth provider
+            external_id: id on that system
+            user_id: complete mxid that it is mapped to
+        """
 
         self.db_pool.simple_insert_txn(
             txn,
@@ -687,10 +705,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         """Replace mappings from external user ids to a mxid in a single transaction.
         All mappings are deleted and the new ones are created.
 
+        See notes in _record_user_external_id_txn about what constitutes valid data.
+
         Args:
             record_external_ids:
                 List with tuple of auth_provider and external_id to record
             user_id: complete mxid that it is mapped to
+
         Raises:
             ExternalIDReuseException if the new external_id could not be mapped.
         """
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index e2c27e594b..36aa1092f6 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -53,8 +53,13 @@ logger = logging.getLogger(__name__)
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class _ThreadAggregation:
+    # The latest event in the thread.
     latest_event: EventBase
+    # The latest edit to the latest event in the thread.
+    latest_edit: Optional[EventBase]
+    # The total number of events in the thread.
     count: int
+    # True if the current user has sent an event to the thread.
     current_user_participated: bool
 
 
@@ -461,8 +466,8 @@ class RelationsWorkerStore(SQLBaseStore):
     @cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
     async def _get_thread_summaries(
         self, event_ids: Collection[str]
-    ) -> Dict[str, Optional[Tuple[int, EventBase]]]:
-        """Get the number of threaded replies and the latest reply (if any) for the given event.
+    ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
+        """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
 
         Args:
             event_ids: Summarize the thread related to this event ID.
@@ -471,8 +476,10 @@ class RelationsWorkerStore(SQLBaseStore):
             A map of the thread summary each event. A missing event implies there
             are no threaded replies.
 
-            Each summary includes the number of items in the thread and the most
-            recent response.
+            Each summary is a tuple of:
+                The number of events in the thread.
+                The most recent event in the thread.
+                The most recent edit to the most recent event in the thread, if applicable.
         """
 
         def _get_thread_summaries_txn(
@@ -482,7 +489,7 @@ class RelationsWorkerStore(SQLBaseStore):
             # TODO Should this only allow m.room.message events.
             if isinstance(self.database_engine, PostgresEngine):
                 # The `DISTINCT ON` clause will pick the *first* row it encounters,
-                # so ordering by topologica ordering + stream ordering desc will
+                # so ordering by topological ordering + stream ordering desc will
                 # ensure we get the latest event in the thread.
                 sql = """
                     SELECT DISTINCT ON (parent.event_id) parent.event_id, child.event_id FROM events AS child
@@ -558,6 +565,9 @@ class RelationsWorkerStore(SQLBaseStore):
 
         latest_events = await self.get_events(latest_event_ids.values())  # type: ignore[attr-defined]
 
+        # Check to see if any of those events are edited.
+        latest_edits = await self._get_applicable_edits(latest_event_ids.values())
+
         # Map to the event IDs to the thread summary.
         #
         # There might not be a summary due to there not being a thread or
@@ -568,7 +578,8 @@ class RelationsWorkerStore(SQLBaseStore):
 
             summary = None
             if latest_event:
-                summary = (counts[parent_event_id], latest_event)
+                latest_edit = latest_edits.get(latest_event_id)
+                summary = (counts[parent_event_id], latest_event, latest_edit)
             summaries[parent_event_id] = summary
 
         return summaries
@@ -828,11 +839,12 @@ class RelationsWorkerStore(SQLBaseStore):
             )
             for event_id, summary in summaries.items():
                 if summary:
-                    thread_count, latest_thread_event = summary
+                    thread_count, latest_thread_event, edit = summary
                     results.setdefault(
                         event_id, BundledAggregations()
                     ).thread = _ThreadAggregation(
                         latest_event=latest_thread_event,
+                        latest_edit=edit,
                         count=thread_count,
                         # If there's a thread summary it must also exist in the
                         # participated dictionary.
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 95167116c9..0416df64ce 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1498,7 +1498,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
         self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
 
     async def upsert_room_on_join(
-        self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase]
+        self, room_id: str, room_version: RoomVersion, state_events: List[EventBase]
     ) -> None:
         """Ensure that the room is stored in the table
 
@@ -1511,7 +1511,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
         has_auth_chain_index = await self.has_auth_chain_index(room_id)
 
         create_event = None
-        for e in auth_events:
+        for e in state_events:
             if (e.type, e.state_key) == (EventTypes.Create, ""):
                 create_event = e
                 break
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 4489732fda..e48ec5f495 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -504,6 +504,68 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             for room_id, instance, stream_id in txn
         )
 
+    @cachedList(
+        cached_method_name="get_rooms_for_user_with_stream_ordering",
+        list_name="user_ids",
+    )
+    async def get_rooms_for_users_with_stream_ordering(
+        self, user_ids: Collection[str]
+    ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
+        """A batched version of `get_rooms_for_user_with_stream_ordering`.
+
+        Returns:
+            Map from user_id to set of rooms that is currently in.
+        """
+        return await self.db_pool.runInteraction(
+            "get_rooms_for_users_with_stream_ordering",
+            self._get_rooms_for_users_with_stream_ordering_txn,
+            user_ids,
+        )
+
+    def _get_rooms_for_users_with_stream_ordering_txn(
+        self, txn, user_ids: Collection[str]
+    ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
+
+        clause, args = make_in_list_sql_clause(
+            self.database_engine,
+            "c.state_key",
+            user_ids,
+        )
+
+        if self._current_state_events_membership_up_to_date:
+            sql = f"""
+                SELECT c.state_key, room_id, e.instance_name, e.stream_ordering
+                FROM current_state_events AS c
+                INNER JOIN events AS e USING (room_id, event_id)
+                WHERE
+                    c.type = 'm.room.member'
+                    AND c.membership = ?
+                    AND {clause}
+            """
+        else:
+            sql = f"""
+                SELECT c.state_key, room_id, e.instance_name, e.stream_ordering
+                FROM current_state_events AS c
+                INNER JOIN room_memberships AS m USING (room_id, event_id)
+                INNER JOIN events AS e USING (room_id, event_id)
+                WHERE
+                    c.type = 'm.room.member'
+                    AND m.membership = ?
+                    AND {clause}
+            """
+
+        txn.execute(sql, [Membership.JOIN] + args)
+
+        result = {user_id: set() for user_id in user_ids}
+        for user_id, room_id, instance, stream_id in txn:
+            result[user_id].add(
+                GetRoomsForUserWithStreamOrdering(
+                    room_id, PersistedEventPosition(instance, stream_id)
+                )
+            )
+
+        return {user_id: frozenset(v) for user_id, v in result.items()}
+
     async def get_users_server_still_shares_room_with(
         self, user_ids: Collection[str]
     ) -> Set[str]:
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 2d085a5764..acea300ed3 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -28,6 +28,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.types import JsonDict
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -381,17 +382,19 @@ class SearchStore(SearchBackgroundUpdateStore):
     ):
         super().__init__(database, db_conn, hs)
 
-    async def search_msgs(self, room_ids, search_term, keys):
+    async def search_msgs(
+        self, room_ids: Collection[str], search_term: str, keys: Iterable[str]
+    ) -> JsonDict:
         """Performs a full text search over events with given keys.
 
         Args:
-            room_ids (list): List of room ids to search in
-            search_term (str): Search term to search for
-            keys (list): List of keys to search in, currently supports
+            room_ids: List of room ids to search in
+            search_term: Search term to search for
+            keys: List of keys to search in, currently supports
                 "content.body", "content.name", "content.topic"
 
         Returns:
-            list of dicts
+            Dictionary of results
         """
         clauses = []
 
@@ -499,10 +502,10 @@ class SearchStore(SearchBackgroundUpdateStore):
         self,
         room_ids: Collection[str],
         search_term: str,
-        keys: List[str],
+        keys: Iterable[str],
         limit,
         pagination_token: Optional[str] = None,
-    ) -> List[dict]:
+    ) -> JsonDict:
         """Performs a full text search over events with given keys.
 
         Args:
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index f7c778bdf2..e7fddd2426 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -58,7 +58,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         database: DatabasePool,
         db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
-    ):
+    ) -> None:
         super().__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
@@ -234,10 +234,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         processed_event_count = 0
 
         for room_id, event_count in rooms_to_work_on:
-            is_in_room = await self.is_host_joined(room_id, self.server_name)
+            is_in_room = await self.is_host_joined(room_id, self.server_name)  # type: ignore[attr-defined]
 
             if is_in_room:
-                users_with_profile = await self.get_users_in_room_with_profiles(room_id)
+                users_with_profile = await self.get_users_in_room_with_profiles(room_id)  # type: ignore[attr-defined]
                 # Throw away users excluded from the directory.
                 users_with_profile = {
                     user_id: profile
@@ -368,7 +368,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
 
         for user_id in users_to_work_on:
             if await self.should_include_local_user_in_dir(user_id):
-                profile = await self.get_profileinfo(get_localpart_from_id(user_id))
+                profile = await self.get_profileinfo(get_localpart_from_id(user_id))  # type: ignore[attr-defined]
                 await self.update_profile_in_user_dir(
                     user_id, profile.display_name, profile.avatar_url
                 )
@@ -397,7 +397,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         # technically it could be DM-able. In the future, this could potentially
         # be configurable per-appservice whether the appservice sender can be
         # contacted.
-        if self.get_app_service_by_user_id(user) is not None:
+        if self.get_app_service_by_user_id(user) is not None:  # type: ignore[attr-defined]
             return False
 
         # We're opting to exclude appservice users (anyone matching the user
@@ -405,17 +405,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         # they could be DM-able. In the future, this could potentially
         # be configurable per-appservice whether the appservice users can be
         # contacted.
-        if self.get_if_app_services_interested_in_user(user):
+        if self.get_if_app_services_interested_in_user(user):  # type: ignore[attr-defined]
             # TODO we might want to make this configurable for each app service
             return False
 
         # Support users are for diagnostics and should not appear in the user directory.
-        if await self.is_support_user(user):
+        if await self.is_support_user(user):  # type: ignore[attr-defined]
             return False
 
         # Deactivated users aren't contactable, so should not appear in the user directory.
         try:
-            if await self.get_user_deactivated_status(user):
+            if await self.get_user_deactivated_status(user):  # type: ignore[attr-defined]
                 return False
         except StoreError:
             # No such user in the users table. No need to do this when calling
@@ -433,20 +433,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             (EventTypes.RoomHistoryVisibility, ""),
         )
 
-        current_state_ids = await self.get_filtered_current_state_ids(
+        current_state_ids = await self.get_filtered_current_state_ids(  # type: ignore[attr-defined]
             room_id, StateFilter.from_types(types_to_filter)
         )
 
         join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
         if join_rules_id:
-            join_rule_ev = await self.get_event(join_rules_id, allow_none=True)
+            join_rule_ev = await self.get_event(join_rules_id, allow_none=True)  # type: ignore[attr-defined]
             if join_rule_ev:
                 if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
                     return True
 
         hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
         if hist_vis_id:
-            hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)
+            hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)  # type: ignore[attr-defined]
             if hist_vis_ev:
                 if (
                     hist_vis_ev.content.get("history_visibility")
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 7614d76ac6..3af69a2076 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -13,11 +13,23 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Collection,
+    Dict,
+    Iterable,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+)
 
 import attr
 
+from twisted.internet import defer
+
 from synapse.api.constants import EventTypes
+from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
     DatabasePool,
@@ -29,6 +41,12 @@ from synapse.storage.state import StateFilter
 from synapse.storage.types import Cursor
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import MutableStateMap, StateKey, StateMap
+from synapse.util import unwrapFirstError
+from synapse.util.async_helpers import (
+    AbstractObservableDeferred,
+    ObservableDeferred,
+    yieldable_gather_results,
+)
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.dictionary_cache import DictionaryCache
 
@@ -37,7 +55,6 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
-
 MAX_STATE_DELTA_HOPS = 100
 
 
@@ -106,6 +123,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             500000,
         )
 
+        # Current ongoing get_state_for_groups in-flight requests
+        # {group ID -> {StateFilter -> ObservableDeferred}}
+        self._state_group_inflight_requests: Dict[
+            int, Dict[StateFilter, AbstractObservableDeferred[StateMap[str]]]
+        ] = {}
+
         def get_max_state_group_txn(txn: Cursor) -> int:
             txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
             return txn.fetchone()[0]  # type: ignore
@@ -157,7 +180,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         )
 
     async def _get_state_groups_from_groups(
-        self, groups: List[int], state_filter: StateFilter
+        self, groups: Sequence[int], state_filter: StateFilter
     ) -> Dict[int, StateMap[str]]:
         """Returns the state groups for a given set of groups from the
         database, filtering on types of state events.
@@ -228,6 +251,150 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
 
         return state_filter.filter_state(state_dict_ids), not missing_types
 
+    def _get_state_for_group_gather_inflight_requests(
+        self, group: int, state_filter_left_over: StateFilter
+    ) -> Tuple[Sequence[AbstractObservableDeferred[StateMap[str]]], StateFilter]:
+        """
+        Attempts to gather in-flight requests and re-use them to retrieve state
+        for the given state group, filtered with the given state filter.
+
+        Used as part of _get_state_for_group_using_inflight_cache.
+
+        Returns:
+            Tuple of two values:
+                A sequence of ObservableDeferreds to observe
+                A StateFilter representing what else needs to be requested to fulfill the request
+        """
+
+        inflight_requests = self._state_group_inflight_requests.get(group)
+        if inflight_requests is None:
+            # no requests for this group, need to retrieve it all ourselves
+            return (), state_filter_left_over
+
+        # The list of ongoing requests which will help narrow the current request.
+        reusable_requests = []
+        for (request_state_filter, request_deferred) in inflight_requests.items():
+            new_state_filter_left_over = state_filter_left_over.approx_difference(
+                request_state_filter
+            )
+            if new_state_filter_left_over == state_filter_left_over:
+                # Reusing this request would not gain us anything, so don't bother.
+                continue
+
+            reusable_requests.append(request_deferred)
+            state_filter_left_over = new_state_filter_left_over
+            if state_filter_left_over == StateFilter.none():
+                # we have managed to collect enough of the in-flight requests
+                # to cover our StateFilter and give us the state we need.
+                break
+
+        return reusable_requests, state_filter_left_over
+
+    async def _get_state_for_group_fire_request(
+        self, group: int, state_filter: StateFilter
+    ) -> StateMap[str]:
+        """
+        Fires off a request to get the state at a state group,
+        potentially filtering by type and/or state key.
+
+        This request will be tracked in the in-flight request cache and automatically
+        removed when it is finished.
+
+        Used as part of _get_state_for_group_using_inflight_cache.
+
+        Args:
+            group: ID of the state group for which we want to get state
+            state_filter: the state filter used to fetch state from the database
+        """
+        cache_sequence_nm = self._state_group_cache.sequence
+        cache_sequence_m = self._state_group_members_cache.sequence
+
+        # Help the cache hit ratio by expanding the filter a bit
+        db_state_filter = state_filter.return_expanded()
+
+        async def _the_request() -> StateMap[str]:
+            group_to_state_dict = await self._get_state_groups_from_groups(
+                (group,), state_filter=db_state_filter
+            )
+
+            # Now let's update the caches
+            self._insert_into_cache(
+                group_to_state_dict,
+                db_state_filter,
+                cache_seq_num_members=cache_sequence_m,
+                cache_seq_num_non_members=cache_sequence_nm,
+            )
+
+            # Remove ourselves from the in-flight cache
+            group_request_dict = self._state_group_inflight_requests[group]
+            del group_request_dict[db_state_filter]
+            if not group_request_dict:
+                # If there are no more requests in-flight for this group,
+                # clean up the cache by removing the empty dictionary
+                del self._state_group_inflight_requests[group]
+
+            return group_to_state_dict[group]
+
+        # We don't immediately await the result, so must use run_in_background
+        # But we DO await the result before the current log context (request)
+        # finishes, so don't need to run it as a background process.
+        request_deferred = run_in_background(_the_request)
+        observable_deferred = ObservableDeferred(request_deferred, consumeErrors=True)
+
+        # Insert the ObservableDeferred into the cache
+        group_request_dict = self._state_group_inflight_requests.setdefault(group, {})
+        group_request_dict[db_state_filter] = observable_deferred
+
+        return await make_deferred_yieldable(observable_deferred.observe())
+
+    async def _get_state_for_group_using_inflight_cache(
+        self, group: int, state_filter: StateFilter
+    ) -> MutableStateMap[str]:
+        """
+        Gets the state at a state group, potentially filtering by type and/or
+        state key.
+
+        1. Calls _get_state_for_group_gather_inflight_requests to gather any
+           ongoing requests which might overlap with the current request.
+        2. Fires a new request, using _get_state_for_group_fire_request,
+           for any state which cannot be gathered from ongoing requests.
+
+        Args:
+            group: ID of the state group for which we want to get state
+            state_filter: the state filter used to fetch state from the database
+        Returns:
+            state map
+        """
+
+        # first, figure out whether we can re-use any in-flight requests
+        # (and if so, what would be left over)
+        (
+            reusable_requests,
+            state_filter_left_over,
+        ) = self._get_state_for_group_gather_inflight_requests(group, state_filter)
+
+        if state_filter_left_over != StateFilter.none():
+            # Fetch remaining state
+            remaining = await self._get_state_for_group_fire_request(
+                group, state_filter_left_over
+            )
+            assembled_state: MutableStateMap[str] = dict(remaining)
+        else:
+            assembled_state = {}
+
+        gathered = await make_deferred_yieldable(
+            defer.gatherResults(
+                (r.observe() for r in reusable_requests), consumeErrors=True
+            )
+        ).addErrback(unwrapFirstError)
+
+        # assemble our result.
+        for result_piece in gathered:
+            assembled_state.update(result_piece)
+
+        # Filter out any state that may be more than what we asked for.
+        return state_filter.filter_state(assembled_state)
+
     async def _get_state_for_groups(
         self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
     ) -> Dict[int, MutableStateMap[str]]:
@@ -269,31 +436,17 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         if not incomplete_groups:
             return state
 
-        cache_sequence_nm = self._state_group_cache.sequence
-        cache_sequence_m = self._state_group_members_cache.sequence
-
-        # Help the cache hit ratio by expanding the filter a bit
-        db_state_filter = state_filter.return_expanded()
-
-        group_to_state_dict = await self._get_state_groups_from_groups(
-            list(incomplete_groups), state_filter=db_state_filter
-        )
+        async def get_from_cache(group: int, state_filter: StateFilter) -> None:
+            state[group] = await self._get_state_for_group_using_inflight_cache(
+                group, state_filter
+            )
 
-        # Now lets update the caches
-        self._insert_into_cache(
-            group_to_state_dict,
-            db_state_filter,
-            cache_seq_num_members=cache_sequence_m,
-            cache_seq_num_non_members=cache_sequence_nm,
+        await yieldable_gather_results(
+            get_from_cache,
+            incomplete_groups,
+            state_filter,
         )
 
-        # And finally update the result dict, by filtering out any extra
-        # stuff we pulled out of the database.
-        for group, group_state_dict in group_to_state_dict.items():
-            # We just replace any existing entries, as we will have loaded
-            # everything we need from the database anyway.
-            state[group] = state_filter.filter_state(group_state_dict)
-
         return state
 
     def _get_state_for_groups_using_cache(
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 913448f0f9..e79ecf64a0 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -204,13 +204,16 @@ class StateFilter:
         if get_all_members:
             # We want to return everything.
             return StateFilter.all()
-        else:
+        elif EventTypes.Member in self.types:
             # We want to return all non-members, but only particular
             # memberships
             return StateFilter(
                 types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}),
                 include_others=True,
             )
+        else:
+            # We want to return all non-members
+            return _ALL_NON_MEMBER_STATE_FILTER
 
     def make_sql_filter_clause(self) -> Tuple[str, List[str]]:
         """Converts the filter to an SQL clause.
@@ -528,6 +531,9 @@ class StateFilter:
 
 
 _ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True)
+_ALL_NON_MEMBER_STATE_FILTER = StateFilter(
+    types=frozendict({EventTypes.Member: frozenset()}), include_others=True
+)
 _NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False)
 
 
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 21591d0bfd..4ec2a713cf 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -37,14 +37,16 @@ class _EventSourcesInner:
     account_data: AccountDataEventSource
 
     def get_sources(self) -> Iterator[Tuple[str, EventSource]]:
-        for attribute in _EventSourcesInner.__attrs_attrs__:  # type: ignore[attr-defined]
+        for attribute in attr.fields(_EventSourcesInner):
             yield attribute.name, getattr(self, attribute.name)
 
 
 class EventSources:
     def __init__(self, hs: "HomeServer"):
         self.sources = _EventSourcesInner(
-            *(attribute.type(hs) for attribute in _EventSourcesInner.__attrs_attrs__)  # type: ignore[attr-defined]
+            # mypy thinks attribute.type is `Optional`, but we know it's never `None` here since
+            # all the attributes of `_EventSourcesInner` are annotated.
+            *(attribute.type(hs) for attribute in attr.fields(_EventSourcesInner))  # type: ignore[misc]
         )
         self.store = hs.get_datastore()
 
diff --git a/synapse/types.py b/synapse/types.py
index f89fb216a6..53be3583a0 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -51,7 +51,7 @@ from synapse.util.stringutils import parse_and_validate_server_name
 
 if TYPE_CHECKING:
     from synapse.appservice.api import ApplicationService
-    from synapse.storage.databases.main import DataStore
+    from synapse.storage.databases.main import DataStore, PurgeEventsStore
 
 # Define a state map type from type/state_key to T (usually an event ID or
 # event)
@@ -485,7 +485,7 @@ class RoomStreamToken:
             )
 
     @classmethod
-    async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken":
+    async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken":
         try:
             if string[0] == "s":
                 return cls(topological=None, stream=int(string[1:]))
@@ -502,7 +502,7 @@ class RoomStreamToken:
                     instance_id = int(key)
                     pos = int(value)
 
-                    instance_name = await store.get_name_from_instance_id(instance_id)
+                    instance_name = await store.get_name_from_instance_id(instance_id)  # type: ignore[attr-defined]
                     instance_map[instance_name] = pos
 
                 return cls(
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 15debd6c46..1cbc180eda 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -56,6 +56,7 @@ response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["n
 class EvictionReason(Enum):
     size = auto()
     time = auto()
+    invalidation = auto()
 
 
 @attr.s(slots=True, auto_attribs=True)
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 67ee4c693b..c6a5d0dfc0 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -133,6 +133,11 @@ class ExpiringCache(Generic[KT, VT]):
                 raise KeyError(key)
             return default
 
+        if self.iterable:
+            self.metrics.inc_evictions(EvictionReason.invalidation, len(value.value))
+        else:
+            self.metrics.inc_evictions(EvictionReason.invalidation)
+
         return value.value
 
     def __contains__(self, key: KT) -> bool:
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 7548b38548..45ff0de638 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -560,8 +560,10 @@ class LruCache(Generic[KT, VT]):
         def cache_pop(key: KT, default: Optional[T] = None) -> Union[None, T, VT]:
             node = cache.get(key, None)
             if node:
-                delete_node(node)
+                evicted_len = delete_node(node)
                 cache.pop(node.key, None)
+                if metrics:
+                    metrics.inc_evictions(EvictionReason.invalidation, evicted_len)
                 return node.value
             else:
                 return default
diff --git a/synapse/util/daemonize.py b/synapse/util/daemonize.py
index de04f34e4e..031880ec39 100644
--- a/synapse/util/daemonize.py
+++ b/synapse/util/daemonize.py
@@ -20,7 +20,7 @@ import os
 import signal
 import sys
 from types import FrameType, TracebackType
-from typing import NoReturn, Type
+from typing import NoReturn, Optional, Type
 
 
 def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -> None:
@@ -100,7 +100,9 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
     # also catch any other uncaught exceptions before we get that far.)
 
     def excepthook(
-        type_: Type[BaseException], value: BaseException, traceback: TracebackType
+        type_: Type[BaseException],
+        value: BaseException,
+        traceback: Optional[TracebackType],
     ) -> None:
         logger.critical("Unhanded exception", exc_info=(type_, value, traceback))
 
@@ -123,7 +125,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
         sys.exit(1)
 
     # write a log line on SIGTERM.
-    def sigterm(signum: signal.Signals, frame: FrameType) -> NoReturn:
+    def sigterm(signum: int, frame: Optional[FrameType]) -> NoReturn:
         logger.warning("Caught signal %s. Stopping daemon." % signum)
         sys.exit(0)
 
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index 1f18654d47..6d4b0b7c5a 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -14,7 +14,7 @@
 
 import functools
 import sys
-from typing import Any, Callable, Generator, List, TypeVar
+from typing import Any, Callable, Generator, List, TypeVar, cast
 
 from twisted.internet import defer
 from twisted.internet.defer import Deferred
@@ -174,7 +174,9 @@ def _check_yield_points(
                         )
                     )
                     changes.append(err)
-                return getattr(e, "value", None)
+                # The `StopIteration` or `_DefGen_Return` contains the return value from the
+                # generator.
+                return cast(T, e.value)
 
             frame = gen.gi_frame
 
diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py
new file mode 100644
index 0000000000..ec8864dafe
--- /dev/null
+++ b/tests/federation/test_federation_client.py
@@ -0,0 +1,149 @@
+# Copyright 2022 Matrix.org Federation C.I.C
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from unittest import mock
+
+import twisted.web.client
+from twisted.internet import defer
+from twisted.internet.protocol import Protocol
+from twisted.python.failure import Failure
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.room_versions import RoomVersions
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
+
+from tests.unittest import FederatingHomeserverTestCase
+
+
+class FederationClientTest(FederatingHomeserverTestCase):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
+        super().prepare(reactor, clock, homeserver)
+
+        # mock out the Agent used by the federation client, which is easier than
+        # catching the HTTPS connection and do the TLS stuff.
+        self._mock_agent = mock.create_autospec(twisted.web.client.Agent, spec_set=True)
+        homeserver.get_federation_http_client().agent = self._mock_agent
+
+    def test_get_room_state(self):
+        creator = f"@creator:{self.OTHER_SERVER_NAME}"
+        test_room_id = "!room_id"
+
+        # mock up some events to use in the response.
+        # In real life, these would have things in `prev_events` and `auth_events`, but that's
+        # a bit annoying to mock up, and the code under test doesn't care, so we don't bother.
+        create_event_dict = self.add_hashes_and_signatures(
+            {
+                "room_id": test_room_id,
+                "type": "m.room.create",
+                "state_key": "",
+                "sender": creator,
+                "content": {"creator": creator},
+                "prev_events": [],
+                "auth_events": [],
+                "origin_server_ts": 500,
+            }
+        )
+        member_event_dict = self.add_hashes_and_signatures(
+            {
+                "room_id": test_room_id,
+                "type": "m.room.member",
+                "sender": creator,
+                "state_key": creator,
+                "content": {"membership": "join"},
+                "prev_events": [],
+                "auth_events": [],
+                "origin_server_ts": 600,
+            }
+        )
+        pl_event_dict = self.add_hashes_and_signatures(
+            {
+                "room_id": test_room_id,
+                "type": "m.room.power_levels",
+                "sender": creator,
+                "state_key": "",
+                "content": {},
+                "prev_events": [],
+                "auth_events": [],
+                "origin_server_ts": 700,
+            }
+        )
+
+        # mock up the response, and have the agent return it
+        self._mock_agent.request.return_value = defer.succeed(
+            _mock_response(
+                {
+                    "pdus": [
+                        create_event_dict,
+                        member_event_dict,
+                        pl_event_dict,
+                    ],
+                    "auth_chain": [
+                        create_event_dict,
+                        member_event_dict,
+                    ],
+                }
+            )
+        )
+
+        # now fire off the request
+        state_resp, auth_resp = self.get_success(
+            self.hs.get_federation_client().get_room_state(
+                "yet_another_server",
+                test_room_id,
+                "event_id",
+                RoomVersions.V9,
+            )
+        )
+
+        # check the right call got made to the agent
+        self._mock_agent.request.assert_called_once_with(
+            b"GET",
+            b"matrix://yet_another_server/_matrix/federation/v1/state/%21room_id?event_id=event_id",
+            headers=mock.ANY,
+            bodyProducer=None,
+        )
+
+        # ... and that the response is correct.
+
+        # the auth_resp should be empty because all the events are also in state
+        self.assertEqual(auth_resp, [])
+
+        # all of the events should be returned in state_resp, though not necessarily
+        # in the same order. We just check the type on the assumption that if the type
+        # is right, so is the rest of the event.
+        self.assertCountEqual(
+            [e.type for e in state_resp],
+            ["m.room.create", "m.room.member", "m.room.power_levels"],
+        )
+
+
+def _mock_response(resp: JsonDict):
+    body = json.dumps(resp).encode("utf-8")
+
+    def deliver_body(p: Protocol):
+        p.dataReceived(body)
+        p.connectionLost(Failure(twisted.web.client.ResponseDone()))
+
+    response = mock.Mock(
+        code=200,
+        phrase=b"OK",
+        headers=twisted.web.client.Headers({"content-Type": ["application/json"]}),
+        length=len(body),
+        deliverBody=deliver_body,
+    )
+    mock.seal(response)
+    return response
diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py
index a7031a55f2..c2320ce133 100644
--- a/tests/federation/transport/test_client.py
+++ b/tests/federation/transport/test_client.py
@@ -62,3 +62,35 @@ class SendJoinParserTestCase(TestCase):
         self.assertEqual(len(parsed_response.state), 1, parsed_response)
         self.assertEqual(parsed_response.event_dict, {}, parsed_response)
         self.assertIsNone(parsed_response.event, parsed_response)
+        self.assertFalse(parsed_response.partial_state, parsed_response)
+        self.assertEqual(parsed_response.servers_in_room, None, parsed_response)
+
+    def test_partial_state(self) -> None:
+        """Check that the partial_state flag is correctly parsed"""
+        parser = SendJoinParser(RoomVersions.V1, False)
+        response = {
+            "org.matrix.msc3706.partial_state": True,
+        }
+
+        serialised_response = json.dumps(response).encode()
+
+        # Send data to the parser
+        parser.write(serialised_response)
+
+        # Retrieve and check the parsed SendJoinResponse
+        parsed_response = parser.finish()
+        self.assertTrue(parsed_response.partial_state)
+
+    def test_servers_in_room(self) -> None:
+        """Check that the servers_in_room field is correctly parsed"""
+        parser = SendJoinParser(RoomVersions.V1, False)
+        response = {"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]}
+
+        serialised_response = json.dumps(response).encode()
+
+        # Send data to the parser
+        parser.write(serialised_response)
+
+        # Retrieve and check the parsed SendJoinResponse
+        parsed_response = parser.finish()
+        self.assertEqual(parsed_response.servers_in_room, ["hs1", "hs2"])
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 4740dd0a65..49d832de81 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -84,7 +84,7 @@ class CustomAuthProvider:
 
     def __init__(self, config, api: ModuleApi):
         api.register_password_auth_provider_callbacks(
-            auth_checkers={("test.login_type", ("test_field",)): self.check_auth},
+            auth_checkers={("test.login_type", ("test_field",)): self.check_auth}
         )
 
     def check_auth(self, *args):
@@ -122,7 +122,7 @@ class PasswordCustomAuthProvider:
             auth_checkers={
                 ("test.login_type", ("test_field",)): self.check_auth,
                 ("m.login.password", ("password",)): self.check_auth,
-            },
+            }
         )
         pass
 
@@ -163,6 +163,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         account.register_servlets,
     ]
 
+    CALLBACK_USERNAME = "get_username_for_registration"
+    CALLBACK_DISPLAYNAME = "get_displayname_for_registration"
+
     def setUp(self):
         # we use a global mock device, so make sure we are starting with a clean slate
         mock_password_provider.reset_mock()
@@ -754,7 +757,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         """Tests that the get_username_for_registration callback can define the username
         of a user when registering.
         """
-        self._setup_get_username_for_registration()
+        self._setup_get_name_for_registration(
+            callback_name=self.CALLBACK_USERNAME,
+        )
 
         username = "rin"
         channel = self.make_request(
@@ -777,30 +782,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         """Tests that the get_username_for_registration callback is only called at the
         end of the UIA flow.
         """
-        m = self._setup_get_username_for_registration()
-
-        # Initiate the UIA flow.
-        username = "rin"
-        channel = self.make_request(
-            "POST",
-            "register",
-            {"username": username, "type": "m.login.password", "password": "bar"},
+        m = self._setup_get_name_for_registration(
+            callback_name=self.CALLBACK_USERNAME,
         )
-        self.assertEqual(channel.code, 401)
-        self.assertIn("session", channel.json_body)
 
-        # Check that the callback hasn't been called yet.
-        m.assert_not_called()
+        username = "rin"
+        res = self._do_uia_assert_mock_not_called(username, m)
 
-        # Finish the UIA flow.
-        session = channel.json_body["session"]
-        channel = self.make_request(
-            "POST",
-            "register",
-            {"auth": {"session": session, "type": LoginType.DUMMY}},
-        )
-        self.assertEqual(channel.code, 200, channel.json_body)
-        mxid = channel.json_body["user_id"]
+        mxid = res["user_id"]
         self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
 
         # Check that the callback has been called.
@@ -817,6 +806,56 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self._test_3pid_allowed("rin", False)
         self._test_3pid_allowed("kitay", True)
 
+    def test_displayname(self):
+        """Tests that the get_displayname_for_registration callback can define the
+        display name of a user when registering.
+        """
+        self._setup_get_name_for_registration(
+            callback_name=self.CALLBACK_DISPLAYNAME,
+        )
+
+        username = "rin"
+        channel = self.make_request(
+            "POST",
+            "/register",
+            {
+                "username": username,
+                "password": "bar",
+                "auth": {"type": LoginType.DUMMY},
+            },
+        )
+        self.assertEqual(channel.code, 200)
+
+        # Our callback takes the username and appends "-foo" to it, check that's what we
+        # have.
+        user_id = UserID.from_string(channel.json_body["user_id"])
+        display_name = self.get_success(
+            self.hs.get_profile_handler().get_displayname(user_id)
+        )
+
+        self.assertEqual(display_name, username + "-foo")
+
+    def test_displayname_uia(self):
+        """Tests that the get_displayname_for_registration callback is only called at the
+        end of the UIA flow.
+        """
+        m = self._setup_get_name_for_registration(
+            callback_name=self.CALLBACK_DISPLAYNAME,
+        )
+
+        username = "rin"
+        res = self._do_uia_assert_mock_not_called(username, m)
+
+        user_id = UserID.from_string(res["user_id"])
+        display_name = self.get_success(
+            self.hs.get_profile_handler().get_displayname(user_id)
+        )
+
+        self.assertEqual(display_name, username + "-foo")
+
+        # Check that the callback has been called.
+        m.assert_called_once()
+
     def _test_3pid_allowed(self, username: str, registration: bool):
         """Tests that the "is_3pid_allowed" module callback is called correctly, using
         either /register or /account URLs depending on the arguments.
@@ -877,23 +916,47 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
         m.assert_called_once_with("email", "bar@test.com", registration)
 
-    def _setup_get_username_for_registration(self) -> Mock:
-        """Registers a get_username_for_registration callback that appends "-foo" to the
-        username the client is trying to register.
+    def _setup_get_name_for_registration(self, callback_name: str) -> Mock:
+        """Registers either a get_username_for_registration callback or a
+        get_displayname_for_registration callback that appends "-foo" to the username the
+        client is trying to register.
         """
 
-        async def get_username_for_registration(uia_results, params):
+        async def callback(uia_results, params):
             self.assertIn(LoginType.DUMMY, uia_results)
             username = params["username"]
             return username + "-foo"
 
-        m = Mock(side_effect=get_username_for_registration)
+        m = Mock(side_effect=callback)
 
         password_auth_provider = self.hs.get_password_auth_provider()
-        password_auth_provider.get_username_for_registration_callbacks.append(m)
+        getattr(password_auth_provider, callback_name + "_callbacks").append(m)
 
         return m
 
+    def _do_uia_assert_mock_not_called(self, username: str, m: Mock) -> JsonDict:
+        # Initiate the UIA flow.
+        channel = self.make_request(
+            "POST",
+            "register",
+            {"username": username, "type": "m.login.password", "password": "bar"},
+        )
+        self.assertEqual(channel.code, 401)
+        self.assertIn("session", channel.json_body)
+
+        # Check that the callback hasn't been called yet.
+        m.assert_not_called()
+
+        # Finish the UIA flow.
+        session = channel.json_body["session"]
+        channel = self.make_request(
+            "POST",
+            "register",
+            {"auth": {"session": session, "type": LoginType.DUMMY}},
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        return channel.json_body
+
     def _get_login_flows(self) -> JsonDict:
         channel = self.make_request("GET", "/_matrix/client/r0/login")
         self.assertEqual(channel.code, 200, channel.result)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index c068d329a9..e1e3fb97c5 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -571,9 +571,7 @@ class HTTPPusherTests(HomeserverTestCase):
         # Carry out our option-value specific test
         #
         # This push should still only contain an unread count of 1 (for 1 unread room)
-        self.assertEqual(
-            self.push_attempts[5][2]["notification"]["counts"]["unread"], 1
-        )
+        self._check_push_attempt(6, 1)
 
     @override_config({"push": {"group_unread_count_by_room": False}})
     def test_push_unread_count_message_count(self):
@@ -585,11 +583,9 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # Carry out our option-value specific test
         #
-        # We're counting every unread message, so there should now be 4 since the
+        # We're counting every unread message, so there should now be 3 since the
         # last read receipt
-        self.assertEqual(
-            self.push_attempts[5][2]["notification"]["counts"]["unread"], 4
-        )
+        self._check_push_attempt(6, 3)
 
     def _test_push_unread_count(self):
         """
@@ -597,8 +593,9 @@ class HTTPPusherTests(HomeserverTestCase):
 
         Note that:
         * Sending messages will cause push notifications to go out to relevant users
-        * Sending a read receipt will cause a "badge update" notification to go out to
-          the user that sent the receipt
+        * Sending a read receipt will cause the HTTP pusher to check whether the unread
+            count has changed since the last push notification. If so, a "badge update"
+            notification goes out to the user that sent the receipt
         """
         # Register the user who gets notified
         user_id = self.register_user("user", "pass")
@@ -642,24 +639,74 @@ class HTTPPusherTests(HomeserverTestCase):
         # position in the room. We'll set the read position to this event in a moment
         first_message_event_id = response["event_id"]
 
-        # Advance time a bit (so the pusher will register something has happened) and
-        # make the push succeed
-        self.push_attempts[0][0].callback({})
+        expected_push_attempts = 1
+        self._check_push_attempt(expected_push_attempts, 0)
+
+        self._send_read_request(access_token, first_message_event_id, room_id)
+
+        # Unread count has not changed. Therefore, ensure that read request does not
+        # trigger a push notification.
+        self.assertEqual(len(self.push_attempts), 1)
+
+        # Send another message
+        response2 = self.helper.send(
+            room_id, body="How's the weather today?", tok=other_access_token
+        )
+        second_message_event_id = response2["event_id"]
+
+        expected_push_attempts += 1
+
+        self._check_push_attempt(expected_push_attempts, 1)
+
+        self._send_read_request(access_token, second_message_event_id, room_id)
+        expected_push_attempts += 1
+
+        self._check_push_attempt(expected_push_attempts, 0)
+
+        # If we're grouping by room, sending more messages shouldn't increase the
+        # unread count, as they're all being sent in the same room. Otherwise, it
+        # should. Therefore, the last call to _check_push_attempt is done in the
+        # caller method.
+        self.helper.send(room_id, body="Hello?", tok=other_access_token)
+        expected_push_attempts += 1
+
+        self._advance_time_and_make_push_succeed(expected_push_attempts)
+
+        self.helper.send(room_id, body="Hello??", tok=other_access_token)
+        expected_push_attempts += 1
+
+        self._advance_time_and_make_push_succeed(expected_push_attempts)
+
+        self.helper.send(room_id, body="HELLO???", tok=other_access_token)
+
+    def _advance_time_and_make_push_succeed(self, expected_push_attempts):
         self.pump()
+        self.push_attempts[expected_push_attempts - 1][0].callback({})
 
+    def _check_push_attempt(
+        self, expected_push_attempts: int, expected_unread_count_last_push: int
+    ) -> None:
+        """
+        Makes sure that the last expected push attempt succeeds and checks whether
+        it contains the expected unread count.
+        """
+        self._advance_time_and_make_push_succeed(expected_push_attempts)
         # Check our push made it
-        self.assertEqual(len(self.push_attempts), 1)
+        self.assertEqual(len(self.push_attempts), expected_push_attempts)
+        _, push_url, push_body = self.push_attempts[expected_push_attempts - 1]
         self.assertEqual(
-            self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+            push_url,
+            "http://example.com/_matrix/push/v1/notify",
         )
-
         # Check that the unread count for the room is 0
         #
         # The unread count is zero as the user has no read receipt in the room yet
         self.assertEqual(
-            self.push_attempts[0][2]["notification"]["counts"]["unread"], 0
+            push_body["notification"]["counts"]["unread"],
+            expected_unread_count_last_push,
         )
 
+    def _send_read_request(self, access_token, message_event_id, room_id):
         # Now set the user's read receipt position to the first event
         #
         # This will actually trigger a new notification to be sent out so that
@@ -667,56 +714,8 @@ class HTTPPusherTests(HomeserverTestCase):
         # count goes down
         channel = self.make_request(
             "POST",
-            "/rooms/%s/receipt/m.read/%s" % (room_id, first_message_event_id),
+            "/rooms/%s/receipt/m.read/%s" % (room_id, message_event_id),
             {},
             access_token=access_token,
         )
         self.assertEqual(channel.code, 200, channel.json_body)
-
-        # Advance time and make the push succeed
-        self.push_attempts[1][0].callback({})
-        self.pump()
-
-        # Unread count is still zero as we've read the only message in the room
-        self.assertEqual(len(self.push_attempts), 2)
-        self.assertEqual(
-            self.push_attempts[1][2]["notification"]["counts"]["unread"], 0
-        )
-
-        # Send another message
-        self.helper.send(
-            room_id, body="How's the weather today?", tok=other_access_token
-        )
-
-        # Advance time and make the push succeed
-        self.push_attempts[2][0].callback({})
-        self.pump()
-
-        # This push should contain an unread count of 1 as there's now been one
-        # message since our last read receipt
-        self.assertEqual(len(self.push_attempts), 3)
-        self.assertEqual(
-            self.push_attempts[2][2]["notification"]["counts"]["unread"], 1
-        )
-
-        # Since we're grouping by room, sending more messages shouldn't increase the
-        # unread count, as they're all being sent in the same room
-        self.helper.send(room_id, body="Hello?", tok=other_access_token)
-
-        # Advance time and make the push succeed
-        self.pump()
-        self.push_attempts[3][0].callback({})
-
-        self.helper.send(room_id, body="Hello??", tok=other_access_token)
-
-        # Advance time and make the push succeed
-        self.pump()
-        self.push_attempts[4][0].callback({})
-
-        self.helper.send(room_id, body="HELLO???", tok=other_access_token)
-
-        # Advance time and make the push succeed
-        self.pump()
-        self.push_attempts[5][0].callback({})
-
-        self.assertEqual(len(self.push_attempts), 6)
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 89d85b0a17..51146c471d 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -486,8 +486,9 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
             {
                 "user_id": user_id,
                 "device_id": device_id,
-                # Unstable until MSC3069 enters spec
+                # MSC3069 entered spec in Matrix 1.2 but maintained compatibility
                 "org.matrix.msc3069.is_guest": False,
+                "is_guest": False,
             },
         )
 
@@ -505,8 +506,9 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
             {
                 "user_id": user_id,
                 "device_id": device_id,
-                # Unstable until MSC3069 enters spec
+                # MSC3069 entered spec in Matrix 1.2 but maintained compatibility
                 "org.matrix.msc3069.is_guest": True,
+                "is_guest": True,
             },
         )
 
@@ -528,8 +530,9 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
             whoami,
             {
                 "user_id": user_id,
-                # Unstable until MSC3069 enters spec
+                # MSC3069 entered spec in Matrix 1.2 but maintained compatibility
                 "org.matrix.msc3069.is_guest": False,
+                "is_guest": False,
             },
         )
         self.assertFalse(hasattr(whoami, "device_id"))
diff --git a/tests/rest/client/test_device_lists.py b/tests/rest/client/test_device_lists.py
new file mode 100644
index 0000000000..16070cf027
--- /dev/null
+++ b/tests/rest/client/test_device_lists.py
@@ -0,0 +1,155 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.rest import admin, devices, room, sync
+from synapse.rest.client import account, login, register
+
+from tests import unittest
+
+
+class DeviceListsTestCase(unittest.HomeserverTestCase):
+    """Tests regarding device list changes."""
+
+    servlets = [
+        admin.register_servlets_for_client_rest_resource,
+        login.register_servlets,
+        register.register_servlets,
+        account.register_servlets,
+        room.register_servlets,
+        sync.register_servlets,
+        devices.register_servlets,
+    ]
+
+    def test_receiving_local_device_list_changes(self):
+        """Tests that a local users that share a room receive each other's device list
+        changes.
+        """
+        # Register two users
+        test_device_id = "TESTDEVICE"
+        alice_user_id = self.register_user("alice", "correcthorse")
+        alice_access_token = self.login(
+            alice_user_id, "correcthorse", device_id=test_device_id
+        )
+
+        bob_user_id = self.register_user("bob", "ponyponypony")
+        bob_access_token = self.login(bob_user_id, "ponyponypony")
+
+        # Create a room for them to coexist peacefully in
+        new_room_id = self.helper.create_room_as(
+            alice_user_id, is_public=True, tok=alice_access_token
+        )
+        self.assertIsNotNone(new_room_id)
+
+        # Have Bob join the room
+        self.helper.invite(
+            new_room_id, alice_user_id, bob_user_id, tok=alice_access_token
+        )
+        self.helper.join(new_room_id, bob_user_id, tok=bob_access_token)
+
+        # Now have Bob initiate an initial sync (in order to get a since token)
+        channel = self.make_request(
+            "GET",
+            "/sync",
+            access_token=bob_access_token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        next_batch_token = channel.json_body["next_batch"]
+
+        # ...and then an incremental sync. This should block until the sync stream is woken up,
+        # which we hope will happen as a result of Alice updating their device list.
+        bob_sync_channel = self.make_request(
+            "GET",
+            f"/sync?since={next_batch_token}&timeout=30000",
+            access_token=bob_access_token,
+            # Start the request, then continue on.
+            await_result=False,
+        )
+
+        # Have alice update their device list
+        channel = self.make_request(
+            "PUT",
+            f"/devices/{test_device_id}",
+            {
+                "display_name": "New Device Name",
+            },
+            access_token=alice_access_token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check that bob's incremental sync contains the updated device list.
+        # If not, the client would only receive the device list update on the
+        # *next* sync.
+        bob_sync_channel.await_result()
+        self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
+
+        changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
+            "changed", []
+        )
+        self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body)
+
+    def test_not_receiving_local_device_list_changes(self):
+        """Tests a local users DO NOT receive device updates from each other if they do not
+        share a room.
+        """
+        # Register two users
+        test_device_id = "TESTDEVICE"
+        alice_user_id = self.register_user("alice", "correcthorse")
+        alice_access_token = self.login(
+            alice_user_id, "correcthorse", device_id=test_device_id
+        )
+
+        bob_user_id = self.register_user("bob", "ponyponypony")
+        bob_access_token = self.login(bob_user_id, "ponyponypony")
+
+        # These users do not share a room. They are lonely.
+
+        # Have Bob initiate an initial sync (in order to get a since token)
+        channel = self.make_request(
+            "GET",
+            "/sync",
+            access_token=bob_access_token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        next_batch_token = channel.json_body["next_batch"]
+
+        # ...and then an incremental sync. This should block until the sync stream is woken up,
+        # which we hope will happen as a result of Alice updating their device list.
+        bob_sync_channel = self.make_request(
+            "GET",
+            f"/sync?since={next_batch_token}&timeout=1000",
+            access_token=bob_access_token,
+            # Start the request, then continue on.
+            await_result=False,
+        )
+
+        # Have alice update their device list
+        channel = self.make_request(
+            "PUT",
+            f"/devices/{test_device_id}",
+            {
+                "display_name": "New Device Name",
+            },
+            access_token=alice_access_token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check that bob's incremental sync does not contain the updated device list.
+        bob_sync_channel.await_result()
+        self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
+
+        changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
+            "changed", []
+        )
+        self.assertNotIn(
+            alice_user_id, changed_device_lists, bob_sync_channel.json_body
+        )
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index de80aca037..dfd9ffcb93 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -1123,6 +1123,48 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
         )
 
+    @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
+    def test_edit_thread(self):
+        """Test that editing a thread works."""
+
+        # Create a thread and edit the last event.
+        channel = self._send_relation(
+            RelationTypes.THREAD,
+            "m.room.message",
+            content={"msgtype": "m.text", "body": "A threaded reply!"},
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        threaded_event_id = channel.json_body["event_id"]
+
+        new_body = {"msgtype": "m.text", "body": "I've been edited!"}
+        channel = self._send_relation(
+            RelationTypes.REPLACE,
+            "m.room.message",
+            content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
+            parent_id=threaded_event_id,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        # Fetch the thread root, to get the bundled aggregation for the thread.
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room}/event/{self.parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        # We expect that the edit message appears in the thread summary in the
+        # unsigned relations section.
+        relations_dict = channel.json_body["unsigned"].get("m.relations")
+        self.assertIn(RelationTypes.THREAD, relations_dict)
+
+        thread_summary = relations_dict[RelationTypes.THREAD]
+        self.assertIn("latest_event", thread_summary)
+        latest_event_in_thread = thread_summary["latest_event"]
+        self.assertEquals(
+            latest_event_in_thread["content"]["body"], "I've been edited!"
+        )
+
     def test_edit_edit(self):
         """Test that an edit cannot be edited."""
         new_body = {"msgtype": "m.text", "body": "Initial edit"}
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 1c0cb0cf4f..2b3fdadffa 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -106,9 +106,13 @@ class RestHelper:
                 default room version.
             tok: The access token to use in the request.
             expect_code: The expected HTTP response code.
+            extra_content: Extra keys to include in the body of the /createRoom request.
+                Note that if is_public is set, the "visibility" key will be overridden.
+                If room_version is set, the "room_version" key will be overridden.
+            custom_headers: HTTP headers to include in the request.
 
         Returns:
-            The ID of the newly created room.
+            The ID of the newly created room, or None if the request failed.
         """
         temp_id = self.auth_user_id
         self.auth_user_id = room_creator
diff --git a/tests/storage/databases/test_state_store.py b/tests/storage/databases/test_state_store.py
new file mode 100644
index 0000000000..3a4a4a3a29
--- /dev/null
+++ b/tests/storage/databases/test_state_store.py
@@ -0,0 +1,283 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import typing
+from typing import Dict, List, Sequence, Tuple
+from unittest.mock import patch
+
+from twisted.internet.defer import Deferred, ensureDeferred
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.constants import EventTypes
+from synapse.storage.state import StateFilter
+from synapse.types import StateMap
+from synapse.util import Clock
+
+from tests.unittest import HomeserverTestCase
+
+if typing.TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+# StateFilter for ALL non-m.room.member state events
+ALL_NON_MEMBERS_STATE_FILTER = StateFilter.freeze(
+    types={EventTypes.Member: set()},
+    include_others=True,
+)
+
+FAKE_STATE = {
+    (EventTypes.Member, "@alice:test"): "join",
+    (EventTypes.Member, "@bob:test"): "leave",
+    (EventTypes.Member, "@charlie:test"): "invite",
+    ("test.type", "a"): "AAA",
+    ("test.type", "b"): "BBB",
+    ("other.event.type", "state.key"): "123",
+}
+
+
+class StateGroupInflightCachingTestCase(HomeserverTestCase):
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: "HomeServer"
+    ) -> None:
+        self.state_storage = homeserver.get_storage().state
+        self.state_datastore = homeserver.get_datastores().state
+        # Patch out the `_get_state_groups_from_groups`.
+        # This is useful because it lets us pretend we have a slow database.
+        get_state_groups_patch = patch.object(
+            self.state_datastore,
+            "_get_state_groups_from_groups",
+            self._fake_get_state_groups_from_groups,
+        )
+        get_state_groups_patch.start()
+
+        self.addCleanup(get_state_groups_patch.stop)
+        self.get_state_group_calls: List[
+            Tuple[Tuple[int, ...], StateFilter, Deferred[Dict[int, StateMap[str]]]]
+        ] = []
+
+    def _fake_get_state_groups_from_groups(
+        self, groups: Sequence[int], state_filter: StateFilter
+    ) -> "Deferred[Dict[int, StateMap[str]]]":
+        d: Deferred[Dict[int, StateMap[str]]] = Deferred()
+        self.get_state_group_calls.append((tuple(groups), state_filter, d))
+        return d
+
+    def _complete_request_fake(
+        self,
+        groups: Tuple[int, ...],
+        state_filter: StateFilter,
+        d: "Deferred[Dict[int, StateMap[str]]]",
+    ) -> None:
+        """
+        Assemble a fake database response and complete the database request.
+        """
+
+        # Return a filtered copy of the fake state
+        d.callback({group: state_filter.filter_state(FAKE_STATE) for group in groups})
+
+    def test_duplicate_requests_deduplicated(self) -> None:
+        """
+        Tests that duplicate requests for state are deduplicated.
+
+        This test:
+        - requests some state (state group 42, 'all' state filter)
+        - requests it again, before the first request finishes
+        - checks to see that only one database query was made
+        - completes the database query
+        - checks that both requests see the same retrieved state
+        """
+        req1 = ensureDeferred(
+            self.state_datastore._get_state_for_group_using_inflight_cache(
+                42, StateFilter.all()
+            )
+        )
+        self.pump(by=0.1)
+
+        # This should have gone to the database
+        self.assertEqual(len(self.get_state_group_calls), 1)
+        self.assertFalse(req1.called)
+
+        req2 = ensureDeferred(
+            self.state_datastore._get_state_for_group_using_inflight_cache(
+                42, StateFilter.all()
+            )
+        )
+        self.pump(by=0.1)
+
+        # No more calls should have gone to the database
+        self.assertEqual(len(self.get_state_group_calls), 1)
+        self.assertFalse(req1.called)
+        self.assertFalse(req2.called)
+
+        groups, sf, d = self.get_state_group_calls[0]
+        self.assertEqual(groups, (42,))
+        self.assertEqual(sf, StateFilter.all())
+
+        # Now we can complete the request
+        self._complete_request_fake(groups, sf, d)
+
+        self.assertEqual(self.get_success(req1), FAKE_STATE)
+        self.assertEqual(self.get_success(req2), FAKE_STATE)
+
+    def test_smaller_request_deduplicated(self) -> None:
+        """
+        Tests that duplicate requests for state are deduplicated.
+
+        This test:
+        - requests some state (state group 42, 'all' state filter)
+        - requests a subset of that state, before the first request finishes
+        - checks to see that only one database query was made
+        - completes the database query
+        - checks that both requests see the correct retrieved state
+        """
+        req1 = ensureDeferred(
+            self.state_datastore._get_state_for_group_using_inflight_cache(
+                42, StateFilter.from_types((("test.type", None),))
+            )
+        )
+        self.pump(by=0.1)
+
+        # This should have gone to the database
+        self.assertEqual(len(self.get_state_group_calls), 1)
+        self.assertFalse(req1.called)
+
+        req2 = ensureDeferred(
+            self.state_datastore._get_state_for_group_using_inflight_cache(
+                42, StateFilter.from_types((("test.type", "b"),))
+            )
+        )
+        self.pump(by=0.1)
+
+        # No more calls should have gone to the database, because the second
+        # request was already in the in-flight cache!
+        self.assertEqual(len(self.get_state_group_calls), 1)
+        self.assertFalse(req1.called)
+        self.assertFalse(req2.called)
+
+        groups, sf, d = self.get_state_group_calls[0]
+        self.assertEqual(groups, (42,))
+        # The state filter is expanded internally for increased cache hit rate,
+        # so we the database sees a wider state filter than requested.
+        self.assertEqual(sf, ALL_NON_MEMBERS_STATE_FILTER)
+
+        # Now we can complete the request
+        self._complete_request_fake(groups, sf, d)
+
+        self.assertEqual(
+            self.get_success(req1),
+            {("test.type", "a"): "AAA", ("test.type", "b"): "BBB"},
+        )
+        self.assertEqual(self.get_success(req2), {("test.type", "b"): "BBB"})
+
+    def test_partially_overlapping_request_deduplicated(self) -> None:
+        """
+        Tests that partially-overlapping requests are partially deduplicated.
+
+        This test:
+        - requests a single type of wildcard state
+          (This is internally expanded to be all non-member state)
+        - requests the entire state in parallel
+        - checks to see that two database queries were made, but that the second
+          one is only for member state.
+        - completes the database queries
+        - checks that both requests have the correct result.
+        """
+
+        req1 = ensureDeferred(
+            self.state_datastore._get_state_for_group_using_inflight_cache(
+                42, StateFilter.from_types((("test.type", None),))
+            )
+        )
+        self.pump(by=0.1)
+
+        # This should have gone to the database
+        self.assertEqual(len(self.get_state_group_calls), 1)
+        self.assertFalse(req1.called)
+
+        req2 = ensureDeferred(
+            self.state_datastore._get_state_for_group_using_inflight_cache(
+                42, StateFilter.all()
+            )
+        )
+        self.pump(by=0.1)
+
+        # Because it only partially overlaps, this also went to the database
+        self.assertEqual(len(self.get_state_group_calls), 2)
+        self.assertFalse(req1.called)
+        self.assertFalse(req2.called)
+
+        # First request:
+        groups, sf, d = self.get_state_group_calls[0]
+        self.assertEqual(groups, (42,))
+        # The state filter is expanded internally for increased cache hit rate,
+        # so we the database sees a wider state filter than requested.
+        self.assertEqual(sf, ALL_NON_MEMBERS_STATE_FILTER)
+        self._complete_request_fake(groups, sf, d)
+
+        # Second request:
+        groups, sf, d = self.get_state_group_calls[1]
+        self.assertEqual(groups, (42,))
+        # The state filter is narrowed to only request membership state, because
+        # the remainder of the state is already being queried in the first request!
+        self.assertEqual(
+            sf, StateFilter.freeze({EventTypes.Member: None}, include_others=False)
+        )
+        self._complete_request_fake(groups, sf, d)
+
+        # Check the results are correct
+        self.assertEqual(
+            self.get_success(req1),
+            {("test.type", "a"): "AAA", ("test.type", "b"): "BBB"},
+        )
+        self.assertEqual(self.get_success(req2), FAKE_STATE)
+
+    def test_in_flight_requests_stop_being_in_flight(self) -> None:
+        """
+        Tests that in-flight request deduplication doesn't somehow 'hold on'
+        to completed requests: once they're done, they're taken out of the
+        in-flight cache.
+        """
+        req1 = ensureDeferred(
+            self.state_datastore._get_state_for_group_using_inflight_cache(
+                42, StateFilter.all()
+            )
+        )
+        self.pump(by=0.1)
+
+        # This should have gone to the database
+        self.assertEqual(len(self.get_state_group_calls), 1)
+        self.assertFalse(req1.called)
+
+        # Complete the request right away.
+        self._complete_request_fake(*self.get_state_group_calls[0])
+        self.assertTrue(req1.called)
+
+        # Send off another request
+        req2 = ensureDeferred(
+            self.state_datastore._get_state_for_group_using_inflight_cache(
+                42, StateFilter.all()
+            )
+        )
+        self.pump(by=0.1)
+
+        # It should have gone to the database again, because the previous request
+        # isn't in-flight and therefore isn't available for deduplication.
+        self.assertEqual(len(self.get_state_group_calls), 2)
+        self.assertFalse(req2.called)
+
+        # Complete the request right away.
+        self._complete_request_fake(*self.get_state_group_calls[1])
+        self.assertTrue(req2.called)
+        groups, sf, d = self.get_state_group_calls[0]
+
+        self.assertEqual(self.get_success(req1), FAKE_STATE)
+        self.assertEqual(self.get_success(req2), FAKE_STATE)
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index f462a8b1c7..a8639d8f82 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -329,3 +329,110 @@ class ExtremPruneTestCase(HomeserverTestCase):
 
         # Check the new extremity is just the new remote event.
         self.assert_extremities([local_message_event_id, remote_event_2.event_id])
+
+
+class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, homeserver):
+        self.state = self.hs.get_state_handler()
+        self.persistence = self.hs.get_storage().persistence
+        self.store = self.hs.get_datastore()
+
+    def test_remote_user_rooms_cache_invalidated(self):
+        """Test that if the server leaves a room the `get_rooms_for_user` cache
+        is invalidated for remote users.
+        """
+
+        # Set up a room with a local and remote user in it.
+        user_id = self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        room_id = self.helper.create_room_as(
+            "user", room_version=RoomVersions.V6.identifier, tok=token
+        )
+
+        body = self.helper.send(room_id, body="Test", tok=token)
+        local_message_event_id = body["event_id"]
+
+        # Fudge a join event for a remote user.
+        remote_user = "@user:other"
+        remote_event_1 = event_from_pdu_json(
+            {
+                "type": EventTypes.Member,
+                "state_key": remote_user,
+                "content": {"membership": Membership.JOIN},
+                "room_id": room_id,
+                "sender": remote_user,
+                "depth": 5,
+                "prev_events": [local_message_event_id],
+                "auth_events": [],
+                "origin_server_ts": self.clock.time_msec(),
+            },
+            RoomVersions.V6,
+        )
+
+        context = self.get_success(self.state.compute_event_context(remote_event_1))
+        self.get_success(self.persistence.persist_event(remote_event_1, context))
+
+        # Call `get_rooms_for_user` to add the remote user to the cache
+        rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
+        self.assertEqual(set(rooms), {room_id})
+
+        # Now we have the local server leave the room, and check that calling
+        # `get_user_in_room` for the remote user no longer includes the room.
+        self.helper.leave(room_id, user_id, tok=token)
+
+        rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
+        self.assertEqual(set(rooms), set())
+
+    def test_room_remote_user_cache_invalidated(self):
+        """Test that if the server leaves a room the `get_users_in_room` cache
+        is invalidated for remote users.
+        """
+
+        # Set up a room with a local and remote user in it.
+        user_id = self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        room_id = self.helper.create_room_as(
+            "user", room_version=RoomVersions.V6.identifier, tok=token
+        )
+
+        body = self.helper.send(room_id, body="Test", tok=token)
+        local_message_event_id = body["event_id"]
+
+        # Fudge a join event for a remote user.
+        remote_user = "@user:other"
+        remote_event_1 = event_from_pdu_json(
+            {
+                "type": EventTypes.Member,
+                "state_key": remote_user,
+                "content": {"membership": Membership.JOIN},
+                "room_id": room_id,
+                "sender": remote_user,
+                "depth": 5,
+                "prev_events": [local_message_event_id],
+                "auth_events": [],
+                "origin_server_ts": self.clock.time_msec(),
+            },
+            RoomVersions.V6,
+        )
+
+        context = self.get_success(self.state.compute_event_context(remote_event_1))
+        self.get_success(self.persistence.persist_event(remote_event_1, context))
+
+        # Call `get_users_in_room` to add the remote user to the cache
+        users = self.get_success(self.store.get_users_in_room(room_id))
+        self.assertEqual(set(users), {user_id, remote_user})
+
+        # Now we have the local server leave the room, and check that calling
+        # `get_user_in_room` for the remote user no longer includes the room.
+        self.helper.leave(room_id, user_id, tok=token)
+
+        users = self.get_success(self.store.get_users_in_room(room_id))
+        self.assertEqual(users, [])
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 70d52b088c..28c767ecfd 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -992,3 +992,112 @@ class StateFilterDifferenceTestCase(TestCase):
             StateFilter.none(),
             StateFilter.all(),
         )
+
+
+class StateFilterTestCase(TestCase):
+    def test_return_expanded(self):
+        """
+        Tests the behaviour of the return_expanded() function that expands
+        StateFilters to include more state types (for the sake of cache hit rate).
+        """
+
+        self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all())
+
+        self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none())
+
+        # Concrete-only state filters stay the same
+        # (Case: mixed filter)
+        self.assertEqual(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:test", "@alicia:test"},
+                    "some.other.state.type": {""},
+                },
+                include_others=False,
+            ).return_expanded(),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:test", "@alicia:test"},
+                    "some.other.state.type": {""},
+                },
+                include_others=False,
+            ),
+        )
+
+        # Concrete-only state filters stay the same
+        # (Case: non-member-only filter)
+        self.assertEqual(
+            StateFilter.freeze(
+                {"some.other.state.type": {""}}, include_others=False
+            ).return_expanded(),
+            StateFilter.freeze({"some.other.state.type": {""}}, include_others=False),
+        )
+
+        # Concrete-only state filters stay the same
+        # (Case: member-only filter)
+        self.assertEqual(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:test", "@alicia:test"},
+                },
+                include_others=False,
+            ).return_expanded(),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:test", "@alicia:test"},
+                },
+                include_others=False,
+            ),
+        )
+
+        # Wildcard member-only state filters stay the same
+        self.assertEqual(
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ).return_expanded(),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+        )
+
+        # If there is a wildcard in the non-member portion of the filter,
+        # it's expanded to include ALL non-member events.
+        # (Case: mixed filter)
+        self.assertEqual(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:test", "@alicia:test"},
+                    "some.other.state.type": None,
+                },
+                include_others=False,
+            ).return_expanded(),
+            StateFilter.freeze(
+                {EventTypes.Member: {"@wombat:test", "@alicia:test"}},
+                include_others=True,
+            ),
+        )
+
+        # If there is a wildcard in the non-member portion of the filter,
+        # it's expanded to include ALL non-member events.
+        # (Case: non-member-only filter)
+        self.assertEqual(
+            StateFilter.freeze(
+                {
+                    "some.other.state.type": None,
+                },
+                include_others=False,
+            ).return_expanded(),
+            StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
+        )
+        self.assertEqual(
+            StateFilter.freeze(
+                {
+                    "some.other.state.type": None,
+                    "yet.another.state.type": {"wombat"},
+                },
+                include_others=False,
+            ).return_expanded(),
+            StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
+        )
diff --git a/tests/unittest.py b/tests/unittest.py
index a71892cb9d..7983c1e8b8 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -51,7 +51,10 @@ from twisted.web.server import Request
 
 from synapse import events
 from synapse.api.constants import EventTypes, Membership
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.config.homeserver import HomeServerConfig
+from synapse.config.server import DEFAULT_ROOM_VERSION
+from synapse.crypto.event_signing import add_hashes_and_signatures
 from synapse.federation.transport.server import TransportLayerServer
 from synapse.http.server import JsonResource
 from synapse.http.site import SynapseRequest, SynapseSite
@@ -839,6 +842,24 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
             client_ip=client_ip,
         )
 
+    def add_hashes_and_signatures(
+        self,
+        event_dict: JsonDict,
+        room_version: RoomVersion = KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
+    ) -> JsonDict:
+        """Adds hashes and signatures to the given event dict
+
+        Returns:
+             The modified event dict, for convenience
+        """
+        add_hashes_and_signatures(
+            room_version,
+            event_dict,
+            signature_name=self.OTHER_SERVER_NAME,
+            signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
+        )
+        return event_dict
+
 
 def _auth_header_for_request(
     origin: str,
diff --git a/tox.ini b/tox.ini
index 32679e9106..436ecf7552 100644
--- a/tox.ini
+++ b/tox.ini
@@ -4,6 +4,11 @@ envlist = packaging, py37, py38, py39, py310, check_codestyle, check_isort
 # we require tox>=2.3.2 for the fix to https://github.com/tox-dev/tox/issues/208
 minversion = 2.3.2
 
+# the tox-venv plugin makes tox use python's built-in `venv` module rather than
+# the legacy `virtualenv` tool. `virtualenv` embeds its own `pip`, `setuptools`,
+# etc, and ends up being rather unreliable.
+requires = tox-venv
+
 [base]
 deps =
     python-subunit
@@ -119,6 +124,9 @@ usedevelop = false
 deps =
     Automat == 0.8.0
     lxml
+    # markupsafe 2.1 introduced a change that breaks Jinja 2.x. Since we depend on
+    # Jinja >= 2.9, it means this test suite will fail if markupsafe >= 2.1 is installed.
+    markupsafe < 2.1
     {[base]deps}
 
 commands =
@@ -158,7 +166,7 @@ commands =
 
 [testenv:check_isort]
 extras = lint
-commands = isort -c --df --sp setup.cfg {[base]lint_targets}
+commands = isort -c --df {[base]lint_targets}
 
 [testenv:check-newsfragment]
 skip_install = true