summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/crypto/test_keyring.py211
-rw-r--r--tests/handlers/test_typing.py4
-rw-r--r--tests/replication/slave/storage/_base.py28
-rw-r--r--tests/replication/slave/storage/test_events.py161
-rw-r--r--tests/server.py56
-rw-r--r--tests/storage/test_presence.py118
6 files changed, 302 insertions, 276 deletions
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py

index d643bec887..c30a1a69e7 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py
@@ -19,14 +19,14 @@ from mock import Mock import signedjson.key import signedjson.sign -from twisted.internet import defer, reactor +from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.crypto import keyring -from synapse.util import Clock, logcontext +from synapse.util import logcontext from synapse.util.logcontext import LoggingContext -from tests import unittest, utils +from tests import unittest class MockPerspectiveServer(object): @@ -52,75 +52,50 @@ class MockPerspectiveServer(object): return res -class KeyringTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): +class KeyringTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): self.mock_perspective_server = MockPerspectiveServer() self.http_client = Mock() - self.hs = yield utils.setup_test_homeserver( - self.addCleanup, handlers=None, http_client=self.http_client - ) + hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client) keys = self.mock_perspective_server.get_verify_keys() - self.hs.config.perspectives = {self.mock_perspective_server.server_name: keys} - - def assert_sentinel_context(self): - if LoggingContext.current_context() != LoggingContext.sentinel: - self.fail( - "Expected sentinel context but got %s" % ( - LoggingContext.current_context(), - ) - ) + hs.config.perspectives = {self.mock_perspective_server.server_name: keys} + return hs def check_context(self, _, expected): self.assertEquals( getattr(LoggingContext.current_context(), "request", None), expected ) - @defer.inlineCallbacks def test_wait_for_previous_lookups(self): kr = keyring.Keyring(self.hs) lookup_1_deferred = defer.Deferred() lookup_2_deferred = defer.Deferred() - with LoggingContext("one") as context_one: - context_one.request = "one" - - wait_1_deferred = kr.wait_for_previous_lookups( - ["server1"], {"server1": lookup_1_deferred} - ) - - # there were no previous lookups, so the deferred should be ready - self.assertTrue(wait_1_deferred.called) - # ... so we should have preserved the LoggingContext. - self.assertIs(LoggingContext.current_context(), context_one) - wait_1_deferred.addBoth(self.check_context, "one") - - with LoggingContext("two") as context_two: - context_two.request = "two" + # we run the lookup in a logcontext so that the patched inlineCallbacks can check + # it is doing the right thing with logcontexts. + wait_1_deferred = run_in_context( + kr.wait_for_previous_lookups, ["server1"], {"server1": lookup_1_deferred} + ) - # set off another wait. It should block because the first lookup - # hasn't yet completed. - wait_2_deferred = kr.wait_for_previous_lookups( - ["server1"], {"server1": lookup_2_deferred} - ) - self.assertFalse(wait_2_deferred.called) + # there were no previous lookups, so the deferred should be ready + self.successResultOf(wait_1_deferred) - # ... so we should have reset the LoggingContext. - self.assert_sentinel_context() + # set off another wait. It should block because the first lookup + # hasn't yet completed. + wait_2_deferred = run_in_context( + kr.wait_for_previous_lookups, ["server1"], {"server1": lookup_2_deferred} + ) - wait_2_deferred.addBoth(self.check_context, "two") + self.assertFalse(wait_2_deferred.called) - # let the first lookup complete (in the sentinel context) - lookup_1_deferred.callback(None) + # let the first lookup complete (in the sentinel context) + lookup_1_deferred.callback(None) - # now the second wait should complete and restore our - # loggingcontext. - yield wait_2_deferred + # now the second wait should complete. + self.successResultOf(wait_2_deferred) - @defer.inlineCallbacks def test_verify_json_objects_for_server_awaits_previous_requests(self): - clock = Clock(reactor) key1 = signedjson.key.generate_signing_key(1) kr = keyring.Keyring(self.hs) @@ -145,81 +120,103 @@ class KeyringTestCase(unittest.TestCase): self.http_client.post_json.side_effect = get_perspectives - with LoggingContext("11") as context_11: - context_11.request = "11" - - # start off a first set of lookups - res_deferreds = kr.verify_json_objects_for_server( - [("server10", json1), ("server11", {})] - ) - - # the unsigned json should be rejected pretty quickly - self.assertTrue(res_deferreds[1].called) - try: - yield res_deferreds[1] - self.assertFalse("unsigned json didn't cause a failure") - except SynapseError: - pass - - self.assertFalse(res_deferreds[0].called) - res_deferreds[0].addBoth(self.check_context, None) - - # wait a tick for it to send the request to the perspectives server - # (it first tries the datastore) - yield clock.sleep(1) # XXX find out why this takes so long! - self.http_client.post_json.assert_called_once() - - self.assertIs(LoggingContext.current_context(), context_11) - - context_12 = LoggingContext("12") - context_12.request = "12" - with logcontext.PreserveLoggingContext(context_12): - # a second request for a server with outstanding requests - # should block rather than start a second call + # start off a first set of lookups + @defer.inlineCallbacks + def first_lookup(): + with LoggingContext("11") as context_11: + context_11.request = "11" + + res_deferreds = kr.verify_json_objects_for_server( + [("server10", json1), ("server11", {})] + ) + + # the unsigned json should be rejected pretty quickly + self.assertTrue(res_deferreds[1].called) + try: + yield res_deferreds[1] + self.assertFalse("unsigned json didn't cause a failure") + except SynapseError: + pass + + self.assertFalse(res_deferreds[0].called) + res_deferreds[0].addBoth(self.check_context, None) + + yield logcontext.make_deferred_yieldable(res_deferreds[0]) + + # let verify_json_objects_for_server finish its work before we kill the + # logcontext + yield self.clock.sleep(0) + + d0 = first_lookup() + + # wait a tick for it to send the request to the perspectives server + # (it first tries the datastore) + self.pump() + self.http_client.post_json.assert_called_once() + + # a second request for a server with outstanding requests + # should block rather than start a second call + @defer.inlineCallbacks + def second_lookup(): + with LoggingContext("12") as context_12: + context_12.request = "12" self.http_client.post_json.reset_mock() self.http_client.post_json.return_value = defer.Deferred() res_deferreds_2 = kr.verify_json_objects_for_server( - [("server10", json1)] + [("server10", json1, )] ) - yield clock.sleep(1) - self.http_client.post_json.assert_not_called() res_deferreds_2[0].addBoth(self.check_context, None) + yield logcontext.make_deferred_yieldable(res_deferreds_2[0]) - # complete the first request - with logcontext.PreserveLoggingContext(): - persp_deferred.callback(persp_resp) - self.assertIs(LoggingContext.current_context(), context_11) + # let verify_json_objects_for_server finish its work before we kill the + # logcontext + yield self.clock.sleep(0) - with logcontext.PreserveLoggingContext(): - yield res_deferreds[0] - yield res_deferreds_2[0] + d2 = second_lookup() + + self.pump() + self.http_client.post_json.assert_not_called() + + # complete the first request + persp_deferred.callback(persp_resp) + self.get_success(d0) + self.get_success(d2) - @defer.inlineCallbacks def test_verify_json_for_server(self): kr = keyring.Keyring(self.hs) key1 = signedjson.key.generate_signing_key(1) - yield self.hs.datastore.store_server_verify_key( + r = self.hs.datastore.store_server_verify_key( "server9", "", time.time() * 1000, signedjson.key.get_verify_key(key1) ) + self.get_success(r) json1 = {} signedjson.sign.sign_json(json1, "server9", key1) - with LoggingContext("one") as context_one: - context_one.request = "one" + # should fail immediately on an unsigned object + d = _verify_json_for_server(kr, "server9", {}) + self.failureResultOf(d, SynapseError) + + d = _verify_json_for_server(kr, "server9", json1) + self.assertFalse(d.called) + self.get_success(d) - defer = kr.verify_json_for_server("server9", {}) - try: - yield defer - self.fail("should fail on unsigned json") - except SynapseError: - pass - self.assertIs(LoggingContext.current_context(), context_one) - defer = kr.verify_json_for_server("server9", json1) - self.assertFalse(defer.called) - self.assert_sentinel_context() - yield defer +@defer.inlineCallbacks +def run_in_context(f, *args, **kwargs): + with LoggingContext("testctx"): + rv = yield f(*args, **kwargs) + defer.returnValue(rv) + + +def _verify_json_for_server(keyring, server_name, json_object): + """thin wrapper around verify_json_for_server which makes sure it is wrapped + with the patched defer.inlineCallbacks. + """ + @defer.inlineCallbacks + def v(): + rv1 = yield keyring.verify_json_for_server(server_name, json_object) + defer.returnValue(rv1) - self.assertIs(LoggingContext.current_context(), context_one) + return run_in_context(v) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 6460cbc708..5a0b6c201c 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py
@@ -121,9 +121,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room - def get_current_user_in_room(room_id): + def get_current_users_in_room(room_id): return set(str(u) for u in self.room_members) - hs.get_state_handler().get_current_user_in_room = get_current_user_in_room + hs.get_state_handler().get_current_users_in_room = get_current_users_in_room self.datastore.get_user_directory_stream_pos.return_value = ( # we deliberately return a non-None stream pos to avoid doing an initial_spam diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 524af4f8d1..1f72a2a04f 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py
@@ -56,7 +56,9 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): client = client_factory.buildProtocol(None) client.makeConnection(FakeTransport(server, reactor)) - server.makeConnection(FakeTransport(client, reactor)) + + self.server_to_client_transport = FakeTransport(client, reactor) + server.makeConnection(self.server_to_client_transport) def replicate(self): """Tell the master side of replication that something has happened, and then @@ -69,6 +71,24 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): master_result = self.get_success(getattr(self.master_store, method)(*args)) slaved_result = self.get_success(getattr(self.slaved_store, method)(*args)) if expected_result is not None: - self.assertEqual(master_result, expected_result) - self.assertEqual(slaved_result, expected_result) - self.assertEqual(master_result, slaved_result) + self.assertEqual( + master_result, + expected_result, + "Expected master result to be %r but was %r" % ( + expected_result, master_result + ), + ) + self.assertEqual( + slaved_result, + expected_result, + "Expected slave result to be %r but was %r" % ( + expected_result, slaved_result + ), + ) + self.assertEqual( + master_result, + slaved_result, + "Slave result %r does not match master result %r" % ( + slaved_result, master_result + ), + ) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 1688a741d1..65ecff3bd6 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py
@@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging from canonicaljson import encode_canonical_json from synapse.events import FrozenEvent, _EventInternalMetadata from synapse.events.snapshot import EventContext +from synapse.handlers.room import RoomEventSource from synapse.replication.slave.storage.events import SlavedEventStore from synapse.storage.roommember import RoomsForUser @@ -26,6 +28,8 @@ USER_ID_2 = "@bright:blue" OUTLIER = {"outlier": True} ROOM_ID = "!room:blue" +logger = logging.getLogger(__name__) + def dict_equals(self, other): me = encode_canonical_json(self.get_pdu_json()) @@ -172,18 +176,142 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): {"highlight_count": 1, "notify_count": 2}, ) + def test_get_rooms_for_user_with_stream_ordering(self): + """Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated + by rows in the events stream + """ + self.persist(type="m.room.create", key="", creator=USER_ID) + self.persist(type="m.room.member", key=USER_ID, membership="join") + self.replicate() + self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set()) + + j2 = self.persist( + type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" + ) + self.replicate() + self.check( + "get_rooms_for_user_with_stream_ordering", + (USER_ID_2,), + {(ROOM_ID, j2.internal_metadata.stream_ordering)}, + ) + + def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self): + """Check that current_state invalidation happens correctly with multiple events + in the persistence batch. + + This test attempts to reproduce a race condition between the event persistence + loop and a worker-based Sync handler. + + The problem occurred when the master persisted several events in one batch. It + only updates the current_state at the end of each batch, so the obvious thing + to do is then to issue a current_state_delta stream update corresponding to the + last stream_id in the batch. + + However, that raises the possibility that a worker will see the replication + notification for a join event before the current_state caches are invalidated. + + The test involves: + * creating a join and a message event for a user, and persisting them in the + same batch + + * controlling the replication stream so that updates are sent gradually + + * between each bunch of replication updates, check that we see a consistent + snapshot of the state. + """ + self.persist(type="m.room.create", key="", creator=USER_ID) + self.persist(type="m.room.member", key=USER_ID, membership="join") + self.replicate() + self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set()) + + # limit the replication rate + repl_transport = self.server_to_client_transport + repl_transport.autoflush = False + + # build the join and message events and persist them in the same batch. + logger.info("----- build test events ------") + j2, j2ctx = self.build_event( + type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" + ) + msg, msgctx = self.build_event() + self.get_success(self.master_store.persist_events([ + (j2, j2ctx), + (msg, msgctx), + ])) + self.replicate() + + event_source = RoomEventSource(self.hs) + event_source.store = self.slaved_store + current_token = self.get_success(event_source.get_current_key()) + + # gradually stream out the replication + while repl_transport.buffer: + logger.info("------ flush ------") + repl_transport.flush(30) + self.pump(0) + + prev_token = current_token + current_token = self.get_success(event_source.get_current_key()) + + # attempt to replicate the behaviour of the sync handler. + # + # First, we get a list of the rooms we are joined to + joined_rooms = self.get_success( + self.slaved_store.get_rooms_for_user_with_stream_ordering( + USER_ID_2, + ), + ) + + # Then, we get a list of the events since the last sync + membership_changes = self.get_success( + self.slaved_store.get_membership_changes_for_user( + USER_ID_2, prev_token, current_token, + ) + ) + + logger.info( + "%s->%s: joined_rooms=%r membership_changes=%r", + prev_token, + current_token, + joined_rooms, + membership_changes, + ) + + # the membership change is only any use to us if the room is in the + # joined_rooms list. + if membership_changes: + self.assertEqual( + joined_rooms, {(ROOM_ID, j2.internal_metadata.stream_ordering)} + ) + event_id = 0 - def persist( + def persist(self, backfill=False, **kwargs): + """ + Returns: + synapse.events.FrozenEvent: The event that was persisted. + """ + event, context = self.build_event(**kwargs) + + if backfill: + self.get_success( + self.master_store.persist_events([(event, context)], backfilled=True) + ) + else: + self.get_success( + self.master_store.persist_event(event, context) + ) + + return event + + def build_event( self, sender=USER_ID, room_id=ROOM_ID, - type={}, + type="m.room.message", key=None, internal={}, state=None, - reset_state=False, - backfill=False, depth=None, prev_events=[], auth_events=[], @@ -192,10 +320,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): push_actions=[], **content ): - """ - Returns: - synapse.events.FrozenEvent: The event that was persisted. - """ + if depth is None: depth = self.event_id @@ -234,23 +359,11 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): ) else: state_handler = self.hs.get_state_handler() - context = self.get_success(state_handler.compute_event_context(event)) + context = self.get_success(state_handler.compute_event_context( + event + )) self.master_store.add_push_actions_to_staging( event.event_id, {user_id: actions for user_id, actions in push_actions} ) - - ordering = None - if backfill: - self.get_success( - self.master_store.persist_events([(event, context)], backfilled=True) - ) - else: - ordering, _ = self.get_success( - self.master_store.persist_event(event, context) - ) - - if ordering: - event.internal_metadata.stream_ordering = ordering - - return event + return event, context diff --git a/tests/server.py b/tests/server.py
index ea26dea623..8f89f4a83d 100644 --- a/tests/server.py +++ b/tests/server.py
@@ -365,6 +365,7 @@ class FakeTransport(object): disconnected = False buffer = attr.ib(default=b'') producer = attr.ib(default=None) + autoflush = attr.ib(default=True) def getPeer(self): return None @@ -415,31 +416,44 @@ class FakeTransport(object): def write(self, byt): self.buffer = self.buffer + byt - def _write(): - if not self.buffer: - # nothing to do. Don't write empty buffers: it upsets the - # TLSMemoryBIOProtocol - return - - if self.disconnected: - return - logger.info("%s->%s: %s", self._protocol, self.other, self.buffer) - - if getattr(self.other, "transport") is not None: - try: - self.other.dataReceived(self.buffer) - self.buffer = b"" - except Exception as e: - logger.warning("Exception writing to protocol: %s", e) - return - - self._reactor.callLater(0.0, _write) - # always actually do the write asynchronously. Some protocols (notably the # TLSMemoryBIOProtocol) get very confused if a read comes back while they are # still doing a write. Doing a callLater here breaks the cycle. - self._reactor.callLater(0.0, _write) + if self.autoflush: + self._reactor.callLater(0.0, self.flush) def writeSequence(self, seq): for x in seq: self.write(x) + + def flush(self, maxbytes=None): + if not self.buffer: + # nothing to do. Don't write empty buffers: it upsets the + # TLSMemoryBIOProtocol + return + + if self.disconnected: + return + + if getattr(self.other, "transport") is None: + # the other has no transport yet; reschedule + if self.autoflush: + self._reactor.callLater(0.0, self.flush) + return + + if maxbytes is not None: + to_write = self.buffer[:maxbytes] + else: + to_write = self.buffer + + logger.info("%s->%s: %s", self._protocol, self.other, to_write) + + try: + self.other.dataReceived(to_write) + except Exception as e: + logger.warning("Exception writing to protocol: %s", e) + return + + self.buffer = self.buffer[len(to_write):] + if self.buffer and self.autoflush: + self._reactor.callLater(0.0, self.flush) diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py deleted file mode 100644
index c7a63f39b9..0000000000 --- a/tests/storage/test_presence.py +++ /dev/null
@@ -1,118 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from twisted.internet import defer - -from synapse.types import UserID - -from tests import unittest -from tests.utils import setup_test_homeserver - - -class PresenceStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield setup_test_homeserver(self.addCleanup) - - self.store = hs.get_datastore() - - self.u_apple = UserID.from_string("@apple:test") - self.u_banana = UserID.from_string("@banana:test") - - @defer.inlineCallbacks - def test_presence_list(self): - self.assertEquals( - [], - ( - yield self.store.get_presence_list( - observer_localpart=self.u_apple.localpart - ) - ), - ) - self.assertEquals( - [], - ( - yield self.store.get_presence_list( - observer_localpart=self.u_apple.localpart, accepted=True - ) - ), - ) - - yield self.store.add_presence_list_pending( - observer_localpart=self.u_apple.localpart, - observed_userid=self.u_banana.to_string(), - ) - - self.assertEquals( - [{"observed_user_id": "@banana:test", "accepted": 0}], - ( - yield self.store.get_presence_list( - observer_localpart=self.u_apple.localpart - ) - ), - ) - self.assertEquals( - [], - ( - yield self.store.get_presence_list( - observer_localpart=self.u_apple.localpart, accepted=True - ) - ), - ) - - yield self.store.set_presence_list_accepted( - observer_localpart=self.u_apple.localpart, - observed_userid=self.u_banana.to_string(), - ) - - self.assertEquals( - [{"observed_user_id": "@banana:test", "accepted": 1}], - ( - yield self.store.get_presence_list( - observer_localpart=self.u_apple.localpart - ) - ), - ) - self.assertEquals( - [{"observed_user_id": "@banana:test", "accepted": 1}], - ( - yield self.store.get_presence_list( - observer_localpart=self.u_apple.localpart, accepted=True - ) - ), - ) - - yield self.store.del_presence_list( - observer_localpart=self.u_apple.localpart, - observed_userid=self.u_banana.to_string(), - ) - - self.assertEquals( - [], - ( - yield self.store.get_presence_list( - observer_localpart=self.u_apple.localpart - ) - ), - ) - self.assertEquals( - [], - ( - yield self.store.get_presence_list( - observer_localpart=self.u_apple.localpart, accepted=True - ) - ), - )