summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/config/test_generate.py17
-rw-r--r--tests/crypto/test_keyring.py18
-rw-r--r--tests/handlers/test_directory.py11
-rw-r--r--tests/handlers/test_e2e_keys.py3
-rw-r--r--tests/handlers/test_profile.py12
-rw-r--r--tests/handlers/test_typing.py2
-rw-r--r--tests/metrics/test_metric.py12
-rw-r--r--tests/replication/slave/storage/_base.py10
-rw-r--r--tests/replication/slave/storage/test_events.py11
-rw-r--r--tests/rest/client/v1/test_events.py2
-rw-r--r--tests/rest/client/v1/test_profile.py2
-rw-r--r--tests/rest/client/v1/test_rooms.py19
-rw-r--r--tests/rest/client/v1/test_typing.py4
-rw-r--r--tests/rest/client/v2_alpha/test_register.py1
-rw-r--r--tests/rest/media/__init__.py14
-rw-r--r--tests/rest/media/v1/__init__.py14
-rw-r--r--tests/rest/media/v1/test_media_storage.py86
-rw-r--r--tests/storage/test_appservice.py10
-rw-r--r--tests/storage/test_event_push_actions.py77
-rw-r--r--tests/storage/test_redaction.py9
-rw-r--r--tests/storage/test_roommember.py5
-rw-r--r--tests/storage/test_user_directory.py88
-rw-r--r--tests/test_state.py158
-rw-r--r--tests/unittest.py6
-rw-r--r--tests/util/test_file_consumer.py176
-rw-r--r--tests/util/test_logcontext.py16
-rw-r--r--tests/utils.py249
27 files changed, 706 insertions, 326 deletions
diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py
index 8f57fbeb23..879159ccea 100644
--- a/tests/config/test_generate.py
+++ b/tests/config/test_generate.py
@@ -12,9 +12,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 import os.path
+import re
 import shutil
 import tempfile
+
 from synapse.config.homeserver import HomeServerConfig
 from tests import unittest
 
@@ -23,7 +26,6 @@ class ConfigGenerationTestCase(unittest.TestCase):
 
     def setUp(self):
         self.dir = tempfile.mkdtemp()
-        print self.dir
         self.file = os.path.join(self.dir, "homeserver.yaml")
 
     def tearDown(self):
@@ -48,3 +50,16 @@ class ConfigGenerationTestCase(unittest.TestCase):
             ]),
             set(os.listdir(self.dir))
         )
+
+        self.assert_log_filename_is(
+            os.path.join(self.dir, "lemurs.win.log.config"),
+            os.path.join(os.getcwd(), "homeserver.log"),
+        )
+
+    def assert_log_filename_is(self, log_config_file, expected):
+        with open(log_config_file) as f:
+            config = f.read()
+            # find the 'filename' line
+            matches = re.findall("^\s*filename:\s*(.*)$", config, re.M)
+            self.assertEqual(1, len(matches))
+            self.assertEqual(matches[0], expected)
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 570312da84..d4ec02ffc2 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -68,7 +68,7 @@ class KeyringTestCase(unittest.TestCase):
 
     def check_context(self, _, expected):
         self.assertEquals(
-            getattr(LoggingContext.current_context(), "test_key", None),
+            getattr(LoggingContext.current_context(), "request", None),
             expected
         )
 
@@ -82,7 +82,7 @@ class KeyringTestCase(unittest.TestCase):
         lookup_2_deferred = defer.Deferred()
 
         with LoggingContext("one") as context_one:
-            context_one.test_key = "one"
+            context_one.request = "one"
 
             wait_1_deferred = kr.wait_for_previous_lookups(
                 ["server1"],
@@ -96,7 +96,7 @@ class KeyringTestCase(unittest.TestCase):
             wait_1_deferred.addBoth(self.check_context, "one")
 
         with LoggingContext("two") as context_two:
-            context_two.test_key = "two"
+            context_two.request = "two"
 
             # set off another wait. It should block because the first lookup
             # hasn't yet completed.
@@ -137,7 +137,7 @@ class KeyringTestCase(unittest.TestCase):
         @defer.inlineCallbacks
         def get_perspectives(**kwargs):
             self.assertEquals(
-                LoggingContext.current_context().test_key, "11",
+                LoggingContext.current_context().request, "11",
             )
             with logcontext.PreserveLoggingContext():
                 yield persp_deferred
@@ -145,7 +145,7 @@ class KeyringTestCase(unittest.TestCase):
         self.http_client.post_json.side_effect = get_perspectives
 
         with LoggingContext("11") as context_11:
-            context_11.test_key = "11"
+            context_11.request = "11"
 
             # start off a first set of lookups
             res_deferreds = kr.verify_json_objects_for_server(
@@ -167,13 +167,13 @@ class KeyringTestCase(unittest.TestCase):
 
             # wait a tick for it to send the request to the perspectives server
             # (it first tries the datastore)
-            yield async.sleep(0.005)
+            yield async.sleep(1)   # XXX find out why this takes so long!
             self.http_client.post_json.assert_called_once()
 
             self.assertIs(LoggingContext.current_context(), context_11)
 
             context_12 = LoggingContext("12")
-            context_12.test_key = "12"
+            context_12.request = "12"
             with logcontext.PreserveLoggingContext(context_12):
                 # a second request for a server with outstanding requests
                 # should block rather than start a second call
@@ -183,7 +183,7 @@ class KeyringTestCase(unittest.TestCase):
                 res_deferreds_2 = kr.verify_json_objects_for_server(
                     [("server10", json1)],
                 )
-                yield async.sleep(0.005)
+                yield async.sleep(01)
                 self.http_client.post_json.assert_not_called()
                 res_deferreds_2[0].addBoth(self.check_context, None)
 
@@ -211,7 +211,7 @@ class KeyringTestCase(unittest.TestCase):
         sentinel_context = LoggingContext.current_context()
 
         with LoggingContext("one") as context_one:
-            context_one.test_key = "one"
+            context_one.request = "one"
 
             defer = kr.verify_json_for_server("server9", {})
             try:
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 5712773909..7e5332e272 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -35,21 +35,20 @@ class DirectoryTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def setUp(self):
-        self.mock_federation = Mock(spec=[
-            "make_query",
-            "register_edu_handler",
-        ])
+        self.mock_federation = Mock()
+        self.mock_registry = Mock()
 
         self.query_handlers = {}
 
         def register_query_handler(query_type, handler):
             self.query_handlers[query_type] = handler
-        self.mock_federation.register_query_handler = register_query_handler
+        self.mock_registry.register_query_handler = register_query_handler
 
         hs = yield setup_test_homeserver(
             http_client=None,
             resource_for_federation=Mock(),
-            replication_layer=self.mock_federation,
+            federation_client=self.mock_federation,
+            federation_registry=self.mock_registry,
         )
         hs.handlers = DirectoryHandlers(hs)
 
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 19f5ed6bce..d1bd87b898 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -34,7 +34,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
     def setUp(self):
         self.hs = yield utils.setup_test_homeserver(
             handlers=None,
-            replication_layer=mock.Mock(),
+            federation_client=mock.Mock(),
         )
         self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs)
 
@@ -143,7 +143,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         except errors.SynapseError:
             pass
 
-    @unittest.DEBUG
     @defer.inlineCallbacks
     def test_claim_one_time_key(self):
         local_user = "@boris:" + self.hs.hostname
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index a5f47181d7..458296ee4c 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -37,23 +37,23 @@ class ProfileTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def setUp(self):
-        self.mock_federation = Mock(spec=[
-            "make_query",
-            "register_edu_handler",
-        ])
+        self.mock_federation = Mock()
+        self.mock_registry = Mock()
 
         self.query_handlers = {}
 
         def register_query_handler(query_type, handler):
             self.query_handlers[query_type] = handler
 
-        self.mock_federation.register_query_handler = register_query_handler
+        self.mock_registry.register_query_handler = register_query_handler
 
         hs = yield setup_test_homeserver(
             http_client=None,
             handlers=None,
             resource_for_federation=Mock(),
-            replication_layer=self.mock_federation,
+            federation_client=self.mock_federation,
+            federation_server=Mock(),
+            federation_registry=self.mock_registry,
             ratelimiter=NonCallableMock(spec_set=[
                 "send_message",
             ])
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index fcd380b03a..a433bbfa8a 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -81,7 +81,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
                 "get_current_state_deltas",
             ]),
             state_handler=self.state_handler,
-            handlers=None,
+            handlers=Mock(),
             notifier=mock_notifier,
             resource_for_client=Mock(),
             resource_for_federation=self.mock_federation_resource,
diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py
index f85455a5af..39bde6e3f8 100644
--- a/tests/metrics/test_metric.py
+++ b/tests/metrics/test_metric.py
@@ -141,6 +141,7 @@ class CacheMetricTestCase(unittest.TestCase):
             'cache:hits{name="cache_name"} 0',
             'cache:total{name="cache_name"} 0',
             'cache:size{name="cache_name"} 0',
+            'cache:evicted_size{name="cache_name"} 0',
         ])
 
         metric.inc_misses()
@@ -150,6 +151,7 @@ class CacheMetricTestCase(unittest.TestCase):
             'cache:hits{name="cache_name"} 0',
             'cache:total{name="cache_name"} 1',
             'cache:size{name="cache_name"} 1',
+            'cache:evicted_size{name="cache_name"} 0',
         ])
 
         metric.inc_hits()
@@ -158,4 +160,14 @@ class CacheMetricTestCase(unittest.TestCase):
             'cache:hits{name="cache_name"} 1',
             'cache:total{name="cache_name"} 2',
             'cache:size{name="cache_name"} 1',
+            'cache:evicted_size{name="cache_name"} 0',
+        ])
+
+        metric.inc_evictions(2)
+
+        self.assertEquals(metric.render(), [
+            'cache:hits{name="cache_name"} 1',
+            'cache:total{name="cache_name"} 2',
+            'cache:size{name="cache_name"} 1',
+            'cache:evicted_size{name="cache_name"} 2',
         ])
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 81063f19a1..64e07a8c93 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -15,6 +15,8 @@
 from twisted.internet import defer, reactor
 from tests import unittest
 
+import tempfile
+
 from mock import Mock, NonCallableMock
 from tests.utils import setup_test_homeserver
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -29,7 +31,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
         self.hs = yield setup_test_homeserver(
             "blue",
             http_client=None,
-            replication_layer=Mock(),
+            federation_client=Mock(),
             ratelimiter=NonCallableMock(spec_set=[
                 "send_message",
             ]),
@@ -41,7 +43,9 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
         self.event_id = 0
 
         server_factory = ReplicationStreamProtocolFactory(self.hs)
-        listener = reactor.listenUNIX("\0xxx", server_factory)
+        # XXX: mktemp is unsafe and should never be used. but we're just a test.
+        path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket")
+        listener = reactor.listenUNIX(path, server_factory)
         self.addCleanup(listener.stopListening)
         self.streamer = server_factory.streamer
 
@@ -49,7 +53,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
         client_factory = ReplicationClientFactory(
             self.hs, "client_name", self.replication_handler
         )
-        client_connector = reactor.connectUNIX("\0xxx", client_factory)
+        client_connector = reactor.connectUNIX(path, client_factory)
         self.addCleanup(client_factory.stopTrying)
         self.addCleanup(client_connector.disconnect)
 
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 105e1228bb..cb058d3142 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -226,13 +226,16 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             context = EventContext()
             context.current_state_ids = state_ids
             context.prev_state_ids = state_ids
-        elif not backfill:
+        else:
             state_handler = self.hs.get_state_handler()
             context = yield state_handler.compute_event_context(event)
-        else:
-            context = EventContext()
 
-        context.push_actions = push_actions
+        yield self.master_store.add_push_actions_to_staging(
+            event.event_id, {
+                user_id: actions
+                for user_id, actions in push_actions
+            },
+        )
 
         ordering = None
         if backfill:
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index e9698bfdc9..2b89c0a3c7 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -114,7 +114,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
 
         hs = yield setup_test_homeserver(
             http_client=None,
-            replication_layer=Mock(),
+            federation_client=Mock(),
             ratelimiter=NonCallableMock(spec_set=[
                 "send_message",
             ]),
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index dddcf51b69..deac7f100c 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -45,7 +45,7 @@ class ProfileTestCase(unittest.TestCase):
             http_client=None,
             resource_for_client=self.mock_resource,
             federation=Mock(),
-            replication_layer=Mock(),
+            federation_client=Mock(),
             profile_handler=self.mock_handler
         )
 
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index de376fb514..7e8966a1a8 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -46,7 +46,7 @@ class RoomPermissionsTestCase(RestTestCase):
         hs = yield setup_test_homeserver(
             "red",
             http_client=None,
-            replication_layer=Mock(),
+            federation_client=Mock(),
             ratelimiter=NonCallableMock(spec_set=["send_message"]),
         )
         self.ratelimiter = hs.get_ratelimiter()
@@ -409,7 +409,7 @@ class RoomsMemberListTestCase(RestTestCase):
         hs = yield setup_test_homeserver(
             "red",
             http_client=None,
-            replication_layer=Mock(),
+            federation_client=Mock(),
             ratelimiter=NonCallableMock(spec_set=["send_message"]),
         )
         self.ratelimiter = hs.get_ratelimiter()
@@ -493,7 +493,7 @@ class RoomsCreateTestCase(RestTestCase):
         hs = yield setup_test_homeserver(
             "red",
             http_client=None,
-            replication_layer=Mock(),
+            federation_client=Mock(),
             ratelimiter=NonCallableMock(spec_set=["send_message"]),
         )
         self.ratelimiter = hs.get_ratelimiter()
@@ -515,9 +515,6 @@ class RoomsCreateTestCase(RestTestCase):
 
         synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
 
-    def tearDown(self):
-        pass
-
     @defer.inlineCallbacks
     def test_post_room_no_keys(self):
         # POST with no config keys, expect new room id
@@ -585,7 +582,7 @@ class RoomTopicTestCase(RestTestCase):
         hs = yield setup_test_homeserver(
             "red",
             http_client=None,
-            replication_layer=Mock(),
+            federation_client=Mock(),
             ratelimiter=NonCallableMock(spec_set=["send_message"]),
         )
         self.ratelimiter = hs.get_ratelimiter()
@@ -700,7 +697,7 @@ class RoomMemberStateTestCase(RestTestCase):
         hs = yield setup_test_homeserver(
             "red",
             http_client=None,
-            replication_layer=Mock(),
+            federation_client=Mock(),
             ratelimiter=NonCallableMock(spec_set=["send_message"]),
         )
         self.ratelimiter = hs.get_ratelimiter()
@@ -832,7 +829,7 @@ class RoomMessagesTestCase(RestTestCase):
         hs = yield setup_test_homeserver(
             "red",
             http_client=None,
-            replication_layer=Mock(),
+            federation_client=Mock(),
             ratelimiter=NonCallableMock(spec_set=["send_message"]),
         )
         self.ratelimiter = hs.get_ratelimiter()
@@ -932,7 +929,7 @@ class RoomInitialSyncTestCase(RestTestCase):
         hs = yield setup_test_homeserver(
             "red",
             http_client=None,
-            replication_layer=Mock(),
+            federation_client=Mock(),
             ratelimiter=NonCallableMock(spec_set=[
                 "send_message",
             ]),
@@ -1006,7 +1003,7 @@ class RoomMessageListTestCase(RestTestCase):
         hs = yield setup_test_homeserver(
             "red",
             http_client=None,
-            replication_layer=Mock(),
+            federation_client=Mock(),
             ratelimiter=NonCallableMock(spec_set=["send_message"]),
         )
         self.ratelimiter = hs.get_ratelimiter()
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index a269e6f56e..2ec4ecab5b 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -47,7 +47,7 @@ class RoomTypingTestCase(RestTestCase):
             "red",
             clock=self.clock,
             http_client=None,
-            replication_layer=Mock(),
+            federation_client=Mock(),
             ratelimiter=NonCallableMock(spec_set=[
                 "send_message",
             ]),
@@ -95,7 +95,7 @@ class RoomTypingTestCase(RestTestCase):
                 else:
                     if remotedomains is not None:
                         remotedomains.add(member.domain)
-        hs.get_handlers().room_member_handler.fetch_room_distributions_into = (
+        hs.get_room_member_handler().fetch_room_distributions_into = (
             fetch_room_distributions_into
         )
 
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 096f771bea..8aba456510 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -49,6 +49,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
         self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
         self.hs.get_device_handler = Mock(return_value=self.device_handler)
         self.hs.config.enable_registration = True
+        self.hs.config.registrations_require_3pid = []
         self.hs.config.auto_join_rooms = []
 
         # init the thing we're testing
diff --git a/tests/rest/media/__init__.py b/tests/rest/media/__init__.py
new file mode 100644
index 0000000000..a354d38ca8
--- /dev/null
+++ b/tests/rest/media/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/rest/media/v1/__init__.py b/tests/rest/media/v1/__init__.py
new file mode 100644
index 0000000000..a354d38ca8
--- /dev/null
+++ b/tests/rest/media/v1/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
new file mode 100644
index 0000000000..eef38b6781
--- /dev/null
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -0,0 +1,86 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet import defer
+
+from synapse.rest.media.v1._base import FileInfo
+from synapse.rest.media.v1.media_storage import MediaStorage
+from synapse.rest.media.v1.filepath import MediaFilePaths
+from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
+
+from mock import Mock
+
+from tests import unittest
+
+import os
+import shutil
+import tempfile
+
+
+class MediaStorageTests(unittest.TestCase):
+    def setUp(self):
+        self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
+
+        self.primary_base_path = os.path.join(self.test_dir, "primary")
+        self.secondary_base_path = os.path.join(self.test_dir, "secondary")
+
+        hs = Mock()
+        hs.config.media_store_path = self.primary_base_path
+
+        storage_providers = [FileStorageProviderBackend(
+            hs, self.secondary_base_path
+        )]
+
+        self.filepaths = MediaFilePaths(self.primary_base_path)
+        self.media_storage = MediaStorage(
+            self.primary_base_path, self.filepaths, storage_providers,
+        )
+
+    def tearDown(self):
+        shutil.rmtree(self.test_dir)
+
+    @defer.inlineCallbacks
+    def test_ensure_media_is_in_local_cache(self):
+        media_id = "some_media_id"
+        test_body = "Test\n"
+
+        # First we create a file that is in a storage provider but not in the
+        # local primary media store
+        rel_path = self.filepaths.local_media_filepath_rel(media_id)
+        secondary_path = os.path.join(self.secondary_base_path, rel_path)
+
+        os.makedirs(os.path.dirname(secondary_path))
+
+        with open(secondary_path, "w") as f:
+            f.write(test_body)
+
+        # Now we run ensure_media_is_in_local_cache, which should copy the file
+        # to the local cache.
+        file_info = FileInfo(None, media_id)
+        local_path = yield self.media_storage.ensure_media_is_in_local_cache(file_info)
+
+        self.assertTrue(os.path.exists(local_path))
+
+        # Asserts the file is under the expected local cache directory
+        self.assertEquals(
+            os.path.commonprefix([self.primary_base_path, local_path]),
+            self.primary_base_path,
+        )
+
+        with open(local_path) as f:
+            body = f.read()
+
+        self.assertEqual(test_body, body)
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 13d81f972b..c2e39a7288 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -42,7 +42,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
         hs = yield setup_test_homeserver(
             config=config,
             federation_sender=Mock(),
-            replication_layer=Mock(),
+            federation_client=Mock(),
         )
 
         self.as_token = "token1"
@@ -119,7 +119,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         hs = yield setup_test_homeserver(
             config=config,
             federation_sender=Mock(),
-            replication_layer=Mock(),
+            federation_client=Mock(),
         )
         self.db_pool = hs.get_db_pool()
 
@@ -455,7 +455,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
             config=config,
             datastore=Mock(),
             federation_sender=Mock(),
-            replication_layer=Mock(),
+            federation_client=Mock(),
         )
 
         ApplicationServiceStore(None, hs)
@@ -473,7 +473,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
             config=config,
             datastore=Mock(),
             federation_sender=Mock(),
-            replication_layer=Mock(),
+            federation_client=Mock(),
         )
 
         with self.assertRaises(ConfigError) as cm:
@@ -497,7 +497,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
             config=config,
             datastore=Mock(),
             federation_sender=Mock(),
-            replication_layer=Mock(),
+            federation_client=Mock(),
         )
 
         with self.assertRaises(ConfigError) as cm:
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 3135488353..575374c6a6 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -62,6 +62,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
                 {"notify_count": noitf_count, "highlight_count": highlight_count}
             )
 
+        @defer.inlineCallbacks
         def _inject_actions(stream, action):
             event = Mock()
             event.room_id = room_id
@@ -69,11 +70,12 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
             event.internal_metadata.stream_ordering = stream
             event.depth = stream
 
-            tuples = [(user_id, action)]
-
-            return self.store.runInteraction(
+            yield self.store.add_push_actions_to_staging(
+                event.event_id, {user_id: action},
+            )
+            yield self.store.runInteraction(
                 "", self.store._set_push_actions_for_event_and_users_txn,
-                event, tuples
+                [(event, None)], [(event, None)],
             )
 
         def _rotate(stream):
@@ -125,3 +127,70 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
         yield _assert_counts(1, 1)
         yield _rotate(10)
         yield _assert_counts(1, 1)
+
+    @tests.unittest.DEBUG
+    @defer.inlineCallbacks
+    def test_find_first_stream_ordering_after_ts(self):
+        def add_event(so, ts):
+            return self.store._simple_insert("events", {
+                "stream_ordering": so,
+                "received_ts": ts,
+                "event_id": "event%i" % so,
+                "type": "",
+                "room_id": "",
+                "content": "",
+                "processed": True,
+                "outlier": False,
+                "topological_ordering": 0,
+                "depth": 0,
+            })
+
+        # start with the base case where there are no events in the table
+        r = yield self.store.find_first_stream_ordering_after_ts(11)
+        self.assertEqual(r, 0)
+
+        # now with one event
+        yield add_event(2, 10)
+        r = yield self.store.find_first_stream_ordering_after_ts(9)
+        self.assertEqual(r, 2)
+        r = yield self.store.find_first_stream_ordering_after_ts(10)
+        self.assertEqual(r, 2)
+        r = yield self.store.find_first_stream_ordering_after_ts(11)
+        self.assertEqual(r, 3)
+
+        # add a bunch of dummy events to the events table
+        for (stream_ordering, ts) in (
+                (3, 110),
+                (4, 120),
+                (5, 120),
+                (10, 130),
+                (20, 140),
+        ):
+            yield add_event(stream_ordering, ts)
+
+        r = yield self.store.find_first_stream_ordering_after_ts(110)
+        self.assertEqual(r, 3,
+                         "First event after 110ms should be 3, was %i" % r)
+
+        # 4 and 5 are both after 120: we want 4 rather than 5
+        r = yield self.store.find_first_stream_ordering_after_ts(120)
+        self.assertEqual(r, 4,
+                         "First event after 120ms should be 4, was %i" % r)
+
+        r = yield self.store.find_first_stream_ordering_after_ts(129)
+        self.assertEqual(r, 10,
+                         "First event after 129ms should be 10, was %i" % r)
+
+        # check we can get the last event
+        r = yield self.store.find_first_stream_ordering_after_ts(140)
+        self.assertEqual(r, 20,
+                         "First event after 14ms should be 20, was %i" % r)
+
+        # off the end
+        r = yield self.store.find_first_stream_ordering_after_ts(160)
+        self.assertEqual(r, 21)
+
+        # check we can find an event at ordering zero
+        yield add_event(0, 5)
+        r = yield self.store.find_first_stream_ordering_after_ts(1)
+        self.assertEqual(r, 0)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 6afaca3a61..888ddfaddd 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -36,8 +36,7 @@ class RedactionTestCase(unittest.TestCase):
 
         self.store = hs.get_datastore()
         self.event_builder_factory = hs.get_event_builder_factory()
-        self.handlers = hs.get_handlers()
-        self.message_handler = self.handlers.message_handler
+        self.event_creation_handler = hs.get_event_creation_handler()
 
         self.u_alice = UserID.from_string("@alice:test")
         self.u_bob = UserID.from_string("@bob:test")
@@ -59,7 +58,7 @@ class RedactionTestCase(unittest.TestCase):
             "content": content,
         })
 
-        event, context = yield self.message_handler._create_new_client_event(
+        event, context = yield self.event_creation_handler.create_new_client_event(
             builder
         )
 
@@ -79,7 +78,7 @@ class RedactionTestCase(unittest.TestCase):
             "content": {"body": body, "msgtype": u"message"},
         })
 
-        event, context = yield self.message_handler._create_new_client_event(
+        event, context = yield self.event_creation_handler.create_new_client_event(
             builder
         )
 
@@ -98,7 +97,7 @@ class RedactionTestCase(unittest.TestCase):
             "redacts": event_id,
         })
 
-        event, context = yield self.message_handler._create_new_client_event(
+        event, context = yield self.event_creation_handler.create_new_client_event(
             builder
         )
 
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 1be7d932f6..657b279e5d 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -37,8 +37,7 @@ class RoomMemberStoreTestCase(unittest.TestCase):
         # storage logic
         self.store = hs.get_datastore()
         self.event_builder_factory = hs.get_event_builder_factory()
-        self.handlers = hs.get_handlers()
-        self.message_handler = self.handlers.message_handler
+        self.event_creation_handler = hs.get_event_creation_handler()
 
         self.u_alice = UserID.from_string("@alice:test")
         self.u_bob = UserID.from_string("@bob:test")
@@ -58,7 +57,7 @@ class RoomMemberStoreTestCase(unittest.TestCase):
             "content": {"membership": membership},
         })
 
