diff --git a/tests/storage/event_injector.py b/tests/storage/event_injector.py
deleted file mode 100644
index 024ac15069..0000000000
--- a/tests/storage/event_injector.py
+++ /dev/null
@@ -1,76 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from twisted.internet import defer
-
-from synapse.api.constants import EventTypes
-
-
-class EventInjector:
- def __init__(self, hs):
- self.hs = hs
- self.store = hs.get_datastore()
- self.message_handler = hs.get_handlers().message_handler
- self.event_builder_factory = hs.get_event_builder_factory()
-
- @defer.inlineCallbacks
- def create_room(self, room, user):
- builder = self.event_builder_factory.new({
- "type": EventTypes.Create,
- "sender": user.to_string(),
- "room_id": room.to_string(),
- "content": {},
- })
-
- event, context = yield self.message_handler._create_new_client_event(
- builder
- )
-
- yield self.store.persist_event(event, context)
-
- @defer.inlineCallbacks
- def inject_room_member(self, room, user, membership):
- builder = self.event_builder_factory.new({
- "type": EventTypes.Member,
- "sender": user.to_string(),
- "state_key": user.to_string(),
- "room_id": room.to_string(),
- "content": {"membership": membership},
- })
-
- event, context = yield self.message_handler._create_new_client_event(
- builder
- )
-
- yield self.store.persist_event(event, context)
-
- defer.returnValue(event)
-
- @defer.inlineCallbacks
- def inject_message(self, room, user, body):
- builder = self.event_builder_factory.new({
- "type": EventTypes.Message,
- "sender": user.to_string(),
- "state_key": user.to_string(),
- "room_id": room.to_string(),
- "content": {"body": body, "msgtype": u"message"},
- })
-
- event, context = yield self.message_handler._create_new_client_event(
- builder
- )
-
- yield self.store.persist_event(event, context)
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 3cfa21c9f8..6d6f00c5c5 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -14,15 +14,15 @@
# limitations under the License.
-from tests import unittest
-from twisted.internet import defer
-
from mock import Mock
-from synapse.util.async import ObservableDeferred
+from twisted.internet import defer
+from synapse.util.async import ObservableDeferred
from synapse.util.caches.descriptors import Cache, cached
+from tests import unittest
+
class CacheTestCase(unittest.TestCase):
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 9e98d0e330..099861b27c 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -12,21 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
+import os
import tempfile
-from synapse.config._base import ConfigError
-from tests import unittest
+
+from mock import Mock
+
+import yaml
+
from twisted.internet import defer
-from tests.utils import setup_test_homeserver
from synapse.appservice import ApplicationService, ApplicationServiceState
+from synapse.config._base import ConfigError
from synapse.storage.appservice import (
- ApplicationServiceStore, ApplicationServiceTransactionStore
+ ApplicationServiceStore,
+ ApplicationServiceTransactionStore,
)
-import json
-import os
-import yaml
-from mock import Mock
+from tests import unittest
+from tests.utils import setup_test_homeserver
class ApplicationServiceStoreTestCase(unittest.TestCase):
@@ -42,7 +46,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
hs = yield setup_test_homeserver(
config=config,
federation_sender=Mock(),
- replication_layer=Mock(),
+ federation_client=Mock(),
)
self.as_token = "token1"
@@ -58,14 +62,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
# must be done after inserts
- self.store = ApplicationServiceStore(hs)
+ self.store = ApplicationServiceStore(None, hs)
def tearDown(self):
# TODO: suboptimal that we need to create files for tests!
for f in self.as_yaml_files:
try:
os.remove(f)
- except:
+ except Exception:
pass
def _add_appservice(self, as_token, id, url, hs_token, sender):
@@ -119,7 +123,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
hs = yield setup_test_homeserver(
config=config,
federation_sender=Mock(),
- replication_layer=Mock(),
+ federation_client=Mock(),
)
self.db_pool = hs.get_db_pool()
@@ -150,7 +154,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.as_yaml_files = []
- self.store = TestTransactionStore(hs)
+ self.store = TestTransactionStore(None, hs)
def _add_service(self, url, as_token, id):
as_yaml = dict(url=url, as_token=as_token, hs_token="something",
@@ -420,8 +424,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
class TestTransactionStore(ApplicationServiceTransactionStore,
ApplicationServiceStore):
- def __init__(self, hs):
- super(TestTransactionStore, self).__init__(hs)
+ def __init__(self, db_conn, hs):
+ super(TestTransactionStore, self).__init__(db_conn, hs)
class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
@@ -455,10 +459,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
config=config,
datastore=Mock(),
federation_sender=Mock(),
- replication_layer=Mock(),
+ federation_client=Mock(),
)
- ApplicationServiceStore(hs)
+ ApplicationServiceStore(None, hs)
@defer.inlineCallbacks
def test_duplicate_ids(self):
@@ -473,16 +477,16 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
config=config,
datastore=Mock(),
federation_sender=Mock(),
- replication_layer=Mock(),
+ federation_client=Mock(),
)
with self.assertRaises(ConfigError) as cm:
- ApplicationServiceStore(hs)
+ ApplicationServiceStore(None, hs)
e = cm.exception
- self.assertIn(f1, e.message)
- self.assertIn(f2, e.message)
- self.assertIn("id", e.message)
+ self.assertIn(f1, str(e))
+ self.assertIn(f2, str(e))
+ self.assertIn("id", str(e))
@defer.inlineCallbacks
def test_duplicate_as_tokens(self):
@@ -497,13 +501,13 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
config=config,
datastore=Mock(),
federation_sender=Mock(),
- replication_layer=Mock(),
+ federation_client=Mock(),
)
with self.assertRaises(ConfigError) as cm:
- ApplicationServiceStore(hs)
+ ApplicationServiceStore(None, hs)
e = cm.exception
- self.assertIn(f1, e.message)
- self.assertIn(f2, e.message)
- self.assertIn("as_token", e.message)
+ self.assertIn(f1, str(e))
+ self.assertIn(f2, str(e))
+ self.assertIn("as_token", str(e))
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 1286b4ce2d..ab1f310572 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -1,10 +1,10 @@
-from tests import unittest
+from mock import Mock
+
from twisted.internet import defer
+from tests import unittest
from tests.utils import setup_test_homeserver
-from mock import Mock
-
class BackgroundUpdateTestCase(unittest.TestCase):
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 91e971190c..1d1234ee39 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -14,18 +14,18 @@
# limitations under the License.
-from tests import unittest
-from twisted.internet import defer
+from collections import OrderedDict
from mock import Mock
-from collections import OrderedDict
+from twisted.internet import defer
from synapse.server import HomeServer
-
from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import create_engine
+from tests import unittest
+
class SQLBaseStoreTestCase(unittest.TestCase):
""" Test the "simple" SQL generating methods in SQLBaseStore. """
@@ -56,7 +56,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
database_engine=create_engine(config.database_config),
)
- self.datastore = SQLBaseStore(hs)
+ self.datastore = SQLBaseStore(None, hs)
@defer.inlineCallbacks
def test_insert_1col(self):
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 03df697575..bd6fda6cb1 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -15,9 +15,6 @@
from twisted.internet import defer
-import synapse.server
-import synapse.storage
-import synapse.types
import tests.unittest
import tests.utils
@@ -39,7 +36,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
self.clock.now = 12345678
user_id = "@user:id"
yield self.store.insert_client_ip(
- synapse.types.UserID.from_string(user_id),
+ user_id,
"access_token", "ip", "user_agent", "device_id",
)
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index f8725acea0..a54cc6bc32 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -16,6 +16,7 @@
from twisted.internet import defer
import synapse.api.errors
+
import tests.unittest
import tests.utils
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index b087892e0b..129ebaf343 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -14,12 +14,12 @@
# limitations under the License.
-from tests import unittest
from twisted.internet import defer
from synapse.storage.directory import DirectoryStore
-from synapse.types import RoomID, RoomAlias
+from synapse.types import RoomAlias, RoomID
+from tests import unittest
from tests.utils import setup_test_homeserver
@@ -29,7 +29,7 @@ class DirectoryStoreTestCase(unittest.TestCase):
def setUp(self):
hs = yield setup_test_homeserver()
- self.store = DirectoryStore(hs)
+ self.store = DirectoryStore(None, hs)
self.room = RoomID.from_string("!abcde:test")
self.alias = RoomAlias.from_string("#my-room:test")
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
new file mode 100644
index 0000000000..30683e7888
--- /dev/null
+++ b/tests/storage/test_event_federation.py
@@ -0,0 +1,68 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+import tests.unittest
+import tests.utils
+
+
+class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
+ @defer.inlineCallbacks
+ def setUp(self):
+ hs = yield tests.utils.setup_test_homeserver()
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def test_get_prev_events_for_room(self):
+ room_id = '@ROOM:local'
+
+ # add a bunch of events and hashes to act as forward extremities
+ def insert_event(txn, i):
+ event_id = '$event_%i:local' % i
+
+ txn.execute((
+ "INSERT INTO events ("
+ " room_id, event_id, type, depth, topological_ordering,"
+ " content, processed, outlier) "
+ "VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?)"
+ ), (room_id, event_id, i, i, True, False))
+
+ txn.execute((
+ 'INSERT INTO event_forward_extremities (room_id, event_id) '
+ 'VALUES (?, ?)'
+ ), (room_id, event_id))
+
+ txn.execute((
+ 'INSERT INTO event_reference_hashes '
+ '(event_id, algorithm, hash) '
+ "VALUES (?, 'sha256', ?)"
+ ), (event_id, 'ffff'))
+
+ for i in range(0, 11):
+ yield self.store.runInteraction("insert", insert_event, i)
+
+ # this should get the last five and five others
+ r = yield self.store.get_prev_events_for_room(room_id)
+ self.assertEqual(10, len(r))
+ for i in range(0, 5):
+ el = r[i]
+ depth = el[2]
+ self.assertEqual(10 - i, depth)
+
+ for i in range(5, 5):
+ el = r[i]
+ depth = el[2]
+ self.assertLessEqual(5, depth)
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 3135488353..8430fc7ba6 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -13,11 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from mock import Mock
+
from twisted.internet import defer
import tests.unittest
import tests.utils
-from mock import Mock
USER_ID = "@user:example.com"
@@ -55,13 +56,14 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
def _assert_counts(noitf_count, highlight_count):
counts = yield self.store.runInteraction(
"", self.store._get_unread_counts_by_pos_txn,
- room_id, user_id, 0, 0
+ room_id, user_id, 0
)
self.assertEquals(
counts,
{"notify_count": noitf_count, "highlight_count": highlight_count}
)
+ @defer.inlineCallbacks
def _inject_actions(stream, action):
event = Mock()
event.room_id = room_id
@@ -69,11 +71,12 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
event.internal_metadata.stream_ordering = stream
event.depth = stream
- tuples = [(user_id, action)]
-
- return self.store.runInteraction(
+ yield self.store.add_push_actions_to_staging(
+ event.event_id, {user_id: action},
+ )
+ yield self.store.runInteraction(
"", self.store._set_push_actions_for_event_and_users_txn,
- event, tuples
+ [(event, None)], [(event, None)],
)
def _rotate(stream):
@@ -84,7 +87,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
def _mark_read(stream, depth):
return self.store.runInteraction(
"", self.store._remove_old_push_actions_before_txn,
- room_id, user_id, depth, stream
+ room_id, user_id, stream
)
yield _assert_counts(0, 0)
@@ -125,3 +128,69 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
yield _assert_counts(1, 1)
yield _rotate(10)
yield _assert_counts(1, 1)
+
+ @defer.inlineCallbacks
+ def test_find_first_stream_ordering_after_ts(self):
+ def add_event(so, ts):
+ return self.store._simple_insert("events", {
+ "stream_ordering": so,
+ "received_ts": ts,
+ "event_id": "event%i" % so,
+ "type": "",
+ "room_id": "",
+ "content": "",
+ "processed": True,
+ "outlier": False,
+ "topological_ordering": 0,
+ "depth": 0,
+ })
+
+ # start with the base case where there are no events in the table
+ r = yield self.store.find_first_stream_ordering_after_ts(11)
+ self.assertEqual(r, 0)
+
+ # now with one event
+ yield add_event(2, 10)
+ r = yield self.store.find_first_stream_ordering_after_ts(9)
+ self.assertEqual(r, 2)
+ r = yield self.store.find_first_stream_ordering_after_ts(10)
+ self.assertEqual(r, 2)
+ r = yield self.store.find_first_stream_ordering_after_ts(11)
+ self.assertEqual(r, 3)
+
+ # add a bunch of dummy events to the events table
+ for (stream_ordering, ts) in (
+ (3, 110),
+ (4, 120),
+ (5, 120),
+ (10, 130),
+ (20, 140),
+ ):
+ yield add_event(stream_ordering, ts)
+
+ r = yield self.store.find_first_stream_ordering_after_ts(110)
+ self.assertEqual(r, 3,
+ "First event after 110ms should be 3, was %i" % r)
+
+ # 4 and 5 are both after 120: we want 4 rather than 5
+ r = yield self.store.find_first_stream_ordering_after_ts(120)
+ self.assertEqual(r, 4,
+ "First event after 120ms should be 4, was %i" % r)
+
+ r = yield self.store.find_first_stream_ordering_after_ts(129)
+ self.assertEqual(r, 10,
+ "First event after 129ms should be 10, was %i" % r)
+
+ # check we can get the last event
+ r = yield self.store.find_first_stream_ordering_after_ts(140)
+ self.assertEqual(r, 20,
+ "First event after 14ms should be 20, was %i" % r)
+
+ # off the end
+ r = yield self.store.find_first_stream_ordering_after_ts(160)
+ self.assertEqual(r, 21)
+
+ # check we can find an event at ordering zero
+ yield add_event(0, 5)
+ r = yield self.store.find_first_stream_ordering_after_ts(1)
+ self.assertEqual(r, 0)
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index 0be790d8f8..3a3d002782 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -14,6 +14,7 @@
# limitations under the License.
import signedjson.key
+
from twisted.internet import defer
import tests.unittest
diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py
index 63203cea35..3276b39504 100644
--- a/tests/storage/test_presence.py
+++ b/tests/storage/test_presence.py
@@ -14,13 +14,13 @@
# limitations under the License.
-from tests import unittest
from twisted.internet import defer
from synapse.storage.presence import PresenceStore
from synapse.types import UserID
-from tests.utils import setup_test_homeserver, MockClock
+from tests import unittest
+from tests.utils import MockClock, setup_test_homeserver
class PresenceStoreTestCase(unittest.TestCase):
@@ -29,7 +29,7 @@ class PresenceStoreTestCase(unittest.TestCase):
def setUp(self):
hs = yield setup_test_homeserver(clock=MockClock())
- self.store = PresenceStore(hs)
+ self.store = PresenceStore(None, hs)
self.u_apple = UserID.from_string("@apple:test")
self.u_banana = UserID.from_string("@banana:test")
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 24118bbc86..2c95e5e95a 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -14,12 +14,12 @@
# limitations under the License.
-from tests import unittest
from twisted.internet import defer
from synapse.storage.profile import ProfileStore
from synapse.types import UserID
+from tests import unittest
from tests.utils import setup_test_homeserver
@@ -29,7 +29,7 @@ class ProfileStoreTestCase(unittest.TestCase):
def setUp(self):
hs = yield setup_test_homeserver()
- self.store = ProfileStore(hs)
+ self.store = ProfileStore(None, hs)
self.u_frank = UserID.from_string("@frank:test")
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 6afaca3a61..475ec900c4 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -14,16 +14,16 @@
# limitations under the License.
-from tests import unittest
+from mock import Mock
+
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
-from synapse.types import UserID, RoomID
+from synapse.types import RoomID, UserID
+from tests import unittest
from tests.utils import setup_test_homeserver
-from mock import Mock
-
class RedactionTestCase(unittest.TestCase):
@@ -36,8 +36,7 @@ class RedactionTestCase(unittest.TestCase):
self.store = hs.get_datastore()
self.event_builder_factory = hs.get_event_builder_factory()
- self.handlers = hs.get_handlers()
- self.message_handler = self.handlers.message_handler
+ self.event_creation_handler = hs.get_event_creation_handler()
self.u_alice = UserID.from_string("@alice:test")
self.u_bob = UserID.from_string("@bob:test")
@@ -59,7 +58,7 @@ class RedactionTestCase(unittest.TestCase):
"content": content,
})
- event, context = yield self.message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder
)
@@ -79,7 +78,7 @@ class RedactionTestCase(unittest.TestCase):
"content": {"body": body, "msgtype": u"message"},
})
- event, context = yield self.message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder
)
@@ -98,7 +97,7 @@ class RedactionTestCase(unittest.TestCase):
"redacts": event_id,
})
- event, context = yield self.message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder
)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 316ecdb32d..7821ea3fa3 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -14,9 +14,9 @@
# limitations under the License.
-from tests import unittest
from twisted.internet import defer
+from tests import unittest
from tests.utils import setup_test_homeserver
@@ -42,9 +42,15 @@ class RegistrationStoreTestCase(unittest.TestCase):
yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
self.assertEquals(
- # TODO(paul): Surely this field should be 'user_id', not 'name'
- # Additionally surely it shouldn't come in a 1-element list
- {"name": self.user_id, "password_hash": self.pwhash, "is_guest": 0},
+ {
+ # TODO(paul): Surely this field should be 'user_id', not 'name'
+ "name": self.user_id,
+ "password_hash": self.pwhash,
+ "is_guest": 0,
+ "consent_version": None,
+ "consent_server_notice_sent": None,
+ "appservice_id": None,
+ },
(yield self.store.get_user_by_id(self.user_id))
)
@@ -86,7 +92,8 @@ class RegistrationStoreTestCase(unittest.TestCase):
# now delete some
yield self.store.user_delete_access_tokens(
- self.user_id, device_id=self.device_id, delete_refresh_tokens=True)
+ self.user_id, device_id=self.device_id,
+ )
# check they were deleted
user = yield self.store.get_user_by_access_token(self.tokens[1])
@@ -97,8 +104,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
self.assertEqual(self.user_id, user["name"])
# now delete the rest
- yield self.store.user_delete_access_tokens(
- self.user_id, delete_refresh_tokens=True)
+ yield self.store.user_delete_access_tokens(self.user_id)
user = yield self.store.get_user_by_access_token(self.tokens[0])
self.assertIsNone(user,
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index ef8a4d234f..ae8ae94b6d 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -14,12 +14,12 @@
# limitations under the License.
-from tests import unittest
from twisted.internet import defer
from synapse.api.constants import EventTypes
-from synapse.types import UserID, RoomID, RoomAlias
+from synapse.types import RoomAlias, RoomID, UserID
+from tests import unittest
from tests.utils import setup_test_homeserver
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 1be7d932f6..c5fd54f67e 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -14,16 +14,16 @@
# limitations under the License.
-from tests import unittest
+from mock import Mock
+
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
-from synapse.types import UserID, RoomID
+from synapse.types import RoomID, UserID
+from tests import unittest
from tests.utils import setup_test_homeserver
-from mock import Mock
-
class RoomMemberStoreTestCase(unittest.TestCase):
@@ -37,8 +37,7 @@ class RoomMemberStoreTestCase(unittest.TestCase):
# storage logic
self.store = hs.get_datastore()
self.event_builder_factory = hs.get_event_builder_factory()
- self.handlers = hs.get_handlers()
- self.message_handler = self.handlers.message_handler
+ self.event_creation_handler = hs.get_event_creation_handler()
self.u_alice = UserID.from_string("@alice:test")
self.u_bob = UserID.from_string("@bob:test")
@@ -58,7 +57,7 @@ class RoomMemberStoreTestCase(unittest.TestCase):
"content": {"membership": membership},
})
- event, context = yield self.message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder
)
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
new file mode 100644
index 0000000000..23fad12bca
--- /dev/null
+++ b/tests/storage/test_user_directory.py
@@ -0,0 +1,89 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.storage import UserDirectoryStore
+from synapse.storage.roommember import ProfileInfo
+
+from tests import unittest
+from tests.utils import setup_test_homeserver
+
+ALICE = "@alice:a"
+BOB = "@bob:b"
+BOBBY = "@bobby:a"
+
+
+class UserDirectoryStoreTestCase(unittest.TestCase):
+ @defer.inlineCallbacks
+ def setUp(self):
+ self.hs = yield setup_test_homeserver()
+ self.store = UserDirectoryStore(None, self.hs)
+
+ # alice and bob are both in !room_id. bobby is not but shares
+ # a homeserver with alice.
+ yield self.store.add_profiles_to_user_dir(
+ "!room:id",
+ {
+ ALICE: ProfileInfo(None, "alice"),
+ BOB: ProfileInfo(None, "bob"),
+ BOBBY: ProfileInfo(None, "bobby")
+ },
+ )
+ yield self.store.add_users_to_public_room(
+ "!room:id",
+ [ALICE, BOB],
+ )
+ yield self.store.add_users_who_share_room(
+ "!room:id",
+ False,
+ (
+ (ALICE, BOB),
+ (BOB, ALICE),
+ ),
+ )
+
+ @defer.inlineCallbacks
+ def test_search_user_dir(self):
+ # normally when alice searches the directory she should just find
+ # bob because bobby doesn't share a room with her.
+ r = yield self.store.search_user_dir(ALICE, "bob", 10)
+ self.assertFalse(r["limited"])
+ self.assertEqual(1, len(r["results"]))
+ self.assertDictEqual(r["results"][0], {
+ "user_id": BOB,
+ "display_name": "bob",
+ "avatar_url": None,
+ })
+
+ @defer.inlineCallbacks
+ def test_search_user_dir_all_users(self):
+ self.hs.config.user_directory_search_all_users = True
+ try:
+ r = yield self.store.search_user_dir(ALICE, "bob", 10)
+ self.assertFalse(r["limited"])
+ self.assertEqual(2, len(r["results"]))
+ self.assertDictEqual(r["results"][0], {
+ "user_id": BOB,
+ "display_name": "bob",
+ "avatar_url": None,
+ })
+ self.assertDictEqual(r["results"][1], {
+ "user_id": BOBBY,
+ "display_name": "bobby",
+ "avatar_url": None,
+ })
+ finally:
+ self.hs.config.user_directory_search_all_users = False
|