summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/event_injector.py76
-rw-r--r--tests/storage/test__base.py8
-rw-r--r--tests/storage/test_appservice.py58
-rw-r--r--tests/storage/test_background_update.py6
-rw-r--r--tests/storage/test_base.py10
-rw-r--r--tests/storage/test_client_ips.py5
-rw-r--r--tests/storage/test_devices.py1
-rw-r--r--tests/storage/test_directory.py6
-rw-r--r--tests/storage/test_event_federation.py68
-rw-r--r--tests/storage/test_event_push_actions.py83
-rw-r--r--tests/storage/test_keys.py1
-rw-r--r--tests/storage/test_presence.py6
-rw-r--r--tests/storage/test_profile.py4
-rw-r--r--tests/storage/test_redaction.py17
-rw-r--r--tests/storage/test_registration.py20
-rw-r--r--tests/storage/test_room.py4
-rw-r--r--tests/storage/test_roommember.py13
-rw-r--r--tests/storage/test_user_directory.py89
18 files changed, 316 insertions, 159 deletions
diff --git a/tests/storage/event_injector.py b/tests/storage/event_injector.py
deleted file mode 100644
index 024ac15069..0000000000
--- a/tests/storage/event_injector.py
+++ /dev/null
@@ -1,76 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from twisted.internet import defer
-
-from synapse.api.constants import EventTypes
-
-
-class EventInjector:
-    def __init__(self, hs):
-        self.hs = hs
-        self.store = hs.get_datastore()
-        self.message_handler = hs.get_handlers().message_handler
-        self.event_builder_factory = hs.get_event_builder_factory()
-
-    @defer.inlineCallbacks
-    def create_room(self, room, user):
-        builder = self.event_builder_factory.new({
-            "type": EventTypes.Create,
-            "sender": user.to_string(),
-            "room_id": room.to_string(),
-            "content": {},
-        })
-
-        event, context = yield self.message_handler._create_new_client_event(
-            builder
-        )
-
-        yield self.store.persist_event(event, context)
-
-    @defer.inlineCallbacks
-    def inject_room_member(self, room, user, membership):
-        builder = self.event_builder_factory.new({
-            "type": EventTypes.Member,
-            "sender": user.to_string(),
-            "state_key": user.to_string(),
-            "room_id": room.to_string(),
-            "content": {"membership": membership},
-        })
-
-        event, context = yield self.message_handler._create_new_client_event(
-            builder
-        )
-
-        yield self.store.persist_event(event, context)
-
-        defer.returnValue(event)
-
-    @defer.inlineCallbacks
-    def inject_message(self, room, user, body):
-        builder = self.event_builder_factory.new({
-            "type": EventTypes.Message,
-            "sender": user.to_string(),
-            "state_key": user.to_string(),
-            "room_id": room.to_string(),
-            "content": {"body": body, "msgtype": u"message"},
-        })
-
-        event, context = yield self.message_handler._create_new_client_event(
-            builder
-        )
-
-        yield self.store.persist_event(event, context)
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 3cfa21c9f8..6d6f00c5c5 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -14,15 +14,15 @@
 # limitations under the License.
 
 
-from tests import unittest
-from twisted.internet import defer
-
 from mock import Mock
 
-from synapse.util.async import ObservableDeferred
+from twisted.internet import defer
 
+from synapse.util.async import ObservableDeferred
 from synapse.util.caches.descriptors import Cache, cached
 
+from tests import unittest
+
 
 class CacheTestCase(unittest.TestCase):
 
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 9e98d0e330..099861b27c 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -12,21 +12,25 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import json
+import os
 import tempfile
-from synapse.config._base import ConfigError
-from tests import unittest
+
+from mock import Mock
+
+import yaml
+
 from twisted.internet import defer
 
-from tests.utils import setup_test_homeserver
 from synapse.appservice import ApplicationService, ApplicationServiceState
+from synapse.config._base import ConfigError
 from synapse.storage.appservice import (
-    ApplicationServiceStore, ApplicationServiceTransactionStore
+    ApplicationServiceStore,
+    ApplicationServiceTransactionStore,
 )
 
-import json
-import os
-import yaml
-from mock import Mock
+from tests import unittest
+from tests.utils import setup_test_homeserver
 
 
 class ApplicationServiceStoreTestCase(unittest.TestCase):
@@ -42,7 +46,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
         hs = yield setup_test_homeserver(
             config=config,
             federation_sender=Mock(),
-            replication_layer=Mock(),
+            federation_client=Mock(),
         )
 
         self.as_token = "token1"
@@ -58,14 +62,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
         self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
         self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
         # must be done after inserts
-        self.store = ApplicationServiceStore(hs)
+        self.store = ApplicationServiceStore(None, hs)
 
     def tearDown(self):
         # TODO: suboptimal that we need to create files for tests!
         for f in self.as_yaml_files:
             try:
                 os.remove(f)
-            except:
+            except Exception:
                 pass
 
     def _add_appservice(self, as_token, id, url, hs_token, sender):
@@ -119,7 +123,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         hs = yield setup_test_homeserver(
             config=config,
             federation_sender=Mock(),
-            replication_layer=Mock(),
+            federation_client=Mock(),
         )
         self.db_pool = hs.get_db_pool()
 
@@ -150,7 +154,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
         self.as_yaml_files = []
 
-        self.store = TestTransactionStore(hs)
+        self.store = TestTransactionStore(None, hs)
 
     def _add_service(self, url, as_token, id):
         as_yaml = dict(url=url, as_token=as_token, hs_token="something",
@@ -420,8 +424,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 class TestTransactionStore(ApplicationServiceTransactionStore,
                            ApplicationServiceStore):
 
-    def __init__(self, hs):
-        super(TestTransactionStore, self).__init__(hs)
+    def __init__(self, db_conn, hs):
+        super(TestTransactionStore, self).__init__(db_conn, hs)
 
 
 class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
@@ -455,10 +459,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
             config=config,
             datastore=Mock(),
             federation_sender=Mock(),
-            replication_layer=Mock(),
+            federation_client=Mock(),
         )
 
-        ApplicationServiceStore(hs)
+        ApplicationServiceStore(None, hs)
 
     @defer.inlineCallbacks
     def test_duplicate_ids(self):
@@ -473,16 +477,16 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
             config=config,
             datastore=Mock(),
             federation_sender=Mock(),
-            replication_layer=Mock(),
+            federation_client=Mock(),
         )
 
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(hs)
+            ApplicationServiceStore(None, hs)
 
         e = cm.exception
-        self.assertIn(f1, e.message)
-        self.assertIn(f2, e.message)
-        self.assertIn("id", e.message)
+        self.assertIn(f1, str(e))
+        self.assertIn(f2, str(e))
+        self.assertIn("id", str(e))
 
     @defer.inlineCallbacks
     def test_duplicate_as_tokens(self):
@@ -497,13 +501,13 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
             config=config,
             datastore=Mock(),
             federation_sender=Mock(),
-            replication_layer=Mock(),
+            federation_client=Mock(),
         )
 
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(hs)
+            ApplicationServiceStore(None, hs)
 
         e = cm.exception
-        self.assertIn(f1, e.message)
-        self.assertIn(f2, e.message)
-        self.assertIn("as_token", e.message)
+        self.assertIn(f1, str(e))
+        self.assertIn(f2, str(e))
+        self.assertIn("as_token", str(e))
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 1286b4ce2d..ab1f310572 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -1,10 +1,10 @@
-from tests import unittest
+from mock import Mock
+
 from twisted.internet import defer
 
+from tests import unittest
 from tests.utils import setup_test_homeserver
 
-from mock import Mock
-
 
 class BackgroundUpdateTestCase(unittest.TestCase):
 
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 91e971190c..1d1234ee39 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -14,18 +14,18 @@
 # limitations under the License.
 
 
-from tests import unittest
-from twisted.internet import defer
+from collections import OrderedDict
 
 from mock import Mock
 
-from collections import OrderedDict
+from twisted.internet import defer
 
 from synapse.server import HomeServer
-
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.engines import create_engine
 
+from tests import unittest
+
 
 class SQLBaseStoreTestCase(unittest.TestCase):
     """ Test the "simple" SQL generating methods in SQLBaseStore. """
@@ -56,7 +56,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
             database_engine=create_engine(config.database_config),
         )
 
-        self.datastore = SQLBaseStore(hs)
+        self.datastore = SQLBaseStore(None, hs)
 
     @defer.inlineCallbacks
     def test_insert_1col(self):
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 03df697575..bd6fda6cb1 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -15,9 +15,6 @@
 
 from twisted.internet import defer
 
-import synapse.server
-import synapse.storage
-import synapse.types
 import tests.unittest
 import tests.utils
 
@@ -39,7 +36,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
         self.clock.now = 12345678
         user_id = "@user:id"
         yield self.store.insert_client_ip(
-            synapse.types.UserID.from_string(user_id),
+            user_id,
             "access_token", "ip", "user_agent", "device_id",
         )
 
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index f8725acea0..a54cc6bc32 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -16,6 +16,7 @@
 from twisted.internet import defer
 
 import synapse.api.errors
+
 import tests.unittest
 import tests.utils
 
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index b087892e0b..129ebaf343 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -14,12 +14,12 @@
 # limitations under the License.
 
 
-from tests import unittest
 from twisted.internet import defer
 
 from synapse.storage.directory import DirectoryStore
-from synapse.types import RoomID, RoomAlias
+from synapse.types import RoomAlias, RoomID
 
+from tests import unittest
 from tests.utils import setup_test_homeserver
 
 
@@ -29,7 +29,7 @@ class DirectoryStoreTestCase(unittest.TestCase):
     def setUp(self):
         hs = yield setup_test_homeserver()
 
-        self.store = DirectoryStore(hs)
+        self.store = DirectoryStore(None, hs)
 
         self.room = RoomID.from_string("!abcde:test")
         self.alias = RoomAlias.from_string("#my-room:test")
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
new file mode 100644
index 0000000000..30683e7888
--- /dev/null
+++ b/tests/storage/test_event_federation.py
@@ -0,0 +1,68 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+import tests.unittest
+import tests.utils
+
+
+class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
+    @defer.inlineCallbacks
+    def setUp(self):
+        hs = yield tests.utils.setup_test_homeserver()
+        self.store = hs.get_datastore()
+
+    @defer.inlineCallbacks
+    def test_get_prev_events_for_room(self):
+        room_id = '@ROOM:local'
+
+        # add a bunch of events and hashes to act as forward extremities
+        def insert_event(txn, i):
+            event_id = '$event_%i:local' % i
+
+            txn.execute((
+                "INSERT INTO events ("
+                "   room_id, event_id, type, depth, topological_ordering,"
+                "   content, processed, outlier) "
+                "VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?)"
+            ), (room_id, event_id, i, i, True, False))
+
+            txn.execute((
+                'INSERT INTO event_forward_extremities (room_id, event_id) '
+                'VALUES (?, ?)'
+            ), (room_id, event_id))
+
+            txn.execute((
+                'INSERT INTO event_reference_hashes '
+                '(event_id, algorithm, hash) '
+                "VALUES (?, 'sha256', ?)"
+            ), (event_id, 'ffff'))
+
+        for i in range(0, 11):
+            yield self.store.runInteraction("insert", insert_event, i)
+
+        # this should get the last five and five others
+        r = yield self.store.get_prev_events_for_room(room_id)
+        self.assertEqual(10, len(r))
+        for i in range(0, 5):
+            el = r[i]
+            depth = el[2]
+            self.assertEqual(10 - i, depth)
+
+        for i in range(5, 5):
+            el = r[i]
+            depth = el[2]
+            self.assertLessEqual(5, depth)
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 3135488353..8430fc7ba6 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -13,11 +13,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from mock import Mock
+
 from twisted.internet import defer
 
 import tests.unittest
 import tests.utils
-from mock import Mock
 
 USER_ID = "@user:example.com"
 
@@ -55,13 +56,14 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
         def _assert_counts(noitf_count, highlight_count):
             counts = yield self.store.runInteraction(
                 "", self.store._get_unread_counts_by_pos_txn,
-                room_id, user_id, 0, 0
+                room_id, user_id, 0
             )
             self.assertEquals(
                 counts,
                 {"notify_count": noitf_count, "highlight_count": highlight_count}
             )
 
+        @defer.inlineCallbacks
         def _inject_actions(stream, action):
             event = Mock()
             event.room_id = room_id
@@ -69,11 +71,12 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
             event.internal_metadata.stream_ordering = stream
             event.depth = stream
 
