summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/federation/test_federation_client.py149
-rw-r--r--tests/federation/transport/test_client.py32
-rw-r--r--tests/handlers/test_password_providers.py123
-rw-r--r--tests/push/test_http.py129
-rw-r--r--tests/rest/client/test_account.py9
-rw-r--r--tests/rest/client/test_auth.py93
-rw-r--r--tests/rest/client/test_device_lists.py155
-rw-r--r--tests/rest/client/test_relations.py42
-rw-r--r--tests/rest/client/utils.py6
-rw-r--r--tests/storage/databases/test_state_store.py352
-rw-r--r--tests/storage/test_events.py107
-rw-r--r--tests/storage/test_state.py109
-rw-r--r--tests/unittest.py21
13 files changed, 1226 insertions, 101 deletions
diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py
new file mode 100644
index 0000000000..ec8864dafe
--- /dev/null
+++ b/tests/federation/test_federation_client.py
@@ -0,0 +1,149 @@
+# Copyright 2022 Matrix.org Federation C.I.C
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from unittest import mock
+
+import twisted.web.client
+from twisted.internet import defer
+from twisted.internet.protocol import Protocol
+from twisted.python.failure import Failure
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.room_versions import RoomVersions
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
+
+from tests.unittest import FederatingHomeserverTestCase
+
+
+class FederationClientTest(FederatingHomeserverTestCase):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
+        super().prepare(reactor, clock, homeserver)
+
+        # mock out the Agent used by the federation client, which is easier than
+        # catching the HTTPS connection and do the TLS stuff.
+        self._mock_agent = mock.create_autospec(twisted.web.client.Agent, spec_set=True)
+        homeserver.get_federation_http_client().agent = self._mock_agent
+
+    def test_get_room_state(self):
+        creator = f"@creator:{self.OTHER_SERVER_NAME}"
+        test_room_id = "!room_id"
+
+        # mock up some events to use in the response.
+        # In real life, these would have things in `prev_events` and `auth_events`, but that's
+        # a bit annoying to mock up, and the code under test doesn't care, so we don't bother.
+        create_event_dict = self.add_hashes_and_signatures(
+            {
+                "room_id": test_room_id,
+                "type": "m.room.create",
+                "state_key": "",
+                "sender": creator,
+                "content": {"creator": creator},
+                "prev_events": [],
+                "auth_events": [],
+                "origin_server_ts": 500,
+            }
+        )
+        member_event_dict = self.add_hashes_and_signatures(
+            {
+                "room_id": test_room_id,
+                "type": "m.room.member",
+                "sender": creator,
+                "state_key": creator,
+                "content": {"membership": "join"},
+                "prev_events": [],
+                "auth_events": [],
+                "origin_server_ts": 600,
+            }
+        )
+        pl_event_dict = self.add_hashes_and_signatures(
+            {
+                "room_id": test_room_id,
+                "type": "m.room.power_levels",
+                "sender": creator,
+                "state_key": "",
+                "content": {},
+                "prev_events": [],
+                "auth_events": [],
+                "origin_server_ts": 700,
+            }
+        )
+
+        # mock up the response, and have the agent return it
+        self._mock_agent.request.return_value = defer.succeed(
+            _mock_response(
+                {
+                    "pdus": [
+                        create_event_dict,
+                        member_event_dict,
+                        pl_event_dict,
+                    ],
+                    "auth_chain": [
+                        create_event_dict,
+                        member_event_dict,
+                    ],
+                }
+            )
+        )
+
+        # now fire off the request
+        state_resp, auth_resp = self.get_success(
+            self.hs.get_federation_client().get_room_state(
+                "yet_another_server",
+                test_room_id,
+                "event_id",
+                RoomVersions.V9,
+            )
+        )
+
+        # check the right call got made to the agent
+        self._mock_agent.request.assert_called_once_with(
+            b"GET",
+            b"matrix://yet_another_server/_matrix/federation/v1/state/%21room_id?event_id=event_id",
+            headers=mock.ANY,
+            bodyProducer=None,
+        )
+
+        # ... and that the response is correct.
+
+        # the auth_resp should be empty because all the events are also in state
+        self.assertEqual(auth_resp, [])
+
+        # all of the events should be returned in state_resp, though not necessarily
+        # in the same order. We just check the type on the assumption that if the type
+        # is right, so is the rest of the event.
+        self.assertCountEqual(
+            [e.type for e in state_resp],
+            ["m.room.create", "m.room.member", "m.room.power_levels"],
+        )
+
+
+def _mock_response(resp: JsonDict):
+    body = json.dumps(resp).encode("utf-8")
+
+    def deliver_body(p: Protocol):
+        p.dataReceived(body)
+        p.connectionLost(Failure(twisted.web.client.ResponseDone()))
+
+    response = mock.Mock(
+        code=200,
+        phrase=b"OK",
+        headers=twisted.web.client.Headers({"content-Type": ["application/json"]}),
+        length=len(body),
+        deliverBody=deliver_body,
+    )
+    mock.seal(response)
+    return response
diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py
index a7031a55f2..c2320ce133 100644
--- a/tests/federation/transport/test_client.py
+++ b/tests/federation/transport/test_client.py
@@ -62,3 +62,35 @@ class SendJoinParserTestCase(TestCase):
         self.assertEqual(len(parsed_response.state), 1, parsed_response)
         self.assertEqual(parsed_response.event_dict, {}, parsed_response)
         self.assertIsNone(parsed_response.event, parsed_response)
+        self.assertFalse(parsed_response.partial_state, parsed_response)
+        self.assertEqual(parsed_response.servers_in_room, None, parsed_response)
+
+    def test_partial_state(self) -> None:
+        """Check that the partial_state flag is correctly parsed"""
+        parser = SendJoinParser(RoomVersions.V1, False)
+        response = {
+            "org.matrix.msc3706.partial_state": True,
+        }
+
+        serialised_response = json.dumps(response).encode()
+
+        # Send data to the parser
+        parser.write(serialised_response)
+
+        # Retrieve and check the parsed SendJoinResponse
+        parsed_response = parser.finish()
+        self.assertTrue(parsed_response.partial_state)
+
+    def test_servers_in_room(self) -> None:
+        """Check that the servers_in_room field is correctly parsed"""
+        parser = SendJoinParser(RoomVersions.V1, False)
+        response = {"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]}
+
+        serialised_response = json.dumps(response).encode()
+
+        # Send data to the parser
+        parser.write(serialised_response)
+
+        # Retrieve and check the parsed SendJoinResponse
+        parsed_response = parser.finish()
+        self.assertEqual(parsed_response.servers_in_room, ["hs1", "hs2"])
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 4740dd0a65..49d832de81 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -84,7 +84,7 @@ class CustomAuthProvider:
 
     def __init__(self, config, api: ModuleApi):
         api.register_password_auth_provider_callbacks(
-            auth_checkers={("test.login_type", ("test_field",)): self.check_auth},
+            auth_checkers={("test.login_type", ("test_field",)): self.check_auth}
         )
 
     def check_auth(self, *args):
@@ -122,7 +122,7 @@ class PasswordCustomAuthProvider:
             auth_checkers={
                 ("test.login_type", ("test_field",)): self.check_auth,
                 ("m.login.password", ("password",)): self.check_auth,
-            },
+            }
         )
         pass
 
