summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/test__base.py299
-rw-r--r--tests/storage/test_appservice.py78
-rw-r--r--tests/storage/test_cleanup_extrems.py36
-rw-r--r--tests/storage/test_client_ips.py19
-rw-r--r--tests/storage/test_event_metrics.py4
-rw-r--r--tests/storage/test_id_generators.py25
-rw-r--r--tests/storage/test_redaction.py4
-rw-r--r--tests/storage/test_registration.py10
-rw-r--r--tests/storage/test_roommember.py4
9 files changed, 108 insertions, 371 deletions
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index f5afed017c..1ac4ebc61d 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -15,308 +15,9 @@
 # limitations under the License.
 
 
-from mock import Mock
-
-from twisted.internet import defer
-
-from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.caches.descriptors import Cache, cached
-
 from tests import unittest
 
 
-class CacheTestCase(unittest.HomeserverTestCase):
-    def prepare(self, reactor, clock, homeserver):
-        self.cache = Cache("test")
-
-    def test_empty(self):
-        failed = False
-        try:
-            self.cache.get("foo")
-        except KeyError:
-            failed = True
-
-        self.assertTrue(failed)
-
-    def test_hit(self):
-        self.cache.prefill("foo", 123)
-
-        self.assertEquals(self.cache.get("foo"), 123)
-
-    def test_invalidate(self):
-        self.cache.prefill(("foo",), 123)
-        self.cache.invalidate(("foo",))
-
-        failed = False
-        try:
-            self.cache.get(("foo",))
-        except KeyError:
-            failed = True
-
-        self.assertTrue(failed)
-
-    def test_eviction(self):
-        cache = Cache("test", max_entries=2)
-
-        cache.prefill(1, "one")
-        cache.prefill(2, "two")
-        cache.prefill(3, "three")  # 1 will be evicted
-
-        failed = False
-        try:
-            cache.get(1)
-        except KeyError:
-            failed = True
-
-        self.assertTrue(failed)
-
-        cache.get(2)
-        cache.get(3)
-
-    def test_eviction_lru(self):
-        cache = Cache("test", max_entries=2)
-
-        cache.prefill(1, "one")
-        cache.prefill(2, "two")
-
-        # Now access 1 again, thus causing 2 to be least-recently used
-        cache.get(1)
-
-        cache.prefill(3, "three")
-
-        failed = False
-        try:
-            cache.get(2)
-        except KeyError:
-            failed = True
-
-        self.assertTrue(failed)
-
-        cache.get(1)
-        cache.get(3)
-
-
-class CacheDecoratorTestCase(unittest.HomeserverTestCase):
-    @defer.inlineCallbacks
-    def test_passthrough(self):
-        class A:
-            @cached()
-            def func(self, key):
-                return key
-
-        a = A()
-
-        self.assertEquals((yield a.func("foo")), "foo")
-        self.assertEquals((yield a.func("bar")), "bar")
-
-    @defer.inlineCallbacks
-    def test_hit(self):
-        callcount = [0]
-
-        class A:
-            @cached()
-            def func(self, key):
-                callcount[0] += 1
-                return key
-
-        a = A()
-        yield a.func("foo")
-
-        self.assertEquals(callcount[0], 1)
-
-        self.assertEquals((yield a.func("foo")), "foo")
-        self.assertEquals(callcount[0], 1)
-
-    @defer.inlineCallbacks
-    def test_invalidate(self):
-        callcount = [0]
-
-        class A:
-            @cached()
-            def func(self, key):
-                callcount[0] += 1
-                return key
-
-        a = A()
-        yield a.func("foo")
-
-        self.assertEquals(callcount[0], 1)
-
-        a.func.invalidate(("foo",))
-
-        yield a.func("foo")
-
-        self.assertEquals(callcount[0], 2)
-
-    def test_invalidate_missing(self):
-        class A:
-            @cached()
-            def func(self, key):
-                return key
-
-        A().func.invalidate(("what",))
-
-    @defer.inlineCallbacks
-    def test_max_entries(self):
-        callcount = [0]
-
-        class A:
-            @cached(max_entries=10)
-            def func(self, key):
-                callcount[0] += 1
-                return key
-
-        a = A()
-
-        for k in range(0, 12):
-            yield a.func(k)
-
-        self.assertEquals(callcount[0], 12)
-
-        # There must have been at least 2 evictions, meaning if we calculate
-        # all 12 values again, we must get called at least 2 more times
-        for k in range(0, 12):
-            yield a.func(k)
-
-        self.assertTrue(
-            callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
-        )
-
-    def test_prefill(self):
-        callcount = [0]
-
-        d = defer.succeed(123)
-
-        class A:
-            @cached()
-            def func(self, key):
-                callcount[0] += 1
-                return d
-
-        a = A()
-
-        a.func.prefill(("foo",), ObservableDeferred(d))
-
-        self.assertEquals(a.func("foo").result, d.result)
-        self.assertEquals(callcount[0], 0)
-
-    @defer.inlineCallbacks
-    def test_invalidate_context(self):
-        callcount = [0]
-        callcount2 = [0]
-
-        class A:
-            @cached()
-            def func(self, key):
-                callcount[0] += 1
-                return key
-
-            @cached(cache_context=True)
-            def func2(self, key, cache_context):
-                callcount2[0] += 1
-                return self.func(key, on_invalidate=cache_context.invalidate)
-
-        a = A()
-        yield a.func2("foo")
-
-        self.assertEquals(callcount[0], 1)
-        self.assertEquals(callcount2[0], 1)
-
-        a.func.invalidate(("foo",))
-        yield a.func("foo")
-
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 1)
-
-        yield a.func2("foo")
-
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 2)
-
-    @defer.inlineCallbacks
-    def test_eviction_context(self):
-        callcount = [0]
-        callcount2 = [0]
-
-        class A:
-            @cached(max_entries=2)
-            def func(self, key):
-                callcount[0] += 1
-                return key
-
-            @cached(cache_context=True)
-            def func2(self, key, cache_context):
-                callcount2[0] += 1
-                return self.func(key, on_invalidate=cache_context.invalidate)
-
-        a = A()
-        yield a.func2("foo")
-        yield a.func2("foo2")
-
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 2)
-
-        yield a.func2("foo")
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 2)
-
-        yield a.func("foo3")
-
-        self.assertEquals(callcount[0], 3)
-        self.assertEquals(callcount2[0], 2)
-
-        yield a.func2("foo")
-
-        self.assertEquals(callcount[0], 4)
-        self.assertEquals(callcount2[0], 3)
-
-    @defer.inlineCallbacks
-    def test_double_get(self):
-        callcount = [0]
-        callcount2 = [0]
-
-        class A:
-            @cached()
-            def func(self, key):
-                callcount[0] += 1
-                return key
-
-            @cached(cache_context=True)
-            def func2(self, key, cache_context):
-                callcount2[0] += 1
-                return self.func(key, on_invalidate=cache_context.invalidate)
-
-        a = A()
-        a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
-
-        yield a.func2("foo")
-
-        self.assertEquals(callcount[0], 1)
-        self.assertEquals(callcount2[0], 1)
-
-        a.func2.invalidate(("foo",))
-        self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
-
-        yield a.func2("foo")
-        a.func2.invalidate(("foo",))
-        self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
-
-        self.assertEquals(callcount[0], 1)
-        self.assertEquals(callcount2[0], 2)
-
-        a.func.invalidate(("foo",))
-        self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
-        yield a.func("foo")
-
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 2)
-
-        yield a.func2("foo")
-
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 3)
-
-
 class UpsertManyTests(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
         self.storage = hs.get_datastore()
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 46f94914ff..1ce29af5fd 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -58,7 +58,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
         # must be done after inserts
         database = hs.get_datastores().databases[0]
         self.store = ApplicationServiceStore(
-            database, make_conn(database._database_config, database.engine), hs
+            database, make_conn(database._database_config, database.engine, "test"), hs
         )
 
     def tearDown(self):
@@ -132,7 +132,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
         db_config = hs.config.get_single_database()
         self.store = TestTransactionStore(
-            database, make_conn(db_config, self.engine), hs
+            database, make_conn(db_config, self.engine, "test"), hs
         )
 
     def _add_service(self, url, as_token, id):
@@ -244,7 +244,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         service = Mock(id=self.as_list[0]["id"])
         events = [Mock(event_id="e1"), Mock(event_id="e2")]
         txn = yield defer.ensureDeferred(
-            self.store.create_appservice_txn(service, events)
+            self.store.create_appservice_txn(service, events, [])
         )
         self.assertEquals(txn.id, 1)
         self.assertEquals(txn.events, events)
@@ -258,7 +258,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         yield self._insert_txn(service.id, 9644, events)
         yield self._insert_txn(service.id, 9645, events)
         txn = yield defer.ensureDeferred(
-            self.store.create_appservice_txn(service, events)
+            self.store.create_appservice_txn(service, events, [])
         )
         self.assertEquals(txn.id, 9646)
         self.assertEquals(txn.events, events)
@@ -270,7 +270,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         events = [Mock(event_id="e1"), Mock(event_id="e2")]
         yield self._set_last_txn(service.id, 9643)
         txn = yield defer.ensureDeferred(
-            self.store.create_appservice_txn(service, events)
+            self.store.create_appservice_txn(service, events, [])
         )
         self.assertEquals(txn.id, 9644)
         self.assertEquals(txn.events, events)
@@ -293,7 +293,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         yield self._insert_txn(self.as_list[3]["id"], 9643, events)
 
         txn = yield defer.ensureDeferred(
-            self.store.create_appservice_txn(service, events)
+            self.store.create_appservice_txn(service, events, [])
         )
         self.assertEquals(txn.id, 9644)
         self.assertEquals(txn.events, events)
@@ -410,6 +410,62 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         )
 
 
+class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
+    def make_homeserver(self, reactor, clock):
+        hs = self.setup_test_homeserver()
+        return hs
+
+    def prepare(self, hs, reactor, clock):
+        self.service = Mock(id="foo")
+        self.store = self.hs.get_datastore()
+        self.get_success(self.store.set_appservice_state(self.service, "up"))
+
+    def test_get_type_stream_id_for_appservice_no_value(self):
+        value = self.get_success(
+            self.store.get_type_stream_id_for_appservice(self.service, "read_receipt")
+        )
+        self.assertEquals(value, 0)
+
+        value = self.get_success(
+            self.store.get_type_stream_id_for_appservice(self.service, "presence")
+        )
+        self.assertEquals(value, 0)
+
+    def test_get_type_stream_id_for_appservice_invalid_type(self):
+        self.get_failure(
+            self.store.get_type_stream_id_for_appservice(self.service, "foobar"),
+            ValueError,
+        )
+
+    def test_set_type_stream_id_for_appservice(self):
+        read_receipt_value = 1024
+        self.get_success(
+            self.store.set_type_stream_id_for_appservice(
+                self.service, "read_receipt", read_receipt_value
+            )
+        )
+        result = self.get_success(
+            self.store.get_type_stream_id_for_appservice(self.service, "read_receipt")
+        )
+        self.assertEqual(result, read_receipt_value)
+
+        self.get_success(
+            self.store.set_type_stream_id_for_appservice(
+                self.service, "presence", read_receipt_value
+            )
+        )
+        result = self.get_success(
+            self.store.get_type_stream_id_for_appservice(self.service, "presence")
+        )
+        self.assertEqual(result, read_receipt_value)
+
+    def test_set_type_stream_id_for_appservice_invalid_type(self):
+        self.get_failure(
+            self.store.set_type_stream_id_for_appservice(self.service, "foobar", 1024),
+            ValueError,
+        )
+
+
 # required for ApplicationServiceTransactionStoreTestCase tests
 class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