-            tuples = [(user_id, action)]
-
-            return self.store.runInteraction(
+            yield self.store.add_push_actions_to_staging(
+                event.event_id, {user_id: action},
+            )
+            yield self.store.runInteraction(
                 "", self.store._set_push_actions_for_event_and_users_txn,
-                event, tuples
+                [(event, None)], [(event, None)],
             )
 
         def _rotate(stream):
@@ -84,7 +87,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
         def _mark_read(stream, depth):
             return self.store.runInteraction(
                 "", self.store._remove_old_push_actions_before_txn,
-                room_id, user_id, depth, stream
+                room_id, user_id, stream
             )
 
         yield _assert_counts(0, 0)
@@ -125,3 +128,69 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
         yield _assert_counts(1, 1)
         yield _rotate(10)
         yield _assert_counts(1, 1)
+
+    @defer.inlineCallbacks
+    def test_find_first_stream_ordering_after_ts(self):
+        def add_event(so, ts):
+            return self.store._simple_insert("events", {
+                "stream_ordering": so,
+                "received_ts": ts,
+                "event_id": "event%i" % so,
+                "type": "",
+                "room_id": "",
+                "content": "",
+                "processed": True,
+                "outlier": False,
+                "topological_ordering": 0,
+                "depth": 0,
+            })
+
+        # start with the base case where there are no events in the table
+        r = yield self.store.find_first_stream_ordering_after_ts(11)
+        self.assertEqual(r, 0)
+
+        # now with one event
+        yield add_event(2, 10)
+        r = yield self.store.find_first_stream_ordering_after_ts(9)
+        self.assertEqual(r, 2)
+        r = yield self.store.find_first_stream_ordering_after_ts(10)
+        self.assertEqual(r, 2)
+        r = yield self.store.find_first_stream_ordering_after_ts(11)
+        self.assertEqual(r, 3)
+
+        # add a bunch of dummy events to the events table
+        for (stream_ordering, ts) in (
+                (3, 110),
+                (4, 120),
+                (5, 120),
+                (10, 130),
+                (20, 140),
+        ):
+            yield add_event(stream_ordering, ts)
+
+        r = yield self.store.find_first_stream_ordering_after_ts(110)
+        self.assertEqual(r, 3,
+                         "First event after 110ms should be 3, was %i" % r)
+
+        # 4 and 5 are both after 120: we want 4 rather than 5
+        r = yield self.store.find_first_stream_ordering_after_ts(120)
+        self.assertEqual(r, 4,
+                         "First event after 120ms should be 4, was %i" % r)
+
+        r = yield self.store.find_first_stream_ordering_after_ts(129)
+        self.assertEqual(r, 10,
+                         "First event after 129ms should be 10, was %i" % r)
+
+        # check we can get the last event
+        r = yield self.store.find_first_stream_ordering_after_ts(140)
+        self.assertEqual(r, 20,
+                         "First event after 14ms should be 20, was %i" % r)
+
+        # off the end
+        r = yield self.store.find_first_stream_ordering_after_ts(160)
+        self.assertEqual(r, 21)
+
+        # check we can find an event at ordering zero
+        yield add_event(0, 5)
+        r = yield self.store.find_first_stream_ordering_after_ts(1)
+        self.assertEqual(r, 0)
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index 0be790d8f8..3a3d002782 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import signedjson.key
+
 from twisted.internet import defer
 
 import tests.unittest
diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py
index 63203cea35..3276b39504 100644
--- a/tests/storage/test_presence.py
+++ b/tests/storage/test_presence.py
@@ -14,13 +14,13 @@
 # limitations under the License.
 
 
-from tests import unittest
 from twisted.internet import defer
 
 from synapse.storage.presence import PresenceStore
 from synapse.types import UserID
 
-from tests.utils import setup_test_homeserver, MockClock
+from tests import unittest
+from tests.utils import MockClock, setup_test_homeserver
 
 
 class PresenceStoreTestCase(unittest.TestCase):
@@ -29,7 +29,7 @@ class PresenceStoreTestCase(unittest.TestCase):
     def setUp(self):
         hs = yield setup_test_homeserver(clock=MockClock())
 
-        self.store = PresenceStore(hs)
+        self.store = PresenceStore(None, hs)
 
         self.u_apple = UserID.from_string("@apple:test")
         self.u_banana = UserID.from_string("@banana:test")
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 24118bbc86..2c95e5e95a 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -14,12 +14,12 @@
 # limitations under the License.
 
 
-from tests import unittest
 from twisted.internet import defer
 
 from synapse.storage.profile import ProfileStore
 from synapse.types import UserID
 
+from tests import unittest
 from tests.utils import setup_test_homeserver
 
 
@@ -29,7 +29,7 @@ class ProfileStoreTestCase(unittest.TestCase):
     def setUp(self):
         hs = yield setup_test_homeserver()
 
-        self.store = ProfileStore(hs)
+        self.store = ProfileStore(None, hs)
 
         self.u_frank = UserID.from_string("@frank:test")
 
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 6afaca3a61..475ec900c4 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -14,16 +14,16 @@
 # limitations under the License.
 
 
-from tests import unittest
+from mock import Mock
+
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership
-from synapse.types import UserID, RoomID
+from synapse.types import RoomID, UserID
 
+from tests import unittest
 from tests.utils import setup_test_homeserver
 
-from mock import Mock
-
 
 class RedactionTestCase(unittest.TestCase):
 
@@ -36,8 +36,7 @@ class RedactionTestCase(unittest.TestCase):
 
         self.store = hs.get_datastore()
         self.event_builder_factory = hs.get_event_builder_factory()
-        self.handlers = hs.get_handlers()
-        self.message_handler = self.handlers.message_handler
+        self.event_creation_handler = hs.get_event_creation_handler()
 
         self.u_alice = UserID.from_string("@alice:test")
         self.u_bob = UserID.from_string("@bob:test")
@@ -59,7 +58,7 @@ class RedactionTestCase(unittest.TestCase):
             "content": content,
         })
 
