diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py
index 8f57fbeb23..879159ccea 100644
--- a/tests/config/test_generate.py
+++ b/tests/config/test_generate.py
@@ -12,9 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import os.path
+import re
import shutil
import tempfile
+
from synapse.config.homeserver import HomeServerConfig
from tests import unittest
@@ -23,7 +26,6 @@ class ConfigGenerationTestCase(unittest.TestCase):
def setUp(self):
self.dir = tempfile.mkdtemp()
- print self.dir
self.file = os.path.join(self.dir, "homeserver.yaml")
def tearDown(self):
@@ -48,3 +50,16 @@ class ConfigGenerationTestCase(unittest.TestCase):
]),
set(os.listdir(self.dir))
)
+
+ self.assert_log_filename_is(
+ os.path.join(self.dir, "lemurs.win.log.config"),
+ os.path.join(os.getcwd(), "homeserver.log"),
+ )
+
+ def assert_log_filename_is(self, log_config_file, expected):
+ with open(log_config_file) as f:
+ config = f.read()
+ # find the 'filename' line
+ matches = re.findall("^\s*filename:\s*(.*)$", config, re.M)
+ self.assertEqual(1, len(matches))
+ self.assertEqual(matches[0], expected)
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 570312da84..d4ec02ffc2 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -68,7 +68,7 @@ class KeyringTestCase(unittest.TestCase):
def check_context(self, _, expected):
self.assertEquals(
- getattr(LoggingContext.current_context(), "test_key", None),
+ getattr(LoggingContext.current_context(), "request", None),
expected
)
@@ -82,7 +82,7 @@ class KeyringTestCase(unittest.TestCase):
lookup_2_deferred = defer.Deferred()
with LoggingContext("one") as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
wait_1_deferred = kr.wait_for_previous_lookups(
["server1"],
@@ -96,7 +96,7 @@ class KeyringTestCase(unittest.TestCase):
wait_1_deferred.addBoth(self.check_context, "one")
with LoggingContext("two") as context_two:
- context_two.test_key = "two"
+ context_two.request = "two"
# set off another wait. It should block because the first lookup
# hasn't yet completed.
@@ -137,7 +137,7 @@ class KeyringTestCase(unittest.TestCase):
@defer.inlineCallbacks
def get_perspectives(**kwargs):
self.assertEquals(
- LoggingContext.current_context().test_key, "11",
+ LoggingContext.current_context().request, "11",
)
with logcontext.PreserveLoggingContext():
yield persp_deferred
@@ -145,7 +145,7 @@ class KeyringTestCase(unittest.TestCase):
self.http_client.post_json.side_effect = get_perspectives
with LoggingContext("11") as context_11:
- context_11.test_key = "11"
+ context_11.request = "11"
# start off a first set of lookups
res_deferreds = kr.verify_json_objects_for_server(
@@ -167,13 +167,13 @@ class KeyringTestCase(unittest.TestCase):
# wait a tick for it to send the request to the perspectives server
# (it first tries the datastore)
- yield async.sleep(0.005)
+ yield async.sleep(1) # XXX find out why this takes so long!
self.http_client.post_json.assert_called_once()
self.assertIs(LoggingContext.current_context(), context_11)
context_12 = LoggingContext("12")
- context_12.test_key = "12"
+ context_12.request = "12"
with logcontext.PreserveLoggingContext(context_12):
# a second request for a server with outstanding requests
# should block rather than start a second call
@@ -183,7 +183,7 @@ class KeyringTestCase(unittest.TestCase):
res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1)],
)
- yield async.sleep(0.005)
+ yield async.sleep(01)
self.http_client.post_json.assert_not_called()
res_deferreds_2[0].addBoth(self.check_context, None)
@@ -211,7 +211,7 @@ class KeyringTestCase(unittest.TestCase):
sentinel_context = LoggingContext.current_context()
with LoggingContext("one") as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
defer = kr.verify_json_for_server("server9", {})
try:
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 19f5ed6bce..d92bf240b1 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -143,7 +143,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
except errors.SynapseError:
pass
- @unittest.DEBUG
@defer.inlineCallbacks
def test_claim_one_time_key(self):
local_user = "@boris:" + self.hs.hostname
diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py
index f85455a5af..39bde6e3f8 100644
--- a/tests/metrics/test_metric.py
+++ b/tests/metrics/test_metric.py
@@ -141,6 +141,7 @@ class CacheMetricTestCase(unittest.TestCase):
'cache:hits{name="cache_name"} 0',
'cache:total{name="cache_name"} 0',
'cache:size{name="cache_name"} 0',
+ 'cache:evicted_size{name="cache_name"} 0',
])
metric.inc_misses()
@@ -150,6 +151,7 @@ class CacheMetricTestCase(unittest.TestCase):
'cache:hits{name="cache_name"} 0',
'cache:total{name="cache_name"} 1',
'cache:size{name="cache_name"} 1',
+ 'cache:evicted_size{name="cache_name"} 0',
])
metric.inc_hits()
@@ -158,4 +160,14 @@ class CacheMetricTestCase(unittest.TestCase):
'cache:hits{name="cache_name"} 1',
'cache:total{name="cache_name"} 2',
'cache:size{name="cache_name"} 1',
+ 'cache:evicted_size{name="cache_name"} 0',
+ ])
+
+ metric.inc_evictions(2)
+
+ self.assertEquals(metric.render(), [
+ 'cache:hits{name="cache_name"} 1',
+ 'cache:total{name="cache_name"} 2',
+ 'cache:size{name="cache_name"} 1',
+ 'cache:evicted_size{name="cache_name"} 2',
])
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 81063f19a1..74f104e3b8 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -15,6 +15,8 @@
from twisted.internet import defer, reactor
from tests import unittest
+import tempfile
+
from mock import Mock, NonCallableMock
from tests.utils import setup_test_homeserver
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -41,7 +43,9 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
self.event_id = 0
server_factory = ReplicationStreamProtocolFactory(self.hs)
- listener = reactor.listenUNIX("\0xxx", server_factory)
+ # XXX: mktemp is unsafe and should never be used. but we're just a test.
+ path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket")
+ listener = reactor.listenUNIX(path, server_factory)
self.addCleanup(listener.stopListening)
self.streamer = server_factory.streamer
@@ -49,7 +53,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
client_factory = ReplicationClientFactory(
self.hs, "client_name", self.replication_handler
)
- client_connector = reactor.connectUNIX("\0xxx", client_factory)
+ client_connector = reactor.connectUNIX(path, client_factory)
self.addCleanup(client_factory.stopTrying)
self.addCleanup(client_connector.disconnect)
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 105e1228bb..f430cce931 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -226,11 +226,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
context = EventContext()
context.current_state_ids = state_ids
context.prev_state_ids = state_ids
- elif not backfill:
+ else:
state_handler = self.hs.get_state_handler()
context = yield state_handler.compute_event_context(event)
- else:
- context = EventContext()
context.push_actions = push_actions
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index de376fb514..9f37255381 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -515,9 +515,6 @@ class RoomsCreateTestCase(RestTestCase):
synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
- def tearDown(self):
- pass
-
@defer.inlineCallbacks
def test_post_room_no_keys(self):
# POST with no config keys, expect new room id
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 096f771bea..8aba456510 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -49,6 +49,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
self.hs.get_device_handler = Mock(return_value=self.device_handler)
self.hs.config.enable_registration = True
+ self.hs.config.registrations_require_3pid = []
self.hs.config.auto_join_rooms = []
# init the thing we're testing
diff --git a/tests/rest/media/__init__.py b/tests/rest/media/__init__.py
new file mode 100644
index 0000000000..a354d38ca8
--- /dev/null
+++ b/tests/rest/media/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/rest/media/v1/__init__.py b/tests/rest/media/v1/__init__.py
new file mode 100644
index 0000000000..a354d38ca8
--- /dev/null
+++ b/tests/rest/media/v1/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
new file mode 100644
index 0000000000..eef38b6781
--- /dev/null
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -0,0 +1,86 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet import defer
+
+from synapse.rest.media.v1._base import FileInfo
+from synapse.rest.media.v1.media_storage import MediaStorage
+from synapse.rest.media.v1.filepath import MediaFilePaths
+from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
+
+from mock import Mock
+
+from tests import unittest
+
+import os
+import shutil
+import tempfile
+
+
+class MediaStorageTests(unittest.TestCase):
+ def setUp(self):
+ self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
+
+ self.primary_base_path = os.path.join(self.test_dir, "primary")
+ self.secondary_base_path = os.path.join(self.test_dir, "secondary")
+
+ hs = Mock()
+ hs.config.media_store_path = self.primary_base_path
+
+ storage_providers = [FileStorageProviderBackend(
+ hs, self.secondary_base_path
+ )]
+
+ self.filepaths = MediaFilePaths(self.primary_base_path)
+ self.media_storage = MediaStorage(
+ self.primary_base_path, self.filepaths, storage_providers,
+ )
+
+ def tearDown(self):
+ shutil.rmtree(self.test_dir)
+
+ @defer.inlineCallbacks
+ def test_ensure_media_is_in_local_cache(self):
+ media_id = "some_media_id"
+ test_body = "Test\n"
+
+ # First we create a file that is in a storage provider but not in the
+ # local primary media store
+ rel_path = self.filepaths.local_media_filepath_rel(media_id)
+ secondary_path = os.path.join(self.secondary_base_path, rel_path)
+
+ os.makedirs(os.path.dirname(secondary_path))
+
+ with open(secondary_path, "w") as f:
+ f.write(test_body)
+
+ # Now we run ensure_media_is_in_local_cache, which should copy the file
+ # to the local cache.
+ file_info = FileInfo(None, media_id)
+ local_path = yield self.media_storage.ensure_media_is_in_local_cache(file_info)
+
+ self.assertTrue(os.path.exists(local_path))
+
+ # Asserts the file is under the expected local cache directory
+ self.assertEquals(
+ os.path.commonprefix([self.primary_base_path, local_path]),
+ self.primary_base_path,
+ )
+
+ with open(local_path) as f:
+ body = f.read()
+
+ self.assertEqual(test_body, body)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 6afaca3a61..888ddfaddd 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -36,8 +36,7 @@ class RedactionTestCase(unittest.TestCase):
self.store = hs.get_datastore()
self.event_builder_factory = hs.get_event_builder_factory()
- self.handlers = hs.get_handlers()
- self.message_handler = self.handlers.message_handler
+ self.event_creation_handler = hs.get_event_creation_handler()
self.u_alice = UserID.from_string("@alice:test")
self.u_bob = UserID.from_string("@bob:test")
@@ -59,7 +58,7 @@ class RedactionTestCase(unittest.TestCase):
"content": content,
})
- event, context = yield self.message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder
)
@@ -79,7 +78,7 @@ class RedactionTestCase(unittest.TestCase):
"content": {"body": body, "msgtype": u"message"},
})
- event, context = yield self.message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder
)
@@ -98,7 +97,7 @@ class RedactionTestCase(unittest.TestCase):
"redacts": event_id,
})
- event, context = yield self.message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder
)
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 1be7d932f6..657b279e5d 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -37,8 +37,7 @@ class RoomMemberStoreTestCase(unittest.TestCase):
# storage logic
self.store = hs.get_datastore()
self.event_builder_factory = hs.get_event_builder_factory()
- self.handlers = hs.get_handlers()
- self.message_handler = self.handlers.message_handler
+ self.event_creation_handler = hs.get_event_creation_handler()
self.u_alice = UserID.from_string("@alice:test")
self.u_bob = UserID.from_string("@bob:test")
@@ -58,7 +57,7 @@ class RoomMemberStoreTestCase(unittest.TestCase):
"content": {"membership": membership},
})
- event, context = yield self.message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder
)
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
new file mode 100644
index 0000000000..0891308f25
--- /dev/null
+++ b/tests/storage/test_user_directory.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.storage import UserDirectoryStore
+from synapse.storage.roommember import ProfileInfo
+from tests import unittest
+from tests.utils import setup_test_homeserver
+
+ALICE = "@alice:a"
+BOB = "@bob:b"
+BOBBY = "@bobby:a"
+
+
+class UserDirectoryStoreTestCase(unittest.TestCase):
+ @defer.inlineCallbacks
+ def setUp(self):
+ self.hs = yield setup_test_homeserver()
+ self.store = UserDirectoryStore(None, self.hs)
+
+ # alice and bob are both in !room_id. bobby is not but shares
+ # a homeserver with alice.
+ yield self.store.add_profiles_to_user_dir(
+ "!room:id",
+ {
+ ALICE: ProfileInfo(None, "alice"),
+ BOB: ProfileInfo(None, "bob"),
+ BOBBY: ProfileInfo(None, "bobby")
+ },
+ )
+ yield self.store.add_users_to_public_room(
+ "!room:id",
+ [ALICE, BOB],
+ )
+ yield self.store.add_users_who_share_room(
+ "!room:id",
+ False,
+ (
+ (ALICE, BOB),
+ (BOB, ALICE),
+ ),
+ )
+
+ @defer.inlineCallbacks
+ def test_search_user_dir(self):
+ # normally when alice searches the directory she should just find
+ # bob because bobby doesn't share a room with her.
+ r = yield self.store.search_user_dir(ALICE, "bob", 10)
+ self.assertFalse(r["limited"])
+ self.assertEqual(1, len(r["results"]))
+ self.assertDictEqual(r["results"][0], {
+ "user_id": BOB,
+ "display_name": "bob",
+ "avatar_url": None,
+ })
+
+ @defer.inlineCallbacks
+ def test_search_user_dir_all_users(self):
+ self.hs.config.user_directory_search_all_users = True
+ try:
+ r = yield self.store.search_user_dir(ALICE, "bob", 10)
+ self.assertFalse(r["limited"])
+ self.assertEqual(2, len(r["results"]))
+ self.assertDictEqual(r["results"][0], {
+ "user_id": BOB,
+ "display_name": "bob",
+ "avatar_url": None,
+ })
+ self.assertDictEqual(r["results"][1], {
+ "user_id": BOBBY,
+ "display_name": "bobby",
+ "avatar_url": None,
+ })
+ finally:
+ self.hs.config.user_directory_search_all_users = False
diff --git a/tests/test_state.py b/tests/test_state.py
index feb84f3d48..a5c5e55951 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.events import FrozenEvent
from synapse.api.auth import Auth
from synapse.api.constants import EventTypes, Membership
-from synapse.state import StateHandler
+from synapse.state import StateHandler, StateResolutionHandler
from .utils import MockClock
@@ -80,14 +80,14 @@ class StateGroupStore(object):
return defer.succeed(groups)
- def store_state_groups(self, event, context):
- if context.current_state_ids is None:
- return
+ def store_state_group(self, event_id, room_id, prev_group, delta_ids,
+ current_state_ids):
+ state_group = self._next_group
+ self._next_group += 1
- state_events = dict(context.current_state_ids)
+ self._group_to_state[state_group] = dict(current_state_ids)
- self._group_to_state[context.state_group] = state_events
- self._event_to_state_group[event.event_id] = context.state_group
+ return state_group
def get_events(self, event_ids, **kwargs):
return {
@@ -95,10 +95,19 @@ class StateGroupStore(object):
if e_id in self._event_id_to_event
}
+ def get_state_group_delta(self, name):
+ return (None, None)
+
def register_events(self, events):
for e in events:
self._event_id_to_event[e.event_id] = e
+ def register_event_context(self, event, context):
+ self._event_to_state_group[event.event_id] = context.state_group
+
+ def register_event_id_state_group(self, event_id, state_group):
+ self._event_to_state_group[event_id] = state_group
+
class DictObj(dict):
def __init__(self, **kwargs):
@@ -137,25 +146,16 @@ class Graph(object):
class StateTestCase(unittest.TestCase):
def setUp(self):
- self.store = Mock(
- spec_set=[
- "get_state_groups_ids",
- "add_event_hashes",
- "get_events",
- "get_next_state_group",
- "get_state_group_delta",
- ]
- )
+ self.store = StateGroupStore()
hs = Mock(spec_set=[
"get_datastore", "get_auth", "get_state_handler", "get_clock",
+ "get_state_resolution_handler",
])
hs.get_datastore.return_value = self.store
hs.get_state_handler.return_value = None
hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs)
-
- self.store.get_next_state_group.side_effect = Mock
- self.store.get_state_group_delta.return_value = (None, None)
+ hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
self.state = StateHandler(hs)
self.event_id = 0
@@ -195,14 +195,13 @@ class StateTestCase(unittest.TestCase):
}
)
- store = StateGroupStore()
- self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
+ self.store.register_events(graph.walk())
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
- store.store_state_groups(event, context)
+ self.store.register_event_context(event, context)
context_store[event.event_id] = context
self.assertEqual(2, len(context_store["D"].prev_state_ids))
@@ -247,16 +246,13 @@ class StateTestCase(unittest.TestCase):
}
)
- store = StateGroupStore()
- self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
- self.store.get_events = store.get_events
- store.register_events(graph.walk())
+ self.store.register_events(graph.walk())
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
- store.store_state_groups(event, context)
+ self.store.register_event_context(event, context)
context_store[event.event_id] = context
self.assertSetEqual(
@@ -313,16 +309,13 @@ class StateTestCase(unittest.TestCase):
}
)
- store = StateGroupStore()
- self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
- self.store.get_events = store.get_events
- store.register_events(graph.walk())
+ self.store.register_events(graph.walk())
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
- store.store_state_groups(event, context)
+ self.store.register_event_context(event, context)
context_store[event.event_id] = context
self.assertSetEqual(
@@ -396,16 +389,13 @@ class StateTestCase(unittest.TestCase):
self._add_depths(nodes, edges)
graph = Graph(nodes, edges)
- store = StateGroupStore()
- self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
- self.store.get_events = store.get_events
- store.register_events(graph.walk())
+ self.store.register_events(graph.walk())
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
- store.store_state_groups(event, context)
+ self.store.register_event_context(event, context)
context_store[event.event_id] = context
self.assertSetEqual(
@@ -465,7 +455,11 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_trivial_annotate_message(self):
- event = create_event(type="test_message", name="event")
+ prev_event_id = "prev_event_id"
+ event = create_event(
+ type="test_message", name="event2",
+ prev_events=[(prev_event_id, {})],
+ )
old_state = [
create_event(type="test1", state_key="1"),
@@ -473,11 +467,11 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- group_name = "group_name_1"
-
- self.store.get_state_groups_ids.return_value = {
- group_name: {(e.type, e.state_key): e.event_id for e in old_state},
- }
+ group_name = self.store.store_state_group(
+ prev_event_id, event.room_id, None, None,
+ {(e.type, e.state_key): e.event_id for e in old_state},
+ )
+ self.store.register_event_id_state_group(prev_event_id, group_name)
context = yield self.state.compute_event_context(event)
@@ -490,7 +484,11 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_trivial_annotate_state(self):
- event = create_event(type="state", state_key="", name="event")
+ prev_event_id = "prev_event_id"
+ event = create_event(
+ type="state", state_key="", name="event2",
+ prev_events=[(prev_event_id, {})],
+ )
old_state = [
create_event(type="test1", state_key="1"),
@@ -498,11 +496,11 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- group_name = "group_name_1"
-
- self.store.get_state_groups_ids.return_value = {
- group_name: {(e.type, e.state_key): e.event_id for e in old_state},
- }
+ group_name = self.store.store_state_group(
+ prev_event_id, event.room_id, None, None,
+ {(e.type, e.state_key): e.event_id for e in old_state},
+ )
+ self.store.register_event_id_state_group(prev_event_id, group_name)
context = yield self.state.compute_event_context(event)
@@ -515,7 +513,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_resolve_message_conflict(self):
- event = create_event(type="test_message", name="event")
+ prev_event_id1 = "event_id1"
+ prev_event_id2 = "event_id2"
+ event = create_event(
+ type="test_message", name="event3",
+ prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
+ )
creation = create_event(
type=EventTypes.Create, state_key=""
@@ -535,12 +538,12 @@ class StateTestCase(unittest.TestCase):
create_event(type="test4", state_key=""),
]
- store = StateGroupStore()
- store.register_events(old_state_1)
- store.register_events(old_state_2)
- self.store.get_events = store.get_events
+ self.store.register_events(old_state_1)
+ self.store.register_events(old_state_2)
- context = yield self._get_context(event, old_state_1, old_state_2)
+ context = yield self._get_context(
+ event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+ )
self.assertEqual(len(context.current_state_ids), 6)
@@ -548,7 +551,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_resolve_state_conflict(self):
- event = create_event(type="test4", state_key="", name="event")
+ prev_event_id1 = "event_id1"
+ prev_event_id2 = "event_id2"
+ event = create_event(
+ type="test4", state_key="", name="event",
+ prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
+ )
creation = create_event(
type=EventTypes.Create, state_key=""
@@ -573,7 +581,9 @@ class StateTestCase(unittest.TestCase):
store.register_events(old_state_2)
self.store.get_events = store.get_events
- context = yield self._get_context(event, old_state_1, old_state_2)
+ context = yield self._get_context(
+ event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+ )
self.assertEqual(len(context.current_state_ids), 6)
@@ -581,7 +591,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_standard_depth_conflict(self):
- event = create_event(type="test4", name="event")
+ prev_event_id1 = "event_id1"
+ prev_event_id2 = "event_id2"
+ event = create_event(
+ type="test4", name="event",
+ prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
+ )
member_event = create_event(
type=EventTypes.Member,
@@ -613,7 +628,9 @@ class StateTestCase(unittest.TestCase):
store.register_events(old_state_2)
self.store.get_events = store.get_events
- context = yield self._get_context(event, old_state_1, old_state_2)
+ context = yield self._get_context(
+ event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+ )
self.assertEqual(
old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
@@ -637,19 +654,26 @@ class StateTestCase(unittest.TestCase):
store.register_events(old_state_1)
store.register_events(old_state_2)
- context = yield self._get_context(event, old_state_1, old_state_2)
+ context = yield self._get_context(
+ event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+ )
self.assertEqual(
old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
)
- def _get_context(self, event, old_state_1, old_state_2):
- group_name_1 = "group_name_1"
- group_name_2 = "group_name_2"
+ def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2,
+ old_state_2):
+ sg1 = self.store.store_state_group(
+ prev_event_id_1, event.room_id, None, None,
+ {(e.type, e.state_key): e.event_id for e in old_state_1},
+ )
+ self.store.register_event_id_state_group(prev_event_id_1, sg1)
- self.store.get_state_groups_ids.return_value = {
- group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1},
- group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2},
- }
+ sg2 = self.store.store_state_group(
+ prev_event_id_2, event.room_id, None, None,
+ {(e.type, e.state_key): e.event_id for e in old_state_2},
+ )
+ self.store.register_event_id_state_group(prev_event_id_2, sg2)
return self.state.compute_event_context(event)
diff --git a/tests/unittest.py b/tests/unittest.py
index 38715972dd..7b478c4294 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import twisted
from twisted.trial import unittest
import logging
@@ -65,6 +65,10 @@ class TestCase(unittest.TestCase):
@around(self)
def setUp(orig):
+ # enable debugging of delayed calls - this means that we get a
+ # traceback when a unit test exits leaving things on the reactor.
+ twisted.internet.base.DelayedCall.debug = True
+
old_level = logging.getLogger().level
if old_level != level:
diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py
new file mode 100644
index 0000000000..76e2234255
--- /dev/null
+++ b/tests/util/test_file_consumer.py
@@ -0,0 +1,176 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet import defer, reactor
+from mock import NonCallableMock
+
+from synapse.util.file_consumer import BackgroundFileConsumer
+
+from tests import unittest
+from StringIO import StringIO
+
+import threading
+
+
+class FileConsumerTests(unittest.TestCase):
+
+ @defer.inlineCallbacks
+ def test_pull_consumer(self):
+ string_file = StringIO()
+ consumer = BackgroundFileConsumer(string_file)
+
+ try:
+ producer = DummyPullProducer()
+
+ yield producer.register_with_consumer(consumer)
+
+ yield producer.write_and_wait("Foo")
+
+ self.assertEqual(string_file.getvalue(), "Foo")
+
+ yield producer.write_and_wait("Bar")
+
+ self.assertEqual(string_file.getvalue(), "FooBar")
+ finally:
+ consumer.unregisterProducer()
+
+ yield consumer.wait()
+
+ self.assertTrue(string_file.closed)
+
+ @defer.inlineCallbacks
+ def test_push_consumer(self):
+ string_file = BlockingStringWrite()
+ consumer = BackgroundFileConsumer(string_file)
+
+ try:
+ producer = NonCallableMock(spec_set=[])
+
+ consumer.registerProducer(producer, True)
+
+ consumer.write("Foo")
+ yield string_file.wait_for_n_writes(1)
+
+ self.assertEqual(string_file.buffer, "Foo")
+
+ consumer.write("Bar")
+ yield string_file.wait_for_n_writes(2)
+
+ self.assertEqual(string_file.buffer, "FooBar")
+ finally:
+ consumer.unregisterProducer()
+
+ yield consumer.wait()
+
+ self.assertTrue(string_file.closed)
+
+ @defer.inlineCallbacks
+ def test_push_producer_feedback(self):
+ string_file = BlockingStringWrite()
+ consumer = BackgroundFileConsumer(string_file)
+
+ try:
+ producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
+
+ resume_deferred = defer.Deferred()
+ producer.resumeProducing.side_effect = lambda: resume_deferred.callback(None)
+
+ consumer.registerProducer(producer, True)
+
+ number_writes = 0
+ with string_file.write_lock:
+ for _ in range(consumer._PAUSE_ON_QUEUE_SIZE):
+ consumer.write("Foo")
+ number_writes += 1
+
+ producer.pauseProducing.assert_called_once()
+
+ yield string_file.wait_for_n_writes(number_writes)
+
+ yield resume_deferred
+ producer.resumeProducing.assert_called_once()
+ finally:
+ consumer.unregisterProducer()
+
+ yield consumer.wait()
+
+ self.assertTrue(string_file.closed)
+
+
+class DummyPullProducer(object):
+ def __init__(self):
+ self.consumer = None
+ self.deferred = defer.Deferred()
+
+ def resumeProducing(self):
+ d = self.deferred
+ self.deferred = defer.Deferred()
+ d.callback(None)
+
+ def write_and_wait(self, bytes):
+ d = self.deferred
+ self.consumer.write(bytes)
+ return d
+
+ def register_with_consumer(self, consumer):
+ d = self.deferred
+ self.consumer = consumer
+ self.consumer.registerProducer(self, False)
+ return d
+
+
+class BlockingStringWrite(object):
+ def __init__(self):
+ self.buffer = ""
+ self.closed = False
+ self.write_lock = threading.Lock()
+
+ self._notify_write_deferred = None
+ self._number_of_writes = 0
+
+ def write(self, bytes):
+ with self.write_lock:
+ self.buffer += bytes
+ self._number_of_writes += 1
+
+ reactor.callFromThread(self._notify_write)
+
+ def close(self):
+ self.closed = True
+
+ def _notify_write(self):
+ "Called by write to indicate a write happened"
+ with self.write_lock:
+ if not self._notify_write_deferred:
+ return
+ d = self._notify_write_deferred
+ self._notify_write_deferred = None
+ d.callback(None)
+
+ @defer.inlineCallbacks
+ def wait_for_n_writes(self, n):
+ "Wait for n writes to have happened"
+ while True:
+ with self.write_lock:
+ if n <= self._number_of_writes:
+ return
+
+ if not self._notify_write_deferred:
+ self._notify_write_deferred = defer.Deferred()
+
+ d = self._notify_write_deferred
+
+ yield d
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index e2f7765f49..4850722bc5 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -12,12 +12,12 @@ class LoggingContextTestCase(unittest.TestCase):
def _check_test_key(self, value):
self.assertEquals(
- LoggingContext.current_context().test_key, value
+ LoggingContext.current_context().request, value
)
def test_with_context(self):
with LoggingContext() as context_one:
- context_one.test_key = "test"
+ context_one.request = "test"
self._check_test_key("test")
@defer.inlineCallbacks
@@ -25,14 +25,14 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks
def competing_callback():
with LoggingContext() as competing_context:
- competing_context.test_key = "competing"
+ competing_context.request = "competing"
yield sleep(0)
self._check_test_key("competing")
reactor.callLater(0, competing_callback)
with LoggingContext() as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
yield sleep(0)
self._check_test_key("one")
@@ -43,14 +43,14 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks
def cb():
- context_one.test_key = "one"
+ context_one.request = "one"
yield function()
self._check_test_key("one")
callback_completed[0] = True
with LoggingContext() as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
# fire off function, but don't wait on it.
logcontext.preserve_fn(cb)()
@@ -107,7 +107,7 @@ class LoggingContextTestCase(unittest.TestCase):
sentinel_context = LoggingContext.current_context()
with LoggingContext() as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
d1 = logcontext.make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
@@ -124,7 +124,7 @@ class LoggingContextTestCase(unittest.TestCase):
argument isn't actually a deferred"""
with LoggingContext() as context_one:
- context_one.test_key = "one"
+ context_one.request = "one"
d1 = logcontext.make_deferred_yieldable("bum")
self._check_test_key("one")
diff --git a/tests/utils.py b/tests/utils.py
index ed8a7360f5..8efd3a3475 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -13,27 +13,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.http.server import HttpServer
-from synapse.api.errors import cs_error, CodeMessageException, StoreError
-from synapse.api.constants import EventTypes
-from synapse.storage.prepare_database import prepare_database
-from synapse.storage.engines import create_engine
-from synapse.server import HomeServer
-from synapse.federation.transport import server
-from synapse.util.ratelimitutils import FederationRateLimiter
-
-from synapse.util.logcontext import LoggingContext
-
-from twisted.internet import defer, reactor
-from twisted.enterprise.adbapi import ConnectionPool
-
-from collections import namedtuple
-from mock import patch, Mock
import hashlib
+from inspect import getcallargs
import urllib
import urlparse
-from inspect import getcallargs
+from mock import Mock, patch
+from twisted.internet import defer, reactor
+
+from synapse.api.errors import CodeMessageException, cs_error
+from synapse.federation.transport import server
+from synapse.http.server import HttpServer
+from synapse.server import HomeServer
+from synapse.storage import PostgresEngine
+from synapse.storage.engines import create_engine
+from synapse.storage.prepare_database import prepare_database
+from synapse.util.logcontext import LoggingContext
+from synapse.util.ratelimitutils import FederationRateLimiter
+
+# set this to True to run the tests against postgres instead of sqlite.
+# It requires you to have a local postgres database called synapse_test, within
+# which ALL TABLES WILL BE DROPPED
+USE_POSTGRES_FOR_TESTS = False
@defer.inlineCallbacks
@@ -57,32 +58,70 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config.worker_app = None
config.email_enable_notifs = False
config.block_non_admin_invites = False
+ config.federation_domain_whitelist = None
+ config.user_directory_search_all_users = False
+
+ # disable user directory updates, because they get done in the
+ # background, which upsets the test runner.
+ config.update_user_directory = False
config.use_frozen_dicts = True
- config.database_config = {"name": "sqlite3"}
config.ldap_enabled = False
if "clock" not in kargs:
kargs["clock"] = MockClock()
+ if USE_POSTGRES_FOR_TESTS:
+ config.database_config = {
+ "name": "psycopg2",
+ "args": {
+ "database": "synapse_test",
+ "cp_min": 1,
+ "cp_max": 5,
+ },
+ }
+ else:
+ config.database_config = {
+ "name": "sqlite3",
+ "args": {
+ "database": ":memory:",
+ "cp_min": 1,
+ "cp_max": 1,
+ },
+ }
+
+ db_engine = create_engine(config.database_config)
+
+ # we need to configure the connection pool to run the on_new_connection
+ # function, so that we can test code that uses custom sqlite functions
+ # (like rank).
+ config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
+
if datastore is None:
- db_pool = SQLiteMemoryDbPool()
- yield db_pool.prepare()
hs = HomeServer(
- name, db_pool=db_pool, config=config,
+ name, config=config,
+ db_config=config.database_config,
version_string="Synapse/tests",
- database_engine=create_engine(config.database_config),
- get_db_conn=db_pool.get_db_conn,
+ database_engine=db_engine,
room_list_handler=object(),
tls_server_context_factory=Mock(),
**kargs
)
+ db_conn = hs.get_db_conn()
+ # make sure that the database is empty
+ if isinstance(db_engine, PostgresEngine):
+ cur = db_conn.cursor()
+ cur.execute("SELECT tablename FROM pg_tables where schemaname='public'")
+ rows = cur.fetchall()
+ for r in rows:
+ cur.execute("DROP TABLE %s CASCADE" % r[0])
+ yield prepare_database(db_conn, db_engine, config)
hs.setup()
else:
hs = HomeServer(
name, db_pool=None, datastore=datastore, config=config,
version_string="Synapse/tests",
- database_engine=create_engine(config.database_config),
+ database_engine=db_engine,
room_list_handler=object(),
tls_server_context_factory=Mock(),
**kargs
@@ -301,168 +340,6 @@ class MockClock(object):
return d
-class SQLiteMemoryDbPool(ConnectionPool, object):
- def __init__(self):
- super(SQLiteMemoryDbPool, self).__init__(
- "sqlite3", ":memory:",
- cp_min=1,
- cp_max=1,
- )
-
- self.config = Mock()
- self.config.password_providers = []
- self.config.database_config = {"name": "sqlite3"}
-
- def prepare(self):
- engine = self.create_engine()
- return self.runWithConnection(
- lambda conn: prepare_database(conn, engine, self.config)
- )
-
- def get_db_conn(self):
- conn = self.connect()
- engine = self.create_engine()
- prepare_database(conn, engine, self.config)
- return conn
-
- def create_engine(self):
- return create_engine(self.config.database_config)
-
-
-class MemoryDataStore(object):
-
- Room = namedtuple(
- "Room",
- ["room_id", "is_public", "creator"]
- )
-
- def __init__(self):
- self.tokens_to_users = {}
- self.paths_to_content = {}
-
- self.members = {}
- self.rooms = {}
-
- self.current_state = {}
- self.events = []
-
- class Snapshot(namedtuple("Snapshot", "room_id user_id membership_state")):
- def fill_out_prev_events(self, event):
- pass
-
- def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
- return self.Snapshot(
- room_id, user_id, self.get_room_member(user_id, room_id)
- )
-
- def register(self, user_id, token, password_hash):
- if user_id in self.tokens_to_users.values():
- raise StoreError(400, "User in use.")
- self.tokens_to_users[token] = user_id
-
- def get_user_by_access_token(self, token):
- try:
- return {
- "name": self.tokens_to_users[token],
- }
- except Exception:
- raise StoreError(400, "User does not exist.")
-
- def get_room(self, room_id):
- try:
- return self.rooms[room_id]
- except Exception:
- return None
-
- def store_room(self, room_id, room_creator_user_id, is_public):
- if room_id in self.rooms:
- raise StoreError(409, "Conflicting room!")
-
- room = MemoryDataStore.Room(
- room_id=room_id,
- is_public=is_public,
- creator=room_creator_user_id
- )
- self.rooms[room_id] = room
-
- def get_room_member(self, user_id, room_id):
- return self.members.get(room_id, {}).get(user_id)
-
- def get_room_members(self, room_id, membership=None):
- if membership:
- return [
- v for k, v in self.members.get(room_id, {}).items()
- if v.membership == membership
- ]
- else:
- return self.members.get(room_id, {}).values()
-
- def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
- return [
- m[user_id] for m in self.members.values()
- if user_id in m and m[user_id].membership in membership_list
- ]
-
- def get_room_events_stream(self, user_id=None, from_key=None, to_key=None,
- limit=0, with_feedback=False):
- return ([], from_key) # TODO
-
- def get_joined_hosts_for_room(self, room_id):
- return defer.succeed([])
-
- def persist_event(self, event):
- if event.type == EventTypes.Member:
- room_id = event.room_id
- user = event.state_key
- self.members.setdefault(room_id, {})[user] = event
-
- if hasattr(event, "state_key"):
- key = (event.room_id, event.type, event.state_key)
- self.current_state[key] = event
-
- self.events.append(event)
-
- def get_current_state(self, room_id, event_type=None, state_key=""):
- if event_type:
- key = (room_id, event_type, state_key)
- if self.current_state.get(key):
- return [self.current_state.get(key)]
- return None
- else:
- return [
- e for e in self.current_state
- if e[0] == room_id
- ]
-
- def set_presence_state(self, user_localpart, state):
- return defer.succeed({"state": 0})
-
- def get_presence_list(self, user_localpart, accepted):
- return []
-
- def get_room_events_max_id(self):
- return "s0" # TODO (erikj)
-
- def get_send_event_level(self, room_id):
- return defer.succeed(0)
-
- def get_power_level(self, room_id, user_id):
- return defer.succeed(0)
-
- def get_add_state_level(self, room_id):
- return defer.succeed(0)
-
- def get_room_join_rule(self, room_id):
- # TODO (erikj): This should be configurable
- return defer.succeed("invite")
-
- def get_ops_levels(self, room_id):
- return defer.succeed((5, 5, 5))
-
- def insert_client_ip(self, user, access_token, ip, user_agent):
- return defer.succeed(None)
-
-
def _format_call(args, kwargs):
return ", ".join(
["%r" % (a) for a in args] +
|