summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorAndrew Morgan <andrew@amorgan.xyz>2020-10-21 14:51:58 +0100
committerAndrew Morgan <andrew@amorgan.xyz>2020-10-21 14:51:58 +0100
commit5b0b10352165fd3a46bc55b675678a9dc326e817 (patch)
tree69f8635e6c95301b62394d9f48f07d3866e6cf84 /tests
parentMerge commit '920dd1083' into anoa/dinsic_release_1_21_x (diff)
parentEscape the error description on the sso_error template. (#8405) (diff)
downloadsynapse-5b0b10352165fd3a46bc55b675678a9dc326e817.tar.xz
Merge commit '31acc5c30' into anoa/dinsic_release_1_21_x
* commit '31acc5c30':
  Escape the error description on the sso_error template. (#8405)
  Fix occasional "Re-starting finished log context" from keyring (#8398)
  Allow existing users to login via OpenID Connect. (#8345)
  Fix schema delta for servers that have not backfilled (#8396)
  Fix MultiWriteIdGenerator's handling of restarts. (#8374)
  s/URLs/variables in changelog
  s/accidentally/incorrectly in changelog
  Update changelog wording
  Add type annotations to SimpleHttpClient (#8372)
  Add new sequences to port DB script (#8387)
  Add EventStreamPosition type (#8388)
  Mark the shadow_banned column as boolean in synapse_port_db. (#8386)
Diffstat (limited to 'tests')
-rw-r--r--tests/crypto/test_keyring.py120
-rw-r--r--tests/handlers/test_oidc.py35
-rw-r--r--tests/replication/slave/storage/test_events.py12
-rw-r--r--tests/storage/test_id_generators.py119
4 files changed, 211 insertions, 75 deletions
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py

index 2e6e7abf1f..5cf408f21f 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py
@@ -23,6 +23,7 @@ from nacl.signing import SigningKey from signedjson.key import encode_verify_key_base64, get_verify_key from twisted.internet import defer +from twisted.internet.defer import Deferred, ensureDeferred from synapse.api.errors import SynapseError from synapse.crypto import keyring @@ -33,7 +34,6 @@ from synapse.crypto.keyring import ( ) from synapse.logging.context import ( LoggingContext, - PreserveLoggingContext, current_context, make_deferred_yieldable, ) @@ -68,54 +68,40 @@ class MockPerspectiveServer: class KeyringTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): - self.mock_perspective_server = MockPerspectiveServer() - self.http_client = Mock() - - config = self.default_config() - config["trusted_key_servers"] = [ - { - "server_name": self.mock_perspective_server.server_name, - "verify_keys": self.mock_perspective_server.get_verify_keys(), - } - ] - - return self.setup_test_homeserver( - handlers=None, http_client=self.http_client, config=config - ) - - def check_context(self, _, expected): + def check_context(self, val, expected): self.assertEquals(getattr(current_context(), "request", None), expected) + return val def test_verify_json_objects_for_server_awaits_previous_requests(self): - key1 = signedjson.key.generate_signing_key(1) + mock_fetcher = keyring.KeyFetcher() + mock_fetcher.get_keys = Mock() + kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,)) - kr = keyring.Keyring(self.hs) + # a signed object that we are going to try to validate + key1 = signedjson.key.generate_signing_key(1) json1 = {} signedjson.sign.sign_json(json1, "server10", key1) - persp_resp = { - "server_keys": [ - self.mock_perspective_server.get_signed_key( - "server10", signedjson.key.get_verify_key(key1) - ) - ] - } - persp_deferred = defer.Deferred() + # start off a first set of lookups. We make the mock fetcher block until this + # deferred completes. + first_lookup_deferred = Deferred() + + async def first_lookup_fetch(keys_to_fetch): + self.assertEquals(current_context().request, "context_11") + self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}}) - async def get_perspectives(**kwargs): - self.assertEquals(current_context().request, "11") - with PreserveLoggingContext(): - await persp_deferred - return persp_resp + await make_deferred_yieldable(first_lookup_deferred) + return { + "server10": { + get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100) + } + } - self.http_client.post_json.side_effect = get_perspectives + mock_fetcher.get_keys.side_effect = first_lookup_fetch - # start off a first set of lookups - @defer.inlineCallbacks - def first_lookup(): - with LoggingContext("11") as context_11: - context_11.request = "11" + async def first_lookup(): + with LoggingContext("context_11") as context_11: + context_11.request = "context_11" res_deferreds = kr.verify_json_objects_for_server( [("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")] @@ -124,7 +110,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): # the unsigned json should be rejected pretty quickly self.assertTrue(res_deferreds[1].called) try: - yield res_deferreds[1] + await res_deferreds[1] self.assertFalse("unsigned json didn't cause a failure") except SynapseError: pass @@ -132,45 +118,51 @@ class KeyringTestCase(unittest.HomeserverTestCase): self.assertFalse(res_deferreds[0].called) res_deferreds[0].addBoth(self.check_context, None) - yield make_deferred_yieldable(res_deferreds[0]) + await make_deferred_yieldable(res_deferreds[0]) - # let verify_json_objects_for_server finish its work before we kill the - # logcontext - yield self.clock.sleep(0) + d0 = ensureDeferred(first_lookup()) - d0 = first_lookup() - - # wait a tick for it to send the request to the perspectives server - # (it first tries the datastore) - self.pump() - self.http_client.post_json.assert_called_once() + mock_fetcher.get_keys.assert_called_once() # a second request for a server with outstanding requests # should block rather than start a second call - @defer.inlineCallbacks - def second_lookup(): - with LoggingContext("12") as context_12: - context_12.request = "12" - self.http_client.post_json.reset_mock() - self.http_client.post_json.return_value = defer.Deferred() + + async def second_lookup_fetch(keys_to_fetch): + self.assertEquals(current_context().request, "context_12") + return { + "server10": { + get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100) + } + } + + mock_fetcher.get_keys.reset_mock() + mock_fetcher.get_keys.side_effect = second_lookup_fetch + second_lookup_state = [0] + + async def second_lookup(): + with LoggingContext("context_12") as context_12: + context_12.request = "context_12" res_deferreds_2 = kr.verify_json_objects_for_server( [("server10", json1, 0, "test")] ) res_deferreds_2[0].addBoth(self.check_context, None) - yield make_deferred_yieldable(res_deferreds_2[0]) + second_lookup_state[0] = 1 + await make_deferred_yieldable(res_deferreds_2[0]) + second_lookup_state[0] = 2 - # let verify_json_objects_for_server finish its work before we kill the - # logcontext - yield self.clock.sleep(0) - - d2 = second_lookup() + d2 = ensureDeferred(second_lookup()) self.pump() - self.http_client.post_json.assert_not_called() + # the second request should be pending, but the fetcher should not yet have been + # called + self.assertEqual(second_lookup_state[0], 1) + mock_fetcher.get_keys.assert_not_called() # complete the first request - persp_deferred.callback(persp_resp) + first_lookup_deferred.callback(None) + + # and now both verifications should succeed. self.get_success(d0) self.get_success(d2) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 89ec5fcb31..5910772aa8 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py
@@ -617,3 +617,38 @@ class OidcHandlerTestCase(HomeserverTestCase): ) ) self.assertEqual(mxid, "@test_user_2:test") + + # Test if the mxid is already taken + store = self.hs.get_datastore() + user3 = UserID.from_string("@test_user_3:test") + self.get_success( + store.register_user(user_id=user3.to_string(), password_hash=None) + ) + userinfo = {"sub": "test3", "username": "test_user_3"} + e = self.get_failure( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ), + MappingException, + ) + self.assertEqual(str(e.value), "mxid '@test_user_3:test' is already taken") + + @override_config({"oidc_config": {"allow_existing_users": True}}) + def test_map_userinfo_to_existing_user(self): + """Existing users can log in with OpenID Connect when allow_existing_users is True.""" + store = self.hs.get_datastore() + user4 = UserID.from_string("@test_user_4:test") + self.get_success( + store.register_user(user_id=user4.to_string(), password_hash=None) + ) + userinfo = { + "sub": "test4", + "username": "test_user_4", + } + token = {} + mxid = self.get_success( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ) + ) + self.assertEqual(mxid, "@test_user_4:test") diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index bc578411d6..c0ee1cfbd6 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py
@@ -20,6 +20,7 @@ from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_ from synapse.handlers.room import RoomEventSource from synapse.replication.slave.storage.events import SlavedEventStore from synapse.storage.roommember import RoomsForUser +from synapse.types import PersistedEventPosition from tests.server import FakeTransport @@ -204,10 +205,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" ) self.replicate() + + expected_pos = PersistedEventPosition( + "master", j2.internal_metadata.stream_ordering + ) self.check( "get_rooms_for_user_with_stream_ordering", (USER_ID_2,), - {(ROOM_ID, j2.internal_metadata.stream_ordering)}, + {(ROOM_ID, expected_pos)}, ) def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self): @@ -293,9 +298,10 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): # the membership change is only any use to us if the room is in the # joined_rooms list. if membership_changes: - self.assertEqual( - joined_rooms, {(ROOM_ID, j2.internal_metadata.stream_ordering)} + expected_pos = PersistedEventPosition( + "master", j2.internal_metadata.stream_ordering ) + self.assertEqual(joined_rooms, {(ROOM_ID, expected_pos)}) event_id = 0 diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index fb8f5bc255..d4ff55fbff 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py
@@ -43,16 +43,20 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): """ ) - def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator: + def _create_id_generator( + self, instance_name="master", writers=["master"] + ) -> MultiWriterIdGenerator: def _create(conn): return MultiWriterIdGenerator( conn, self.db_pool, + stream_name="test_stream", instance_name=instance_name, table="foobar", instance_column="instance_name", id_column="stream_id", sequence_name="foobar_seq", + writers=writers, ) return self.get_success(self.db_pool.runWithConnection(_create)) @@ -68,6 +72,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)", (instance_name,), ) + txn.execute( + """ + INSERT INTO stream_positions VALUES ('test_stream', ?, lastval()) + ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval() + """, + (instance_name,), + ) self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) @@ -81,6 +92,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,), ) txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,)) + txn.execute( + """ + INSERT INTO stream_positions VALUES ('test_stream', ?, ?) + ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ? + """, + (instance_name, stream_id, stream_id), + ) self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert)) @@ -179,8 +197,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self._insert_rows("first", 3) self._insert_rows("second", 4) - first_id_gen = self._create_id_generator("first") - second_id_gen = self._create_id_generator("second") + first_id_gen = self._create_id_generator("first", writers=["first", "second"]) + second_id_gen = self._create_id_generator("second", writers=["first", "second"]) self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3) @@ -262,7 +280,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self._insert_row_with_id("first", 3) self._insert_row_with_id("second", 5) - id_gen = self._create_id_generator("first") + id_gen = self._create_id_generator("first", writers=["first", "second"]) self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) @@ -300,7 +318,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self._insert_row_with_id("first", 3) self._insert_row_with_id("second", 5) - id_gen = self._create_id_generator("first") + id_gen = self._create_id_generator("first", writers=["first", "second"]) self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) @@ -319,6 +337,80 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # `persisted_upto_position` in this case, then it will be correct in the # other cases that are tested above (since they'll hit the same code). + def test_restart_during_out_of_order_persistence(self): + """Test that restarting a process while another process is writing out + of order updates are handled correctly. + """ + + # Prefill table with 7 rows written by 'master' + self._insert_rows("master", 7) + + id_gen = self._create_id_generator() + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + + # Persist two rows at once + ctx1 = self.get_success(id_gen.get_next()) + ctx2 = self.get_success(id_gen.get_next()) + + s1 = self.get_success(ctx1.__aenter__()) + s2 = self.get_success(ctx2.__aenter__()) + + self.assertEqual(s1, 8) + self.assertEqual(s2, 9) + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + + # We finish persisting the second row before restart + self.get_success(ctx2.__aexit__(None, None, None)) + + # We simulate a restart of another worker by just creating a new ID gen. + id_gen_worker = self._create_id_generator("worker") + + # Restarted worker should not see the second persisted row + self.assertEqual(id_gen_worker.get_positions(), {"master": 7}) + self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7) + + # Now if we persist the first row then both instances should jump ahead + # correctly. + self.get_success(ctx1.__aexit__(None, None, None)) + + self.assertEqual(id_gen.get_positions(), {"master": 9}) + id_gen_worker.advance("master", 9) + self.assertEqual(id_gen_worker.get_positions(), {"master": 9}) + + def test_writer_config_change(self): + """Test that changing the writer config correctly works. + """ + + self._insert_row_with_id("first", 3) + self._insert_row_with_id("second", 5) + + # Initial config has two writers + id_gen = self._create_id_generator("first", writers=["first", "second"]) + self.assertEqual(id_gen.get_persisted_upto_position(), 3) + + # New config removes one of the configs. Note that if the writer is + # removed from config we assume that it has been shut down and has + # finished persisting, hence why the persisted upto position is 5. + id_gen_2 = self._create_id_generator("second", writers=["second"]) + self.assertEqual(id_gen_2.get_persisted_upto_position(), 5) + + # This config points to a single, previously unused writer. + id_gen_3 = self._create_id_generator("third", writers=["third"]) + self.assertEqual(id_gen_3.get_persisted_upto_position(), 5) + + # Check that we get a sane next stream ID with this new config. + + async def _get_next_async(): + async with id_gen_3.get_next() as stream_id: + self.assertEqual(stream_id, 6) + + self.get_success(_get_next_async()) + self.assertEqual(id_gen_3.get_persisted_upto_position(), 6) + class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): """Tests MultiWriterIdGenerator that produce *negative* stream IDs. @@ -345,16 +437,20 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): """ ) - def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator: + def _create_id_generator( + self, instance_name="master", writers=["master"] + ) -> MultiWriterIdGenerator: def _create(conn): return MultiWriterIdGenerator( conn, self.db_pool, + stream_name="test_stream", instance_name=instance_name, table="foobar", instance_column="instance_name", id_column="stream_id", sequence_name="foobar_seq", + writers=writers, positive=False, ) @@ -368,6 +464,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): txn.execute( "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,), ) + txn.execute( + """ + INSERT INTO stream_positions VALUES ('test_stream', ?, ?) + ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ? + """, + (instance_name, -stream_id, -stream_id), + ) self.get_success(self.db_pool.runInteraction("_insert_row", _insert)) @@ -409,8 +512,8 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): """Tests that having multiple instances that get advanced over federation works corretly. """ - id_gen_1 = self._create_id_generator("first") - id_gen_2 = self._create_id_generator("second") + id_gen_1 = self._create_id_generator("first", writers=["first", "second"]) + id_gen_2 = self._create_id_generator("second", writers=["first", "second"]) async def _get_next_async(): async with id_gen_1.get_next() as stream_id: