summary refs log tree commit diff
path: root/tests/rest
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest')
-rw-r--r--tests/rest/admin/test_background_updates.py2
-rw-r--r--tests/rest/admin/test_federation.py4
-rw-r--r--tests/rest/admin/test_media.py4
-rw-r--r--tests/rest/admin/test_registration_tokens.py2
-rw-r--r--tests/rest/admin/test_room.py6
-rw-r--r--tests/rest/admin/test_server_notice.py2
-rw-r--r--tests/rest/admin/test_user.py20
-rw-r--r--tests/rest/client/test_account.py10
-rw-r--r--tests/rest/client/test_filter.py2
-rw-r--r--tests/rest/client/test_login.py4
-rw-r--r--tests/rest/client/test_profile.py2
-rw-r--r--tests/rest/client/test_register.py28
-rw-r--r--tests/rest/client/test_relations.py4
-rw-r--r--tests/rest/client/test_retention.py4
-rw-r--r--tests/rest/client/test_rooms.py10
-rw-r--r--tests/rest/client/test_shadow_banned.py2
-rw-r--r--tests/rest/client/test_shared_rooms.py2
-rw-r--r--tests/rest/client/test_sync.py2
-rw-r--r--tests/rest/client/test_typing.py2
-rw-r--r--tests/rest/client/test_upgrade_room.py2
-rw-r--r--tests/rest/media/v1/test_media_storage.py2
21 files changed, 60 insertions, 56 deletions
diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py
index 1e3fe9c62c..fb36aa9940 100644
--- a/tests/rest/admin/test_background_updates.py
+++ b/tests/rest/admin/test_background_updates.py
@@ -36,7 +36,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py
index 71068d16cd..929bbdc37d 100644
--- a/tests/rest/admin/test_federation.py
+++ b/tests/rest/admin/test_federation.py
@@ -35,7 +35,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
@@ -537,7 +537,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 86aff7575c..0d47dd0aff 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -634,7 +634,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         media_repo = hs.get_media_repository_resource()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.server_name = hs.hostname
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -767,7 +767,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         media_repo = hs.get_media_repository_resource()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py
index 8513b1d2df..8354250ec2 100644
--- a/tests/rest/admin/test_registration_tokens.py
+++ b/tests/rest/admin/test_registration_tokens.py
@@ -34,7 +34,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 23da0ad736..09c48e85c7 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -50,7 +50,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
         self.event_creation_handler._consent_uri_builder = consent_uri_builder
 
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
@@ -465,7 +465,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
         consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
         self.event_creation_handler._consent_uri_builder = consent_uri_builder
 
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
@@ -2239,7 +2239,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index 3c59f5f766..2c855bff99 100644
--- a/tests/rest/admin/test_server_notice.py
+++ b/tests/rest/admin/test_server_notice.py
@@ -38,7 +38,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.room_shutdown_handler = hs.get_room_shutdown_handler()
         self.pagination_handler = hs.get_pagination_handler()
         self.server_notices_manager = self.hs.get_server_notices_manager()
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 272637e965..a60ea0a563 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -410,7 +410,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         even if the MAU limit is reached.
         """
         handler = self.hs.get_registration_handler()
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
 
         # Set monthly active users to the limit
         store.get_monthly_active_count = Mock(
@@ -455,7 +455,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
     url = "/_synapse/admin/v2/users"
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
@@ -913,7 +913,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
@@ -1167,7 +1167,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth_handler = hs.get_auth_handler()
 
         # create users and get access tokens
@@ -2609,7 +2609,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
@@ -2737,7 +2737,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.media_repo = hs.get_media_repository_resource()
         self.filepaths = MediaFilePaths(hs.config.media.media_store_path)
 
@@ -3317,7 +3317,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
@@ -3609,7 +3609,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
@@ -3687,7 +3687,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
@@ -3913,7 +3913,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index afaa597f65..aa019c9a44 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -77,7 +77,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         return hs
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.submit_token_resource = PasswordResetSubmitTokenResource(hs)
 
     def test_basic_password_reset(self):
@@ -398,7 +398,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
 
         self.deactivate(user_id, tok)
 
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
 
         # Check that the user has been marked as deactivated.
         self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id)))
@@ -409,7 +409,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
 
     def test_pending_invites(self):
         """Tests that deactivating a user rejects every pending invite for them."""
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
 
         inviter_id = self.register_user("inviter", "test")
         inviter_tok = self.login("inviter", "test")
@@ -527,7 +527,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
             namespaces={"users": [{"regex": user_id, "exclusive": True}]},
             sender=user_id,
         )
-        self.hs.get_datastore().services_cache.append(appservice)
+        self.hs.get_datastores().main.services_cache.append(appservice)
 
         whoami = self._whoami(as_token)
         self.assertEqual(
@@ -586,7 +586,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         return self.hs
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.user_id = self.register_user("kermit", "test")
         self.user_id_tok = self.login("kermit", "test")
diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index 475c6bed3d..a573cc3c2e 100644
--- a/tests/rest/client/test_filter.py
+++ b/tests/rest/client/test_filter.py
@@ -32,7 +32,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.filtering = hs.get_filtering()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     def test_add_filter(self):
         channel = self.make_request(
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index 19f5e46537..26d0d83e00 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -1101,8 +1101,8 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        self.hs.get_datastore().services_cache.append(self.service)
-        self.hs.get_datastore().services_cache.append(self.another_service)
+        self.hs.get_datastores().main.services_cache.append(self.service)
+        self.hs.get_datastores().main.services_cache.append(self.another_service)
         return self.hs
 
     def test_login_appservice_user(self):
diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py
index ead883ded8..b9647d5bd8 100644
--- a/tests/rest/client/test_profile.py
+++ b/tests/rest/client/test_profile.py
@@ -292,7 +292,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
                 properties are "mimetype" (for the file's type) and "size" (for the
                 file's size).
         """
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
 
         for name, props in names_and_props.items():
             self.get_success(
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 0f1c47dcbb..2835d86e5b 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -56,7 +56,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             sender="@as:test",
         )
 
