diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/crypto/test_keyring.py | 18 | ||||
-rw-r--r-- | tests/handlers/test_e2e_keys.py | 1 | ||||
-rw-r--r-- | tests/replication/slave/storage/_base.py | 8 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_register.py | 1 | ||||
-rw-r--r-- | tests/storage/test_user_directory.py | 88 | ||||
-rw-r--r-- | tests/test_state.py | 4 | ||||
-rw-r--r-- | tests/util/test_logcontext.py | 16 | ||||
-rw-r--r-- | tests/utils.py | 245 |
8 files changed, 174 insertions, 207 deletions
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 570312da84..d4ec02ffc2 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -68,7 +68,7 @@ class KeyringTestCase(unittest.TestCase): def check_context(self, _, expected): self.assertEquals( - getattr(LoggingContext.current_context(), "test_key", None), + getattr(LoggingContext.current_context(), "request", None), expected ) @@ -82,7 +82,7 @@ class KeyringTestCase(unittest.TestCase): lookup_2_deferred = defer.Deferred() with LoggingContext("one") as context_one: - context_one.test_key = "one" + context_one.request = "one" wait_1_deferred = kr.wait_for_previous_lookups( ["server1"], @@ -96,7 +96,7 @@ class KeyringTestCase(unittest.TestCase): wait_1_deferred.addBoth(self.check_context, "one") with LoggingContext("two") as context_two: - context_two.test_key = "two" + context_two.request = "two" # set off another wait. It should block because the first lookup # hasn't yet completed. @@ -137,7 +137,7 @@ class KeyringTestCase(unittest.TestCase): @defer.inlineCallbacks def get_perspectives(**kwargs): self.assertEquals( - LoggingContext.current_context().test_key, "11", + LoggingContext.current_context().request, "11", ) with logcontext.PreserveLoggingContext(): yield persp_deferred @@ -145,7 +145,7 @@ class KeyringTestCase(unittest.TestCase): self.http_client.post_json.side_effect = get_perspectives with LoggingContext("11") as context_11: - context_11.test_key = "11" + context_11.request = "11" # start off a first set of lookups res_deferreds = kr.verify_json_objects_for_server( @@ -167,13 +167,13 @@ class KeyringTestCase(unittest.TestCase): # wait a tick for it to send the request to the perspectives server # (it first tries the datastore) - yield async.sleep(0.005) + yield async.sleep(1) # XXX find out why this takes so long! self.http_client.post_json.assert_called_once() self.assertIs(LoggingContext.current_context(), context_11) context_12 = LoggingContext("12") - context_12.test_key = "12" + context_12.request = "12" with logcontext.PreserveLoggingContext(context_12): # a second request for a server with outstanding requests # should block rather than start a second call @@ -183,7 +183,7 @@ class KeyringTestCase(unittest.TestCase): res_deferreds_2 = kr.verify_json_objects_for_server( [("server10", json1)], ) - yield async.sleep(0.005) + yield async.sleep(01) self.http_client.post_json.assert_not_called() res_deferreds_2[0].addBoth(self.check_context, None) @@ -211,7 +211,7 @@ class KeyringTestCase(unittest.TestCase): sentinel_context = LoggingContext.current_context() with LoggingContext("one") as context_one: - context_one.test_key = "one" + context_one.request = "one" defer = kr.verify_json_for_server("server9", {}) try: diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 19f5ed6bce..d92bf240b1 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -143,7 +143,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase): except errors.SynapseError: pass - @unittest.DEBUG @defer.inlineCallbacks def test_claim_one_time_key(self): local_user = "@boris:" + self.hs.hostname diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 81063f19a1..74f104e3b8 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -15,6 +15,8 @@ from twisted.internet import defer, reactor from tests import unittest +import tempfile + from mock import Mock, NonCallableMock from tests.utils import setup_test_homeserver from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory @@ -41,7 +43,9 @@ class BaseSlavedStoreTestCase(unittest.TestCase): self.event_id = 0 server_factory = ReplicationStreamProtocolFactory(self.hs) - listener = reactor.listenUNIX("\0xxx", server_factory) + # XXX: mktemp is unsafe and should never be used. but we're just a test. + path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket") + listener = reactor.listenUNIX(path, server_factory) self.addCleanup(listener.stopListening) self.streamer = server_factory.streamer @@ -49,7 +53,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase): client_factory = ReplicationClientFactory( self.hs, "client_name", self.replication_handler ) - client_connector = reactor.connectUNIX("\0xxx", client_factory) + client_connector = reactor.connectUNIX(path, client_factory) self.addCleanup(client_factory.stopTrying) self.addCleanup(client_connector.disconnect) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 096f771bea..8aba456510 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -49,6 +49,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.hs.get_auth_handler = Mock(return_value=self.auth_handler) self.hs.get_device_handler = Mock(return_value=self.device_handler) self.hs.config.enable_registration = True + self.hs.config.registrations_require_3pid = [] self.hs.config.auto_join_rooms = [] # init the thing we're testing diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py new file mode 100644 index 0000000000..0891308f25 --- /dev/null +++ b/tests/storage/test_user_directory.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.storage import UserDirectoryStore +from synapse.storage.roommember import ProfileInfo +from tests import unittest +from tests.utils import setup_test_homeserver + +ALICE = "@alice:a" +BOB = "@bob:b" +BOBBY = "@bobby:a" + + +class UserDirectoryStoreTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver() + self.store = UserDirectoryStore(None, self.hs) + + # alice and bob are both in !room_id. bobby is not but shares + # a homeserver with alice. + yield self.store.add_profiles_to_user_dir( + "!room:id", + { + ALICE: ProfileInfo(None, "alice"), + BOB: ProfileInfo(None, "bob"), + BOBBY: ProfileInfo(None, "bobby") + }, + ) + yield self.store.add_users_to_public_room( + "!room:id", + [ALICE, BOB], + ) + yield self.store.add_users_who_share_room( + "!room:id", + False, + ( + (ALICE, BOB), + (BOB, ALICE), + ), + ) + + @defer.inlineCallbacks + def test_search_user_dir(self): + # normally when alice searches the directory she should just find + # bob because bobby doesn't share a room with her. + r = yield self.store.search_user_dir(ALICE, "bob", 10) + self.assertFalse(r["limited"]) + self.assertEqual(1, len(r["results"])) + self.assertDictEqual(r["results"][0], { + "user_id": BOB, + "display_name": "bob", + "avatar_url": None, + }) + + @defer.inlineCallbacks + def test_search_user_dir_all_users(self): + self.hs.config.user_directory_search_all_users = True + try: + r = yield self.store.search_user_dir(ALICE, "bob", 10) + self.assertFalse(r["limited"]) + self.assertEqual(2, len(r["results"])) + self.assertDictEqual(r["results"][0], { + "user_id": BOB, + "display_name": "bob", + "avatar_url": None, + }) + self.assertDictEqual(r["results"][1], { + "user_id": BOBBY, + "display_name": "bobby", + "avatar_url": None, + }) + finally: + self.hs.config.user_directory_search_all_users = False diff --git a/tests/test_state.py b/tests/test_state.py index feb84f3d48..d16e1b3b8b 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -19,7 +19,7 @@ from twisted.internet import defer from synapse.events import FrozenEvent from synapse.api.auth import Auth from synapse.api.constants import EventTypes, Membership -from synapse.state import StateHandler +from synapse.state import StateHandler, StateResolutionHandler from .utils import MockClock @@ -148,11 +148,13 @@ class StateTestCase(unittest.TestCase): ) hs = Mock(spec_set=[ "get_datastore", "get_auth", "get_state_handler", "get_clock", + "get_state_resolution_handler", ]) hs.get_datastore.return_value = self.store hs.get_state_handler.return_value = None hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) + hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) self.store.get_next_state_group.side_effect = Mock self.store.get_state_group_delta.return_value = (None, None) diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index e2f7765f49..4850722bc5 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -12,12 +12,12 @@ class LoggingContextTestCase(unittest.TestCase): def _check_test_key(self, value): self.assertEquals( - LoggingContext.current_context().test_key, value + LoggingContext.current_context().request, value ) def test_with_context(self): with LoggingContext() as context_one: - context_one.test_key = "test" + context_one.request = "test" self._check_test_key("test") @defer.inlineCallbacks @@ -25,14 +25,14 @@ class LoggingContextTestCase(unittest.TestCase): @defer.inlineCallbacks def competing_callback(): with LoggingContext() as competing_context: - competing_context.test_key = "competing" + competing_context.request = "competing" yield sleep(0) self._check_test_key("competing") reactor.callLater(0, competing_callback) with LoggingContext() as context_one: - context_one.test_key = "one" + context_one.request = "one" yield sleep(0) self._check_test_key("one") @@ -43,14 +43,14 @@ class LoggingContextTestCase(unittest.TestCase): @defer.inlineCallbacks def cb(): - context_one.test_key = "one" + context_one.request = "one" yield function() self._check_test_key("one") callback_completed[0] = True with LoggingContext() as context_one: - context_one.test_key = "one" + context_one.request = "one" # fire off function, but don't wait on it. logcontext.preserve_fn(cb)() @@ -107,7 +107,7 @@ class LoggingContextTestCase(unittest.TestCase): sentinel_context = LoggingContext.current_context() with LoggingContext() as context_one: - context_one.test_key = "one" + context_one.request = "one" d1 = logcontext.make_deferred_yieldable(blocking_function()) # make sure that the context was reset by make_deferred_yieldable @@ -124,7 +124,7 @@ class LoggingContextTestCase(unittest.TestCase): argument isn't actually a deferred""" with LoggingContext() as context_one: - context_one.test_key = "one" + context_one.request = "one" d1 = logcontext.make_deferred_yieldable("bum") self._check_test_key("one") diff --git a/tests/utils.py b/tests/utils.py index 44e5f75093..8efd3a3475 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -13,27 +13,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.http.server import HttpServer -from synapse.api.errors import cs_error, CodeMessageException, StoreError -from synapse.api.constants import EventTypes -from synapse.storage.prepare_database import prepare_database -from synapse.storage.engines import create_engine -from synapse.server import HomeServer -from synapse.federation.transport import server -from synapse.util.ratelimitutils import FederationRateLimiter - -from synapse.util.logcontext import LoggingContext - -from twisted.internet import defer, reactor -from twisted.enterprise.adbapi import ConnectionPool - -from collections import namedtuple -from mock import patch, Mock import hashlib +from inspect import getcallargs import urllib import urlparse -from inspect import getcallargs +from mock import Mock, patch +from twisted.internet import defer, reactor + +from synapse.api.errors import CodeMessageException, cs_error +from synapse.federation.transport import server +from synapse.http.server import HttpServer +from synapse.server import HomeServer +from synapse.storage import PostgresEngine +from synapse.storage.engines import create_engine +from synapse.storage.prepare_database import prepare_database +from synapse.util.logcontext import LoggingContext +from synapse.util.ratelimitutils import FederationRateLimiter + +# set this to True to run the tests against postgres instead of sqlite. +# It requires you to have a local postgres database called synapse_test, within +# which ALL TABLES WILL BE DROPPED +USE_POSTGRES_FOR_TESTS = False @defer.inlineCallbacks @@ -57,36 +58,70 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): config.worker_app = None config.email_enable_notifs = False config.block_non_admin_invites = False + config.federation_domain_whitelist = None + config.user_directory_search_all_users = False # disable user directory updates, because they get done in the # background, which upsets the test runner. config.update_user_directory = False config.use_frozen_dicts = True - config.database_config = {"name": "sqlite3"} config.ldap_enabled = False if "clock" not in kargs: kargs["clock"] = MockClock() + if USE_POSTGRES_FOR_TESTS: + config.database_config = { + "name": "psycopg2", + "args": { + "database": "synapse_test", + "cp_min": 1, + "cp_max": 5, + }, + } + else: + config.database_config = { + "name": "sqlite3", + "args": { + "database": ":memory:", + "cp_min": 1, + "cp_max": 1, + }, + } + + db_engine = create_engine(config.database_config) + + # we need to configure the connection pool to run the on_new_connection + # function, so that we can test code that uses custom sqlite functions + # (like rank). + config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection + if datastore is None: - db_pool = SQLiteMemoryDbPool() - yield db_pool.prepare() hs = HomeServer( - name, db_pool=db_pool, config=config, + name, config=config, + db_config=config.database_config, version_string="Synapse/tests", - database_engine=create_engine(config.database_config), - get_db_conn=db_pool.get_db_conn, + database_engine=db_engine, room_list_handler=object(), tls_server_context_factory=Mock(), **kargs ) + db_conn = hs.get_db_conn() + # make sure that the database is empty + if isinstance(db_engine, PostgresEngine): + cur = db_conn.cursor() + cur.execute("SELECT tablename FROM pg_tables where schemaname='public'") + rows = cur.fetchall() + for r in rows: + cur.execute("DROP TABLE %s CASCADE" % r[0]) + yield prepare_database(db_conn, db_engine, config) hs.setup() else: hs = HomeServer( name, db_pool=None, datastore=datastore, config=config, version_string="Synapse/tests", - database_engine=create_engine(config.database_config), + database_engine=db_engine, room_list_handler=object(), tls_server_context_factory=Mock(), **kargs @@ -305,168 +340,6 @@ class MockClock(object): return d -class SQLiteMemoryDbPool(ConnectionPool, object): - def __init__(self): - super(SQLiteMemoryDbPool, self).__init__( - "sqlite3", ":memory:", - cp_min=1, - cp_max=1, - ) - - self.config = Mock() - self.config.password_providers = [] - self.config.database_config = {"name": "sqlite3"} - - def prepare(self): - engine = self.create_engine() - return self.runWithConnection( - lambda conn: prepare_database(conn, engine, self.config) - ) - - def get_db_conn(self): - conn = self.connect() - engine = self.create_engine() - prepare_database(conn, engine, self.config) - return conn - - def create_engine(self): - return create_engine(self.config.database_config) - - -class MemoryDataStore(object): - - Room = namedtuple( - "Room", - ["room_id", "is_public", "creator"] - ) - - def __init__(self): - self.tokens_to_users = {} - self.paths_to_content = {} - - self.members = {} - self.rooms = {} - - self.current_state = {} - self.events = [] - - class Snapshot(namedtuple("Snapshot", "room_id user_id membership_state")): - def fill_out_prev_events(self, event): - pass - - def snapshot_room(self, room_id, user_id, state_type=None, state_key=None): - return self.Snapshot( - room_id, user_id, self.get_room_member(user_id, room_id) - ) - - def register(self, user_id, token, password_hash): - if user_id in self.tokens_to_users.values(): - raise StoreError(400, "User in use.") - self.tokens_to_users[token] = user_id - - def get_user_by_access_token(self, token): - try: - return { - "name": self.tokens_to_users[token], - } - except Exception: - raise StoreError(400, "User does not exist.") - - def get_room(self, room_id): - try: - return self.rooms[room_id] - except Exception: - return None - - def store_room(self, room_id, room_creator_user_id, is_public): - if room_id in self.rooms: - raise StoreError(409, "Conflicting room!") - - room = MemoryDataStore.Room( - room_id=room_id, - is_public=is_public, - creator=room_creator_user_id - ) - self.rooms[room_id] = room - - def get_room_member(self, user_id, room_id): - return self.members.get(room_id, {}).get(user_id) - - def get_room_members(self, room_id, membership=None): - if membership: - return [ - v for k, v in self.members.get(room_id, {}).items() - if v.membership == membership - ] - else: - return self.members.get(room_id, {}).values() - - def get_rooms_for_user_where_membership_is(self, user_id, membership_list): - return [ - m[user_id] for m in self.members.values() - if user_id in m and m[user_id].membership in membership_list - ] - - def get_room_events_stream(self, user_id=None, from_key=None, to_key=None, - limit=0, with_feedback=False): - return ([], from_key) # TODO - - def get_joined_hosts_for_room(self, room_id): - return defer.succeed([]) - - def persist_event(self, event): - if event.type == EventTypes.Member: - room_id = event.room_id - user = event.state_key - self.members.setdefault(room_id, {})[user] = event - - if hasattr(event, "state_key"): - key = (event.room_id, event.type, event.state_key) - self.current_state[key] = event - - self.events.append(event) - - def get_current_state(self, room_id, event_type=None, state_key=""): - if event_type: - key = (room_id, event_type, state_key) - if self.current_state.get(key): - return [self.current_state.get(key)] - return None - else: - return [ - e for e in self.current_state - if e[0] == room_id - ] - - def set_presence_state(self, user_localpart, state): - return defer.succeed({"state": 0}) - - def get_presence_list(self, user_localpart, accepted): - return [] - - def get_room_events_max_id(self): - return "s0" # TODO (erikj) - - def get_send_event_level(self, room_id): - return defer.succeed(0) - - def get_power_level(self, room_id, user_id): - return defer.succeed(0) - - def get_add_state_level(self, room_id): - return defer.succeed(0) - - def get_room_join_rule(self, room_id): - # TODO (erikj): This should be configurable - return defer.succeed("invite") - - def get_ops_levels(self, room_id): - return defer.succeed((5, 5, 5)) - - def insert_client_ip(self, user, access_token, ip, user_agent): - return defer.succeed(None) - - def _format_call(args, kwargs): return ", ".join( ["%r" % (a) for a in args] + |