diff options
-rw-r--r-- | changelog.d/4840.feature | 1 | ||||
-rw-r--r-- | changelog.d/4912.misc | 1 | ||||
-rw-r--r-- | changelog.d/4913.misc | 1 | ||||
-rwxr-xr-x | scripts-dev/check-newsfragment | 4 | ||||
-rw-r--r-- | synapse/federation/transport/client.py | 21 | ||||
-rw-r--r-- | synapse/http/matrixfederationclient.py | 85 | ||||
-rw-r--r-- | tests/handlers/test_register.py | 124 | ||||
-rw-r--r-- | tests/handlers/test_typing.py | 2 | ||||
-rw-r--r-- | tests/http/test_fedclient.py | 99 | ||||
-rw-r--r-- | tests/rest/client/v1/utils.py | 125 | ||||
-rw-r--r-- | tests/server_notices/test_resource_limits_server_notices.py | 92 | ||||
-rw-r--r-- | tests/unittest.py | 12 |
12 files changed, 318 insertions, 249 deletions
diff --git a/changelog.d/4840.feature b/changelog.d/4840.feature new file mode 100644 index 0000000000..9d1fd59053 --- /dev/null +++ b/changelog.d/4840.feature @@ -0,0 +1 @@ +Remove trailing slashes from certain outbound federation requests. Retry if receiving a 404. Context: #3622. \ No newline at end of file diff --git a/changelog.d/4912.misc b/changelog.d/4912.misc new file mode 100644 index 0000000000..f05a239187 --- /dev/null +++ b/changelog.d/4912.misc @@ -0,0 +1 @@ +Allow newsfragments to end with exclamation marks. Exciting! diff --git a/changelog.d/4913.misc b/changelog.d/4913.misc new file mode 100644 index 0000000000..9e835badc0 --- /dev/null +++ b/changelog.d/4913.misc @@ -0,0 +1 @@ +Refactor some more tests to use HomeserverTestCase. diff --git a/scripts-dev/check-newsfragment b/scripts-dev/check-newsfragment index e0ac84198e..0ec5075e79 100755 --- a/scripts-dev/check-newsfragment +++ b/scripts-dev/check-newsfragment @@ -31,8 +31,8 @@ echo # check that any new newsfiles on this branch end with a full stop. for f in `git diff --name-only FETCH_HEAD... -- changelog.d`; do lastchar=`tr -d '\n' < $f | tail -c 1` - if [ $lastchar != '.' ]; then - echo -e "\e[31mERROR: newsfragment $f does not end with a '.'\e[39m" >&2 + if [ $lastchar != '.' -a $lastchar != '!' ]; then + echo -e "\e[31mERROR: newsfragment $f does not end with a '.' or '!'\e[39m" >&2 exit 1 fi done diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 8e2be218e2..0cdb31178f 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -51,9 +51,10 @@ class TransportLayerClient(object): logger.debug("get_room_state dest=%s, room=%s", destination, room_id) - path = _create_v1_path("/state/%s/", room_id) + path = _create_v1_path("/state/%s", room_id) return self.client.get_json( destination, path=path, args={"event_id": event_id}, + try_trailing_slash_on_400=True, ) @log_function @@ -73,9 +74,10 @@ class TransportLayerClient(object): logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id) - path = _create_v1_path("/state_ids/%s/", room_id) + path = _create_v1_path("/state_ids/%s", room_id) return self.client.get_json( destination, path=path, args={"event_id": event_id}, + try_trailing_slash_on_400=True, ) @log_function @@ -95,8 +97,11 @@ class TransportLayerClient(object): logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id) - path = _create_v1_path("/event/%s/", event_id) - return self.client.get_json(destination, path=path, timeout=timeout) + path = _create_v1_path("/event/%s", event_id) + return self.client.get_json( + destination, path=path, timeout=timeout, + try_trailing_slash_on_400=True, + ) @log_function def backfill(self, destination, room_id, event_tuples, limit): @@ -121,7 +126,7 @@ class TransportLayerClient(object): # TODO: raise? return - path = _create_v1_path("/backfill/%s/", room_id) + path = _create_v1_path("/backfill/%s", room_id) args = { "v": event_tuples, @@ -132,6 +137,7 @@ class TransportLayerClient(object): destination, path=path, args=args, + try_trailing_slash_on_400=True, ) @defer.inlineCallbacks @@ -176,6 +182,7 @@ class TransportLayerClient(object): json_data_callback=json_data_callback, long_retries=True, backoff_on_404=True, # If we get a 404 the other side has gone + try_trailing_slash_on_400=True, ) defer.returnValue(response) @@ -959,7 +966,7 @@ def _create_v1_path(path, *args): Example: - _create_v1_path("/event/%s/", event_id) + _create_v1_path("/event/%s", event_id) Args: path (str): String template for the path @@ -980,7 +987,7 @@ def _create_v2_path(path, *args): Example: - _create_v2_path("/event/%s/", event_id) + _create_v2_path("/event/%s", event_id) Args: path (str): String template for the path diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 1682c9af13..8e855d13d6 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -189,6 +189,57 @@ class MatrixFederationHttpClient(object): self._cooperator = Cooperator(scheduler=schedule) @defer.inlineCallbacks + def _send_request_with_optional_trailing_slash( + self, + request, + try_trailing_slash_on_400=False, + **send_request_args + ): + """Wrapper for _send_request which can optionally retry the request + upon receiving a combination of a 400 HTTP response code and a + 'M_UNRECOGNIZED' errcode. This is a workaround for Synapse <= v0.99.3 + due to #3622. + + Args: + request (MatrixFederationRequest): details of request to be sent + try_trailing_slash_on_400 (bool): Whether on receiving a 400 + 'M_UNRECOGNIZED' from the server to retry the request with a + trailing slash appended to the request path. + send_request_args (Dict): A dictionary of arguments to pass to + `_send_request()`. + + Raises: + HttpResponseException: If we get an HTTP response code >= 300 + (except 429). + + Returns: + Deferred[Dict]: Parsed JSON response body. + """ + try: + response = yield self._send_request( + request, **send_request_args + ) + except HttpResponseException as e: + # Received an HTTP error > 300. Check if it meets the requirements + # to retry with a trailing slash + if not try_trailing_slash_on_400: + raise + + if e.code != 400 or e.to_synapse_error().errcode != "M_UNRECOGNIZED": + raise + + # Retry with a trailing slash if we received a 400 with + # 'M_UNRECOGNIZED' which some endpoints can return when omitting a + # trailing slash on Synapse <= v0.99.3. + request.path += "/" + + response = yield self._send_request( + request, **send_request_args + ) + + defer.returnValue(response) + + @defer.inlineCallbacks def _send_request( self, request, @@ -196,7 +247,7 @@ class MatrixFederationHttpClient(object): timeout=None, long_retries=False, ignore_backoff=False, - backoff_on_404=False + backoff_on_404=False, ): """ Sends a request to the given server. @@ -473,7 +524,8 @@ class MatrixFederationHttpClient(object): json_data_callback=None, long_retries=False, timeout=None, ignore_backoff=False, - backoff_on_404=False): + backoff_on_404=False, + try_trailing_slash_on_400=False): """ Sends the specifed json data using PUT Args: @@ -493,7 +545,12 @@ class MatrixFederationHttpClient(object): and try the request anyway. backoff_on_404 (bool): True if we should count a 404 response as a failure of the server (and should therefore back off future - requests) + requests). + try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED + response we should try appending a trailing slash to the end + of the request. Workaround for #3622 in Synapse <= v0.99.3. This + will be attempted before backing off if backing off has been + enabled. Returns: Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The @@ -509,7 +566,6 @@ class MatrixFederationHttpClient(object): RequestSendFailed: If there were problems connecting to the remote, due to e.g. DNS failures, connection timeouts etc. """ - request = MatrixFederationRequest( method="PUT", destination=destination, @@ -519,17 +575,19 @@ class MatrixFederationHttpClient(object): json=data, ) - response = yield self._send_request( + response = yield self._send_request_with_optional_trailing_slash( request, + try_trailing_slash_on_400, + backoff_on_404=backoff_on_404, + ignore_backoff=ignore_backoff, long_retries=long_retries, timeout=timeout, - ignore_backoff=ignore_backoff, - backoff_on_404=backoff_on_404, ) body = yield _handle_json_response( self.hs.get_reactor(), self.default_timeout, request, response, ) + defer.returnValue(body) @defer.inlineCallbacks @@ -592,7 +650,8 @@ class MatrixFederationHttpClient(object): @defer.inlineCallbacks def get_json(self, destination, path, args=None, retry_on_dns_fail=True, - timeout=None, ignore_backoff=False): + timeout=None, ignore_backoff=False, + try_trailing_slash_on_400=False): """ GETs some json from the given host homeserver and path Args: @@ -606,6 +665,9 @@ class MatrixFederationHttpClient(object): be retried. ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. + try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED + response we should try appending a trailing slash to the end of + the request. Workaround for #3622 in Synapse <= v0.99.3. Returns: Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. @@ -631,16 +693,19 @@ class MatrixFederationHttpClient(object): query=args, ) - response = yield self._send_request( + response = yield self._send_request_with_optional_trailing_slash( request, + try_trailing_slash_on_400, + backoff_on_404=False, + ignore_backoff=ignore_backoff, retry_on_dns_fail=retry_on_dns_fail, timeout=timeout, - ignore_backoff=ignore_backoff, ) body = yield _handle_json_response( self.hs.get_reactor(), self.default_timeout, request, response, ) + defer.returnValue(body) @defer.inlineCallbacks diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 2217eb2a10..017ea0385e 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -22,8 +22,6 @@ from synapse.api.errors import ResourceLimitError, SynapseError from synapse.handlers.register import RegistrationHandler from synapse.types import RoomAlias, UserID, create_requester -from tests.utils import default_config, setup_test_homeserver - from .. import unittest @@ -32,26 +30,23 @@ class RegistrationHandlers(object): self.registration_handler = RegistrationHandler(hs) -class RegistrationTestCase(unittest.TestCase): +class RegistrationTestCase(unittest.HomeserverTestCase): """ Tests the RegistrationHandler. """ - @defer.inlineCallbacks - def setUp(self): - self.mock_distributor = Mock() - self.mock_distributor.declare("registered_user") - self.mock_captcha_client = Mock() - - hs_config = default_config("test") + def make_homeserver(self, reactor, clock): + hs_config = self.default_config("test") # some of the tests rely on us having a user consent version hs_config.user_consent_version = "test_consent_version" hs_config.max_mau_value = 50 - self.hs = yield setup_test_homeserver( - self.addCleanup, - config=hs_config, - expire_access_token=True, - ) + hs = self.setup_test_homeserver(config=hs_config, expire_access_token=True) + return hs + + def prepare(self, reactor, clock, hs): + self.mock_distributor = Mock() + self.mock_distributor.declare("registered_user") + self.mock_captcha_client = Mock() self.macaroon_generator = Mock( generate_access_token=Mock(return_value='secret') ) @@ -63,136 +58,133 @@ class RegistrationTestCase(unittest.TestCase): self.requester = create_requester("@requester:test") - @defer.inlineCallbacks def test_user_is_created_and_logged_in_if_doesnt_exist(self): frank = UserID.from_string("@frank:test") user_id = frank.to_string() requester = create_requester(user_id) - result_user_id, result_token = yield self.handler.get_or_create_user( - requester, frank.localpart, "Frankie" + result_user_id, result_token = self.get_success( + self.handler.get_or_create_user(requester, frank.localpart, "Frankie") ) self.assertEquals(result_user_id, user_id) self.assertTrue(result_token is not None) self.assertEquals(result_token, 'secret') - @defer.inlineCallbacks def test_if_user_exists(self): store = self.hs.get_datastore() frank = UserID.from_string("@frank:test") - yield store.register( - user_id=frank.to_string(), - token="jkv;g498752-43gj['eamb!-5", - password_hash=None, + self.get_success( + store.register( + user_id=frank.to_string(), + token="jkv;g498752-43gj['eamb!-5", + password_hash=None, + ) ) local_part = frank.localpart user_id = frank.to_string() requester = create_requester(user_id) - result_user_id, result_token = yield self.handler.get_or_create_user( - requester, local_part, None + result_user_id, result_token = self.get_success( + self.handler.get_or_create_user(requester, local_part, None) ) self.assertEquals(result_user_id, user_id) self.assertTrue(result_token is not None) - @defer.inlineCallbacks def test_mau_limits_when_disabled(self): self.hs.config.limit_usage_by_mau = False # Ensure does not throw exception - yield self.handler.get_or_create_user(self.requester, 'a', "display_name") + self.get_success( + self.handler.get_or_create_user(self.requester, 'a', "display_name") + ) - @defer.inlineCallbacks def test_get_or_create_user_mau_not_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.count_monthly_users = Mock( return_value=defer.succeed(self.hs.config.max_mau_value - 1) ) # Ensure does not throw exception - yield self.handler.get_or_create_user(self.requester, 'c', "User") + self.get_success(self.handler.get_or_create_user(self.requester, 'c', "User")) - @defer.inlineCallbacks def test_get_or_create_user_mau_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(ResourceLimitError): - yield self.handler.get_or_create_user(self.requester, 'b', "display_name") + self.get_failure( + self.handler.get_or_create_user(self.requester, 'b', "display_name"), + ResourceLimitError, + ) self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - with self.assertRaises(ResourceLimitError): - yield self.handler.get_or_create_user(self.requester, 'b', "display_name") + self.get_failure( + self.handler.get_or_create_user(self.requester, 'b', "display_name"), + ResourceLimitError, + ) - @defer.inlineCallbacks def test_register_mau_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(ResourceLimitError): - yield self.handler.register(localpart="local_part") + self.get_failure( + self.handler.register(localpart="local_part"), ResourceLimitError + ) self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - with self.assertRaises(ResourceLimitError): - yield self.handler.register(localpart="local_part") + self.get_failure( + self.handler.register(localpart="local_part"), ResourceLimitError + ) - @defer.inlineCallbacks def test_auto_create_auto_join_rooms(self): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - res = yield self.handler.register(localpart='jeff') - rooms = yield self.store.get_rooms_for_user(res[0]) + res = self.get_success(self.handler.register(localpart='jeff')) + rooms = self.get_success(self.store.get_rooms_for_user(res[0])) directory_handler = self.hs.get_handlers().directory_handler room_alias = RoomAlias.from_string(room_alias_str) - room_id = yield directory_handler.get_association(room_alias) + room_id = self.get_success(directory_handler.get_association(room_alias)) self.assertTrue(room_id['room_id'] in rooms) self.assertEqual(len(rooms), 1) - @defer.inlineCallbacks def test_auto_create_auto_join_rooms_with_no_rooms(self): self.hs.config.auto_join_rooms = [] frank = UserID.from_string("@frank:test") - res = yield self.handler.register(frank.localpart) + res = self.get_success(self.handler.register(frank.localpart)) self.assertEqual(res[0], frank.to_string()) - rooms = yield self.store.get_rooms_for_user(res[0]) + rooms = self.get_success(self.store.get_rooms_for_user(res[0])) self.assertEqual(len(rooms), 0) - @defer.inlineCallbacks def test_auto_create_auto_join_where_room_is_another_domain(self): self.hs.config.auto_join_rooms = ["#room:another"] frank = UserID.from_string("@frank:test") - res = yield self.handler.register(frank.localpart) + res = self.get_success(self.handler.register(frank.localpart)) self.assertEqual(res[0], frank.to_string()) - rooms = yield self.store.get_rooms_for_user(res[0]) + rooms = self.get_success(self.store.get_rooms_for_user(res[0])) self.assertEqual(len(rooms), 0) - @defer.inlineCallbacks def test_auto_create_auto_join_where_auto_create_is_false(self): self.hs.config.autocreate_auto_join_rooms = False room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - res = yield self.handler.register(localpart='jeff') - rooms = yield self.store.get_rooms_for_user(res[0]) + res = self.get_success(self.handler.register(localpart='jeff')) + rooms = self.get_success(self.store.get_rooms_for_user(res[0])) self.assertEqual(len(rooms), 0) - @defer.inlineCallbacks def test_auto_create_auto_join_rooms_when_support_user_exists(self): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] self.store.is_support_user = Mock(return_value=True) - res = yield self.handler.register(localpart='support') - rooms = yield self.store.get_rooms_for_user(res[0]) + res = self.get_success(self.handler.register(localpart='support')) + rooms = self.get_success(self.store.get_rooms_for_user(res[0])) self.assertEqual(len(rooms), 0) directory_handler = self.hs.get_handlers().directory_handler room_alias = RoomAlias.from_string(room_alias_str) - with self.assertRaises(SynapseError): - yield directory_handler.get_association(room_alias) + self.get_failure(directory_handler.get_association(room_alias), SynapseError) - @defer.inlineCallbacks def test_auto_create_auto_join_where_no_consent(self): """Test to ensure that the first user is not auto-joined to a room if they have not given general consent. @@ -208,27 +200,27 @@ class RegistrationTestCase(unittest.TestCase): # (Messing with the internals of event_creation_handler is fragile # but can't see a better way to do this. One option could be to subclass # the test with custom config.) - event_creation_handler._block_events_without_consent_error = ("Error") + event_creation_handler._block_events_without_consent_error = "Error" event_creation_handler._consent_uri_builder = Mock() room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] # When:- # * the user is registered and post consent actions are called - res = yield self.handler.register(localpart='jeff') - yield self.handler.post_consent_actions(res[0]) + res = self.get_success(self.handler.register(localpart='jeff')) + self.get_success(self.handler.post_consent_actions(res[0])) # Then:- # * Ensure that they have not been joined to the room - rooms = yield self.store.get_rooms_for_user(res[0]) + rooms = self.get_success(self.store.get_rooms_for_user(res[0])) self.assertEqual(len(rooms), 0) - @defer.inlineCallbacks def test_register_support_user(self): - res = yield self.handler.register(localpart='user', user_type=UserTypes.SUPPORT) + res = self.get_success( + self.handler.register(localpart='user', user_type=UserTypes.SUPPORT) + ) self.assertTrue(self.store.is_support_user(res[0])) - @defer.inlineCallbacks def test_register_not_support_user(self): - res = yield self.handler.register(localpart='user') + res = self.get_success(self.handler.register(localpart='user')) self.assertFalse(self.store.is_support_user(res[0])) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 13486930fb..7decb22933 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -192,6 +192,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): json_data_callback=ANY, long_retries=True, backoff_on_404=True, + try_trailing_slash_on_400=True, ) def test_started_typing_remote_recv(self): @@ -269,6 +270,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): json_data_callback=ANY, long_retries=True, backoff_on_404=True, + try_trailing_slash_on_400=True, ) self.assertEquals(self.event_source.get_current_key(), 1) diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index b03b37affe..cd8e086f86 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -268,6 +268,105 @@ class FederationClientTests(HomeserverTestCase): self.assertIsInstance(f.value, TimeoutError) + def test_client_requires_trailing_slashes(self): + """ + If a connection is made to a client but the client rejects it due to + requiring a trailing slash. We need to retry the request with a + trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622. + """ + d = self.cl.get_json( + "testserv:8008", "foo/bar", try_trailing_slash_on_400=True, + ) + + # Send the request + self.pump() + + # there should have been a call to connectTCP + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (_host, _port, factory, _timeout, _bindAddress) = clients[0] + + # complete the connection and wire it up to a fake transport + client = factory.buildProtocol(None) + conn = StringTransport() + client.makeConnection(conn) + + # that should have made it send the request to the connection + self.assertRegex(conn.value(), b"^GET /foo/bar") + + # Clear the original request data before sending a response + conn.clear() + + # Send the HTTP response + client.dataReceived( + b"HTTP/1.1 400 Bad Request\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: 59\r\n" + b"\r\n" + b'{"errcode":"M_UNRECOGNIZED","error":"Unrecognized request"}' + ) + + # We should get another request with a trailing slash + self.assertRegex(conn.value(), b"^GET /foo/bar/") + + # Send a happy response this time + client.dataReceived( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: 2\r\n" + b"\r\n" + b'{}' + ) + + # We should get a successful response + r = self.successResultOf(d) + self.assertEqual(r, {}) + + def test_client_does_not_retry_on_400_plus(self): + """ + Another test for trailing slashes but now test that we don't retry on + trailing slashes on a non-400/M_UNRECOGNIZED response. + + See test_client_requires_trailing_slashes() for context. + """ + d = self.cl.get_json( + "testserv:8008", "foo/bar", try_trailing_slash_on_400=True, + ) + + # Send the request + self.pump() + + # there should have been a call to connectTCP + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (_host, _port, factory, _timeout, _bindAddress) = clients[0] + + # complete the connection and wire it up to a fake transport + client = factory.buildProtocol(None) + conn = StringTransport() + client.makeConnection(conn) + + # that should have made it send the request to the connection + self.assertRegex(conn.value(), b"^GET /foo/bar") + + # Clear the original request data before sending a response + conn.clear() + + # Send the HTTP response + client.dataReceived( + b"HTTP/1.1 404 Not Found\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: 2\r\n" + b"\r\n" + b"{}" + ) + + # We should not get another request + self.assertEqual(conn.value(), b"") + + # We should get a 404 failure response + self.failureResultOf(d) + def test_client_sends_body(self): self.cl.post_json( "testserv:8008", "foo/bar", timeout=10000, diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 9c401bf300..05b0143c42 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -18,136 +18,11 @@ import time import attr -from twisted.internet import defer - from synapse.api.constants import Membership -from tests import unittest from tests.server import make_request, render -class RestTestCase(unittest.TestCase): - """Contains extra helper functions to quickly and clearly perform a given - REST action, which isn't the focus of the test. - - This subclass assumes there are mock_resource and auth_user_id attributes. - """ - - def __init__(self, *args, **kwargs): - super(RestTestCase, self).__init__(*args, **kwargs) - self.mock_resource = None - self.auth_user_id = None - - @defer.inlineCallbacks - def create_room_as(self, room_creator, is_public=True, tok=None): - temp_id = self.auth_user_id - self.auth_user_id = room_creator - path = "/createRoom" - content = "{}" - if not is_public: - content = '{"visibility":"private"}' - if tok: - path = path + "?access_token=%s" % tok - (code, response) = yield self.mock_resource.trigger("POST", path, content) - self.assertEquals(200, code, msg=str(response)) - self.auth_user_id = temp_id - defer.returnValue(response["room_id"]) - - @defer.inlineCallbacks - def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): - yield self.change_membership( - room=room, - src=src, - targ=targ, - tok=tok, - membership=Membership.INVITE, - expect_code=expect_code, - ) - - @defer.inlineCallbacks - def join(self, room=None, user=None, expect_code=200, tok=None): - yield self.change_membership( - room=room, - src=user, - targ=user, - tok=tok, - membership=Membership.JOIN, - expect_code=expect_code, - ) - - @defer.inlineCallbacks - def leave(self, room=None, user=None, expect_code=200, tok=None): - yield self.change_membership( - room=room, - src=user, - targ=user, - tok=tok, - membership=Membership.LEAVE, - expect_code=expect_code, - ) - - @defer.inlineCallbacks - def change_membership(self, room, src, targ, membership, tok=None, expect_code=200): - temp_id = self.auth_user_id - self.auth_user_id = src - - path = "/rooms/%s/state/m.room.member/%s" % (room, targ) - if tok: - path = path + "?access_token=%s" % tok - - data = {"membership": membership} - - (code, response) = yield self.mock_resource.trigger( - "PUT", path, json.dumps(data) - ) - self.assertEquals( - expect_code, - code, - msg="Expected: %d, got: %d, resp: %r" % (expect_code, code, response), - ) - - self.auth_user_id = temp_id - - @defer.inlineCallbacks - def register(self, user_id): - (code, response) = yield self.mock_resource.trigger( - "POST", - "/register", - json.dumps( - {"user": user_id, "password": "test", "type": "m.login.password"} - ), - ) - self.assertEquals(200, code, msg=response) - defer.returnValue(response) - - @defer.inlineCallbacks - def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): - if txn_id is None: - txn_id = "m%s" % (str(time.time())) - if body is None: - body = "body_text_here" - - path = "/rooms/%s/send/m.room.message/%s" % (room_id, txn_id) - content = '{"msgtype":"m.text","body":"%s"}' % body - if tok: - path = path + "?access_token=%s" % tok - - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(expect_code, code, msg=str(response)) - - def assert_dict(self, required, actual): - """Does a partial assert of a dict. - - Args: - required (dict): The keys and value which MUST be in 'actual'. - actual (dict): The test result. Extra keys will not be checked. - """ - for key in required: - self.assertEquals( - required[key], actual[key], msg="%s mismatch. %s" % (key, actual) - ) - - @attr.s class RestHelper(object): """Contains extra helper functions to quickly and clearly perform a given diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 3bd9f1e9c1..be73e718c2 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -1,3 +1,18 @@ +# -*- coding: utf-8 -*- +# Copyright 2018, 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from mock import Mock from twisted.internet import defer @@ -9,16 +24,18 @@ from synapse.server_notices.resource_limits_server_notices import ( ) from tests import unittest -from tests.utils import default_config, setup_test_homeserver -class TestResourceLimitsServerNotices(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs_config = default_config(name="test") +class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): + + def make_homeserver(self, reactor, clock): + hs_config = self.default_config("test") hs_config.server_notices_mxid = "@server:test" - self.hs = yield setup_test_homeserver(self.addCleanup, config=hs_config) + hs = self.setup_test_homeserver(config=hs_config, expire_access_token=True) + return hs + + def prepare(self, reactor, clock, hs): self.server_notices_sender = self.hs.get_server_notices_sender() # relying on [1] is far from ideal, but the only case where @@ -53,23 +70,21 @@ class TestResourceLimitsServerNotices(unittest.TestCase): self._rlsn._store.get_tags_for_room = Mock(return_value={}) self.hs.config.admin_contact = "mailto:user@test.com" - @defer.inlineCallbacks def test_maybe_send_server_notice_to_user_flag_off(self): """Tests cases where the flags indicate nothing to do""" # test hs disabled case self.hs.config.hs_disabled = True - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() # Test when mau limiting disabled self.hs.config.hs_disabled = False self.hs.limit_usage_by_mau = False - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() - @defer.inlineCallbacks def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): """Test when user has blocked notice, but should have it removed""" @@ -81,13 +96,14 @@ class TestResourceLimitsServerNotices(unittest.TestCase): return_value=defer.succeed({"123": mock_event}) ) - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check the content, but once == remove blocking event self._send_notice.assert_called_once() - @defer.inlineCallbacks def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self): - """Test when user has blocked notice, but notice ought to be there (NOOP)""" + """ + Test when user has blocked notice, but notice ought to be there (NOOP) + """ self._rlsn._auth.check_auth_blocking = Mock( side_effect=ResourceLimitError(403, 'foo') ) @@ -98,52 +114,49 @@ class TestResourceLimitsServerNotices(unittest.TestCase): self._rlsn._store.get_events = Mock( return_value=defer.succeed({"123": mock_event}) ) - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() - @defer.inlineCallbacks def test_maybe_send_server_notice_to_user_add_blocked_notice(self): - """Test when user does not have blocked notice, but should have one""" + """ + Test when user does not have blocked notice, but should have one + """ self._rlsn._auth.check_auth_blocking = Mock( side_effect=ResourceLimitError(403, 'foo') ) - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check contents, but 2 calls == set blocking event self.assertTrue(self._send_notice.call_count == 2) - @defer.inlineCallbacks def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self): - """Test when user does not have blocked notice, nor should they (NOOP)""" - + """ + Test when user does not have blocked notice, nor should they (NOOP) + """ self._rlsn._auth.check_auth_blocking = Mock() - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() - @defer.inlineCallbacks def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self): - - """Test when user is not part of the MAU cohort - this should not ever + """ + Test when user is not part of the MAU cohort - this should not ever happen - but ... """ - self._rlsn._auth.check_auth_blocking = Mock() self._rlsn._store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(None) ) - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() -class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - self.hs = yield setup_test_homeserver(self.addCleanup) +class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = self.hs.get_datastore() self.server_notices_sender = self.hs.get_server_notices_sender() self.server_notices_manager = self.hs.get_server_notices_manager() @@ -168,26 +181,27 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase): self.hs.config.admin_contact = "mailto:user@test.com" - @defer.inlineCallbacks def test_server_notice_only_sent_once(self): self.store.get_monthly_active_count = Mock(return_value=1000) self.store.user_last_seen_monthly_active = Mock(return_value=1000) # Call the function multiple times to ensure we only send the notice once - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Now lets get the last load of messages in the service notice room and # check that there is only one server notice - room_id = yield self.server_notices_manager.get_notice_room_for_user( - self.user_id + room_id = self.get_success( + self.server_notices_manager.get_notice_room_for_user(self.user_id) ) - token = yield self.event_source.get_current_token() - events, _ = yield self.store.get_recent_events_for_room( - room_id, limit=100, end_token=token.room_key + token = self.get_success(self.event_source.get_current_token()) + events, _ = self.get_success( + self.store.get_recent_events_for_room( + room_id, limit=100, end_token=token.room_key + ) ) count = 0 diff --git a/tests/unittest.py b/tests/unittest.py index 7772a47078..27403de908 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -314,6 +314,9 @@ class HomeserverTestCase(TestCase): """ kwargs = dict(kwargs) kwargs.update(self._hs_args) + if "config" not in kwargs: + config = self.default_config() + kwargs["config"] = config hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) stor = hs.get_datastore() @@ -336,6 +339,15 @@ class HomeserverTestCase(TestCase): self.pump(by=by) return self.successResultOf(d) + def get_failure(self, d, exc): + """ + Run a Deferred and get a Failure from it. The failure must be of the type `exc`. + """ + if not isinstance(d, Deferred): + return d + self.pump() + return self.failureResultOf(d, exc) + def register_user(self, username, password, admin=False): """ Register a user. Requires the Admin API be registered. |