summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_filtering.py67
-rw-r--r--tests/appservice/test_appservice.py5
-rw-r--r--tests/appservice/test_scheduler.py11
-rw-r--r--tests/config/test_generate.py17
-rw-r--r--tests/config/test_load.py2
-rw-r--r--tests/crypto/test_keyring.py229
-rw-r--r--tests/handlers/test_appservice.py16
-rw-r--r--tests/handlers/test_device.py3
-rw-r--r--tests/handlers/test_directory.py12
-rw-r--r--tests/handlers/test_e2e_keys.py133
-rw-r--r--tests/handlers/test_presence.py2
-rw-r--r--tests/handlers/test_profile.py19
-rw-r--r--tests/handlers/test_register.py5
-rw-r--r--tests/handlers/test_typing.py18
-rw-r--r--tests/metrics/test_metric.py33
-rw-r--r--tests/replication/slave/storage/_base.py36
-rw-r--r--tests/replication/slave/storage/test_events.py11
-rw-r--r--tests/replication/test_resource.py204
-rw-r--r--tests/rest/client/v1/test_events.py12
-rw-r--r--tests/rest/client/v1/test_profile.py7
-rw-r--r--tests/rest/client/v1/test_rooms.py55
-rw-r--r--tests/rest/client/v1/test_typing.py6
-rw-r--r--tests/rest/client/v2_alpha/test_filter.py4
-rw-r--r--tests/rest/client/v2_alpha/test_register.py13
-rw-r--r--tests/rest/media/__init__.py14
-rw-r--r--tests/rest/media/v1/__init__.py14
-rw-r--r--tests/rest/media/v1/test_media_storage.py86
-rw-r--r--tests/storage/event_injector.py76
-rw-r--r--tests/storage/test__base.py4
-rw-r--r--tests/storage/test_appservice.py38
-rw-r--r--tests/storage/test_base.py6
-rw-r--r--tests/storage/test_client_ips.py10
-rw-r--r--tests/storage/test_directory.py2
-rw-r--r--tests/storage/test_event_federation.py68
-rw-r--r--tests/storage/test_event_push_actions.py154
-rw-r--r--tests/storage/test_events.py115
-rw-r--r--tests/storage/test_keys.py53
-rw-r--r--tests/storage/test_presence.py2
-rw-r--r--tests/storage/test_profile.py2
-rw-r--r--tests/storage/test_redaction.py9
-rw-r--r--tests/storage/test_registration.py6
-rw-r--r--tests/storage/test_roommember.py5
-rw-r--r--tests/storage/test_user_directory.py88
-rw-r--r--tests/test_distributor.py2
-rw-r--r--tests/test_dns.py17
-rw-r--r--tests/test_state.py156
-rw-r--r--tests/test_types.py24
-rw-r--r--tests/unittest.py6
-rw-r--r--tests/util/caches/__init__.py14
-rw-r--r--tests/util/caches/test_descriptors.py261
-rw-r--r--tests/util/test_dict_cache.py2
-rw-r--r--tests/util/test_file_consumer.py176
-rw-r--r--tests/util/test_linearizer.py29
-rw-r--r--tests/util/test_log_context.py35
-rw-r--r--tests/util/test_logcontext.py178
-rw-r--r--tests/util/test_logformatter.py38
-rw-r--r--tests/util/test_snapshot_cache.py4
-rw-r--r--tests/utils.py261
58 files changed, 2033 insertions, 842 deletions
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 50e8607c14..dcceca7f3e 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -23,6 +23,9 @@ from tests.utils import (
 
 from synapse.api.filtering import Filter
 from synapse.events import FrozenEvent
+from synapse.api.errors import SynapseError
+
+import jsonschema
 
 user_localpart = "test_user"
 
@@ -54,6 +57,70 @@ class FilteringTestCase(unittest.TestCase):
 
         self.datastore = hs.get_datastore()
 
+    def test_errors_on_invalid_filters(self):
+        invalid_filters = [
+            {"boom": {}},
+            {"account_data": "Hello World"},
+            {"event_fields": ["\\foo"]},
+            {"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}},
+            {"event_format": "other"},
+            {"room": {"not_rooms": ["#foo:pik-test"]}},
+            {"presence": {"senders": ["@bar;pik.test.com"]}}
+        ]
+        for filter in invalid_filters:
+            with self.assertRaises(SynapseError) as check_filter_error:
+                self.filtering.check_valid_filter(filter)
+                self.assertIsInstance(check_filter_error.exception, SynapseError)
+
+    def test_valid_filters(self):
+        valid_filters = [
+            {
+                "room": {
+                    "timeline": {"limit": 20},
+                    "state": {"not_types": ["m.room.member"]},
+                    "ephemeral": {"limit": 0, "not_types": ["*"]},
+                    "include_leave": False,
+                    "rooms": ["!dee:pik-test"],
+                    "not_rooms": ["!gee:pik-test"],
+                    "account_data": {"limit": 0, "types": ["*"]}
+                }
+            },
+            {
+                "room": {
+                    "state": {
+                        "types": ["m.room.*"],
+                        "not_rooms": ["!726s6s6q:example.com"]
+                    },
+                    "timeline": {
+                        "limit": 10,
+                        "types": ["m.room.message"],
+                        "not_rooms": ["!726s6s6q:example.com"],
+                        "not_senders": ["@spam:example.com"]
+                    },
+                    "ephemeral": {
+                        "types": ["m.receipt", "m.typing"],
+                        "not_rooms": ["!726s6s6q:example.com"],
+                        "not_senders": ["@spam:example.com"]
+                    }
+                },
+                "presence": {
+                    "types": ["m.presence"],
+                    "not_senders": ["@alice:example.com"]
+                },
+                "event_format": "client",
+                "event_fields": ["type", "content", "sender"]
+            }
+        ]
+        for filter in valid_filters:
+            try:
+                self.filtering.check_valid_filter(filter)
+            except jsonschema.ValidationError as e:
+                self.fail(e)
+
+    def test_limits_are_applied(self):
+        # TODO
+        pass
+
     def test_definition_types_works_with_literals(self):
         definition = {
             "types": ["m.room.message", "org.matrix.foo.bar"]
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index aa8cc50550..5b2b95860a 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -19,10 +19,12 @@ from twisted.internet import defer
 from mock import Mock
 from tests import unittest
 
+import re
+
 
 def _regex(regex, exclusive=True):
     return {
-        "regex": regex,
+        "regex": re.compile(regex),
         "exclusive": exclusive
     }
 
@@ -34,6 +36,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
             id="unique_identifier",
             url="some_url",
             token="some_token",
+            hostname="matrix.org",  # only used by get_groups_for_user
             namespaces={
                 ApplicationService.NS_USERS: [],
                 ApplicationService.NS_ROOMS: [],
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index e5a902f734..9181692771 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -17,6 +17,8 @@ from synapse.appservice.scheduler import (
     _ServiceQueuer, _TransactionController, _Recoverer
 )
 from twisted.internet import defer
+
+from synapse.util.logcontext import make_deferred_yieldable
 from ..utils import MockClock
 from mock import Mock
 from tests import unittest
@@ -204,7 +206,9 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
 
     def test_send_single_event_with_queue(self):
         d = defer.Deferred()
-        self.txn_ctrl.send = Mock(return_value=d)
+        self.txn_ctrl.send = Mock(
+            side_effect=lambda x, y: make_deferred_yieldable(d),
+        )
         service = Mock(id=4)
         event = Mock(event_id="first")
         event2 = Mock(event_id="second")
@@ -235,7 +239,10 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
         srv_2_event2 = Mock(event_id="srv2b")
 
         send_return_list = [srv_1_defer, srv_2_defer]
-        self.txn_ctrl.send = Mock(side_effect=lambda x, y: send_return_list.pop(0))
+
+        def do_send(x, y):
+            return make_deferred_yieldable(send_return_list.pop(0))
+        self.txn_ctrl.send = Mock(side_effect=do_send)
 
         # send events for different ASes and make sure they are sent
         self.queuer.enqueue(srv1, srv_1_event)
diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py
index 8f57fbeb23..879159ccea 100644
--- a/tests/config/test_generate.py
+++ b/tests/config/test_generate.py
@@ -12,9 +12,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 import os.path
+import re
 import shutil
 import tempfile
+
 from synapse.config.homeserver import HomeServerConfig
 from tests import unittest
 
@@ -23,7 +26,6 @@ class ConfigGenerationTestCase(unittest.TestCase):
 
     def setUp(self):
         self.dir = tempfile.mkdtemp()
-        print self.dir
         self.file = os.path.join(self.dir, "homeserver.yaml")
 
     def tearDown(self):
@@ -48,3 +50,16 @@ class ConfigGenerationTestCase(unittest.TestCase):
             ]),
             set(os.listdir(self.dir))
         )
+
+        self.assert_log_filename_is(
+            os.path.join(self.dir, "lemurs.win.log.config"),
+            os.path.join(os.getcwd(), "homeserver.log"),
+        )
+
+    def assert_log_filename_is(self, log_config_file, expected):
+        with open(log_config_file) as f:
+            config = f.read()
+            # find the 'filename' line
+            matches = re.findall("^\s*filename:\s*(.*)$", config, re.M)
+            self.assertEqual(1, len(matches))
+            self.assertEqual(matches[0], expected)
diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index 161a87d7e3..772afd2cf9 100644
--- a/tests/config/test_load.py
+++ b/tests/config/test_load.py
@@ -24,7 +24,7 @@ class ConfigLoadingTestCase(unittest.TestCase):
 
     def setUp(self):
         self.dir = tempfile.mkdtemp()
-        print self.dir
+        print(self.dir)
         self.file = os.path.join(self.dir, "homeserver.yaml")
 
     def tearDown(self):
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
new file mode 100644
index 0000000000..149e443022
--- /dev/null
+++ b/tests/crypto/test_keyring.py
@@ -0,0 +1,229 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import time
+
+import signedjson.key
+import signedjson.sign
+from mock import Mock
+from synapse.api.errors import SynapseError
+from synapse.crypto import keyring
+from synapse.util import async, logcontext
+from synapse.util.logcontext import LoggingContext
+from tests import unittest, utils
+from twisted.internet import defer
+
+
+class MockPerspectiveServer(object):
+    def __init__(self):
+        self.server_name = "mock_server"
+        self.key = signedjson.key.generate_signing_key(0)
+
+    def get_verify_keys(self):
+        vk = signedjson.key.get_verify_key(self.key)
+        return {
+            "%s:%s" % (vk.alg, vk.version): vk,
+        }
+
+    def get_signed_key(self, server_name, verify_key):
+        key_id = "%s:%s" % (verify_key.alg, verify_key.version)
+        res = {
+            "server_name": server_name,
+            "old_verify_keys": {},
+            "valid_until_ts": time.time() * 1000 + 3600,
+            "verify_keys": {
+                key_id: {
+                    "key": signedjson.key.encode_verify_key_base64(verify_key)
+                }
+            }
+        }
+        signedjson.sign.sign_json(res, self.server_name, self.key)
+        return res
+
+
+class KeyringTestCase(unittest.TestCase):
+    @defer.inlineCallbacks
+    def setUp(self):
+        self.mock_perspective_server = MockPerspectiveServer()
+        self.http_client = Mock()
+        self.hs = yield utils.setup_test_homeserver(
+            handlers=None,
+            http_client=self.http_client,
+        )
+        self.hs.config.perspectives = {
+            self.mock_perspective_server.server_name:
+                self.mock_perspective_server.get_verify_keys()
+        }
+
+    def check_context(self, _, expected):
+        self.assertEquals(
+            getattr(LoggingContext.current_context(), "request", None),
+            expected
+        )
+
+    @defer.inlineCallbacks
+    def test_wait_for_previous_lookups(self):
+        sentinel_context = LoggingContext.current_context()
+
+        kr = keyring.Keyring(self.hs)
+
+        lookup_1_deferred = defer.Deferred()
+        lookup_2_deferred = defer.Deferred()
+
+        with LoggingContext("one") as context_one:
+            context_one.request = "one"
+
+            wait_1_deferred = kr.wait_for_previous_lookups(
+                ["server1"],
+                {"server1": lookup_1_deferred},
+            )
+
+            # there were no previous lookups, so the deferred should be ready
+            self.assertTrue(wait_1_deferred.called)
+            # ... so we should have preserved the LoggingContext.
+            self.assertIs(LoggingContext.current_context(), context_one)
+            wait_1_deferred.addBoth(self.check_context, "one")
+
+        with LoggingContext("two") as context_two:
+            context_two.request = "two"
+
+            # set off another wait. It should block because the first lookup
+            # hasn't yet completed.
+            wait_2_deferred = kr.wait_for_previous_lookups(
+                ["server1"],
+                {"server1": lookup_2_deferred},
+            )
+            self.assertFalse(wait_2_deferred.called)
+            # ... so we should have reset the LoggingContext.
+            self.assertIs(LoggingContext.current_context(), sentinel_context)
+            wait_2_deferred.addBoth(self.check_context, "two")
+
+            # let the first lookup complete (in the sentinel context)
+            lookup_1_deferred.callback(None)
+
+            # now the second wait should complete and restore our
+            # loggingcontext.
+            yield wait_2_deferred
+
+    @defer.inlineCallbacks
+    def test_verify_json_objects_for_server_awaits_previous_requests(self):
+        key1 = signedjson.key.generate_signing_key(1)
+
+        kr = keyring.Keyring(self.hs)
+        json1 = {}
+        signedjson.sign.sign_json(json1, "server10", key1)
+
+        persp_resp = {
+            "server_keys": [
+                self.mock_perspective_server.get_signed_key(
+                    "server10",
+                    signedjson.key.get_verify_key(key1)
+                ),
+            ]
+        }
+        persp_deferred = defer.Deferred()
+
+        @defer.inlineCallbacks
+        def get_perspectives(**kwargs):
+            self.assertEquals(
+                LoggingContext.current_context().request, "11",
+            )
+            with logcontext.PreserveLoggingContext():
+                yield persp_deferred
+            defer.returnValue(persp_resp)
+        self.http_client.post_json.side_effect = get_perspectives
+
+        with LoggingContext("11") as context_11:
+            context_11.request = "11"
+
+            # start off a first set of lookups
+            res_deferreds = kr.verify_json_objects_for_server(
+                [("server10", json1),
+                 ("server11", {})
+                 ]
+            )
+
+            # the unsigned json should be rejected pretty quickly
+            self.assertTrue(res_deferreds[1].called)
+            try:
+                yield res_deferreds[1]
+                self.assertFalse("unsigned json didn't cause a failure")
+            except SynapseError:
+                pass
+
+            self.assertFalse(res_deferreds[0].called)
+            res_deferreds[0].addBoth(self.check_context, None)
+
+            # wait a tick for it to send the request to the perspectives server
+            # (it first tries the datastore)
+            yield async.sleep(1)   # XXX find out why this takes so long!
+            self.http_client.post_json.assert_called_once()
+
+            self.assertIs(LoggingContext.current_context(), context_11)
+
+            context_12 = LoggingContext("12")
+            context_12.request = "12"
+            with logcontext.PreserveLoggingContext(context_12):
+                # a second request for a server with outstanding requests
+                # should block rather than start a second call
+                self.http_client.post_json.reset_mock()
+                self.http_client.post_json.return_value = defer.Deferred()
+
+                res_deferreds_2 = kr.verify_json_objects_for_server(
+                    [("server10", json1)],
+                )
+                yield async.sleep(1)
+                self.http_client.post_json.assert_not_called()
+                res_deferreds_2[0].addBoth(self.check_context, None)
+
+            # complete the first request
+            with logcontext.PreserveLoggingContext():
+                persp_deferred.callback(persp_resp)
+            self.assertIs(LoggingContext.current_context(), context_11)
+
+            with logcontext.PreserveLoggingContext():
+                yield res_deferreds[0]
+                yield res_deferreds_2[0]
+
+    @defer.inlineCallbacks
+    def test_verify_json_for_server(self):
+        kr = keyring.Keyring(self.hs)
+
+        key1 = signedjson.key.generate_signing_key(1)
+        yield self.hs.datastore.store_server_verify_key(
+            "server9", "", time.time() * 1000,
+            signedjson.key.get_verify_key(key1),
+        )
+        json1 = {}
+        signedjson.sign.sign_json(json1, "server9", key1)
+
+        sentinel_context = LoggingContext.current_context()
+
+        with LoggingContext("one") as context_one:
+            context_one.request = "one"
+
+            defer = kr.verify_json_for_server("server9", {})
+            try:
+                yield defer
+                self.fail("should fail on unsigned json")
+            except SynapseError:
+                pass
+            self.assertIs(LoggingContext.current_context(), context_one)
+
+            defer = kr.verify_json_for_server("server9", json1)
+            self.assertFalse(defer.called)
+            self.assertIs(LoggingContext.current_context(), sentinel_context)
+            yield defer
+
+            self.assertIs(LoggingContext.current_context(), context_one)
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 7fe88172c0..b753455943 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -31,6 +31,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         self.mock_scheduler = Mock()
         hs = Mock()
         hs.get_datastore = Mock(return_value=self.mock_store)
+        self.mock_store.get_received_ts.return_value = 0
         hs.get_application_service_api = Mock(return_value=self.mock_as_api)
         hs.get_application_service_scheduler = Mock(return_value=self.mock_scheduler)
         hs.get_clock.return_value = MockClock()
@@ -53,7 +54,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             type="m.room.message",
             room_id="!foo:bar"
         )