-        event, context = yield self.message_handler._create_new_client_event(
+        event, context = yield self.event_creation_handler.create_new_client_event(
             builder
         )
 
@@ -79,7 +78,7 @@ class RedactionTestCase(unittest.TestCase):
             "content": {"body": body, "msgtype": u"message"},
         })
 
-        event, context = yield self.message_handler._create_new_client_event(
+        event, context = yield self.event_creation_handler.create_new_client_event(
             builder
         )
 
@@ -98,7 +97,7 @@ class RedactionTestCase(unittest.TestCase):
             "redacts": event_id,
         })
 
-        event, context = yield self.message_handler._create_new_client_event(
+        event, context = yield self.event_creation_handler.create_new_client_event(
             builder
         )
 
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 316ecdb32d..7821ea3fa3 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -14,9 +14,9 @@
 # limitations under the License.
 
 
-from tests import unittest
 from twisted.internet import defer
 
+from tests import unittest
 from tests.utils import setup_test_homeserver
 
 
@@ -42,9 +42,15 @@ class RegistrationStoreTestCase(unittest.TestCase):
         yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
 
         self.assertEquals(
-            # TODO(paul): Surely this field should be 'user_id', not 'name'
-            #  Additionally surely it shouldn't come in a 1-element list
-            {"name": self.user_id, "password_hash": self.pwhash, "is_guest": 0},
+            {
+                # TODO(paul): Surely this field should be 'user_id', not 'name'
+                "name": self.user_id,
+                "password_hash": self.pwhash,
+                "is_guest": 0,
+                "consent_version": None,
+                "consent_server_notice_sent": None,
+                "appservice_id": None,
+            },
             (yield self.store.get_user_by_id(self.user_id))
         )
 
@@ -86,7 +92,8 @@ class RegistrationStoreTestCase(unittest.TestCase):
 
         # now delete some
         yield self.store.user_delete_access_tokens(
-            self.user_id, device_id=self.device_id, delete_refresh_tokens=True)
+            self.user_id, device_id=self.device_id,
+        )
 
         # check they were deleted
         user = yield self.store.get_user_by_access_token(self.tokens[1])
@@ -97,8 +104,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
         self.assertEqual(self.user_id, user["name"])
 
         # now delete the rest