@@ -448,7 +504,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
 
         database = hs.get_datastores().databases[0]
         ApplicationServiceStore(
-            database, make_conn(database._database_config, database.engine), hs
+            database, make_conn(database._database_config, database.engine, "test"), hs
         )
 
     @defer.inlineCallbacks
@@ -467,7 +523,9 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         with self.assertRaises(ConfigError) as cm:
             database = hs.get_datastores().databases[0]
             ApplicationServiceStore(
-                database, make_conn(database._database_config, database.engine), hs
+                database,
+                make_conn(database._database_config, database.engine, "test"),
+                hs,
             )
 
         e = cm.exception
@@ -491,7 +549,9 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         with self.assertRaises(ConfigError) as cm:
             database = hs.get_datastores().databases[0]
             ApplicationServiceStore(
-                database, make_conn(database._database_config, database.engine), hs
+                database,
+                make_conn(database._database_config, database.engine, "test"),
+                hs,
             )
 
         e = cm.exception
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 080761d1d2..c13a57dad1 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -22,7 +22,7 @@ import synapse.rest.admin
 from synapse.api.constants import EventTypes
 from synapse.rest.client.v1 import login, room
 from synapse.storage import prepare_database
-from synapse.types import Requester, UserID
+from synapse.types import UserID, create_requester
 
 from tests.unittest import HomeserverTestCase
 
