diff options
Diffstat (limited to 'tests')
27 files changed, 706 insertions, 326 deletions
diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py index 8f57fbeb23..879159ccea 100644 --- a/tests/config/test_generate.py +++ b/tests/config/test_generate.py @@ -12,9 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os.path +import re import shutil import tempfile + from synapse.config.homeserver import HomeServerConfig from tests import unittest @@ -23,7 +26,6 @@ class ConfigGenerationTestCase(unittest.TestCase): def setUp(self): self.dir = tempfile.mkdtemp() - print self.dir self.file = os.path.join(self.dir, "homeserver.yaml") def tearDown(self): @@ -48,3 +50,16 @@ class ConfigGenerationTestCase(unittest.TestCase): ]), set(os.listdir(self.dir)) ) + + self.assert_log_filename_is( + os.path.join(self.dir, "lemurs.win.log.config"), + os.path.join(os.getcwd(), "homeserver.log"), + ) + + def assert_log_filename_is(self, log_config_file, expected): + with open(log_config_file) as f: + config = f.read() + # find the 'filename' line + matches = re.findall("^\s*filename:\s*(.*)$", config, re.M) + self.assertEqual(1, len(matches)) + self.assertEqual(matches[0], expected) diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 570312da84..d4ec02ffc2 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -68,7 +68,7 @@ class KeyringTestCase(unittest.TestCase): def check_context(self, _, expected): self.assertEquals( - getattr(LoggingContext.current_context(), "test_key", None), + getattr(LoggingContext.current_context(), "request", None), expected ) @@ -82,7 +82,7 @@ class KeyringTestCase(unittest.TestCase): lookup_2_deferred = defer.Deferred() with LoggingContext("one") as context_one: - context_one.test_key = "one" + context_one.request = "one" wait_1_deferred = kr.wait_for_previous_lookups( ["server1"], @@ -96,7 +96,7 @@ class KeyringTestCase(unittest.TestCase): wait_1_deferred.addBoth(self.check_context, "one") with LoggingContext("two") as context_two: - context_two.test_key = "two" + context_two.request = "two" # set off another wait. It should block because the first lookup # hasn't yet completed. @@ -137,7 +137,7 @@ class KeyringTestCase(unittest.TestCase): @defer.inlineCallbacks def get_perspectives(**kwargs): self.assertEquals( - LoggingContext.current_context().test_key, "11", + LoggingContext.current_context().request, "11", ) with logcontext.PreserveLoggingContext(): yield persp_deferred @@ -145,7 +145,7 @@ class KeyringTestCase(unittest.TestCase): self.http_client.post_json.side_effect = get_perspectives with LoggingContext("11") as context_11: - context_11.test_key = "11" + context_11.request = "11" # start off a first set of lookups res_deferreds = kr.verify_json_objects_for_server( @@ -167,13 +167,13 @@ class KeyringTestCase(unittest.TestCase): # wait a tick for it to send the request to the perspectives server # (it first tries the datastore) - yield async.sleep(0.005) + yield async.sleep(1) # XXX find out why this takes so long! self.http_client.post_json.assert_called_once() self.assertIs(LoggingContext.current_context(), context_11) context_12 = LoggingContext("12") - context_12.test_key = "12" + context_12.request = "12" with logcontext.PreserveLoggingContext(context_12): # a second request for a server with outstanding requests # should block rather than start a second call @@ -183,7 +183,7 @@ class KeyringTestCase(unittest.TestCase): res_deferreds_2 = kr.verify_json_objects_for_server( [("server10", json1)], ) - yield async.sleep(0.005) + yield async.sleep(01) self.http_client.post_json.assert_not_called() res_deferreds_2[0].addBoth(self.check_context, None) @@ -211,7 +211,7 @@ class KeyringTestCase(unittest.TestCase): sentinel_context = LoggingContext.current_context() with LoggingContext("one") as context_one: - context_one.test_key = "one" + context_one.request = "one" defer = kr.verify_json_for_server("server9", {}) try: diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 5712773909..7e5332e272 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -35,21 +35,20 @@ class DirectoryTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.mock_federation = Mock(spec=[ - "make_query", - "register_edu_handler", - ]) + self.mock_federation = Mock() + self.mock_registry = Mock() self.query_handlers = {} def register_query_handler(query_type, handler): self.query_handlers[query_type] = handler - self.mock_federation.register_query_handler = register_query_handler + self.mock_registry.register_query_handler = register_query_handler hs = yield setup_test_homeserver( http_client=None, resource_for_federation=Mock(), - replication_layer=self.mock_federation, + federation_client=self.mock_federation, + federation_registry=self.mock_registry, ) hs.handlers = DirectoryHandlers(hs) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 19f5ed6bce..d1bd87b898 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -34,7 +34,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): def setUp(self): self.hs = yield utils.setup_test_homeserver( handlers=None, - replication_layer=mock.Mock(), + federation_client=mock.Mock(), ) self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs) @@ -143,7 +143,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase): except errors.SynapseError: pass - @unittest.DEBUG @defer.inlineCallbacks def test_claim_one_time_key(self): local_user = "@boris:" + self.hs.hostname diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index a5f47181d7..458296ee4c 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -37,23 +37,23 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.mock_federation = Mock(spec=[ - "make_query", - "register_edu_handler", - ]) + self.mock_federation = Mock() + self.mock_registry = Mock() self.query_handlers = {} def register_query_handler(query_type, handler): self.query_handlers[query_type] = handler - self.mock_federation.register_query_handler = register_query_handler + self.mock_registry.register_query_handler = register_query_handler hs = yield setup_test_homeserver( http_client=None, handlers=None, resource_for_federation=Mock(), - replication_layer=self.mock_federation, + federation_client=self.mock_federation, + federation_server=Mock(), + federation_registry=self.mock_registry, ratelimiter=NonCallableMock(spec_set=[ "send_message", ]) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index fcd380b03a..a433bbfa8a 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -81,7 +81,7 @@ class TypingNotificationsTestCase(unittest.TestCase): "get_current_state_deltas", ]), state_handler=self.state_handler, - handlers=None, + handlers=Mock(), notifier=mock_notifier, resource_for_client=Mock(), resource_for_federation=self.mock_federation_resource, diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index f85455a5af..39bde6e3f8 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -141,6 +141,7 @@ class CacheMetricTestCase(unittest.TestCase): 'cache:hits{name="cache_name"} 0', 'cache:total{name="cache_name"} 0', 'cache:size{name="cache_name"} 0', + 'cache:evicted_size{name="cache_name"} 0', ]) metric.inc_misses() @@ -150,6 +151,7 @@ class CacheMetricTestCase(unittest.TestCase): 'cache:hits{name="cache_name"} 0', 'cache:total{name="cache_name"} 1', 'cache:size{name="cache_name"} 1', + 'cache:evicted_size{name="cache_name"} 0', ]) metric.inc_hits() @@ -158,4 +160,14 @@ class CacheMetricTestCase(unittest.TestCase): 'cache:hits{name="cache_name"} 1', 'cache:total{name="cache_name"} 2', 'cache:size{name="cache_name"} 1', + 'cache:evicted_size{name="cache_name"} 0', + ]) + + metric.inc_evictions(2) + + self.assertEquals(metric.render(), [ + 'cache:hits{name="cache_name"} 1', + 'cache:total{name="cache_name"} 2', + 'cache:size{name="cache_name"} 1', + 'cache:evicted_size{name="cache_name"} 2', ]) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 81063f19a1..64e07a8c93 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -15,6 +15,8 @@ from twisted.internet import defer, reactor from tests import unittest +import tempfile + from mock import Mock, NonCallableMock from tests.utils import setup_test_homeserver from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory @@ -29,7 +31,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase): self.hs = yield setup_test_homeserver( "blue", http_client=None, - replication_layer=Mock(), + federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=[ "send_message", ]), @@ -41,7 +43,9 @@ class BaseSlavedStoreTestCase(unittest.TestCase): self.event_id = 0 server_factory = ReplicationStreamProtocolFactory(self.hs) - listener = reactor.listenUNIX("\0xxx", server_factory) + # XXX: mktemp is unsafe and should never be used. but we're just a test. + path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket") + listener = reactor.listenUNIX(path, server_factory) self.addCleanup(listener.stopListening) self.streamer = server_factory.streamer @@ -49,7 +53,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase): client_factory = ReplicationClientFactory( self.hs, "client_name", self.replication_handler ) - client_connector = reactor.connectUNIX("\0xxx", client_factory) + client_connector = reactor.connectUNIX(path, client_factory) self.addCleanup(client_factory.stopTrying) self.addCleanup(client_connector.disconnect) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 105e1228bb..cb058d3142 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -226,13 +226,16 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): context = EventContext() context.current_state_ids = state_ids context.prev_state_ids = state_ids - elif not backfill: + else: state_handler = self.hs.get_state_handler() context = yield state_handler.compute_event_context(event) - else: - context = EventContext() - context.push_actions = push_actions + yield self.master_store.add_push_actions_to_staging( + event.event_id, { + user_id: actions + for user_id, actions in push_actions + }, + ) ordering = None if backfill: diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index e9698bfdc9..2b89c0a3c7 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -114,7 +114,7 @@ class EventStreamPermissionsTestCase(RestTestCase): hs = yield setup_test_homeserver( http_client=None, - replication_layer=Mock(), + federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=[ "send_message", ]), diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index dddcf51b69..deac7f100c 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -45,7 +45,7 @@ class ProfileTestCase(unittest.TestCase): http_client=None, resource_for_client=self.mock_resource, federation=Mock(), - replication_layer=Mock(), + federation_client=Mock(), profile_handler=self.mock_handler ) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index de376fb514..7e8966a1a8 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -46,7 +46,7 @@ class RoomPermissionsTestCase(RestTestCase): hs = yield setup_test_homeserver( "red", http_client=None, - replication_layer=Mock(), + federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=["send_message"]), ) self.ratelimiter = hs.get_ratelimiter() @@ -409,7 +409,7 @@ class RoomsMemberListTestCase(RestTestCase): hs = yield setup_test_homeserver( "red", http_client=None, - replication_layer=Mock(), + federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=["send_message"]), ) self.ratelimiter = hs.get_ratelimiter() @@ -493,7 +493,7 @@ class RoomsCreateTestCase(RestTestCase): hs = yield setup_test_homeserver( "red", http_client=None, - replication_layer=Mock(), + federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=["send_message"]), ) self.ratelimiter = hs.get_ratelimiter() @@ -515,9 +515,6 @@ class RoomsCreateTestCase(RestTestCase): synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) - def tearDown(self): - pass - @defer.inlineCallbacks def test_post_room_no_keys(self): # POST with no config keys, expect new room id @@ -585,7 +582,7 @@ class RoomTopicTestCase(RestTestCase): hs = yield setup_test_homeserver( "red", http_client=None, - replication_layer=Mock(), + federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=["send_message"]), ) self.ratelimiter = hs.get_ratelimiter() @@ -700,7 +697,7 @@ class RoomMemberStateTestCase(RestTestCase): hs = yield setup_test_homeserver( "red", http_client=None, - replication_layer=Mock(), + federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=["send_message"]), ) self.ratelimiter = hs.get_ratelimiter() @@ -832,7 +829,7 @@ class RoomMessagesTestCase(RestTestCase): hs = yield setup_test_homeserver( "red", http_client=None, - replication_layer=Mock(), + federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=["send_message"]), ) self.ratelimiter = hs.get_ratelimiter() @@ -932,7 +929,7 @@ class RoomInitialSyncTestCase(RestTestCase): hs = yield setup_test_homeserver( "red", http_client=None, - replication_layer=Mock(), + federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=[ "send_message", ]), @@ -1006,7 +1003,7 @@ class RoomMessageListTestCase(RestTestCase): hs = yield setup_test_homeserver( "red", http_client=None, - replication_layer=Mock(), + federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=["send_message"]), ) self.ratelimiter = hs.get_ratelimiter() diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index a269e6f56e..2ec4ecab5b 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -47,7 +47,7 @@ class RoomTypingTestCase(RestTestCase): "red", clock=self.clock, http_client=None, - replication_layer=Mock(), + federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=[ "send_message", ]), @@ -95,7 +95,7 @@ class RoomTypingTestCase(RestTestCase): else: if remotedomains is not None: remotedomains.add(member.domain) - hs.get_handlers().room_member_handler.fetch_room_distributions_into = ( + hs.get_room_member_handler().fetch_room_distributions_into = ( fetch_room_distributions_into ) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 096f771bea..8aba456510 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -49,6 +49,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.hs.get_auth_handler = Mock(return_value=self.auth_handler) self.hs.get_device_handler = Mock(return_value=self.device_handler) self.hs.config.enable_registration = True + self.hs.config.registrations_require_3pid = [] self.hs.config.auto_join_rooms = [] # init the thing we're testing diff --git a/tests/rest/media/__init__.py b/tests/rest/media/__init__.py new file mode 100644 index 0000000000..a354d38ca8 --- /dev/null +++ b/tests/rest/media/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/rest/media/v1/__init__.py b/tests/rest/media/v1/__init__.py new file mode 100644 index 0000000000..a354d38ca8 --- /dev/null +++ b/tests/rest/media/v1/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py new file mode 100644 index 0000000000..eef38b6781 --- /dev/null +++ b/tests/rest/media/v1/test_media_storage.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer + +from synapse.rest.media.v1._base import FileInfo +from synapse.rest.media.v1.media_storage import MediaStorage +from synapse.rest.media.v1.filepath import MediaFilePaths +from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend + +from mock import Mock + +from tests import unittest + +import os +import shutil +import tempfile + + +class MediaStorageTests(unittest.TestCase): + def setUp(self): + self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-") + + self.primary_base_path = os.path.join(self.test_dir, "primary") + self.secondary_base_path = os.path.join(self.test_dir, "secondary") + + hs = Mock() + hs.config.media_store_path = self.primary_base_path + + storage_providers = [FileStorageProviderBackend( + hs, self.secondary_base_path + )] + + self.filepaths = MediaFilePaths(self.primary_base_path) + self.media_storage = MediaStorage( + self.primary_base_path, self.filepaths, storage_providers, + ) + + def tearDown(self): + shutil.rmtree(self.test_dir) + + @defer.inlineCallbacks + def test_ensure_media_is_in_local_cache(self): + media_id = "some_media_id" + test_body = "Test\n" + + # First we create a file that is in a storage provider but not in the + # local primary media store + rel_path = self.filepaths.local_media_filepath_rel(media_id) + secondary_path = os.path.join(self.secondary_base_path, rel_path) + + os.makedirs(os.path.dirname(secondary_path)) + + with open(secondary_path, "w") as f: + f.write(test_body) + + # Now we run ensure_media_is_in_local_cache, which should copy the file + # to the local cache. + file_info = FileInfo(None, media_id) + local_path = yield self.media_storage.ensure_media_is_in_local_cache(file_info) + + self.assertTrue(os.path.exists(local_path)) + + # Asserts the file is under the expected local cache directory + self.assertEquals( + os.path.commonprefix([self.primary_base_path, local_path]), + self.primary_base_path, + ) + + with open(local_path) as f: + body = f.read() + + self.assertEqual(test_body, body) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 13d81f972b..c2e39a7288 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -42,7 +42,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): hs = yield setup_test_homeserver( config=config, federation_sender=Mock(), - replication_layer=Mock(), + federation_client=Mock(), ) self.as_token = "token1" @@ -119,7 +119,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): hs = yield setup_test_homeserver( config=config, federation_sender=Mock(), - replication_layer=Mock(), + federation_client=Mock(), ) self.db_pool = hs.get_db_pool() @@ -455,7 +455,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): config=config, datastore=Mock(), federation_sender=Mock(), - replication_layer=Mock(), + federation_client=Mock(), ) ApplicationServiceStore(None, hs) @@ -473,7 +473,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): config=config, datastore=Mock(), federation_sender=Mock(), - replication_layer=Mock(), + federation_client=Mock(), ) with self.assertRaises(ConfigError) as cm: @@ -497,7 +497,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): config=config, datastore=Mock(), federation_sender=Mock(), - replication_layer=Mock(), + federation_client=Mock(), ) with self.assertRaises(ConfigError) as cm: diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 3135488353..575374c6a6 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -62,6 +62,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): {"notify_count": noitf_count, "highlight_count": highlight_count} ) + @defer.inlineCallbacks def _inject_actions(stream, action): event = Mock() event.room_id = room_id @@ -69,11 +70,12 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): event.internal_metadata.stream_ordering = stream event.depth = stream - tuples = [(user_id, action)] - - return self.store.runInteraction( + yield self.store.add_push_actions_to_staging( + event.event_id, {user_id: action}, + ) + yield self.store.runInteraction( "", self.store._set_push_actions_for_event_and_users_txn, - event, tuples + [(event, None)], [(event, None)], ) def _rotate(stream): @@ -125,3 +127,70 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield _assert_counts(1, 1) yield _rotate(10) yield _assert_counts(1, 1) + + @tests.unittest.DEBUG + @defer.inlineCallbacks + def test_find_first_stream_ordering_after_ts(self): + def add_event(so, ts): + return self.store._simple_insert("events", { + "stream_ordering": so, + "received_ts": ts, + "event_id": "event%i" % so, + "type": "", + "room_id": "", + "content": "", + "processed": True, + "outlier": False, + "topological_ordering": 0, + "depth": 0, + }) + + # start with the base case where there are no events in the table + r = yield self.store.find_first_stream_ordering_after_ts(11) + self.assertEqual(r, 0) + + # now with one event + yield add_event(2, 10) + r = yield self.store.find_first_stream_ordering_after_ts(9) + self.assertEqual(r, 2) + r = yield self.store.find_first_stream_ordering_after_ts(10) + self.assertEqual(r, 2) + r = yield self.store.find_first_stream_ordering_after_ts(11) + self.assertEqual(r, 3) + + # add a bunch of dummy events to the events table + for (stream_ordering, ts) in ( + (3, 110), + (4, 120), + (5, 120), + (10, 130), + (20, 140), + ): + yield add_event(stream_ordering, ts) + + r = yield self.store.find_first_stream_ordering_after_ts(110) + self.assertEqual(r, 3, + "First event after 110ms should be 3, was %i" % r) + + # 4 and 5 are both after 120: we want 4 rather than 5 + r = yield self.store.find_first_stream_ordering_after_ts(120) + self.assertEqual(r, 4, + "First event after 120ms should be 4, was %i" % r) + + r = yield self.store.find_first_stream_ordering_after_ts(129) + self.assertEqual(r, 10, + "First event after 129ms should be 10, was %i" % r) + + # check we can get the last event + r = yield self.store.find_first_stream_ordering_after_ts(140) + self.assertEqual(r, 20, + "First event after 14ms should be 20, was %i" % r) + + # off the end + r = yield self.store.find_first_stream_ordering_after_ts(160) + self.assertEqual(r, 21) + + # check we can find an event at ordering zero + yield add_event(0, 5) + r = yield self.store.find_first_stream_ordering_after_ts(1) + self.assertEqual(r, 0) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 6afaca3a61..888ddfaddd 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -36,8 +36,7 @@ class RedactionTestCase(unittest.TestCase): self.store = hs.get_datastore() self.event_builder_factory = hs.get_event_builder_factory() - self.handlers = hs.get_handlers() - self.message_handler = self.handlers.message_handler + self.event_creation_handler = hs.get_event_creation_handler() self.u_alice = UserID.from_string("@alice:test") self.u_bob = UserID.from_string("@bob:test") @@ -59,7 +58,7 @@ class RedactionTestCase(unittest.TestCase): "content": content, }) - event, context = yield self.message_handler._create_new_client_event( + event, context = yield self.event_creation_handler.create_new_client_event( builder ) @@ -79,7 +78,7 @@ class RedactionTestCase(unittest.TestCase): "content": {"body": body, "msgtype": u"message"}, }) - event, context = yield self.message_handler._create_new_client_event( + event, context = yield self.event_creation_handler.create_new_client_event( builder ) @@ -98,7 +97,7 @@ class RedactionTestCase(unittest.TestCase): "redacts": event_id, }) - event, context = yield self.message_handler._create_new_client_event( + event, context = yield self.event_creation_handler.create_new_client_event( builder ) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 1be7d932f6..657b279e5d 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -37,8 +37,7 @@ class RoomMemberStoreTestCase(unittest.TestCase): # storage logic self.store = hs.get_datastore() self.event_builder_factory = hs.get_event_builder_factory() - self.handlers = hs.get_handlers() - self.message_handler = self.handlers.message_handler + self.event_creation_handler = hs.get_event_creation_handler() self.u_alice = UserID.from_string("@alice:test") self.u_bob = UserID.from_string("@bob:test") @@ -58,7 +57,7 @@ class RoomMemberStoreTestCase(unittest.TestCase): "content": {"membership": membership}, }) - event, context = yield self.message_handler._create_new_client_event( + event, context = yield self.event_creation_handler.create_new_client_event( builder ) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py new file mode 100644 index 0000000000..0891308f25 --- /dev/null +++ b/tests/storage/test_user_directory.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.storage import UserDirectoryStore +from synapse.storage.roommember import ProfileInfo +from tests import unittest +from tests.utils import setup_test_homeserver + +ALICE = "@alice:a" +BOB = "@bob:b" +BOBBY = "@bobby:a" + + +class UserDirectoryStoreTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver() + self.store = UserDirectoryStore(None, self.hs) + + # alice and bob are both in !room_id. bobby is not but shares + # a homeserver with alice. + yield self.store.add_profiles_to_user_dir( + "!room:id", + { + ALICE: ProfileInfo(None, "alice"), + BOB: ProfileInfo(None, "bob"), + BOBBY: ProfileInfo(None, "bobby") + }, + ) + yield self.store.add_users_to_public_room( + "!room:id", + [ALICE, BOB], + ) + yield self.store.add_users_who_share_room( + "!room:id", + False, + ( + (ALICE, BOB), + (BOB, ALICE), + ), + ) + + @defer.inlineCallbacks + def test_search_user_dir(self): + # normally when alice searches the directory she should just find + # bob because bobby doesn't share a room with her. + r = yield self.store.search_user_dir(ALICE, "bob", 10) + self.assertFalse(r["limited"]) + self.assertEqual(1, len(r["results"])) + self.assertDictEqual(r["results"][0], { + "user_id": BOB, + "display_name": "bob", + "avatar_url": None, + }) + + @defer.inlineCallbacks + def test_search_user_dir_all_users(self): + self.hs.config.user_directory_search_all_users = True + try: + r = yield self.store.search_user_dir(ALICE, "bob", 10) + self.assertFalse(r["limited"]) + self.assertEqual(2, len(r["results"])) + self.assertDictEqual(r["results"][0], { + "user_id": BOB, + "display_name": "bob", + "avatar_url": None, + }) + self.assertDictEqual(r["results"][1], { + "user_id": BOBBY, + "display_name": "bobby", + "avatar_url": None, + }) + finally: + self.hs.config.user_directory_search_all_users = False diff --git a/tests/test_state.py b/tests/test_state.py index feb84f3d48..a5c5e55951 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -19,7 +19,7 @@ from twisted.internet import defer from synapse.events import FrozenEvent from synapse.api.auth import Auth from synapse.api.constants import EventTypes, Membership -from synapse.state import StateHandler +from synapse.state import StateHandler, StateResolutionHandler from .utils import MockClock @@ -80,14 +80,14 @@ class StateGroupStore(object): return defer.succeed(groups) - def store_state_groups(self, event, context): - if context.current_state_ids is None: - return + def store_state_group(self, event_id, room_id, prev_group, delta_ids, + current_state_ids): + state_group = self._next_group + self._next_group += 1 - state_events = dict(context.current_state_ids) + self._group_to_state[state_group] = dict(current_state_ids) - self._group_to_state[context.state_group] = state_events - self._event_to_state_group[event.event_id] = context.state_group + return state_group def get_events(self, event_ids, **kwargs): return { @@ -95,10 +95,19 @@ class StateGroupStore(object): if e_id in self._event_id_to_event } + def get_state_group_delta(self, name): + return (None, None) + def register_events(self, events): for e in events: self._event_id_to_event[e.event_id] = e + def register_event_context(self, event, context): + self._event_to_state_group[event.event_id] = context.state_group + + def register_event_id_state_group(self, event_id, state_group): + self._event_to_state_group[event_id] = state_group + class DictObj(dict): def __init__(self, **kwargs): @@ -137,25 +146,16 @@ class Graph(object): class StateTestCase(unittest.TestCase): def setUp(self): - self.store = Mock( - spec_set=[ - "get_state_groups_ids", - "add_event_hashes", - "get_events", - "get_next_state_group", - "get_state_group_delta", - ] - ) + self.store = StateGroupStore() hs = Mock(spec_set=[ "get_datastore", "get_auth", "get_state_handler", "get_clock", + "get_state_resolution_handler", ]) hs.get_datastore.return_value = self.store hs.get_state_handler.return_value = None hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) - - self.store.get_next_state_group.side_effect = Mock - self.store.get_state_group_delta.return_value = (None, None) + hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) self.state = StateHandler(hs) self.event_id = 0 @@ -195,14 +195,13 @@ class StateTestCase(unittest.TestCase): } ) - store = StateGroupStore() - self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids + self.store.register_events(graph.walk()) context_store = {} for event in graph.walk(): context = yield self.state.compute_event_context(event) - store.store_state_groups(event, context) + self.store.register_event_context(event, context) context_store[event.event_id] = context self.assertEqual(2, len(context_store["D"].prev_state_ids)) @@ -247,16 +246,13 @@ class StateTestCase(unittest.TestCase): } ) - store = StateGroupStore() - self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids - self.store.get_events = store.get_events - store.register_events(graph.walk()) + self.store.register_events(graph.walk()) context_store = {} for event in graph.walk(): context = yield self.state.compute_event_context(event) - store.store_state_groups(event, context) + self.store.register_event_context(event, context) context_store[event.event_id] = context self.assertSetEqual( @@ -313,16 +309,13 @@ class StateTestCase(unittest.TestCase): } ) - store = StateGroupStore() - self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids - self.store.get_events = store.get_events - store.register_events(graph.walk()) + self.store.register_events(graph.walk()) context_store = {} for event in graph.walk(): context = yield self.state.compute_event_context(event) - store.store_state_groups(event, context) + self.store.register_event_context(event, context) context_store[event.event_id] = context self.assertSetEqual( @@ -396,16 +389,13 @@ class StateTestCase(unittest.TestCase): self._add_depths(nodes, edges) graph = Graph(nodes, edges) - store = StateGroupStore() - self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids - self.store.get_events = store.get_events - store.register_events(graph.walk()) + self.store.register_events(graph.walk()) context_store = {} for event in graph.walk(): context = yield self.state.compute_event_context(event) - store.store_state_groups(event, context) + self.store.register_event_context(event, context) context_store[event.event_id] = context self.assertSetEqual( @@ -465,7 +455,11 @@ class StateTestCase(unittest.TestCase): @defer.inlineCallbacks def test_trivial_annotate_message(self): - event = create_event(type="test_message", name="event") + prev_event_id = "prev_event_id" + event = create_event( + type="test_message", name="event2", + prev_events=[(prev_event_id, {})], + ) old_state = [ create_event(type="test1", state_key="1"), @@ -473,11 +467,11 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - group_name = "group_name_1" - - self.store.get_state_groups_ids.return_value = { - group_name: {(e.type, e.state_key): e.event_id for e in old_state}, - } + group_name = self.store.store_state_group( + prev_event_id, event.room_id, None, None, + {(e.type, e.state_key): e.event_id for e in old_state}, + ) + self.store.register_event_id_state_group(prev_event_id, group_name) context = yield self.state.compute_event_context(event) @@ -490,7 +484,11 @@ class StateTestCase(unittest.TestCase): @defer.inlineCallbacks def test_trivial_annotate_state(self): - event = create_event(type="state", state_key="", name="event") + prev_event_id = "prev_event_id" + event = create_event( + type="state", state_key="", name="event2", + prev_events=[(prev_event_id, {})], + ) old_state = [ create_event(type="test1", state_key="1"), @@ -498,11 +496,11 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - group_name = "group_name_1" - - self.store.get_state_groups_ids.return_value = { - group_name: {(e.type, e.state_key): e.event_id for e in old_state}, - } + group_name = self.store.store_state_group( + prev_event_id, event.room_id, None, None, + {(e.type, e.state_key): e.event_id for e in old_state}, + ) + self.store.register_event_id_state_group(prev_event_id, group_name) context = yield self.state.compute_event_context(event) @@ -515,7 +513,12 @@ class StateTestCase(unittest.TestCase): @defer.inlineCallbacks def test_resolve_message_conflict(self): - event = create_event(type="test_message", name="event") + prev_event_id1 = "event_id1" + prev_event_id2 = "event_id2" + event = create_event( + type="test_message", name="event3", + prev_events=[(prev_event_id1, {}), (prev_event_id2, {})], + ) creation = create_event( type=EventTypes.Create, state_key="" @@ -535,12 +538,12 @@ class StateTestCase(unittest.TestCase): create_event(type="test4", state_key=""), ] - store = StateGroupStore() - store.register_events(old_state_1) - store.register_events(old_state_2) - self.store.get_events = store.get_events + self.store.register_events(old_state_1) + self.store.register_events(old_state_2) - context = yield self._get_context(event, old_state_1, old_state_2) + context = yield self._get_context( + event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, + ) self.assertEqual(len(context.current_state_ids), 6) @@ -548,7 +551,12 @@ class StateTestCase(unittest.TestCase): @defer.inlineCallbacks def test_resolve_state_conflict(self): - event = create_event(type="test4", state_key="", name="event") + prev_event_id1 = "event_id1" + prev_event_id2 = "event_id2" + event = create_event( + type="test4", state_key="", name="event", + prev_events=[(prev_event_id1, {}), (prev_event_id2, {})], + ) creation = create_event( type=EventTypes.Create, state_key="" @@ -573,7 +581,9 @@ class StateTestCase(unittest.TestCase): store.register_events(old_state_2) self.store.get_events = store.get_events - context = yield self._get_context(event, old_state_1, old_state_2) + context = yield self._get_context( + event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, + ) self.assertEqual(len(context.current_state_ids), 6) @@ -581,7 +591,12 @@ class StateTestCase(unittest.TestCase): @defer.inlineCallbacks def test_standard_depth_conflict(self): - event = create_event(type="test4", name="event") + prev_event_id1 = "event_id1" + prev_event_id2 = "event_id2" + event = create_event( + type="test4", name="event", + prev_events=[(prev_event_id1, {}), (prev_event_id2, {})], + ) member_event = create_event( type=EventTypes.Member, @@ -613,7 +628,9 @@ class StateTestCase(unittest.TestCase): store.register_events(old_state_2) self.store.get_events = store.get_events - context = yield self._get_context(event, old_state_1, old_state_2) + context = yield self._get_context( + event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, + ) self.assertEqual( old_state_2[2].event_id, context.current_state_ids[("test1", "1")] @@ -637,19 +654,26 @@ class StateTestCase(unittest.TestCase): store.register_events(old_state_1) store.register_events(old_state_2) - context = yield self._get_context(event, old_state_1, old_state_2) + context = yield self._get_context( + event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, + ) self.assertEqual( old_state_1[2].event_id, context.current_state_ids[("test1", "1")] ) - def _get_context(self, event, old_state_1, old_state_2): - group_name_1 = "group_name_1" - group_name_2 = "group_name_2" + def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2, + old_state_2): + sg1 = self.store.store_state_group( + prev_event_id_1, event.room_id, None, None, + {(e.type, e.state_key): e.event_id for e in old_state_1}, + ) + self.store.register_event_id_state_group(prev_event_id_1, sg1) - self.store.get_state_groups_ids.return_value = { - group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1}, - group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2}, - } + sg2 = self.store.store_state_group( + prev_event_id_2, event.room_id, None, None, + {(e.type, e.state_key): e.event_id for e in old_state_2}, + ) + self.store.register_event_id_state_group(prev_event_id_2, sg2) return self.state.compute_event_context(event) diff --git a/tests/unittest.py b/tests/unittest.py index 38715972dd..7b478c4294 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import twisted from twisted.trial import unittest import logging @@ -65,6 +65,10 @@ class TestCase(unittest.TestCase): @around(self) def setUp(orig): + # enable debugging of delayed calls - this means that we get a + # traceback when a unit test exits leaving things on the reactor. + twisted.internet.base.DelayedCall.debug = True + old_level = logging.getLogger().level if old_level != level: diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py new file mode 100644 index 0000000000..76e2234255 --- /dev/null +++ b/tests/util/test_file_consumer.py @@ -0,0 +1,176 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer, reactor +from mock import NonCallableMock + +from synapse.util.file_consumer import BackgroundFileConsumer + +from tests import unittest +from StringIO import StringIO + +import threading + + +class FileConsumerTests(unittest.TestCase): + + @defer.inlineCallbacks + def test_pull_consumer(self): + string_file = StringIO() + consumer = BackgroundFileConsumer(string_file) + + try: + producer = DummyPullProducer() + + yield producer.register_with_consumer(consumer) + + yield producer.write_and_wait("Foo") + + self.assertEqual(string_file.getvalue(), "Foo") + + yield producer.write_and_wait("Bar") + + self.assertEqual(string_file.getvalue(), "FooBar") + finally: + consumer.unregisterProducer() + + yield consumer.wait() + + self.assertTrue(string_file.closed) + + @defer.inlineCallbacks + def test_push_consumer(self): + string_file = BlockingStringWrite() + consumer = BackgroundFileConsumer(string_file) + + try: + producer = NonCallableMock(spec_set=[]) + + consumer.registerProducer(producer, True) + + consumer.write("Foo") + yield string_file.wait_for_n_writes(1) + + self.assertEqual(string_file.buffer, "Foo") + + consumer.write("Bar") + yield string_file.wait_for_n_writes(2) + + self.assertEqual(string_file.buffer, "FooBar") + finally: + consumer.unregisterProducer() + + yield consumer.wait() + + self.assertTrue(string_file.closed) + + @defer.inlineCallbacks + def test_push_producer_feedback(self): + string_file = BlockingStringWrite() + consumer = BackgroundFileConsumer(string_file) + + try: + producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"]) + + resume_deferred = defer.Deferred() + producer.resumeProducing.side_effect = lambda: resume_deferred.callback(None) + + consumer.registerProducer(producer, True) + + number_writes = 0 + with string_file.write_lock: + for _ in range(consumer._PAUSE_ON_QUEUE_SIZE): + consumer.write("Foo") + number_writes += 1 + + producer.pauseProducing.assert_called_once() + + yield string_file.wait_for_n_writes(number_writes) + + yield resume_deferred + producer.resumeProducing.assert_called_once() + finally: + consumer.unregisterProducer() + + yield consumer.wait() + + self.assertTrue(string_file.closed) + + +class DummyPullProducer(object): + def __init__(self): + self.consumer = None + self.deferred = defer.Deferred() + + def resumeProducing(self): + d = self.deferred + self.deferred = defer.Deferred() + d.callback(None) + + def write_and_wait(self, bytes): + d = self.deferred + self.consumer.write(bytes) + return d + + def register_with_consumer(self, consumer): + d = self.deferred + self.consumer = consumer + self.consumer.registerProducer(self, False) + return d + + +class BlockingStringWrite(object): + def __init__(self): + self.buffer = "" + self.closed = False + self.write_lock = threading.Lock() + + self._notify_write_deferred = None + self._number_of_writes = 0 + + def write(self, bytes): + with self.write_lock: + self.buffer += bytes + self._number_of_writes += 1 + + reactor.callFromThread(self._notify_write) + + def close(self): + self.closed = True + + def _notify_write(self): + "Called by write to indicate a write happened" + with self.write_lock: + if not self._notify_write_deferred: + return + d = self._notify_write_deferred + self._notify_write_deferred = None + d.callback(None) + + @defer.inlineCallbacks + def wait_for_n_writes(self, n): + "Wait for n writes to have happened" + while True: + with self.write_lock: + if n <= self._number_of_writes: + return + + if not self._notify_write_deferred: + self._notify_write_deferred = defer.Deferred() + + d = self._notify_write_deferred + + yield d diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index e2f7765f49..4850722bc5 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -12,12 +12,12 @@ class LoggingContextTestCase(unittest.TestCase): def _check_test_key(self, value): self.assertEquals( - LoggingContext.current_context().test_key, value + LoggingContext.current_context().request, value ) def test_with_context(self): with LoggingContext() as context_one: - context_one.test_key = "test" + context_one.request = "test" self._check_test_key("test") @defer.inlineCallbacks @@ -25,14 +25,14 @@ class LoggingContextTestCase(unittest.TestCase): @defer.inlineCallbacks def competing_callback(): with LoggingContext() as competing_context: - competing_context.test_key = "competing" + competing_context.request = "competing" yield sleep(0) self._check_test_key("competing") reactor.callLater(0, competing_callback) with LoggingContext() as context_one: - context_one.test_key = "one" + context_one.request = "one" yield sleep(0) self._check_test_key("one") @@ -43,14 +43,14 @@ class LoggingContextTestCase(unittest.TestCase): @defer.inlineCallbacks def cb(): - context_one.test_key = "one" + context_one.request = "one" yield function() self._check_test_key("one") callback_completed[0] = True with LoggingContext() as context_one: - context_one.test_key = "one" + context_one.request = "one" # fire off function, but don't wait on it. logcontext.preserve_fn(cb)() @@ -107,7 +107,7 @@ class LoggingContextTestCase(unittest.TestCase): sentinel_context = LoggingContext.current_context() with LoggingContext() as context_one: - context_one.test_key = "one" + context_one.request = "one" d1 = logcontext.make_deferred_yieldable(blocking_function()) # make sure that the context was reset by make_deferred_yieldable @@ -124,7 +124,7 @@ class LoggingContextTestCase(unittest.TestCase): argument isn't actually a deferred""" with LoggingContext() as context_one: - context_one.test_key = "one" + context_one.request = "one" d1 = logcontext.make_deferred_yieldable("bum") self._check_test_key("one") diff --git a/tests/utils.py b/tests/utils.py index ed8a7360f5..8efd3a3475 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -13,27 +13,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.http.server import HttpServer -from synapse.api.errors import cs_error, CodeMessageException, StoreError -from synapse.api.constants import EventTypes -from synapse.storage.prepare_database import prepare_database -from synapse.storage.engines import create_engine -from synapse.server import HomeServer -from synapse.federation.transport import server -from synapse.util.ratelimitutils import FederationRateLimiter - -from synapse.util.logcontext import LoggingContext - -from twisted.internet import defer, reactor -from twisted.enterprise.adbapi import ConnectionPool - -from collections import namedtuple -from mock import patch, Mock import hashlib +from inspect import getcallargs import urllib import urlparse -from inspect import getcallargs +from mock import Mock, patch +from twisted.internet import defer, reactor + +from synapse.api.errors import CodeMessageException, cs_error +from synapse.federation.transport import server +from synapse.http.server import HttpServer +from synapse.server import HomeServer +from synapse.storage import PostgresEngine +from synapse.storage.engines import create_engine +from synapse.storage.prepare_database import prepare_database +from synapse.util.logcontext import LoggingContext +from synapse.util.ratelimitutils import FederationRateLimiter + +# set this to True to run the tests against postgres instead of sqlite. +# It requires you to have a local postgres database called synapse_test, within +# which ALL TABLES WILL BE DROPPED +USE_POSTGRES_FOR_TESTS = False @defer.inlineCallbacks @@ -57,32 +58,70 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): config.worker_app = None config.email_enable_notifs = False config.block_non_admin_invites = False + config.federation_domain_whitelist = None + config.user_directory_search_all_users = False + + # disable user directory updates, because they get done in the + # background, which upsets the test runner. + config.update_user_directory = False config.use_frozen_dicts = True - config.database_config = {"name": "sqlite3"} config.ldap_enabled = False if "clock" not in kargs: kargs["clock"] = MockClock() + if USE_POSTGRES_FOR_TESTS: + config.database_config = { + "name": "psycopg2", + "args": { + "database": "synapse_test", + "cp_min": 1, + "cp_max": 5, + }, + } + else: + config.database_config = { + "name": "sqlite3", + "args": { + "database": ":memory:", + "cp_min": 1, + "cp_max": 1, + }, + } + + db_engine = create_engine(config.database_config) + + # we need to configure the connection pool to run the on_new_connection + # function, so that we can test code that uses custom sqlite functions + # (like rank). + config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection + if datastore is None: - db_pool = SQLiteMemoryDbPool() - yield db_pool.prepare() hs = HomeServer( - name, db_pool=db_pool, config=config, + name, config=config, + db_config=config.database_config, version_string="Synapse/tests", - database_engine=create_engine(config.database_config), - get_db_conn=db_pool.get_db_conn, + database_engine=db_engine, room_list_handler=object(), tls_server_context_factory=Mock(), **kargs ) + db_conn = hs.get_db_conn() + # make sure that the database is empty + if isinstance(db_engine, PostgresEngine): + cur = db_conn.cursor() + cur.execute("SELECT tablename FROM pg_tables where schemaname='public'") + rows = cur.fetchall() + for r in rows: + cur.execute("DROP TABLE %s CASCADE" % r[0]) + yield prepare_database(db_conn, db_engine, config) hs.setup() else: hs = HomeServer( name, db_pool=None, datastore=datastore, config=config, version_string="Synapse/tests", - database_engine=create_engine(config.database_config), + database_engine=db_engine, room_list_handler=object(), tls_server_context_factory=Mock(), **kargs @@ -301,168 +340,6 @@ class MockClock(object): return d -class SQLiteMemoryDbPool(ConnectionPool, object): - def __init__(self): - super(SQLiteMemoryDbPool, self).__init__( - "sqlite3", ":memory:", - cp_min=1, - cp_max=1, - ) - - self.config = Mock() - self.config.password_providers = [] - self.config.database_config = {"name": "sqlite3"} - - def prepare(self): - engine = self.create_engine() - return self.runWithConnection( - lambda conn: prepare_database(conn, engine, self.config) - ) - - def get_db_conn(self): - conn = self.connect() - engine = self.create_engine() - prepare_database(conn, engine, self.config) - return conn - - def create_engine(self): - return create_engine(self.config.database_config) - - -class MemoryDataStore(object): - - Room = namedtuple( - "Room", - ["room_id", "is_public", "creator"] - ) - - def __init__(self): - self.tokens_to_users = {} - self.paths_to_content = {} - - self.members = {} - self.rooms = {} - - self.current_state = {} - self.events = [] - - class Snapshot(namedtuple("Snapshot", "room_id user_id membership_state")): - def fill_out_prev_events(self, event): - pass - - def snapshot_room(self, room_id, user_id, state_type=None, state_key=None): - return self.Snapshot( - room_id, user_id, self.get_room_member(user_id, room_id) - ) - - def register(self, user_id, token, password_hash): - if user_id in self.tokens_to_users.values(): - raise StoreError(400, "User in use.") - self.tokens_to_users[token] = user_id - - def get_user_by_access_token(self, token): - try: - return { - "name": self.tokens_to_users[token], - } - except Exception: - raise StoreError(400, "User does not exist.") - - def get_room(self, room_id): - try: - return self.rooms[room_id] - except Exception: - return None - - def store_room(self, room_id, room_creator_user_id, is_public): - if room_id in self.rooms: - raise StoreError(409, "Conflicting room!") - - room = MemoryDataStore.Room( - room_id=room_id, - is_public=is_public, - creator=room_creator_user_id - ) - self.rooms[room_id] = room - - def get_room_member(self, user_id, room_id): - return self.members.get(room_id, {}).get(user_id) - - def get_room_members(self, room_id, membership=None): - if membership: - return [ - v for k, v in self.members.get(room_id, {}).items() - if v.membership == membership - ] - else: - return self.members.get(room_id, {}).values() - - def get_rooms_for_user_where_membership_is(self, user_id, membership_list): - return [ - m[user_id] for m in self.members.values() - if user_id in m and m[user_id].membership in membership_list - ] - - def get_room_events_stream(self, user_id=None, from_key=None, to_key=None, - limit=0, with_feedback=False): - return ([], from_key) # TODO - - def get_joined_hosts_for_room(self, room_id): - return defer.succeed([]) - - def persist_event(self, event): - if event.type == EventTypes.Member: - room_id = event.room_id - user = event.state_key - self.members.setdefault(room_id, {})[user] = event - - if hasattr(event, "state_key"): - key = (event.room_id, event.type, event.state_key) - self.current_state[key] = event - - self.events.append(event) - - def get_current_state(self, room_id, event_type=None, state_key=""): - if event_type: - key = (room_id, event_type, state_key) - if self.current_state.get(key): - return [self.current_state.get(key)] - return None - else: - return [ - e for e in self.current_state - if e[0] == room_id - ] - - def set_presence_state(self, user_localpart, state): - return defer.succeed({"state": 0}) - - def get_presence_list(self, user_localpart, accepted): - return [] - - def get_room_events_max_id(self): - return "s0" # TODO (erikj) - - def get_send_event_level(self, room_id): - return defer.succeed(0) - - def get_power_level(self, room_id, user_id): - return defer.succeed(0) - - def get_add_state_level(self, room_id): - return defer.succeed(0) - - def get_room_join_rule(self, room_id): - # TODO (erikj): This should be configurable - return defer.succeed("invite") - - def get_ops_levels(self, room_id): - return defer.succeed((5, 5, 5)) - - def insert_client_ip(self, user, access_token, ip, user_agent): - return defer.succeed(None) - - def _format_call(args, kwargs): return ", ".join( ["%r" % (a) for a in args] + |