diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 570312da84..d4ec02ffc2 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -68,7 +68,7 @@ class KeyringTestCase(unittest.TestCase):
def check_context(self, _, expected):
self.assertEquals(
- getattr(LoggingContext.current_context(), "test_key", None),
+ getattr(LoggingContext.current_context(), "request", None),
expected
)
@@ -82,7 +82,7 @@ class KeyringTestCase(unittest.TestCase):
lookup_2_deferred = defer.Deferred()
with LoggingContext("one") as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
wait_1_deferred = kr.wait_for_previous_lookups(
["server1"],
@@ -96,7 +96,7 @@ class KeyringTestCase(unittest.TestCase):
wait_1_deferred.addBoth(self.check_context, "one")
with LoggingContext("two") as context_two:
- context_two.test_key = "two"
+ context_two.request = "two"
# set off another wait. It should block because the first lookup
# hasn't yet completed.
@@ -137,7 +137,7 @@ class KeyringTestCase(unittest.TestCase):
@defer.inlineCallbacks
def get_perspectives(**kwargs):
self.assertEquals(
- LoggingContext.current_context().test_key, "11",
+ LoggingContext.current_context().request, "11",
)
with logcontext.PreserveLoggingContext():
yield persp_deferred
@@ -145,7 +145,7 @@ class KeyringTestCase(unittest.TestCase):
self.http_client.post_json.side_effect = get_perspectives
with LoggingContext("11") as context_11:
- context_11.test_key = "11"
+ context_11.request = "11"
# start off a first set of lookups
res_deferreds = kr.verify_json_objects_for_server(
@@ -167,13 +167,13 @@ class KeyringTestCase(unittest.TestCase):
# wait a tick for it to send the request to the perspectives server
# (it first tries the datastore)
- yield async.sleep(0.005)
+ yield async.sleep(1) # XXX find out why this takes so long!
self.http_client.post_json.assert_called_once()
self.assertIs(LoggingContext.current_context(), context_11)
context_12 = LoggingContext("12")
- context_12.test_key = "12"
+ context_12.request = "12"
with logcontext.PreserveLoggingContext(context_12):
# a second request for a server with outstanding requests
# should block rather than start a second call
@@ -183,7 +183,7 @@ class KeyringTestCase(unittest.TestCase):
res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1)],
)
- yield async.sleep(0.005)
+ yield async.sleep(01)
self.http_client.post_json.assert_not_called()
res_deferreds_2[0].addBoth(self.check_context, None)
@@ -211,7 +211,7 @@ class KeyringTestCase(unittest.TestCase):
sentinel_context = LoggingContext.current_context()
with LoggingContext("one") as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
defer = kr.verify_json_for_server("server9", {})
try:
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 19f5ed6bce..d92bf240b1 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -143,7 +143,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
except errors.SynapseError:
pass
- @unittest.DEBUG
@defer.inlineCallbacks
def test_claim_one_time_key(self):
local_user = "@boris:" + self.hs.hostname
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 81063f19a1..74f104e3b8 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -15,6 +15,8 @@
from twisted.internet import defer, reactor
from tests import unittest
+import tempfile
+
from mock import Mock, NonCallableMock
from tests.utils import setup_test_homeserver
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -41,7 +43,9 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
self.event_id = 0
server_factory = ReplicationStreamProtocolFactory(self.hs)
- listener = reactor.listenUNIX("\0xxx", server_factory)
+ # XXX: mktemp is unsafe and should never be used. but we're just a test.
+ path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket")
+ listener = reactor.listenUNIX(path, server_factory)
self.addCleanup(listener.stopListening)
self.streamer = server_factory.streamer
@@ -49,7 +53,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
client_factory = ReplicationClientFactory(
self.hs, "client_name", self.replication_handler
)
- client_connector = reactor.connectUNIX("\0xxx", client_factory)
+ client_connector = reactor.connectUNIX(path, client_factory)
self.addCleanup(client_factory.stopTrying)
self.addCleanup(client_connector.disconnect)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 096f771bea..8aba456510 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -49,6 +49,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
self.hs.get_device_handler = Mock(return_value=self.device_handler)
self.hs.config.enable_registration = True
+ self.hs.config.registrations_require_3pid = []
self.hs.config.auto_join_rooms = []
# init the thing we're testing
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
new file mode 100644
index 0000000000..0891308f25
--- /dev/null
+++ b/tests/storage/test_user_directory.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.storage import UserDirectoryStore
+from synapse.storage.roommember import ProfileInfo
+from tests import unittest
+from tests.utils import setup_test_homeserver
+
+ALICE = "@alice:a"
+BOB = "@bob:b"
+BOBBY = "@bobby:a"
+
+
+class UserDirectoryStoreTestCase(unittest.TestCase):
+ @defer.inlineCallbacks
+ def setUp(self):
+ self.hs = yield setup_test_homeserver()
+ self.store = UserDirectoryStore(None, self.hs)
+
+ # alice and bob are both in !room_id. bobby is not but shares
+ # a homeserver with alice.
+ yield self.store.add_profiles_to_user_dir(
+ "!room:id",
+ {
+ ALICE: ProfileInfo(None, "alice"),
+ BOB: ProfileInfo(None, "bob"),
+ BOBBY: ProfileInfo(None, "bobby")
+ },
+ )
+ yield self.store.add_users_to_public_room(
+ "!room:id",
+ [ALICE, BOB],
+ )
+ yield self.store.add_users_who_share_room(
+ "!room:id",
+ False,
+ (
+ (ALICE, BOB),
+ (BOB, ALICE),
+ ),
+ )
+
+ @defer.inlineCallbacks
+ def test_search_user_dir(self):
+ # normally when alice searches the directory she should just find
+ # bob because bobby doesn't share a room with her.
+ r = yield self.store.search_user_dir(ALICE, "bob", 10)
+ self.assertFalse(r["limited"])
+ self.assertEqual(1, len(r["results"]))
+ self.assertDictEqual(r["results"][0], {
+ "user_id": BOB,
+ "display_name": "bob",
+ "avatar_url": None,
+ })
+
+ @defer.inlineCallbacks
+ def test_search_user_dir_all_users(self):
+ self.hs.config.user_directory_search_all_users = True
+ try:
+ r = yield self.store.search_user_dir(ALICE, "bob", 10)
+ self.assertFalse(r["limited"])
+ self.assertEqual(2, len(r["results"]))
+ self.assertDictEqual(r["results"][0], {
+ "user_id": BOB,
+ "display_name": "bob",
+ "avatar_url": None,
+ })
+ self.assertDictEqual(r["results"][1], {
+ "user_id": BOBBY,
+ "display_name": "bobby",
+ "avatar_url": None,
+ })
+ finally:
+ self.hs.config.user_directory_search_all_users = False
diff --git a/tests/test_state.py b/tests/test_state.py
index feb84f3d48..d16e1b3b8b 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.events import FrozenEvent
from synapse.api.auth import Auth
from synapse.api.constants import EventTypes, Membership
-from synapse.state import StateHandler
+from synapse.state import StateHandler, StateResolutionHandler
from .utils import MockClock
@@ -148,11 +148,13 @@ class StateTestCase(unittest.TestCase):
)
hs = Mock(spec_set=[
"get_datastore", "get_auth", "get_state_handler", "get_clock",
+ "get_state_resolution_handler",
])
hs.get_datastore.return_value = self.store
hs.get_state_handler.return_value = None
hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs)
+ hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
self.store.get_next_state_group.side_effect = Mock
self.store.get_state_group_delta.return_value = (None, None)
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index e2f7765f49..4850722bc5 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -12,12 +12,12 @@ class LoggingContextTestCase(unittest.TestCase):
def _check_test_key(self, value):
self.assertEquals(
- LoggingContext.current_context().test_key, value
+ LoggingContext.current_context().request, value
)
def test_with_context(self):
with LoggingContext() as context_one:
- context_one.test_key = "test"
+ context_one.request = "test"
self._check_test_key("test")
@defer.inlineCallbacks
@@ -25,14 +25,14 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks
def competing_callback():
with LoggingContext() as competing_context:
- competing_context.test_key = "competing"
+ competing_context.request = "competing"
yield sleep(0)
self._check_test_key("competing")
reactor.callLater(0, competing_callback)
with LoggingContext() as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
yield sleep(0)
self._check_test_key("one")
@@ -43,14 +43,14 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks
def cb():
- context_one.test_key = "one"
+ context_one.request = "one"
yield function()
self._check_test_key("one")
callback_completed[0] = True
with LoggingContext() as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
# fire off function, but don't wait on it.
logcontext.preserve_fn(cb)()
@@ -107,7 +107,7 @@ class LoggingContextTestCase(unittest.TestCase):
sentinel_context = LoggingContext.current_context()
with LoggingContext() as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
d1 = logcontext.make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
@@ -124,7 +124,7 @@ class LoggingContextTestCase(unittest.TestCase):
argument isn't actually a deferred"""
with LoggingContext() as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
d1 = logcontext.make_deferred_yieldable("bum")
self._check_test_key("one")
diff --git a/tests/utils.py b/tests/utils.py
index 44e5f75093..8efd3a3475 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -13,27 +13,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.http.server import HttpServer
-from synapse.api.errors import cs_error, CodeMessageException, StoreError
-from synapse.api.constants import EventTypes
-from synapse.storage.prepare_database import prepare_database
-from synapse.storage.engines import create_engine
-from synapse.server import HomeServer
-from synapse.federation.transport import server
-from synapse.util.ratelimitutils import FederationRateLimiter
-
-from synapse.util.logcontext import LoggingContext
-
-from twisted.internet import defer, reactor
-from twisted.enterprise.adbapi import ConnectionPool
-
-from collections import namedtuple
-from mock import patch, Mock
import hashlib
+from inspect import getcallargs
import urllib
import urlparse
-from inspect import getcallargs
+from mock import Mock, patch
+from twisted.internet import defer, reactor
+
+from synapse.api.errors import CodeMessageException, cs_error
+from synapse.federation.transport import server
+from synapse.http.server import HttpServer
+from synapse.server import HomeServer
+from synapse.storage import PostgresEngine
+from synapse.storage.engines import create_engine
+from synapse.storage.prepare_database import prepare_database
+from synapse.util.logcontext import LoggingContext
+from synapse.util.ratelimitutils import FederationRateLimiter
+
+# set this to True to run the tests against postgres instead of sqlite.
+# It requires you to have a local postgres database called synapse_test, within
+# which ALL TABLES WILL BE DROPPED
+USE_POSTGRES_FOR_TESTS = False
@defer.inlineCallbacks
@@ -57,36 +58,70 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config.worker_app = None
config.email_enable_notifs = False
config.block_non_admin_invites = False
+ config.federation_domain_whitelist = None
+ config.user_directory_search_all_users = False
# disable user directory updates, because they get done in the
# background, which upsets the test runner.
config.update_user_directory = False
config.use_frozen_dicts = True
- config.database_config = {"name": "sqlite3"}
config.ldap_enabled = False
if "clock" not in kargs:
kargs["clock"] = MockClock()
+ if USE_POSTGRES_FOR_TESTS:
+ config.database_config = {
+ "name": "psycopg2",
+ "args": {
+ "database": "synapse_test",
+ "cp_min": 1,
+ "cp_max": 5,
+ },
+ }
+ else:
+ config.database_config = {
+ "name": "sqlite3",
+ "args": {
+ "database": ":memory:",
+ "cp_min": 1,
+ "cp_max": 1,
+ },
+ }
+
+ db_engine = create_engine(config.database_config)
+
+ # we need to configure the connection pool to run the on_new_connection
+ # function, so that we can test code that uses custom sqlite functions
+ # (like rank).
+ config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
+
if datastore is None:
- db_pool = SQLiteMemoryDbPool()
- yield db_pool.prepare()
hs = HomeServer(
- name, db_pool=db_pool, config=config,
+ name, config=config,
+ db_config=config.database_config,
version_string="Synapse/tests",
- database_engine=create_engine(config.database_config),
- get_db_conn=db_pool.get_db_conn,
+ database_engine=db_engine,
room_list_handler=object(),
tls_server_context_factory=Mock(),
**kargs
)
+ db_conn = hs.get_db_conn()
+ # make sure that the database is empty
+ if isinstance(db_engine, PostgresEngine):
+ cur = db_conn.cursor()
+ cur.execute("SELECT tablename FROM pg_tables where schemaname='public'")
+ rows = cur.fetchall()
+ for r in rows:
+ cur.execute("DROP TABLE %s CASCADE" % r[0])
+ yield prepare_database(db_conn, db_engine, config)
hs.setup()
else:
hs = HomeServer(
name, db_pool=None, datastore=datastore, config=config,
version_string="Synapse/tests",
- database_engine=create_engine(config.database_config),
+ database_engine=db_engine,
room_list_handler=object(),
tls_server_context_factory=Mock(),
**kargs
@@ -305,168 +340,6 @@ class MockClock(object):
return d
-class SQLiteMemoryDbPool(ConnectionPool, object):
- def __init__(self):
- super(SQLiteMemoryDbPool, self).__init__(
- "sqlite3", ":memory:",
- cp_min=1,
- cp_max=1,
- )
-
- self.config = Mock()
- self.config.password_providers = []
- self.config.database_config = {"name": "sqlite3"}
-
- def prepare(self):
- engine = self.create_engine()
- return self.runWithConnection(
- lambda conn: prepare_database(conn, engine, self.config)
- )
-
- def get_db_conn(self):
- conn = self.connect()
- engine = self.create_engine()
- prepare_database(conn, engine, self.config)
- return conn
-
- def create_engine(self):
- return create_engine(self.config.database_config)
-
-
-class MemoryDataStore(object):
-
- Room = namedtuple(
- "Room",
- ["room_id", "is_public", "creator"]
- )
-
- def __init__(self):
- self.tokens_to_users = {}
- self.paths_to_content = {}
-
- self.members = {}
- self.rooms = {}
-
- self.current_state = {}
- self.events = []
-
- class Snapshot(namedtuple("Snapshot", "room_id user_id membership_state")):
- def fill_out_prev_events(self, event):
- pass
-
- def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
- return self.Snapshot(
- room_id, user_id, self.get_room_member(user_id, room_id)
- )
-
- def register(self, user_id, token, password_hash):
- if user_id in self.tokens_to_users.values():
- raise StoreError(400, "User in use.")
- self.tokens_to_users[token] = user_id
-
- def get_user_by_access_token(self, token):
- try:
- return {
- "name": self.tokens_to_users[token],
- }
- except Exception:
- raise StoreError(400, "User does not exist.")
-
- def get_room(self, room_id):
- try:
- return self.rooms[room_id]
- except Exception:
- return None
-
- def store_room(self, room_id, room_creator_user_id, is_public):
- if room_id in self.rooms:
- raise StoreError(409, "Conflicting room!")
-
- room = MemoryDataStore.Room(
- room_id=room_id,
- is_public=is_public,
- creator=room_creator_user_id
- )
- self.rooms[room_id] = room
-
- def get_room_member(self, user_id, room_id):
- return self.members.get(room_id, {}).get(user_id)
-
- def get_room_members(self, room_id, membership=None):
- if membership:
- return [
- v for k, v in self.members.get(room_id, {}).items()
- if v.membership == membership
- ]
- else:
- return self.members.get(room_id, {}).values()
-
- def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
- return [
- m[user_id] for m in self.members.values()
- if user_id in m and m[user_id].membership in membership_list
- ]
-
- def get_room_events_stream(self, user_id=None, from_key=None, to_key=None,
- limit=0, with_feedback=False):
- return ([], from_key) # TODO
-
- def get_joined_hosts_for_room(self, room_id):
- return defer.succeed([])
-
- def persist_event(self, event):
- if event.type == EventTypes.Member:
- room_id = event.room_id
- user = event.state_key
- self.members.setdefault(room_id, {})[user] = event
-
- if hasattr(event, "state_key"):
- key = (event.room_id, event.type, event.state_key)
- self.current_state[key] = event
-
- self.events.append(event)
-
- def get_current_state(self, room_id, event_type=None, state_key=""):
- if event_type:
- key = (room_id, event_type, state_key)
- if self.current_state.get(key):
- return [self.current_state.get(key)]
- return None
- else:
- return [
- e for e in self.current_state
- if e[0] == room_id
- ]
-
- def set_presence_state(self, user_localpart, state):
- return defer.succeed({"state": 0})
-
- def get_presence_list(self, user_localpart, accepted):
- return []
-
- def get_room_events_max_id(self):
- return "s0" # TODO (erikj)
-
- def get_send_event_level(self, room_id):
- return defer.succeed(0)
-
- def get_power_level(self, room_id, user_id):
- return defer.succeed(0)
-
- def get_add_state_level(self, room_id):
- return defer.succeed(0)
-
- def get_room_join_rule(self, room_id):
- # TODO (erikj): This should be configurable
- return defer.succeed("invite")
-
- def get_ops_levels(self, room_id):
- return defer.succeed((5, 5, 5))
-
- def insert_client_ip(self, user, access_token, ip, user_agent):
- return defer.succeed(None)
-
-
def _format_call(args, kwargs):
return ", ".join(
["%r" % (a) for a in args] +
|