summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2024-06-24 16:21:16 +0100
committerErik Johnston <erik@matrix.org>2024-06-24 16:21:16 +0100
commit8a2a335db4e0cbe18b9ae1f390c77dfe3b8a88c8 (patch)
tree18adcc01ff6eee5ab873055c0a3045d090325e74
parentMerge remote-tracking branch 'origin/develop' into matrix-org-hotfixes (diff)
parentFix room `type` typo in mailer (#17336) (diff)
downloadsynapse-8a2a335db4e0cbe18b9ae1f390c77dfe3b8a88c8.tar.xz
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
-rw-r--r--Cargo.lock4
-rw-r--r--changelog.d/17255.feature1
-rw-r--r--changelog.d/17333.misc1
-rw-r--r--changelog.d/17336.bugfix1
-rw-r--r--changelog.d/17338.misc1
-rw-r--r--changelog.d/17339.misc1
-rw-r--r--poetry.lock76
-rw-r--r--synapse/config/experimental.py4
-rw-r--r--synapse/handlers/message.py11
-rw-r--r--synapse/http/servlet.py12
-rw-r--r--synapse/push/mailer.py5
-rw-r--r--synapse/replication/tcp/client.py19
-rw-r--r--synapse/replication/tcp/streams/_base.py12
-rw-r--r--synapse/rest/admin/__init__.py3
-rw-r--r--synapse/rest/admin/federation.py8
-rw-r--r--synapse/rest/admin/media.py12
-rw-r--r--synapse/rest/admin/statistics.py8
-rw-r--r--synapse/rest/admin/users.py43
-rw-r--r--synapse/rest/client/profile.py26
-rw-r--r--synapse/rest/client/room.py25
-rw-r--r--synapse/storage/controllers/persist_events.py12
-rw-r--r--synapse/storage/databases/main/devices.py93
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py4
-rw-r--r--synapse/storage/databases/main/event_federation.py20
-rw-r--r--synapse/storage/databases/main/events.py251
-rw-r--r--synapse/streams/config.py3
-rw-r--r--tests/rest/admin/test_user.py84
-rw-r--r--tests/rest/client/test_rooms.py105
-rw-r--r--tests/storage/test_devices.py8
-rw-r--r--tests/storage/test_event_chain.py9
-rw-r--r--tests/storage/test_event_federation.py44
-rw-r--r--tests/unittest.py2
32 files changed, 695 insertions, 213 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 7472e16291..1955c1a4e7 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -212,9 +212,9 @@ dependencies = [
 
 [[package]]
 name = "lazy_static"
-version = "1.4.0"
+version = "1.5.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
+checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
 
 [[package]]
 name = "libc"
diff --git a/changelog.d/17255.feature b/changelog.d/17255.feature
new file mode 100644
index 0000000000..4093de1146
--- /dev/null
+++ b/changelog.d/17255.feature
@@ -0,0 +1 @@
+Add support for [MSC823](https://github.com/matrix-org/matrix-spec-proposals/pull/3823) - Account suspension.
\ No newline at end of file
diff --git a/changelog.d/17333.misc b/changelog.d/17333.misc
new file mode 100644
index 0000000000..d3ef0b3777
--- /dev/null
+++ b/changelog.d/17333.misc
@@ -0,0 +1 @@
+Handle device lists notifications for large accounts more efficiently in worker mode.
diff --git a/changelog.d/17336.bugfix b/changelog.d/17336.bugfix
new file mode 100644
index 0000000000..618834302e
--- /dev/null
+++ b/changelog.d/17336.bugfix
@@ -0,0 +1 @@
+Fix email notification subject when invited to a space.
diff --git a/changelog.d/17338.misc b/changelog.d/17338.misc
new file mode 100644
index 0000000000..1a81bdef85
--- /dev/null
+++ b/changelog.d/17338.misc
@@ -0,0 +1 @@
+Do not block event sending/receiving while calculating large event auth chains.
diff --git a/changelog.d/17339.misc b/changelog.d/17339.misc
new file mode 100644
index 0000000000..1d7cb96c8b
--- /dev/null
+++ b/changelog.d/17339.misc
@@ -0,0 +1 @@
+Tidy up `parse_integer` docs and call sites to reflect the fact that they require non-negative integers by default, and bring `parse_integer_from_args` default in alignment. Contributed by Denis Kasak (@dkasak).
diff --git a/poetry.lock b/poetry.lock
index 58981ff6e1..1bae0ea388 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -35,13 +35,13 @@ tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "p
 
 [[package]]
 name = "authlib"
-version = "1.3.0"
+version = "1.3.1"
 description = "The ultimate Python library in building OAuth and OpenID Connect servers and clients."
 optional = true
 python-versions = ">=3.8"
 files = [
-    {file = "Authlib-1.3.0-py2.py3-none-any.whl", hash = "sha256:9637e4de1fb498310a56900b3e2043a206b03cb11c05422014b0302cbc814be3"},
-    {file = "Authlib-1.3.0.tar.gz", hash = "sha256:959ea62a5b7b5123c5059758296122b57cd2585ae2ed1c0622c21b371ffdae06"},
+    {file = "Authlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:d35800b973099bbadc49b42b256ecb80041ad56b7fe1216a362c7943c088f377"},
+    {file = "authlib-1.3.1.tar.gz", hash = "sha256:7ae843f03c06c5c0debd63c9db91f9fda64fa62a42a77419fa15fbb7e7a58917"},
 ]
 
 [package.dependencies]
@@ -1461,13 +1461,13 @@ test = ["lxml", "pytest (>=4.6)", "pytest-cov"]
 
 [[package]]
 name = "netaddr"
-version = "1.2.1"
+version = "1.3.0"
 description = "A network address manipulation library for Python"
 optional = false
 python-versions = ">=3.7"
 files = [
-    {file = "netaddr-1.2.1-py3-none-any.whl", hash = "sha256:bd9e9534b0d46af328cf64f0e5a23a5a43fca292df221c85580b27394793496e"},
-    {file = "netaddr-1.2.1.tar.gz", hash = "sha256:6eb8fedf0412c6d294d06885c110de945cf4d22d2b510d0404f4e06950857987"},
+    {file = "netaddr-1.3.0-py3-none-any.whl", hash = "sha256:c2c6a8ebe5554ce33b7d5b3a306b71bbb373e000bbbf2350dd5213cc56e3dbbe"},
+    {file = "netaddr-1.3.0.tar.gz", hash = "sha256:5c3c3d9895b551b763779ba7db7a03487dc1f8e3b385af819af341ae9ef6e48a"},
 ]
 
 [package.extras]
@@ -1488,13 +1488,13 @@ tests = ["Sphinx", "doubles", "flake8", "flake8-quotes", "gevent", "mock", "pyte
 
 [[package]]
 name = "packaging"
-version = "24.0"
+version = "24.1"
 description = "Core utilities for Python packages"
 optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
 files = [
-    {file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"},
-    {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"},
+    {file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"},
+    {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"},
 ]
 
 [[package]]
@@ -2157,13 +2157,13 @@ rpds-py = ">=0.7.0"
 
 [[package]]
 name = "requests"
-version = "2.31.0"
+version = "2.32.2"
 description = "Python HTTP for Humans."
 optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
 files = [
-    {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"},
-    {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"},
+    {file = "requests-2.32.2-py3-none-any.whl", hash = "sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c"},
+    {file = "requests-2.32.2.tar.gz", hash = "sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289"},
 ]
 
 [package.dependencies]
@@ -2387,13 +2387,13 @@ doc = ["Sphinx", "sphinx-rtd-theme"]
 
 [[package]]
 name = "sentry-sdk"
-version = "2.3.1"
+version = "2.6.0"
 description = "Python client for Sentry (https://sentry.io)"
 optional = true
 python-versions = ">=3.6"
 files = [
-    {file = "sentry_sdk-2.3.1-py2.py3-none-any.whl", hash = "sha256:c5aeb095ba226391d337dd42a6f9470d86c9fc236ecc71cfc7cd1942b45010c6"},
-    {file = "sentry_sdk-2.3.1.tar.gz", hash = "sha256:139a71a19f5e9eb5d3623942491ce03cf8ebc14ea2e39ba3e6fe79560d8a5b1f"},
+    {file = "sentry_sdk-2.6.0-py2.py3-none-any.whl", hash = "sha256:422b91cb49378b97e7e8d0e8d5a1069df23689d45262b86f54988a7db264e874"},
+    {file = "sentry_sdk-2.6.0.tar.gz", hash = "sha256:65cc07e9c6995c5e316109f138570b32da3bd7ff8d0d0ee4aaf2628c3dd8127d"},
 ]
 
 [package.dependencies]
@@ -2598,22 +2598,22 @@ files = [
 
 [[package]]
 name = "tornado"
-version = "6.4"
+version = "6.4.1"
 description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed."
 optional = true
-python-versions = ">= 3.8"
+python-versions = ">=3.8"
 files = [
-    {file = "tornado-6.4-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:02ccefc7d8211e5a7f9e8bc3f9e5b0ad6262ba2fbb683a6443ecc804e5224ce0"},
-    {file = "tornado-6.4-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:27787de946a9cffd63ce5814c33f734c627a87072ec7eed71f7fc4417bb16263"},
-    {file = "tornado-6.4-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f7894c581ecdcf91666a0912f18ce5e757213999e183ebfc2c3fdbf4d5bd764e"},
-    {file = "tornado-6.4-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e43bc2e5370a6a8e413e1e1cd0c91bedc5bd62a74a532371042a18ef19e10579"},
-    {file = "tornado-6.4-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f0251554cdd50b4b44362f73ad5ba7126fc5b2c2895cc62b14a1c2d7ea32f212"},
-    {file = "tornado-6.4-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:fd03192e287fbd0899dd8f81c6fb9cbbc69194d2074b38f384cb6fa72b80e9c2"},
-    {file = "tornado-6.4-cp38-abi3-musllinux_1_1_i686.whl", hash = "sha256:88b84956273fbd73420e6d4b8d5ccbe913c65d31351b4c004ae362eba06e1f78"},
-    {file = "tornado-6.4-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:71ddfc23a0e03ef2df1c1397d859868d158c8276a0603b96cf86892bff58149f"},
-    {file = "tornado-6.4-cp38-abi3-win32.whl", hash = "sha256:6f8a6c77900f5ae93d8b4ae1196472d0ccc2775cc1dfdc9e7727889145c45052"},
-    {file = "tornado-6.4-cp38-abi3-win_amd64.whl", hash = "sha256:10aeaa8006333433da48dec9fe417877f8bcc21f48dda8d661ae79da357b2a63"},
-    {file = "tornado-6.4.tar.gz", hash = "sha256:72291fa6e6bc84e626589f1c29d90a5a6d593ef5ae68052ee2ef000dfd273dee"},
+    {file = "tornado-6.4.1-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:163b0aafc8e23d8cdc3c9dfb24c5368af84a81e3364745ccb4427669bf84aec8"},
+    {file = "tornado-6.4.1-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:6d5ce3437e18a2b66fbadb183c1d3364fb03f2be71299e7d10dbeeb69f4b2a14"},
+    {file = "tornado-6.4.1-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2e20b9113cd7293f164dc46fffb13535266e713cdb87bd2d15ddb336e96cfc4"},
+    {file = "tornado-6.4.1-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8ae50a504a740365267b2a8d1a90c9fbc86b780a39170feca9bcc1787ff80842"},
+    {file = "tornado-6.4.1-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:613bf4ddf5c7a95509218b149b555621497a6cc0d46ac341b30bd9ec19eac7f3"},
+    {file = "tornado-6.4.1-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:25486eb223babe3eed4b8aecbac33b37e3dd6d776bc730ca14e1bf93888b979f"},
+    {file = "tornado-6.4.1-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:454db8a7ecfcf2ff6042dde58404164d969b6f5d58b926da15e6b23817950fc4"},
+    {file = "tornado-6.4.1-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a02a08cc7a9314b006f653ce40483b9b3c12cda222d6a46d4ac63bb6c9057698"},
+    {file = "tornado-6.4.1-cp38-abi3-win32.whl", hash = "sha256:d9a566c40b89757c9aa8e6f032bcdb8ca8795d7c1a9762910c722b1635c9de4d"},
+    {file = "tornado-6.4.1-cp38-abi3-win_amd64.whl", hash = "sha256:b24b8982ed444378d7f21d563f4180a2de31ced9d8d84443907a0a64da2072e7"},
+    {file = "tornado-6.4.1.tar.gz", hash = "sha256:92d3ab53183d8c50f8204a51e6f91d18a15d5ef261e84d452800d4ff6fc504e9"},
 ]
 
 [[package]]
@@ -2917,13 +2917,13 @@ files = [
 
 [[package]]
 name = "typing-extensions"
-version = "4.11.0"
+version = "4.12.2"
 description = "Backported and Experimental Type Hints for Python 3.8+"
 optional = false
 python-versions = ">=3.8"
 files = [
-    {file = "typing_extensions-4.11.0-py3-none-any.whl", hash = "sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a"},
-    {file = "typing_extensions-4.11.0.tar.gz", hash = "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0"},
+    {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"},
+    {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"},
 ]
 
 [[package]]
@@ -2939,18 +2939,18 @@ files = [
 
 [[package]]
 name = "urllib3"
-version = "2.0.7"
+version = "2.2.2"
 description = "HTTP library with thread-safe connection pooling, file post, and more."
 optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
 files = [
-    {file = "urllib3-2.0.7-py3-none-any.whl", hash = "sha256:fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e"},
-    {file = "urllib3-2.0.7.tar.gz", hash = "sha256:c97dfde1f7bd43a71c8d2a58e369e9b2bf692d1334ea9f9cae55add7d0dd0f84"},
+    {file = "urllib3-2.2.2-py3-none-any.whl", hash = "sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472"},
+    {file = "urllib3-2.2.2.tar.gz", hash = "sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168"},
 ]
 
 [package.extras]
 brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"]
-secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"]
+h2 = ["h2 (>=4,<5)"]
 socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
 zstd = ["zstandard (>=0.18.0)"]
 
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 23e96da6a3..1b72727b75 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -433,6 +433,10 @@ class ExperimentalConfig(Config):
                 ("experimental", "msc4108_delegation_endpoint"),
             )
 
+        self.msc3823_account_suspension = experimental.get(
+            "msc3823_account_suspension", False
+        )
+
         self.msc3916_authenticated_media_enabled = experimental.get(
             "msc3916_authenticated_media_enabled", False
         )
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 16d01efc67..5aa48230ec 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -642,6 +642,17 @@ class EventCreationHandler:
         """
         await self.auth_blocking.check_auth_blocking(requester=requester)
 
+        if event_dict["type"] == EventTypes.Message:
+            requester_suspended = await self.store.get_user_suspended_status(
+                requester.user.to_string()
+            )
+            if requester_suspended:
+                raise SynapseError(
+                    403,
+                    "Sending messages while account is suspended is not allowed.",
+                    Codes.USER_ACCOUNT_SUSPENDED,
+                )
+
         if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
             room_version_id = event_dict["content"]["room_version"]
             maybe_room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id)
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index ab12951da8..08b8ff7afd 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -119,14 +119,15 @@ def parse_integer(
         default: value to use if the parameter is absent, defaults to None.
         required: whether to raise a 400 SynapseError if the parameter is absent,
             defaults to False.
-        negative: whether to allow negative integers, defaults to True.
+        negative: whether to allow negative integers, defaults to False (disallowing
+            negatives).
     Returns:
         An int value or the default.
 
     Raises:
         SynapseError: if the parameter is absent and required, if the
             parameter is present and not an integer, or if the
-            parameter is illegitimate negative.
+            parameter is illegitimately negative.
     """
     args: Mapping[bytes, Sequence[bytes]] = request.args  # type: ignore
     return parse_integer_from_args(args, name, default, required, negative)
@@ -164,7 +165,7 @@ def parse_integer_from_args(
     name: str,
     default: Optional[int] = None,
     required: bool = False,
-    negative: bool = True,
+    negative: bool = False,
 ) -> Optional[int]:
     """Parse an integer parameter from the request string
 
@@ -174,7 +175,8 @@ def parse_integer_from_args(
         default: value to use if the parameter is absent, defaults to None.
         required: whether to raise a 400 SynapseError if the parameter is absent,
             defaults to False.
-        negative: whether to allow negative integers, defaults to True.
+        negative: whether to allow negative integers, defaults to False (disallowing
+            negatives).
 
     Returns:
         An int value or the default.
@@ -182,7 +184,7 @@ def parse_integer_from_args(
     Raises:
         SynapseError: if the parameter is absent and required, if the
             parameter is present and not an integer, or if the
-            parameter is illegitimate negative.
+            parameter is illegitimately negative.
     """
     name_bytes = name.encode("ascii")
 
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 77cc69a71f..cf611bd90b 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -28,7 +28,7 @@ import jinja2
 from markupsafe import Markup
 from prometheus_client import Counter
 
-from synapse.api.constants import EventTypes, Membership, RoomTypes
+from synapse.api.constants import EventContentFields, EventTypes, Membership, RoomTypes
 from synapse.api.errors import StoreError
 from synapse.config.emailconfig import EmailSubjectConfig
 from synapse.events import EventBase
@@ -716,7 +716,8 @@ class Mailer:
                 )
                 if (
                     create_event
-                    and create_event.content.get("room_type") == RoomTypes.SPACE
+                    and create_event.content.get(EventContentFields.ROOM_TYPE)
+                    == RoomTypes.SPACE
                 ):
                     return self.email_subjects.invite_from_person_to_space % {
                         "person": inviter_name,
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 2d6d49eed7..3dddbb70b4 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -114,13 +114,19 @@ class ReplicationDataHandler:
         """
         all_room_ids: Set[str] = set()
         if stream_name == DeviceListsStream.NAME:
-            if any(row.entity.startswith("@") and not row.is_signature for row in rows):
+            if any(not row.is_signature and not row.hosts_calculated for row in rows):
                 prev_token = self.store.get_device_stream_token()
                 all_room_ids = await self.store.get_all_device_list_changes(
                     prev_token, token
                 )
                 self.store.device_lists_in_rooms_have_changed(all_room_ids, token)
 
+            # If we're sending federation we need to update the device lists
+            # outbound pokes stream change cache with updated hosts.
+            if self.send_handler and any(row.hosts_calculated for row in rows):
+                hosts = await self.store.get_destinations_for_device(token)
+                self.store.device_lists_outbound_pokes_have_changed(hosts, token)
+
         self.store.process_replication_rows(stream_name, instance_name, token, rows)
         # NOTE: this must be called after process_replication_rows to ensure any
         # cache invalidations are first handled before any stream ID advances.
@@ -433,12 +439,11 @@ class FederationSenderHandler:
             # The entities are either user IDs (starting with '@') whose devices
             # have changed, or remote servers that we need to tell about
             # changes.
-            hosts = {
-                row.entity
-                for row in rows
-                if not row.entity.startswith("@") and not row.is_signature
-            }
-            await self.federation_sender.send_device_messages(hosts, immediate=False)
+            if any(row.hosts_calculated for row in rows):
+                hosts = await self.store.get_destinations_for_device(token)
+                await self.federation_sender.send_device_messages(
+                    hosts, immediate=False
+                )
 
         elif stream_name == ToDeviceStream.NAME:
             # The to_device stream includes stuff to be pushed to both local
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 661206c841..d021904de7 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -549,10 +549,14 @@ class DeviceListsStream(_StreamFromIdGen):
 
     @attr.s(slots=True, frozen=True, auto_attribs=True)
     class DeviceListsStreamRow:
-        entity: str
+        user_id: str
         # Indicates that a user has signed their own device with their user-signing key
         is_signature: bool
 
+        # Indicates if this is a notification that we've calculated the hosts we
+        # need to send the update to.
+        hosts_calculated: bool
+
     NAME = "device_lists"
     ROW_TYPE = DeviceListsStreamRow
 
@@ -594,13 +598,13 @@ class DeviceListsStream(_StreamFromIdGen):
             upper_limit_token = min(upper_limit_token, signatures_to_token)
 
         device_updates = [
-            (stream_id, (entity, False))
-            for stream_id, (entity,) in device_updates
+            (stream_id, (entity, False, hosts))
+            for stream_id, (entity, hosts) in device_updates
             if stream_id <= upper_limit_token
         ]
 
         signatures_updates = [
-            (stream_id, (entity, True))
+            (stream_id, (entity, True, False))
             for stream_id, (entity,) in signatures_updates
             if stream_id <= upper_limit_token
         ]
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 6da1d79168..cdaee17451 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -101,6 +101,7 @@ from synapse.rest.admin.users import (
     ResetPasswordRestServlet,
     SearchUsersRestServlet,
     ShadowBanRestServlet,
+    SuspendAccountRestServlet,
     UserAdminServlet,
     UserByExternalId,
     UserByThreePid,
@@ -327,6 +328,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     BackgroundUpdateRestServlet(hs).register(http_server)
     BackgroundUpdateStartJobRestServlet(hs).register(http_server)
     ExperimentalFeaturesRestServlet(hs).register(http_server)
+    if hs.config.experimental.msc3823_account_suspension:
+        SuspendAccountRestServlet(hs).register(http_server)
 
 
 def register_servlets_for_client_rest_resource(
diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py
index 14ab4644cb..d85a04b825 100644
--- a/synapse/rest/admin/federation.py
+++ b/synapse/rest/admin/federation.py
@@ -61,8 +61,8 @@ class ListDestinationsRestServlet(RestServlet):
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self._auth, request)
 
-        start = parse_integer(request, "from", default=0, negative=False)
-        limit = parse_integer(request, "limit", default=100, negative=False)
+        start = parse_integer(request, "from", default=0)
+        limit = parse_integer(request, "limit", default=100)
 
         destination = parse_string(request, "destination")
 
@@ -181,8 +181,8 @@ class DestinationMembershipRestServlet(RestServlet):
         if not await self._store.is_destination_known(destination):
             raise NotFoundError("Unknown destination")
 
-        start = parse_integer(request, "from", default=0, negative=False)
-        limit = parse_integer(request, "limit", default=100, negative=False)
+        start = parse_integer(request, "from", default=0)
+        limit = parse_integer(request, "limit", default=100)
 
         direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
 
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index a05b7252ec..ee6a681285 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -311,8 +311,8 @@ class DeleteMediaByDateSize(RestServlet):
     ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
-        before_ts = parse_integer(request, "before_ts", required=True, negative=False)
-        size_gt = parse_integer(request, "size_gt", default=0, negative=False)
+        before_ts = parse_integer(request, "before_ts", required=True)
+        size_gt = parse_integer(request, "size_gt", default=0)
         keep_profiles = parse_boolean(request, "keep_profiles", default=True)
 
         if before_ts < 30000000000:  # Dec 1970 in milliseconds, Aug 2920 in seconds
@@ -377,8 +377,8 @@ class UserMediaRestServlet(RestServlet):
         if user is None:
             raise NotFoundError("Unknown user")
 
-        start = parse_integer(request, "from", default=0, negative=False)
-        limit = parse_integer(request, "limit", default=100, negative=False)
+        start = parse_integer(request, "from", default=0)
+        limit = parse_integer(request, "limit", default=100)
 
         # If neither `order_by` nor `dir` is set, set the default order
         # to newest media is on top for backward compatibility.
@@ -421,8 +421,8 @@ class UserMediaRestServlet(RestServlet):
         if user is None:
             raise NotFoundError("Unknown user")
 
-        start = parse_integer(request, "from", default=0, negative=False)
-        limit = parse_integer(request, "limit", default=100, negative=False)
+        start = parse_integer(request, "from", default=0)
+        limit = parse_integer(request, "limit", default=100)
 
         # If neither `order_by` nor `dir` is set, set the default order
         # to newest media is on top for backward compatibility.
diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py
index dc27a41dd9..0adc5b7005 100644
--- a/synapse/rest/admin/statistics.py
+++ b/synapse/rest/admin/statistics.py
@@ -63,10 +63,10 @@ class UserMediaStatisticsRestServlet(RestServlet):
             ),
         )
 
-        start = parse_integer(request, "from", default=0, negative=False)
-        limit = parse_integer(request, "limit", default=100, negative=False)
-        from_ts = parse_integer(request, "from_ts", default=0, negative=False)
-        until_ts = parse_integer(request, "until_ts", negative=False)
+        start = parse_integer(request, "from", default=0)
+        limit = parse_integer(request, "limit", default=100)
+        from_ts = parse_integer(request, "from_ts", default=0)
+        until_ts = parse_integer(request, "until_ts")
 
         if until_ts is not None:
             if until_ts <= from_ts:
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 5bf12c4979..ad515bd5a3 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -27,11 +27,13 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
 
 import attr
 
+from synapse._pydantic_compat import HAS_PYDANTIC_V2
 from synapse.api.constants import Direction, UserTypes
 from synapse.api.errors import Codes, NotFoundError, SynapseError
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
+    parse_and_validate_json_object_from_request,
     parse_boolean,
     parse_enum,
     parse_integer,
@@ -49,10 +51,17 @@ from synapse.rest.client._base import client_patterns
 from synapse.storage.databases.main.registration import ExternalIDReuseException
 from synapse.storage.databases.main.stats import UserSortOrder
 from synapse.types import JsonDict, JsonMapping, UserID
+from synapse.types.rest import RequestBodyModel
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
+if TYPE_CHECKING or HAS_PYDANTIC_V2:
+    from pydantic.v1 import StrictBool
+else:
+    from pydantic import StrictBool
+
+
 logger = logging.getLogger(__name__)
 
 
@@ -90,8 +99,8 @@ class UsersRestServletV2(RestServlet):
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
-        start = parse_integer(request, "from", default=0, negative=False)
-        limit = parse_integer(request, "limit", default=100, negative=False)
+        start = parse_integer(request, "from", default=0)
+        limit = parse_integer(request, "limit", default=100)
 
         user_id = parse_string(request, "user_id")
         name = parse_string(request, "name", encoding="utf-8")
@@ -732,6 +741,36 @@ class DeactivateAccountRestServlet(RestServlet):
         return HTTPStatus.OK, {"id_server_unbind_result": id_server_unbind_result}
 
 
+class SuspendAccountRestServlet(RestServlet):
+    PATTERNS = admin_patterns("/suspend/(?P<target_user_id>[^/]*)$")
+
+    def __init__(self, hs: "HomeServer"):
+        self.auth = hs.get_auth()
+        self.is_mine = hs.is_mine
+        self.store = hs.get_datastores().main
+
+    class PutBody(RequestBodyModel):
+        suspend: StrictBool
+
+    async def on_PUT(
+        self, request: SynapseRequest, target_user_id: str
+    ) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester)
+
+        if not self.is_mine(UserID.from_string(target_user_id)):
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only suspend local users")
+
+        if not await self.store.get_user_by_id(target_user_id):
+            raise NotFoundError("User not found")
+
+        body = parse_and_validate_json_object_from_request(request, self.PutBody)
+        suspend = body.suspend
+        await self.store.set_user_suspended_status(target_user_id, suspend)
+
+        return HTTPStatus.OK, {f"user_{target_user_id}_suspended": suspend}
+
+
 class AccountValidityRenewServlet(RestServlet):
     PATTERNS = admin_patterns("/account_validity/validity$")
 
diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py
index 0323f6afa1..c1a80c5c3d 100644
--- a/synapse/rest/client/profile.py
+++ b/synapse/rest/client/profile.py
@@ -108,6 +108,19 @@ class ProfileDisplaynameRestServlet(RestServlet):
 
         propagate = _read_propagate(self.hs, request)
 
+        requester_suspended = (
+            await self.hs.get_datastores().main.get_user_suspended_status(
+                requester.user.to_string()
+            )
+        )
+
+        if requester_suspended:
+            raise SynapseError(
+                403,
+                "Updating displayname while account is suspended is not allowed.",
+                Codes.USER_ACCOUNT_SUSPENDED,
+            )
+
         await self.profile_handler.set_displayname(
             user, requester, new_name, is_admin, propagate=propagate
         )
@@ -167,6 +180,19 @@ class ProfileAvatarURLRestServlet(RestServlet):
 
         propagate = _read_propagate(self.hs, request)
 
+        requester_suspended = (
+            await self.hs.get_datastores().main.get_user_suspended_status(
+                requester.user.to_string()
+            )
+        )
+
+        if requester_suspended:
+            raise SynapseError(
+                403,
+                "Updating avatar URL while account is suspended is not allowed.",
+                Codes.USER_ACCOUNT_SUSPENDED,
+            )
+
         await self.profile_handler.set_avatar_url(
             user, requester, new_avatar_url, is_admin, propagate=propagate
         )
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index c98241f6ce..903c74f6d8 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -510,7 +510,7 @@ class PublicRoomListRestServlet(RestServlet):
             if server:
                 raise e
 
-        limit: Optional[int] = parse_integer(request, "limit", 0, negative=False)
+        limit: Optional[int] = parse_integer(request, "limit", 0)
         since_token = parse_string(request, "since")
 
         if limit == 0:
@@ -1120,6 +1120,20 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
     ) -> Tuple[int, JsonDict]:
         content = parse_json_object_from_request(request)
 
+        requester_suspended = await self._store.get_user_suspended_status(
+            requester.user.to_string()
+        )
+
+        if requester_suspended:
+            event = await self._store.get_event(event_id, allow_none=True)
+            if event:
+                if event.sender != requester.user.to_string():
+                    raise SynapseError(
+                        403,
+                        "You can only redact your own events while account is suspended.",
+                        Codes.USER_ACCOUNT_SUSPENDED,
+                    )
+
         # Ensure the redacts property in the content matches the one provided in
         # the URL.
         room_version = await self._store.get_room_version(room_id)
@@ -1430,16 +1444,7 @@ class RoomHierarchyRestServlet(RestServlet):
         requester = await self._auth.get_user_by_req(request, allow_guest=True)
 
         max_depth = parse_integer(request, "max_depth")
-        if max_depth is not None and max_depth < 0:
-            raise SynapseError(
-                400, "'max_depth' must be a non-negative integer", Codes.BAD_JSON
-            )
-
         limit = parse_integer(request, "limit")
-        if limit is not None and limit <= 0:
-            raise SynapseError(
-                400, "'limit' must be a positive integer", Codes.BAD_JSON
-            )
 
         return 200, await self._room_summary_handler.get_room_hierarchy(
             requester,
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index 84699a2ee1..d0e015bf19 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -617,6 +617,17 @@ class EventsPersistenceStorageController:
                         room_id, chunk
                     )
 
+            with Measure(self._clock, "calculate_chain_cover_index_for_events"):
+                # We now calculate chain ID/sequence numbers for any state events we're
+                # persisting. We ignore out of band memberships as we're not in the room
+                # and won't have their auth chain (we'll fix it up later if we join the
+                # room).
+                #
+                # See: docs/auth_chain_difference_algorithm.md
+                new_event_links = await self.persist_events_store.calculate_chain_cover_index_for_events(
+                    room_id, [e for e, _ in chunk]
+                )
+
             await self.persist_events_store._persist_events_and_state_updates(
                 room_id,
                 chunk,
@@ -624,6 +635,7 @@ class EventsPersistenceStorageController:
                 new_forward_extremities=new_forward_extremities,
                 use_negative_stream_ordering=backfilled,
                 inhibit_local_membership_updates=backfilled,
+                new_event_links=new_event_links,
             )
 
         return replaced_events
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index d687fc9e78..b62ef0b507 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -164,22 +164,24 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             prefilled_cache=user_signature_stream_prefill,
         )
 
-        (
-            device_list_federation_prefill,
-            device_list_federation_list_id,
-        ) = self.db_pool.get_cache_dict(
-            db_conn,
-            "device_lists_outbound_pokes",
-            entity_column="destination",
-            stream_column="stream_id",
-            max_value=device_list_max,
-            limit=10000,
-        )
-        self._device_list_federation_stream_cache = StreamChangeCache(
-            "DeviceListFederationStreamChangeCache",
-            device_list_federation_list_id,
-            prefilled_cache=device_list_federation_prefill,
-        )
+        self._device_list_federation_stream_cache = None
+        if hs.should_send_federation():
+            (
+                device_list_federation_prefill,
+                device_list_federation_list_id,
+            ) = self.db_pool.get_cache_dict(
+                db_conn,
+                "device_lists_outbound_pokes",
+                entity_column="destination",
+                stream_column="stream_id",
+                max_value=device_list_max,
+                limit=10000,
+            )
+            self._device_list_federation_stream_cache = StreamChangeCache(
+                "DeviceListFederationStreamChangeCache",
+                device_list_federation_list_id,
+                prefilled_cache=device_list_federation_prefill,
+            )
 
         # vdh,rei 2023-10-13: disable because it is eating DB
         # https://github.com/matrix-org/synapse/issues/16480
@@ -209,23 +211,30 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     ) -> None:
         for row in rows:
             if row.is_signature:
-                self._user_signature_stream_cache.entity_has_changed(row.entity, token)
+                self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
                 continue
 
             # The entities are either user IDs (starting with '@') whose devices
             # have changed, or remote servers that we need to tell about
             # changes.
-            if row.entity.startswith("@"):
-                self._device_list_stream_cache.entity_has_changed(row.entity, token)
-                self.get_cached_devices_for_user.invalidate((row.entity,))
-                self._get_cached_user_device.invalidate((row.entity,))
-                self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
-
-            else:
-                self._device_list_federation_stream_cache.entity_has_changed(
-                    row.entity, token
+            if not row.hosts_calculated:
+                self._device_list_stream_cache.entity_has_changed(row.user_id, token)
+                self.get_cached_devices_for_user.invalidate((row.user_id,))
+                self._get_cached_user_device.invalidate((row.user_id,))
+                self.get_device_list_last_stream_id_for_remote.invalidate(
+                    (row.user_id,)
                 )
 
+    def device_lists_outbound_pokes_have_changed(
+        self, destinations: StrCollection, token: int
+    ) -> None:
+        assert self._device_list_federation_stream_cache is not None
+
+        for destination in destinations:
+            self._device_list_federation_stream_cache.entity_has_changed(
+                destination, token
+            )
+
     def device_lists_in_rooms_have_changed(
         self, room_ids: StrCollection, token: int
     ) -> None:
@@ -365,6 +374,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
               EDU contents.
         """
         now_stream_id = self.get_device_stream_token()
+        if from_stream_id == now_stream_id:
+            return now_stream_id, []
+
+        if self._device_list_federation_stream_cache is None:
+            raise Exception("Func can only be used on federation senders")
 
         has_changed = self._device_list_federation_stream_cache.has_entity_changed(
             destination, int(from_stream_id)
@@ -1020,10 +1034,10 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             # This query Does The Right Thing where it'll correctly apply the
             # bounds to the inner queries.
             sql = """
-                SELECT stream_id, entity FROM (
-                    SELECT stream_id, user_id AS entity FROM device_lists_stream
+                SELECT stream_id, user_id, hosts FROM (
+                    SELECT stream_id, user_id, false AS hosts FROM device_lists_stream
                     UNION ALL
-                    SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
+                    SELECT DISTINCT stream_id, user_id, true AS hosts FROM device_lists_outbound_pokes
                 ) AS e
                 WHERE ? < stream_id AND stream_id <= ?
                 ORDER BY stream_id ASC
@@ -1579,6 +1593,14 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             get_device_list_changes_in_room_txn,
         )
 
+    async def get_destinations_for_device(self, stream_id: int) -> StrCollection:
+        return await self.db_pool.simple_select_onecol(
+            table="device_lists_outbound_pokes",
+            keyvalues={"stream_id": stream_id},
+            retcol="destination",
+            desc="get_destinations_for_device",
+        )
+
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
     def __init__(
@@ -2114,12 +2136,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         stream_ids: List[int],
         context: Optional[Dict[str, str]],
     ) -> None:
-        for host in hosts:
-            txn.call_after(
-                self._device_list_federation_stream_cache.entity_has_changed,
-                host,
-                stream_ids[-1],
-            )
+        if self._device_list_federation_stream_cache:
+            for host in hosts:
+                txn.call_after(
+                    self._device_list_federation_stream_cache.entity_has_changed,
+                    host,
+                    stream_ids[-1],
+                )
 
         now = self._clock.time_msec()
         stream_id_iterator = iter(stream_ids)
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 38d8785faa..9e6c9561ae 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -123,9 +123,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         if stream_name == DeviceListsStream.NAME:
             for row in rows:
                 assert isinstance(row, DeviceListsStream.DeviceListsStreamRow)
-                if row.entity.startswith("@"):
+                if not row.hosts_calculated:
                     self._get_e2e_device_keys_for_federation_query_inner.invalidate(
-                        (row.entity,)
+                        (row.user_id,)
                     )
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index fb132ef090..24abab4a23 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -148,6 +148,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             500000, "_event_auth_cache", size_callback=len
         )
 
+        # Flag used by unit tests to disable fallback when there is no chain cover
+        # index.
+        self.tests_allow_no_chain_cover_index = True
+
         self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000)
 
         if isinstance(self.database_engine, PostgresEngine):
@@ -220,8 +224,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                 )
             except _NoChainCoverIndex:
                 # For whatever reason we don't actually have a chain cover index
-                # for the events in question, so we fall back to the old method.
-                pass
+                # for the events in question, so we fall back to the old method
+                # (except in tests)
+                if not self.tests_allow_no_chain_cover_index:
+                    raise
 
         return await self.db_pool.runInteraction(
             "get_auth_chain_ids",
@@ -271,7 +277,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         if events_missing_chain_info:
             # This can happen due to e.g. downgrade/upgrade of the server. We
             # raise an exception and fall back to the previous algorithm.
-            logger.info(
+            logger.error(
                 "Unexpectedly found that events don't have chain IDs in room %s: %s",
                 room_id,
                 events_missing_chain_info,
@@ -482,8 +488,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                 )
             except _NoChainCoverIndex:
                 # For whatever reason we don't actually have a chain cover index
-                # for the events in question, so we fall back to the old method.
-                pass
+                # for the events in question, so we fall back to the old method
+                # (except in tests)
+                if not self.tests_allow_no_chain_cover_index:
+                    raise
 
         return await self.db_pool.runInteraction(
             "get_auth_chain_difference",
@@ -710,7 +718,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         if events_missing_chain_info - event_to_auth_ids.keys():
             # Uh oh, we somehow haven't correctly done the chain cover index,
             # bail and fall back to the old method.
-            logger.info(
+            logger.error(
                 "Unexpectedly found that events don't have chain IDs in room %s: %s",
                 room_id,
                 events_missing_chain_info - event_to_auth_ids.keys(),
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 66428e6c8e..1f7acdb859 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -34,7 +34,6 @@ from typing import (
     Optional,
     Set,
     Tuple,
-    Union,
     cast,
 )
 
@@ -100,6 +99,23 @@ class DeltaState:
         return not self.to_delete and not self.to_insert and not self.no_longer_in_room
 
 
+@attr.s(slots=True, auto_attribs=True)
+class NewEventChainLinks:
+    """Information about new auth chain links that need to be added to the DB.
+
+    Attributes:
+        chain_id, sequence_number: the IDs corresponding to the event being
+            inserted, and the starting point of the links
+        links: Lists the links that need to be added, 2-tuple of the chain
+            ID/sequence number of the end point of the link.
+    """
+
+    chain_id: int
+    sequence_number: int
+
+    links: List[Tuple[int, int]] = attr.Factory(list)
+
+
 class PersistEventsStore:
     """Contains all the functions for writing events to the database.
 
@@ -148,6 +164,7 @@ class PersistEventsStore:
         *,
         state_delta_for_room: Optional[DeltaState],
         new_forward_extremities: Optional[Set[str]],
+        new_event_links: Dict[str, NewEventChainLinks],
         use_negative_stream_ordering: bool = False,
         inhibit_local_membership_updates: bool = False,
     ) -> None:
@@ -217,6 +234,7 @@ class PersistEventsStore:
                 inhibit_local_membership_updates=inhibit_local_membership_updates,
                 state_delta_for_room=state_delta_for_room,
                 new_forward_extremities=new_forward_extremities,
+                new_event_links=new_event_links,
             )
             persist_event_counter.inc(len(events_and_contexts))
 
@@ -243,6 +261,87 @@ class PersistEventsStore:
                     (room_id,), frozenset(new_forward_extremities)
                 )
 
+    async def calculate_chain_cover_index_for_events(
+        self, room_id: str, events: Collection[EventBase]
+    ) -> Dict[str, NewEventChainLinks]:
+        # Filter to state events, and ensure there are no duplicates.
+        state_events = []
+        seen_events = set()
+        for event in events:
+            if not event.is_state() or event.event_id in seen_events:
+                continue
+
+            state_events.append(event)
+            seen_events.add(event.event_id)
+
+        if not state_events:
+            return {}
+
+        return await self.db_pool.runInteraction(
+            "_calculate_chain_cover_index_for_events",
+            self.calculate_chain_cover_index_for_events_txn,
+            room_id,
+            state_events,
+        )
+
+    def calculate_chain_cover_index_for_events_txn(
+        self, txn: LoggingTransaction, room_id: str, state_events: Collection[EventBase]
+    ) -> Dict[str, NewEventChainLinks]:
+        # We now calculate chain ID/sequence numbers for any state events we're
+        # persisting. We ignore out of band memberships as we're not in the room
+        # and won't have their auth chain (we'll fix it up later if we join the
+        # room).
+        #
+        # See: docs/auth_chain_difference_algorithm.md
+
+        # We ignore legacy rooms that we aren't filling the chain cover index
+        # for.
+        row = self.db_pool.simple_select_one_txn(
+            txn,
+            table="rooms",
+            keyvalues={"room_id": room_id},
+            retcols=("room_id", "has_auth_chain_index"),
+            allow_none=True,
+        )
+        if row is None or row[1] is False:
+            return {}
+
+        # Filter out events that we've already calculated.
+        rows = self.db_pool.simple_select_many_txn(
+            txn,
+            table="event_auth_chains",
+            column="event_id",
+            iterable=[e.event_id for e in state_events],
+            keyvalues={},
+            retcols=("event_id",),
+        )
+        already_persisted_events = {event_id for event_id, in rows}
+        state_events = [
+            event
+            for event in state_events
+            if event.event_id not in already_persisted_events
+        ]
+
+        if not state_events:
+            return {}
+
+        # We need to know the type/state_key and auth events of the events we're
+        # calculating chain IDs for. We don't rely on having the full Event
+        # instances as we'll potentially be pulling more events from the DB and
+        # we don't need the overhead of fetching/parsing the full event JSON.
+        event_to_types = {e.event_id: (e.type, e.state_key) for e in state_events}
+        event_to_auth_chain = {e.event_id: e.auth_event_ids() for e in state_events}
+        event_to_room_id = {e.event_id: e.room_id for e in state_events}
+
+        return self._calculate_chain_cover_index(
+            txn,
+            self.db_pool,
+            self.store.event_chain_id_gen,
+            event_to_room_id,
+            event_to_types,
+            event_to_auth_chain,
+        )
+
     async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
         """Filter the supplied list of event_ids to get those which are prev_events of
         existing (non-outlier/rejected) events.
@@ -358,6 +457,7 @@ class PersistEventsStore:
         inhibit_local_membership_updates: bool,
         state_delta_for_room: Optional[DeltaState],
         new_forward_extremities: Optional[Set[str]],
+        new_event_links: Dict[str, NewEventChainLinks],
     ) -> None:
         """Insert some number of room events into the necessary database tables.
 
@@ -466,7 +566,9 @@ class PersistEventsStore:
         # Insert into event_to_state_groups.
         self._store_event_state_mappings_txn(txn, events_and_contexts)
 
-        self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts])
+        self._persist_event_auth_chain_txn(
+            txn, [e for e, _ in events_and_contexts], new_event_links
+        )
 
         # _store_rejected_events_txn filters out any events which were
         # rejected, and returns the filtered list.
@@ -496,7 +598,11 @@ class PersistEventsStore:
         self,
         txn: LoggingTransaction,
         events: List[EventBase],
+        new_event_links: Dict[str, NewEventChainLinks],
     ) -> None:
+        if new_event_links:
+            self._persist_chain_cover_index(txn, self.db_pool, new_event_links)
+
         # We only care about state events, so this if there are no state events.
         if not any(e.is_state() for e in events):
             return
@@ -519,62 +625,37 @@ class PersistEventsStore:
             ],
         )
 
-        # We now calculate chain ID/sequence numbers for any state events we're
-        # persisting. We ignore out of band memberships as we're not in the room
-        # and won't have their auth chain (we'll fix it up later if we join the
-        # room).
-        #
-        # See: docs/auth_chain_difference_algorithm.md
-
-        # We ignore legacy rooms that we aren't filling the chain cover index
-        # for.
-        rows = cast(
-            List[Tuple[str, Optional[Union[int, bool]]]],
-            self.db_pool.simple_select_many_txn(
-                txn,
-                table="rooms",
-                column="room_id",
-                iterable={event.room_id for event in events if event.is_state()},
-                keyvalues={},
-                retcols=("room_id", "has_auth_chain_index"),
-            ),
-        )
-        rooms_using_chain_index = {
-            room_id for room_id, has_auth_chain_index in rows if has_auth_chain_index
-        }
-
-        state_events = {
-            event.event_id: event
-            for event in events
-            if event.is_state() and event.room_id in rooms_using_chain_index
-        }
-
-        if not state_events:
-            return
+    @classmethod
+    def _add_chain_cover_index(
+        cls,
+        txn: LoggingTransaction,
+        db_pool: DatabasePool,
+        event_chain_id_gen: SequenceGenerator,
+        event_to_room_id: Dict[str, str],
+        event_to_types: Dict[str, Tuple[str, str]],
+        event_to_auth_chain: Dict[str, StrCollection],
+    ) -> None:
+        """Calculate and persist the chain cover index for the given events.
 
-        # We need to know the type/state_key and auth events of the events we're
-        # calculating chain IDs for. We don't rely on having the full Event
-        # instances as we'll potentially be pulling more events from the DB and
-        # we don't need the overhead of fetching/parsing the full event JSON.
-        event_to_types = {
-            e.event_id: (e.type, e.state_key) for e in state_events.values()
-        }
-        event_to_auth_chain = {
-            e.event_id: e.auth_event_ids() for e in state_events.values()
-        }
-        event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
+        Args:
+            event_to_room_id: Event ID to the room ID of the event
+            event_to_types: Event ID to type and state_key of the event
+            event_to_auth_chain: Event ID to list of auth event IDs of the
+                event (events with no auth events can be excluded).
+        """
 
-        self._add_chain_cover_index(
+        new_event_links = cls._calculate_chain_cover_index(
             txn,
-            self.db_pool,
-            self.store.event_chain_id_gen,
+            db_pool,
+            event_chain_id_gen,
             event_to_room_id,
             event_to_types,
             event_to_auth_chain,
         )
+        cls._persist_chain_cover_index(txn, db_pool, new_event_links)
 
     @classmethod
-    def _add_chain_cover_index(
+    def _calculate_chain_cover_index(
         cls,
         txn: LoggingTransaction,
         db_pool: DatabasePool,
@@ -582,7 +663,7 @@ class PersistEventsStore:
         event_to_room_id: Dict[str, str],
         event_to_types: Dict[str, Tuple[str, str]],
         event_to_auth_chain: Dict[str, StrCollection],
-    ) -> None:
+    ) -> Dict[str, NewEventChainLinks]:
         """Calculate the chain cover index for the given events.
 
         Args:
@@ -590,6 +671,10 @@ class PersistEventsStore:
             event_to_types: Event ID to type and state_key of the event
             event_to_auth_chain: Event ID to list of auth event IDs of the
                 event (events with no auth events can be excluded).
+
+        Returns:
+            A mapping with any new auth chain links we need to add, keyed by
+            event ID.
         """
 
         # Map from event ID to chain ID/sequence number.
@@ -708,11 +793,11 @@ class PersistEventsStore:
                     room_id = event_to_room_id.get(event_id)
                     if room_id:
                         e_type, state_key = event_to_types[event_id]
-                        db_pool.simple_insert_txn(
+                        db_pool.simple_upsert_txn(
                             txn,
                             table="event_auth_chain_to_calculate",
+                            keyvalues={"event_id": event_id},
                             values={
-                                "event_id": event_id,
                                 "room_id": room_id,
                                 "type": e_type,
                                 "state_key": state_key,
@@ -724,7 +809,7 @@ class PersistEventsStore:
                     break
 
         if not events_to_calc_chain_id_for:
-            return
+            return {}
 
         # Allocate chain ID/sequence numbers to each new event.
         new_chain_tuples = cls._allocate_chain_ids(
@@ -739,23 +824,10 @@ class PersistEventsStore:
         )
         chain_map.update(new_chain_tuples)
 
-        db_pool.simple_insert_many_txn(
-            txn,
-            table="event_auth_chains",
-            keys=("event_id", "chain_id", "sequence_number"),
-            values=[
-                (event_id, c_id, seq)
-                for event_id, (c_id, seq) in new_chain_tuples.items()
-            ],
-        )
-
-        db_pool.simple_delete_many_txn(
-            txn,
-            table="event_auth_chain_to_calculate",
-            keyvalues={},
-            column="event_id",
-            values=new_chain_tuples,
-        )
+        to_return = {
+            event_id: NewEventChainLinks(chain_id, sequence_number)
+            for event_id, (chain_id, sequence_number) in new_chain_tuples.items()
+        }
 
         # Now we need to calculate any new links between chains caused by
         # the new events.
@@ -825,10 +897,38 @@ class PersistEventsStore:
                 auth_chain_id, auth_sequence_number = chain_map[auth_id]
 
                 # Step 2a, add link between the event and auth event
+                to_return[event_id].links.append((auth_chain_id, auth_sequence_number))
                 chain_links.add_link(
                     (chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
                 )
 
+        return to_return
+
+    @classmethod
+    def _persist_chain_cover_index(
+        cls,
+        txn: LoggingTransaction,
+        db_pool: DatabasePool,
+        new_event_links: Dict[str, NewEventChainLinks],
+    ) -> None:
+        db_pool.simple_insert_many_txn(
+            txn,
+            table="event_auth_chains",
+            keys=("event_id", "chain_id", "sequence_number"),
+            values=[
+                (event_id, new_links.chain_id, new_links.sequence_number)
+                for event_id, new_links in new_event_links.items()
+            ],
+        )
+
+        db_pool.simple_delete_many_txn(
+            txn,
+            table="event_auth_chain_to_calculate",
+            keyvalues={},
+            column="event_id",
+            values=new_event_links,
+        )
+
         db_pool.simple_insert_many_txn(
             txn,
             table="event_auth_chain_links",
@@ -838,7 +938,16 @@ class PersistEventsStore:
                 "target_chain_id",
                 "target_sequence_number",
             ),
-            values=list(chain_links.get_additions()),
+            values=[
+                (
+                    new_links.chain_id,
+                    new_links.sequence_number,
+                    target_chain_id,
+                    target_sequence_number,
+                )
+                for new_links in new_event_links.values()
+                for (target_chain_id, target_sequence_number) in new_links.links
+            ],
         )
 
     @staticmethod
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index eeafe889de..9fee5bfb92 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -75,9 +75,6 @@ class PaginationConfig:
             raise SynapseError(400, "'to' parameter is invalid")
 
         limit = parse_integer(request, "limit", default=default_limit)
-        if limit < 0:
-            raise SynapseError(400, "Limit must be 0 or above")
-
         limit = min(limit, MAX_LIMIT)
 
         try:
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index c5da1e9686..16bb4349f5 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -37,6 +37,7 @@ from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes
 from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
 from synapse.api.room_versions import RoomVersions
 from synapse.media.filepath import MediaFilePaths
+from synapse.rest import admin
 from synapse.rest.client import (
     devices,
     login,
@@ -5005,3 +5006,86 @@ class AllowCrossSigningReplacementTestCase(unittest.HomeserverTestCase):
         )
         assert timestamp is not None
         self.assertGreater(timestamp, self.clock.time_msec())
+
+
+class UserSuspensionTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        admin.register_servlets,
+    ]
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.admin = self.register_user("thomas", "hackme", True)
+        self.admin_tok = self.login("thomas", "hackme")
+
+        self.bad_user = self.register_user("teresa", "hackme")
+        self.bad_user_tok = self.login("teresa", "hackme")
+
+        self.store = hs.get_datastores().main
+
+    @override_config({"experimental_features": {"msc3823_account_suspension": True}})
+    def test_suspend_user(self) -> None:
+        # test that suspending user works
+        channel = self.make_request(
+            "PUT",
+            f"/_synapse/admin/v1/suspend/{self.bad_user}",
+            {"suspend": True},
+            access_token=self.admin_tok,
+        )
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body, {f"user_{self.bad_user}_suspended": True})
+
+        res = self.get_success(self.store.get_user_suspended_status(self.bad_user))
+        self.assertEqual(True, res)
+
+        # test that un-suspending user works
+        channel2 = self.make_request(
+            "PUT",
+            f"/_synapse/admin/v1/suspend/{self.bad_user}",
+            {"suspend": False},
+            access_token=self.admin_tok,
+        )
+        self.assertEqual(channel2.code, 200)
+        self.assertEqual(channel2.json_body, {f"user_{self.bad_user}_suspended": False})
+
+        res2 = self.get_success(self.store.get_user_suspended_status(self.bad_user))
+        self.assertEqual(False, res2)
+
+        # test that trying to un-suspend user who isn't suspended doesn't cause problems
+        channel3 = self.make_request(
+            "PUT",
+            f"/_synapse/admin/v1/suspend/{self.bad_user}",
+            {"suspend": False},
+            access_token=self.admin_tok,
+        )
+        self.assertEqual(channel3.code, 200)
+        self.assertEqual(channel3.json_body, {f"user_{self.bad_user}_suspended": False})
+
+        res3 = self.get_success(self.store.get_user_suspended_status(self.bad_user))
+        self.assertEqual(False, res3)
+
+        # test that trying to suspend user who is already suspended doesn't cause problems
+        channel4 = self.make_request(
+            "PUT",
+            f"/_synapse/admin/v1/suspend/{self.bad_user}",
+            {"suspend": True},
+            access_token=self.admin_tok,
+        )
+        self.assertEqual(channel4.code, 200)
+        self.assertEqual(channel4.json_body, {f"user_{self.bad_user}_suspended": True})
+
+        res4 = self.get_success(self.store.get_user_suspended_status(self.bad_user))
+        self.assertEqual(True, res4)
+
+        channel5 = self.make_request(
+            "PUT",
+            f"/_synapse/admin/v1/suspend/{self.bad_user}",
+            {"suspend": True},
+            access_token=self.admin_tok,
+        )
+        self.assertEqual(channel5.code, 200)
+        self.assertEqual(channel5.json_body, {f"user_{self.bad_user}_suspended": True})
+
+        res5 = self.get_success(self.store.get_user_suspended_status(self.bad_user))
+        self.assertEqual(True, res5)
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index d398cead1c..c559dfda83 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -3819,3 +3819,108 @@ class TimestampLookupTestCase(unittest.HomeserverTestCase):
 
         # Make sure the outlier event is not returned
         self.assertNotEqual(channel.json_body["event_id"], outlier_event.event_id)
+
+
+class UserSuspensionTests(unittest.HomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+        profile.register_servlets,
+    ]
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.user1 = self.register_user("thomas", "hackme")
+        self.tok1 = self.login("thomas", "hackme")
+
+        self.user2 = self.register_user("teresa", "hackme")
+        self.tok2 = self.login("teresa", "hackme")
+
+        self.room1 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
+        self.store = hs.get_datastores().main
+
+    def test_suspended_user_cannot_send_message_to_room(self) -> None:
+        # set the user as suspended
+        self.get_success(self.store.set_user_suspended_status(self.user1, True))
+
+        channel = self.make_request(
+            "PUT",
+            f"/rooms/{self.room1}/send/m.room.message/1",
+            access_token=self.tok1,
+            content={"body": "hello", "msgtype": "m.text"},
+        )
+        self.assertEqual(
+            channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
+        )
+
+    def test_suspended_user_cannot_change_profile_data(self) -> None:
+        # set the user as suspended
+        self.get_success(self.store.set_user_suspended_status(self.user1, True))
+
+        channel = self.make_request(
+            "PUT",
+            f"/_matrix/client/v3/profile/{self.user1}/avatar_url",
+            access_token=self.tok1,
+            content={"avatar_url": "mxc://matrix.org/wefh34uihSDRGhw34"},
+            shorthand=False,
+        )
+        self.assertEqual(
+            channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
+        )
+
+        channel2 = self.make_request(
+            "PUT",
+            f"/_matrix/client/v3/profile/{self.user1}/displayname",
+            access_token=self.tok1,
+            content={"displayname": "something offensive"},
+            shorthand=False,
+        )
+        self.assertEqual(
+            channel2.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
+        )
+
+    def test_suspended_user_cannot_redact_messages_other_than_their_own(self) -> None:
+        # first user sends message
+        self.make_request("POST", f"/rooms/{self.room1}/join", access_token=self.tok2)
+        res = self.helper.send_event(
+            self.room1,
+            "m.room.message",
+            {"body": "hello", "msgtype": "m.text"},
+            tok=self.tok2,
+        )
+        event_id = res["event_id"]
+
+        # second user sends message
+        self.make_request("POST", f"/rooms/{self.room1}/join", access_token=self.tok1)
+        res2 = self.helper.send_event(
+            self.room1,
+            "m.room.message",
+            {"body": "bad_message", "msgtype": "m.text"},
+            tok=self.tok1,
+        )
+        event_id2 = res2["event_id"]
+
+        # set the second user as suspended
+        self.get_success(self.store.set_user_suspended_status(self.user1, True))
+
+        # second user can't redact first user's message
+        channel = self.make_request(
+            "PUT",
+            f"/_matrix/client/v3/rooms/{self.room1}/redact/{event_id}/1",
+            access_token=self.tok1,
+            content={"reason": "bogus"},
+            shorthand=False,
+        )
+        self.assertEqual(
+            channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
+        )
+
+        # but can redact their own
+        channel = self.make_request(
+            "PUT",
+            f"/_matrix/client/v3/rooms/{self.room1}/redact/{event_id2}/1",
+            access_token=self.tok1,
+            content={"reason": "bogus"},
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 200)
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 7f975d04ff..ba01b038ab 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -36,6 +36,14 @@ class DeviceStoreTestCase(HomeserverTestCase):
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
 
+    def default_config(self) -> JsonDict:
+        config = super().default_config()
+
+        # We 'enable' federation otherwise `get_device_updates_by_remote` will
+        # throw an exception.
+        config["federation_sender_instances"] = ["master"]
+        return config
+
     def add_device_change(self, user_id: str, device_ids: List[str], host: str) -> None:
         """Add a device list change for the given device to
         `device_lists_outbound_pokes` table.
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 81feb3ec29..c4e216c308 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -447,7 +447,14 @@ class EventChainStoreTestCase(HomeserverTestCase):
             )
 
             # Actually call the function that calculates the auth chain stuff.
-            persist_events_store._persist_event_auth_chain_txn(txn, events)
+            new_event_links = (
+                persist_events_store.calculate_chain_cover_index_for_events_txn(
+                    txn, events[0].room_id, [e for e in events if e.is_state()]
+                )
+            )
+            persist_events_store._persist_event_auth_chain_txn(
+                txn, events, new_event_links
+            )
 
         self.get_success(
             persist_events_store.db_pool.runInteraction(
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 0a6253e22c..088f0d24f9 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -365,12 +365,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
                     },
                 )
 
+            events = [
+                cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
+                for event_id in AUTH_GRAPH
+            ]
+            new_event_links = (
+                self.persist_events.calculate_chain_cover_index_for_events_txn(
+                    txn, room_id, [e for e in events if e.is_state()]
+                )
+            )
             self.persist_events._persist_event_auth_chain_txn(
                 txn,
-                [
-                    cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
-                    for event_id in AUTH_GRAPH
-                ],
+                events,
+                new_event_links,
             )
 
         self.get_success(
@@ -544,6 +551,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         rooms.
         """
 
+        # We allow partial covers for this test
+        self.hs.get_datastores().main.tests_allow_no_chain_cover_index = True
+
         room_id = "@ROOM:local"
 
         # The silly auth graph we use to test the auth difference algorithm,
@@ -628,13 +638,20 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
                 )
 
             # Insert all events apart from 'B'
+            events = [
+                cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
+                for event_id in auth_graph
+                if event_id != "b"
+            ]
+            new_event_links = (
+                self.persist_events.calculate_chain_cover_index_for_events_txn(
+                    txn, room_id, [e for e in events if e.is_state()]
+                )
+            )
             self.persist_events._persist_event_auth_chain_txn(
                 txn,
-                [
-                    cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
-                    for event_id in auth_graph
-                    if event_id != "b"
-                ],
+                events,
+                new_event_links,
             )
 
             # Now we insert the event 'B' without a chain cover, by temporarily
@@ -647,9 +664,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
                 updatevalues={"has_auth_chain_index": False},
             )
 
+            events = [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))]
+            new_event_links = (
+                self.persist_events.calculate_chain_cover_index_for_events_txn(
+                    txn, room_id, [e for e in events if e.is_state()]
+                )
+            )
             self.persist_events._persist_event_auth_chain_txn(
-                txn,
-                [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
+                txn, events, new_event_links
             )
 
             self.store.db_pool.simple_update_txn(
diff --git a/tests/unittest.py b/tests/unittest.py
index 18963b9e32..a7c20556a0 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -344,6 +344,8 @@ class HomeserverTestCase(TestCase):
         self._hs_args = {"clock": self.clock, "reactor": self.reactor}
         self.hs = self.make_homeserver(self.reactor, self.clock)
 
+        self.hs.get_datastores().main.tests_allow_no_chain_cover_index = False
+
         # Honour the `use_frozen_dicts` config option. We have to do this
         # manually because this is taken care of in the app `start` code, which
         # we don't run. Plus we want to reset it on tearDown.