-        self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
+        self.mock_store.get_new_events_for_appservice.side_effect = [
+            (0, [event]),
+            (0, [])
+        ]
         self.mock_as_api.push = Mock()
         yield self.handler.notify_interested_services(0)
         self.mock_scheduler.submit_event_for_as.assert_called_once_with(
@@ -75,7 +79,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         )
         self.mock_as_api.push = Mock()
         self.mock_as_api.query_user = Mock()
-        self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
+        self.mock_store.get_new_events_for_appservice.side_effect = [
+            (0, [event]),
+            (0, [])
+        ]
         yield self.handler.notify_interested_services(0)
         self.mock_as_api.query_user.assert_called_once_with(
             services[0], user_id
@@ -98,7 +105,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         )
         self.mock_as_api.push = Mock()
         self.mock_as_api.query_user = Mock()
-        self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
+        self.mock_store.get_new_events_for_appservice.side_effect = [
+            (0, [event]),
+            (0, [])
+        ]
         yield self.handler.notify_interested_services(0)
         self.assertFalse(
             self.mock_as_api.query_user.called,
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 2eaaa8253c..778ff2f6e9 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -19,7 +19,6 @@ import synapse.api.errors
 import synapse.handlers.device
 
 import synapse.storage
-from synapse import types
 from tests import unittest, utils
 
 user1 = "@boris:aaa"
@@ -179,6 +178,6 @@ class DeviceTestCase(unittest.TestCase):
 
         if ip is not None:
             yield self.store.insert_client_ip(
-                types.UserID.from_string(user_id),
+                user_id,
                 access_token, ip, "user_agent", device_id)
             self.clock.advance_time(1000)
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index ceb9aa5765..7e5332e272 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -35,21 +35,20 @@ class DirectoryTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def setUp(self):
-        self.mock_federation = Mock(spec=[
-            "make_query",
-            "register_edu_handler",
-        ])
+        self.mock_federation = Mock()
+        self.mock_registry = Mock()
 
         self.query_handlers = {}
 
         def register_query_handler(query_type, handler):
             self.query_handlers[query_type] = handler
-        self.mock_federation.register_query_handler = register_query_handler
+        self.mock_registry.register_query_handler = register_query_handler
 
         hs = yield setup_test_homeserver(
             http_client=None,
             resource_for_federation=Mock(),
-            replication_layer=self.mock_federation,
+            federation_client=self.mock_federation,
+            federation_registry=self.mock_registry,
         )
         hs.handlers = DirectoryHandlers(hs)
 
@@ -93,6 +92,7 @@ class DirectoryTestCase(unittest.TestCase):
                 "room_alias": "#another:remote",
             },
             retry_on_dns_fail=False,
+            ignore_backoff=True,
         )
 
     @defer.inlineCallbacks
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 878a54dc34..d1bd87b898 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import mock
+from synapse.api import errors
 from twisted.internet import defer
 
 import synapse.api.errors
@@ -33,7 +34,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
     def setUp(self):
         self.hs = yield utils.setup_test_homeserver(
             handlers=None,
-            replication_layer=mock.Mock(),
+            federation_client=mock.Mock(),
         )
         self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs)
 
@@ -44,3 +45,133 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         local_user = "@boris:" + self.hs.hostname
         res = yield self.handler.query_local_devices({local_user: None})
         self.assertDictEqual(res, {local_user: {}})
+
+    @defer.inlineCallbacks
+    def test_reupload_one_time_keys(self):
+        """we should be able to re-upload the same keys"""
+        local_user = "@boris:" + self.hs.hostname
+        device_id = "xyz"
+        keys = {
+            "alg1:k1": "key1",
+            "alg2:k2": {
+                "key": "key2",
+                "signatures": {"k1": "sig1"}
+            },
+            "alg2:k3": {
+                "key": "key3",
+            },
+        }
+
+        res = yield self.handler.upload_keys_for_user(
+            local_user, device_id, {"one_time_keys": keys},
+        )
+        self.assertDictEqual(res, {
+            "one_time_key_counts": {"alg1": 1, "alg2": 2}
+        })
+
+        # we should be able to change the signature without a problem
+        keys["alg2:k2"]["signatures"]["k1"] = "sig2"
+        res = yield self.handler.upload_keys_for_user(
+            local_user, device_id, {"one_time_keys": keys},
+        )
+        self.assertDictEqual(res, {
+            "one_time_key_counts": {"alg1": 1, "alg2": 2}
+        })
+
+    @defer.inlineCallbacks
+    def test_change_one_time_keys(self):
+        """attempts to change one-time-keys should be rejected"""
+
+        local_user = "@boris:" + self.hs.hostname
+        device_id = "xyz"
+        keys = {
+            "alg1:k1": "key1",
+            "alg2:k2": {
+                "key": "key2",
+                "signatures": {"k1": "sig1"}
+            },
+            "alg2:k3": {
+                "key": "key3",
+            },
+        }
+
+        res = yield self.handler.upload_keys_for_user(
+            local_user, device_id, {"one_time_keys": keys},
+        )
+        self.assertDictEqual(res, {
+            "one_time_key_counts": {"alg1": 1, "alg2": 2}
+        })
+
+        try:
+            yield self.handler.upload_keys_for_user(
+                local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}},
+            )
+            self.fail("No error when changing string key")
+        except errors.SynapseError:
+            pass
+
+        try:
+            yield self.handler.upload_keys_for_user(
+                local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}},
+            )
+            self.fail("No error when replacing dict key with string")
+        except errors.SynapseError:
+            pass
+
+        try:
+            yield self.handler.upload_keys_for_user(
+                local_user, device_id, {
+                    "one_time_keys": {"alg1:k1": {"key": "key"}}
+                },
+            )
+            self.fail("No error when replacing string key with dict")
+        except errors.SynapseError:
+            pass
+
+        try:
+            yield self.handler.upload_keys_for_user(
+                local_user, device_id, {
+                    "one_time_keys": {
+                        "alg2:k2": {
+                            "key": "key3",
+                            "signatures": {"k1": "sig1"},
+                        }
+                    },
+                },
+            )
+            self.fail("No error when replacing dict key")
+        except errors.SynapseError:
+            pass
+
+    @defer.inlineCallbacks
+    def test_claim_one_time_key(self):
+        local_user = "@boris:" + self.hs.hostname
+        device_id = "xyz"
+        keys = {
+            "alg1:k1": "key1",
+        }
+
+        res = yield self.handler.upload_keys_for_user(
+            local_user, device_id, {"one_time_keys": keys},
+        )
+        self.assertDictEqual(res, {
+            "one_time_key_counts": {"alg1": 1}
+        })
+
+        res2 = yield self.handler.claim_one_time_keys({
+            "one_time_keys": {
+                local_user: {
+                    device_id: "alg1"
+                }
+            }
+        }, timeout=None)
+        self.assertEqual(res2, {
+            "failures": {},
+            "one_time_keys": {
+                local_user: {
+                    device_id: {
+                        "alg1:k1": "key1"
+                    }
+                }
+            }
+        })
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index d9e8f634ae..de06a6ad30 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -324,7 +324,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
         state = UserPresenceState.default(user_id)
         state = state.copy_and_replace(
             state=PresenceState.ONLINE,
-            last_active_ts=now,
+            last_active_ts=0,
             last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1,
         )
 
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 979cebf600..458296ee4c 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -37,23 +37,23 @@ class ProfileTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def setUp(self):
-        self.mock_federation = Mock(spec=[
-            "make_query",
-            "register_edu_handler",
-        ])
+        self.mock_federation = Mock()
+        self.mock_registry = Mock()
 
         self.query_handlers = {}
 
         def register_query_handler(query_type, handler):
             self.query_handlers[query_type] = handler
 
-        self.mock_federation.register_query_handler = register_query_handler
+        self.mock_registry.register_query_handler = register_query_handler
 
         hs = yield setup_test_homeserver(
             http_client=None,
             handlers=None,
             resource_for_federation=Mock(),
-            replication_layer=self.mock_federation,
+            federation_client=self.mock_federation,
+            federation_server=Mock(),
+            federation_registry=self.mock_registry,
             ratelimiter=NonCallableMock(spec_set=[
                 "send_message",
             ])
@@ -62,8 +62,6 @@ class ProfileTestCase(unittest.TestCase):
         self.ratelimiter = hs.get_ratelimiter()
         self.ratelimiter.send_message.return_value = (True, 0)
 
-        hs.handlers = ProfileHandlers(hs)
-
         self.store = hs.get_datastore()
 
         self.frank = UserID.from_string("@1234ABCD:test")
@@ -72,7 +70,7 @@ class ProfileTestCase(unittest.TestCase):
 
         yield self.store.create_profile(self.frank.localpart)
 
-        self.handler = hs.get_handlers().profile_handler
+        self.handler = hs.get_profile_handler()
 
     @defer.inlineCallbacks
     def test_get_my_name(self):
@@ -119,7 +117,8 @@ class ProfileTestCase(unittest.TestCase):
         self.mock_federation.make_query.assert_called_with(
             destination="remote",
             query_type="profile",
-            args={"user_id": "@alice:remote", "field": "displayname"}
+            args={"user_id": "@alice:remote", "field": "displayname"},
+            ignore_backoff=True,
         )
 
     @defer.inlineCallbacks
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index c8cf9a63ec..e990e45220 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -40,13 +40,14 @@ class RegistrationTestCase(unittest.TestCase):
         self.hs = yield setup_test_homeserver(
             handlers=None,
             http_client=None,
-            expire_access_token=True)
+            expire_access_token=True,
+            profile_handler=Mock(),
+        )
         self.macaroon_generator = Mock(
             generate_access_token=Mock(return_value='secret'))
         self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
         self.hs.handlers = RegistrationHandlers(self.hs)
         self.handler = self.hs.get_handlers().registration_handler
-        self.hs.get_handlers().profile_handler = Mock()
 
     @defer.inlineCallbacks
     def test_user_is_created_and_logged_in_if_doesnt_exist(self):
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index f88d2be7c5..a433bbfa8a 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -58,7 +58,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
 
         self.mock_federation_resource = MockHttpResource()
 
-        mock_notifier = Mock(spec=["on_new_event"])
+        mock_notifier = Mock()
         self.on_new_event = mock_notifier.on_new_event
 
         self.auth = Mock(spec=[])
@@ -76,9 +76,12 @@ class TypingNotificationsTestCase(unittest.TestCase):
                 "set_received_txn_response",
                 "get_destination_retry_timings",
                 "get_devices_by_remote",
+                # Bits that user_directory needs
+                "get_user_directory_stream_pos",
+                "get_current_state_deltas",
             ]),
             state_handler=self.state_handler,
-            handlers=None,
+            handlers=Mock(),
             notifier=mock_notifier,
             resource_for_client=Mock(),
             resource_for_federation=self.mock_federation_resource,
@@ -122,6 +125,15 @@ class TypingNotificationsTestCase(unittest.TestCase):
             return set(str(u) for u in self.room_members)
         self.state_handler.get_current_user_in_room = get_current_user_in_room
 
+        self.datastore.get_user_directory_stream_pos.return_value = (
+            # we deliberately return a non-None stream pos to avoid doing an initial_spam
+            defer.succeed(1)
+        )
+
+        self.datastore.get_current_state_deltas.return_value = (
+            None
+        )
+
         self.auth.check_joined_room = check_joined_room
 
         self.datastore.get_to_device_stream_token = lambda: 0
@@ -192,6 +204,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
                 ),
                 json_data_callback=ANY,
                 long_retries=True,
+                backoff_on_404=True,
             ),
             defer.succeed((200, "OK"))
         )
@@ -263,6 +276,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
                 ),
                 json_data_callback=ANY,
                 long_retries=True,
+                backoff_on_404=True,
             ),
             defer.succeed((200, "OK"))
         )
diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py
index f85455a5af..069c0be762 100644
--- a/tests/metrics/test_metric.py
+++ b/tests/metrics/test_metric.py
@@ -16,7 +16,8 @@
 from tests import unittest
 
 from synapse.metrics.metric import (
-    CounterMetric, CallbackMetric, DistributionMetric, CacheMetric
+    CounterMetric, CallbackMetric, DistributionMetric, CacheMetric,
+    _escape_label_value,
 )
 
 
@@ -141,6 +142,7 @@ class CacheMetricTestCase(unittest.TestCase):
             'cache:hits{name="cache_name"} 0',
             'cache:total{name="cache_name"} 0',
             'cache:size{name="cache_name"} 0',
+            'cache:evicted_size{name="cache_name"} 0',
         ])
 
         metric.inc_misses()
@@ -150,6 +152,7 @@ class CacheMetricTestCase(unittest.TestCase):
             'cache:hits{name="cache_name"} 0',
             'cache:total{name="cache_name"} 1',
             'cache:size{name="cache_name"} 1',
+            'cache:evicted_size{name="cache_name"} 0',
         ])
 
         metric.inc_hits()
@@ -158,4 +161,32 @@ class CacheMetricTestCase(unittest.TestCase):
             'cache:hits{name="cache_name"} 1',
             'cache:total{name="cache_name"} 2',
             'cache:size{name="cache_name"} 1',
+            'cache:evicted_size{name="cache_name"} 0',
         ])
+
+        metric.inc_evictions(2)
+
+        self.assertEquals(metric.render(), [
+            'cache:hits{name="cache_name"} 1',
+            'cache:total{name="cache_name"} 2',
+            'cache:size{name="cache_name"} 1',
+            'cache:evicted_size{name="cache_name"} 2',
+        ])
+
+
+class LabelValueEscapeTestCase(unittest.TestCase):
+    def test_simple(self):
+        string = "safjhsdlifhyskljfksdfh"
+        self.assertEqual(string, _escape_label_value(string))
+
+    def test_escape(self):
+        self.assertEqual(
+            "abc\\\"def\\nghi\\\\",
+            _escape_label_value("abc\"def\nghi\\"),
+        )
+
+    def test_sequence_of_escapes(self):
+        self.assertEqual(
+            "abc\\\"def\\nghi\\\\\\n",
+            _escape_label_value("abc\"def\nghi\\\n"),
+        )
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index b82868054d..64e07a8c93 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -12,12 +12,17 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
+from twisted.internet import defer, reactor
 from tests import unittest
 
+import tempfile
+
 from mock import Mock, NonCallableMock
 from tests.utils import setup_test_homeserver
-from synapse.replication.resource import ReplicationResource
+from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.replication.tcp.client import (
+    ReplicationClientHandler, ReplicationClientFactory,
+)
 
 
 class BaseSlavedStoreTestCase(unittest.TestCase):
@@ -26,25 +31,38 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
         self.hs = yield setup_test_homeserver(
             "blue",
             http_client=None,
-            replication_layer=Mock(),
+            federation_client=Mock(),
             ratelimiter=NonCallableMock(spec_set=[
                 "send_message",
             ]),
         )
         self.hs.get_ratelimiter().send_message.return_value = (True, 0)
 
-        self.replication = ReplicationResource(self.hs)
-
         self.master_store = self.hs.get_datastore()
         self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
         self.event_id = 0
 