@@ -163,6 +163,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         account.register_servlets,
     ]
 
+    CALLBACK_USERNAME = "get_username_for_registration"
+    CALLBACK_DISPLAYNAME = "get_displayname_for_registration"
+
     def setUp(self):
         # we use a global mock device, so make sure we are starting with a clean slate
         mock_password_provider.reset_mock()
@@ -754,7 +757,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         """Tests that the get_username_for_registration callback can define the username
         of a user when registering.
         """
-        self._setup_get_username_for_registration()
+        self._setup_get_name_for_registration(
+            callback_name=self.CALLBACK_USERNAME,
+        )
 
         username = "rin"
         channel = self.make_request(
@@ -777,30 +782,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         """Tests that the get_username_for_registration callback is only called at the
         end of the UIA flow.
         """
-        m = self._setup_get_username_for_registration()
-
-        # Initiate the UIA flow.
-        username = "rin"
-        channel = self.make_request(
-            "POST",
-            "register",
-            {"username": username, "type": "m.login.password", "password": "bar"},
+        m = self._setup_get_name_for_registration(
+            callback_name=self.CALLBACK_USERNAME,
         )
-        self.assertEqual(channel.code, 401)
-        self.assertIn("session", channel.json_body)
 
-        # Check that the callback hasn't been called yet.
-        m.assert_not_called()
+        username = "rin"
+        res = self._do_uia_assert_mock_not_called(username, m)
 
-        # Finish the UIA flow.
-        session = channel.json_body["session"]
-        channel = self.make_request(
-            "POST",
-            "register",
-            {"auth": {"session": session, "type": LoginType.DUMMY}},
-        )
-        self.assertEqual(channel.code, 200, channel.json_body)
-        mxid = channel.json_body["user_id"]
+        mxid = res["user_id"]
         self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
 
         # Check that the callback has been called.
@@ -817,6 +806,56 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self._test_3pid_allowed("rin", False)
         self._test_3pid_allowed("kitay", True)
 
+    def test_displayname(self):
+        """Tests that the get_displayname_for_registration callback can define the
+        display name of a user when registering.
+        """
+        self._setup_get_name_for_registration(
+            callback_name=self.CALLBACK_DISPLAYNAME,
+        )
+
+        username = "rin"
+        channel = self.make_request(
+            "POST",
+            "/register",
+            {
+                "username": username,
+                "password": "bar",
+                "auth": {"type": LoginType.DUMMY},
+            },
+        )
+        self.assertEqual(channel.code, 200)
+
+        # Our callback takes the username and appends "-foo" to it, check that's what we
+        # have.
+        user_id = UserID.from_string(channel.json_body["user_id"])
+        display_name = self.get_success(
+            self.hs.get_profile_handler().get_displayname(user_id)
+        )
+
+        self.assertEqual(display_name, username + "-foo")
+
+    def test_displayname_uia(self):
+        """Tests that the get_displayname_for_registration callback is only called at the
+        end of the UIA flow.
+        """
+        m = self._setup_get_name_for_registration(
+            callback_name=self.CALLBACK_DISPLAYNAME,
+        )
+
+        username = "rin"
+        res = self._do_uia_assert_mock_not_called(username, m)
+
+        user_id = UserID.from_string(res["user_id"])
+        display_name = self.get_success(
+            self.hs.get_profile_handler().get_displayname(user_id)
+        )
+
+        self.assertEqual(display_name, username + "-foo")
+
+        # Check that the callback has been called.
+        m.assert_called_once()
+
     def _test_3pid_allowed(self, username: str, registration: bool):
         """Tests that the "is_3pid_allowed" module callback is called correctly, using
         either /register or /account URLs depending on the arguments.
@@ -877,23 +916,47 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
         m.assert_called_once_with("email", "bar@test.com", registration)
 
-    def _setup_get_username_for_registration(self) -> Mock:
-        """Registers a get_username_for_registration callback that appends "-foo" to the
-        username the client is trying to register.
+    def _setup_get_name_for_registration(self, callback_name: str) -> Mock:
+        """Registers either a get_username_for_registration callback or a
+        get_displayname_for_registration callback that appends "-foo" to the username the
+        client is trying to register.
         """
 
-        async def get_username_for_registration(uia_results, params):
+        async def callback(uia_results, params):
             self.assertIn(LoginType.DUMMY, uia_results)
             username = params["username"]
             return username + "-foo"
 
-        m = Mock(side_effect=get_username_for_registration)
+        m = Mock(side_effect=callback)
 
         password_auth_provider = self.hs.get_password_auth_provider()
-        password_auth_provider.get_username_for_registration_callbacks.append(m)
+        getattr(password_auth_provider, callback_name + "_callbacks").append(m)
 
         return m
 
+    def _do_uia_assert_mock_not_called(self, username: str, m: Mock) -> JsonDict:
+        # Initiate the UIA flow.
+        channel = self.make_request(
+            "POST",
+            "register",
+            {"username": username, "type": "m.login.password", "password": "bar"},
+        )
+        self.assertEqual(channel.code, 401)
+        self.assertIn("session", channel.json_body)
+
+        # Check that the callback hasn't been called yet.
+        m.assert_not_called()
+
+        # Finish the UIA flow.
+        session = channel.json_body["session"]
+        channel = self.make_request(
+            "POST",
+            "register",
+            {"auth": {"session": session, "type": LoginType.DUMMY}},
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        return channel.json_body
+
     def _get_login_flows(self) -> JsonDict:
         channel = self.make_request("GET", "/_matrix/client/r0/login")
         self.assertEqual(channel.code, 200, channel.result)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index c068d329a9..e1e3fb97c5 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -571,9 +571,7 @@ class HTTPPusherTests(HomeserverTestCase):
         # Carry out our option-value specific test
         #
         # This push should still only contain an unread count of 1 (for 1 unread room)
-        self.assertEqual(
-            self.push_attempts[5][2]["notification"]["counts"]["unread"], 1
-        )
+        self._check_push_attempt(6, 1)
 
     @override_config({"push": {"group_unread_count_by_room": False}})
     def test_push_unread_count_message_count(self):
@@ -585,11 +583,9 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # Carry out our option-value specific test
         #
-        # We're counting every unread message, so there should now be 4 since the
+        # We're counting every unread message, so there should now be 3 since the
         # last read receipt
-        self.assertEqual(
-            self.push_attempts[5][2]["notification"]["counts"]["unread"], 4
-        )
+        self._check_push_attempt(6, 3)
 
     def _test_push_unread_count(self):
         """
@@ -597,8 +593,9 @@ class HTTPPusherTests(HomeserverTestCase):
 
         Note that:
         * Sending messages will cause push notifications to go out to relevant users
-        * Sending a read receipt will cause a "badge update" notification to go out to
-          the user that sent the receipt
+        * Sending a read receipt will cause the HTTP pusher to check whether the unread
+            count has changed since the last push notification. If so, a "badge update"
+            notification goes out to the user that sent the receipt
         """
         # Register the user who gets notified
         user_id = self.register_user("user", "pass")
@@ -642,24 +639,74 @@ class HTTPPusherTests(HomeserverTestCase):
         # position in the room. We'll set the read position to this event in a moment
         first_message_event_id = response["event_id"]
 
-        # Advance time a bit (so the pusher will register something has happened) and
-        # make the push succeed
-        self.push_attempts[0][0].callback({})
+        expected_push_attempts = 1
+        self._check_push_attempt(expected_push_attempts, 0)
+
+        self._send_read_request(access_token, first_message_event_id, room_id)
+
+        # Unread count has not changed. Therefore, ensure that read request does not
+        # trigger a push notification.
+        self.assertEqual(len(self.push_attempts), 1)
+
+        # Send another message
+        response2 = self.helper.send(
+            room_id, body="How's the weather today?", tok=other_access_token
+        )
+        second_message_event_id = response2["event_id"]
+
+        expected_push_attempts += 1
+
+        self._check_push_attempt(expected_push_attempts, 1)
+
+        self._send_read_request(access_token, second_message_event_id, room_id)
+        expected_push_attempts += 1
+
+        self._check_push_attempt(expected_push_attempts, 0)
+
+        # If we're grouping by room, sending more messages shouldn't increase the
+        # unread count, as they're all being sent in the same room. Otherwise, it
+        # should. Therefore, the last call to _check_push_attempt is done in the
+        # caller method.
+        self.helper.send(room_id, body="Hello?", tok=other_access_token)
+        expected_push_attempts += 1
+
+        self._advance_time_and_make_push_succeed(expected_push_attempts)
+
+        self.helper.send(room_id, body="Hello??", tok=other_access_token)
+        expected_push_attempts += 1
+
+        self._advance_time_and_make_push_succeed(expected_push_attempts)
+
+        self.helper.send(room_id, body="HELLO???", tok=other_access_token)
+
+    def _advance_time_and_make_push_succeed(self, expected_push_attempts):
         self.pump()
+        self.push_attempts[expected_push_attempts - 1][0].callback({})
 
+    def _check_push_attempt(
+        self, expected_push_attempts: int, expected_unread_count_last_push: int
+    ) -> None:
+        """
+        Makes sure that the last expected push attempt succeeds and checks whether
+        it contains the expected unread count.
+        """
+        self._advance_time_and_make_push_succeed(expected_push_attempts)
         # Check our push made it
-        self.assertEqual(len(self.push_attempts), 1)
+        self.assertEqual(len(self.push_attempts), expected_push_attempts)
+        _, push_url, push_body = self.push_attempts[expected_push_attempts - 1]
         self.assertEqual(
-            self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+            push_url,
+            "http://example.com/_matrix/push/v1/notify",
         )
-
         # Check that the unread count for the room is 0
         #
         # The unread count is zero as the user has no read receipt in the room yet
         self.assertEqual(
-            self.push_attempts[0][2]["notification"]["counts"]["unread"], 0
+            push_body["notification"]["counts"]["unread"],
+            expected_unread_count_last_push,
         )
 
+    def _send_read_request(self, access_token, message_event_id, room_id):
         # Now set the user's read receipt position to the first event
         #
         # This will actually trigger a new notification to be sent out so that
@@ -667,56 +714,8 @@ class HTTPPusherTests(HomeserverTestCase):
         # count goes down
         channel = self.make_request(
             "POST",
-            "/rooms/%s/receipt/m.read/%s" % (room_id, first_message_event_id),
+            "/rooms/%s/receipt/m.read/%s" % (room_id, message_event_id),
             {},
             access_token=access_token,
         )
         self.assertEqual(channel.code, 200, channel.json_body)
-
-        # Advance time and make the push succeed
-        self.push_attempts[1][0].callback({})
-        self.pump()
-
-        # Unread count is still zero as we've read the only message in the room
-        self.assertEqual(len(self.push_attempts), 2)
-        self.assertEqual(
-            self.push_attempts[1][2]["notification"]["counts"]["unread"], 0
-        )
-
-        # Send another message
-        self.helper.send(
-            room_id, body="How's the weather today?", tok=other_access_token
-        )
-
-        # Advance time and make the push succeed
-        self.push_attempts[2][0].callback({})
-        self.pump()
-
-        # This push should contain an unread count of 1 as there's now been one
-        # message since our last read receipt
-        self.assertEqual(len(self.push_attempts), 3)
-        self.assertEqual(
-            self.push_attempts[2][2]["notification"]["counts"]["unread"], 1
-        )
-
-        # Since we're grouping by room, sending more messages shouldn't increase the
-        # unread count, as they're all being sent in the same room
-        self.helper.send(room_id, body="Hello?", tok=other_access_token)
-
-        # Advance time and make the push succeed
-        self.pump()
-        self.push_attempts[3][0].callback({})
-
-        self.helper.send(room_id, body="Hello??", tok=other_access_token)
-
-        # Advance time and make the push succeed
-        self.pump()
-        self.push_attempts[4][0].callback({})
-
-        self.helper.send(room_id, body="HELLO???", tok=other_access_token)
-
-        # Advance time and make the push succeed
-        self.pump()
-        self.push_attempts[5][0].callback({})
-
-        self.assertEqual(len(self.push_attempts), 6)
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 89d85b0a17..51146c471d 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -486,8 +486,9 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
             {
                 "user_id": user_id,
                 "device_id": device_id,
-                # Unstable until MSC3069 enters spec
+                # MSC3069 entered spec in Matrix 1.2 but maintained compatibility
                 "org.matrix.msc3069.is_guest": False,
+                "is_guest": False,
             },
         )
 
@@ -505,8 +506,9 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
             {
                 "user_id": user_id,
                 "device_id": device_id,
-                # Unstable until MSC3069 enters spec
+                # MSC3069 entered spec in Matrix 1.2 but maintained compatibility
                 "org.matrix.msc3069.is_guest": True,
+                "is_guest": True,
             },
         )
 
@@ -528,8 +530,9 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
             whoami,
             {
                 "user_id": user_id,
-                # Unstable until MSC3069 enters spec
+                # MSC3069 entered spec in Matrix 1.2 but maintained compatibility
                 "org.matrix.msc3069.is_guest": False,
+                "is_guest": False,
             },
         )
         self.assertFalse(hasattr(whoami, "device_id"))
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 27cb856b0a..4a68d66573 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -13,15 +13,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from http import HTTPStatus
-from typing import Optional, Union
+from typing import Optional, Tuple, Union
 
 from twisted.internet.defer import succeed
 
 import synapse.rest.admin
 from synapse.api.constants import LoginType
 from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
