diff --git a/tests/config/test_server.py b/tests/config/test_server.py
new file mode 100644
index 0000000000..f5836d73ac
--- /dev/null
+++ b/tests/config/test_server.py
@@ -0,0 +1,32 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.config.server import is_threepid_reserved
+
+from tests import unittest
+
+
+class ServerConfigTestCase(unittest.TestCase):
+
+ def test_is_threepid_reserved(self):
+ user1 = {'medium': 'email', 'address': 'user1@example.com'}
+ user2 = {'medium': 'email', 'address': 'user2@example.com'}
+ user3 = {'medium': 'email', 'address': 'user3@example.com'}
+ user1_msisdn = {'medium': 'msisdn', 'address': '447700000000'}
+ config = [user1, user2]
+
+ self.assertTrue(is_threepid_reserved(config, user1))
+ self.assertFalse(is_threepid_reserved(config, user3))
+ self.assertFalse(is_threepid_reserved(config, user1_msisdn))
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index d643bec887..f5bd7a1aa1 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2017 New Vector Ltd.
+# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,17 +16,19 @@ import time
from mock import Mock
+import canonicaljson
import signedjson.key
import signedjson.sign
-from twisted.internet import defer, reactor
+from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.crypto import keyring
-from synapse.util import Clock, logcontext
+from synapse.crypto.keyring import KeyLookupError
+from synapse.util import logcontext
from synapse.util.logcontext import LoggingContext
-from tests import unittest, utils
+from tests import unittest
class MockPerspectiveServer(object):
@@ -48,79 +50,57 @@ class MockPerspectiveServer(object):
key_id: {"key": signedjson.key.encode_verify_key_base64(verify_key)}
},
}
+ return self.get_signed_response(res)
+
+ def get_signed_response(self, res):
signedjson.sign.sign_json(res, self.server_name, self.key)
return res
-class KeyringTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
+class KeyringTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
self.mock_perspective_server = MockPerspectiveServer()
self.http_client = Mock()
- self.hs = yield utils.setup_test_homeserver(
- self.addCleanup, handlers=None, http_client=self.http_client
- )
+ hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
keys = self.mock_perspective_server.get_verify_keys()
- self.hs.config.perspectives = {self.mock_perspective_server.server_name: keys}
-
- def assert_sentinel_context(self):
- if LoggingContext.current_context() != LoggingContext.sentinel:
- self.fail(
- "Expected sentinel context but got %s" % (
- LoggingContext.current_context(),
- )
- )
+ hs.config.perspectives = {self.mock_perspective_server.server_name: keys}
+ return hs
def check_context(self, _, expected):
self.assertEquals(
getattr(LoggingContext.current_context(), "request", None), expected
)
- @defer.inlineCallbacks
def test_wait_for_previous_lookups(self):
kr = keyring.Keyring(self.hs)
lookup_1_deferred = defer.Deferred()
lookup_2_deferred = defer.Deferred()
- with LoggingContext("one") as context_one:
- context_one.request = "one"
-
- wait_1_deferred = kr.wait_for_previous_lookups(
- ["server1"], {"server1": lookup_1_deferred}
- )
-
- # there were no previous lookups, so the deferred should be ready
- self.assertTrue(wait_1_deferred.called)
- # ... so we should have preserved the LoggingContext.
- self.assertIs(LoggingContext.current_context(), context_one)
- wait_1_deferred.addBoth(self.check_context, "one")
-
- with LoggingContext("two") as context_two:
- context_two.request = "two"
+ # we run the lookup in a logcontext so that the patched inlineCallbacks can check
+ # it is doing the right thing with logcontexts.
+ wait_1_deferred = run_in_context(
+ kr.wait_for_previous_lookups, ["server1"], {"server1": lookup_1_deferred}
+ )
- # set off another wait. It should block because the first lookup
- # hasn't yet completed.
- wait_2_deferred = kr.wait_for_previous_lookups(
- ["server1"], {"server1": lookup_2_deferred}
- )
- self.assertFalse(wait_2_deferred.called)
+ # there were no previous lookups, so the deferred should be ready
+ self.successResultOf(wait_1_deferred)
- # ... so we should have reset the LoggingContext.
- self.assert_sentinel_context()
+ # set off another wait. It should block because the first lookup
+ # hasn't yet completed.
+ wait_2_deferred = run_in_context(
+ kr.wait_for_previous_lookups, ["server1"], {"server1": lookup_2_deferred}
+ )
- wait_2_deferred.addBoth(self.check_context, "two")
+ self.assertFalse(wait_2_deferred.called)
- # let the first lookup complete (in the sentinel context)
- lookup_1_deferred.callback(None)
+ # let the first lookup complete (in the sentinel context)
+ lookup_1_deferred.callback(None)
- # now the second wait should complete and restore our
- # loggingcontext.
- yield wait_2_deferred
+ # now the second wait should complete.
+ self.successResultOf(wait_2_deferred)
- @defer.inlineCallbacks
def test_verify_json_objects_for_server_awaits_previous_requests(self):
- clock = Clock(reactor)
key1 = signedjson.key.generate_signing_key(1)
kr = keyring.Keyring(self.hs)
@@ -145,81 +125,229 @@ class KeyringTestCase(unittest.TestCase):
self.http_client.post_json.side_effect = get_perspectives
- with LoggingContext("11") as context_11:
- context_11.request = "11"
-
- # start off a first set of lookups
- res_deferreds = kr.verify_json_objects_for_server(
- [("server10", json1), ("server11", {})]
- )
-
- # the unsigned json should be rejected pretty quickly
- self.assertTrue(res_deferreds[1].called)
- try:
- yield res_deferreds[1]
- self.assertFalse("unsigned json didn't cause a failure")
- except SynapseError:
- pass
-
- self.assertFalse(res_deferreds[0].called)
- res_deferreds[0].addBoth(self.check_context, None)
-
- # wait a tick for it to send the request to the perspectives server
- # (it first tries the datastore)
- yield clock.sleep(1) # XXX find out why this takes so long!
- self.http_client.post_json.assert_called_once()
-
- self.assertIs(LoggingContext.current_context(), context_11)
-
- context_12 = LoggingContext("12")
- context_12.request = "12"
- with logcontext.PreserveLoggingContext(context_12):
- # a second request for a server with outstanding requests
- # should block rather than start a second call
+ # start off a first set of lookups
+ @defer.inlineCallbacks
+ def first_lookup():
+ with LoggingContext("11") as context_11:
+ context_11.request = "11"
+
+ res_deferreds = kr.verify_json_objects_for_server(
+ [("server10", json1), ("server11", {})]
+ )
+
+ # the unsigned json should be rejected pretty quickly
+ self.assertTrue(res_deferreds[1].called)
+ try:
+ yield res_deferreds[1]
+ self.assertFalse("unsigned json didn't cause a failure")
+ except SynapseError:
+ pass
+
+ self.assertFalse(res_deferreds[0].called)
+ res_deferreds[0].addBoth(self.check_context, None)
+
+ yield logcontext.make_deferred_yieldable(res_deferreds[0])
+
+ # let verify_json_objects_for_server finish its work before we kill the
+ # logcontext
+ yield self.clock.sleep(0)
+
+ d0 = first_lookup()
+
+ # wait a tick for it to send the request to the perspectives server
+ # (it first tries the datastore)
+ self.pump()
+ self.http_client.post_json.assert_called_once()
+
+ # a second request for a server with outstanding requests
+ # should block rather than start a second call
+ @defer.inlineCallbacks
+ def second_lookup():
+ with LoggingContext("12") as context_12:
+ context_12.request = "12"
self.http_client.post_json.reset_mock()
self.http_client.post_json.return_value = defer.Deferred()
res_deferreds_2 = kr.verify_json_objects_for_server(
- [("server10", json1)]
+ [("server10", json1, )]
)
- yield clock.sleep(1)
- self.http_client.post_json.assert_not_called()
res_deferreds_2[0].addBoth(self.check_context, None)
+ yield logcontext.make_deferred_yieldable(res_deferreds_2[0])
- # complete the first request
- with logcontext.PreserveLoggingContext():
- persp_deferred.callback(persp_resp)
- self.assertIs(LoggingContext.current_context(), context_11)
+ # let verify_json_objects_for_server finish its work before we kill the
+ # logcontext
+ yield self.clock.sleep(0)
- with logcontext.PreserveLoggingContext():
- yield res_deferreds[0]
- yield res_deferreds_2[0]
+ d2 = second_lookup()
+
+ self.pump()
+ self.http_client.post_json.assert_not_called()
+
+ # complete the first request
+ persp_deferred.callback(persp_resp)
+ self.get_success(d0)
+ self.get_success(d2)
- @defer.inlineCallbacks
def test_verify_json_for_server(self):
kr = keyring.Keyring(self.hs)
key1 = signedjson.key.generate_signing_key(1)
- yield self.hs.datastore.store_server_verify_key(
+ r = self.hs.datastore.store_server_verify_key(
"server9", "", time.time() * 1000, signedjson.key.get_verify_key(key1)
)
+ self.get_success(r)
json1 = {}
signedjson.sign.sign_json(json1, "server9", key1)
- with LoggingContext("one") as context_one:
- context_one.request = "one"
+ # should fail immediately on an unsigned object
+ d = _verify_json_for_server(kr, "server9", {})
+ self.failureResultOf(d, SynapseError)
- defer = kr.verify_json_for_server("server9", {})
- try:
- yield defer
- self.fail("should fail on unsigned json")
- except SynapseError:
- pass
- self.assertIs(LoggingContext.current_context(), context_one)
+ d = _verify_json_for_server(kr, "server9", json1)
+ self.assertFalse(d.called)
+ self.get_success(d)
- defer = kr.verify_json_for_server("server9", json1)
- self.assertFalse(defer.called)
- self.assert_sentinel_context()
- yield defer
+ def test_get_keys_from_server(self):
+ # arbitrarily advance the clock a bit
+ self.reactor.advance(100)
+
+ SERVER_NAME = "server2"
+ kr = keyring.Keyring(self.hs)
+ testkey = signedjson.key.generate_signing_key("ver1")
+ testverifykey = signedjson.key.get_verify_key(testkey)
+ testverifykey_id = "ed25519:ver1"
+ VALID_UNTIL_TS = 1000
+
+ # valid response
+ response = {
+ "server_name": SERVER_NAME,
+ "old_verify_keys": {},
+ "valid_until_ts": VALID_UNTIL_TS,
+ "verify_keys": {
+ testverifykey_id: {
+ "key": signedjson.key.encode_verify_key_base64(testverifykey)
+ }
+ },
+ }
+ signedjson.sign.sign_json(response, SERVER_NAME, testkey)
+
+ def get_json(destination, path, **kwargs):
+ self.assertEqual(destination, SERVER_NAME)
+ self.assertEqual(path, "/_matrix/key/v2/server/key1")
+ return response
+
+ self.http_client.get_json.side_effect = get_json
+
+ server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
+ keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids))
+ k = keys[SERVER_NAME][testverifykey_id]
+ self.assertEqual(k, testverifykey)
+ self.assertEqual(k.alg, "ed25519")
+ self.assertEqual(k.version, "ver1")
+
+ # check that the perspectives store is correctly updated
+ lookup_triplet = (SERVER_NAME, testverifykey_id, None)
+ key_json = self.get_success(
+ self.hs.get_datastore().get_server_keys_json([lookup_triplet])
+ )
+ res = key_json[lookup_triplet]
+ self.assertEqual(len(res), 1)
+ res = res[0]
+ self.assertEqual(res["key_id"], testverifykey_id)
+ self.assertEqual(res["from_server"], SERVER_NAME)
+ self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
+ self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
+
+ # we expect it to be encoded as canonical json *before* it hits the db
+ self.assertEqual(
+ bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
+ )
+
+ # change the server name: it should cause a rejection
+ response["server_name"] = "OTHER_SERVER"
+ self.get_failure(
+ kr.get_keys_from_server(server_name_and_key_ids), KeyLookupError
+ )
+
+ def test_get_keys_from_perspectives(self):
+ # arbitrarily advance the clock a bit
+ self.reactor.advance(100)
+
+ SERVER_NAME = "server2"
+ kr = keyring.Keyring(self.hs)
+ testkey = signedjson.key.generate_signing_key("ver1")
+ testverifykey = signedjson.key.get_verify_key(testkey)
+ testverifykey_id = "ed25519:ver1"
+ VALID_UNTIL_TS = 200 * 1000
+
+ # valid response
+ response = {
+ "server_name": SERVER_NAME,
+ "old_verify_keys": {},
+ "valid_until_ts": VALID_UNTIL_TS,
+ "verify_keys": {
+ testverifykey_id: {
+ "key": signedjson.key.encode_verify_key_base64(testverifykey)
+ }
+ },
+ }
+
+ persp_resp = {
+ "server_keys": [self.mock_perspective_server.get_signed_response(response)]
+ }
+
+ def post_json(destination, path, data, **kwargs):
+ self.assertEqual(destination, self.mock_perspective_server.server_name)
+ self.assertEqual(path, "/_matrix/key/v2/query")
+
+ # check that the request is for the expected key
+ q = data["server_keys"]
+ self.assertEqual(list(q[SERVER_NAME].keys()), ["key1"])
+ return persp_resp
+
+ self.http_client.post_json.side_effect = post_json
+
+ server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
+ keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids))
+ self.assertIn(SERVER_NAME, keys)
+ k = keys[SERVER_NAME][testverifykey_id]
+ self.assertEqual(k, testverifykey)
+ self.assertEqual(k.alg, "ed25519")
+ self.assertEqual(k.version, "ver1")
+
+ # check that the perspectives store is correctly updated
+ lookup_triplet = (SERVER_NAME, testverifykey_id, None)
+ key_json = self.get_success(
+ self.hs.get_datastore().get_server_keys_json([lookup_triplet])
+ )
+ res = key_json[lookup_triplet]
+ self.assertEqual(len(res), 1)
+ res = res[0]
+ self.assertEqual(res["key_id"], testverifykey_id)
+ self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
+ self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
+ self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
+
+ self.assertEqual(
+ bytes(res["key_json"]),
+ canonicaljson.encode_canonical_json(persp_resp["server_keys"][0]),
+ )
+
+
+@defer.inlineCallbacks
+def run_in_context(f, *args, **kwargs):
+ with LoggingContext("testctx"):
+ rv = yield f(*args, **kwargs)
+ defer.returnValue(rv)
+
+
+def _verify_json_for_server(keyring, server_name, json_object):
+ """thin wrapper around verify_json_for_server which makes sure it is wrapped
+ with the patched defer.inlineCallbacks.
+ """
+ @defer.inlineCallbacks
+ def v():
+ rv1 = yield keyring.verify_json_for_server(server_name, json_object)
+ defer.returnValue(rv1)
- self.assertIs(LoggingContext.current_context(), context_one)
+ return run_in_context(v)
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index fc2b646ba2..94c6080e34 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -16,7 +16,11 @@
from mock import Mock, call
-from synapse.api.constants import PresenceState
+from signedjson.key import generate_signing_key
+
+from synapse.api.constants import EventTypes, Membership, PresenceState
+from synapse.events import room_version_to_event_format
+from synapse.events.builder import EventBuilder
from synapse.handlers.presence import (
FEDERATION_PING_INTERVAL,
FEDERATION_TIMEOUT,
@@ -26,7 +30,9 @@ from synapse.handlers.presence import (
handle_timeout,
handle_update,
)
+from synapse.rest.client.v1 import room
from synapse.storage.presence import UserPresenceState
+from synapse.types import UserID, get_domain_from_id
from tests import unittest
@@ -405,3 +411,171 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertIsNotNone(new_state)
self.assertEquals(state, new_state)
+
+
+class PresenceJoinTestCase(unittest.HomeserverTestCase):
+ """Tests remote servers get told about presence of users in the room when
+ they join and when new local users join.
+ """
+
+ user_id = "@test:server"
+
+ servlets = [room.register_servlets]
+
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver(
+ "server", http_client=None,
+ federation_sender=Mock(),
+ )
+ return hs
+
+ def prepare(self, reactor, clock, hs):
+ self.federation_sender = hs.get_federation_sender()
+ self.event_builder_factory = hs.get_event_builder_factory()
+ self.federation_handler = hs.get_handlers().federation_handler
+ self.presence_handler = hs.get_presence_handler()
+
+ # self.event_builder_for_2 = EventBuilderFactory(hs)
+ # self.event_builder_for_2.hostname = "test2"
+
+ self.store = hs.get_datastore()
+ self.state = hs.get_state_handler()
+ self.auth = hs.get_auth()
+
+ # We don't actually check signatures in tests, so lets just create a
+ # random key to use.
+ self.random_signing_key = generate_signing_key("ver")
+
+ def test_remote_joins(self):
+ # We advance time to something that isn't 0, as we use 0 as a special
+ # value.
+ self.reactor.advance(1000000000000)
+
+ # Create a room with two local users
+ room_id = self.helper.create_room_as(self.user_id)
+ self.helper.join(room_id, "@test2:server")
+
+ # Mark test2 as online, test will be offline with a last_active of 0
+ self.presence_handler.set_state(
+ UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE},
+ )
+ self.reactor.pump([0]) # Wait for presence updates to be handled
+
+ #
+ # Test that a new server gets told about existing presence
+ #
+
+ self.federation_sender.reset_mock()
+
+ # Add a new remote server to the room
+ self._add_new_user(room_id, "@alice:server2")
+
+ # We shouldn't have sent out any local presence *updates*
+ self.federation_sender.send_presence.assert_not_called()
+
+ # When new server is joined we send it the local users presence states.
+ # We expect to only see user @test2:server, as @test:server is offline
+ # and has a zero last_active_ts
+ expected_state = self.get_success(
+ self.presence_handler.current_state_for_user("@test2:server")
+ )
+ self.assertEqual(expected_state.state, PresenceState.ONLINE)
+ self.federation_sender.send_presence_to_destinations.assert_called_once_with(
+ destinations=["server2"], states=[expected_state]
+ )
+
+ #
+ # Test that only the new server gets sent presence and not existing servers
+ #
+
+ self.federation_sender.reset_mock()
+ self._add_new_user(room_id, "@bob:server3")
+
+ self.federation_sender.send_presence.assert_not_called()
+ self.federation_sender.send_presence_to_destinations.assert_called_once_with(
+ destinations=["server3"], states=[expected_state]
+ )
+
+ def test_remote_gets_presence_when_local_user_joins(self):
+ # We advance time to something that isn't 0, as we use 0 as a special
+ # value.
+ self.reactor.advance(1000000000000)
+
+ # Create a room with one local users
+ room_id = self.helper.create_room_as(self.user_id)
+
+ # Mark test as online
+ self.presence_handler.set_state(
+ UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE},
+ )
+
+ # Mark test2 as online, test will be offline with a last_active of 0.
+ # Note we don't join them to the room yet
+ self.presence_handler.set_state(
+ UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE},
+ )
+
+ # Add servers to the room
+ self._add_new_user(room_id, "@alice:server2")
+ self._add_new_user(room_id, "@bob:server3")
+
+ self.reactor.pump([0]) # Wait for presence updates to be handled
+
+ #
+ # Test that when a local join happens remote servers get told about it
+ #
+
+ self.federation_sender.reset_mock()
+
+ # Join local user to room
+ self.helper.join(room_id, "@test2:server")
+
+ self.reactor.pump([0]) # Wait for presence updates to be handled
+
+ # We shouldn't have sent out any local presence *updates*
+ self.federation_sender.send_presence.assert_not_called()
+
+ # We expect to only send test2 presence to server2 and server3
+ expected_state = self.get_success(
+ self.presence_handler.current_state_for_user("@test2:server")
+ )
+ self.assertEqual(expected_state.state, PresenceState.ONLINE)
+ self.federation_sender.send_presence_to_destinations.assert_called_once_with(
+ destinations=set(("server2", "server3")),
+ states=[expected_state]
+ )
+
+ def _add_new_user(self, room_id, user_id):
+ """Add new user to the room by creating an event and poking the federation API.
+ """
+
+ hostname = get_domain_from_id(user_id)
+
+ room_version = self.get_success(self.store.get_room_version(room_id))
+
+ builder = EventBuilder(
+ state=self.state,
+ auth=self.auth,
+ store=self.store,
+ clock=self.clock,
+ hostname=hostname,
+ signing_key=self.random_signing_key,
+ format_version=room_version_to_event_format(room_version),
+ room_id=room_id,
+ type=EventTypes.Member,
+ sender=user_id,
+ state_key=user_id,
+ content={"membership": Membership.JOIN}
+ )
+
+ prev_event_ids = self.get_success(
+ self.store.get_latest_event_ids_in_room(room_id)
+ )
+
+ event = self.get_success(builder.build(prev_event_ids))
+
+ self.get_success(self.federation_handler.on_receive_pdu(hostname, event))
+
+ # Check that it was successfully persisted.
+ self.get_success(self.store.get_event(event.event_id))
+ self.get_success(self.store.get_event(event.event_id))
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 6460cbc708..5a0b6c201c 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -121,9 +121,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
- def get_current_user_in_room(room_id):
+ def get_current_users_in_room(room_id):
return set(str(u) for u in self.room_members)
- hs.get_state_handler().get_current_user_in_room = get_current_user_in_room
+ hs.get_state_handler().get_current_users_in_room = get_current_users_in_room
self.datastore.get_user_directory_stream_pos.return_value = (
# we deliberately return a non-None stream pos to avoid doing an initial_spam
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 524af4f8d1..1f72a2a04f 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -56,7 +56,9 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
client = client_factory.buildProtocol(None)
client.makeConnection(FakeTransport(server, reactor))
- server.makeConnection(FakeTransport(client, reactor))
+
+ self.server_to_client_transport = FakeTransport(client, reactor)
+ server.makeConnection(self.server_to_client_transport)
def replicate(self):
"""Tell the master side of replication that something has happened, and then
@@ -69,6 +71,24 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
master_result = self.get_success(getattr(self.master_store, method)(*args))
slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
if expected_result is not None:
- self.assertEqual(master_result, expected_result)
- self.assertEqual(slaved_result, expected_result)
- self.assertEqual(master_result, slaved_result)
+ self.assertEqual(
+ master_result,
+ expected_result,
+ "Expected master result to be %r but was %r" % (
+ expected_result, master_result
+ ),
+ )
+ self.assertEqual(
+ slaved_result,
+ expected_result,
+ "Expected slave result to be %r but was %r" % (
+ expected_result, slaved_result
+ ),
+ )
+ self.assertEqual(
+ master_result,
+ slaved_result,
+ "Slave result %r does not match master result %r" % (
+ slaved_result, master_result
+ ),
+ )
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 1688a741d1..65ecff3bd6 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -11,11 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
from canonicaljson import encode_canonical_json
from synapse.events import FrozenEvent, _EventInternalMetadata
from synapse.events.snapshot import EventContext
+from synapse.handlers.room import RoomEventSource
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.storage.roommember import RoomsForUser
@@ -26,6 +28,8 @@ USER_ID_2 = "@bright:blue"
OUTLIER = {"outlier": True}
ROOM_ID = "!room:blue"
+logger = logging.getLogger(__name__)
+
def dict_equals(self, other):
me = encode_canonical_json(self.get_pdu_json())
@@ -172,18 +176,142 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
{"highlight_count": 1, "notify_count": 2},
)
+ def test_get_rooms_for_user_with_stream_ordering(self):
+ """Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated
+ by rows in the events stream
+ """
+ self.persist(type="m.room.create", key="", creator=USER_ID)
+ self.persist(type="m.room.member", key=USER_ID, membership="join")
+ self.replicate()
+ self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())
+
+ j2 = self.persist(
+ type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
+ )
+ self.replicate()
+ self.check(
+ "get_rooms_for_user_with_stream_ordering",
+ (USER_ID_2,),
+ {(ROOM_ID, j2.internal_metadata.stream_ordering)},
+ )
+
+ def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
+ """Check that current_state invalidation happens correctly with multiple events
+ in the persistence batch.
+
+ This test attempts to reproduce a race condition between the event persistence
+ loop and a worker-based Sync handler.
+
+ The problem occurred when the master persisted several events in one batch. It
+ only updates the current_state at the end of each batch, so the obvious thing
+ to do is then to issue a current_state_delta stream update corresponding to the
+ last stream_id in the batch.
+
+ However, that raises the possibility that a worker will see the replication
+ notification for a join event before the current_state caches are invalidated.
+
+ The test involves:
+ * creating a join and a message event for a user, and persisting them in the
+ same batch
+
+ * controlling the replication stream so that updates are sent gradually
+
+ * between each bunch of replication updates, check that we see a consistent
+ snapshot of the state.
+ """
+ self.persist(type="m.room.create", key="", creator=USER_ID)
+ self.persist(type="m.room.member", key=USER_ID, membership="join")
+ self.replicate()
+ self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())
+
+ # limit the replication rate
+ repl_transport = self.server_to_client_transport
+ repl_transport.autoflush = False
+
+ # build the join and message events and persist them in the same batch.
+ logger.info("----- build test events ------")
+ j2, j2ctx = self.build_event(
+ type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
+ )
+ msg, msgctx = self.build_event()
+ self.get_success(self.master_store.persist_events([
+ (j2, j2ctx),
+ (msg, msgctx),
+ ]))
+ self.replicate()
+
+ event_source = RoomEventSource(self.hs)
+ event_source.store = self.slaved_store
+ current_token = self.get_success(event_source.get_current_key())
+
+ # gradually stream out the replication
+ while repl_transport.buffer:
+ logger.info("------ flush ------")
+ repl_transport.flush(30)
+ self.pump(0)
+
+ prev_token = current_token
+ current_token = self.get_success(event_source.get_current_key())
+
+ # attempt to replicate the behaviour of the sync handler.
+ #
+ # First, we get a list of the rooms we are joined to
+ joined_rooms = self.get_success(
+ self.slaved_store.get_rooms_for_user_with_stream_ordering(
+ USER_ID_2,
+ ),
+ )
+
+ # Then, we get a list of the events since the last sync
+ membership_changes = self.get_success(
+ self.slaved_store.get_membership_changes_for_user(
+ USER_ID_2, prev_token, current_token,
+ )
+ )
+
+ logger.info(
+ "%s->%s: joined_rooms=%r membership_changes=%r",
+ prev_token,
+ current_token,
+ joined_rooms,
+ membership_changes,
+ )
+
+ # the membership change is only any use to us if the room is in the
+ # joined_rooms list.
+ if membership_changes:
+ self.assertEqual(
+ joined_rooms, {(ROOM_ID, j2.internal_metadata.stream_ordering)}
+ )
+
event_id = 0
- def persist(
+ def persist(self, backfill=False, **kwargs):
+ """
+ Returns:
+ synapse.events.FrozenEvent: The event that was persisted.
+ """
+ event, context = self.build_event(**kwargs)
+
+ if backfill:
+ self.get_success(
+ self.master_store.persist_events([(event, context)], backfilled=True)
+ )
+ else:
+ self.get_success(
+ self.master_store.persist_event(event, context)
+ )
+
+ return event
+
+ def build_event(
self,
sender=USER_ID,
room_id=ROOM_ID,
- type={},
+ type="m.room.message",
key=None,
internal={},
state=None,
- reset_state=False,
- backfill=False,
depth=None,
prev_events=[],
auth_events=[],
@@ -192,10 +320,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
push_actions=[],
**content
):
- """
- Returns:
- synapse.events.FrozenEvent: The event that was persisted.
- """
+
if depth is None:
depth = self.event_id
@@ -234,23 +359,11 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
)
else:
state_handler = self.hs.get_state_handler()
- context = self.get_success(state_handler.compute_event_context(event))
+ context = self.get_success(state_handler.compute_event_context(
+ event
+ ))
self.master_store.add_push_actions_to_staging(
event.event_id, {user_id: actions for user_id, actions in push_actions}
)
-
- ordering = None
- if backfill:
- self.get_success(
- self.master_store.persist_events([(event, context)], backfilled=True)
- )
- else:
- ordering, _ = self.get_success(
- self.master_store.persist_event(event, context)
- )
-
- if ordering:
- event.internal_metadata.stream_ordering = ordering
-
- return event
+ return event, context
diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index 9aa9dfe82e..d5a99f6caa 100644
--- a/tests/replication/tcp/streams/test_receipts.py
+++ b/tests/replication/tcp/streams/test_receipts.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.replication.tcp.streams import ReceiptsStreamRow
+from synapse.replication.tcp.streams._base import ReceiptsStreamRow
from tests.replication.tcp.streams._base import BaseStreamTestCase
diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/client/v1/test_admin.py
index ef38473bd6..c00ef21d75 100644
--- a/tests/rest/client/v1/test_admin.py
+++ b/tests/rest/client/v1/test_admin.py
@@ -21,6 +21,7 @@ from mock import Mock
from synapse.api.constants import UserTypes
from synapse.rest.client.v1 import admin, events, login, room
+from synapse.rest.client.v2_alpha import groups
from tests import unittest
@@ -490,3 +491,126 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(
expect_code, int(channel.result["code"]), msg=channel.result["body"],
)
+
+
+class DeleteGroupTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ groups.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_token = self.login("user", "pass")
+
+ def test_delete_group(self):
+ # Create a new group
+ request, channel = self.make_request(
+ "POST",
+ "/create_group".encode('ascii'),
+ access_token=self.admin_user_tok,
+ content={
+ "localpart": "test",
+ }
+ )
+
+ self.render(request)
+ self.assertEqual(
+ 200, int(channel.result["code"]), msg=channel.result["body"],
+ )
+
+ group_id = channel.json_body["group_id"]
+
+ self._check_group(group_id, expect_code=200)
+
+ # Invite/join another user
+
+ url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user)
+ request, channel = self.make_request(
+ "PUT",
+ url.encode('ascii'),
+ access_token=self.admin_user_tok,
+ content={}
+ )
+ self.render(request)
+ self.assertEqual(
+ 200, int(channel.result["code"]), msg=channel.result["body"],
+ )
+
+ url = "/groups/%s/self/accept_invite" % (group_id,)
+ request, channel = self.make_request(
+ "PUT",
+ url.encode('ascii'),
+ access_token=self.other_user_token,
+ content={}
+ )
+ self.render(request)
+ self.assertEqual(
+ 200, int(channel.result["code"]), msg=channel.result["body"],
+ )
+
+ # Check other user knows they're in the group
+ self.assertIn(group_id, self._get_groups_user_is_in(self.admin_user_tok))
+ self.assertIn(group_id, self._get_groups_user_is_in(self.other_user_token))
+
+ # Now delete the group
+ url = "/admin/delete_group/" + group_id
+ request, channel = self.make_request(
+ "POST",
+ url.encode('ascii'),
+ access_token=self.admin_user_tok,
+ content={
+ "localpart": "test",
+ }
+ )
+
+ self.render(request)
+ self.assertEqual(
+ 200, int(channel.result["code"]), msg=channel.result["body"],
+ )
+
+ # Check group returns 404
+ self._check_group(group_id, expect_code=404)
+
+ # Check users don't think they're in the group
+ self.assertNotIn(group_id, self._get_groups_user_is_in(self.admin_user_tok))
+ self.assertNotIn(group_id, self._get_groups_user_is_in(self.other_user_token))
+
+ def _check_group(self, group_id, expect_code):
+ """Assert that trying to fetch the given group results in the given
+ HTTP status code
+ """
+
+ url = "/groups/%s/profile" % (group_id,)
+ request, channel = self.make_request(
+ "GET",
+ url.encode('ascii'),
+ access_token=self.admin_user_tok,
+ )
+
+ self.render(request)
+ self.assertEqual(
+ expect_code, int(channel.result["code"]), msg=channel.result["body"],
+ )
+
+ def _get_groups_user_is_in(self, access_token):
+ """Returns the list of groups the user is in (given their access token)
+ """
+ request, channel = self.make_request(
+ "GET",
+ "/joined_groups".encode('ascii'),
+ access_token=access_token,
+ )
+
+ self.render(request)
+ self.assertEqual(
+ 200, int(channel.result["code"]), msg=channel.result["body"],
+ )
+
+ return channel.json_body["groups"]
diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py
index d3d43970fb..bbfc77e829 100644
--- a/tests/rest/client/v2_alpha/test_capabilities.py
+++ b/tests/rest/client/v2_alpha/test_capabilities.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.api.constants import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS
+from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS
from synapse.rest.client.v1 import admin, login
from synapse.rest.client.v2_alpha import capabilities
@@ -52,7 +52,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
for room_version in capabilities['m.room_versions']['available'].keys():
self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version)
self.assertEqual(
- DEFAULT_ROOM_VERSION, capabilities['m.room_versions']['default']
+ DEFAULT_ROOM_VERSION.identifier, capabilities['m.room_versions']['default']
)
def test_get_change_password_capabilities(self):
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index a45e6e5e1f..d3611ed21f 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -1,15 +1,18 @@
+import datetime
import json
from synapse.api.constants import LoginType
+from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
-from synapse.rest.client.v2_alpha.register import register_servlets
+from synapse.rest.client.v1 import admin, login
+from synapse.rest.client.v2_alpha import register, sync
from tests import unittest
class RegisterRestServletTestCase(unittest.HomeserverTestCase):
- servlets = [register_servlets]
+ servlets = [register.register_servlets]
def make_homeserver(self, reactor, clock):
@@ -181,3 +184,47 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
+
+
+class AccountValidityTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ register.register_servlets,
+ admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config.enable_registration = True
+ config.account_validity.enabled = True
+ config.account_validity.period = 604800000 # Time in ms for 1 week
+ self.hs = self.setup_test_homeserver(config=config)
+
+ return self.hs
+
+ def test_validity_period(self):
+ self.register_user("kermit", "monkey")
+ tok = self.login("kermit", "monkey")
+
+ # The specific endpoint doesn't matter, all we need is an authenticated
+ # endpoint.
+ request, channel = self.make_request(
+ b"GET", "/sync", access_token=tok,
+ )
+ self.render(request)
+
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
+
+ request, channel = self.make_request(
+ b"GET", "/sync", access_token=tok,
+ )
+ self.render(request)
+
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+ self.assertEquals(
+ channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result,
+ )
diff --git a/tests/server.py b/tests/server.py
index ea26dea623..8f89f4a83d 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -365,6 +365,7 @@ class FakeTransport(object):
disconnected = False
buffer = attr.ib(default=b'')
producer = attr.ib(default=None)
+ autoflush = attr.ib(default=True)
def getPeer(self):
return None
@@ -415,31 +416,44 @@ class FakeTransport(object):
def write(self, byt):
self.buffer = self.buffer + byt
- def _write():
- if not self.buffer:
- # nothing to do. Don't write empty buffers: it upsets the
- # TLSMemoryBIOProtocol
- return
-
- if self.disconnected:
- return
- logger.info("%s->%s: %s", self._protocol, self.other, self.buffer)
-
- if getattr(self.other, "transport") is not None:
- try:
- self.other.dataReceived(self.buffer)
- self.buffer = b""
- except Exception as e:
- logger.warning("Exception writing to protocol: %s", e)
- return
-
- self._reactor.callLater(0.0, _write)
-
# always actually do the write asynchronously. Some protocols (notably the
# TLSMemoryBIOProtocol) get very confused if a read comes back while they are
# still doing a write. Doing a callLater here breaks the cycle.
- self._reactor.callLater(0.0, _write)
+ if self.autoflush:
+ self._reactor.callLater(0.0, self.flush)
def writeSequence(self, seq):
for x in seq:
self.write(x)
+
+ def flush(self, maxbytes=None):
+ if not self.buffer:
+ # nothing to do. Don't write empty buffers: it upsets the
+ # TLSMemoryBIOProtocol
+ return
+
+ if self.disconnected:
+ return
+
+ if getattr(self.other, "transport") is None:
+ # the other has no transport yet; reschedule
+ if self.autoflush:
+ self._reactor.callLater(0.0, self.flush)
+ return
+
+ if maxbytes is not None:
+ to_write = self.buffer[:maxbytes]
+ else:
+ to_write = self.buffer
+
+ logger.info("%s->%s: %s", self._protocol, self.other, to_write)
+
+ try:
+ self.other.dataReceived(to_write)
+ except Exception as e:
+ logger.warning("Exception writing to protocol: %s", e)
+ return
+
+ self.buffer = self.buffer[len(to_write):]
+ if self.buffer and self.autoflush:
+ self._reactor.callLater(0.0, self.flush)
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index 9a5c816927..f448b01326 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -19,7 +19,8 @@ from six.moves import zip
import attr
-from synapse.api.constants import EventTypes, JoinRules, Membership, RoomVersions
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.room_versions import RoomVersions
from synapse.event_auth import auth_types_for_event
from synapse.events import FrozenEvent
from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
@@ -539,7 +540,7 @@ class StateTestCase(unittest.TestCase):
state_before = dict(state_at_event[prev_events[0]])
else:
state_d = resolve_events_with_store(
- RoomVersions.V2,
+ RoomVersions.V2.identifier,
[state_at_event[n] for n in prev_events],
event_map=event_map,
state_res_store=TestStateResolutionStore(event_map),
@@ -686,7 +687,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
# Test that we correctly handle passing `None` as the event_map
state_d = resolve_events_with_store(
- RoomVersions.V2,
+ RoomVersions.V2.identifier,
[self.state_at_bob, self.state_at_charlie],
event_map=None,
state_res_store=TestStateResolutionStore(self.event_map),
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index 0d2dc9f325..6bfaa00fe9 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -15,34 +15,77 @@
import signedjson.key
-from twisted.internet import defer
+from twisted.internet.defer import Deferred
import tests.unittest
-import tests.utils
+KEY_1 = signedjson.key.decode_verify_key_base64(
+ "ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw"
+)
+KEY_2 = signedjson.key.decode_verify_key_base64(
+ "ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
+)
-class KeyStoreTestCase(tests.unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
- self.store = hs.get_datastore()
-
- @defer.inlineCallbacks
+class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_server_verify_keys(self):
- key1 = signedjson.key.decode_verify_key_base64(
- "ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw"
- )
- key2 = signedjson.key.decode_verify_key_base64(
- "ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
+ store = self.hs.get_datastore()
+
+ d = store.store_server_verify_key("server1", "from_server", 0, KEY_1)
+ self.get_success(d)
+ d = store.store_server_verify_key("server1", "from_server", 0, KEY_2)
+ self.get_success(d)
+
+ d = store.get_server_verify_keys(
+ [
+ ("server1", "ed25519:key1"),
+ ("server1", "ed25519:key2"),
+ ("server1", "ed25519:key3"),
+ ]
)
- yield self.store.store_server_verify_key("server1", "from_server", 0, key1)
- yield self.store.store_server_verify_key("server1", "from_server", 0, key2)
+ res = self.get_success(d)
+
+ self.assertEqual(len(res.keys()), 3)
+ self.assertEqual(res[("server1", "ed25519:key1")].version, "key1")
+ self.assertEqual(res[("server1", "ed25519:key2")].version, "key2")
+
+ # non-existent result gives None
+ self.assertIsNone(res[("server1", "ed25519:key3")])
+
+ def test_cache(self):
+ """Check that updates correctly invalidate the cache."""
+
+ store = self.hs.get_datastore()
+
+ key_id_1 = "ed25519:key1"
+ key_id_2 = "ed25519:key2"
+
+ d = store.store_server_verify_key("srv1", "from_server", 0, KEY_1)
+ self.get_success(d)
+ d = store.store_server_verify_key("srv1", "from_server", 0, KEY_2)
+ self.get_success(d)
+
+ d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
+ res = self.get_success(d)
+ self.assertEqual(len(res.keys()), 2)
+ self.assertEqual(res[("srv1", key_id_1)], KEY_1)
+ self.assertEqual(res[("srv1", key_id_2)], KEY_2)
+
+ # we should be able to look up the same thing again without a db hit
+ res = store.get_server_verify_keys([("srv1", key_id_1)])
+ if isinstance(res, Deferred):
+ res = self.successResultOf(res)
+ self.assertEqual(len(res.keys()), 1)
+ self.assertEqual(res[("srv1", key_id_1)], KEY_1)
- res = yield self.store.get_server_verify_keys(
- "server1", ["ed25519:key1", "ed25519:key2", "ed25519:key3"]
+ new_key_2 = signedjson.key.get_verify_key(
+ signedjson.key.generate_signing_key("key2")
)
+ d = store.store_server_verify_key("srv1", "from_server", 10, new_key_2)
+ self.get_success(d)
+ d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
+ res = self.get_success(d)
self.assertEqual(len(res.keys()), 2)
- self.assertEqual(res["ed25519:key1"].version, "key1")
- self.assertEqual(res["ed25519:key2"].version, "key2")
+ self.assertEqual(res[("srv1", key_id_1)], KEY_1)
+ self.assertEqual(res[("srv1", key_id_2)], new_key_2)
diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py
deleted file mode 100644
index c7a63f39b9..0000000000
--- a/tests/storage/test_presence.py
+++ /dev/null
@@ -1,118 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from twisted.internet import defer
-
-from synapse.types import UserID
-
-from tests import unittest
-from tests.utils import setup_test_homeserver
-
-
-class PresenceStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield setup_test_homeserver(self.addCleanup)
-
- self.store = hs.get_datastore()
-
- self.u_apple = UserID.from_string("@apple:test")
- self.u_banana = UserID.from_string("@banana:test")
-
- @defer.inlineCallbacks
- def test_presence_list(self):
- self.assertEquals(
- [],
- (
- yield self.store.get_presence_list(
- observer_localpart=self.u_apple.localpart
- )
- ),
- )
- self.assertEquals(
- [],
- (
- yield self.store.get_presence_list(
- observer_localpart=self.u_apple.localpart, accepted=True
- )
- ),
- )
-
- yield self.store.add_presence_list_pending(
- observer_localpart=self.u_apple.localpart,
- observed_userid=self.u_banana.to_string(),
- )
-
- self.assertEquals(
- [{"observed_user_id": "@banana:test", "accepted": 0}],
- (
- yield self.store.get_presence_list(
- observer_localpart=self.u_apple.localpart
- )
- ),
- )
- self.assertEquals(
- [],
- (
- yield self.store.get_presence_list(
- observer_localpart=self.u_apple.localpart, accepted=True
- )
- ),
- )
-
- yield self.store.set_presence_list_accepted(
- observer_localpart=self.u_apple.localpart,
- observed_userid=self.u_banana.to_string(),
- )
-
- self.assertEquals(
- [{"observed_user_id": "@banana:test", "accepted": 1}],
- (
- yield self.store.get_presence_list(
- observer_localpart=self.u_apple.localpart
- )
- ),
- )
- self.assertEquals(
- [{"observed_user_id": "@banana:test", "accepted": 1}],
- (
- yield self.store.get_presence_list(
- observer_localpart=self.u_apple.localpart, accepted=True
- )
- ),
- )
-
- yield self.store.del_presence_list(
- observer_localpart=self.u_apple.localpart,
- observed_userid=self.u_banana.to_string(),
- )
-
- self.assertEquals(
- [],
- (
- yield self.store.get_presence_list(
- observer_localpart=self.u_apple.localpart
- )
- ),
- )
- self.assertEquals(
- [],
- (
- yield self.store.get_presence_list(
- observer_localpart=self.u_apple.localpart, accepted=True
- )
- ),
- )
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 3957561b1e..0fc5019e9f 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -18,7 +18,8 @@ from mock import Mock
from twisted.internet import defer
-from synapse.api.constants import EventTypes, Membership, RoomVersions
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.room_versions import RoomVersions
from synapse.types import RoomID, UserID
from tests import unittest
@@ -51,7 +52,7 @@ class RedactionTestCase(unittest.TestCase):
):
content = {"membership": membership}
content.update(extra_content)
- builder = self.event_builder_factory.new(
+ builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": EventTypes.Member,
@@ -74,7 +75,7 @@ class RedactionTestCase(unittest.TestCase):
def inject_message(self, room, user, body):
self.depth += 1
- builder = self.event_builder_factory.new(
+ builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": EventTypes.Message,
@@ -95,7 +96,7 @@ class RedactionTestCase(unittest.TestCase):
@defer.inlineCallbacks
def inject_redaction(self, room, event_id, user, reason):
- builder = self.event_builder_factory.new(
+ builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": EventTypes.Redaction,
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 7fa2f4fd70..063387863e 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -18,7 +18,8 @@ from mock import Mock
from twisted.internet import defer
-from synapse.api.constants import EventTypes, Membership, RoomVersions
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.room_versions import RoomVersions
from synapse.types import RoomID, UserID
from tests import unittest
@@ -49,7 +50,7 @@ class RoomMemberStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def inject_room_member(self, room, user, membership, replaces_state=None):
- builder = self.event_builder_factory.new(
+ builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": EventTypes.Member,
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 99cd3e09eb..78e260a7fa 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -17,7 +17,8 @@ import logging
from twisted.internet import defer
-from synapse.api.constants import EventTypes, Membership, RoomVersions
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.room_versions import RoomVersions
from synapse.storage.state import StateFilter
from synapse.types import RoomID, UserID
@@ -48,7 +49,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def inject_state_event(self, room, sender, typ, state_key, content):
- builder = self.event_builder_factory.new(
+ builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": typ,
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 7ee318e4e8..4c8f87e958 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -16,8 +16,8 @@
import unittest
from synapse import event_auth
-from synapse.api.constants import RoomVersions
from synapse.api.errors import AuthError
+from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent
@@ -37,7 +37,7 @@ class EventAuthTestCase(unittest.TestCase):
# creator should be able to send state
event_auth.check(
- RoomVersions.V1, _random_state_event(creator), auth_events,
+ RoomVersions.V1.identifier, _random_state_event(creator), auth_events,
do_sig_check=False,
)
@@ -45,7 +45,7 @@ class EventAuthTestCase(unittest.TestCase):
self.assertRaises(
AuthError,
event_auth.check,
- RoomVersions.V1,
+ RoomVersions.V1.identifier,
_random_state_event(joiner),
auth_events,
do_sig_check=False,
@@ -74,7 +74,7 @@ class EventAuthTestCase(unittest.TestCase):
self.assertRaises(
AuthError,
event_auth.check,
- RoomVersions.V1,
+ RoomVersions.V1.identifier,
_random_state_event(pleb),
auth_events,
do_sig_check=False,
@@ -82,7 +82,7 @@ class EventAuthTestCase(unittest.TestCase):
# king should be able to send state
event_auth.check(
- RoomVersions.V1, _random_state_event(king), auth_events,
+ RoomVersions.V1.identifier, _random_state_event(king), auth_events,
do_sig_check=False,
)
diff --git a/tests/test_state.py b/tests/test_state.py
index e20c33322a..5bcc6aaa18 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -18,13 +18,14 @@ from mock import Mock
from twisted.internet import defer
from synapse.api.auth import Auth
-from synapse.api.constants import EventTypes, Membership, RoomVersions
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent
from synapse.state import StateHandler, StateResolutionHandler
from tests import unittest
-from .utils import MockClock
+from .utils import MockClock, default_config
_next_event_id = 1000
@@ -118,7 +119,7 @@ class StateGroupStore(object):
self._event_to_state_group[event_id] = state_group
def get_room_version(self, room_id):
- return RoomVersions.V1
+ return RoomVersions.V1.identifier
class DictObj(dict):
@@ -159,6 +160,7 @@ class StateTestCase(unittest.TestCase):
self.store = StateGroupStore()
hs = Mock(
spec_set=[
+ "config",
"get_datastore",
"get_auth",
"get_state_handler",
@@ -166,6 +168,7 @@ class StateTestCase(unittest.TestCase):
"get_state_resolution_handler",
]
)
+ hs.config = default_config("tesths")
hs.get_datastore.return_value = self.store
hs.get_state_handler.return_value = None
hs.get_clock.return_value = MockClock()
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 455db9f276..3bdb500514 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -17,7 +17,7 @@ import logging
from twisted.internet import defer
from twisted.internet.defer import succeed
-from synapse.api.constants import RoomVersions
+from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent
from synapse.visibility import filter_events_for_server
@@ -124,7 +124,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def inject_visibility(self, user_id, visibility):
content = {"history_visibility": visibility}
- builder = self.event_builder_factory.new(
+ builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": "m.room.history_visibility",
@@ -145,7 +145,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
def inject_room_member(self, user_id, membership="join", extra_content={}):
content = {"membership": membership}
content.update(extra_content)
- builder = self.event_builder_factory.new(
+ builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": "m.room.member",
@@ -167,7 +167,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
def inject_message(self, user_id, content=None):
if content is None:
content = {"body": "testytest", "msgtype": "m.text"}
- builder = self.event_builder_factory.new(
+ builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": "m.room.message",
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index 61a55b461b..ec7ba9719c 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd.
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/utils.py b/tests/utils.py
index 1b8eeb5167..e6e6cb4c75 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -27,8 +27,9 @@ from six.moves.urllib import parse as urlparse
from twisted.internet import defer, reactor
-from synapse.api.constants import EventTypes, RoomVersions
+from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error
+from synapse.api.room_versions import RoomVersions
from synapse.config.homeserver import HomeServerConfig
from synapse.federation.transport import server as federation_server
from synapse.http.server import HttpServer
@@ -674,7 +675,7 @@ def create_room(hs, room_id, creator_id):
event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler()
- builder = event_builder_factory.new(
+ builder = event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": EventTypes.Create,
|