+        server_factory = ReplicationStreamProtocolFactory(self.hs)
+        # XXX: mktemp is unsafe and should never be used. but we're just a test.
+        path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket")
+        listener = reactor.listenUNIX(path, server_factory)
+        self.addCleanup(listener.stopListening)
+        self.streamer = server_factory.streamer
+
+        self.replication_handler = ReplicationClientHandler(self.slaved_store)
+        client_factory = ReplicationClientFactory(
+            self.hs, "client_name", self.replication_handler
+        )
+        client_connector = reactor.connectUNIX(path, client_factory)
+        self.addCleanup(client_factory.stopTrying)
+        self.addCleanup(client_connector.disconnect)
+
     @defer.inlineCallbacks
     def replicate(self):
-        streams = self.slaved_store.stream_positions()
-        writer = yield self.replication.replicate(streams, 100)
-        result = writer.finish()
-        yield self.slaved_store.process_replication(result)
+        yield self.streamer.on_notifier_poke()
+        d = self.replication_handler.await_sync("replication_test")
+        self.streamer.send_sync_to_all_connections("replication_test")
+        yield d
 
     @defer.inlineCallbacks
     def check(self, method, args, expected_result=None):
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 105e1228bb..cb058d3142 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -226,13 +226,16 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             context = EventContext()
             context.current_state_ids = state_ids
             context.prev_state_ids = state_ids
-        elif not backfill:
+        else:
             state_handler = self.hs.get_state_handler()
             context = yield state_handler.compute_event_context(event)
-        else:
-            context = EventContext()
 
-        context.push_actions = push_actions
+        yield self.master_store.add_push_actions_to_staging(
+            event.event_id, {
+                user_id: actions
+                for user_id, actions in push_actions
+            },
+        )
 
         ordering = None
         if backfill:
diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py
deleted file mode 100644
index 93b9fad012..0000000000
--- a/tests/replication/test_resource.py
+++ /dev/null
@@ -1,204 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import contextlib
-import json
-
-from mock import Mock, NonCallableMock
-from twisted.internet import defer
-
-import synapse.types
-from synapse.replication.resource import ReplicationResource
-from synapse.types import UserID
-from tests import unittest
-from tests.utils import setup_test_homeserver
-
-
-class ReplicationResourceCase(unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        self.hs = yield setup_test_homeserver(
-            "red",
-            http_client=None,
-            replication_layer=Mock(),
-            ratelimiter=NonCallableMock(spec_set=[
-                "send_message",
-            ]),
-        )
-        self.user_id = "@seeing:red"
-        self.user = UserID.from_string(self.user_id)
-
-        self.hs.get_ratelimiter().send_message.return_value = (True, 0)
-
-        self.resource = ReplicationResource(self.hs)
-
-    @defer.inlineCallbacks
-    def test_streams(self):
-        # Passing "-1" returns the current stream positions
-        code, body = yield self.get(streams="-1")
-        self.assertEquals(code, 200)
-        self.assertEquals(body["streams"]["field_names"], ["name", "position"])
-        position = body["streams"]["position"]
-        # Passing the current position returns an empty response after the
-        # timeout
-        get = self.get(streams=str(position), timeout="0")
-        self.hs.clock.advance_time_msec(1)
-        code, body = yield get
-        self.assertEquals(code, 200)
-        self.assertEquals(body, {})
-
-    @defer.inlineCallbacks
-    def test_events(self):
-        get = self.get(events="-1", timeout="0")
-        yield self.hs.get_handlers().room_creation_handler.create_room(
-            synapse.types.create_requester(self.user), {}
-        )
-        code, body = yield get
-        self.assertEquals(code, 200)
-        self.assertEquals(body["events"]["field_names"], [
-            "position", "internal", "json", "state_group"
-        ])
-
-    @defer.inlineCallbacks
-    def test_presence(self):
-        get = self.get(presence="-1")
-        yield self.hs.get_presence_handler().set_state(
-            self.user, {"presence": "online"}
-        )
-        code, body = yield get
-        self.assertEquals(code, 200)
-        self.assertEquals(body["presence"]["field_names"], [
-            "position", "user_id", "state", "last_active_ts",
-            "last_federation_update_ts", "last_user_sync_ts",
-            "status_msg", "currently_active",
-        ])
-
-    @defer.inlineCallbacks
-    def test_typing(self):
-        room_id = yield self.create_room()
-        get = self.get(typing="-1")
-        yield self.hs.get_typing_handler().started_typing(
-            self.user, self.user, room_id, timeout=2
-        )
-        code, body = yield get
-        self.assertEquals(code, 200)
-        self.assertEquals(body["typing"]["field_names"], [
-            "position", "room_id", "typing"
-        ])
-
-    @defer.inlineCallbacks
-    def test_receipts(self):
-        room_id = yield self.create_room()
-        event_id = yield self.send_text_message(room_id, "Hello, World")
-        get = self.get(receipts="-1")
-        yield self.hs.get_receipts_handler().received_client_receipt(
-            room_id, "m.read", self.user_id, event_id
-        )
-        code, body = yield get
-        self.assertEquals(code, 200)
-        self.assertEquals(body["receipts"]["field_names"], [
-            "position", "room_id", "receipt_type", "user_id", "event_id", "data"
-        ])
-
-    def _test_timeout(stream):
-        """Check that a request for the given stream timesout"""
-        @defer.inlineCallbacks
-        def test_timeout(self):
-            get = self.get(**{stream: "-1", "timeout": "0"})
-            self.hs.clock.advance_time_msec(1)
-            code, body = yield get
-            self.assertEquals(code, 200)
-            self.assertEquals(body.get("rows", []), [])
-        test_timeout.__name__ = "test_timeout_%s" % (stream)
-        return test_timeout
-
-    test_timeout_events = _test_timeout("events")
-    test_timeout_presence = _test_timeout("presence")
-    test_timeout_typing = _test_timeout("typing")
-    test_timeout_receipts = _test_timeout("receipts")
-    test_timeout_user_account_data = _test_timeout("user_account_data")
-    test_timeout_room_account_data = _test_timeout("room_account_data")
-    test_timeout_tag_account_data = _test_timeout("tag_account_data")
-    test_timeout_backfill = _test_timeout("backfill")
-    test_timeout_push_rules = _test_timeout("push_rules")
-    test_timeout_pushers = _test_timeout("pushers")
-    test_timeout_state = _test_timeout("state")
-
-    @defer.inlineCallbacks
-    def send_text_message(self, room_id, message):
-        handler = self.hs.get_handlers().message_handler
-        event = yield handler.create_and_send_nonmember_event(
-            synapse.types.create_requester(self.user),
-            {
-                "type": "m.room.message",
-                "content": {"body": "message", "msgtype": "m.text"},
-                "room_id": room_id,
-                "sender": self.user.to_string(),
-            }
-        )
-        defer.returnValue(event.event_id)
-
-    @defer.inlineCallbacks
-    def create_room(self):
-        result = yield self.hs.get_handlers().room_creation_handler.create_room(
-            synapse.types.create_requester(self.user), {}
-        )
-        defer.returnValue(result["room_id"])
-
-    @defer.inlineCallbacks
-    def get(self, **params):
-        request = NonCallableMock(spec_set=[
-            "write", "finish", "setResponseCode", "setHeader", "args",
-            "method", "processing"
-        ])
-
-        request.method = "GET"
-        request.args = {k: [v] for k, v in params.items()}
-
-        @contextlib.contextmanager
-        def processing():
-            yield
-        request.processing = processing
-
-        yield self.resource._async_render_GET(request)
-        self.assertTrue(request.finish.called)
-
-        if request.setResponseCode.called:
-            response_code = request.setResponseCode.call_args[0][0]
-        else:
-            response_code = 200
-
-        response_json = "".join(
-            call[0][0] for call in request.write.call_args_list
-        )
-        response_body = json.loads(response_json)
-
-        if response_code == 200:
-            self.check_response(response_body)
-
-        defer.returnValue((response_code, response_body))
-
-    def check_response(self, response_body):
-        for name, stream in response_body.items():
-            self.assertIn("field_names", stream)
-            field_names = stream["field_names"]
-            self.assertIn("rows", stream)
-            for row in stream["rows"]:
-                self.assertEquals(
-                    len(row), len(field_names),
-                    "%s: len(row = %r) == len(field_names = %r)" % (
-                        name, row, field_names
-                    )
-                )
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index e9698bfdc9..f5a7258e68 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -114,7 +114,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
 
         hs = yield setup_test_homeserver(
             http_client=None,
-            replication_layer=Mock(),
+            federation_client=Mock(),
             ratelimiter=NonCallableMock(spec_set=[
                 "send_message",
             ]),
@@ -123,6 +123,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
         self.ratelimiter.send_message.return_value = (True, 0)
         hs.config.enable_registration_captcha = False
         hs.config.enable_registration = True
+        hs.config.auto_join_rooms = []
 
         hs.get_handlers().federation_handler = Mock()
 
@@ -147,11 +148,16 @@ class EventStreamPermissionsTestCase(RestTestCase):
 
     @defer.inlineCallbacks
     def test_stream_basic_permissions(self):
-        # invalid token, expect 403
+        # invalid token, expect 401
+        # note: this is in violation of the original v1 spec, which expected
+        # 403. However, since the v1 spec no longer exists and the v1
+        # implementation is now part of the r0 implementation, the newer
+        # behaviour is used instead to be consistent with the r0 spec.
+        # see issue #2602
         (code, response) = yield self.mock_resource.trigger_get(
             "/events?access_token=%s" % ("invalid" + self.token, )
         )
-        self.assertEquals(403, code, msg=str(response))
+        self.assertEquals(401, code, msg=str(response))
 
         # valid token, expect content
         (code, response) = yield self.mock_resource.trigger_get(
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 1e95e97538..dc94b8bd19 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -45,15 +45,14 @@ class ProfileTestCase(unittest.TestCase):
             http_client=None,
             resource_for_client=self.mock_resource,
             federation=Mock(),
-            replication_layer=Mock(),
+            federation_client=Mock(),
+            profile_handler=self.mock_handler
         )
 
         def _get_user_by_req(request=None, allow_guest=False):
             return synapse.types.create_requester(myid)
 
-        hs.get_v1auth().get_user_by_req = _get_user_by_req
-
-        hs.get_handlers().profile_handler = self.mock_handler
+        hs.get_auth().get_user_by_req = _get_user_by_req
 
         profile.register_servlets(hs, self.mock_resource)
 
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index d746ea8568..61d737725b 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -24,7 +24,7 @@ from synapse.api.constants import Membership
 from synapse.types import UserID
 
 import json
-import urllib
+from six.moves.urllib import parse as urlparse
 
 from ....utils import MockHttpResource, setup_test_homeserver
 from .utils import RestTestCase
@@ -46,7 +46,7 @@ class RoomPermissionsTestCase(RestTestCase):
         hs = yield setup_test_homeserver(
             "red",
             http_client=None,
-            replication_layer=Mock(),
+            federation_client=Mock(),
             ratelimiter=NonCallableMock(spec_set=["send_message"]),
         )
         self.ratelimiter = hs.get_ratelimiter()
@@ -60,7 +60,7 @@ class RoomPermissionsTestCase(RestTestCase):
                 "token_id": 1,
                 "is_guest": False,
             }
-        hs.get_v1auth().get_user_by_access_token = get_user_by_access_token
+        hs.get_auth().get_user_by_access_token = get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
@@ -70,7 +70,7 @@ class RoomPermissionsTestCase(RestTestCase):
 
         synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
 
-        self.auth = hs.get_v1auth()
+        self.auth = hs.get_auth()
 
         # create some rooms under the name rmcreator_id
         self.uncreated_rmid = "!aa:test"
@@ -409,7 +409,7 @@ class RoomsMemberListTestCase(RestTestCase):
         hs = yield setup_test_homeserver(
             "red",
             http_client=None,
-            replication_layer=Mock(),
+            federation_client=Mock(),
             ratelimiter=NonCallableMock(spec_set=["send_message"]),
         )
         self.ratelimiter = hs.get_ratelimiter()
@@ -425,7 +425,7 @@ class RoomsMemberListTestCase(RestTestCase):
                 "token_id": 1,
                 "is_guest": False,
             }
-        hs.get_v1auth().get_user_by_access_token = get_user_by_access_token
+        hs.get_auth().get_user_by_access_token = get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
@@ -493,7 +493,7 @@ class RoomsCreateTestCase(RestTestCase):
         hs = yield setup_test_homeserver(
             "red",
             http_client=None,
-            replication_layer=Mock(),
+            federation_client=Mock(),
             ratelimiter=NonCallableMock(spec_set=["send_message"]),
         )
         self.ratelimiter = hs.get_ratelimiter()
@@ -507,7 +507,7 @@ class RoomsCreateTestCase(RestTestCase):
                 "token_id": 1,
                 "is_guest": False,
             }
-        hs.get_v1auth().get_user_by_access_token = get_user_by_access_token
+        hs.get_auth().get_user_by_access_token = get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
@@ -515,9 +515,6 @@ class RoomsCreateTestCase(RestTestCase):
 
         synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
 
-    def tearDown(self):
-        pass
-
     @defer.inlineCallbacks
     def test_post_room_no_keys(self):
         # POST with no config keys, expect new room id
@@ -585,7 +582,7 @@ class RoomTopicTestCase(RestTestCase):
         hs = yield setup_test_homeserver(
             "red",
             http_client=None,
-            replication_layer=Mock(),
+            federation_client=Mock(),
             ratelimiter=NonCallableMock(spec_set=["send_message"]),
         )
         self.ratelimiter = hs.get_ratelimiter()
@@ -600,7 +597,7 @@ class RoomTopicTestCase(RestTestCase):
                 "is_guest": False,
             }
 
-        hs.get_v1auth().get_user_by_access_token = get_user_by_access_token
+        hs.get_auth().get_user_by_access_token = get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
@@ -700,7 +697,7 @@ class RoomMemberStateTestCase(RestTestCase):
         hs = yield setup_test_homeserver(
             "red",
             http_client=None,
-            replication_layer=Mock(),
+            federation_client=Mock(),
             ratelimiter=NonCallableMock(spec_set=["send_message"]),
         )
         self.ratelimiter = hs.get_ratelimiter()
@@ -714,7 +711,7 @@ class RoomMemberStateTestCase(RestTestCase):
                 "token_id": 1,
                 "is_guest": False,
             }
-        hs.get_v1auth().get_user_by_access_token = get_user_by_access_token
+        hs.get_auth().get_user_by_access_token = get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
@@ -769,7 +766,7 @@ class RoomMemberStateTestCase(RestTestCase):
     @defer.inlineCallbacks
     def test_rooms_members_self(self):
         path = "/rooms/%s/state/m.room.member/%s" % (
-            urllib.quote(self.room_id), self.user_id
+            urlparse.quote(self.room_id), self.user_id
         )
 
         # valid join message (NOOP since we made the room)
@@ -789,7 +786,7 @@ class RoomMemberStateTestCase(RestTestCase):
     def test_rooms_members_other(self):
         self.other_id = "@zzsid1:red"
         path = "/rooms/%s/state/m.room.member/%s" % (
-            urllib.quote(self.room_id), self.other_id
+            urlparse.quote(self.room_id), self.other_id
         )
 
         # valid invite message
@@ -805,7 +802,7 @@ class RoomMemberStateTestCase(RestTestCase):
     def test_rooms_members_other_custom_keys(self):
         self.other_id = "@zzsid1:red"
         path = "/rooms/%s/state/m.room.member/%s" % (
-            urllib.quote(self.room_id), self.other_id
+            urlparse.quote(self.room_id), self.other_id
         )
 
         # valid invite message with custom key
@@ -832,7 +829,7 @@ class RoomMessagesTestCase(RestTestCase):
         hs = yield setup_test_homeserver(
             "red",
             http_client=None,
-            replication_layer=Mock(),
+            federation_client=Mock(),
             ratelimiter=NonCallableMock(spec_set=["send_message"]),
         )
         self.ratelimiter = hs.get_ratelimiter()
@@ -846,7 +843,7 @@ class RoomMessagesTestCase(RestTestCase):
                 "token_id": 1,
                 "is_guest": False,
             }