-from synapse.rest.client import account, auth, devices, login, register
+from synapse.rest.client import account, auth, devices, login, logout, register
 from synapse.rest.synapse.client import build_synapse_client_resource_tree
+from synapse.storage.database import LoggingTransaction
 from synapse.types import JsonDict, UserID
 
 from tests import unittest
@@ -527,6 +528,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
         auth.register_servlets,
         account.register_servlets,
         login.register_servlets,
+        logout.register_servlets,
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         register.register_servlets,
     ]
@@ -984,3 +986,90 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
         self.assertEqual(
             fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result
         )
+
+    def test_many_token_refresh(self):
+        """
+        If a refresh is performed many times during a session, there shouldn't be
+        extra 'cruft' built up over time.
+
+        This test was written specifically to troubleshoot a case where logout
+        was very slow if a lot of refreshes had been performed for the session.
+        """
+
+        def _refresh(refresh_token: str) -> Tuple[str, str]:
+            """
+            Performs one refresh, returning the next refresh token and access token.
+            """
+            refresh_response = self.use_refresh_token(refresh_token)
+            self.assertEqual(
+                refresh_response.code, HTTPStatus.OK, refresh_response.result
+            )
+            return (
+                refresh_response.json_body["refresh_token"],
+                refresh_response.json_body["access_token"],
+            )
+
+        def _table_length(table_name: str) -> int:
+            """
+            Helper to get the size of a table, in rows.
+            For testing only; trivially vulnerable to SQL injection.
+            """
+
+            def _txn(txn: LoggingTransaction) -> int:
+                txn.execute(f"SELECT COUNT(1) FROM {table_name}")
+                row = txn.fetchone()
+                # Query is infallible
+                assert row is not None
+                return row[0]
+
+            return self.get_success(
+                self.hs.get_datastores().main.db_pool.runInteraction(
+                    "_table_length", _txn
+                )
+            )
+
+        # Before we log in, there are no access tokens.
+        self.assertEqual(_table_length("access_tokens"), 0)
+        self.assertEqual(_table_length("refresh_tokens"), 0)
+
+        body = {
+            "type": "m.login.password",
+            "user": "test",
+            "password": self.user_pass,
+            "refresh_token": True,
+        }
+        login_response = self.make_request(
+            "POST",
+            "/_matrix/client/v3/login",
+            body,
+        )
+        self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result)
+
+        access_token = login_response.json_body["access_token"]
+        refresh_token = login_response.json_body["refresh_token"]
+
+        # Now that we have logged in, there should be one access token and one
+        # refresh token
+        self.assertEqual(_table_length("access_tokens"), 1)
+        self.assertEqual(_table_length("refresh_tokens"), 1)
+
+        for _ in range(5):
+            refresh_token, access_token = _refresh(refresh_token)
+
+        # After 5 sequential refreshes, there should only be the latest two
+        # refresh/access token pairs.
+        # (The last one is preserved because it's in use!
+        # The one before that is preserved because it can still be used to
+        # replace the last token pair, in case of e.g. a network interruption.)
+        self.assertEqual(_table_length("access_tokens"), 2)
+        self.assertEqual(_table_length("refresh_tokens"), 2)
+
+        logout_response = self.make_request(
+            "POST", "/_matrix/client/v3/logout", {}, access_token=access_token
+        )
+        self.assertEqual(logout_response.code, HTTPStatus.OK, logout_response.result)
+
+        # Now that we have logged in, there should be no access token
+        # and no refresh token
+        self.assertEqual(_table_length("access_tokens"), 0)
+        self.assertEqual(_table_length("refresh_tokens"), 0)
diff --git a/tests/rest/client/test_device_lists.py b/tests/rest/client/test_device_lists.py
new file mode 100644
index 0000000000..16070cf027
--- /dev/null
+++ b/tests/rest/client/test_device_lists.py
@@ -0,0 +1,155 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.rest import admin, devices, room, sync
+from synapse.rest.client import account, login, register
+
+from tests import unittest
+
+
+class DeviceListsTestCase(unittest.HomeserverTestCase):
+    """Tests regarding device list changes."""
+
+    servlets = [
+        admin.register_servlets_for_client_rest_resource,
+        login.register_servlets,
+        register.register_servlets,
+        account.register_servlets,
+        room.register_servlets,
+        sync.register_servlets,
+        devices.register_servlets,
+    ]
+
+    def test_receiving_local_device_list_changes(self):
+        """Tests that a local users that share a room receive each other's device list
+        changes.
+        """
+        # Register two users
+        test_device_id = "TESTDEVICE"
+        alice_user_id = self.register_user("alice", "correcthorse")
+        alice_access_token = self.login(
+            alice_user_id, "correcthorse", device_id=test_device_id
+        )
+
+        bob_user_id = self.register_user("bob", "ponyponypony")
+        bob_access_token = self.login(bob_user_id, "ponyponypony")
+
+        # Create a room for them to coexist peacefully in
+        new_room_id = self.helper.create_room_as(
+            alice_user_id, is_public=True, tok=alice_access_token
+        )
+        self.assertIsNotNone(new_room_id)
+
+        # Have Bob join the room
+        self.helper.invite(
+            new_room_id, alice_user_id, bob_user_id, tok=alice_access_token
+        )
+        self.helper.join(new_room_id, bob_user_id, tok=bob_access_token)
+
+        # Now have Bob initiate an initial sync (in order to get a since token)
+        channel = self.make_request(
+            "GET",
+            "/sync",
+            access_token=bob_access_token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        next_batch_token = channel.json_body["next_batch"]
+
+        # ...and then an incremental sync. This should block until the sync stream is woken up,
+        # which we hope will happen as a result of Alice updating their device list.
+        bob_sync_channel = self.make_request(
+            "GET",
+            f"/sync?since={next_batch_token}&timeout=30000",
+            access_token=bob_access_token,
+            # Start the request, then continue on.
+            await_result=False,
+        )
+
+        # Have alice update their device list
+        channel = self.make_request(
+            "PUT",
+            f"/devices/{test_device_id}",
+            {
+                "display_name": "New Device Name",
+            },
+            access_token=alice_access_token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check that bob's incremental sync contains the updated device list.
+        # If not, the client would only receive the device list update on the
+        # *next* sync.
+        bob_sync_channel.await_result()
+        self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
+
+        changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
+            "changed", []
+        )
+        self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body)
+
+    def test_not_receiving_local_device_list_changes(self):
+        """Tests a local users DO NOT receive device updates from each other if they do not
+        share a room.
+        """
+        # Register two users
+        test_device_id = "TESTDEVICE"
+        alice_user_id = self.register_user("alice", "correcthorse")
+        alice_access_token = self.login(
+            alice_user_id, "correcthorse", device_id=test_device_id
+        )
+
+        bob_user_id = self.register_user("bob", "ponyponypony")
+        bob_access_token = self.login(bob_user_id, "ponyponypony")
+
+        # These users do not share a room. They are lonely.
+
+        # Have Bob initiate an initial sync (in order to get a since token)
+        channel = self.make_request(
+            "GET",
+            "/sync",
+            access_token=bob_access_token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        next_batch_token = channel.json_body["next_batch"]
+
+        # ...and then an incremental sync. This should block until the sync stream is woken up,
+        # which we hope will happen as a result of Alice updating their device list.
+        bob_sync_channel = self.make_request(
+            "GET",
+            f"/sync?since={next_batch_token}&timeout=1000",
+            access_token=bob_access_token,
+            # Start the request, then continue on.
+            await_result=False,
+        )
+
+        # Have alice update their device list
+        channel = self.make_request(
+            "PUT",
+            f"/devices/{test_device_id}",
+            {
+                "display_name": "New Device Name",
+            },
+            access_token=alice_access_token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check that bob's incremental sync does not contain the updated device list.
+        bob_sync_channel.await_result()
+        self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
+
+        changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
+            "changed", []
+        )
+        self.assertNotIn(
+            alice_user_id, changed_device_lists, bob_sync_channel.json_body
+        )
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index de80aca037..dfd9ffcb93 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -1123,6 +1123,48 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
         )
 