-        self.hs.get_datastore().services_cache.append(appservice)
+        self.hs.get_datastores().main.services_cache.append(appservice)
         request_data = json.dumps(
             {"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
         )
@@ -80,7 +80,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             sender="@as:test",
         )
 
-        self.hs.get_datastore().services_cache.append(appservice)
+        self.hs.get_datastores().main.services_cache.append(appservice)
         request_data = json.dumps({"username": "as_user_kermit"})
 
         channel = self.make_request(
@@ -210,7 +210,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         username = "kermit"
         device_id = "frogfone"
         token = "abcd"
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         self.get_success(
             store.db_pool.simple_insert(
                 "registration_tokens",
@@ -316,7 +316,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
     @override_config({"registration_requires_token": True})
     def test_POST_registration_token_limit_uses(self):
         token = "abcd"
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         # Create token that can be used once
         self.get_success(
             store.db_pool.simple_insert(
@@ -391,7 +391,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
     def test_POST_registration_token_expiry(self):
         token = "abcd"
         now = self.hs.get_clock().time_msec()
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         # Create token that expired yesterday
         self.get_success(
             store.db_pool.simple_insert(
@@ -439,7 +439,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
     def test_POST_registration_token_session_expiry(self):
         """Test `pending` is decremented when an uncompleted session expires."""
         token = "abcd"
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         self.get_success(
             store.db_pool.simple_insert(
                 "registration_tokens",
@@ -530,7 +530,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         3. Expire the session
         """
         token = "abcd"
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         self.get_success(
             store.db_pool.simple_insert(
                 "registration_tokens",
@@ -657,7 +657,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         # Add a threepid
         self.get_success(
-            self.hs.get_datastore().user_add_threepid(
+            self.hs.get_datastores().main.user_add_threepid(
                 user_id=user_id,
                 medium="email",
                 address=email,
@@ -941,7 +941,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
         self.email_attempts = []
         self.hs.get_send_email_handler()._sendmail = sendmail
 
-        self.store = self.hs.get_datastore()
+        self.store = self.hs.get_datastores().main
 
         return self.hs
 
@@ -1126,10 +1126,12 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
         # We need to set these directly, instead of in the homeserver config dict above.
         # This is due to account validity-related config options not being read by
         # Synapse when account_validity.enabled is False.
-        self.hs.get_datastore()._account_validity_period = self.validity_period
-        self.hs.get_datastore()._account_validity_startup_job_max_delta = self.max_delta
+        self.hs.get_datastores().main._account_validity_period = self.validity_period
+        self.hs.get_datastores().main._account_validity_startup_job_max_delta = (
+            self.max_delta
+        )
 
-        self.store = self.hs.get_datastore()
+        self.store = self.hs.get_datastores().main
 
         return self.hs
 
@@ -1163,7 +1165,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
 
     def test_GET_token_valid(self):
         token = "abcd"
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         self.get_success(
             store.db_pool.simple_insert(
                 "registration_tokens",
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index dfd9ffcb93..5687dea48d 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -53,7 +53,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         return config
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.user_id, self.user_token = self._create_user("alice")
         self.user2_id, self.user2_token = self._create_user("bob")
@@ -107,7 +107,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
         # Unless that event is referenced from another event!
         self.get_success(
-            self.hs.get_datastore().db_pool.simple_insert(
+            self.hs.get_datastores().main.db_pool.simple_insert(
                 table="event_relations",
                 values={
                     "event_id": "bar",
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index fe5b536d97..c41a1c14a1 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -51,7 +51,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         self.user_id = self.register_user("user", "password")
         self.token = self.login("user", "password")
 
-        self.store = self.hs.get_datastore()
+        self.store = self.hs.get_datastores().main
         self.serializer = self.hs.get_event_client_serializer()
         self.clock = self.hs.get_clock()
 
@@ -114,7 +114,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         """Tests that synapse.visibility.filter_events_for_client correctly filters out
         outdated events
         """
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         storage = self.hs.get_storage()
         room_id = self.helper.create_room_as(self.user_id, tok=self.token)
         events = []
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index b7f086927b..1afd96b8f5 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -65,7 +65,7 @@ class RoomBase(unittest.HomeserverTestCase):
         async def _insert_client_ip(*args, **kwargs):
             return None
 
-        self.hs.get_datastore().insert_client_ip = _insert_client_ip
+        self.hs.get_datastores().main.insert_client_ip = _insert_client_ip
 
         return self.hs
 
@@ -667,7 +667,7 @@ class RoomsCreateTestCase(RoomBase):
 
         # Add the current user to the ratelimit overrides, allowing them no ratelimiting.
         self.get_success(
-            self.hs.get_datastore().set_ratelimit_for_user(self.user_id, 0, 0)
+            self.hs.get_datastores().main.set_ratelimit_for_user(self.user_id, 0, 0)
         )
 
         # Test that the invites aren't ratelimited anymore.
@@ -1060,7 +1060,9 @@ class RoomJoinRatelimitTestCase(RoomBase):
         user_id = self.register_user("testuser", "password")
 
         # Check that the new user successfully joined the four rooms
-        rooms = self.get_success(self.hs.get_datastore().get_rooms_for_user(user_id))
+        rooms = self.get_success(
+            self.hs.get_datastores().main.get_rooms_for_user(user_id)
+        )
         self.assertEqual(len(rooms), 4)
 
 
@@ -1184,7 +1186,7 @@ class RoomMessageListTestCase(RoomBase):
         self.assertTrue("end" in channel.json_body)
 
     def test_room_messages_purge(self):
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         pagination_handler = self.hs.get_pagination_handler()
 
         # Send a first message in the room, which will be removed by the purge.
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index b0c44af033..7d0e66b534 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -34,7 +34,7 @@ class _ShadowBannedBase(unittest.HomeserverTestCase):
         self.banned_user_id = self.register_user("banned", "test")
         self.banned_access_token = self.login("banned", "test")
 
-        self.store = self.hs.get_datastore()
+        self.store = self.hs.get_datastores().main
 
         self.get_success(
             self.store.set_shadow_banned(UserID.from_string(self.banned_user_id), True)
diff --git a/tests/rest/client/test_shared_rooms.py b/tests/rest/client/test_shared_rooms.py
index 283eccd53f..c42c8aff6c 100644
--- a/tests/rest/client/test_shared_rooms.py
+++ b/tests/rest/client/test_shared_rooms.py
@@ -36,7 +36,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
         return self.setup_test_homeserver(config=config)
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.handler = hs.get_user_directory_handler()
 
     def _get_shared_rooms(self, token, other_user) -> FakeChannel:
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index cd4af2b1f3..e062561365 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -299,7 +299,7 @@ class SyncKnockTestCase(
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.url = "/sync?since=%s"
         self.next_batch = "s0"
 
diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py
index ee0abd5295..de312cb63c 100644
--- a/tests/rest/client/test_typing.py
+++ b/tests/rest/client/test_typing.py
@@ -57,7 +57,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
         async def _insert_client_ip(*args, **kwargs):
             return None
 
-        hs.get_datastore().insert_client_ip = _insert_client_ip
+        hs.get_datastores().main.insert_client_ip = _insert_client_ip
 
         return hs
 
diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index a42388b26f..7f79336abc 100644
--- a/tests/rest/client/test_upgrade_room.py
+++ b/tests/rest/client/test_upgrade_room.py
@@ -32,7 +32,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.creator = self.register_user("creator", "pass")
         self.creator_token = self.login(self.creator, "pass")
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 4cf1ed5ddf..6878ccddbf 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -243,7 +243,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         media_resource = hs.get_media_repository_resource()
         self.download_resource = media_resource.children[b"download"]
         self.thumbnail_resource = media_resource.children[b"thumbnail"]
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.media_repo = hs.get_media_repository()
 
         self.media_id = "example.com/12345"