-        hs.get_v1auth().get_user_by_access_token = get_user_by_access_token
+        hs.get_auth().get_user_by_access_token = get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
@@ -862,7 +859,7 @@ class RoomMessagesTestCase(RestTestCase):
     @defer.inlineCallbacks
     def test_invalid_puts(self):
         path = "/rooms/%s/send/m.room.message/mid1" % (
-            urllib.quote(self.room_id))
+            urlparse.quote(self.room_id))
         # missing keys or invalid json
         (code, response) = yield self.mock_resource.trigger(
             "PUT", path, '{}'
@@ -897,7 +894,7 @@ class RoomMessagesTestCase(RestTestCase):
     @defer.inlineCallbacks
     def test_rooms_messages_sent(self):
         path = "/rooms/%s/send/m.room.message/mid1" % (
-            urllib.quote(self.room_id))
+            urlparse.quote(self.room_id))
 
         content = '{"body":"test","msgtype":{"type":"a"}}'
         (code, response) = yield self.mock_resource.trigger("PUT", path, content)
@@ -914,7 +911,7 @@ class RoomMessagesTestCase(RestTestCase):
 
         # m.text message type
         path = "/rooms/%s/send/m.room.message/mid2" % (
-            urllib.quote(self.room_id))
+            urlparse.quote(self.room_id))
         content = '{"body":"test2","msgtype":"m.text"}'
         (code, response) = yield self.mock_resource.trigger("PUT", path, content)
         self.assertEquals(200, code, msg=str(response))
@@ -932,7 +929,7 @@ class RoomInitialSyncTestCase(RestTestCase):
         hs = yield setup_test_homeserver(
             "red",
             http_client=None,
-            replication_layer=Mock(),
+            federation_client=Mock(),
             ratelimiter=NonCallableMock(spec_set=[
                 "send_message",
             ]),
@@ -948,7 +945,7 @@ class RoomInitialSyncTestCase(RestTestCase):
                 "token_id": 1,
                 "is_guest": False,
             }
-        hs.get_v1auth().get_user_by_access_token = get_user_by_access_token
+        hs.get_auth().get_user_by_access_token = get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
@@ -1006,7 +1003,7 @@ class RoomMessageListTestCase(RestTestCase):
         hs = yield setup_test_homeserver(
             "red",
             http_client=None,
-            replication_layer=Mock(),
+            federation_client=Mock(),
             ratelimiter=NonCallableMock(spec_set=["send_message"]),
         )
         self.ratelimiter = hs.get_ratelimiter()
@@ -1020,7 +1017,7 @@ class RoomMessageListTestCase(RestTestCase):
                 "token_id": 1,
                 "is_guest": False,
             }
-        hs.get_v1auth().get_user_by_access_token = get_user_by_access_token
+        hs.get_auth().get_user_by_access_token = get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
@@ -1032,7 +1029,7 @@ class RoomMessageListTestCase(RestTestCase):
 
     @defer.inlineCallbacks
     def test_topo_token_is_accepted(self):
-        token = "t1-0_0_0_0_0_0_0_0"
+        token = "t1-0_0_0_0_0_0_0_0_0"
         (code, response) = yield self.mock_resource.trigger_get(
             "/rooms/%s/messages?access_token=x&from=%s" %
             (self.room_id, token))
@@ -1044,7 +1041,7 @@ class RoomMessageListTestCase(RestTestCase):
 
     @defer.inlineCallbacks
     def test_stream_token_is_accepted_for_fwd_pagianation(self):
-        token = "s0_0_0_0_0_0_0_0"
+        token = "s0_0_0_0_0_0_0_0_0"
         (code, response) = yield self.mock_resource.trigger_get(
             "/rooms/%s/messages?access_token=x&from=%s" %
             (self.room_id, token))
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index a269e6f56e..fe161ee5cb 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -47,7 +47,7 @@ class RoomTypingTestCase(RestTestCase):
             "red",
             clock=self.clock,
             http_client=None,
-            replication_layer=Mock(),
+            federation_client=Mock(),
             ratelimiter=NonCallableMock(spec_set=[
                 "send_message",
             ]),
@@ -68,7 +68,7 @@ class RoomTypingTestCase(RestTestCase):
                 "is_guest": False,
             }
 
-        hs.get_v1auth().get_user_by_access_token = get_user_by_access_token
+        hs.get_auth().get_user_by_access_token = get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
@@ -95,7 +95,7 @@ class RoomTypingTestCase(RestTestCase):
                 else:
                     if remotedomains is not None:
                         remotedomains.add(member.domain)
-        hs.get_handlers().room_member_handler.fetch_room_distributions_into = (
+        hs.get_room_member_handler().fetch_room_distributions_into = (
             fetch_room_distributions_into
         )
 
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index 3d27d03cbf..76b833e119 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -33,8 +33,8 @@ PATH_PREFIX = "/_matrix/client/v2_alpha"
 class FilterTestCase(unittest.TestCase):
 
     USER_ID = "@apple:test"
-    EXAMPLE_FILTER = {"type": ["m.*"]}
-    EXAMPLE_FILTER_JSON = '{"type": ["m.*"]}'
+    EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
+    EXAMPLE_FILTER_JSON = '{"room": {"timeline": {"types": ["m.room.message"]}}}'
     TO_REGISTER = [filter]
 
     @defer.inlineCallbacks
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index b6173ab2ee..8aba456510 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -1,5 +1,7 @@
+from twisted.python import failure
+
 from synapse.rest.client.v2_alpha.register import RegisterRestServlet
-from synapse.api.errors import SynapseError
+from synapse.api.errors import SynapseError, InteractiveAuthIncompleteError
 from twisted.internet import defer
 from mock import Mock
 from tests import unittest
@@ -24,7 +26,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
             side_effect=lambda x: self.appservice)
         )
 
-        self.auth_result = (False, None, None, None)
+        self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
         self.auth_handler = Mock(
             check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
             get_session_data=Mock(return_value=None)
@@ -47,6 +49,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
         self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
         self.hs.get_device_handler = Mock(return_value=self.device_handler)
         self.hs.config.enable_registration = True
+        self.hs.config.registrations_require_3pid = []
+        self.hs.config.auto_join_rooms = []
 
         # init the thing we're testing
         self.servlet = RegisterRestServlet(self.hs)
@@ -85,6 +89,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
         self.request.args = {
             "access_token": "i_am_an_app_service"
         }
+
         self.request_data = json.dumps({
             "username": "kermit"
         })
@@ -119,7 +124,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
             "device_id": device_id,
         })
         self.registration_handler.check_username = Mock(return_value=True)
-        self.auth_result = (True, None, {
+        self.auth_result = (None, {
             "username": "kermit",
             "password": "monkey"
         }, None)
@@ -149,7 +154,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
             "password": "monkey"
         })
         self.registration_handler.check_username = Mock(return_value=True)
-        self.auth_result = (True, None, {
+        self.auth_result = (None, {
             "username": "kermit",
             "password": "monkey"
         }, None)
diff --git a/tests/rest/media/__init__.py b/tests/rest/media/__init__.py
new file mode 100644
index 0000000000..a354d38ca8
--- /dev/null
+++ b/tests/rest/media/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/rest/media/v1/__init__.py b/tests/rest/media/v1/__init__.py
new file mode 100644
index 0000000000..a354d38ca8
--- /dev/null
+++ b/tests/rest/media/v1/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
new file mode 100644
index 0000000000..eef38b6781
--- /dev/null
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -0,0 +1,86 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet import defer
+
+from synapse.rest.media.v1._base import FileInfo
+from synapse.rest.media.v1.media_storage import MediaStorage
+from synapse.rest.media.v1.filepath import MediaFilePaths
+from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
+
+from mock import Mock
+
+from tests import unittest
+
+import os
+import shutil
+import tempfile
+
+
+class MediaStorageTests(unittest.TestCase):
+    def setUp(self):
+        self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
+
+        self.primary_base_path = os.path.join(self.test_dir, "primary")
+        self.secondary_base_path = os.path.join(self.test_dir, "secondary")
+
+        hs = Mock()
+        hs.config.media_store_path = self.primary_base_path
+
+        storage_providers = [FileStorageProviderBackend(
+            hs, self.secondary_base_path
+        )]
+
+        self.filepaths = MediaFilePaths(self.primary_base_path)
+        self.media_storage = MediaStorage(
+            self.primary_base_path, self.filepaths, storage_providers,
+        )
+
+    def tearDown(self):
+        shutil.rmtree(self.test_dir)
+
+    @defer.inlineCallbacks
+    def test_ensure_media_is_in_local_cache(self):
+        media_id = "some_media_id"
+        test_body = "Test\n"
+
+        # First we create a file that is in a storage provider but not in the
+        # local primary media store
+        rel_path = self.filepaths.local_media_filepath_rel(media_id)
+        secondary_path = os.path.join(self.secondary_base_path, rel_path)
+
+        os.makedirs(os.path.dirname(secondary_path))
+
+        with open(secondary_path, "w") as f:
+            f.write(test_body)
+
+        # Now we run ensure_media_is_in_local_cache, which should copy the file
+        # to the local cache.
+        file_info = FileInfo(None, media_id)
+        local_path = yield self.media_storage.ensure_media_is_in_local_cache(file_info)
+
+        self.assertTrue(os.path.exists(local_path))
+
+        # Asserts the file is under the expected local cache directory
+        self.assertEquals(
+            os.path.commonprefix([self.primary_base_path, local_path]),
+            self.primary_base_path,
+        )
+
+        with open(local_path) as f:
+            body = f.read()
+
+        self.assertEqual(test_body, body)
diff --git a/tests/storage/event_injector.py b/tests/storage/event_injector.py
deleted file mode 100644
index 38556da9a7..0000000000
--- a/tests/storage/event_injector.py
+++ /dev/null
@@ -1,76 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from twisted.internet import defer
-
-from synapse.api.constants import EventTypes
-
-
-class EventInjector:
-    def __init__(self, hs):
-        self.hs = hs
-        self.store = hs.get_datastore()
-        self.message_handler = hs.get_handlers().message_handler
-        self.event_builder_factory = hs.get_event_builder_factory()
-
-    @defer.inlineCallbacks
-    def create_room(self, room):
-        builder = self.event_builder_factory.new({
-            "type": EventTypes.Create,
-            "sender": "",
-            "room_id": room.to_string(),
-            "content": {},
-        })
-
-        event, context = yield self.message_handler._create_new_client_event(
-            builder
-        )
-
-        yield self.store.persist_event(event, context)
-
-    @defer.inlineCallbacks
-    def inject_room_member(self, room, user, membership):
-        builder = self.event_builder_factory.new({
-            "type": EventTypes.Member,
-            "sender": user.to_string(),
-            "state_key": user.to_string(),
-            "room_id": room.to_string(),
-            "content": {"membership": membership},
-        })
-
-        event, context = yield self.message_handler._create_new_client_event(
-            builder
-        )
-
-        yield self.store.persist_event(event, context)
-
-        defer.returnValue(event)
-
-    @defer.inlineCallbacks
-    def inject_message(self, room, user, body):
-        builder = self.event_builder_factory.new({
-            "type": EventTypes.Message,
-            "sender": user.to_string(),
-            "state_key": user.to_string(),
-            "room_id": room.to_string(),
-            "content": {"body": body, "msgtype": u"message"},
-        })
-
-        event, context = yield self.message_handler._create_new_client_event(
-            builder
-        )
-
-        yield self.store.persist_event(event, context)
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 8361dd8cee..3cfa21c9f8 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -199,7 +199,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
 
         a.func.prefill(("foo",), ObservableDeferred(d))
 
-        self.assertEquals(a.func("foo").result, d.result)
+        self.assertEquals(a.func("foo"), d.result)
         self.assertEquals(callcount[0], 0)
 
     @defer.inlineCallbacks
@@ -241,7 +241,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
         callcount2 = [0]
 
         class A(object):
-            @cached(max_entries=20)  # HACK: This makes it 2 due to cache factor
+            @cached(max_entries=4)  # HACK: This makes it 2 due to cache factor
             def func(self, key):
                 callcount[0] += 1
                 return key
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 9e98d0e330..00825498b1 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -42,7 +42,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
         hs = yield setup_test_homeserver(
             config=config,
             federation_sender=Mock(),
-            replication_layer=Mock(),
+            federation_client=Mock(),
         )
 
         self.as_token = "token1"
@@ -58,14 +58,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
         self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
         self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
         # must be done after inserts
-        self.store = ApplicationServiceStore(hs)
+        self.store = ApplicationServiceStore(None, hs)
 
     def tearDown(self):
         # TODO: suboptimal that we need to create files for tests!
         for f in self.as_yaml_files:
             try:
                 os.remove(f)
-            except:
+            except Exception:
                 pass
 
     def _add_appservice(self, as_token, id, url, hs_token, sender):
@@ -119,7 +119,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         hs = yield setup_test_homeserver(
             config=config,
             federation_sender=Mock(),
-            replication_layer=Mock(),
+            federation_client=Mock(),
         )
         self.db_pool = hs.get_db_pool()
 
@@ -150,7 +150,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
         self.as_yaml_files = []
 
-        self.store = TestTransactionStore(hs)
+        self.store = TestTransactionStore(None, hs)
 
     def _add_service(self, url, as_token, id):
         as_yaml = dict(url=url, as_token=as_token, hs_token="something",
@@ -420,8 +420,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 class TestTransactionStore(ApplicationServiceTransactionStore,
                            ApplicationServiceStore):
 
-    def __init__(self, hs):
-        super(TestTransactionStore, self).__init__(hs)
+    def __init__(self, db_conn, hs):
+        super(TestTransactionStore, self).__init__(db_conn, hs)
 
 
 class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
@@ -455,10 +455,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
             config=config,
             datastore=Mock(),
             federation_sender=Mock(),
-            replication_layer=Mock(),
+            federation_client=Mock(),
         )
 
-        ApplicationServiceStore(hs)
+        ApplicationServiceStore(None, hs)
 
     @defer.inlineCallbacks
     def test_duplicate_ids(self):
@@ -473,16 +473,16 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
             config=config,
             datastore=Mock(),
             federation_sender=Mock(),
-            replication_layer=Mock(),
+            federation_client=Mock(),
         )
 
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(hs)
+            ApplicationServiceStore(None, hs)
 
         e = cm.exception
-        self.assertIn(f1, e.message)
-        self.assertIn(f2, e.message)
-        self.assertIn("id", e.message)
+        self.assertIn(f1, str(e))
+        self.assertIn(f2, str(e))
+        self.assertIn("id", str(e))
 
     @defer.inlineCallbacks
     def test_duplicate_as_tokens(self):
@@ -497,13 +497,13 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
             config=config,
             datastore=Mock(),
             federation_sender=Mock(),
-            replication_layer=Mock(),
+            federation_client=Mock(),
         )
 
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(hs)
+            ApplicationServiceStore(None, hs)
 
         e = cm.exception
-        self.assertIn(f1, e.message)
-        self.assertIn(f2, e.message)
-        self.assertIn("as_token", e.message)
+        self.assertIn(f1, str(e))
+        self.assertIn(f2, str(e))
+        self.assertIn("as_token", str(e))
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index afbefb2e2d..0ac910e76f 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -56,7 +56,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
             database_engine=create_engine(config.database_config),
         )
 
-        self.datastore = SQLBaseStore(hs)
+        self.datastore = SQLBaseStore(None, hs)
 
     @defer.inlineCallbacks
     def test_insert_1col(self):
@@ -89,7 +89,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_select_one_1col(self):
         self.mock_txn.rowcount = 1