+    @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
+    def test_edit_thread(self):
+        """Test that editing a thread works."""
+
+        # Create a thread and edit the last event.
+        channel = self._send_relation(
+            RelationTypes.THREAD,
+            "m.room.message",
+            content={"msgtype": "m.text", "body": "A threaded reply!"},
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        threaded_event_id = channel.json_body["event_id"]
+
+        new_body = {"msgtype": "m.text", "body": "I've been edited!"}
+        channel = self._send_relation(
+            RelationTypes.REPLACE,
+            "m.room.message",
+            content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
+            parent_id=threaded_event_id,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        # Fetch the thread root, to get the bundled aggregation for the thread.
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room}/event/{self.parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        # We expect that the edit message appears in the thread summary in the
+        # unsigned relations section.
+        relations_dict = channel.json_body["unsigned"].get("m.relations")
+        self.assertIn(RelationTypes.THREAD, relations_dict)
+
+        thread_summary = relations_dict[RelationTypes.THREAD]
+        self.assertIn("latest_event", thread_summary)
+        latest_event_in_thread = thread_summary["latest_event"]
+        self.assertEquals(
+            latest_event_in_thread["content"]["body"], "I've been edited!"
+        )
+
     def test_edit_edit(self):
         """Test that an edit cannot be edited."""
         new_body = {"msgtype": "m.text", "body": "Initial edit"}
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 1c0cb0cf4f..2b3fdadffa 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -106,9 +106,13 @@ class RestHelper:
                 default room version.
             tok: The access token to use in the request.
             expect_code: The expected HTTP response code.
+            extra_content: Extra keys to include in the body of the /createRoom request.
+                Note that if is_public is set, the "visibility" key will be overridden.
+                If room_version is set, the "room_version" key will be overridden.
+            custom_headers: HTTP headers to include in the request.
 
         Returns:
-            The ID of the newly created room.
+            The ID of the newly created room, or None if the request failed.
         """
         temp_id = self.auth_user_id
         self.auth_user_id = room_creator
diff --git a/tests/storage/databases/test_state_store.py b/tests/storage/databases/test_state_store.py
new file mode 100644
index 0000000000..076b660809
--- /dev/null
+++ b/tests/storage/databases/test_state_store.py
@@ -0,0 +1,352 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import typing
+from typing import Dict, List, Sequence, Tuple
+from unittest.mock import patch
+
+from twisted.internet.defer import Deferred, ensureDeferred
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.constants import EventTypes
+from synapse.storage.databases.state.store import MAX_INFLIGHT_REQUESTS_PER_GROUP
+from synapse.storage.state import StateFilter
+from synapse.types import StateMap
+from synapse.util import Clock
+
+from tests.unittest import HomeserverTestCase
+
+if typing.TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+# StateFilter for ALL non-m.room.member state events
+ALL_NON_MEMBERS_STATE_FILTER = StateFilter.freeze(
+    types={EventTypes.Member: set()},
+    include_others=True,
+)
+
+FAKE_STATE = {
+    (EventTypes.Member, "@alice:test"): "join",
+    (EventTypes.Member, "@bob:test"): "leave",
+    (EventTypes.Member, "@charlie:test"): "invite",
+    ("test.type", "a"): "AAA",
+    ("test.type", "b"): "BBB",
+    ("other.event.type", "state.key"): "123",
+}
+
+
+class StateGroupInflightCachingTestCase(HomeserverTestCase):
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: "HomeServer"
+    ) -> None:
+        self.state_storage = homeserver.get_storage().state
+        self.state_datastore = homeserver.get_datastores().state
+        # Patch out the `_get_state_groups_from_groups`.
+        # This is useful because it lets us pretend we have a slow database.
+        get_state_groups_patch = patch.object(
+            self.state_datastore,
+            "_get_state_groups_from_groups",
+            self._fake_get_state_groups_from_groups,
+        )
+        get_state_groups_patch.start()
+
+        self.addCleanup(get_state_groups_patch.stop)
+        self.get_state_group_calls: List[
+            Tuple[Tuple[int, ...], StateFilter, Deferred[Dict[int, StateMap[str]]]]
+        ] = []
+
+    def _fake_get_state_groups_from_groups(
+        self, groups: Sequence[int], state_filter: StateFilter
+    ) -> "Deferred[Dict[int, StateMap[str]]]":
+        d: Deferred[Dict[int, StateMap[str]]] = Deferred()
+        self.get_state_group_calls.append((tuple(groups), state_filter, d))
+        return d
+
+    def _complete_request_fake(
+        self,
+        groups: Tuple[int, ...],
+        state_filter: StateFilter,
+        d: "Deferred[Dict[int, StateMap[str]]]",
+    ) -> None:
+        """
+        Assemble a fake database response and complete the database request.
+        """
+
+        # Return a filtered copy of the fake state
+        d.callback({group: state_filter.filter_state(FAKE_STATE) for group in groups})
+
+    def test_duplicate_requests_deduplicated(self) -> None:
+        """
+        Tests that duplicate requests for state are deduplicated.
+
+        This test:
+        - requests some state (state group 42, 'all' state filter)
+        - requests it again, before the first request finishes
+        - checks to see that only one database query was made
+        - completes the database query
+        - checks that both requests see the same retrieved state
+        """
+        req1 = ensureDeferred(
+            self.state_datastore._get_state_for_group_using_inflight_cache(
+                42, StateFilter.all()
+            )
+        )
+        self.pump(by=0.1)
+
+        # This should have gone to the database
+        self.assertEqual(len(self.get_state_group_calls), 1)
+        self.assertFalse(req1.called)
+
+        req2 = ensureDeferred(
+            self.state_datastore._get_state_for_group_using_inflight_cache(
+                42, StateFilter.all()
+            )
+        )
+        self.pump(by=0.1)
+
+        # No more calls should have gone to the database
+        self.assertEqual(len(self.get_state_group_calls), 1)
+        self.assertFalse(req1.called)
+        self.assertFalse(req2.called)
+
+        groups, sf, d = self.get_state_group_calls[0]
+        self.assertEqual(groups, (42,))
+        self.assertEqual(sf, StateFilter.all())
+
+        # Now we can complete the request
+        self._complete_request_fake(groups, sf, d)
+
+        self.assertEqual(self.get_success(req1), FAKE_STATE)
+        self.assertEqual(self.get_success(req2), FAKE_STATE)
+
+    def test_smaller_request_deduplicated(self) -> None:
+        """
+        Tests that duplicate requests for state are deduplicated.
+
+        This test:
+        - requests some state (state group 42, 'all' state filter)
+        - requests a subset of that state, before the first request finishes
+        - checks to see that only one database query was made
+        - completes the database query
+        - checks that both requests see the correct retrieved state
+        """
+        req1 = ensureDeferred(
+            self.state_datastore._get_state_for_group_using_inflight_cache(
+                42, StateFilter.from_types((("test.type", None),))
+            )
+        )
+        self.pump(by=0.1)
+
+        # This should have gone to the database
+        self.assertEqual(len(self.get_state_group_calls), 1)
+        self.assertFalse(req1.called)
+
+        req2 = ensureDeferred(
+            self.state_datastore._get_state_for_group_using_inflight_cache(
+                42, StateFilter.from_types((("test.type", "b"),))
+            )
+        )
+        self.pump(by=0.1)
+
+        # No more calls should have gone to the database, because the second
+        # request was already in the in-flight cache!
+        self.assertEqual(len(self.get_state_group_calls), 1)
+        self.assertFalse(req1.called)
+        self.assertFalse(req2.called)
+
+        groups, sf, d = self.get_state_group_calls[0]
+        self.assertEqual(groups, (42,))
+        # The state filter is expanded internally for increased cache hit rate,
+        # so we the database sees a wider state filter than requested.
+        self.assertEqual(sf, ALL_NON_MEMBERS_STATE_FILTER)
+
+        # Now we can complete the request
+        self._complete_request_fake(groups, sf, d)
+
+        self.assertEqual(
+            self.get_success(req1),
+            {("test.type", "a"): "AAA", ("test.type", "b"): "BBB"},
+        )
+        self.assertEqual(self.get_success(req2), {("test.type", "b"): "BBB"})
+
+    def test_partially_overlapping_request_deduplicated(self) -> None:
+        """
+        Tests that partially-overlapping requests are partially deduplicated.
+
+        This test:
+        - requests a single type of wildcard state
+          (This is internally expanded to be all non-member state)
+        - requests the entire state in parallel
+        - checks to see that two database queries were made, but that the second
+          one is only for member state.
+        - completes the database queries
+        - checks that both requests have the correct result.
+        """
+
+        req1 = ensureDeferred(
+            self.state_datastore._get_state_for_group_using_inflight_cache(
+                42, StateFilter.from_types((("test.type", None),))
+            )
+        )
+        self.pump(by=0.1)
+
+        # This should have gone to the database
+        self.assertEqual(len(self.get_state_group_calls), 1)
+        self.assertFalse(req1.called)
+
+        req2 = ensureDeferred(
+            self.state_datastore._get_state_for_group_using_inflight_cache(
+                42, StateFilter.all()
+            )
+        )
+        self.pump(by=0.1)
+
+        # Because it only partially overlaps, this also went to the database
+        self.assertEqual(len(self.get_state_group_calls), 2)
+        self.assertFalse(req1.called)
+        self.assertFalse(req2.called)
+
+        # First request:
+        groups, sf, d = self.get_state_group_calls[0]
+        self.assertEqual(groups, (42,))
+        # The state filter is expanded internally for increased cache hit rate,
+        # so we the database sees a wider state filter than requested.
+        self.assertEqual(sf, ALL_NON_MEMBERS_STATE_FILTER)
+        self._complete_request_fake(groups, sf, d)
+
+        # Second request:
+        groups, sf, d = self.get_state_group_calls[1]
+        self.assertEqual(groups, (42,))
+        # The state filter is narrowed to only request membership state, because
+        # the remainder of the state is already being queried in the first request!
+        self.assertEqual(
+            sf, StateFilter.freeze({EventTypes.Member: None}, include_others=False)
+        )
+        self._complete_request_fake(groups, sf, d)
+
+        # Check the results are correct
+        self.assertEqual(
+            self.get_success(req1),
+            {("test.type", "a"): "AAA", ("test.type", "b"): "BBB"},
+        )
+        self.assertEqual(self.get_success(req2), FAKE_STATE)
+
+    def test_in_flight_requests_stop_being_in_flight(self) -> None:
+        """
+        Tests that in-flight request deduplication doesn't somehow 'hold on'
+        to completed requests: once they're done, they're taken out of the
+        in-flight cache.
+        """
+        req1 = ensureDeferred(
+            self.state_datastore._get_state_for_group_using_inflight_cache(
+                42, StateFilter.all()
+            )
+        )
+        self.pump(by=0.1)
+
+        # This should have gone to the database
+        self.assertEqual(len(self.get_state_group_calls), 1)
+        self.assertFalse(req1.called)
+
+        # Complete the request right away.
+        self._complete_request_fake(*self.get_state_group_calls[0])
+        self.assertTrue(req1.called)
+
+        # Send off another request
+        req2 = ensureDeferred(
+            self.state_datastore._get_state_for_group_using_inflight_cache(
+                42, StateFilter.all()
+            )
+        )
+        self.pump(by=0.1)
+
+        # It should have gone to the database again, because the previous request
+        # isn't in-flight and therefore isn't available for deduplication.
+        self.assertEqual(len(self.get_state_group_calls), 2)
+        self.assertFalse(req2.called)
+
+        # Complete the request right away.
+        self._complete_request_fake(*self.get_state_group_calls[1])
+        self.assertTrue(req2.called)
+        groups, sf, d = self.get_state_group_calls[0]
+
+        self.assertEqual(self.get_success(req1), FAKE_STATE)
+        self.assertEqual(self.get_success(req2), FAKE_STATE)
+
+    def test_inflight_requests_capped(self) -> None:
+        """
+        Tests that the number of in-flight requests is capped to 5.
+
+        - requests several pieces of state separately
+          (5 to hit the limit, 1 to 'shunt out', another that comes after the
+          group has been 'shunted out')
+        - checks to see that the torrent of requests is shunted out by
+          rewriting one of the filters as the 'all' state filter
+        - requests after that one do not cause any additional queries
+        """
+        # 5 at the time of writing.
+        CAP_COUNT = MAX_INFLIGHT_REQUESTS_PER_GROUP
+
+        reqs = []
+
+        # Request 7 different keys (1 to 7) of the `some.state` type.
+        for req_id in range(CAP_COUNT + 2):
+            reqs.append(
+                ensureDeferred(
+                    self.state_datastore._get_state_for_group_using_inflight_cache(
+                        42,
+                        StateFilter.freeze(
+                            {"some.state": {str(req_id + 1)}}, include_others=False
+                        ),
+                    )
+                )
+            )
+        self.pump(by=0.1)
+
+        # There should only be 6 calls to the database, not 7.
+        self.assertEqual(len(self.get_state_group_calls), CAP_COUNT + 1)
+
+        # Assert that the first 5 are exact requests for the individual pieces
+        # wanted
+        for req_id in range(CAP_COUNT):
+            groups, sf, d = self.get_state_group_calls[req_id]
+            self.assertEqual(
+                sf,
+                StateFilter.freeze(
+                    {"some.state": {str(req_id + 1)}}, include_others=False
+                ),
+            )
+
+        # The 6th request should be the 'all' state filter
+        groups, sf, d = self.get_state_group_calls[CAP_COUNT]
+        self.assertEqual(sf, StateFilter.all())
+
+        # Complete the queries and check which requests complete as a result
+        for req_id in range(CAP_COUNT):
+            # This request should not have been completed yet
+            self.assertFalse(reqs[req_id].called)
+
+            groups, sf, d = self.get_state_group_calls[req_id]
+            self._complete_request_fake(groups, sf, d)
+
+            # This should have only completed this one request
+            self.assertTrue(reqs[req_id].called)
+
+        # Now complete the final query; the last 2 requests should complete
+        # as a result
+        self.assertFalse(reqs[CAP_COUNT].called)
+        self.assertFalse(reqs[CAP_COUNT + 1].called)
+        groups, sf, d = self.get_state_group_calls[CAP_COUNT]
+        self._complete_request_fake(groups, sf, d)
+        self.assertTrue(reqs[CAP_COUNT].called)
+        self.assertTrue(reqs[CAP_COUNT + 1].called)
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index f462a8b1c7..a8639d8f82 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -329,3 +329,110 @@ class ExtremPruneTestCase(HomeserverTestCase):
 
         # Check the new extremity is just the new remote event.
         self.assert_extremities([local_message_event_id, remote_event_2.event_id])
+
+
+class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, homeserver):
+        self.state = self.hs.get_state_handler()
+        self.persistence = self.hs.get_storage().persistence
+        self.store = self.hs.get_datastore()
+
+    def test_remote_user_rooms_cache_invalidated(self):
+        """Test that if the server leaves a room the `get_rooms_for_user` cache
+        is invalidated for remote users.
+        """
+
+        # Set up a room with a local and remote user in it.
+        user_id = self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        room_id = self.helper.create_room_as(
+            "user", room_version=RoomVersions.V6.identifier, tok=token
+        )
+
+        body = self.helper.send(room_id, body="Test", tok=token)
+        local_message_event_id = body["event_id"]
+
+        # Fudge a join event for a remote user.
+        remote_user = "@user:other"
+        remote_event_1 = event_from_pdu_json(
+            {
+                "type": EventTypes.Member,
+                "state_key": remote_user,
+                "content": {"membership": Membership.JOIN},
+                "room_id": room_id,
+                "sender": remote_user,
+                "depth": 5,
+                "prev_events": [local_message_event_id],
+                "auth_events": [],
+                "origin_server_ts": self.clock.time_msec(),
+            },
+            RoomVersions.V6,
+        )
+
+        context = self.get_success(self.state.compute_event_context(remote_event_1))
+        self.get_success(self.persistence.persist_event(remote_event_1, context))
+
+        # Call `get_rooms_for_user` to add the remote user to the cache
+        rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
+        self.assertEqual(set(rooms), {room_id})
+
+        # Now we have the local server leave the room, and check that calling
+        # `get_user_in_room` for the remote user no longer includes the room.
+        self.helper.leave(room_id, user_id, tok=token)
+
+        rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
+        self.assertEqual(set(rooms), set())
+
+    def test_room_remote_user_cache_invalidated(self):
+        """Test that if the server leaves a room the `get_users_in_room` cache
+        is invalidated for remote users.
+        """
+
+        # Set up a room with a local and remote user in it.
+        user_id = self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        room_id = self.helper.create_room_as(
+            "user", room_version=RoomVersions.V6.identifier, tok=token
+        )
+
+        body = self.helper.send(room_id, body="Test", tok=token)
+        local_message_event_id = body["event_id"]
+
+        # Fudge a join event for a remote user.
+        remote_user = "@user:other"
+        remote_event_1 = event_from_pdu_json(
+            {
+                "type": EventTypes.Member,
+                "state_key": remote_user,
+                "content": {"membership": Membership.JOIN},
+                "room_id": room_id,
+                "sender": remote_user,
+                "depth": 5,
+                "prev_events": [local_message_event_id],
+                "auth_events": [],
+                "origin_server_ts": self.clock.time_msec(),
+            },
+            RoomVersions.V6,
+        )
+
+        context = self.get_success(self.state.compute_event_context(remote_event_1))
+        self.get_success(self.persistence.persist_event(remote_event_1, context))
+
+        # Call `get_users_in_room` to add the remote user to the cache
+        users = self.get_success(self.store.get_users_in_room(room_id))
+        self.assertEqual(set(users), {user_id, remote_user})
+
+        # Now we have the local server leave the room, and check that calling
+        # `get_user_in_room` for the remote user no longer includes the room.
+        self.helper.leave(room_id, user_id, tok=token)
+
+        users = self.get_success(self.store.get_users_in_room(room_id))
+        self.assertEqual(users, [])
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 70d52b088c..28c767ecfd 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -992,3 +992,112 @@ class StateFilterDifferenceTestCase(TestCase):
             StateFilter.none(),
             StateFilter.all(),
         )
+
+
+class StateFilterTestCase(TestCase):
+    def test_return_expanded(self):
+        """
+        Tests the behaviour of the return_expanded() function that expands
+        StateFilters to include more state types (for the sake of cache hit rate).
+        """
+
+        self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all())
+
+        self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none())
+
+        # Concrete-only state filters stay the same
+        # (Case: mixed filter)
+        self.assertEqual(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:test", "@alicia:test"},
+                    "some.other.state.type": {""},
+                },
+                include_others=False,
+            ).return_expanded(),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:test", "@alicia:test"},
+                    "some.other.state.type": {""},
+                },
+                include_others=False,
+            ),
+        )
+
+        # Concrete-only state filters stay the same
+        # (Case: non-member-only filter)
+        self.assertEqual(
+            StateFilter.freeze(
+                {"some.other.state.type": {""}}, include_others=False
+            ).return_expanded(),
+            StateFilter.freeze({"some.other.state.type": {""}}, include_others=False),
+        )
+
+        # Concrete-only state filters stay the same
+        # (Case: member-only filter)
+        self.assertEqual(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:test", "@alicia:test"},
+                },
+                include_others=False,
+            ).return_expanded(),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:test", "@alicia:test"},
+                },
+                include_others=False,
+            ),
+        )
+
+        # Wildcard member-only state filters stay the same
+        self.assertEqual(
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ).return_expanded(),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+        )
+
+        # If there is a wildcard in the non-member portion of the filter,
+        # it's expanded to include ALL non-member events.
+        # (Case: mixed filter)
+        self.assertEqual(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:test", "@alicia:test"},
+                    "some.other.state.type": None,
+                },
+                include_others=False,
+            ).return_expanded(),
+            StateFilter.freeze(
+                {EventTypes.Member: {"@wombat:test", "@alicia:test"}},
+                include_others=True,
+            ),
+        )
+
+        # If there is a wildcard in the non-member portion of the filter,
+        # it's expanded to include ALL non-member events.
+        # (Case: non-member-only filter)
+        self.assertEqual(
+            StateFilter.freeze(
+                {
+                    "some.other.state.type": None,
+                },
+                include_others=False,
+            ).return_expanded(),
+            StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
+        )
+        self.assertEqual(
+            StateFilter.freeze(
+                {
+                    "some.other.state.type": None,
+                    "yet.another.state.type": {"wombat"},
+                },
+                include_others=False,
+            ).return_expanded(),
+            StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
+        )
diff --git a/tests/unittest.py b/tests/unittest.py
index a71892cb9d..7983c1e8b8 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -51,7 +51,10 @@ from twisted.web.server import Request
 
 from synapse import events
 from synapse.api.constants import EventTypes, Membership
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.config.homeserver import HomeServerConfig
+from synapse.config.server import DEFAULT_ROOM_VERSION
+from synapse.crypto.event_signing import add_hashes_and_signatures
 from synapse.federation.transport.server import TransportLayerServer
 from synapse.http.server import JsonResource
 from synapse.http.site import SynapseRequest, SynapseSite
@@ -839,6 +842,24 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
             client_ip=client_ip,
         )
 
+    def add_hashes_and_signatures(
+        self,
+        event_dict: JsonDict,
+        room_version: RoomVersion = KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
+    ) -> JsonDict:
+        """Adds hashes and signatures to the given event dict
+
+        Returns:
+             The modified event dict, for convenience
+        """
+        add_hashes_and_signatures(
+            room_version,
+            event_dict,
+            signature_name=self.OTHER_SERVER_NAME,
+            signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
+        )
+        return event_dict
+
 
 def _auth_header_for_request(
     origin: str,