summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/events/test_snapshot.py100
-rw-r--r--tests/handlers/test_profile.py60
-rw-r--r--tests/handlers/test_register.py29
-rw-r--r--tests/replication/tcp/streams/_base.py34
-rw-r--r--tests/replication/tcp/streams/test_events.py24
-rw-r--r--tests/replication/tcp/streams/test_receipts.py7
-rw-r--r--tests/replication/tcp/streams/test_typing.py7
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py50
-rw-r--r--tests/test_federation.py6
-rw-r--r--tests/test_utils/event_injection.py26
10 files changed, 234 insertions, 109 deletions
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
new file mode 100644
index 0000000000..640f5f3bce
--- /dev/null
+++ b/tests/events/test_snapshot.py
@@ -0,0 +1,100 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.events.snapshot import EventContext
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests import unittest
+from tests.test_utils.event_injection import create_event
+
+
+class TestEventContext(unittest.HomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self.storage = hs.get_storage()
+
+        self.user_id = self.register_user("u1", "pass")
+        self.user_tok = self.login("u1", "pass")
+        self.room_id = self.helper.create_room_as(tok=self.user_tok)
+
+    def test_serialize_deserialize_msg(self):
+        """Test that an EventContext for a message event is the same after
+        serialize/deserialize.
+        """
+
+        event, context = create_event(
+            self.hs, room_id=self.room_id, type="m.test", sender=self.user_id,
+        )
+
+        self._check_serialize_deserialize(event, context)
+
+    def test_serialize_deserialize_state_no_prev(self):
+        """Test that an EventContext for a state event (with not previous entry)
+        is the same after serialize/deserialize.
+        """
+        event, context = create_event(
+            self.hs,
+            room_id=self.room_id,
+            type="m.test",
+            sender=self.user_id,
+            state_key="",
+        )
+
+        self._check_serialize_deserialize(event, context)
+
+    def test_serialize_deserialize_state_prev(self):
+        """Test that an EventContext for a state event (which replaces a
+        previous entry) is the same after serialize/deserialize.
+        """
+        event, context = create_event(
+            self.hs,
+            room_id=self.room_id,
+            type="m.room.member",
+            sender=self.user_id,
+            state_key=self.user_id,
+            content={"membership": "leave"},
+        )
+
+        self._check_serialize_deserialize(event, context)
+
+    def _check_serialize_deserialize(self, event, context):
+        serialized = self.get_success(context.serialize(event, self.store))
+
+        d_context = EventContext.deserialize(self.storage, serialized)
+
+        self.assertEqual(context.state_group, d_context.state_group)
+        self.assertEqual(context.rejected, d_context.rejected)
+        self.assertEqual(
+            context.state_group_before_event, d_context.state_group_before_event
+        )
+        self.assertEqual(context.prev_group, d_context.prev_group)
+        self.assertEqual(context.delta_ids, d_context.delta_ids)
+        self.assertEqual(context.app_service, d_context.app_service)
+
+        self.assertEqual(
+            self.get_success(context.get_current_state_ids()),
+            self.get_success(d_context.get_current_state_ids()),
+        )
+        self.assertEqual(
+            self.get_success(context.get_prev_state_ids()),
+            self.get_success(d_context.get_prev_state_ids()),
+        )
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index be665262c6..8aa56f1496 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -82,18 +82,26 @@ class ProfileTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_set_my_name(self):
-        yield self.handler.set_displayname(
-            self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
+        yield defer.ensureDeferred(
+            self.handler.set_displayname(
+                self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
+            )
         )
 
         self.assertEquals(
-            (yield self.store.get_profile_displayname(self.frank.localpart)),
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_displayname(self.frank.localpart)
+                )
+            ),
             "Frank Jr.",
         )
 
         # Set displayname again