-        self.mock_txn.fetchall.return_value = [("Value",)]
+        self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
 
         value = yield self.datastore._simple_select_one_onecol(
             table="tablename",
@@ -136,7 +136,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_select_list(self):
         self.mock_txn.rowcount = 3
-        self.mock_txn.fetchall.return_value = ((1,), (2,), (3,))
+        self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
         self.mock_txn.description = (
             ("colA", None, None, None, None, None, None),
         )
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 1f0c0e7c37..bd6fda6cb1 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -15,9 +15,6 @@
 
 from twisted.internet import defer
 
-import synapse.server
-import synapse.storage
-import synapse.types
 import tests.unittest
 import tests.utils
 
@@ -39,14 +36,11 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
         self.clock.now = 12345678
         user_id = "@user:id"
         yield self.store.insert_client_ip(
-            synapse.types.UserID.from_string(user_id),
+            user_id,
             "access_token", "ip", "user_agent", "device_id",
         )
 
-        # deliberately use an iterable here to make sure that the lookup
-        # method doesn't iterate it twice
-        device_list = iter(((user_id, "device_id"),))
-        result = yield self.store.get_last_client_ip_by_device(device_list)
+        result = yield self.store.get_last_client_ip_by_device(user_id, "device_id")
 
         r = result[(user_id, "device_id")]
         self.assertDictContainsSubset(
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index b087892e0b..95709cd50a 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -29,7 +29,7 @@ class DirectoryStoreTestCase(unittest.TestCase):
     def setUp(self):
         hs = yield setup_test_homeserver()
 
-        self.store = DirectoryStore(hs)
+        self.store = DirectoryStore(None, hs)
 
         self.room = RoomID.from_string("!abcde:test")
         self.alias = RoomAlias.from_string("#my-room:test")
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
new file mode 100644
index 0000000000..30683e7888
--- /dev/null
+++ b/tests/storage/test_event_federation.py
@@ -0,0 +1,68 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+import tests.unittest
+import tests.utils
+
+
+class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
+    @defer.inlineCallbacks
+    def setUp(self):
+        hs = yield tests.utils.setup_test_homeserver()
+        self.store = hs.get_datastore()
+
+    @defer.inlineCallbacks
+    def test_get_prev_events_for_room(self):
+        room_id = '@ROOM:local'
+
+        # add a bunch of events and hashes to act as forward extremities
+        def insert_event(txn, i):
+            event_id = '$event_%i:local' % i
+
+            txn.execute((
+                "INSERT INTO events ("
+                "   room_id, event_id, type, depth, topological_ordering,"
+                "   content, processed, outlier) "
+                "VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?)"
+            ), (room_id, event_id, i, i, True, False))
+
+            txn.execute((
+                'INSERT INTO event_forward_extremities (room_id, event_id) '
+                'VALUES (?, ?)'
+            ), (room_id, event_id))
+
+            txn.execute((
+                'INSERT INTO event_reference_hashes '
+                '(event_id, algorithm, hash) '
+                "VALUES (?, 'sha256', ?)"
+            ), (event_id, 'ffff'))
+
+        for i in range(0, 11):
+            yield self.store.runInteraction("insert", insert_event, i)
+
+        # this should get the last five and five others
+        r = yield self.store.get_prev_events_for_room(room_id)
+        self.assertEqual(10, len(r))
+        for i in range(0, 5):
+            el = r[i]
+            depth = el[2]
+            self.assertEqual(10 - i, depth)
+
+        for i in range(5, 5):
+            el = r[i]
+            depth = el[2]
+            self.assertLessEqual(5, depth)
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index e9044afa2e..9962ce8a5d 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -17,9 +17,15 @@ from twisted.internet import defer
 
 import tests.unittest
 import tests.utils
+from mock import Mock
 
 USER_ID = "@user:example.com"
 
+PlAIN_NOTIF = ["notify", {"set_tweak": "highlight", "value": False}]
+HIGHLIGHT = [
+    "notify", {"set_tweak": "sound", "value": "default"}, {"set_tweak": "highlight"}
+]
+
 
 class EventPushActionsStoreTestCase(tests.unittest.TestCase):
 
@@ -39,3 +45,151 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
         yield self.store.get_unread_push_actions_for_user_in_range_for_email(
             USER_ID, 0, 1000, 20
         )
+
+    @defer.inlineCallbacks
+    def test_count_aggregation(self):
+        room_id = "!foo:example.com"
+        user_id = "@user1235:example.com"
+
+        @defer.inlineCallbacks
+        def _assert_counts(noitf_count, highlight_count):
+            counts = yield self.store.runInteraction(
+                "", self.store._get_unread_counts_by_pos_txn,
+                room_id, user_id, 0, 0
+            )
+            self.assertEquals(
+                counts,
+                {"notify_count": noitf_count, "highlight_count": highlight_count}
+            )
+
+        @defer.inlineCallbacks
+        def _inject_actions(stream, action):
+            event = Mock()
+            event.room_id = room_id
+            event.event_id = "$test:example.com"
+            event.internal_metadata.stream_ordering = stream
+            event.depth = stream
+
+            yield self.store.add_push_actions_to_staging(
+                event.event_id, {user_id: action},
+            )
+            yield self.store.runInteraction(
+                "", self.store._set_push_actions_for_event_and_users_txn,
+                [(event, None)], [(event, None)],
+            )
+
+        def _rotate(stream):
+            return self.store.runInteraction(
+                "", self.store._rotate_notifs_before_txn, stream
+            )
+
+        def _mark_read(stream, depth):
+            return self.store.runInteraction(
+                "", self.store._remove_old_push_actions_before_txn,
+                room_id, user_id, depth, stream
+            )
+
+        yield _assert_counts(0, 0)
+        yield _inject_actions(1, PlAIN_NOTIF)
+        yield _assert_counts(1, 0)
+        yield _rotate(2)
+        yield _assert_counts(1, 0)
+
+        yield _inject_actions(3, PlAIN_NOTIF)
+        yield _assert_counts(2, 0)
+        yield _rotate(4)
+        yield _assert_counts(2, 0)
+
+        yield _inject_actions(5, PlAIN_NOTIF)
+        yield _mark_read(3, 3)
+        yield _assert_counts(1, 0)
+
+        yield _mark_read(5, 5)
+        yield _assert_counts(0, 0)
+
+        yield _inject_actions(6, PlAIN_NOTIF)
+        yield _rotate(7)
+
+        yield self.store._simple_delete(
+            table="event_push_actions",
+            keyvalues={"1": 1},
+            desc="",
+        )
+
+        yield _assert_counts(1, 0)
+
+        yield _mark_read(7, 7)
+        yield _assert_counts(0, 0)
+
+        yield _inject_actions(8, HIGHLIGHT)
+        yield _assert_counts(1, 1)
+        yield _rotate(9)
+        yield _assert_counts(1, 1)
+        yield _rotate(10)
+        yield _assert_counts(1, 1)
+
+    @defer.inlineCallbacks
+    def test_find_first_stream_ordering_after_ts(self):
+        def add_event(so, ts):
+            return self.store._simple_insert("events", {
+                "stream_ordering": so,
+                "received_ts": ts,
+                "event_id": "event%i" % so,
+                "type": "",
+                "room_id": "",
+                "content": "",
+                "processed": True,
+                "outlier": False,
+                "topological_ordering": 0,
+                "depth": 0,
+            })
+
+        # start with the base case where there are no events in the table
+        r = yield self.store.find_first_stream_ordering_after_ts(11)
+        self.assertEqual(r, 0)
+
+        # now with one event
+        yield add_event(2, 10)
+        r = yield self.store.find_first_stream_ordering_after_ts(9)
+        self.assertEqual(r, 2)
+        r = yield self.store.find_first_stream_ordering_after_ts(10)
+        self.assertEqual(r, 2)
+        r = yield self.store.find_first_stream_ordering_after_ts(11)
+        self.assertEqual(r, 3)
+
+        # add a bunch of dummy events to the events table
+        for (stream_ordering, ts) in (
+                (3, 110),
+                (4, 120),
+                (5, 120),
+                (10, 130),
+                (20, 140),
+        ):
+            yield add_event(stream_ordering, ts)
+
+        r = yield self.store.find_first_stream_ordering_after_ts(110)
+        self.assertEqual(r, 3,
+                         "First event after 110ms should be 3, was %i" % r)
+
+        # 4 and 5 are both after 120: we want 4 rather than 5
+        r = yield self.store.find_first_stream_ordering_after_ts(120)
+        self.assertEqual(r, 4,
+                         "First event after 120ms should be 4, was %i" % r)
+
+        r = yield self.store.find_first_stream_ordering_after_ts(129)
+        self.assertEqual(r, 10,
+                         "First event after 129ms should be 10, was %i" % r)
+
+        # check we can get the last event
+        r = yield self.store.find_first_stream_ordering_after_ts(140)
+        self.assertEqual(r, 20,
+                         "First event after 14ms should be 20, was %i" % r)
+
+        # off the end
+        r = yield self.store.find_first_stream_ordering_after_ts(160)
+        self.assertEqual(r, 21)
+
+        # check we can find an event at ordering zero
+        yield add_event(0, 5)
+        r = yield self.store.find_first_stream_ordering_after_ts(1)
+        self.assertEqual(r, 0)
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
deleted file mode 100644
index 3762b38e37..0000000000
--- a/tests/storage/test_events.py
+++ /dev/null
@@ -1,115 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from mock import Mock
-from synapse.types import RoomID, UserID
-
-from tests import unittest
-from twisted.internet import defer
-from tests.storage.event_injector import EventInjector
-
-from tests.utils import setup_test_homeserver
-
-
-class EventsStoreTestCase(unittest.TestCase):
-
-    @defer.inlineCallbacks
-    def setUp(self):
-        self.hs = yield setup_test_homeserver(
-            resource_for_federation=Mock(),
-            http_client=None,
-        )
-        self.store = self.hs.get_datastore()
-        self.db_pool = self.hs.get_db_pool()
-        self.message_handler = self.hs.get_handlers().message_handler
-        self.event_injector = EventInjector(self.hs)
-
-    @defer.inlineCallbacks
-    def test_count_daily_messages(self):
-        yield self.db_pool.runQuery("DELETE FROM stats_reporting")
-
-        self.hs.clock.now = 100
-
-        # Never reported before, and nothing which could be reported
-        count = yield self.store.count_daily_messages()
-        self.assertIsNone(count)
-        count = yield self.db_pool.runQuery("SELECT COUNT(*) FROM stats_reporting")
-        self.assertEqual([(0,)], count)
-
-        # Create something to report
-        room = RoomID.from_string("!abc123:test")
-        user = UserID.from_string("@raccoonlover:test")
-        yield self.event_injector.create_room(room)
-
-        self.base_event = yield self._get_last_stream_token()
-
-        yield self.event_injector.inject_message(room, user, "Raccoons are really cute")
-
-        # Never reported before, something could be reported, but isn't because
-        # it isn't old enough.
-        count = yield self.store.count_daily_messages()
-        self.assertIsNone(count)
-        yield self._assert_stats_reporting(1, self.hs.clock.now)
-
-        # Already reported yesterday, two new events from today.
-        yield self.event_injector.inject_message(room, user, "Yeah they are!")
-        yield self.event_injector.inject_message(room, user, "Incredibly!")
-        self.hs.clock.now += 60 * 60 * 24
-        count = yield self.store.count_daily_messages()
-        self.assertEqual(2, count)  # 2 since yesterday
-        yield self._assert_stats_reporting(3, self.hs.clock.now)  # 3 ever
-
-        # Last reported too recently.
-        yield self.event_injector.inject_message(room, user, "Who could disagree?")
-        self.hs.clock.now += 60 * 60 * 22
-        count = yield self.store.count_daily_messages()
-        self.assertIsNone(count)
-        yield self._assert_stats_reporting(4, self.hs.clock.now)
-
-        # Last reported too long ago
-        yield self.event_injector.inject_message(room, user, "No one.")
-        self.hs.clock.now += 60 * 60 * 26
-        count = yield self.store.count_daily_messages()
-        self.assertIsNone(count)
-        yield self._assert_stats_reporting(5, self.hs.clock.now)
-
-        # And now let's actually report something
-        yield self.event_injector.inject_message(room, user, "Indeed.")
-        yield self.event_injector.inject_message(room, user, "Indeed.")
-        yield self.event_injector.inject_message(room, user, "Indeed.")
-        # A little over 24 hours is fine :)
-        self.hs.clock.now += (60 * 60 * 24) + 50
-        count = yield self.store.count_daily_messages()
-        self.assertEqual(3, count)
-        yield self._assert_stats_reporting(8, self.hs.clock.now)
-
-    @defer.inlineCallbacks
-    def _get_last_stream_token(self):
-        rows = yield self.db_pool.runQuery(
-            "SELECT stream_ordering"
-            " FROM events"
-            " ORDER BY stream_ordering DESC"
-            " LIMIT 1"
-        )
-        if not rows:
-            defer.returnValue(0)
-        else:
-            defer.returnValue(rows[0][0])
-
-    @defer.inlineCallbacks
-    def _assert_stats_reporting(self, messages, time):
-        rows = yield self.db_pool.runQuery(
-            "SELECT reported_stream_token, reported_time FROM stats_reporting"
-        )
-        self.assertEqual([(self.base_event + messages, time,)], rows)
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
new file mode 100644
index 0000000000..0be790d8f8
--- /dev/null
+++ b/tests/storage/test_keys.py
@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import signedjson.key
+from twisted.internet import defer
+
+import tests.unittest
+import tests.utils
+
+
+class KeyStoreTestCase(tests.unittest.TestCase):
+    def __init__(self, *args, **kwargs):
+        super(KeyStoreTestCase, self).__init__(*args, **kwargs)
+        self.store = None  # type: synapse.storage.keys.KeyStore
+
+    @defer.inlineCallbacks
+    def setUp(self):
+        hs = yield tests.utils.setup_test_homeserver()
+        self.store = hs.get_datastore()
+
+    @defer.inlineCallbacks
+    def test_get_server_verify_keys(self):
+        key1 = signedjson.key.decode_verify_key_base64(
+            "ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw"
+        )
+        key2 = signedjson.key.decode_verify_key_base64(
+            "ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
+        )
+        yield self.store.store_server_verify_key(
+            "server1", "from_server", 0, key1
+        )
+        yield self.store.store_server_verify_key(
+            "server1", "from_server", 0, key2
+        )
+
+        res = yield self.store.get_server_verify_keys(
+            "server1", ["ed25519:key1", "ed25519:key2", "ed25519:key3"])
+
+        self.assertEqual(len(res.keys()), 2)
+        self.assertEqual(res["ed25519:key1"].version, "key1")
+        self.assertEqual(res["ed25519:key2"].version, "key2")
diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py
index 63203cea35..f5fcb611d4 100644
--- a/tests/storage/test_presence.py
+++ b/tests/storage/test_presence.py
@@ -29,7 +29,7 @@ class PresenceStoreTestCase(unittest.TestCase):
     def setUp(self):
         hs = yield setup_test_homeserver(clock=MockClock())
 
-        self.store = PresenceStore(hs)
+        self.store = PresenceStore(None, hs)
 
         self.u_apple = UserID.from_string("@apple:test")
         self.u_banana = UserID.from_string("@banana:test")
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 24118bbc86..423710c9c1 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -29,7 +29,7 @@ class ProfileStoreTestCase(unittest.TestCase):
     def setUp(self):
         hs = yield setup_test_homeserver()
 
-        self.store = ProfileStore(hs)
+        self.store = ProfileStore(None, hs)
 
         self.u_frank = UserID.from_string("@frank:test")
 
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 6afaca3a61..888ddfaddd 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -36,8 +36,7 @@ class RedactionTestCase(unittest.TestCase):
 
         self.store = hs.get_datastore()
         self.event_builder_factory = hs.get_event_builder_factory()
-        self.handlers = hs.get_handlers()
-        self.message_handler = self.handlers.message_handler
+        self.event_creation_handler = hs.get_event_creation_handler()
 
         self.u_alice = UserID.from_string("@alice:test")
         self.u_bob = UserID.from_string("@bob:test")
@@ -59,7 +58,7 @@ class RedactionTestCase(unittest.TestCase):
             "content": content,
         })
 
-        event, context = yield self.message_handler._create_new_client_event(
+        event, context = yield self.event_creation_handler.create_new_client_event(
             builder
         )
 
@@ -79,7 +78,7 @@ class RedactionTestCase(unittest.TestCase):
             "content": {"body": body, "msgtype": u"message"},
         })
 
-        event, context = yield self.message_handler._create_new_client_event(
+        event, context = yield self.event_creation_handler.create_new_client_event(
             builder
         )
 
@@ -98,7 +97,7 @@ class RedactionTestCase(unittest.TestCase):
             "redacts": event_id,
         })
 
-        event, context = yield self.message_handler._create_new_client_event(
+        event, context = yield self.event_creation_handler.create_new_client_event(
             builder
         )
 
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 316ecdb32d..7c7b164ee6 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -86,7 +86,8 @@ class RegistrationStoreTestCase(unittest.TestCase):
 
         # now delete some
         yield self.store.user_delete_access_tokens(
-            self.user_id, device_id=self.device_id, delete_refresh_tokens=True)
+            self.user_id, device_id=self.device_id,
+        )
 
         # check they were deleted
         user = yield self.store.get_user_by_access_token(self.tokens[1])
@@ -97,8 +98,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
         self.assertEqual(self.user_id, user["name"])
 
         # now delete the rest