@@ -38,7 +38,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
 
         # Create a test user and room
         self.user = UserID("alice", "test")
-        self.requester = Requester(self.user, None, False, False, None, None)
+        self.requester = create_requester(self.user)
         info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
         self.room_id = info["room_id"]
 
@@ -260,7 +260,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
         # Create a test user and room
         self.user = UserID.from_string(self.register_user("user1", "password"))
         self.token1 = self.login("user1", "password")
-        self.requester = Requester(self.user, None, False, False, None, None)
+        self.requester = create_requester(self.user)
         info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
         self.room_id = info["room_id"]
         self.event_creator = homeserver.get_event_creation_handler()
@@ -309,36 +309,6 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
         )
         self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
 
-    @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0)
-    def test_send_dummy_event_without_consent(self):
-        self._create_extremity_rich_graph()
-        self._enable_consent_checking()
-
-        # Pump the reactor repeatedly so that the background updates have a
-        # chance to run. Attempt to add dummy event with user that has not consented
-        # Check that dummy event send fails.
-        self.pump(10 * 60)
-        latest_event_ids = self.get_success(
-            self.store.get_latest_event_ids_in_room(self.room_id)
-        )
-        self.assertTrue(len(latest_event_ids) == self.EXTREMITIES_COUNT)
-
-        # Create new user, and add consent
-        user2 = self.register_user("user2", "password")
-        token2 = self.login("user2", "password")
-        self.get_success(
-            self.store.user_set_consent_version(user2, self.CONSENT_VERSION)
-        )
-        self.helper.join(self.room_id, user2, tok=token2)
-
-        # Background updates should now cause a dummy event to be added to the graph
-        self.pump(10 * 60)
-
-        latest_event_ids = self.get_success(
-            self.store.get_latest_event_ids_in_room(self.room_id)
-        )
-        self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
-
     @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=250)
     def test_expiry_logic(self):
         """Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion()
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 755c70db31..a69117c5a9 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -21,6 +21,7 @@ from synapse.http.site import XForwardedForRequest
 from synapse.rest.client.v1 import login
 
 from tests import unittest
+from tests.server import make_request
 from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 
@@ -408,18 +409,18 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
         # Advance to a known time
         self.reactor.advance(123456 - self.reactor.seconds())
 
-        request, channel = self.make_request(
+        headers1 = {b"User-Agent": b"Mozzila pizza"}
+        headers1.update(headers)
+
+        make_request(
+            self.reactor,
+            self.site,
             "GET",
-            "/_matrix/client/r0/admin/users/" + self.user_id,
+            "/_synapse/admin/v1/users/" + self.user_id,
             access_token=access_token,
-            **make_request_args
+            custom_headers=headers1.items(),
+            **make_request_args,
         )
-        request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza")
-
-        # Add the optional headers
-        for h, v in headers.items():
-            request.requestHeaders.addRawHeader(h, v)
-        self.render(request)
 
         # Advance so the save loop occurs
         self.reactor.advance(100)
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index 3957471f3f..7691f2d790 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 from synapse.metrics import REGISTRY, generate_latest
-from synapse.types import Requester, UserID
+from synapse.types import UserID, create_requester
 
 from tests.unittest import HomeserverTestCase
 
@@ -27,7 +27,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
         room_creator = self.hs.get_room_creation_handler()
 
         user = UserID("alice", "test")
-        requester = Requester(user, None, False, False, None, None)
+        requester = create_requester(user)
 
         # Real events, forward extremities
         events = [(3, 2), (6, 2), (4, 6)]
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 392b08832b..cc0612cf65 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -199,10 +199,17 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         first_id_gen = self._create_id_generator("first", writers=["first", "second"])
         second_id_gen = self._create_id_generator("second", writers=["first", "second"])
 
-        self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
-        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
+        # The first ID gen will notice that it can advance its token to 7 as it
+        # has no in progress writes...
+        self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 7})
+        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
         self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
 
+        # ... but the second ID gen doesn't know that.
+        self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+
         # Try allocating a new ID gen and check that we only see position
         # advanced after we leave the context manager.
 
@@ -211,7 +218,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
                 self.assertEqual(stream_id, 8)
 
                 self.assertEqual(
-                    first_id_gen.get_positions(), {"first": 3, "second": 7}
+                    first_id_gen.get_positions(), {"first": 7, "second": 7}
                 )
 
         self.get_success(_get_next_async())
@@ -279,7 +286,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self._insert_row_with_id("first", 3)
         self._insert_row_with_id("second", 5)
 
-        id_gen = self._create_id_generator("first", writers=["first", "second"])
+        id_gen = self._create_id_generator("worker", writers=["first", "second"])
 
         self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
 
@@ -319,14 +326,14 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         id_gen = self._create_id_generator("first", writers=["first", "second"])
 
-        self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
+        self.assertEqual(id_gen.get_positions(), {"first": 5, "second": 5})
 
-        self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+        self.assertEqual(id_gen.get_persisted_upto_position(), 5)
 
         async def _get_next_async():
             async with id_gen.get_next() as stream_id:
                 self.assertEqual(stream_id, 6)
-                self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+                self.assertEqual(id_gen.get_persisted_upto_position(), 5)
 
         self.get_success(_get_next_async())
 
@@ -388,7 +395,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self._insert_row_with_id("second", 5)
 
         # Initial config has two writers
-        id_gen = self._create_id_generator("first", writers=["first", "second"])
+        id_gen = self._create_id_generator("worker", writers=["first", "second"])
         self.assertEqual(id_gen.get_persisted_upto_position(), 3)
         self.assertEqual(id_gen.get_current_token_for_writer("first"), 3)
         self.assertEqual(id_gen.get_current_token_for_writer("second"), 5)
@@ -568,7 +575,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         self.get_success(_get_next_async2())
 
-        self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
+        self.assertEqual(id_gen_1.get_positions(), {"first": -2, "second": -2})
         self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
         self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
         self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 1ea35d60c1..d4f9e809db 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -236,9 +236,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
                 self._event_id = event_id
 
             @defer.inlineCallbacks
-            def build(self, prev_event_ids):
+            def build(self, prev_event_ids, auth_event_ids):
                 built_event = yield defer.ensureDeferred(
-                    self._base_builder.build(prev_event_ids)
+                    self._base_builder.build(prev_event_ids, auth_event_ids)
                 )
 
                 built_event._event_id = self._event_id
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 6b582771fe..c8c7a90e5d 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -69,11 +69,9 @@ class RegistrationStoreTestCase(unittest.TestCase):
             self.store.get_user_by_access_token(self.tokens[1])
         )
 
-        self.assertDictContainsSubset(
-            {"name": self.user_id, "device_id": self.device_id}, result
-        )
-
-        self.assertTrue("token_id" in result)
+        self.assertEqual(result.user_id, self.user_id)
+        self.assertEqual(result.device_id, self.device_id)
+        self.assertIsNotNone(result.token_id)
 
     @defer.inlineCallbacks
     def test_user_delete_access_tokens(self):
@@ -105,7 +103,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
         user = yield defer.ensureDeferred(
             self.store.get_user_by_access_token(self.tokens[0])
         )
-        self.assertEqual(self.user_id, user["name"])
+        self.assertEqual(self.user_id, user.user_id)
 
         # now delete the rest
         yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 12ccc1f53e..ff972daeaa 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -19,7 +19,7 @@ from unittest.mock import Mock
 from synapse.api.constants import Membership
 from synapse.rest.admin import register_servlets_for_client_rest_resource
 from synapse.rest.client.v1 import login, room
-from synapse.types import Requester, UserID
+from synapse.types import UserID, create_requester
 
 from tests import unittest
 from tests.test_utils import event_injection
@@ -187,7 +187,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
 
         # Now let's create a room, which will insert a membership
         user = UserID("alice", "test")
-        requester = Requester(user, None, False, False, None, None)
+        requester = create_requester(user)
         self.get_success(self.room_creator.create_room(requester, {}))
 
         # Register the background update to run again.