diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index f5afed017c..1ac4ebc61d 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -15,308 +15,9 @@
# limitations under the License.
-from mock import Mock
-
-from twisted.internet import defer
-
-from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.caches.descriptors import Cache, cached
-
from tests import unittest
-class CacheTestCase(unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, homeserver):
- self.cache = Cache("test")
-
- def test_empty(self):
- failed = False
- try:
- self.cache.get("foo")
- except KeyError:
- failed = True
-
- self.assertTrue(failed)
-
- def test_hit(self):
- self.cache.prefill("foo", 123)
-
- self.assertEquals(self.cache.get("foo"), 123)
-
- def test_invalidate(self):
- self.cache.prefill(("foo",), 123)
- self.cache.invalidate(("foo",))
-
- failed = False
- try:
- self.cache.get(("foo",))
- except KeyError:
- failed = True
-
- self.assertTrue(failed)
-
- def test_eviction(self):
- cache = Cache("test", max_entries=2)
-
- cache.prefill(1, "one")
- cache.prefill(2, "two")
- cache.prefill(3, "three") # 1 will be evicted
-
- failed = False
- try:
- cache.get(1)
- except KeyError:
- failed = True
-
- self.assertTrue(failed)
-
- cache.get(2)
- cache.get(3)
-
- def test_eviction_lru(self):
- cache = Cache("test", max_entries=2)
-
- cache.prefill(1, "one")
- cache.prefill(2, "two")
-
- # Now access 1 again, thus causing 2 to be least-recently used
- cache.get(1)
-
- cache.prefill(3, "three")
-
- failed = False
- try:
- cache.get(2)
- except KeyError:
- failed = True
-
- self.assertTrue(failed)
-
- cache.get(1)
- cache.get(3)
-
-
-class CacheDecoratorTestCase(unittest.HomeserverTestCase):
- @defer.inlineCallbacks
- def test_passthrough(self):
- class A:
- @cached()
- def func(self, key):
- return key
-
- a = A()
-
- self.assertEquals((yield a.func("foo")), "foo")
- self.assertEquals((yield a.func("bar")), "bar")
-
- @defer.inlineCallbacks
- def test_hit(self):
- callcount = [0]
-
- class A:
- @cached()
- def func(self, key):
- callcount[0] += 1
- return key
-
- a = A()
- yield a.func("foo")
-
- self.assertEquals(callcount[0], 1)
-
- self.assertEquals((yield a.func("foo")), "foo")
- self.assertEquals(callcount[0], 1)
-
- @defer.inlineCallbacks
- def test_invalidate(self):
- callcount = [0]
-
- class A:
- @cached()
- def func(self, key):
- callcount[0] += 1
- return key
-
- a = A()
- yield a.func("foo")
-
- self.assertEquals(callcount[0], 1)
-
- a.func.invalidate(("foo",))
-
- yield a.func("foo")
-
- self.assertEquals(callcount[0], 2)
-
- def test_invalidate_missing(self):
- class A:
- @cached()
- def func(self, key):
- return key
-
- A().func.invalidate(("what",))
-
- @defer.inlineCallbacks
- def test_max_entries(self):
- callcount = [0]
-
- class A:
- @cached(max_entries=10)
- def func(self, key):
- callcount[0] += 1
- return key
-
- a = A()
-
- for k in range(0, 12):
- yield a.func(k)
-
- self.assertEquals(callcount[0], 12)
-
- # There must have been at least 2 evictions, meaning if we calculate
- # all 12 values again, we must get called at least 2 more times
- for k in range(0, 12):
- yield a.func(k)
-
- self.assertTrue(
- callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
- )
-
- def test_prefill(self):
- callcount = [0]
-
- d = defer.succeed(123)
-
- class A:
- @cached()
- def func(self, key):
- callcount[0] += 1
- return d
-
- a = A()
-
- a.func.prefill(("foo",), ObservableDeferred(d))
-
- self.assertEquals(a.func("foo").result, d.result)
- self.assertEquals(callcount[0], 0)
-
- @defer.inlineCallbacks
- def test_invalidate_context(self):
- callcount = [0]
- callcount2 = [0]
-
- class A:
- @cached()
- def func(self, key):
- callcount[0] += 1
- return key
-
- @cached(cache_context=True)
- def func2(self, key, cache_context):
- callcount2[0] += 1
- return self.func(key, on_invalidate=cache_context.invalidate)
-
- a = A()
- yield a.func2("foo")
-
- self.assertEquals(callcount[0], 1)
- self.assertEquals(callcount2[0], 1)
-
- a.func.invalidate(("foo",))
- yield a.func("foo")
-
- self.assertEquals(callcount[0], 2)
- self.assertEquals(callcount2[0], 1)
-
- yield a.func2("foo")
-
- self.assertEquals(callcount[0], 2)
- self.assertEquals(callcount2[0], 2)
-
- @defer.inlineCallbacks
- def test_eviction_context(self):
- callcount = [0]
- callcount2 = [0]
-
- class A:
- @cached(max_entries=2)
- def func(self, key):
- callcount[0] += 1
- return key
-
- @cached(cache_context=True)
- def func2(self, key, cache_context):
- callcount2[0] += 1
- return self.func(key, on_invalidate=cache_context.invalidate)
-
- a = A()
- yield a.func2("foo")
- yield a.func2("foo2")
-
- self.assertEquals(callcount[0], 2)
- self.assertEquals(callcount2[0], 2)
-
- yield a.func2("foo")
- self.assertEquals(callcount[0], 2)
- self.assertEquals(callcount2[0], 2)
-
- yield a.func("foo3")
-
- self.assertEquals(callcount[0], 3)
- self.assertEquals(callcount2[0], 2)
-
- yield a.func2("foo")
-
- self.assertEquals(callcount[0], 4)
- self.assertEquals(callcount2[0], 3)
-
- @defer.inlineCallbacks
- def test_double_get(self):
- callcount = [0]
- callcount2 = [0]
-
- class A:
- @cached()
- def func(self, key):
- callcount[0] += 1
- return key
-
- @cached(cache_context=True)
- def func2(self, key, cache_context):
- callcount2[0] += 1
- return self.func(key, on_invalidate=cache_context.invalidate)
-
- a = A()
- a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
-
- yield a.func2("foo")
-
- self.assertEquals(callcount[0], 1)
- self.assertEquals(callcount2[0], 1)
-
- a.func2.invalidate(("foo",))
- self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
-
- yield a.func2("foo")
- a.func2.invalidate(("foo",))
- self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
-
- self.assertEquals(callcount[0], 1)
- self.assertEquals(callcount2[0], 2)
-
- a.func.invalidate(("foo",))
- self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
- yield a.func("foo")
-
- self.assertEquals(callcount[0], 2)
- self.assertEquals(callcount2[0], 2)
-
- yield a.func2("foo")
-
- self.assertEquals(callcount[0], 2)
- self.assertEquals(callcount2[0], 3)
-
-
class UpsertManyTests(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.storage = hs.get_datastore()
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 46f94914ff..1ce29af5fd 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -58,7 +58,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
# must be done after inserts
database = hs.get_datastores().databases[0]
self.store = ApplicationServiceStore(
- database, make_conn(database._database_config, database.engine), hs
+ database, make_conn(database._database_config, database.engine, "test"), hs
)
def tearDown(self):
@@ -132,7 +132,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
db_config = hs.config.get_single_database()
self.store = TestTransactionStore(
- database, make_conn(db_config, self.engine), hs
+ database, make_conn(db_config, self.engine, "test"), hs
)
def _add_service(self, url, as_token, id):
@@ -244,7 +244,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
service = Mock(id=self.as_list[0]["id"])
events = [Mock(event_id="e1"), Mock(event_id="e2")]
txn = yield defer.ensureDeferred(
- self.store.create_appservice_txn(service, events)
+ self.store.create_appservice_txn(service, events, [])
)
self.assertEquals(txn.id, 1)
self.assertEquals(txn.events, events)
@@ -258,7 +258,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._insert_txn(service.id, 9644, events)
yield self._insert_txn(service.id, 9645, events)
txn = yield defer.ensureDeferred(
- self.store.create_appservice_txn(service, events)
+ self.store.create_appservice_txn(service, events, [])
)
self.assertEquals(txn.id, 9646)
self.assertEquals(txn.events, events)
@@ -270,7 +270,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
events = [Mock(event_id="e1"), Mock(event_id="e2")]
yield self._set_last_txn(service.id, 9643)
txn = yield defer.ensureDeferred(
- self.store.create_appservice_txn(service, events)
+ self.store.create_appservice_txn(service, events, [])
)
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)
@@ -293,7 +293,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._insert_txn(self.as_list[3]["id"], 9643, events)
txn = yield defer.ensureDeferred(
- self.store.create_appservice_txn(service, events)
+ self.store.create_appservice_txn(service, events, [])
)
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)
@@ -410,6 +410,62 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
)
+class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver()
+ return hs
+
+ def prepare(self, hs, reactor, clock):
+ self.service = Mock(id="foo")
+ self.store = self.hs.get_datastore()
+ self.get_success(self.store.set_appservice_state(self.service, "up"))
+
+ def test_get_type_stream_id_for_appservice_no_value(self):
+ value = self.get_success(
+ self.store.get_type_stream_id_for_appservice(self.service, "read_receipt")
+ )
+ self.assertEquals(value, 0)
+
+ value = self.get_success(
+ self.store.get_type_stream_id_for_appservice(self.service, "presence")
+ )
+ self.assertEquals(value, 0)
+
+ def test_get_type_stream_id_for_appservice_invalid_type(self):
+ self.get_failure(
+ self.store.get_type_stream_id_for_appservice(self.service, "foobar"),
+ ValueError,
+ )
+
+ def test_set_type_stream_id_for_appservice(self):
+ read_receipt_value = 1024
+ self.get_success(
+ self.store.set_type_stream_id_for_appservice(
+ self.service, "read_receipt", read_receipt_value
+ )
+ )
+ result = self.get_success(
+ self.store.get_type_stream_id_for_appservice(self.service, "read_receipt")
+ )
+ self.assertEqual(result, read_receipt_value)
+
+ self.get_success(
+ self.store.set_type_stream_id_for_appservice(
+ self.service, "presence", read_receipt_value
+ )
+ )
+ result = self.get_success(
+ self.store.get_type_stream_id_for_appservice(self.service, "presence")
+ )
+ self.assertEqual(result, read_receipt_value)
+
+ def test_set_type_stream_id_for_appservice_invalid_type(self):
+ self.get_failure(
+ self.store.set_type_stream_id_for_appservice(self.service, "foobar", 1024),
+ ValueError,
+ )
+
+
# required for ApplicationServiceTransactionStoreTestCase tests
class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
def __init__(self, database: DatabasePool, db_conn, hs):
@@ -448,7 +504,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
database = hs.get_datastores().databases[0]
ApplicationServiceStore(
- database, make_conn(database._database_config, database.engine), hs
+ database, make_conn(database._database_config, database.engine, "test"), hs
)
@defer.inlineCallbacks
@@ -467,7 +523,9 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
with self.assertRaises(ConfigError) as cm:
database = hs.get_datastores().databases[0]
ApplicationServiceStore(
- database, make_conn(database._database_config, database.engine), hs
+ database,
+ make_conn(database._database_config, database.engine, "test"),
+ hs,
)
e = cm.exception
@@ -491,7 +549,9 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
with self.assertRaises(ConfigError) as cm:
database = hs.get_datastores().databases[0]
ApplicationServiceStore(
- database, make_conn(database._database_config, database.engine), hs
+ database,
+ make_conn(database._database_config, database.engine, "test"),
+ hs,
)
e = cm.exception
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 080761d1d2..c13a57dad1 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -22,7 +22,7 @@ import synapse.rest.admin
from synapse.api.constants import EventTypes
from synapse.rest.client.v1 import login, room
from synapse.storage import prepare_database
-from synapse.types import Requester, UserID
+from synapse.types import UserID, create_requester
from tests.unittest import HomeserverTestCase
@@ -38,7 +38,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID("alice", "test")
- self.requester = Requester(self.user, None, False, False, None, None)
+ self.requester = create_requester(self.user)
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
@@ -260,7 +260,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID.from_string(self.register_user("user1", "password"))
self.token1 = self.login("user1", "password")
- self.requester = Requester(self.user, None, False, False, None, None)
+ self.requester = create_requester(self.user)
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
self.event_creator = homeserver.get_event_creation_handler()
@@ -309,36 +309,6 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
)
self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
- @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0)
- def test_send_dummy_event_without_consent(self):
- self._create_extremity_rich_graph()
- self._enable_consent_checking()
-
- # Pump the reactor repeatedly so that the background updates have a
- # chance to run. Attempt to add dummy event with user that has not consented
- # Check that dummy event send fails.
- self.pump(10 * 60)
- latest_event_ids = self.get_success(
- self.store.get_latest_event_ids_in_room(self.room_id)
- )
- self.assertTrue(len(latest_event_ids) == self.EXTREMITIES_COUNT)
-
- # Create new user, and add consent
- user2 = self.register_user("user2", "password")
- token2 = self.login("user2", "password")
- self.get_success(
- self.store.user_set_consent_version(user2, self.CONSENT_VERSION)
- )
- self.helper.join(self.room_id, user2, tok=token2)
-
- # Background updates should now cause a dummy event to be added to the graph
- self.pump(10 * 60)
-
- latest_event_ids = self.get_success(
- self.store.get_latest_event_ids_in_room(self.room_id)
- )
- self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
-
@patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=250)
def test_expiry_logic(self):
"""Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion()
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 755c70db31..a69117c5a9 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -21,6 +21,7 @@ from synapse.http.site import XForwardedForRequest
from synapse.rest.client.v1 import login
from tests import unittest
+from tests.server import make_request
from tests.test_utils import make_awaitable
from tests.unittest import override_config
@@ -408,18 +409,18 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
# Advance to a known time
self.reactor.advance(123456 - self.reactor.seconds())
- request, channel = self.make_request(
+ headers1 = {b"User-Agent": b"Mozzila pizza"}
+ headers1.update(headers)
+
+ make_request(
+ self.reactor,
+ self.site,
"GET",
- "/_matrix/client/r0/admin/users/" + self.user_id,
+ "/_synapse/admin/v1/users/" + self.user_id,
access_token=access_token,
- **make_request_args
+ custom_headers=headers1.items(),
+ **make_request_args,
)
- request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza")
-
- # Add the optional headers
- for h, v in headers.items():
- request.requestHeaders.addRawHeader(h, v)
- self.render(request)
# Advance so the save loop occurs
self.reactor.advance(100)
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index 3957471f3f..7691f2d790 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -14,7 +14,7 @@
# limitations under the License.
from synapse.metrics import REGISTRY, generate_latest
-from synapse.types import Requester, UserID
+from synapse.types import UserID, create_requester
from tests.unittest import HomeserverTestCase
@@ -27,7 +27,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
room_creator = self.hs.get_room_creation_handler()
user = UserID("alice", "test")
- requester = Requester(user, None, False, False, None, None)
+ requester = create_requester(user)
# Real events, forward extremities
events = [(3, 2), (6, 2), (4, 6)]
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 392b08832b..cc0612cf65 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -199,10 +199,17 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
first_id_gen = self._create_id_generator("first", writers=["first", "second"])
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
- self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
- self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
+ # The first ID gen will notice that it can advance its token to 7 as it
+ # has no in progress writes...
+ self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 7})
+ self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
+ # ... but the second ID gen doesn't know that.
+ self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
@@ -211,7 +218,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(stream_id, 8)
self.assertEqual(
- first_id_gen.get_positions(), {"first": 3, "second": 7}
+ first_id_gen.get_positions(), {"first": 7, "second": 7}
)
self.get_success(_get_next_async())
@@ -279,7 +286,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self._insert_row_with_id("first", 3)
self._insert_row_with_id("second", 5)
- id_gen = self._create_id_generator("first", writers=["first", "second"])
+ id_gen = self._create_id_generator("worker", writers=["first", "second"])
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
@@ -319,14 +326,14 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator("first", writers=["first", "second"])
- self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
+ self.assertEqual(id_gen.get_positions(), {"first": 5, "second": 5})
- self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 5)
async def _get_next_async():
async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 6)
- self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 5)
self.get_success(_get_next_async())
@@ -388,7 +395,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self._insert_row_with_id("second", 5)
# Initial config has two writers
- id_gen = self._create_id_generator("first", writers=["first", "second"])
+ id_gen = self._create_id_generator("worker", writers=["first", "second"])
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
self.assertEqual(id_gen.get_current_token_for_writer("first"), 3)
self.assertEqual(id_gen.get_current_token_for_writer("second"), 5)
@@ -568,7 +575,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(_get_next_async2())
- self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
+ self.assertEqual(id_gen_1.get_positions(), {"first": -2, "second": -2})
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 1ea35d60c1..d4f9e809db 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -236,9 +236,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self._event_id = event_id
@defer.inlineCallbacks
- def build(self, prev_event_ids):
+ def build(self, prev_event_ids, auth_event_ids):
built_event = yield defer.ensureDeferred(
- self._base_builder.build(prev_event_ids)
+ self._base_builder.build(prev_event_ids, auth_event_ids)
)
built_event._event_id = self._event_id
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 6b582771fe..c8c7a90e5d 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -69,11 +69,9 @@ class RegistrationStoreTestCase(unittest.TestCase):
self.store.get_user_by_access_token(self.tokens[1])
)
- self.assertDictContainsSubset(
- {"name": self.user_id, "device_id": self.device_id}, result
- )
-
- self.assertTrue("token_id" in result)
+ self.assertEqual(result.user_id, self.user_id)
+ self.assertEqual(result.device_id, self.device_id)
+ self.assertIsNotNone(result.token_id)
@defer.inlineCallbacks
def test_user_delete_access_tokens(self):
@@ -105,7 +103,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
user = yield defer.ensureDeferred(
self.store.get_user_by_access_token(self.tokens[0])
)
- self.assertEqual(self.user_id, user["name"])
+ self.assertEqual(self.user_id, user.user_id)
# now delete the rest
yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 12ccc1f53e..ff972daeaa 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -19,7 +19,7 @@ from unittest.mock import Mock
from synapse.api.constants import Membership
from synapse.rest.admin import register_servlets_for_client_rest_resource
from synapse.rest.client.v1 import login, room
-from synapse.types import Requester, UserID
+from synapse.types import UserID, create_requester
from tests import unittest
from tests.test_utils import event_injection
@@ -187,7 +187,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
# Now let's create a room, which will insert a membership
user = UserID("alice", "test")
- requester = Requester(user, None, False, False, None, None)
+ requester = create_requester(user)
self.get_success(self.room_creator.create_room(requester, {}))
# Register the background update to run again.
|