-        event, context = yield self.message_handler._create_new_client_event(
+        event, context = yield self.event_creation_handler.create_new_client_event(
             builder
         )
 
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
new file mode 100644
index 0000000000..0891308f25
--- /dev/null
+++ b/tests/storage/test_user_directory.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.storage import UserDirectoryStore
+from synapse.storage.roommember import ProfileInfo
+from tests import unittest
+from tests.utils import setup_test_homeserver
+
+ALICE = "@alice:a"
+BOB = "@bob:b"
+BOBBY = "@bobby:a"
+
+
+class UserDirectoryStoreTestCase(unittest.TestCase):
+    @defer.inlineCallbacks
+    def setUp(self):
+        self.hs = yield setup_test_homeserver()
+        self.store = UserDirectoryStore(None, self.hs)
+
+        # alice and bob are both in !room_id. bobby is not but shares
+        # a homeserver with alice.
+        yield self.store.add_profiles_to_user_dir(
+            "!room:id",
+            {
+                ALICE: ProfileInfo(None, "alice"),
+                BOB: ProfileInfo(None, "bob"),
+                BOBBY: ProfileInfo(None, "bobby")
+            },
+        )
+        yield self.store.add_users_to_public_room(
+            "!room:id",
+            [ALICE, BOB],
+        )
+        yield self.store.add_users_who_share_room(
+            "!room:id",
+            False,
+            (
+                (ALICE, BOB),
+                (BOB, ALICE),
+            ),
+        )
+
+    @defer.inlineCallbacks
+    def test_search_user_dir(self):
+        # normally when alice searches the directory she should just find
+        # bob because bobby doesn't share a room with her.
+        r = yield self.store.search_user_dir(ALICE, "bob", 10)
+        self.assertFalse(r["limited"])
+        self.assertEqual(1, len(r["results"]))
+        self.assertDictEqual(r["results"][0], {
+            "user_id": BOB,
+            "display_name": "bob",
+            "avatar_url": None,
+        })
+
+    @defer.inlineCallbacks
+    def test_search_user_dir_all_users(self):
+        self.hs.config.user_directory_search_all_users = True
+        try:
+            r = yield self.store.search_user_dir(ALICE, "bob", 10)
+            self.assertFalse(r["limited"])
+            self.assertEqual(2, len(r["results"]))
+            self.assertDictEqual(r["results"][0], {
+                "user_id": BOB,
+                "display_name": "bob",
+                "avatar_url": None,
+            })
+            self.assertDictEqual(r["results"][1], {
+                "user_id": BOBBY,
+                "display_name": "bobby",
+                "avatar_url": None,
+            })
+        finally:
+            self.hs.config.user_directory_search_all_users = False
diff --git a/tests/test_state.py b/tests/test_state.py
index feb84f3d48..a5c5e55951 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
 from synapse.events import FrozenEvent
 from synapse.api.auth import Auth
 from synapse.api.constants import EventTypes, Membership
-from synapse.state import StateHandler
+from synapse.state import StateHandler, StateResolutionHandler
 
 from .utils import MockClock
 
@@ -80,14 +80,14 @@ class StateGroupStore(object):
 
         return defer.succeed(groups)
 
-    def store_state_groups(self, event, context):
-        if context.current_state_ids is None:
-            return
+    def store_state_group(self, event_id, room_id, prev_group, delta_ids,
+                          current_state_ids):
+        state_group = self._next_group
+        self._next_group += 1
 
-        state_events = dict(context.current_state_ids)
+        self._group_to_state[state_group] = dict(current_state_ids)
 
-        self._group_to_state[context.state_group] = state_events
-        self._event_to_state_group[event.event_id] = context.state_group
+        return state_group
 
     def get_events(self, event_ids, **kwargs):
         return {
@@ -95,10 +95,19 @@ class StateGroupStore(object):
             if e_id in self._event_id_to_event
         }
 
+    def get_state_group_delta(self, name):
+        return (None, None)
+
     def register_events(self, events):
         for e in events:
             self._event_id_to_event[e.event_id] = e
 
+    def register_event_context(self, event, context):
+        self._event_to_state_group[event.event_id] = context.state_group
+
+    def register_event_id_state_group(self, event_id, state_group):
+        self._event_to_state_group[event_id] = state_group
+
 
 class DictObj(dict):
     def __init__(self, **kwargs):
@@ -137,25 +146,16 @@ class Graph(object):
 
 class StateTestCase(unittest.TestCase):
     def setUp(self):
-        self.store = Mock(
-            spec_set=[
-                "get_state_groups_ids",
-                "add_event_hashes",
-                "get_events",
-                "get_next_state_group",
-                "get_state_group_delta",
-            ]
-        )
+        self.store = StateGroupStore()
         hs = Mock(spec_set=[
             "get_datastore", "get_auth", "get_state_handler", "get_clock",
+            "get_state_resolution_handler",
         ])
         hs.get_datastore.return_value = self.store
         hs.get_state_handler.return_value = None
         hs.get_clock.return_value = MockClock()
         hs.get_auth.return_value = Auth(hs)
-
-        self.store.get_next_state_group.side_effect = Mock
-        self.store.get_state_group_delta.return_value = (None, None)
+        hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
 
         self.state = StateHandler(hs)
         self.event_id = 0
@@ -195,14 +195,13 @@ class StateTestCase(unittest.TestCase):
             }
         )
 
-        store = StateGroupStore()
-        self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
+        self.store.register_events(graph.walk())
 
         context_store = {}
 
         for event in graph.walk():
             context = yield self.state.compute_event_context(event)
-            store.store_state_groups(event, context)
+            self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
         self.assertEqual(2, len(context_store["D"].prev_state_ids))
@@ -247,16 +246,13 @@ class StateTestCase(unittest.TestCase):
             }
         )
 
-        store = StateGroupStore()
-        self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
-        self.store.get_events = store.get_events
-        store.register_events(graph.walk())
+        self.store.register_events(graph.walk())
 
         context_store = {}
 
         for event in graph.walk():
             context = yield self.state.compute_event_context(event)
-            store.store_state_groups(event, context)
+            self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
         self.assertSetEqual(
@@ -313,16 +309,13 @@ class StateTestCase(unittest.TestCase):
             }
         )
 
-        store = StateGroupStore()
-        self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
-        self.store.get_events = store.get_events
-        store.register_events(graph.walk())
+        self.store.register_events(graph.walk())
 
         context_store = {}
 
         for event in graph.walk():
             context = yield self.state.compute_event_context(event)