-        yield self.handler.set_displayname(
-            self.frank, synapse.types.create_requester(self.frank), "Frank"
+        yield defer.ensureDeferred(
+            self.handler.set_displayname(
+                self.frank, synapse.types.create_requester(self.frank), "Frank"
+            )
         )
 
         self.assertEquals(
@@ -112,16 +120,20 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         # Setting displayname a second time is forbidden
-        d = self.handler.set_displayname(
-            self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
+        d = defer.ensureDeferred(
+            self.handler.set_displayname(
+                self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
+            )
         )
 
         yield self.assertFailure(d, SynapseError)
 
     @defer.inlineCallbacks
     def test_set_my_name_noauth(self):
-        d = self.handler.set_displayname(
-            self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
+        d = defer.ensureDeferred(
+            self.handler.set_displayname(
+                self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
+            )
         )
 
         yield self.assertFailure(d, AuthError)
@@ -165,10 +177,12 @@ class ProfileTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_set_my_avatar(self):
-        yield self.handler.set_avatar_url(
-            self.frank,
-            synapse.types.create_requester(self.frank),
-            "http://my.server/pic.gif",
+        yield defer.ensureDeferred(
+            self.handler.set_avatar_url(
+                self.frank,
+                synapse.types.create_requester(self.frank),
+                "http://my.server/pic.gif",
+            )
         )
 
         self.assertEquals(
@@ -177,10 +191,12 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         # Set avatar again
-        yield self.handler.set_avatar_url(
-            self.frank,
-            synapse.types.create_requester(self.frank),
-            "http://my.server/me.png",
+        yield defer.ensureDeferred(
+            self.handler.set_avatar_url(
+                self.frank,
+                synapse.types.create_requester(self.frank),
+                "http://my.server/me.png",
+            )
         )
 
         self.assertEquals(
@@ -203,10 +219,12 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         # Set avatar a second time is forbidden
-        d = self.handler.set_avatar_url(
-            self.frank,
-            synapse.types.create_requester(self.frank),
-            "http://my.server/pic.gif",
+        d = defer.ensureDeferred(
+            self.handler.set_avatar_url(
+                self.frank,
+                synapse.types.create_requester(self.frank),
+                "http://my.server/pic.gif",
+            )
         )
 
         yield self.assertFailure(d, SynapseError)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index f1dc51d6c9..1b7935cef2 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -175,7 +175,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         room_alias_str = "#room:test"
         self.hs.config.auto_join_rooms = [room_alias_str]
 
-        self.store.is_real_user = Mock(return_value=False)
+        self.store.is_real_user = Mock(return_value=defer.succeed(False))
         user_id = self.get_success(self.handler.register_user(localpart="support"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         self.assertEqual(len(rooms), 0)
@@ -187,8 +187,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         room_alias_str = "#room:test"
         self.hs.config.auto_join_rooms = [room_alias_str]
 
-        self.store.count_real_users = Mock(return_value=1)
-        self.store.is_real_user = Mock(return_value=True)
+        self.store.count_real_users = Mock(return_value=defer.succeed(1))
+        self.store.is_real_user = Mock(return_value=defer.succeed(True))
         user_id = self.get_success(self.handler.register_user(localpart="real"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         directory_handler = self.hs.get_handlers().directory_handler
@@ -202,8 +202,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         room_alias_str = "#room:test"
         self.hs.config.auto_join_rooms = [room_alias_str]
 
-        self.store.count_real_users = Mock(return_value=2)
-        self.store.is_real_user = Mock(return_value=True)
+        self.store.count_real_users = Mock(return_value=defer.succeed(2))
+        self.store.is_real_user = Mock(return_value=defer.succeed(True))
         user_id = self.get_success(self.handler.register_user(localpart="real"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         self.assertEqual(len(rooms), 0)
@@ -256,8 +256,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
             self.handler.register_user(localpart=invalid_user_id), SynapseError
         )
 
-    @defer.inlineCallbacks
-    def get_or_create_user(self, requester, localpart, displayname, password_hash=None):
+    async def get_or_create_user(
+        self, requester, localpart, displayname, password_hash=None
+    ):
         """Creates a new user if the user does not exist,
         else revokes all previous access tokens and generates a new one.
 
@@ -272,11 +273,11 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         """
         if localpart is None:
             raise SynapseError(400, "Request must include user id")
-        yield self.hs.get_auth().check_auth_blocking()
+        await self.hs.get_auth().check_auth_blocking()
         need_register = True
 
         try:
-            yield self.handler.check_username(localpart)
+            await self.handler.check_username(localpart)
         except SynapseError as e:
             if e.errcode == Codes.USER_IN_USE:
                 need_register = False
@@ -288,23 +289,21 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         token = self.macaroon_generator.generate_access_token(user_id)
 
         if need_register:
-            yield self.handler.register_with_store(
+            await self.handler.register_with_store(
                 user_id=user_id,
                 password_hash=password_hash,
                 create_profile_with_displayname=user.localpart,
             )
         else:
-            yield defer.ensureDeferred(
-                self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
-            )
+            await self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
 
-        yield self.store.add_access_token_to_user(
+        await self.store.add_access_token_to_user(
             user_id=user_id, token=token, device_id=None, valid_until_ms=None
         )
 
         if displayname is not None:
             # logger.info("setting user display name: %s -> %s", user_id, displayname)
-            yield self.hs.get_profile_handler().set_displayname(
+            await self.hs.get_profile_handler().set_displayname(
                 user, requester, displayname, by_admin=True
             )
 
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py
index 83e16cfe3d..7b56d2028d 100644
--- a/tests/replication/tcp/streams/_base.py
+++ b/tests/replication/tcp/streams/_base.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, List, Optional, Tuple
 
 import attr
 
@@ -22,13 +22,15 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
 from twisted.internet.task import LoopingCall
 from twisted.web.http import HTTPChannel
 
-from synapse.app.generic_worker import GenericWorkerServer
+from synapse.app.generic_worker import (
+    GenericWorkerReplicationHandler,
+    GenericWorkerServer,
+)
 from synapse.http.site import SynapseRequest
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.tcp.client import ReplicationDataHandler
 from synapse.replication.tcp.handler import ReplicationCommandHandler
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.server import HomeServer
 from synapse.util import Clock
 
 from tests import unittest
@@ -77,7 +79,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self._server_transport = None
 
     def _build_replication_data_handler(self):
-        return TestReplicationDataHandler(self.worker_hs.get_datastore())
+        return TestReplicationDataHandler(self.worker_hs)
 
     def reconnect(self):
         if self._client_transport:
@@ -172,32 +174,20 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self.assertEqual(request.method, b"GET")
 
 
-class TestReplicationDataHandler(ReplicationDataHandler):
+class TestReplicationDataHandler(GenericWorkerReplicationHandler):
     """Drop-in for ReplicationDataHandler which just collects RDATA rows"""
 
-    def __init__(self, store: BaseSlavedStore):
-        super().__init__(store)
-
-        # streams to subscribe to: map from stream id to position
-        self.stream_positions = {}  # type: Dict[str, int]
+    def __init__(self, hs: HomeServer):
+        super().__init__(hs)
 
         # list of received (stream_name, token, row) tuples
         self.received_rdata_rows = []  # type: List[Tuple[str, int, Any]]
 
-    def get_streams_to_replicate(self):
-        return self.stream_positions
-
-    async def on_rdata(self, stream_name, token, rows):
-        await super().on_rdata(stream_name, token, rows)
+    async def on_rdata(self, stream_name, instance_name, token, rows):
+        await super().on_rdata(stream_name, instance_name, token, rows)
         for r in rows:
             self.received_rdata_rows.append((stream_name, token, r))
 
-        if (
-            stream_name in self.stream_positions
-            and token > self.stream_positions[stream_name]
-        ):
-            self.stream_positions[stream_name] = token
-
 
 @attr.s()
 class OneShotRequestFactory:
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 1fa28084f9..8bd67bb9f1 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -43,7 +43,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
         self.user_tok = self.login("u1", "pass")
 
         self.reconnect()
-        self.test_handler.stream_positions["events"] = 0
 
         self.room_id = self.helper.create_room_as(tok=self.user_tok)
         self.test_handler.received_rdata_rows.clear()
@@ -80,8 +79,12 @@ class EventsStreamTestCase(BaseStreamTestCase):
         self.reconnect()
         self.replicate()
 
-        # we should have received all the expected rows in the right order
-        received_rows = self.test_handler.received_rdata_rows
+        # we should have received all the expected rows in the right order (as
+        # well as various cache invalidation updates which we ignore)
+        received_rows = [
+            row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+        ]
+
         for event in events:
             stream_name, token, row = received_rows.pop(0)
             self.assertEqual("events", stream_name)
@@ -184,7 +187,8 @@ class EventsStreamTestCase(BaseStreamTestCase):
         self.reconnect()
         self.replicate()
 
-        # now we should have received all the expected rows in the right order.
+        # we should have received all the expected rows in the right order (as
+        # well as various cache invalidation updates which we ignore)
         #
         # we expect:
         #
@@ -193,7 +197,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
         #       of the states that got reverted.
         # - two rows for state2
 
-        received_rows = self.test_handler.received_rdata_rows
+        received_rows = [
+            row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+        ]
 
         # first check the first two rows, which should be state1
 
@@ -334,9 +340,11 @@ class EventsStreamTestCase(BaseStreamTestCase):
         self.reconnect()
         self.replicate()
 
-        # we should have received all the expected rows in the right order
-
-        received_rows = self.test_handler.received_rdata_rows
+        # we should have received all the expected rows in the right order (as
+        # well as various cache invalidation updates which we ignore)
+        received_rows = [
+            row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+        ]
         self.assertGreaterEqual(len(received_rows), len(events))
         for i in range(NUM_USERS):
             # for each user, we expect the PL event row, followed by state rows for
diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index c122b8589c..5853314fd4 100644
--- a/tests/replication/tcp/streams/test_receipts.py
+++ b/tests/replication/tcp/streams/test_receipts.py
@@ -31,9 +31,6 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
     def test_receipt(self):
         self.reconnect()
 
-        # make the client subscribe to the receipts stream
-        self.test_handler.stream_positions.update({"receipts": 0})
-
         # tell the master to send a new receipt
         self.get_success(
             self.hs.get_datastore().insert_receipt(
@@ -44,7 +41,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
 
         # there should be one RDATA command
         self.test_handler.on_rdata.assert_called_once()
-        stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "receipts")
         self.assertEqual(1, len(rdata_rows))
         row = rdata_rows[0]  # type: ReceiptsStream.ReceiptsStreamRow
@@ -74,7 +71,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
 
         # We should now have caught up and get the missing data
         self.test_handler.on_rdata.assert_called_once()
-        stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "receipts")
         self.assertEqual(token, 3)
         self.assertEqual(1, len(rdata_rows))
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index 4d354a9db8..d25a7b194e 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -38,9 +38,6 @@ class TypingStreamTestCase(BaseStreamTestCase):
 
         self.reconnect()
 
-        # make the client subscribe to the typing stream
-        self.test_handler.stream_positions.update({"typing": 0})
-
         typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
 
         self.reactor.advance(0)
@@ -50,7 +47,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
         self.assert_request_is_get_repl_stream_updates(request, "typing")
 
         self.test_handler.on_rdata.assert_called_once()
-        stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row = rdata_rows[0]  # type: TypingStream.TypingStreamRow
@@ -77,7 +74,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
         self.assertEqual(int(request.args[b"from_token"][0]), token)
 
         self.test_handler.on_rdata.assert_called_once()
-        stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row = rdata_rows[0]
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 93eb053b8c..406f29a7c0 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -55,26 +55,19 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         self._rlsn._store.user_last_seen_monthly_active = Mock(
             return_value=defer.succeed(1000)
         )
-        self._send_notice = self._rlsn._server_notices_manager.send_notice
-        self._rlsn._server_notices_manager.send_notice = Mock()
-        self._rlsn._state.get_current_state = Mock(return_value=defer.succeed(None))
-        self._rlsn._store.get_events = Mock(return_value=defer.succeed({}))
-
+        self._rlsn._server_notices_manager.send_notice = Mock(
+            return_value=defer.succeed(Mock())
+        )
         self._send_notice = self._rlsn._server_notices_manager.send_notice
 
         self.hs.config.limit_usage_by_mau = True
         self.user_id = "@user_id:test"
 
-        # self.server_notices_mxid = "@server:test"
-        # self.server_notices_mxid_display_name = None
-        # self.server_notices_mxid_avatar_url = None
-        # self.server_notices_room_name = "Server Notices"
-
         self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
-            returnValue=""
+            return_value=defer.succeed("!something:localhost")
         )
-        self._rlsn._store.add_tag_to_room = Mock()
-        self._rlsn._store.get_tags_for_room = Mock(return_value={})
+        self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
+        self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({}))
         self.hs.config.admin_contact = "mailto:user@test.com"
 
     def test_maybe_send_server_notice_to_user_flag_off(self):
@@ -95,14 +88,13 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
     def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
         """Test when user has blocked notice, but should have it removed"""
 
-        self._rlsn._auth.check_auth_blocking = Mock()
+        self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
         mock_event = Mock(
             type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
         )
         self._rlsn._store.get_events = Mock(
             return_value=defer.succeed({"123": mock_event})
         )
-
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
         # Would be better to check the content, but once == remove blocking event
         self._send_notice.assert_called_once()
@@ -112,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         Test when user has blocked notice, but notice ought to be there (NOOP)
         """
         self._rlsn._auth.check_auth_blocking = Mock(
-            side_effect=ResourceLimitError(403, "foo")
+            return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
         )
 
         mock_event = Mock(
@@ -121,6 +113,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         self._rlsn._store.get_events = Mock(
             return_value=defer.succeed({"123": mock_event})
         )
+
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
 
         self._send_notice.assert_not_called()
@@ -129,9 +122,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         """
         Test when user does not have blocked notice, but should have one
         """
-
         self._rlsn._auth.check_auth_blocking = Mock(
-            side_effect=ResourceLimitError(403, "foo")
+            return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
         )
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
 
@@ -142,7 +134,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         """
         Test when user does not have blocked notice, nor should they (NOOP)
         """
-        self._rlsn._auth.check_auth_blocking = Mock()
+        self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
 
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
 
@@ -153,7 +145,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         Test when user is not part of the MAU cohort - this should not ever
         happen - but ...
         """
-        self._rlsn._auth.check_auth_blocking = Mock()
+        self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
         self._rlsn._store.user_last_seen_monthly_active = Mock(
             return_value=defer.succeed(None)
         )
@@ -167,24 +159,28 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         an alert message is not sent into the room
         """
         self.hs.config.mau_limit_alerting = False
+
         self._rlsn._auth.check_auth_blocking = Mock(
+            return_value=defer.succeed(None),
             side_effect=ResourceLimitError(
                 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
-            )
+            ),
         )
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
 
-        self.assertTrue(self._send_notice.call_count == 0)
+        self.assertEqual(self._send_notice.call_count, 0)
 
     def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self):
         """
         Test that when a server is disabled, that MAU limit alerting is ignored.
         """
         self.hs.config.mau_limit_alerting = False
+
         self._rlsn._auth.check_auth_blocking = Mock(
+            return_value=defer.succeed(None),
             side_effect=ResourceLimitError(
                 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
-            )
+            ),
         )
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
 
@@ -198,10 +194,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         """
         self.hs.config.mau_limit_alerting = False
         self._rlsn._auth.check_auth_blocking = Mock(
+            return_value=defer.succeed(None),
             side_effect=ResourceLimitError(
                 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
-            )
+            ),
         )
+
         self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock(
             return_value=defer.succeed((True, []))
         )
@@ -256,7 +254,9 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
     def test_server_notice_only_sent_once(self):
         self.store.get_monthly_active_count = Mock(return_value=1000)
 
-        self.store.user_last_seen_monthly_active = Mock(return_value=1000)
+        self.store.user_last_seen_monthly_active = Mock(
+            return_value=defer.succeed(1000)
+        )
 
         # Call the function multiple times to ensure we only send the notice once
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 9b5cf562f3..f297de95f1 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -27,8 +27,10 @@ class MessageAcceptTests(unittest.TestCase):
         user_id = UserID("us", "test")
         our_user = Requester(user_id, None, False, None, None)
         room_creator = self.homeserver.get_room_creation_handler()
-        room = room_creator.create_room(
-            our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
+        room = ensureDeferred(
+            room_creator.create_room(
+                our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
+            )
         )
         self.reactor.advance(0.1)
         self.room_id = self.successResultOf(room)["room_id"]
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index 8f6872761a..431e9f8e5e 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -14,12 +14,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Optional
+from typing import Optional, Tuple
 
 import synapse.server
 from synapse.api.constants import EventTypes
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
 from synapse.types import Collection
 
 from tests.test_utils import get_awaitable_result
@@ -75,6 +76,23 @@ def inject_event(
     """
     test_reactor = hs.get_reactor()
 
+    event, context = create_event(hs, room_version, prev_event_ids, **kwargs)
+
+    d = hs.get_storage().persistence.persist_event(event, context)
+    test_reactor.advance(0)
+    get_awaitable_result(d)
+
+    return event
+
+
+def create_event(
+    hs: synapse.server.HomeServer,
+    room_version: Optional[str] = None,
+    prev_event_ids: Optional[Collection[str]] = None,
+    **kwargs
+) -> Tuple[EventBase, EventContext]:
+    test_reactor = hs.get_reactor()
+
     if room_version is None:
         d = hs.get_datastore().get_room_version_id(kwargs["room_id"])
         test_reactor.advance(0)
@@ -89,8 +107,4 @@ def inject_event(
     test_reactor.advance(0)
     event, context = get_awaitable_result(d)
 
-    d = hs.get_storage().persistence.persist_event(event, context)
-    test_reactor.advance(0)
-    get_awaitable_result(d)
-
-    return event
+    return event, context