summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_profile.py60
-rw-r--r--tests/handlers/test_register.py29
-rw-r--r--tests/replication/tcp/streams/_base.py54
-rw-r--r--tests/replication/tcp/streams/test_events.py24
-rw-r--r--tests/replication/tcp/streams/test_federation.py81
-rw-r--r--tests/replication/tcp/streams/test_receipts.py7
-rw-r--r--tests/replication/tcp/streams/test_typing.py12
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py50
-rw-r--r--tests/storage/test_id_generators.py184
-rw-r--r--tests/test_federation.py6
10 files changed, 392 insertions, 115 deletions
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index be665262c6..8aa56f1496 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -82,18 +82,26 @@ class ProfileTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_set_my_name(self):
-        yield self.handler.set_displayname(
-            self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
+        yield defer.ensureDeferred(
+            self.handler.set_displayname(
+                self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
+            )
         )
 
         self.assertEquals(
-            (yield self.store.get_profile_displayname(self.frank.localpart)),
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_displayname(self.frank.localpart)
+                )
+            ),
             "Frank Jr.",
         )
 
         # Set displayname again
-        yield self.handler.set_displayname(
-            self.frank, synapse.types.create_requester(self.frank), "Frank"
+        yield defer.ensureDeferred(
+            self.handler.set_displayname(
+                self.frank, synapse.types.create_requester(self.frank), "Frank"
+            )
         )
 
         self.assertEquals(
@@ -112,16 +120,20 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         # Setting displayname a second time is forbidden
-        d = self.handler.set_displayname(
-            self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
+        d = defer.ensureDeferred(
+            self.handler.set_displayname(
+                self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
+            )
         )
 
         yield self.assertFailure(d, SynapseError)
 
     @defer.inlineCallbacks
     def test_set_my_name_noauth(self):
-        d = self.handler.set_displayname(
-            self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
+        d = defer.ensureDeferred(
+            self.handler.set_displayname(
+                self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
+            )
         )
 
         yield self.assertFailure(d, AuthError)
@@ -165,10 +177,12 @@ class ProfileTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_set_my_avatar(self):
-        yield self.handler.set_avatar_url(
-            self.frank,
-            synapse.types.create_requester(self.frank),
-            "http://my.server/pic.gif",
+        yield defer.ensureDeferred(
+            self.handler.set_avatar_url(
+                self.frank,
+                synapse.types.create_requester(self.frank),
+                "http://my.server/pic.gif",
+            )
         )
 
         self.assertEquals(
@@ -177,10 +191,12 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         # Set avatar again
-        yield self.handler.set_avatar_url(
-            self.frank,
-            synapse.types.create_requester(self.frank),
-            "http://my.server/me.png",
+        yield defer.ensureDeferred(
+            self.handler.set_avatar_url(
+                self.frank,
+                synapse.types.create_requester(self.frank),
+                "http://my.server/me.png",
+            )
         )
 
         self.assertEquals(
@@ -203,10 +219,12 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         # Set avatar a second time is forbidden
-        d = self.handler.set_avatar_url(
-            self.frank,
-            synapse.types.create_requester(self.frank),
-            "http://my.server/pic.gif",
+        d = defer.ensureDeferred(
+            self.handler.set_avatar_url(
+                self.frank,
+                synapse.types.create_requester(self.frank),
+                "http://my.server/pic.gif",
+            )
         )
 
         yield self.assertFailure(d, SynapseError)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index f1dc51d6c9..1b7935cef2 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -175,7 +175,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         room_alias_str = "#room:test"
         self.hs.config.auto_join_rooms = [room_alias_str]
 
-        self.store.is_real_user = Mock(return_value=False)
+        self.store.is_real_user = Mock(return_value=defer.succeed(False))
         user_id = self.get_success(self.handler.register_user(localpart="support"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         self.assertEqual(len(rooms), 0)
@@ -187,8 +187,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         room_alias_str = "#room:test"
         self.hs.config.auto_join_rooms = [room_alias_str]
 
-        self.store.count_real_users = Mock(return_value=1)
-        self.store.is_real_user = Mock(return_value=True)
+        self.store.count_real_users = Mock(return_value=defer.succeed(1))
+        self.store.is_real_user = Mock(return_value=defer.succeed(True))
         user_id = self.get_success(self.handler.register_user(localpart="real"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         directory_handler = self.hs.get_handlers().directory_handler
@@ -202,8 +202,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         room_alias_str = "#room:test"
         self.hs.config.auto_join_rooms = [room_alias_str]
 
-        self.store.count_real_users = Mock(return_value=2)
-        self.store.is_real_user = Mock(return_value=True)
+        self.store.count_real_users = Mock(return_value=defer.succeed(2))
+        self.store.is_real_user = Mock(return_value=defer.succeed(True))
         user_id = self.get_success(self.handler.register_user(localpart="real"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         self.assertEqual(len(rooms), 0)
@@ -256,8 +256,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
             self.handler.register_user(localpart=invalid_user_id), SynapseError
         )
 
-    @defer.inlineCallbacks
-    def get_or_create_user(self, requester, localpart, displayname, password_hash=None):
+    async def get_or_create_user(
+        self, requester, localpart, displayname, password_hash=None
+    ):
         """Creates a new user if the user does not exist,
         else revokes all previous access tokens and generates a new one.
 
@@ -272,11 +273,11 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         """
         if localpart is None:
             raise SynapseError(400, "Request must include user id")
-        yield self.hs.get_auth().check_auth_blocking()
+        await self.hs.get_auth().check_auth_blocking()
         need_register = True
 
         try:
-            yield self.handler.check_username(localpart)
+            await self.handler.check_username(localpart)
         except SynapseError as e:
             if e.errcode == Codes.USER_IN_USE:
                 need_register = False
@@ -288,23 +289,21 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         token = self.macaroon_generator.generate_access_token(user_id)
 
         if need_register:
-            yield self.handler.register_with_store(
+            await self.handler.register_with_store(
                 user_id=user_id,
                 password_hash=password_hash,
                 create_profile_with_displayname=user.localpart,
             )
         else:
-            yield defer.ensureDeferred(
-                self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
-            )
+            await self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
 
-        yield self.store.add_access_token_to_user(
+        await self.store.add_access_token_to_user(
             user_id=user_id, token=token, device_id=None, valid_until_ms=None
         )
 
         if displayname is not None:
             # logger.info("setting user display name: %s -> %s", user_id, displayname)
-            yield self.hs.get_profile_handler().set_displayname(
+            await self.hs.get_profile_handler().set_displayname(
                 user, requester, displayname, by_admin=True
             )
 
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py
index 83e16cfe3d..9d4f0bbe44 100644
--- a/tests/replication/tcp/streams/_base.py
+++ b/tests/replication/tcp/streams/_base.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, List, Optional, Tuple
 
 import attr
 
@@ -22,13 +22,16 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
 from twisted.internet.task import LoopingCall
 from twisted.web.http import HTTPChannel
 
-from synapse.app.generic_worker import GenericWorkerServer
+from synapse.app.generic_worker import (
+    GenericWorkerReplicationHandler,
+    GenericWorkerServer,
+)
 from synapse.http.site import SynapseRequest
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.tcp.client import ReplicationDataHandler
+from synapse.replication.http import streams
 from synapse.replication.tcp.handler import ReplicationCommandHandler
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.server import HomeServer
 from synapse.util import Clock
 
 from tests import unittest
@@ -40,6 +43,10 @@ logger = logging.getLogger(__name__)
 class BaseStreamTestCase(unittest.HomeserverTestCase):
     """Base class for tests of the replication streams"""
 
+    servlets = [
+        streams.register_servlets,
+    ]
+
     def prepare(self, reactor, clock, hs):
         # build a replication server
         server_factory = ReplicationStreamProtocolFactory(hs)
@@ -47,17 +54,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self.server = server_factory.buildProtocol(None)
 
         # Make a new HomeServer object for the worker
-        config = self.default_config()
-        config["worker_app"] = "synapse.app.generic_worker"
-        config["worker_replication_host"] = "testserv"
-        config["worker_replication_http_port"] = "8765"
-
         self.reactor.lookups["testserv"] = "1.2.3.4"
-
         self.worker_hs = self.setup_test_homeserver(
             http_client=None,
             homeserverToUse=GenericWorkerServer,
-            config=config,
+            config=self._get_worker_hs_config(),
             reactor=self.reactor,
         )
 
@@ -76,8 +77,15 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self._client_transport = None
         self._server_transport = None
 
+    def _get_worker_hs_config(self) -> dict:
+        config = self.default_config()
+        config["worker_app"] = "synapse.app.generic_worker"
+        config["worker_replication_host"] = "testserv"
+        config["worker_replication_http_port"] = "8765"
+        return config
+
     def _build_replication_data_handler(self):
-        return TestReplicationDataHandler(self.worker_hs.get_datastore())
+        return TestReplicationDataHandler(self.worker_hs)
 
     def reconnect(self):
         if self._client_transport:
@@ -172,32 +180,20 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self.assertEqual(request.method, b"GET")
 
 
-class TestReplicationDataHandler(ReplicationDataHandler):
+class TestReplicationDataHandler(GenericWorkerReplicationHandler):
     """Drop-in for ReplicationDataHandler which just collects RDATA rows"""
 
-    def __init__(self, store: BaseSlavedStore):
-        super().__init__(store)
-
-        # streams to subscribe to: map from stream id to position
-        self.stream_positions = {}  # type: Dict[str, int]
+    def __init__(self, hs: HomeServer):
+        super().__init__(hs)
 
         # list of received (stream_name, token, row) tuples
         self.received_rdata_rows = []  # type: List[Tuple[str, int, Any]]
 
-    def get_streams_to_replicate(self):
-        return self.stream_positions
-
-    async def on_rdata(self, stream_name, token, rows):
-        await super().on_rdata(stream_name, token, rows)
+    async def on_rdata(self, stream_name, instance_name, token, rows):
+        await super().on_rdata(stream_name, instance_name, token, rows)
         for r in rows:
             self.received_rdata_rows.append((stream_name, token, r))
 
-        if (
-            stream_name in self.stream_positions
-            and token > self.stream_positions[stream_name]
-        ):
-            self.stream_positions[stream_name] = token
-
 
 @attr.s()
 class OneShotRequestFactory:
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 1fa28084f9..8bd67bb9f1 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -43,7 +43,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
         self.user_tok = self.login("u1", "pass")
 
         self.reconnect()
-        self.test_handler.stream_positions["events"] = 0
 
         self.room_id = self.helper.create_room_as(tok=self.user_tok)
         self.test_handler.received_rdata_rows.clear()
@@ -80,8 +79,12 @@ class EventsStreamTestCase(BaseStreamTestCase):
         self.reconnect()
         self.replicate()
 
-        # we should have received all the expected rows in the right order
-        received_rows = self.test_handler.received_rdata_rows
+        # we should have received all the expected rows in the right order (as
+        # well as various cache invalidation updates which we ignore)
+        received_rows = [
+            row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+        ]
+
         for event in events:
             stream_name, token, row = received_rows.pop(0)
             self.assertEqual("events", stream_name)
@@ -184,7 +187,8 @@ class EventsStreamTestCase(BaseStreamTestCase):
         self.reconnect()
         self.replicate()
 
-        # now we should have received all the expected rows in the right order.
+        # we should have received all the expected rows in the right order (as
+        # well as various cache invalidation updates which we ignore)
         #
         # we expect:
         #
@@ -193,7 +197,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
         #       of the states that got reverted.
         # - two rows for state2
 
-        received_rows = self.test_handler.received_rdata_rows
+        received_rows = [
+            row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+        ]
 
         # first check the first two rows, which should be state1
 
@@ -334,9 +340,11 @@ class EventsStreamTestCase(BaseStreamTestCase):
         self.reconnect()
         self.replicate()
 
-        # we should have received all the expected rows in the right order
-
-        received_rows = self.test_handler.received_rdata_rows
+        # we should have received all the expected rows in the right order (as
+        # well as various cache invalidation updates which we ignore)
+        received_rows = [
+            row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+        ]
         self.assertGreaterEqual(len(received_rows), len(events))
         for i in range(NUM_USERS):
             # for each user, we expect the PL event row, followed by state rows for
diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py
new file mode 100644
index 0000000000..eea4565da3
--- /dev/null
+++ b/tests/replication/tcp/streams/test_federation.py
@@ -0,0 +1,81 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.federation.send_queue import EduRow
+from synapse.replication.tcp.streams.federation import FederationStream
+
+from tests.replication.tcp.streams._base import BaseStreamTestCase
+
+
+class FederationStreamTestCase(BaseStreamTestCase):
+    def _get_worker_hs_config(self) -> dict:
+        # enable federation sending on the worker
+        config = super()._get_worker_hs_config()
+        # TODO: make it so we don't need both of these
+        config["send_federation"] = True
+        config["worker_app"] = "synapse.app.federation_sender"
+        return config
+
+    def test_catchup(self):
+        """Basic test of catchup on reconnect
+
+        Makes sure that updates sent while we are offline are received later.
+        """
+        fed_sender = self.hs.get_federation_sender()
+        received_rows = self.test_handler.received_rdata_rows
+
+        fed_sender.build_and_send_edu("testdest", "m.test_edu", {"a": "b"})
+
+        self.reconnect()
+        self.reactor.advance(0)
+
+        # check we're testing what we think we are: no rows should yet have been
+        # received
+        self.assertEqual(received_rows, [])
+
+        # We should now see an attempt to connect to the master
+        request = self.handle_http_replication_attempt()
+        self.assert_request_is_get_repl_stream_updates(request, "federation")
+
+        # we should have received an update row
+        stream_name, token, row = received_rows.pop()
+        self.assertEqual(stream_name, "federation")
+        self.assertIsInstance(row, FederationStream.FederationStreamRow)
+        self.assertEqual(row.type, EduRow.TypeId)
+        edurow = EduRow.from_data(row.data)
+        self.assertEqual(edurow.edu.edu_type, "m.test_edu")
+        self.assertEqual(edurow.edu.origin, self.hs.hostname)
+        self.assertEqual(edurow.edu.destination, "testdest")
+        self.assertEqual(edurow.edu.content, {"a": "b"})
+
+        self.assertEqual(received_rows, [])
+
+        # additional updates should be transferred without an HTTP hit
+        fed_sender.build_and_send_edu("testdest", "m.test1", {"c": "d"})
+        self.reactor.advance(0)
+        # there should be no http hit
+        self.assertEqual(len(self.reactor.tcpClients), 0)
+        # ... but we should have a row
+        self.assertEqual(len(received_rows), 1)
+
+        stream_name, token, row = received_rows.pop()
+        self.assertEqual(stream_name, "federation")
+        self.assertIsInstance(row, FederationStream.FederationStreamRow)
+        self.assertEqual(row.type, EduRow.TypeId)
+        edurow = EduRow.from_data(row.data)
+        self.assertEqual(edurow.edu.edu_type, "m.test1")
+        self.assertEqual(edurow.edu.origin, self.hs.hostname)
+        self.assertEqual(edurow.edu.destination, "testdest")
+        self.assertEqual(edurow.edu.content, {"c": "d"})
diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index c122b8589c..5853314fd4 100644
--- a/tests/replication/tcp/streams/test_receipts.py
+++ b/tests/replication/tcp/streams/test_receipts.py
@@ -31,9 +31,6 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
     def test_receipt(self):
         self.reconnect()
 
-        # make the client subscribe to the receipts stream
-        self.test_handler.stream_positions.update({"receipts": 0})
-
         # tell the master to send a new receipt
         self.get_success(
             self.hs.get_datastore().insert_receipt(
@@ -44,7 +41,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
 
         # there should be one RDATA command
         self.test_handler.on_rdata.assert_called_once()
-        stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "receipts")
         self.assertEqual(1, len(rdata_rows))
         row = rdata_rows[0]  # type: ReceiptsStream.ReceiptsStreamRow
@@ -74,7 +71,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
 
         # We should now have caught up and get the missing data
         self.test_handler.on_rdata.assert_called_once()
-        stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "receipts")
         self.assertEqual(token, 3)
         self.assertEqual(1, len(rdata_rows))
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index 4d354a9db8..125c63dab5 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -15,7 +15,6 @@
 from mock import Mock
 
 from synapse.handlers.typing import RoomMember
-from synapse.replication.http import streams
 from synapse.replication.tcp.streams import TypingStream
 
 from tests.replication.tcp.streams._base import BaseStreamTestCase
@@ -24,10 +23,6 @@ USER_ID = "@feeling:blue"
 
 
 class TypingStreamTestCase(BaseStreamTestCase):
-    servlets = [
-        streams.register_servlets,
-    ]
-
     def _build_replication_data_handler(self):
         return Mock(wraps=super()._build_replication_data_handler())
 
@@ -38,9 +33,6 @@ class TypingStreamTestCase(BaseStreamTestCase):
 
         self.reconnect()
 
-        # make the client subscribe to the typing stream
-        self.test_handler.stream_positions.update({"typing": 0})
-
         typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
 
         self.reactor.advance(0)
@@ -50,7 +42,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
         self.assert_request_is_get_repl_stream_updates(request, "typing")
 
         self.test_handler.on_rdata.assert_called_once()
-        stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row = rdata_rows[0]  # type: TypingStream.TypingStreamRow
@@ -77,7 +69,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
         self.assertEqual(int(request.args[b"from_token"][0]), token)
 
         self.test_handler.on_rdata.assert_called_once()
-        stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row = rdata_rows[0]
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 93eb053b8c..406f29a7c0 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -55,26 +55,19 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         self._rlsn._store.user_last_seen_monthly_active = Mock(
             return_value=defer.succeed(1000)
         )
-        self._send_notice = self._rlsn._server_notices_manager.send_notice
-        self._rlsn._server_notices_manager.send_notice = Mock()
-        self._rlsn._state.get_current_state = Mock(return_value=defer.succeed(None))
-        self._rlsn._store.get_events = Mock(return_value=defer.succeed({}))
-
+        self._rlsn._server_notices_manager.send_notice = Mock(
+            return_value=defer.succeed(Mock())
+        )
         self._send_notice = self._rlsn._server_notices_manager.send_notice
 
         self.hs.config.limit_usage_by_mau = True
         self.user_id = "@user_id:test"
 
-        # self.server_notices_mxid = "@server:test"
-        # self.server_notices_mxid_display_name = None
-        # self.server_notices_mxid_avatar_url = None
-        # self.server_notices_room_name = "Server Notices"
-
         self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
-            returnValue=""
+            return_value=defer.succeed("!something:localhost")
         )
-        self._rlsn._store.add_tag_to_room = Mock()
-        self._rlsn._store.get_tags_for_room = Mock(return_value={})
+        self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
+        self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({}))
         self.hs.config.admin_contact = "mailto:user@test.com"
 
     def test_maybe_send_server_notice_to_user_flag_off(self):
@@ -95,14 +88,13 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
     def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
         """Test when user has blocked notice, but should have it removed"""
 
-        self._rlsn._auth.check_auth_blocking = Mock()
+        self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
         mock_event = Mock(
             type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
         )
         self._rlsn._store.get_events = Mock(
             return_value=defer.succeed({"123": mock_event})
         )
-
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
         # Would be better to check the content, but once == remove blocking event
         self._send_notice.assert_called_once()
@@ -112,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         Test when user has blocked notice, but notice ought to be there (NOOP)
         """
         self._rlsn._auth.check_auth_blocking = Mock(
-            side_effect=ResourceLimitError(403, "foo")
+            return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
         )
 
         mock_event = Mock(
@@ -121,6 +113,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         self._rlsn._store.get_events = Mock(
             return_value=defer.succeed({"123": mock_event})
         )
+
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
 
         self._send_notice.assert_not_called()
@@ -129,9 +122,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         """
         Test when user does not have blocked notice, but should have one
         """
-
         self._rlsn._auth.check_auth_blocking = Mock(
-            side_effect=ResourceLimitError(403, "foo")
+            return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
         )
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
 
@@ -142,7 +134,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         """
         Test when user does not have blocked notice, nor should they (NOOP)
         """
-        self._rlsn._auth.check_auth_blocking = Mock()
+        self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
 
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
 
@@ -153,7 +145,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         Test when user is not part of the MAU cohort - this should not ever
         happen - but ...
         """
-        self._rlsn._auth.check_auth_blocking = Mock()
+        self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
         self._rlsn._store.user_last_seen_monthly_active = Mock(
             return_value=defer.succeed(None)
         )
@@ -167,24 +159,28 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         an alert message is not sent into the room
         """
         self.hs.config.mau_limit_alerting = False
+
         self._rlsn._auth.check_auth_blocking = Mock(
+            return_value=defer.succeed(None),
             side_effect=ResourceLimitError(
                 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
-            )
+            ),
         )
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
 
-        self.assertTrue(self._send_notice.call_count == 0)
+        self.assertEqual(self._send_notice.call_count, 0)
 
     def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self):
         """
         Test that when a server is disabled, that MAU limit alerting is ignored.
         """
         self.hs.config.mau_limit_alerting = False
+
         self._rlsn._auth.check_auth_blocking = Mock(
+            return_value=defer.succeed(None),
             side_effect=ResourceLimitError(
                 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
-            )
+            ),
         )
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
 
@@ -198,10 +194,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         """
         self.hs.config.mau_limit_alerting = False
         self._rlsn._auth.check_auth_blocking = Mock(
+            return_value=defer.succeed(None),
             side_effect=ResourceLimitError(
                 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
-            )
+            ),
         )
+
         self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock(
             return_value=defer.succeed((True, []))
         )
@@ -256,7 +254,9 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
     def test_server_notice_only_sent_once(self):
         self.store.get_monthly_active_count = Mock(return_value=1000)
 
-        self.store.user_last_seen_monthly_active = Mock(return_value=1000)
+        self.store.user_last_seen_monthly_active = Mock(
+            return_value=defer.succeed(1000)
+        )
 
         # Call the function multiple times to ensure we only send the notice once
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
new file mode 100644
index 0000000000..55e9ecf264
--- /dev/null
+++ b/tests/storage/test_id_generators.py
@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.storage.database import Database
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
+
+from tests.unittest import HomeserverTestCase
+from tests.utils import USE_POSTGRES_FOR_TESTS
+
+
+class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
+    if not USE_POSTGRES_FOR_TESTS:
+        skip = "Requires Postgres"
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self.db = self.store.db  # type: Database
+
+        self.get_success(self.db.runInteraction("_setup_db", self._setup_db))
+
+    def _setup_db(self, txn):
+        txn.execute("CREATE SEQUENCE foobar_seq")
+        txn.execute(
+            """
+            CREATE TABLE foobar (
+                stream_id BIGINT NOT NULL,
+                instance_name TEXT NOT NULL,
+                data TEXT
+            );
+            """
+        )
+
+    def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+        def _create(conn):
+            return MultiWriterIdGenerator(
+                conn,
+                self.db,
+                instance_name=instance_name,
+                table="foobar",
+                instance_column="instance_name",
+                id_column="stream_id",
+                sequence_name="foobar_seq",
+            )
+
+        return self.get_success(self.db.runWithConnection(_create))
+
+    def _insert_rows(self, instance_name: str, number: int):
+        def _insert(txn):
+            for _ in range(number):
+                txn.execute(
+                    "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
+                    (instance_name,),
+                )
+
+        self.get_success(self.db.runInteraction("test_single_instance", _insert))
+
+    def test_empty(self):
+        """Test an ID generator against an empty database gives sensible
+        current positions.
+        """
+
+        id_gen = self._create_id_generator()
+
+        # The table is empty so we expect an empty map for positions
+        self.assertEqual(id_gen.get_positions(), {})
+
+    def test_single_instance(self):
+        """Test that reads and writes from a single process are handled
+        correctly.
+        """
+
+        # Prefill table with 7 rows written by 'master'
+        self._insert_rows("master", 7)
+
+        id_gen = self._create_id_generator()
+
+        self.assertEqual(id_gen.get_positions(), {"master": 7})
+        self.assertEqual(id_gen.get_current_token("master"), 7)
+
+        # Try allocating a new ID gen and check that we only see position
+        # advanced after we leave the context manager.
+
+        async def _get_next_async():
+            with await id_gen.get_next() as stream_id:
+                self.assertEqual(stream_id, 8)
+
+                self.assertEqual(id_gen.get_positions(), {"master": 7})
+                self.assertEqual(id_gen.get_current_token("master"), 7)
+
+        self.get_success(_get_next_async())
+
+        self.assertEqual(id_gen.get_positions(), {"master": 8})
+        self.assertEqual(id_gen.get_current_token("master"), 8)
+
+    def test_multi_instance(self):
+        """Test that reads and writes from multiple processes are handled
+        correctly.
+        """
+        self._insert_rows("first", 3)
+        self._insert_rows("second", 4)
+
+        first_id_gen = self._create_id_generator("first")
+        second_id_gen = self._create_id_generator("second")
+
+        self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
+        self.assertEqual(first_id_gen.get_current_token("first"), 3)
+        self.assertEqual(first_id_gen.get_current_token("second"), 7)
+
+        # Try allocating a new ID gen and check that we only see position
+        # advanced after we leave the context manager.
+
+        async def _get_next_async():
+            with await first_id_gen.get_next() as stream_id:
+                self.assertEqual(stream_id, 8)
+
+                self.assertEqual(
+                    first_id_gen.get_positions(), {"first": 3, "second": 7}
+                )
+
+        self.get_success(_get_next_async())
+
+        self.assertEqual(first_id_gen.get_positions(), {"first": 8, "second": 7})
+
+        # However the ID gen on the second instance won't have seen the update
+        self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+
+        # ... but calling `get_next` on the second instance should give a unique
+        # stream ID
+
+        async def _get_next_async():
+            with await second_id_gen.get_next() as stream_id:
+                self.assertEqual(stream_id, 9)
+
+                self.assertEqual(
+                    second_id_gen.get_positions(), {"first": 3, "second": 7}
+                )
+
+        self.get_success(_get_next_async())
+
+        self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9})
+
+        # If the second ID gen gets told about the first, it correctly updates
+        second_id_gen.advance("first", 8)
+        self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
+
+    def test_get_next_txn(self):
+        """Test that the `get_next_txn` function works correctly.
+        """
+
+        # Prefill table with 7 rows written by 'master'
+        self._insert_rows("master", 7)
+
+        id_gen = self._create_id_generator()
+
+        self.assertEqual(id_gen.get_positions(), {"master": 7})
+        self.assertEqual(id_gen.get_current_token("master"), 7)
+
+        # Try allocating a new ID gen and check that we only see position
+        # advanced after we leave the context manager.
+
+        def _get_next_txn(txn):
+            stream_id = id_gen.get_next_txn(txn)
+            self.assertEqual(stream_id, 8)
+
+            self.assertEqual(id_gen.get_positions(), {"master": 7})
+            self.assertEqual(id_gen.get_current_token("master"), 7)
+
+        self.get_success(self.db.runInteraction("test", _get_next_txn))
+
+        self.assertEqual(id_gen.get_positions(), {"master": 8})
+        self.assertEqual(id_gen.get_current_token("master"), 8)
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 9b5cf562f3..f297de95f1 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -27,8 +27,10 @@ class MessageAcceptTests(unittest.TestCase):
         user_id = UserID("us", "test")
         our_user = Requester(user_id, None, False, None, None)
         room_creator = self.homeserver.get_room_creation_handler()
-        room = room_creator.create_room(
-            our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
+        room = ensureDeferred(
+            room_creator.create_room(
+                our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
+            )
         )
         self.reactor.advance(0.1)
         self.room_id = self.successResultOf(room)["room_id"]