-        yield self.store.user_delete_access_tokens(
-            self.user_id, delete_refresh_tokens=True)
+        yield self.store.user_delete_access_tokens(self.user_id)
 
         user = yield self.store.get_user_by_access_token(self.tokens[0])
         self.assertIsNone(user,
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 1be7d932f6..657b279e5d 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -37,8 +37,7 @@ class RoomMemberStoreTestCase(unittest.TestCase):
         # storage logic
         self.store = hs.get_datastore()
         self.event_builder_factory = hs.get_event_builder_factory()
-        self.handlers = hs.get_handlers()
-        self.message_handler = self.handlers.message_handler
+        self.event_creation_handler = hs.get_event_creation_handler()
 
         self.u_alice = UserID.from_string("@alice:test")
         self.u_bob = UserID.from_string("@bob:test")
@@ -58,7 +57,7 @@ class RoomMemberStoreTestCase(unittest.TestCase):
             "content": {"membership": membership},
         })
 
-        event, context = yield self.message_handler._create_new_client_event(
+        event, context = yield self.event_creation_handler.create_new_client_event(
             builder
         )
 
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
new file mode 100644
index 0000000000..0891308f25
--- /dev/null
+++ b/tests/storage/test_user_directory.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.storage import UserDirectoryStore
+from synapse.storage.roommember import ProfileInfo
+from tests import unittest
+from tests.utils import setup_test_homeserver
+
+ALICE = "@alice:a"
+BOB = "@bob:b"
+BOBBY = "@bobby:a"
+
+
+class UserDirectoryStoreTestCase(unittest.TestCase):
+    @defer.inlineCallbacks
+    def setUp(self):
+        self.hs = yield setup_test_homeserver()
+        self.store = UserDirectoryStore(None, self.hs)
+
+        # alice and bob are both in !room_id. bobby is not but shares
+        # a homeserver with alice.
+        yield self.store.add_profiles_to_user_dir(
+            "!room:id",
+            {
+                ALICE: ProfileInfo(None, "alice"),
+                BOB: ProfileInfo(None, "bob"),
+                BOBBY: ProfileInfo(None, "bobby")
+            },
+        )
+        yield self.store.add_users_to_public_room(
+            "!room:id",
+            [ALICE, BOB],
+        )
+        yield self.store.add_users_who_share_room(
+            "!room:id",
+            False,
+            (
+                (ALICE, BOB),
+                (BOB, ALICE),
+            ),
+        )
+
+    @defer.inlineCallbacks
+    def test_search_user_dir(self):
+        # normally when alice searches the directory she should just find
+        # bob because bobby doesn't share a room with her.
+        r = yield self.store.search_user_dir(ALICE, "bob", 10)
+        self.assertFalse(r["limited"])
+        self.assertEqual(1, len(r["results"]))
+        self.assertDictEqual(r["results"][0], {
+            "user_id": BOB,
+            "display_name": "bob",
+            "avatar_url": None,
+        })
+
+    @defer.inlineCallbacks
+    def test_search_user_dir_all_users(self):
+        self.hs.config.user_directory_search_all_users = True
+        try:
+            r = yield self.store.search_user_dir(ALICE, "bob", 10)
+            self.assertFalse(r["limited"])
+            self.assertEqual(2, len(r["results"]))
+            self.assertDictEqual(r["results"][0], {
+                "user_id": BOB,
+                "display_name": "bob",
+                "avatar_url": None,
+            })
+            self.assertDictEqual(r["results"][1], {
+                "user_id": BOBBY,
+                "display_name": "bobby",
+                "avatar_url": None,
+            })
+        finally:
+            self.hs.config.user_directory_search_all_users = False
diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index acebcf4a86..010aeaee7e 100644
--- a/tests/test_distributor.py
+++ b/tests/test_distributor.py
@@ -62,7 +62,7 @@ class DistributorTestCase(unittest.TestCase):
     def test_signal_catch(self):
         self.dist.declare("alarm")
 
-        observers = [Mock() for i in 1, 2]
+        observers = [Mock() for i in (1, 2)]
         for o in observers:
             self.dist.observe("alarm", o)
 
diff --git a/tests/test_dns.py b/tests/test_dns.py
index c394c57ee7..af607d626f 100644
--- a/tests/test_dns.py
+++ b/tests/test_dns.py
@@ -24,15 +24,15 @@ from synapse.http.endpoint import resolve_service
 from tests.utils import MockClock
 
 
+@unittest.DEBUG
 class DnsTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_resolve(self):
         dns_client_mock = Mock()
 
-        service_name = "test_service.examle.com"
+        service_name = "test_service.example.com"
         host_name = "example.com"
-        ip_address = "127.0.0.1"
 
         answer_srv = dns.RRHeader(
             type=dns.SRV,
@@ -41,16 +41,10 @@ class DnsTestCase(unittest.TestCase):
             )
         )
 
-        answer_a = dns.RRHeader(
-            type=dns.A,
-            payload=dns.Record_A(
-                address=ip_address,
-            )
+        dns_client_mock.lookupService.return_value = defer.succeed(
+            ([answer_srv], None, None),
         )
 
-        dns_client_mock.lookupService.return_value = ([answer_srv], None, None)
-        dns_client_mock.lookupAddress.return_value = ([answer_a], None, None)
-
         cache = {}
 
         servers = yield resolve_service(
@@ -58,11 +52,10 @@ class DnsTestCase(unittest.TestCase):
         )
 
         dns_client_mock.lookupService.assert_called_once_with(service_name)
-        dns_client_mock.lookupAddress.assert_called_once_with(host_name)
 
         self.assertEquals(len(servers), 1)
         self.assertEquals(servers, cache[service_name])
-        self.assertEquals(servers[0].host, ip_address)
+        self.assertEquals(servers[0].host, host_name)
 
     @defer.inlineCallbacks
     def test_from_cache_expired_and_dns_fail(self):
diff --git a/tests/test_state.py b/tests/test_state.py
index 6454f994e3..a5c5e55951 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
 from synapse.events import FrozenEvent
 from synapse.api.auth import Auth
 from synapse.api.constants import EventTypes, Membership
-from synapse.state import StateHandler
+from synapse.state import StateHandler, StateResolutionHandler
 
 from .utils import MockClock
 
@@ -80,14 +80,14 @@ class StateGroupStore(object):
 
         return defer.succeed(groups)
 
-    def store_state_groups(self, event, context):
-        if context.current_state_ids is None:
-            return
+    def store_state_group(self, event_id, room_id, prev_group, delta_ids,
+                          current_state_ids):
+        state_group = self._next_group
+        self._next_group += 1
 
-        state_events = dict(context.current_state_ids)
+        self._group_to_state[state_group] = dict(current_state_ids)
 
-        self._group_to_state[context.state_group] = state_events
-        self._event_to_state_group[event.event_id] = context.state_group
+        return state_group
 
     def get_events(self, event_ids, **kwargs):
         return {
@@ -95,10 +95,19 @@ class StateGroupStore(object):
             if e_id in self._event_id_to_event
         }
 
+    def get_state_group_delta(self, name):
+        return (None, None)
+
     def register_events(self, events):
         for e in events:
             self._event_id_to_event[e.event_id] = e
 
+    def register_event_context(self, event, context):
+        self._event_to_state_group[event.event_id] = context.state_group
+
+    def register_event_id_state_group(self, event_id, state_group):
+        self._event_to_state_group[event_id] = state_group
+
 
 class DictObj(dict):
     def __init__(self, **kwargs):
@@ -137,23 +146,16 @@ class Graph(object):
 
 class StateTestCase(unittest.TestCase):
     def setUp(self):
-        self.store = Mock(
-            spec_set=[
-                "get_state_groups_ids",
-                "add_event_hashes",
-                "get_events",
-                "get_next_state_group",
-            ]
-        )
+        self.store = StateGroupStore()
         hs = Mock(spec_set=[
             "get_datastore", "get_auth", "get_state_handler", "get_clock",
+            "get_state_resolution_handler",
         ])
         hs.get_datastore.return_value = self.store
         hs.get_state_handler.return_value = None
         hs.get_clock.return_value = MockClock()
         hs.get_auth.return_value = Auth(hs)
-
-        self.store.get_next_state_group.side_effect = Mock
+        hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
 
         self.state = StateHandler(hs)
         self.event_id = 0
@@ -193,14 +195,13 @@ class StateTestCase(unittest.TestCase):
             }
         )
 
-        store = StateGroupStore()
-        self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
+        self.store.register_events(graph.walk())
 
         context_store = {}
 
         for event in graph.walk():
             context = yield self.state.compute_event_context(event)
-            store.store_state_groups(event, context)
+            self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
         self.assertEqual(2, len(context_store["D"].prev_state_ids))
@@ -245,16 +246,13 @@ class StateTestCase(unittest.TestCase):
             }
         )
 
-        store = StateGroupStore()
-        self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
-        self.store.get_events = store.get_events
-        store.register_events(graph.walk())
+        self.store.register_events(graph.walk())
 
         context_store = {}
 
         for event in graph.walk():
             context = yield self.state.compute_event_context(event)
-            store.store_state_groups(event, context)
+            self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
         self.assertSetEqual(
@@ -311,16 +309,13 @@ class StateTestCase(unittest.TestCase):
             }
         )
 
-        store = StateGroupStore()
-        self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
-        self.store.get_events = store.get_events
-        store.register_events(graph.walk())
+        self.store.register_events(graph.walk())
 
         context_store = {}
 
         for event in graph.walk():
             context = yield self.state.compute_event_context(event)
