diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 4b53b6d40b..686d17c0de 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -16,6 +16,8 @@ from unittest.mock import Mock
import pymacaroons
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
from synapse.api.errors import (
@@ -26,8 +28,10 @@ from synapse.api.errors import (
ResourceLimitError,
)
from synapse.appservice import ApplicationService
+from synapse.server import HomeServer
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import Requester
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import simple_async_mock
@@ -36,10 +40,10 @@ from tests.utils import mock_getRawHeaders
class AuthTestCase(unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
self.store = Mock()
- hs.get_datastore = Mock(return_value=self.store)
+ hs.datastores.main = self.store
hs.get_auth_handler().store = self.store
self.auth = Auth(hs)
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index b7fc33dc94..973f0f7fa1 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -40,7 +40,7 @@ def MockEvent(**kwargs):
class FilteringTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.filtering = hs.get_filtering()
- self.datastore = hs.get_datastore()
+ self.datastore = hs.get_datastores().main
def test_errors_on_invalid_filters(self):
invalid_filters = [
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index dcf0110c16..4ef754a186 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -8,7 +8,7 @@ from tests import unittest
class TestRatelimiter(unittest.HomeserverTestCase):
def test_allowed_via_can_do_action(self):
limiter = Ratelimiter(
- store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1
)
allowed, time_allowed = self.get_success_or_raise(
limiter.can_do_action(None, key="test_id", _time_now_s=0)
@@ -39,7 +39,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
as_requester = create_requester("@user:example.com", app_service=appservice)
limiter = Ratelimiter(
- store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1
)
allowed, time_allowed = self.get_success_or_raise(
limiter.can_do_action(as_requester, _time_now_s=0)
@@ -70,7 +70,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
as_requester = create_requester("@user:example.com", app_service=appservice)
limiter = Ratelimiter(
- store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1
)
allowed, time_allowed = self.get_success_or_raise(
limiter.can_do_action(as_requester, _time_now_s=0)
@@ -92,7 +92,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
def test_allowed_via_ratelimit(self):
limiter = Ratelimiter(
- store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1
)
# Shouldn't raise
@@ -116,7 +116,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
"""
# Create a Ratelimiter with a very low allowed rate_hz and burst_count
limiter = Ratelimiter(
- store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1
)
# First attempt should be allowed
@@ -162,7 +162,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
"""
# Create a Ratelimiter with a very low allowed rate_hz and burst_count
limiter = Ratelimiter(
- store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1
)
# First attempt should be allowed
@@ -190,7 +190,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
def test_pruning(self):
limiter = Ratelimiter(
- store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1
)
self.get_success_or_raise(
limiter.can_do_action(None, key="test_id_1", _time_now_s=0)
@@ -208,7 +208,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
"""Test that users that have ratelimiting disabled in the DB aren't
ratelimited.
"""
- store = self.hs.get_datastore()
+ store = self.hs.get_datastores().main
user_id = "@user:test"
requester = create_requester(user_id)
@@ -233,7 +233,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
def test_multiple_actions(self):
limiter = Ratelimiter(
- store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=3
+ store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
)
# Test that 4 actions aren't allowed with a maximum burst of 3.
allowed, time_allowed = self.get_success_or_raise(
|