-        yield self.store.user_delete_access_tokens(
-            self.user_id, delete_refresh_tokens=True)
+        yield self.store.user_delete_access_tokens(self.user_id)
 
         user = yield self.store.get_user_by_access_token(self.tokens[0])
         self.assertIsNone(user,
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index ef8a4d234f..ae8ae94b6d 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -14,12 +14,12 @@
 # limitations under the License.
 
 
-from tests import unittest
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes
-from synapse.types import UserID, RoomID, RoomAlias
+from synapse.types import RoomAlias, RoomID, UserID
 
+from tests import unittest
 from tests.utils import setup_test_homeserver
 
 
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 1be7d932f6..c5fd54f67e 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -14,16 +14,16 @@
 # limitations under the License.
 
 
-from tests import unittest
+from mock import Mock
+
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership
-from synapse.types import UserID, RoomID
+from synapse.types import RoomID, UserID
 
+from tests import unittest
 from tests.utils import setup_test_homeserver
 
-from mock import Mock
-
 
 class RoomMemberStoreTestCase(unittest.TestCase):
 
@@ -37,8 +37,7 @@ class RoomMemberStoreTestCase(unittest.TestCase):
         # storage logic
         self.store = hs.get_datastore()
         self.event_builder_factory = hs.get_event_builder_factory()
-        self.handlers = hs.get_handlers()
-        self.message_handler = self.handlers.message_handler
+        self.event_creation_handler = hs.get_event_creation_handler()
 
         self.u_alice = UserID.from_string("@alice:test")
         self.u_bob = UserID.from_string("@bob:test")
@@ -58,7 +57,7 @@ class RoomMemberStoreTestCase(unittest.TestCase):
             "content": {"membership": membership},
         })
 
-        event, context = yield self.message_handler._create_new_client_event(
+        event, context = yield self.event_creation_handler.create_new_client_event(
             builder
         )
 
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
new file mode 100644
index 0000000000..23fad12bca
--- /dev/null
+++ b/tests/storage/test_user_directory.py
@@ -0,0 +1,89 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.storage import UserDirectoryStore
+from synapse.storage.roommember import ProfileInfo
+
+from tests import unittest
+from tests.utils import setup_test_homeserver
+
+ALICE = "@alice:a"
+BOB = "@bob:b"
+BOBBY = "@bobby:a"
+
+
+class UserDirectoryStoreTestCase(unittest.TestCase):
+    @defer.inlineCallbacks
+    def setUp(self):
+        self.hs = yield setup_test_homeserver()
+        self.store = UserDirectoryStore(None, self.hs)
+
+        # alice and bob are both in !room_id. bobby is not but shares
+        # a homeserver with alice.
+        yield self.store.add_profiles_to_user_dir(
+            "!room:id",
+            {
+                ALICE: ProfileInfo(None, "alice"),
+                BOB: ProfileInfo(None, "bob"),
+                BOBBY: ProfileInfo(None, "bobby")
+            },
+        )
+        yield self.store.add_users_to_public_room(
+            "!room:id",
+            [ALICE, BOB],
+        )
+        yield self.store.add_users_who_share_room(
+            "!room:id",
+            False,
+            (
+                (ALICE, BOB),
+                (BOB, ALICE),
+            ),
+        )
+
+    @defer.inlineCallbacks
+    def test_search_user_dir(self):
+        # normally when alice searches the directory she should just find
+        # bob because bobby doesn't share a room with her.
+        r = yield self.store.search_user_dir(ALICE, "bob", 10)
+        self.assertFalse(r["limited"])
+        self.assertEqual(1, len(r["results"]))
+        self.assertDictEqual(r["results"][0], {
+            "user_id": BOB,
+            "display_name": "bob",
+            "avatar_url": None,
+        })
+
+    @defer.inlineCallbacks
+    def test_search_user_dir_all_users(self):
+        self.hs.config.user_directory_search_all_users = True
+        try:
+            r = yield self.store.search_user_dir(ALICE, "bob", 10)
+            self.assertFalse(r["limited"])
+            self.assertEqual(2, len(r["results"]))
+            self.assertDictEqual(r["results"][0], {
+                "user_id": BOB,
+                "display_name": "bob",
+                "avatar_url": None,
+            })
+            self.assertDictEqual(r["results"][1], {
+                "user_id": BOBBY,
+                "display_name": "bobby",
+                "avatar_url": None,
+            })
+        finally:
+            self.hs.config.user_directory_search_all_users = False