-            store.store_state_groups(event, context)
+            self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
         self.assertSetEqual(
@@ -394,16 +389,13 @@ class StateTestCase(unittest.TestCase):
         self._add_depths(nodes, edges)
         graph = Graph(nodes, edges)
 
-        store = StateGroupStore()
-        self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
-        self.store.get_events = store.get_events
-        store.register_events(graph.walk())
+        self.store.register_events(graph.walk())
 
         context_store = {}
 
         for event in graph.walk():
             context = yield self.state.compute_event_context(event)
-            store.store_state_groups(event, context)
+            self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
         self.assertSetEqual(
@@ -463,7 +455,11 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_trivial_annotate_message(self):
-        event = create_event(type="test_message", name="event")
+        prev_event_id = "prev_event_id"
+        event = create_event(
+            type="test_message", name="event2",
+            prev_events=[(prev_event_id, {})],
+        )
 
         old_state = [
             create_event(type="test1", state_key="1"),
@@ -471,11 +467,11 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test2", state_key=""),
         ]
 
-        group_name = "group_name_1"
-
-        self.store.get_state_groups_ids.return_value = {
-            group_name: {(e.type, e.state_key): e.event_id for e in old_state},
-        }
+        group_name = self.store.store_state_group(
+            prev_event_id, event.room_id, None, None,
+            {(e.type, e.state_key): e.event_id for e in old_state},
+        )
+        self.store.register_event_id_state_group(prev_event_id, group_name)
 
         context = yield self.state.compute_event_context(event)
 
@@ -488,7 +484,11 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_trivial_annotate_state(self):
-        event = create_event(type="state", state_key="", name="event")
+        prev_event_id = "prev_event_id"
+        event = create_event(
+            type="state", state_key="", name="event2",
+            prev_events=[(prev_event_id, {})],
+        )
 
         old_state = [
             create_event(type="test1", state_key="1"),
@@ -496,11 +496,11 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test2", state_key=""),
         ]
 
-        group_name = "group_name_1"
-
-        self.store.get_state_groups_ids.return_value = {
-            group_name: {(e.type, e.state_key): e.event_id for e in old_state},
-        }
+        group_name = self.store.store_state_group(
+            prev_event_id, event.room_id, None, None,
+            {(e.type, e.state_key): e.event_id for e in old_state},
+        )
+        self.store.register_event_id_state_group(prev_event_id, group_name)
 
         context = yield self.state.compute_event_context(event)
 
@@ -513,7 +513,12 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_resolve_message_conflict(self):
-        event = create_event(type="test_message", name="event")
+        prev_event_id1 = "event_id1"
+        prev_event_id2 = "event_id2"
+        event = create_event(
+            type="test_message", name="event3",
+            prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
+        )
 
         creation = create_event(
             type=EventTypes.Create, state_key=""
@@ -533,12 +538,12 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test4", state_key=""),
         ]
 
-        store = StateGroupStore()
-        store.register_events(old_state_1)
-        store.register_events(old_state_2)
-        self.store.get_events = store.get_events
+        self.store.register_events(old_state_1)
+        self.store.register_events(old_state_2)
 
-        context = yield self._get_context(event, old_state_1, old_state_2)
+        context = yield self._get_context(
+            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+        )
 
         self.assertEqual(len(context.current_state_ids), 6)
 
@@ -546,7 +551,12 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_resolve_state_conflict(self):
-        event = create_event(type="test4", state_key="", name="event")
+        prev_event_id1 = "event_id1"
+        prev_event_id2 = "event_id2"
+        event = create_event(
+            type="test4", state_key="", name="event",
+            prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
+        )
 
         creation = create_event(
             type=EventTypes.Create, state_key=""
@@ -571,7 +581,9 @@ class StateTestCase(unittest.TestCase):
         store.register_events(old_state_2)
         self.store.get_events = store.get_events
 
-        context = yield self._get_context(event, old_state_1, old_state_2)
+        context = yield self._get_context(
+            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+        )
 
         self.assertEqual(len(context.current_state_ids), 6)
 
@@ -579,7 +591,12 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_standard_depth_conflict(self):
-        event = create_event(type="test4", name="event")
+        prev_event_id1 = "event_id1"
+        prev_event_id2 = "event_id2"
+        event = create_event(
+            type="test4", name="event",
+            prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
+        )
 
         member_event = create_event(
             type=EventTypes.Member,
@@ -611,7 +628,9 @@ class StateTestCase(unittest.TestCase):
         store.register_events(old_state_2)
         self.store.get_events = store.get_events
 
-        context = yield self._get_context(event, old_state_1, old_state_2)
+        context = yield self._get_context(
+            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+        )
 
         self.assertEqual(
             old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
@@ -635,19 +654,26 @@ class StateTestCase(unittest.TestCase):
         store.register_events(old_state_1)
         store.register_events(old_state_2)
 
-        context = yield self._get_context(event, old_state_1, old_state_2)
+        context = yield self._get_context(
+            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+        )
 
         self.assertEqual(
             old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
         )
 
-    def _get_context(self, event, old_state_1, old_state_2):
-        group_name_1 = "group_name_1"
-        group_name_2 = "group_name_2"
+    def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2,
+                     old_state_2):
+        sg1 = self.store.store_state_group(
+            prev_event_id_1, event.room_id, None, None,
+            {(e.type, e.state_key): e.event_id for e in old_state_1},
+        )
+        self.store.register_event_id_state_group(prev_event_id_1, sg1)
 
-        self.store.get_state_groups_ids.return_value = {
-            group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1},
-            group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2},
-        }
+        sg2 = self.store.store_state_group(
+            prev_event_id_2, event.room_id, None, None,
+            {(e.type, e.state_key): e.event_id for e in old_state_2},
+        )
+        self.store.register_event_id_state_group(prev_event_id_2, sg2)
 
         return self.state.compute_event_context(event)
diff --git a/tests/test_types.py b/tests/test_types.py
index 24d61dbe54..115def2287 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -17,7 +17,7 @@ from tests import unittest
 
 from synapse.api.errors import SynapseError
 from synapse.server import HomeServer
-from synapse.types import UserID, RoomAlias
+from synapse.types import UserID, RoomAlias, GroupID
 
 mock_homeserver = HomeServer(hostname="my.domain")
 
@@ -60,3 +60,25 @@ class RoomAliasTestCase(unittest.TestCase):
         room = RoomAlias("channel", "my.domain")
 
         self.assertEquals(room.to_string(), "#channel:my.domain")
+
+
+class GroupIDTestCase(unittest.TestCase):
+    def test_parse(self):
+        group_id = GroupID.from_string("+group/=_-.123:my.domain")
+        self.assertEqual("group/=_-.123", group_id.localpart)
+        self.assertEqual("my.domain", group_id.domain)
+
+    def test_validate(self):
+        bad_ids = [
+            "$badsigil:domain",
+            "+:empty",
+        ] + [
+            "+group" + c + ":domain" for c in "A%?æ£"
+        ]
+        for id_string in bad_ids:
+            try:
+                GroupID.from_string(id_string)
+                self.fail("Parsing '%s' should raise exception" % id_string)
+            except SynapseError as exc:
+                self.assertEqual(400, exc.code)
+                self.assertEqual("M_UNKNOWN", exc.errcode)
diff --git a/tests/unittest.py b/tests/unittest.py
index 38715972dd..7b478c4294 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import twisted
 from twisted.trial import unittest
 
 import logging
@@ -65,6 +65,10 @@ class TestCase(unittest.TestCase):
 
         @around(self)
         def setUp(orig):
+            # enable debugging of delayed calls - this means that we get a
+            # traceback when a unit test exits leaving things on the reactor.
+            twisted.internet.base.DelayedCall.debug = True
+
             old_level = logging.getLogger().level
 
             if old_level != level:
diff --git a/tests/util/caches/__init__.py b/tests/util/caches/__init__.py
new file mode 100644
index 0000000000..451dae3b6c
--- /dev/null
+++ b/tests/util/caches/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
new file mode 100644
index 0000000000..2516fe40f4
--- /dev/null
+++ b/tests/util/caches/test_descriptors.py
@@ -0,0 +1,261 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from functools import partial
+import logging
+
+import mock
+from synapse.api.errors import SynapseError
+from synapse.util import async
+from synapse.util import logcontext
+from twisted.internet import defer
+from synapse.util.caches import descriptors
+from tests import unittest
+
+logger = logging.getLogger(__name__)
+
+
+class CacheTestCase(unittest.TestCase):
+    def test_invalidate_all(self):
+        cache = descriptors.Cache("testcache")
+
+        callback_record = [False, False]
+
+        def record_callback(idx):
+            callback_record[idx] = True
+
+        # add a couple of pending entries
+        d1 = defer.Deferred()
+        cache.set("key1", d1, partial(record_callback, 0))
+
+        d2 = defer.Deferred()
+        cache.set("key2", d2, partial(record_callback, 1))
+
+        # lookup should return the deferreds
+        self.assertIs(cache.get("key1"), d1)
+        self.assertIs(cache.get("key2"), d2)
+
+        # let one of the lookups complete
+        d2.callback("result2")
+        self.assertEqual(cache.get("key2"), "result2")
+
+        # now do the invalidation
+        cache.invalidate_all()
+
+        # lookup should return none
+        self.assertIsNone(cache.get("key1", None))
+        self.assertIsNone(cache.get("key2", None))
+
+        # both callbacks should have been callbacked
+        self.assertTrue(
+            callback_record[0], "Invalidation callback for key1 not called",
+        )
+        self.assertTrue(
+            callback_record[1], "Invalidation callback for key2 not called",
+        )
+
+        # letting the other lookup complete should do nothing
+        d1.callback("result1")
+        self.assertIsNone(cache.get("key1", None))
+
+
+class DescriptorTestCase(unittest.TestCase):
+    @defer.inlineCallbacks
+    def test_cache(self):
+        class Cls(object):
+            def __init__(self):
+                self.mock = mock.Mock()
+
+            @descriptors.cached()
+            def fn(self, arg1, arg2):
+                return self.mock(arg1, arg2)
+
+        obj = Cls()
+
+        obj.mock.return_value = 'fish'
+        r = yield obj.fn(1, 2)
+        self.assertEqual(r, 'fish')
+        obj.mock.assert_called_once_with(1, 2)
+        obj.mock.reset_mock()
+
+        # a call with different params should call the mock again
+        obj.mock.return_value = 'chips'
+        r = yield obj.fn(1, 3)
+        self.assertEqual(r, 'chips')
+        obj.mock.assert_called_once_with(1, 3)
+        obj.mock.reset_mock()
+
+        # the two values should now be cached
+        r = yield obj.fn(1, 2)
+        self.assertEqual(r, 'fish')
+        r = yield obj.fn(1, 3)
+        self.assertEqual(r, 'chips')
+        obj.mock.assert_not_called()
+
+    @defer.inlineCallbacks
+    def test_cache_num_args(self):
+        """Only the first num_args arguments should matter to the cache"""
+
+        class Cls(object):
+            def __init__(self):
+                self.mock = mock.Mock()
+
+            @descriptors.cached(num_args=1)
+            def fn(self, arg1, arg2):
+                return self.mock(arg1, arg2)
+
+        obj = Cls()
+        obj.mock.return_value = 'fish'
+        r = yield obj.fn(1, 2)
+        self.assertEqual(r, 'fish')
+        obj.mock.assert_called_once_with(1, 2)
+        obj.mock.reset_mock()
+
+        # a call with different params should call the mock again
+        obj.mock.return_value = 'chips'
+        r = yield obj.fn(2, 3)
+        self.assertEqual(r, 'chips')
+        obj.mock.assert_called_once_with(2, 3)
+        obj.mock.reset_mock()
+
+        # the two values should now be cached; we should be able to vary
+        # the second argument and still get the cached result.
+        r = yield obj.fn(1, 4)
+        self.assertEqual(r, 'fish')
+        r = yield obj.fn(2, 5)
+        self.assertEqual(r, 'chips')
+        obj.mock.assert_not_called()
+
+    def test_cache_logcontexts(self):
+        """Check that logcontexts are set and restored correctly when
+        using the cache."""
+
+        complete_lookup = defer.Deferred()
+
+        class Cls(object):
+            @descriptors.cached()
+            def fn(self, arg1):
+                @defer.inlineCallbacks
+                def inner_fn():
+                    with logcontext.PreserveLoggingContext():
+                        yield complete_lookup
+                    defer.returnValue(1)
+
+                return inner_fn()
+
+        @defer.inlineCallbacks
+        def do_lookup():
+            with logcontext.LoggingContext() as c1:
+                c1.name = "c1"
+                r = yield obj.fn(1)
+                self.assertEqual(logcontext.LoggingContext.current_context(),
+                                 c1)
+            defer.returnValue(r)
+
+        def check_result(r):
+            self.assertEqual(r, 1)
+
+        obj = Cls()
+
+        # set off a deferred which will do a cache lookup
+        d1 = do_lookup()
+        self.assertEqual(logcontext.LoggingContext.current_context(),
+                         logcontext.LoggingContext.sentinel)
+        d1.addCallback(check_result)
+
+        # and another
+        d2 = do_lookup()
+        self.assertEqual(logcontext.LoggingContext.current_context(),
+                         logcontext.LoggingContext.sentinel)
+        d2.addCallback(check_result)
+
+        # let the lookup complete
+        complete_lookup.callback(None)
+
+        return defer.gatherResults([d1, d2])
+
+    def test_cache_logcontexts_with_exception(self):
+        """Check that the cache sets and restores logcontexts correctly when
+        the lookup function throws an exception"""
+
+        class Cls(object):
+            @descriptors.cached()
+            def fn(self, arg1):
+                @defer.inlineCallbacks
+                def inner_fn():
+                    yield async.run_on_reactor()
+                    raise SynapseError(400, "blah")
+
+                return inner_fn()
+
+        @defer.inlineCallbacks
+        def do_lookup():
+            with logcontext.LoggingContext() as c1:
+                c1.name = "c1"
+                try:
+                    yield obj.fn(1)
+                    self.fail("No exception thrown")
+                except SynapseError:
+                    pass
+
+                self.assertEqual(logcontext.LoggingContext.current_context(),
+                                 c1)
+
+        obj = Cls()
+
+        # set off a deferred which will do a cache lookup
+        d1 = do_lookup()
+        self.assertEqual(logcontext.LoggingContext.current_context(),
+                         logcontext.LoggingContext.sentinel)
+
+        return d1
+
+    @defer.inlineCallbacks
+    def test_cache_default_args(self):
+        class Cls(object):
+            def __init__(self):
+                self.mock = mock.Mock()
+
+            @descriptors.cached()
+            def fn(self, arg1, arg2=2, arg3=3):
+                return self.mock(arg1, arg2, arg3)
+
+        obj = Cls()
+
+        obj.mock.return_value = 'fish'
+        r = yield obj.fn(1, 2, 3)
+        self.assertEqual(r, 'fish')
+        obj.mock.assert_called_once_with(1, 2, 3)
+        obj.mock.reset_mock()
+
+        # a call with same params shouldn't call the mock again
+        r = yield obj.fn(1, 2)
+        self.assertEqual(r, 'fish')
+        obj.mock.assert_not_called()
+        obj.mock.reset_mock()
+
+        # a call with different params should call the mock again
+        obj.mock.return_value = 'chips'
+        r = yield obj.fn(2, 3)
+        self.assertEqual(r, 'chips')
+        obj.mock.assert_called_once_with(2, 3, 3)
+        obj.mock.reset_mock()
+
+        # the two values should now be cached
+        r = yield obj.fn(1, 2)
+        self.assertEqual(r, 'fish')
+        r = yield obj.fn(2, 3)
+        self.assertEqual(r, 'chips')
+        obj.mock.assert_not_called()
diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py
index 272b71034a..bc92f85fa6 100644
--- a/tests/util/test_dict_cache.py
+++ b/tests/util/test_dict_cache.py
@@ -28,7 +28,7 @@ class DictCacheTestCase(unittest.TestCase):
         key = "test_simple_cache_hit_full"
 
         v = self.cache.get(key)
-        self.assertEqual((False, {}), v)
+        self.assertEqual((False, set(), {}), v)
 
         seq = self.cache.sequence
         test_value = {"test": "test_simple_cache_hit_full"}
diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py
new file mode 100644
index 0000000000..d6e1082779
--- /dev/null
+++ b/tests/util/test_file_consumer.py
@@ -0,0 +1,176 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet import defer, reactor
+from mock import NonCallableMock
+
+from synapse.util.file_consumer import BackgroundFileConsumer
+
+from tests import unittest
+from six import StringIO
+
+import threading
+
+
+class FileConsumerTests(unittest.TestCase):
+
+    @defer.inlineCallbacks
+    def test_pull_consumer(self):
+        string_file = StringIO()
+        consumer = BackgroundFileConsumer(string_file)
+
+        try:
+            producer = DummyPullProducer()
+
+            yield producer.register_with_consumer(consumer)
+
+            yield producer.write_and_wait("Foo")
+
+            self.assertEqual(string_file.getvalue(), "Foo")
+
+            yield producer.write_and_wait("Bar")
+
+            self.assertEqual(string_file.getvalue(), "FooBar")
+        finally:
+            consumer.unregisterProducer()
+
+        yield consumer.wait()
+
+        self.assertTrue(string_file.closed)
+
+    @defer.inlineCallbacks
+    def test_push_consumer(self):
+        string_file = BlockingStringWrite()
+        consumer = BackgroundFileConsumer(string_file)
+
+        try:
+            producer = NonCallableMock(spec_set=[])
+
+            consumer.registerProducer(producer, True)
+
+            consumer.write("Foo")
+            yield string_file.wait_for_n_writes(1)
+
+            self.assertEqual(string_file.buffer, "Foo")
+
+            consumer.write("Bar")
+            yield string_file.wait_for_n_writes(2)
+
+            self.assertEqual(string_file.buffer, "FooBar")
+        finally:
+            consumer.unregisterProducer()
+
+        yield consumer.wait()
+
+        self.assertTrue(string_file.closed)
+
+    @defer.inlineCallbacks
+    def test_push_producer_feedback(self):
+        string_file = BlockingStringWrite()
+        consumer = BackgroundFileConsumer(string_file)
+
+        try:
+            producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
+
+            resume_deferred = defer.Deferred()
+            producer.resumeProducing.side_effect = lambda: resume_deferred.callback(None)
+
+            consumer.registerProducer(producer, True)
+
+            number_writes = 0
+            with string_file.write_lock:
+                for _ in range(consumer._PAUSE_ON_QUEUE_SIZE):
+                    consumer.write("Foo")
+                    number_writes += 1
+
+                producer.pauseProducing.assert_called_once()
+
+            yield string_file.wait_for_n_writes(number_writes)
+
+            yield resume_deferred
+            producer.resumeProducing.assert_called_once()
+        finally:
+            consumer.unregisterProducer()
+
+        yield consumer.wait()
+
+        self.assertTrue(string_file.closed)
+
+
+class DummyPullProducer(object):
+    def __init__(self):
+        self.consumer = None
+        self.deferred = defer.Deferred()
+
+    def resumeProducing(self):
+        d = self.deferred
+        self.deferred = defer.Deferred()
+        d.callback(None)
+
+    def write_and_wait(self, bytes):
+        d = self.deferred
+        self.consumer.write(bytes)
+        return d
+
+    def register_with_consumer(self, consumer):
+        d = self.deferred
+        self.consumer = consumer
+        self.consumer.registerProducer(self, False)
+        return d
+
+
+class BlockingStringWrite(object):
+    def __init__(self):
+        self.buffer = ""
+        self.closed = False
+        self.write_lock = threading.Lock()
+
+        self._notify_write_deferred = None
+        self._number_of_writes = 0
+
+    def write(self, bytes):
+        with self.write_lock:
+            self.buffer += bytes
+            self._number_of_writes += 1
+
+        reactor.callFromThread(self._notify_write)
+
+    def close(self):
+        self.closed = True
+
+    def _notify_write(self):
+        "Called by write to indicate a write happened"
+        with self.write_lock:
+            if not self._notify_write_deferred:
+                return
+            d = self._notify_write_deferred
+            self._notify_write_deferred = None
+        d.callback(None)
+
+    @defer.inlineCallbacks
+    def wait_for_n_writes(self, n):
+        "Wait for n writes to have happened"
+        while True:
+            with self.write_lock:
+                if n <= self._number_of_writes:
+                    return
+
+                if not self._notify_write_deferred:
+                    self._notify_write_deferred = defer.Deferred()
+
+                d = self._notify_write_deferred
+
+            yield d
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index afcba482f9..4865eb4bc6 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -12,13 +12,13 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-
+from synapse.util import async, logcontext
 from tests import unittest
 
 from twisted.internet import defer
 
 from synapse.util.async import Linearizer
+from six.moves import range
 
 
 class LinearizerTestCase(unittest.TestCase):
@@ -38,7 +38,28 @@ class LinearizerTestCase(unittest.TestCase):
         with cm1:
             self.assertFalse(d2.called)
 
-        self.assertTrue(d2.called)
-
         with (yield d2):
             pass
+
+    def test_lots_of_queued_things(self):
+        # we have one slow thing, and lots of fast things queued up behind it.
+        # it should *not* explode the stack.
+        linearizer = Linearizer()
+
+        @defer.inlineCallbacks
+        def func(i, sleep=False):
+            with logcontext.LoggingContext("func(%s)" % i) as lc:
+                with (yield linearizer.queue("")):
+                    self.assertEqual(
+                        logcontext.LoggingContext.current_context(), lc)
+                    if sleep:
+                        yield async.sleep(0)
+
+                self.assertEqual(
+                    logcontext.LoggingContext.current_context(), lc)
+
+        func(0, sleep=True)
+        for i in range(1, 100):
+            func(i)
+
+        return func(1000)
diff --git a/tests/util/test_log_context.py b/tests/util/test_log_context.py
deleted file mode 100644
index 65a330a0e9..0000000000
--- a/tests/util/test_log_context.py
+++ /dev/null
@@ -1,35 +0,0 @@
-from twisted.internet import defer
-from twisted.internet import reactor
-from .. import unittest
-
-from synapse.util.async import sleep
-from synapse.util.logcontext import LoggingContext
-
-
-class LoggingContextTestCase(unittest.TestCase):
-
-    def _check_test_key(self, value):
-        self.assertEquals(
-            LoggingContext.current_context().test_key, value
-        )
-
-    def test_with_context(self):
-        with LoggingContext() as context_one:
-            context_one.test_key = "test"
-            self._check_test_key("test")
-
-    @defer.inlineCallbacks
-    def test_sleep(self):
-        @defer.inlineCallbacks
-        def competing_callback():
-            with LoggingContext() as competing_context:
-                competing_context.test_key = "competing"
-                yield sleep(0)
-                self._check_test_key("competing")
-
-        reactor.callLater(0, competing_callback)
-
-        with LoggingContext() as context_one:
-            context_one.test_key = "one"
-            yield sleep(0)
-            self._check_test_key("one")
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
new file mode 100644
index 0000000000..ad78d884e0
--- /dev/null
+++ b/tests/util/test_logcontext.py
@@ -0,0 +1,178 @@
+import twisted.python.failure
+from twisted.internet import defer
+from twisted.internet import reactor
+from .. import unittest
+
+from synapse.util.async import sleep
+from synapse.util import logcontext
+from synapse.util.logcontext import LoggingContext
+
+
+class LoggingContextTestCase(unittest.TestCase):
+
+    def _check_test_key(self, value):
+        self.assertEquals(
+            LoggingContext.current_context().request, value
+        )
+
+    def test_with_context(self):
+        with LoggingContext() as context_one:
+            context_one.request = "test"
+            self._check_test_key("test")
+
+    @defer.inlineCallbacks
+    def test_sleep(self):
+        @defer.inlineCallbacks
+        def competing_callback():
+            with LoggingContext() as competing_context:
+                competing_context.request = "competing"
+                yield sleep(0)
+                self._check_test_key("competing")
+
+        reactor.callLater(0, competing_callback)
+
+        with LoggingContext() as context_one:
+            context_one.request = "one"
+            yield sleep(0)
+            self._check_test_key("one")
+
+    def _test_run_in_background(self, function):
+        sentinel_context = LoggingContext.current_context()
+
+        callback_completed = [False]
+
+        def test():
+            context_one.request = "one"
+            d = function()
+
+            def cb(res):
+                self._check_test_key("one")
+                callback_completed[0] = True
+                return res
+            d.addCallback(cb)
+
+            return d
+
+        with LoggingContext() as context_one:
+            context_one.request = "one"
+
+            # fire off function, but don't wait on it.
+            logcontext.run_in_background(test)
+
+            self._check_test_key("one")
+
+        # now wait for the function under test to have run, and check that
+        # the logcontext is left in a sane state.
+        d2 = defer.Deferred()
+
+        def check_logcontext():
+            if not callback_completed[0]:
+                reactor.callLater(0.01, check_logcontext)
+                return
+
+            # make sure that the context was reset before it got thrown back
+            # into the reactor
+            try:
+                self.assertIs(LoggingContext.current_context(),
+                              sentinel_context)
+                d2.callback(None)
+            except BaseException:
+                d2.errback(twisted.python.failure.Failure())
+
+        reactor.callLater(0.01, check_logcontext)
+
+        # test is done once d2 finishes
+        return d2
+
+    def test_run_in_background_with_blocking_fn(self):
+        @defer.inlineCallbacks
+        def blocking_function():
+            yield sleep(0)
+
+        return self._test_run_in_background(blocking_function)
+
+    def test_run_in_background_with_non_blocking_fn(self):
+        @defer.inlineCallbacks
+        def nonblocking_function():
+            with logcontext.PreserveLoggingContext():
+                yield defer.succeed(None)
+
+        return self._test_run_in_background(nonblocking_function)
+
+    def test_run_in_background_with_chained_deferred(self):
+        # a function which returns a deferred which looks like it has been
+        # called, but is actually paused
+        def testfunc():
+            return logcontext.make_deferred_yieldable(
+                _chained_deferred_function()
+            )
+
+        return self._test_run_in_background(testfunc)
+
+    @defer.inlineCallbacks
+    def test_make_deferred_yieldable(self):
+        # a function which retuns an incomplete deferred, but doesn't follow
+        # the synapse rules.
+        def blocking_function():
+            d = defer.Deferred()
+            reactor.callLater(0, d.callback, None)
+            return d
+
+        sentinel_context = LoggingContext.current_context()
+
+        with LoggingContext() as context_one:
+            context_one.request = "one"
+
+            d1 = logcontext.make_deferred_yieldable(blocking_function())
+            # make sure that the context was reset by make_deferred_yieldable
+            self.assertIs(LoggingContext.current_context(), sentinel_context)
+
+            yield d1
+
+            # now it should be restored
+            self._check_test_key("one")
+
+    @defer.inlineCallbacks
+    def test_make_deferred_yieldable_with_chained_deferreds(self):
+        sentinel_context = LoggingContext.current_context()
+
+        with LoggingContext() as context_one:
+            context_one.request = "one"
+
+            d1 = logcontext.make_deferred_yieldable(_chained_deferred_function())
+            # make sure that the context was reset by make_deferred_yieldable
+            self.assertIs(LoggingContext.current_context(), sentinel_context)
+
+            yield d1
+
+            # now it should be restored
+            self._check_test_key("one")
+
+    @defer.inlineCallbacks
+    def test_make_deferred_yieldable_on_non_deferred(self):
+        """Check that make_deferred_yieldable does the right thing when its
+        argument isn't actually a deferred"""
+
+        with LoggingContext() as context_one:
+            context_one.request = "one"
+
+            d1 = logcontext.make_deferred_yieldable("bum")
+            self._check_test_key("one")
+
+            r = yield d1
+            self.assertEqual(r, "bum")
+            self._check_test_key("one")
+
+
+# a function which returns a deferred which has been "called", but
+# which had a function which returned another incomplete deferred on
+# its callback list, so won't yet call any other new callbacks.
+def _chained_deferred_function():
+    d = defer.succeed(None)
+
+    def cb(res):
+        d2 = defer.Deferred()
+        reactor.callLater(0, d2.callback, res)
+        return d2
+    d.addCallback(cb)
+    return d
diff --git a/tests/util/test_logformatter.py b/tests/util/test_logformatter.py
new file mode 100644
index 0000000000..1a1a8412f2
--- /dev/null
+++ b/tests/util/test_logformatter.py
@@ -0,0 +1,38 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+
+from synapse.util.logformatter import LogFormatter
+from tests import unittest
+
+
+class TestException(Exception):
+    pass
+
+
+class LogFormatterTestCase(unittest.TestCase):
+    def test_formatter(self):
+        formatter = LogFormatter()
+
+        try:
+            raise TestException("testytest")
+        except TestException:
+            ei = sys.exc_info()
+
+        output = formatter.formatException(ei)
+
+        # check the output looks vaguely sane
+        self.assertIn("testytest", output)
+        self.assertIn("Capture point", output)
diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py
index 7e289715ba..d3a8630c2f 100644
--- a/tests/util/test_snapshot_cache.py
+++ b/tests/util/test_snapshot_cache.py
@@ -53,7 +53,9 @@ class SnapshotCacheTestCase(unittest.TestCase):
         # before the cache expires returns a resolved deferred.
         get_result_at_11 = self.cache.get(11, "key")
         self.assertIsNotNone(get_result_at_11)
-        self.assertTrue(get_result_at_11.called)
+        if isinstance(get_result_at_11, Deferred):
+            # The cache may return the actual result rather than a deferred
+            self.assertTrue(get_result_at_11.called)
 
         # Check that getting the key after the deferred has resolved
         # after the cache expires returns None
diff --git a/tests/utils.py b/tests/utils.py
index d3d6c8021d..c2beb5d9f7 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -13,27 +13,27 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.http.server import HttpServer
-from synapse.api.errors import cs_error, CodeMessageException, StoreError
-from synapse.api.constants import EventTypes
-from synapse.storage.prepare_database import prepare_database
-from synapse.storage.engines import create_engine
-from synapse.server import HomeServer
-from synapse.federation.transport import server
-from synapse.util.ratelimitutils import FederationRateLimiter
-
-from synapse.util.logcontext import LoggingContext
+import hashlib
+from inspect import getcallargs
+from six.moves.urllib import parse as urlparse
 
+from mock import Mock, patch
 from twisted.internet import defer, reactor
-from twisted.enterprise.adbapi import ConnectionPool
 
-from collections import namedtuple
-from mock import patch, Mock
-import hashlib
-import urllib
-import urlparse
+from synapse.api.errors import CodeMessageException, cs_error
+from synapse.federation.transport import server
+from synapse.http.server import HttpServer
+from synapse.server import HomeServer
+from synapse.storage import PostgresEngine
+from synapse.storage.engines import create_engine
+from synapse.storage.prepare_database import prepare_database
+from synapse.util.logcontext import LoggingContext
+from synapse.util.ratelimitutils import FederationRateLimiter
 
-from inspect import getcallargs
+# set this to True to run the tests against postgres instead of sqlite.
+# It requires you to have a local postgres database called synapse_test, within
+# which ALL TABLES WILL BE DROPPED
+USE_POSTGRES_FOR_TESTS = False
 
 
 @defer.inlineCallbacks
@@ -55,32 +55,76 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
         config.password_providers = []
         config.worker_replication_url = ""
         config.worker_app = None
+        config.email_enable_notifs = False
+        config.block_non_admin_invites = False
+        config.federation_domain_whitelist = None
+        config.federation_rc_reject_limit = 10
+        config.federation_rc_sleep_limit = 10
+        config.federation_rc_concurrent = 10
+        config.filter_timeline_limit = 5000
+        config.user_directory_search_all_users = False
+
+        # disable user directory updates, because they get done in the
+        # background, which upsets the test runner.
+        config.update_user_directory = False
 
     config.use_frozen_dicts = True
-    config.database_config = {"name": "sqlite3"}
     config.ldap_enabled = False
 
     if "clock" not in kargs:
         kargs["clock"] = MockClock()
 
+    if USE_POSTGRES_FOR_TESTS:
+        config.database_config = {
+            "name": "psycopg2",
+            "args": {
+                "database": "synapse_test",
+                "cp_min": 1,
+                "cp_max": 5,
+            },
+        }
+    else:
+        config.database_config = {
+            "name": "sqlite3",
+            "args": {
+                "database": ":memory:",
+                "cp_min": 1,
+                "cp_max": 1,
+            },
+        }
+
+    db_engine = create_engine(config.database_config)
+
+    # we need to configure the connection pool to run the on_new_connection
+    # function, so that we can test code that uses custom sqlite functions
+    # (like rank).
+    config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
+
     if datastore is None:
-        db_pool = SQLiteMemoryDbPool()
-        yield db_pool.prepare()
         hs = HomeServer(
-            name, db_pool=db_pool, config=config,
+            name, config=config,
+            db_config=config.database_config,
             version_string="Synapse/tests",
-            database_engine=create_engine(config.database_config),
-            get_db_conn=db_pool.get_db_conn,
+            database_engine=db_engine,
             room_list_handler=object(),
             tls_server_context_factory=Mock(),
             **kargs
         )
+        db_conn = hs.get_db_conn()
+        # make sure that the database is empty
+        if isinstance(db_engine, PostgresEngine):
+            cur = db_conn.cursor()
+            cur.execute("SELECT tablename FROM pg_tables where schemaname='public'")
+            rows = cur.fetchall()
+            for r in rows:
+                cur.execute("DROP TABLE %s CASCADE" % r[0])
+        yield prepare_database(db_conn, db_engine, config)
         hs.setup()
     else:
         hs = HomeServer(
             name, db_pool=None, datastore=datastore, config=config,
             version_string="Synapse/tests",
-            database_engine=create_engine(config.database_config),
+            database_engine=db_engine,
             room_list_handler=object(),
             tls_server_context_factory=Mock(),
             **kargs
@@ -171,7 +215,7 @@ class MockHttpResource(HttpServer):
 
         headers = {}
         if federation_auth:
-            headers["Authorization"] = ["X-Matrix origin=test,key=,sig="]
+            headers[b"Authorization"] = ["X-Matrix origin=test,key=,sig="]
         mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
 
         # return the right path if the event requires it
@@ -182,7 +226,7 @@ class MockHttpResource(HttpServer):
             mock_request.args = urlparse.parse_qs(path.split('?')[1])
             mock_request.path = path.split('?')[0]
             path = mock_request.path
-        except:
+        except Exception:
             pass
 
         for (method, pattern, func) in self.callbacks:
@@ -193,7 +237,7 @@ class MockHttpResource(HttpServer):
             if matcher:
                 try:
                     args = [
-                        urllib.unquote(u).decode("UTF-8")
+                        urlparse.unquote(u).decode("UTF-8")
                         for u in matcher.groups()
                     ]
 
@@ -299,167 +343,6 @@ class MockClock(object):
         return d
 
 
-class SQLiteMemoryDbPool(ConnectionPool, object):
-    def __init__(self):
-        super(SQLiteMemoryDbPool, self).__init__(
-            "sqlite3", ":memory:",
-            cp_min=1,
-            cp_max=1,
-        )
-
-        self.config = Mock()
-        self.config.database_config = {"name": "sqlite3"}
-
-    def prepare(self):
-        engine = self.create_engine()
-        return self.runWithConnection(
-            lambda conn: prepare_database(conn, engine, self.config)
-        )
-
-    def get_db_conn(self):
-        conn = self.connect()
-        engine = self.create_engine()
-        prepare_database(conn, engine, self.config)
-        return conn
-
-    def create_engine(self):
-        return create_engine(self.config.database_config)
-
-
-class MemoryDataStore(object):
-
-    Room = namedtuple(
-        "Room",
-        ["room_id", "is_public", "creator"]
-    )
-
-    def __init__(self):
-        self.tokens_to_users = {}
-        self.paths_to_content = {}
-
-        self.members = {}
-        self.rooms = {}
-
-        self.current_state = {}
-        self.events = []
-
-    class Snapshot(namedtuple("Snapshot", "room_id user_id membership_state")):
-        def fill_out_prev_events(self, event):
-            pass
-
-    def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
-        return self.Snapshot(
-            room_id, user_id, self.get_room_member(user_id, room_id)
-        )
-
-    def register(self, user_id, token, password_hash):
-        if user_id in self.tokens_to_users.values():
-            raise StoreError(400, "User in use.")
-        self.tokens_to_users[token] = user_id
-
-    def get_user_by_access_token(self, token):
-        try:
-            return {
-                "name": self.tokens_to_users[token],
-            }
-        except:
-            raise StoreError(400, "User does not exist.")
-
-    def get_room(self, room_id):
-        try:
-            return self.rooms[room_id]
-        except:
-            return None
-
-    def store_room(self, room_id, room_creator_user_id, is_public):
-        if room_id in self.rooms:
-            raise StoreError(409, "Conflicting room!")
-
-        room = MemoryDataStore.Room(
-            room_id=room_id,
-            is_public=is_public,
-            creator=room_creator_user_id
-        )
-        self.rooms[room_id] = room
-
-    def get_room_member(self, user_id, room_id):
-        return self.members.get(room_id, {}).get(user_id)
-
-    def get_room_members(self, room_id, membership=None):
-        if membership:
-            return [
-                v for k, v in self.members.get(room_id, {}).items()
-                if v.membership == membership
-            ]
-        else:
-            return self.members.get(room_id, {}).values()
-
-    def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
-        return [
-            m[user_id] for m in self.members.values()
-            if user_id in m and m[user_id].membership in membership_list
-        ]
-
-    def get_room_events_stream(self, user_id=None, from_key=None, to_key=None,
-                               limit=0, with_feedback=False):
-        return ([], from_key)  # TODO
-
-    def get_joined_hosts_for_room(self, room_id):
-        return defer.succeed([])
-
-    def persist_event(self, event):
-        if event.type == EventTypes.Member:
-            room_id = event.room_id
-            user = event.state_key
-            self.members.setdefault(room_id, {})[user] = event
-
-        if hasattr(event, "state_key"):
-            key = (event.room_id, event.type, event.state_key)
-            self.current_state[key] = event
-
-        self.events.append(event)
-
-    def get_current_state(self, room_id, event_type=None, state_key=""):
-        if event_type:
-            key = (room_id, event_type, state_key)
-            if self.current_state.get(key):
-                return [self.current_state.get(key)]
-            return None
-        else:
-            return [
-                e for e in self.current_state
-                if e[0] == room_id
-            ]
-
-    def set_presence_state(self, user_localpart, state):
-        return defer.succeed({"state": 0})
-
-    def get_presence_list(self, user_localpart, accepted):
-        return []
-
-    def get_room_events_max_id(self):
-        return "s0"  # TODO (erikj)
-
-    def get_send_event_level(self, room_id):
-        return defer.succeed(0)
-
-    def get_power_level(self, room_id, user_id):
-        return defer.succeed(0)
-
-    def get_add_state_level(self, room_id):
-        return defer.succeed(0)
-
-    def get_room_join_rule(self, room_id):
-        # TODO (erikj): This should be configurable
-        return defer.succeed("invite")
-
-    def get_ops_levels(self, room_id):
-        return defer.succeed((5, 5, 5))
-
-    def insert_client_ip(self, user, access_token, ip, user_agent):
-        return defer.succeed(None)
-
-
 def _format_call(args, kwargs):
     return ", ".join(
         ["%r" % (a) for a in args] +
@@ -497,7 +380,7 @@ class DeferredMockCallable(object):
         for _, _, d in self.expectations:
             try:
                 d.errback(failure)
-            except:
+            except Exception:
                 pass
 
         raise failure