-            store.store_state_groups(event, context)
+            self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
         self.assertSetEqual(
@@ -396,16 +389,13 @@ class StateTestCase(unittest.TestCase):
         self._add_depths(nodes, edges)
         graph = Graph(nodes, edges)
 
-        store = StateGroupStore()
-        self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
-        self.store.get_events = store.get_events
-        store.register_events(graph.walk())
+        self.store.register_events(graph.walk())
 
         context_store = {}
 
         for event in graph.walk():
             context = yield self.state.compute_event_context(event)
-            store.store_state_groups(event, context)
+            self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
         self.assertSetEqual(
@@ -465,7 +455,11 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_trivial_annotate_message(self):
-        event = create_event(type="test_message", name="event")
+        prev_event_id = "prev_event_id"
+        event = create_event(
+            type="test_message", name="event2",
+            prev_events=[(prev_event_id, {})],
+        )
 
         old_state = [
             create_event(type="test1", state_key="1"),
@@ -473,11 +467,11 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test2", state_key=""),
         ]
 
-        group_name = "group_name_1"
-
-        self.store.get_state_groups_ids.return_value = {
-            group_name: {(e.type, e.state_key): e.event_id for e in old_state},
-        }
+        group_name = self.store.store_state_group(
+            prev_event_id, event.room_id, None, None,
+            {(e.type, e.state_key): e.event_id for e in old_state},
+        )
+        self.store.register_event_id_state_group(prev_event_id, group_name)
 
         context = yield self.state.compute_event_context(event)
 
@@ -490,7 +484,11 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_trivial_annotate_state(self):
-        event = create_event(type="state", state_key="", name="event")
+        prev_event_id = "prev_event_id"
+        event = create_event(
+            type="state", state_key="", name="event2",
+            prev_events=[(prev_event_id, {})],
+        )
 
         old_state = [
             create_event(type="test1", state_key="1"),
@@ -498,11 +496,11 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test2", state_key=""),
         ]
 
-        group_name = "group_name_1"
-
-        self.store.get_state_groups_ids.return_value = {
-            group_name: {(e.type, e.state_key): e.event_id for e in old_state},
-        }
+        group_name = self.store.store_state_group(
+            prev_event_id, event.room_id, None, None,
+            {(e.type, e.state_key): e.event_id for e in old_state},
+        )
+        self.store.register_event_id_state_group(prev_event_id, group_name)
 
         context = yield self.state.compute_event_context(event)
 
@@ -515,7 +513,12 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_resolve_message_conflict(self):
-        event = create_event(type="test_message", name="event")
+        prev_event_id1 = "event_id1"
+        prev_event_id2 = "event_id2"
+        event = create_event(
+            type="test_message", name="event3",
+            prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
+        )
 
         creation = create_event(
             type=EventTypes.Create, state_key=""
@@ -535,12 +538,12 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test4", state_key=""),
         ]
 
-        store = StateGroupStore()
-        store.register_events(old_state_1)
-        store.register_events(old_state_2)
-        self.store.get_events = store.get_events
+        self.store.register_events(old_state_1)
+        self.store.register_events(old_state_2)
 
-        context = yield self._get_context(event, old_state_1, old_state_2)
+        context = yield self._get_context(
+            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+        )
 
         self.assertEqual(len(context.current_state_ids), 6)
 
@@ -548,7 +551,12 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_resolve_state_conflict(self):
-        event = create_event(type="test4", state_key="", name="event")
+        prev_event_id1 = "event_id1"
+        prev_event_id2 = "event_id2"
+        event = create_event(
+            type="test4", state_key="", name="event",
+            prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
+        )
 
         creation = create_event(
             type=EventTypes.Create, state_key=""
@@ -573,7 +581,9 @@ class StateTestCase(unittest.TestCase):
         store.register_events(old_state_2)
         self.store.get_events = store.get_events
 
-        context = yield self._get_context(event, old_state_1, old_state_2)
+        context = yield self._get_context(
+            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+        )
 
         self.assertEqual(len(context.current_state_ids), 6)
 
@@ -581,7 +591,12 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_standard_depth_conflict(self):
-        event = create_event(type="test4", name="event")
+        prev_event_id1 = "event_id1"
+        prev_event_id2 = "event_id2"
+        event = create_event(
+            type="test4", name="event",
+            prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
+        )
 
         member_event = create_event(
             type=EventTypes.Member,
@@ -613,7 +628,9 @@ class StateTestCase(unittest.TestCase):
         store.register_events(old_state_2)
         self.store.get_events = store.get_events
 
-        context = yield self._get_context(event, old_state_1, old_state_2)
+        context = yield self._get_context(
+            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+        )
 
         self.assertEqual(
             old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
@@ -637,19 +654,26 @@ class StateTestCase(unittest.TestCase):
         store.register_events(old_state_1)
         store.register_events(old_state_2)
 
-        context = yield self._get_context(event, old_state_1, old_state_2)
+        context = yield self._get_context(
+            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+        )
 
         self.assertEqual(
             old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
         )
 
-    def _get_context(self, event, old_state_1, old_state_2):
-        group_name_1 = "group_name_1"
-        group_name_2 = "group_name_2"
+    def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2,
+                     old_state_2):
+        sg1 = self.store.store_state_group(
+            prev_event_id_1, event.room_id, None, None,
+            {(e.type, e.state_key): e.event_id for e in old_state_1},
+        )
+        self.store.register_event_id_state_group(prev_event_id_1, sg1)
 
-        self.store.get_state_groups_ids.return_value = {
-            group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1},
-            group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2},
-        }
+        sg2 = self.store.store_state_group(
+            prev_event_id_2, event.room_id, None, None,
+            {(e.type, e.state_key): e.event_id for e in old_state_2},
+        )
+        self.store.register_event_id_state_group(prev_event_id_2, sg2)
 
         return self.state.compute_event_context(event)
diff --git a/tests/unittest.py b/tests/unittest.py
index 38715972dd..7b478c4294 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import twisted
 from twisted.trial import unittest
 
 import logging
@@ -65,6 +65,10 @@ class TestCase(unittest.TestCase):
 
         @around(self)
         def setUp(orig):
+            # enable debugging of delayed calls - this means that we get a
+            # traceback when a unit test exits leaving things on the reactor.
+            twisted.internet.base.DelayedCall.debug = True
+
             old_level = logging.getLogger().level
 
             if old_level != level:
diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py
new file mode 100644
index 0000000000..76e2234255
--- /dev/null
+++ b/tests/util/test_file_consumer.py
@@ -0,0 +1,176 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet import defer, reactor
+from mock import NonCallableMock
+
+from synapse.util.file_consumer import BackgroundFileConsumer
+
+from tests import unittest
+from StringIO import StringIO
+
+import threading
+
+
+class FileConsumerTests(unittest.TestCase):
+
+    @defer.inlineCallbacks
+    def test_pull_consumer(self):
+        string_file = StringIO()
+        consumer = BackgroundFileConsumer(string_file)
+
+        try:
+            producer = DummyPullProducer()
+
+            yield producer.register_with_consumer(consumer)
+
+            yield producer.write_and_wait("Foo")
+
+            self.assertEqual(string_file.getvalue(), "Foo")
+
+            yield producer.write_and_wait("Bar")
+
+            self.assertEqual(string_file.getvalue(), "FooBar")
+        finally:
+            consumer.unregisterProducer()
+
+        yield consumer.wait()
+
+        self.assertTrue(string_file.closed)
+
+    @defer.inlineCallbacks
+    def test_push_consumer(self):
+        string_file = BlockingStringWrite()
+        consumer = BackgroundFileConsumer(string_file)
+
+        try:
+            producer = NonCallableMock(spec_set=[])
+
+            consumer.registerProducer(producer, True)
+
+            consumer.write("Foo")
+            yield string_file.wait_for_n_writes(1)
+
+            self.assertEqual(string_file.buffer, "Foo")
+
+            consumer.write("Bar")
+            yield string_file.wait_for_n_writes(2)
+
+            self.assertEqual(string_file.buffer, "FooBar")
+        finally:
+            consumer.unregisterProducer()
+
+        yield consumer.wait()
+
+        self.assertTrue(string_file.closed)
+
+    @defer.inlineCallbacks
+    def test_push_producer_feedback(self):
+        string_file = BlockingStringWrite()
+        consumer = BackgroundFileConsumer(string_file)
+
+        try:
+            producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
+
+            resume_deferred = defer.Deferred()
+            producer.resumeProducing.side_effect = lambda: resume_deferred.callback(None)
+
+            consumer.registerProducer(producer, True)
+
+            number_writes = 0
+            with string_file.write_lock:
+                for _ in range(consumer._PAUSE_ON_QUEUE_SIZE):
+                    consumer.write("Foo")
+                    number_writes += 1
+
+                producer.pauseProducing.assert_called_once()
+
+            yield string_file.wait_for_n_writes(number_writes)
+
+            yield resume_deferred
+            producer.resumeProducing.assert_called_once()
+        finally:
+            consumer.unregisterProducer()
+
+        yield consumer.wait()
+
+        self.assertTrue(string_file.closed)
+
+
+class DummyPullProducer(object):
+    def __init__(self):
+        self.consumer = None
+        self.deferred = defer.Deferred()
+
+    def resumeProducing(self):
+        d = self.deferred
+        self.deferred = defer.Deferred()
+        d.callback(None)
+
+    def write_and_wait(self, bytes):
+        d = self.deferred
+        self.consumer.write(bytes)
+        return d
+
+    def register_with_consumer(self, consumer):
+        d = self.deferred
+        self.consumer = consumer
+        self.consumer.registerProducer(self, False)
+        return d
+
+
+class BlockingStringWrite(object):
+    def __init__(self):
+        self.buffer = ""
+        self.closed = False
+        self.write_lock = threading.Lock()
+
+        self._notify_write_deferred = None
+        self._number_of_writes = 0
+
+    def write(self, bytes):
+        with self.write_lock:
+            self.buffer += bytes
+            self._number_of_writes += 1
+
+        reactor.callFromThread(self._notify_write)
+
+    def close(self):
+        self.closed = True
+
+    def _notify_write(self):
+        "Called by write to indicate a write happened"
+        with self.write_lock:
+            if not self._notify_write_deferred:
+                return
+            d = self._notify_write_deferred
+            self._notify_write_deferred = None
+        d.callback(None)
+
+    @defer.inlineCallbacks
+    def wait_for_n_writes(self, n):
+        "Wait for n writes to have happened"
+        while True:
+            with self.write_lock:
+                if n <= self._number_of_writes:
+                    return
+
+                if not self._notify_write_deferred:
+                    self._notify_write_deferred = defer.Deferred()
+
+                d = self._notify_write_deferred
+
+            yield d
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index e2f7765f49..4850722bc5 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -12,12 +12,12 @@ class LoggingContextTestCase(unittest.TestCase):
 
     def _check_test_key(self, value):
         self.assertEquals(
-            LoggingContext.current_context().test_key, value
+            LoggingContext.current_context().request, value
         )
 
     def test_with_context(self):
         with LoggingContext() as context_one:
-            context_one.test_key = "test"
+            context_one.request = "test"
             self._check_test_key("test")
 
     @defer.inlineCallbacks
@@ -25,14 +25,14 @@ class LoggingContextTestCase(unittest.TestCase):
         @defer.inlineCallbacks
         def competing_callback():
             with LoggingContext() as competing_context:
-                competing_context.test_key = "competing"
+                competing_context.request = "competing"
                 yield sleep(0)
                 self._check_test_key("competing")
 
         reactor.callLater(0, competing_callback)
 
         with LoggingContext() as context_one:
-            context_one.test_key = "one"
+            context_one.request = "one"
             yield sleep(0)
             self._check_test_key("one")
 
@@ -43,14 +43,14 @@ class LoggingContextTestCase(unittest.TestCase):
 
         @defer.inlineCallbacks
         def cb():
-            context_one.test_key = "one"
+            context_one.request = "one"
             yield function()
             self._check_test_key("one")
 
             callback_completed[0] = True
 
         with LoggingContext() as context_one:
-            context_one.test_key = "one"
+            context_one.request = "one"
 
             # fire off function, but don't wait on it.
             logcontext.preserve_fn(cb)()
@@ -107,7 +107,7 @@ class LoggingContextTestCase(unittest.TestCase):
         sentinel_context = LoggingContext.current_context()
 
         with LoggingContext() as context_one:
-            context_one.test_key = "one"
+            context_one.request = "one"
 
             d1 = logcontext.make_deferred_yieldable(blocking_function())
             # make sure that the context was reset by make_deferred_yieldable
@@ -124,7 +124,7 @@ class LoggingContextTestCase(unittest.TestCase):
         argument isn't actually a deferred"""
 
         with LoggingContext() as context_one:
-            context_one.test_key = "one"
+            context_one.request = "one"
 
             d1 = logcontext.make_deferred_yieldable("bum")
             self._check_test_key("one")
diff --git a/tests/utils.py b/tests/utils.py
index ed8a7360f5..8efd3a3475 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -13,27 +13,28 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.http.server import HttpServer
-from synapse.api.errors import cs_error, CodeMessageException, StoreError
-from synapse.api.constants import EventTypes
-from synapse.storage.prepare_database import prepare_database
-from synapse.storage.engines import create_engine
-from synapse.server import HomeServer
-from synapse.federation.transport import server
-from synapse.util.ratelimitutils import FederationRateLimiter
-
-from synapse.util.logcontext import LoggingContext
-
-from twisted.internet import defer, reactor
-from twisted.enterprise.adbapi import ConnectionPool
-
-from collections import namedtuple
-from mock import patch, Mock
 import hashlib
+from inspect import getcallargs
 import urllib
 import urlparse
 
-from inspect import getcallargs
+from mock import Mock, patch
+from twisted.internet import defer, reactor
+
+from synapse.api.errors import CodeMessageException, cs_error
+from synapse.federation.transport import server
+from synapse.http.server import HttpServer
+from synapse.server import HomeServer
+from synapse.storage import PostgresEngine
+from synapse.storage.engines import create_engine
+from synapse.storage.prepare_database import prepare_database
+from synapse.util.logcontext import LoggingContext
+from synapse.util.ratelimitutils import FederationRateLimiter
+
+# set this to True to run the tests against postgres instead of sqlite.
+# It requires you to have a local postgres database called synapse_test, within
+# which ALL TABLES WILL BE DROPPED
+USE_POSTGRES_FOR_TESTS = False
 
 
 @defer.inlineCallbacks
@@ -57,32 +58,70 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
         config.worker_app = None
         config.email_enable_notifs = False
         config.block_non_admin_invites = False
+        config.federation_domain_whitelist = None
+        config.user_directory_search_all_users = False
+
+        # disable user directory updates, because they get done in the
+        # background, which upsets the test runner.
+        config.update_user_directory = False
 
     config.use_frozen_dicts = True
-    config.database_config = {"name": "sqlite3"}
     config.ldap_enabled = False
 
     if "clock" not in kargs:
         kargs["clock"] = MockClock()
 
+    if USE_POSTGRES_FOR_TESTS:
+        config.database_config = {
+            "name": "psycopg2",
+            "args": {
+                "database": "synapse_test",
+                "cp_min": 1,
+                "cp_max": 5,
+            },
+        }
+    else:
+        config.database_config = {
+            "name": "sqlite3",
+            "args": {
+                "database": ":memory:",
+                "cp_min": 1,
+                "cp_max": 1,
+            },
+        }
+
+    db_engine = create_engine(config.database_config)
+
+    # we need to configure the connection pool to run the on_new_connection
+    # function, so that we can test code that uses custom sqlite functions
+    # (like rank).
+    config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
+
     if datastore is None:
-        db_pool = SQLiteMemoryDbPool()
-        yield db_pool.prepare()
         hs = HomeServer(
-            name, db_pool=db_pool, config=config,
+            name, config=config,
+            db_config=config.database_config,
             version_string="Synapse/tests",
-            database_engine=create_engine(config.database_config),
-            get_db_conn=db_pool.get_db_conn,
+            database_engine=db_engine,
             room_list_handler=object(),
             tls_server_context_factory=Mock(),
             **kargs
         )
+        db_conn = hs.get_db_conn()
+        # make sure that the database is empty
+        if isinstance(db_engine, PostgresEngine):
+            cur = db_conn.cursor()
+            cur.execute("SELECT tablename FROM pg_tables where schemaname='public'")
+            rows = cur.fetchall()
+            for r in rows:
+                cur.execute("DROP TABLE %s CASCADE" % r[0])
+        yield prepare_database(db_conn, db_engine, config)
         hs.setup()
     else:
         hs = HomeServer(
             name, db_pool=None, datastore=datastore, config=config,
             version_string="Synapse/tests",
-            database_engine=create_engine(config.database_config),
+            database_engine=db_engine,
             room_list_handler=object(),
             tls_server_context_factory=Mock(),
             **kargs
@@ -301,168 +340,6 @@ class MockClock(object):
         return d
 
 
-class SQLiteMemoryDbPool(ConnectionPool, object):
-    def __init__(self):
-        super(SQLiteMemoryDbPool, self).__init__(
-            "sqlite3", ":memory:",
-            cp_min=1,
-            cp_max=1,
-        )
-
-        self.config = Mock()
-        self.config.password_providers = []
-        self.config.database_config = {"name": "sqlite3"}
-
-    def prepare(self):
-        engine = self.create_engine()
-        return self.runWithConnection(
-            lambda conn: prepare_database(conn, engine, self.config)
-        )
-
-    def get_db_conn(self):
-        conn = self.connect()
-        engine = self.create_engine()
-        prepare_database(conn, engine, self.config)
-        return conn
-
-    def create_engine(self):
-        return create_engine(self.config.database_config)
-
-
-class MemoryDataStore(object):
-
-    Room = namedtuple(
-        "Room",
-        ["room_id", "is_public", "creator"]
-    )
-
-    def __init__(self):
-        self.tokens_to_users = {}
-        self.paths_to_content = {}
-
-        self.members = {}
-        self.rooms = {}
-
-        self.current_state = {}
-        self.events = []
-
-    class Snapshot(namedtuple("Snapshot", "room_id user_id membership_state")):
-        def fill_out_prev_events(self, event):
-            pass
-
-    def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
-        return self.Snapshot(
-            room_id, user_id, self.get_room_member(user_id, room_id)
-        )
-
-    def register(self, user_id, token, password_hash):
-        if user_id in self.tokens_to_users.values():
-            raise StoreError(400, "User in use.")
-        self.tokens_to_users[token] = user_id
-
-    def get_user_by_access_token(self, token):
-        try:
-            return {
-                "name": self.tokens_to_users[token],
-            }
-        except Exception:
-            raise StoreError(400, "User does not exist.")
-
-    def get_room(self, room_id):
-        try:
-            return self.rooms[room_id]
-        except Exception:
-            return None
-
-    def store_room(self, room_id, room_creator_user_id, is_public):
-        if room_id in self.rooms:
-            raise StoreError(409, "Conflicting room!")
-
-        room = MemoryDataStore.Room(
-            room_id=room_id,
-            is_public=is_public,
-            creator=room_creator_user_id
-        )
-        self.rooms[room_id] = room
-
-    def get_room_member(self, user_id, room_id):
-        return self.members.get(room_id, {}).get(user_id)
-
-    def get_room_members(self, room_id, membership=None):
-        if membership:
-            return [
-                v for k, v in self.members.get(room_id, {}).items()
-                if v.membership == membership
-            ]
-        else:
-            return self.members.get(room_id, {}).values()
-
-    def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
-        return [
-            m[user_id] for m in self.members.values()
-            if user_id in m and m[user_id].membership in membership_list
-        ]
-
-    def get_room_events_stream(self, user_id=None, from_key=None, to_key=None,
-                               limit=0, with_feedback=False):
-        return ([], from_key)  # TODO
-
-    def get_joined_hosts_for_room(self, room_id):
-        return defer.succeed([])
-
-    def persist_event(self, event):
-        if event.type == EventTypes.Member:
-            room_id = event.room_id
-            user = event.state_key
-            self.members.setdefault(room_id, {})[user] = event
-
-        if hasattr(event, "state_key"):
-            key = (event.room_id, event.type, event.state_key)
-            self.current_state[key] = event
-
-        self.events.append(event)
-
-    def get_current_state(self, room_id, event_type=None, state_key=""):
-        if event_type:
-            key = (room_id, event_type, state_key)
-            if self.current_state.get(key):
-                return [self.current_state.get(key)]
-            return None
-        else:
-            return [
-                e for e in self.current_state
-                if e[0] == room_id
-            ]
-
-    def set_presence_state(self, user_localpart, state):
-        return defer.succeed({"state": 0})
-
-    def get_presence_list(self, user_localpart, accepted):
-        return []
-
-    def get_room_events_max_id(self):
-        return "s0"  # TODO (erikj)
-
-    def get_send_event_level(self, room_id):
-        return defer.succeed(0)
-
-    def get_power_level(self, room_id, user_id):
-        return defer.succeed(0)
-
-    def get_add_state_level(self, room_id):
-        return defer.succeed(0)
-
-    def get_room_join_rule(self, room_id):
-        # TODO (erikj): This should be configurable
-        return defer.succeed("invite")
-
-    def get_ops_levels(self, room_id):
-        return defer.succeed((5, 5, 5))
-
-    def insert_client_ip(self, user, access_token, ip, user_agent):
-        return defer.succeed(None)
-
-
 def _format_call(args, kwargs):
     return ", ".join(
         ["%r" % (a) for a in args] +