From 8000cf131592b6edcded65ef4be20b8ac0f1bfd3 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 16 Mar 2021 16:44:25 +0100 Subject: Return m.change_password.enabled=false if local database is disabled (#9588) Instead of if the user does not have a password hash. This allows a SSO user to add a password to their account, but only if the local password database is configured. --- tests/rest/client/v2_alpha/test_capabilities.py | 36 ++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py index e808339fb3..287a1a485c 100644 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ b/tests/rest/client/v2_alpha/test_capabilities.py @@ -18,6 +18,7 @@ from synapse.rest.client.v1 import login from synapse.rest.client.v2_alpha import capabilities from tests import unittest +from tests.unittest import override_config class CapabilitiesTestCase(unittest.HomeserverTestCase): @@ -33,6 +34,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver() self.store = hs.get_datastore() self.config = hs.config + self.auth_handler = hs.get_auth_handler() return hs def test_check_auth_required(self): @@ -56,7 +58,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): capabilities["m.room_versions"]["default"], ) - def test_get_change_password_capabilities(self): + def test_get_change_password_capabilities_password_login(self): localpart = "user" password = "pass" user = self.register_user(localpart, password) @@ -66,10 +68,36 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): capabilities = channel.json_body["capabilities"] self.assertEqual(channel.code, 200) - - # Test case where password is handled outside of Synapse self.assertTrue(capabilities["m.change_password"]["enabled"]) - self.get_success(self.store.user_set_password_hash(user, None)) + + @override_config({"password_config": {"localdb_enabled": False}}) + def test_get_change_password_capabilities_localdb_disabled(self): + localpart = "user" + password = "pass" + user = self.register_user(localpart, password) + access_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + user, device_id=None, valid_until_ms=None + ) + ) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["m.change_password"]["enabled"]) + + @override_config({"password_config": {"enabled": False}}) + def test_get_change_password_capabilities_password_disabled(self): + localpart = "user" + password = "pass" + user = self.register_user(localpart, password) + access_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + user, device_id=None, valid_until_ms=None + ) + ) + channel = self.make_request("GET", self.url, access_token=access_token) capabilities = channel.json_body["capabilities"] -- cgit 1.5.1 From dd5e5dc1d6c88a3532d25f18cfc312d8bc813473 Mon Sep 17 00:00:00 2001 From: Hubbe Date: Tue, 16 Mar 2021 17:46:07 +0200 Subject: Add SSO attribute requirements for OIDC providers (#9609) Allows limiting who can login using OIDC via the claims made from the IdP. --- changelog.d/9609.feature | 1 + docs/sample_config.yaml | 24 +++++++ synapse/config/oidc_config.py | 40 +++++++++++- synapse/handlers/oidc_handler.py | 13 ++++ tests/handlers/test_oidc.py | 132 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 209 insertions(+), 1 deletion(-) create mode 100644 changelog.d/9609.feature (limited to 'tests') diff --git a/changelog.d/9609.feature b/changelog.d/9609.feature new file mode 100644 index 0000000000..f3b6342069 --- /dev/null +++ b/changelog.d/9609.feature @@ -0,0 +1 @@ +Logins using OpenID Connect can require attributes on the `userinfo` response in order to login. Contributed by Hubbe King. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 7de000f4a4..a9f59e39f7 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1873,6 +1873,24 @@ saml2_config: # which is set to the claims returned by the UserInfo Endpoint and/or # in the ID Token. # +# It is possible to configure Synapse to only allow logins if certain attributes +# match particular values in the OIDC userinfo. The requirements can be listed under +# `attribute_requirements` as shown below. All of the listed attributes must +# match for the login to be permitted. Additional attributes can be added to +# userinfo by expanding the `scopes` section of the OIDC config to retrieve +# additional information from the OIDC provider. +# +# If the OIDC claim is a list, then the attribute must match any value in the list. +# Otherwise, it must exactly match the value of the claim. Using the example +# below, the `family_name` claim MUST be "Stephensson", but the `groups` +# claim MUST contain "admin". +# +# attribute_requirements: +# - attribute: family_name +# value: "Stephensson" +# - attribute: groups +# value: "admin" +# # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md # for information on how to configure these options. # @@ -1905,6 +1923,9 @@ oidc_providers: # localpart_template: "{{ user.login }}" # display_name_template: "{{ user.name }}" # email_template: "{{ user.email }}" + # attribute_requirements: + # - attribute: userGroup + # value: "synapseUsers" # For use with Keycloak # @@ -1914,6 +1935,9 @@ oidc_providers: # client_id: "synapse" # client_secret: "copy secret generated in Keycloak UI" # scopes: ["openid", "profile"] + # attribute_requirements: + # - attribute: groups + # value: "admin" # For use with Github # diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 2bfb537c15..eab042a085 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -15,11 +15,12 @@ # limitations under the License. from collections import Counter -from typing import Iterable, Mapping, Optional, Tuple, Type +from typing import Iterable, List, Mapping, Optional, Tuple, Type import attr from synapse.config._util import validate_config +from synapse.config.sso import SsoAttributeRequirement from synapse.python_dependencies import DependencyException, check_requirements from synapse.types import Collection, JsonDict from synapse.util.module_loader import load_module @@ -191,6 +192,24 @@ class OIDCConfig(Config): # which is set to the claims returned by the UserInfo Endpoint and/or # in the ID Token. # + # It is possible to configure Synapse to only allow logins if certain attributes + # match particular values in the OIDC userinfo. The requirements can be listed under + # `attribute_requirements` as shown below. All of the listed attributes must + # match for the login to be permitted. Additional attributes can be added to + # userinfo by expanding the `scopes` section of the OIDC config to retrieve + # additional information from the OIDC provider. + # + # If the OIDC claim is a list, then the attribute must match any value in the list. + # Otherwise, it must exactly match the value of the claim. Using the example + # below, the `family_name` claim MUST be "Stephensson", but the `groups` + # claim MUST contain "admin". + # + # attribute_requirements: + # - attribute: family_name + # value: "Stephensson" + # - attribute: groups + # value: "admin" + # # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md # for information on how to configure these options. # @@ -223,6 +242,9 @@ class OIDCConfig(Config): # localpart_template: "{{{{ user.login }}}}" # display_name_template: "{{{{ user.name }}}}" # email_template: "{{{{ user.email }}}}" + # attribute_requirements: + # - attribute: userGroup + # value: "synapseUsers" # For use with Keycloak # @@ -232,6 +254,9 @@ class OIDCConfig(Config): # client_id: "synapse" # client_secret: "copy secret generated in Keycloak UI" # scopes: ["openid", "profile"] + # attribute_requirements: + # - attribute: groups + # value: "admin" # For use with Github # @@ -329,6 +354,10 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { }, "allow_existing_users": {"type": "boolean"}, "user_mapping_provider": {"type": ["object", "null"]}, + "attribute_requirements": { + "type": "array", + "items": SsoAttributeRequirement.JSON_SCHEMA, + }, }, } @@ -465,6 +494,11 @@ def _parse_oidc_config_dict( jwt_header=client_secret_jwt_key_config["jwt_header"], jwt_payload=client_secret_jwt_key_config.get("jwt_payload", {}), ) + # parse attribute_requirements from config (list of dicts) into a list of SsoAttributeRequirement + attribute_requirements = [ + SsoAttributeRequirement(**x) + for x in oidc_config.get("attribute_requirements", []) + ] return OidcProviderConfig( idp_id=idp_id, @@ -488,6 +522,7 @@ def _parse_oidc_config_dict( allow_existing_users=oidc_config.get("allow_existing_users", False), user_mapping_provider_class=user_mapping_provider_class, user_mapping_provider_config=user_mapping_provider_config, + attribute_requirements=attribute_requirements, ) @@ -577,3 +612,6 @@ class OidcProviderConfig: # the config of the user mapping provider user_mapping_provider_config = attr.ib() + + # required attributes to require in userinfo to allow login/registration + attribute_requirements = attr.ib(type=List[SsoAttributeRequirement]) diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 6d8551a6d6..bc3630e9e9 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -280,6 +280,7 @@ class OidcProvider: self._config = provider self._callback_url = hs.config.oidc_callback_url # type: str + self._oidc_attribute_requirements = provider.attribute_requirements self._scopes = provider.scopes self._user_profile_method = provider.user_profile_method @@ -859,6 +860,18 @@ class OidcProvider: ) # otherwise, it's a login + logger.debug("Userinfo for OIDC login: %s", userinfo) + + # Ensure that the attributes of the logged in user meet the required + # attributes by checking the userinfo against attribute_requirements + # In order to deal with the fact that OIDC userinfo can contain many + # types of data, we wrap non-list values in lists. + if not self._sso_handler.check_required_attributes( + request, + {k: v if isinstance(v, list) else [v] for k, v in userinfo.items()}, + self._oidc_attribute_requirements, + ): + return # Call the mapper to register/login the user try: diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 5e9c9c2e88..c7796fb837 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -989,6 +989,138 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) self.assertRenderedError("mapping_error", "localpart is invalid: ") + @override_config( + { + "oidc_config": { + **DEFAULT_CONFIG, + "attribute_requirements": [{"attribute": "test", "value": "foobar"}], + } + } + ) + def test_attribute_requirements(self): + """The required attributes must be met from the OIDC userinfo response.""" + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # userinfo lacking "test": "foobar" attribute should fail. + userinfo = { + "sub": "tester", + "username": "tester", + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_not_called() + + # userinfo with "test": "foobar" attribute should succeed. + userinfo = { + "sub": "tester", + "username": "tester", + "test": "foobar", + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@tester:test", "oidc", ANY, ANY, None, new_user=True + ) + + @override_config( + { + "oidc_config": { + **DEFAULT_CONFIG, + "attribute_requirements": [{"attribute": "test", "value": "foobar"}], + } + } + ) + def test_attribute_requirements_contains(self): + """Test that auth succeeds if userinfo attribute CONTAINS required value""" + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + # userinfo with "test": ["foobar", "foo", "bar"] attribute should succeed. + userinfo = { + "sub": "tester", + "username": "tester", + "test": ["foobar", "foo", "bar"], + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@tester:test", "oidc", ANY, ANY, None, new_user=True + ) + + @override_config( + { + "oidc_config": { + **DEFAULT_CONFIG, + "attribute_requirements": [{"attribute": "test", "value": "foobar"}], + } + } + ) + def test_attribute_requirements_mismatch(self): + """ + Test that auth fails if attributes exist but don't match, + or are non-string values. + """ + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + # userinfo with "test": "not_foobar" attribute should fail + userinfo = { + "sub": "tester", + "username": "tester", + "test": "not_foobar", + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_not_called() + + # userinfo with "test": ["foo", "bar"] attribute should fail + userinfo = { + "sub": "tester", + "username": "tester", + "test": ["foo", "bar"], + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_not_called() + + # userinfo with "test": False attribute should fail + # this is largely just to ensure we don't crash here + userinfo = { + "sub": "tester", + "username": "tester", + "test": False, + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_not_called() + + # userinfo with "test": None attribute should fail + # a value of None breaks the OIDC spec, but it's important to not crash here + userinfo = { + "sub": "tester", + "username": "tester", + "test": None, + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_not_called() + + # userinfo with "test": 1 attribute should fail + # this is largely just to ensure we don't crash here + userinfo = { + "sub": "tester", + "username": "tester", + "test": 1, + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_not_called() + + # userinfo with "test": 3.14 attribute should fail + # this is largely just to ensure we don't crash here + userinfo = { + "sub": "tester", + "username": "tester", + "test": 3.14, + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_not_called() + def _generate_oidc_session_token( self, state: str, -- cgit 1.5.1 From 27d2820c33d94cd99aea128b6ade76a7de838c3d Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Tue, 16 Mar 2021 19:19:27 +0100 Subject: Enable flake8-bugbear, but disable most checks. (#9499) * Adds B00 to ignored checks. * Fixes remaining issues. --- changelog.d/9499.misc | 1 + setup.cfg | 3 ++- setup.py | 1 + synapse/app/__init__.py | 4 +++- synapse/config/key.py | 6 +++++- synapse/config/metrics.py | 4 +++- synapse/config/oidc_config.py | 4 +++- synapse/config/repository.py | 4 +++- synapse/config/saml2_config.py | 4 +++- synapse/config/tracer.py | 4 +++- synapse/crypto/context_factory.py | 2 +- tests/unittest.py | 2 +- 12 files changed, 29 insertions(+), 10 deletions(-) create mode 100644 changelog.d/9499.misc (limited to 'tests') diff --git a/changelog.d/9499.misc b/changelog.d/9499.misc new file mode 100644 index 0000000000..1513017a10 --- /dev/null +++ b/changelog.d/9499.misc @@ -0,0 +1 @@ +Introduce bugbear to the test suite and fix some of it's lint violations. \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 5e301c2cd7..920868df20 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,7 +18,8 @@ ignore = # E203: whitespace before ':' (which is contrary to pep8?) # E731: do not assign a lambda expression, use a def # E501: Line too long (black enforces this for us) -ignore=W503,W504,E203,E731,E501 +# B00: Subsection of the bugbear suite (TODO: add in remaining fixes) +ignore=W503,W504,E203,E731,E501,B00 [isort] line_length = 88 diff --git a/setup.py b/setup.py index bbd9e7862a..b834e4e55b 100755 --- a/setup.py +++ b/setup.py @@ -99,6 +99,7 @@ CONDITIONAL_REQUIREMENTS["lint"] = [ "isort==5.7.0", "black==20.8b1", "flake8-comprehensions", + "flake8-bugbear", "flake8", ] diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py index 4a9b0129c3..d1a2cd5e19 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py @@ -22,7 +22,9 @@ logger = logging.getLogger(__name__) try: python_dependencies.check_requirements() except python_dependencies.DependencyException as e: - sys.stderr.writelines(e.message) + sys.stderr.writelines( + e.message # noqa: B306, DependencyException.message is a property + ) sys.exit(1) diff --git a/synapse/config/key.py b/synapse/config/key.py index de964dff13..350ff1d665 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -404,7 +404,11 @@ def _parse_key_servers(key_servers, federation_verify_certificates): try: jsonschema.validate(key_servers, TRUSTED_KEY_SERVERS_SCHEMA) except jsonschema.ValidationError as e: - raise ConfigError("Unable to parse 'trusted_key_servers': " + e.message) + raise ConfigError( + "Unable to parse 'trusted_key_servers': {}".format( + e.message # noqa: B306, jsonschema.ValidationError.message is a valid attribute + ) + ) for server in key_servers: server_name = server["server_name"] diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index dfd27e1523..2b289f4208 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -56,7 +56,9 @@ class MetricsConfig(Config): try: check_requirements("sentry") except DependencyException as e: - raise ConfigError(e.message) + raise ConfigError( + e.message # noqa: B306, DependencyException.message is a property + ) self.sentry_dsn = config["sentry"].get("dsn") if not self.sentry_dsn: diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index eab042a085..747ab9a7fe 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -42,7 +42,9 @@ class OIDCConfig(Config): try: check_requirements("oidc") except DependencyException as e: - raise ConfigError(e.message) from e + raise ConfigError( + e.message # noqa: B306, DependencyException.message is a property + ) from e # check we don't have any duplicate idp_ids now. (The SSO handler will also # check for duplicates when the REST listeners get registered, but that happens diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 69d9de5a43..061c4ec83f 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -176,7 +176,9 @@ class ContentRepositoryConfig(Config): check_requirements("url_preview") except DependencyException as e: - raise ConfigError(e.message) + raise ConfigError( + e.message # noqa: B306, DependencyException.message is a property + ) if "url_preview_ip_range_blacklist" not in config: raise ConfigError( diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index 4b494f217f..6db9cb5ced 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -76,7 +76,9 @@ class SAML2Config(Config): try: check_requirements("saml2") except DependencyException as e: - raise ConfigError(e.message) + raise ConfigError( + e.message # noqa: B306, DependencyException.message is a property + ) self.saml2_enabled = True diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py index 0c1a854f09..727a1e7008 100644 --- a/synapse/config/tracer.py +++ b/synapse/config/tracer.py @@ -39,7 +39,9 @@ class TracerConfig(Config): try: check_requirements("opentracing") except DependencyException as e: - raise ConfigError(e.message) + raise ConfigError( + e.message # noqa: B306, DependencyException.message is a property + ) # The tracer is enabled so sanitize the config diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 14b21796d9..4ca13011e5 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -219,7 +219,7 @@ class SSLClientConnectionCreator: # ... and we also gut-wrench a '_synapse_tls_verifier' attribute into the # tls_protocol so that the SSL context's info callback has something to # call to do the cert verification. - setattr(tls_protocol, "_synapse_tls_verifier", self._verifier) + tls_protocol._synapse_tls_verifier = self._verifier return connection diff --git a/tests/unittest.py b/tests/unittest.py index ca7031c724..224f037ce1 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -140,7 +140,7 @@ class TestCase(unittest.TestCase): try: self.assertEquals(attrs[key], getattr(obj, key)) except AssertionError as e: - raise (type(e))(e.message + " for '.%s'" % key) + raise (type(e))("Assert error for '.{}':".format(key)) from e def assert_dict(self, required, actual): """Does a partial assert of a dict. -- cgit 1.5.1 From 7b06f85c0e18b62775f12789fdf4adb6a0a47a4b Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 17 Mar 2021 16:51:55 +0000 Subject: Ensure we use a copy of the event content dict before modifying it in serialize_event (#9585) This bug was discovered by DINUM. We were modifying `serialized_event["content"]`, which - if you've got `USE_FROZEN_DICTS` turned on or are [using a third party rules module](https://github.com/matrix-org/synapse/blob/17cd48fe5171d50da4cb59db647b993168e7dfab/synapse/events/third_party_rules.py#L73-L76) - will raise a 500 if you try to a edit a reply to a message. `serialized_event["content"]` could be set to the edit event's content, instead of a copy of it, which is bad as we attempt to modify it. Instead, we also end up modifying the original event's content. DINUM uses a third party rules module, which meant the event's content got frozen and thus an exception was raised. To be clear, the problem is not that the event's content was frozen. In fact doing so helped us uncover the fact we weren't copying event content correctly. --- changelog.d/9585.bugfix | 1 + synapse/events/utils.py | 14 ++++++- tests/rest/client/test_third_party_rules.py | 62 ++++++++++++++++++++++++++++ tests/rest/client/v2_alpha/test_relations.py | 62 ++++++++++++++++++++++++++++ tests/unittest.py | 10 +++++ 5 files changed, 147 insertions(+), 2 deletions(-) create mode 100644 changelog.d/9585.bugfix (limited to 'tests') diff --git a/changelog.d/9585.bugfix b/changelog.d/9585.bugfix new file mode 100644 index 0000000000..de472ddfd1 --- /dev/null +++ b/changelog.d/9585.bugfix @@ -0,0 +1 @@ +Fix a longstanding bug that could cause issues when editing a reply to a message. \ No newline at end of file diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 5022e0fcb3..0f8a3b5ad8 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -22,6 +22,7 @@ from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion from synapse.util.async_helpers import yieldable_gather_results +from synapse.util.frozenutils import unfreeze from . import EventBase @@ -402,10 +403,19 @@ class EventClientSerializer: # If there is an edit replace the content, preserving existing # relations. + # Ensure we take copies of the edit content, otherwise we risk modifying + # the original event. + edit_content = edit.content.copy() + + # Unfreeze the event content if necessary, so that we may modify it below + edit_content = unfreeze(edit_content) + serialized_event["content"] = edit_content.get("m.new_content", {}) + + # Check for existing relations relations = event.content.get("m.relates_to") - serialized_event["content"] = edit.content.get("m.new_content", {}) if relations: - serialized_event["content"]["m.relates_to"] = relations + # Keep the relations, ensuring we use a dict copy of the original + serialized_event["content"]["m.relates_to"] = relations.copy() else: serialized_event["content"].pop("m.relates_to", None) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 227fffab58..bf39014277 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -161,6 +161,68 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): ev = channel.json_body self.assertEqual(ev["content"]["x"], "y") + def test_message_edit(self): + """Ensure that the module doesn't cause issues with edited messages.""" + # first patch the event checker so that it will modify the event + async def check(ev: EventBase, state): + d = ev.get_dict() + d["content"] = { + "msgtype": "m.text", + "body": d["content"]["body"].upper(), + } + return d + + current_rules_module().check_event_allowed = check + + # Send an event, then edit it. + channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id, + { + "msgtype": "m.text", + "body": "Original body", + }, + access_token=self.tok, + ) + self.assertEqual(channel.result["code"], b"200", channel.result) + orig_event_id = channel.json_body["event_id"] + + channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/send/m.room.message/2" % self.room_id, + { + "m.new_content": {"msgtype": "m.text", "body": "Edited body"}, + "m.relates_to": { + "rel_type": "m.replace", + "event_id": orig_event_id, + }, + "msgtype": "m.text", + "body": "Edited body", + }, + access_token=self.tok, + ) + self.assertEqual(channel.result["code"], b"200", channel.result) + edited_event_id = channel.json_body["event_id"] + + # ... and check that they both got modified + channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, orig_event_id), + access_token=self.tok, + ) + self.assertEqual(channel.result["code"], b"200", channel.result) + ev = channel.json_body + self.assertEqual(ev["content"]["body"], "ORIGINAL BODY") + + channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, edited_event_id), + access_token=self.tok, + ) + self.assertEqual(channel.result["code"], b"200", channel.result) + ev = channel.json_body + self.assertEqual(ev["content"]["body"], "EDITED BODY") + def test_send_event(self): """Tests that the module can send an event into a room via the module api""" content = { diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index 7c457754f1..e7bb5583fc 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -39,6 +39,11 @@ class RelationsTestCase(unittest.HomeserverTestCase): # We need to enable msc1849 support for aggregations config = self.default_config() config["experimental_msc1849_support_enabled"] = True + + # We enable frozen dicts as relations/edits change event contents, so we + # want to test that we don't modify the events in the caches. + config["use_frozen_dicts"] = True + return self.setup_test_homeserver(config=config) def prepare(self, reactor, clock, hs): @@ -518,6 +523,63 @@ class RelationsTestCase(unittest.HomeserverTestCase): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) + def test_edit_reply(self): + """Test that editing a reply works.""" + + # Create a reply to edit. + channel = self._send_relation( + RelationTypes.REFERENCE, + "m.room.message", + content={"msgtype": "m.text", "body": "A reply!"}, + ) + self.assertEquals(200, channel.code, channel.json_body) + reply = channel.json_body["event_id"] + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + parent_id=reply, + ) + self.assertEquals(200, channel.code, channel.json_body) + + edit_event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, reply), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # We expect to see the new body in the dict, as well as the reference + # metadata sill intact. + self.assertDictContainsSubset(new_body, channel.json_body["content"]) + self.assertDictContainsSubset( + { + "m.relates_to": { + "event_id": self.parent_id, + "key": None, + "rel_type": "m.reference", + } + }, + channel.json_body["content"], + ) + + # We expect that the edit relation appears in the unsigned relations + # section. + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in ["event_id", "sender", "origin_server_ts"]: + self.assertIn(key, m_replace_dict) + + self.assert_dict( + {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + ) + def test_relations_redaction_redacts_edits(self): """Test that edits of an event are redacted when the original event is redacted. diff --git a/tests/unittest.py b/tests/unittest.py index 224f037ce1..58a4daa1ec 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -32,6 +32,7 @@ from twisted.python.threadpool import ThreadPool from twisted.trial import unittest from twisted.web.resource import Resource +from synapse import events from synapse.api.constants import EventTypes, Membership from synapse.config.homeserver import HomeServerConfig from synapse.config.ratelimiting import FederationRateLimitConfig @@ -229,6 +230,11 @@ class HomeserverTestCase(TestCase): self._hs_args = {"clock": self.clock, "reactor": self.reactor} self.hs = self.make_homeserver(self.reactor, self.clock) + # Honour the `use_frozen_dicts` config option. We have to do this + # manually because this is taken care of in the app `start` code, which + # we don't run. Plus we want to reset it on tearDown. + events.USE_FROZEN_DICTS = self.hs.config.use_frozen_dicts + if self.hs is None: raise Exception("No homeserver returned from make_homeserver.") @@ -292,6 +298,10 @@ class HomeserverTestCase(TestCase): if hasattr(self, "prepare"): self.prepare(self.reactor, self.clock, self.hs) + def tearDown(self): + # Reset to not use frozen dicts. + events.USE_FROZEN_DICTS = False + def wait_on_thread(self, deferred, timeout=10): """ Wait until a Deferred is done, where it's waiting on a real thread. -- cgit 1.5.1 From 405aeb0b2c40443d22ce8c265df18e81bd995b44 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 18 Mar 2021 16:34:47 +0100 Subject: Implement MSC3026: busy presence state --- changelog.d/9644.feature | 1 + synapse/api/constants.py | 1 + synapse/app/generic_worker.py | 1 + synapse/handlers/presence.py | 3 ++- synapse/rest/client/versions.py | 2 ++ tests/handlers/test_presence.py | 20 ++++++++++++++++++++ 6 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 changelog.d/9644.feature (limited to 'tests') diff --git a/changelog.d/9644.feature b/changelog.d/9644.feature new file mode 100644 index 0000000000..556bcf0f9f --- /dev/null +++ b/changelog.d/9644.feature @@ -0,0 +1 @@ +Implement the busy presence state as described in [MSC3026](https://github.com/matrix-org/matrix-doc/pull/3026). diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 691f8f9adf..cc8541bc16 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -51,6 +51,7 @@ class PresenceState: OFFLINE = "offline" UNAVAILABLE = "unavailable" ONLINE = "online" + BUSY = "org.matrix.msc3026.busy" class JoinRules: diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 274d582d07..236d98a29d 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -439,6 +439,7 @@ class GenericWorkerPresence(BasePresenceHandler): PresenceState.ONLINE, PresenceState.UNAVAILABLE, PresenceState.OFFLINE, + PresenceState.BUSY, ) if presence not in valid_presence: raise SynapseError(400, "Invalid presence state") diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 54631b4ee2..bcb99f627b 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -730,6 +730,7 @@ class PresenceHandler(BasePresenceHandler): PresenceState.ONLINE, PresenceState.UNAVAILABLE, PresenceState.OFFLINE, + PresenceState.BUSY, ) if presence not in valid_presence: raise SynapseError(400, "Invalid presence state") @@ -744,7 +745,7 @@ class PresenceHandler(BasePresenceHandler): msg = status_msg if presence != PresenceState.OFFLINE else None new_fields["status_msg"] = msg - if presence == PresenceState.ONLINE: + if presence == PresenceState.ONLINE or presence == PresenceState.BUSY: new_fields["last_active_ts"] = self.clock.time_msec() await self._update_states([prev_state.copy_and_replace(**new_fields)]) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index d24a199318..f387d29b57 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -81,6 +81,8 @@ class VersionsRestServlet(RestServlet): "io.element.e2ee_forced.public": self.e2ee_forced_public, "io.element.e2ee_forced.private": self.e2ee_forced_private, "io.element.e2ee_forced.trusted_private": self.e2ee_forced_trusted_private, + # Supports the busy presence state described in MSC3026. + "org.matrix.msc3026.busy_presence": True, }, }, ) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 996c614198..77330f59a9 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -310,6 +310,26 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.UNAVAILABLE) + def test_busy_no_idle(self): + """ + Tests that a user setting their presence to busy but idling doesn't turn their + presence state into unavailable. + """ + user_id = "@foo:bar" + now = 5000000 + + state = UserPresenceState.default(user_id) + state = state.copy_and_replace( + state=PresenceState.BUSY, + last_active_ts=now - IDLE_TIMER - 1, + last_user_sync_ts=now, + ) + + new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) + + self.assertIsNotNone(new_state) + self.assertEquals(new_state.state, PresenceState.BUSY) + def test_sync_timeout(self): user_id = "@foo:bar" now = 5000000 -- cgit 1.5.1 From dd71eb0f8ab5a6e0d8eda3be8c2d5ff01271d147 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 18 Mar 2021 15:52:26 +0000 Subject: Make federation catchup send last event from any server. (#9640) Currently federation catchup will send the last *local* event that we failed to send to the remote. This can cause issues for large rooms where lots of servers have sent events while the remote server was down, as when it comes back up again it'll be flooded with events from various points in the DAG. Instead, let's make it so that all the servers send the most recent events, even if its not theirs. The remote should deduplicate the events, so there shouldn't be much overhead in doing this. Alternatively, the servers could only send local events if they were also extremities and hope that the other server will send the event over, but that is a bit risky. --- changelog.d/9640.misc | 1 + synapse/federation/federation_server.py | 25 +---- synapse/federation/sender/per_destination_queue.py | 104 ++++++++++++++++++--- tests/federation/test_federation_catch_up.py | 49 ++++++++++ 4 files changed, 141 insertions(+), 38 deletions(-) create mode 100644 changelog.d/9640.misc (limited to 'tests') diff --git a/changelog.d/9640.misc b/changelog.d/9640.misc new file mode 100644 index 0000000000..3d410ed4cd --- /dev/null +++ b/changelog.d/9640.misc @@ -0,0 +1 @@ +Improve performance of federation catch up by sending events the latest events in the room to the remote, rather than just the last event sent by the local server. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 9839d3d016..d84e362070 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -35,7 +35,7 @@ from twisted.internet import defer from twisted.internet.abstract import isIPAddress from twisted.python import failure -from synapse.api.constants import EduTypes, EventTypes, Membership +from synapse.api.constants import EduTypes, EventTypes from synapse.api.errors import ( AuthError, Codes, @@ -63,7 +63,7 @@ from synapse.replication.http.federation import ( ReplicationFederationSendEduRestServlet, ReplicationGetQueryRestServlet, ) -from synapse.types import JsonDict, get_domain_from_id +from synapse.types import JsonDict from synapse.util import glob_to_regex, json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache @@ -727,27 +727,6 @@ class FederationServer(FederationBase): if the event was unacceptable for any other reason (eg, too large, too many prev_events, couldn't find the prev_events) """ - # check that it's actually being sent from a valid destination to - # workaround bug #1753 in 0.18.5 and 0.18.6 - if origin != get_domain_from_id(pdu.sender): - # We continue to accept join events from any server; this is - # necessary for the federation join dance to work correctly. - # (When we join over federation, the "helper" server is - # responsible for sending out the join event, rather than the - # origin. See bug #1893. This is also true for some third party - # invites). - if not ( - pdu.type == "m.room.member" - and pdu.content - and pdu.content.get("membership", None) - in (Membership.JOIN, Membership.INVITE) - ): - logger.info( - "Discarding PDU %s from invalid origin %s", pdu.event_id, origin - ) - return - else: - logger.info("Accepting join PDU %s from %s", pdu.event_id, origin) # We've already checked that we know the room version by this point room_version = await self.store.get_room_version(pdu.room_id) diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index cc0d765e5f..af85fe0a1e 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -15,7 +15,7 @@ # limitations under the License. import datetime import logging -from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast +from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple import attr from prometheus_client import Counter @@ -77,6 +77,7 @@ class PerDestinationQueue: self._transaction_manager = transaction_manager self._instance_name = hs.get_instance_name() self._federation_shard_config = hs.config.worker.federation_shard_config + self._state = hs.get_state_handler() self._should_send_on_this_instance = True if not self._federation_shard_config.should_handle( @@ -415,22 +416,95 @@ class PerDestinationQueue: "This should not happen." % event_ids ) - if logger.isEnabledFor(logging.INFO): - rooms = [p.room_id for p in catchup_pdus] - logger.info("Catching up rooms to %s: %r", self._destination, rooms) + # We send transactions with events from one room only, as its likely + # that the remote will have to do additional processing, which may + # take some time. It's better to give it small amounts of work + # rather than risk the request timing out and repeatedly being + # retried, and not making any progress. + # + # Note: `catchup_pdus` will have exactly one PDU per room. + for pdu in catchup_pdus: + # The PDU from the DB will be the last PDU in the room from + # *this server* that wasn't sent to the remote. However, other + # servers may have sent lots of events since then, and we want + # to try and tell the remote only about the *latest* events in + # the room. This is so that it doesn't get inundated by events + # from various parts of the DAG, which all need to be processed. + # + # Note: this does mean that in large rooms a server coming back + # online will get sent the same events from all the different + # servers, but the remote will correctly deduplicate them and + # handle it only once. + + # Step 1, fetch the current extremities + extrems = await self._store.get_prev_events_for_room(pdu.room_id) + + if pdu.event_id in extrems: + # If the event is in the extremities, then great! We can just + # use that without having to do further checks. + room_catchup_pdus = [pdu] + else: + # If not, fetch the extremities and figure out which we can + # send. + extrem_events = await self._store.get_events_as_list(extrems) + + new_pdus = [] + for p in extrem_events: + # We pulled this from the DB, so it'll be non-null + assert p.internal_metadata.stream_ordering + + # Filter out events that happened before the remote went + # offline + if ( + p.internal_metadata.stream_ordering + < self._last_successful_stream_ordering + ): + continue - await self._transaction_manager.send_new_transaction( - self._destination, catchup_pdus, [] - ) + # Filter out events where the server is not in the room, + # e.g. it may have left/been kicked. *Ideally* we'd pull + # out the kick and send that, but it's a rare edge case + # so we don't bother for now (the server that sent the + # kick should send it out if its online). + hosts = await self._state.get_hosts_in_room_at_events( + p.room_id, [p.event_id] + ) + if self._destination not in hosts: + continue - sent_transactions_counter.inc() - final_pdu = catchup_pdus[-1] - self._last_successful_stream_ordering = cast( - int, final_pdu.internal_metadata.stream_ordering - ) - await self._store.set_destination_last_successful_stream_ordering( - self._destination, self._last_successful_stream_ordering - ) + new_pdus.append(p) + + # If we've filtered out all the extremities, fall back to + # sending the original event. This should ensure that the + # server gets at least some of missed events (especially if + # the other sending servers are up). + if new_pdus: + room_catchup_pdus = new_pdus + + logger.info( + "Catching up rooms to %s: %r", self._destination, pdu.room_id + ) + + await self._transaction_manager.send_new_transaction( + self._destination, room_catchup_pdus, [] + ) + + sent_transactions_counter.inc() + + # We pulled this from the DB, so it'll be non-null + assert pdu.internal_metadata.stream_ordering + + # Note that we mark the last successful stream ordering as that + # from the *original* PDU, rather than the PDU(s) we actually + # send. This is because we use it to mark our position in the + # queue of missed PDUs to process. + self._last_successful_stream_ordering = ( + pdu.internal_metadata.stream_ordering + ) + + await self._store.set_destination_last_successful_stream_ordering( + self._destination, self._last_successful_stream_ordering + ) def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]: if not self._pending_rrs: diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index 6f96cd7940..95eac6a5a3 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -2,6 +2,7 @@ from typing import List, Tuple from mock import Mock +from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.federation.sender import PerDestinationQueue, TransactionManager from synapse.federation.units import Edu @@ -421,3 +422,51 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): self.assertNotIn("zzzerver", woken) # - all destinations are woken exactly once; they appear once in woken. self.assertCountEqual(woken, server_names[:-1]) + + @override_config({"send_federation": True}) + def test_not_latest_event(self): + """Test that we send the latest event in the room even if its not ours.""" + + per_dest_queue, sent_pdus = self.make_fake_destination_queue() + + # Make a room with a local user, and two servers. One will go offline + # and one will send some events. + self.register_user("u1", "you the one") + u1_token = self.login("u1", "you the one") + room_1 = self.helper.create_room_as("u1", tok=u1_token) + + self.get_success( + event_injection.inject_member_event(self.hs, room_1, "@user:host2", "join") + ) + event_1 = self.get_success( + event_injection.inject_member_event(self.hs, room_1, "@user:host3", "join") + ) + + # First we send something from the local server, so that we notice the + # remote is down and go into catchup mode. + self.helper.send(room_1, "you hear me!!", tok=u1_token) + + # Now simulate us receiving an event from the still online remote. + event_2 = self.get_success( + event_injection.inject_event( + self.hs, + type=EventTypes.Message, + sender="@user:host3", + room_id=room_1, + content={"msgtype": "m.text", "body": "Hello"}, + ) + ) + + self.get_success( + self.hs.get_datastore().set_destination_last_successful_stream_ordering( + "host2", event_1.internal_metadata.stream_ordering + ) + ) + + self.get_success(per_dest_queue._catch_up_transmission_loop()) + + # We expect only the last message from the remote, event_2, to have been + # sent, rather than the last *local* event that was sent. + self.assertEqual(len(sent_pdus), 1) + self.assertEqual(sent_pdus[0].event_id, event_2.event_id) + self.assertFalse(per_dest_queue._catching_up) -- cgit 1.5.1 From 8dd2ea65a9566fd0850df0d989f700f61b490ed9 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 18 Mar 2021 17:54:08 +0100 Subject: Consistently check whether a password may be set for a user. (#9636) --- changelog.d/9636.bugfix | 1 + synapse/handlers/set_password.py | 2 +- synapse/rest/admin/users.py | 2 +- synapse/storage/databases/main/registration.py | 1 + tests/rest/admin/test_user.py | 173 +++++++++++++++++-------- 5 files changed, 122 insertions(+), 57 deletions(-) create mode 100644 changelog.d/9636.bugfix (limited to 'tests') diff --git a/changelog.d/9636.bugfix b/changelog.d/9636.bugfix new file mode 100644 index 0000000000..fa772ed6fc --- /dev/null +++ b/changelog.d/9636.bugfix @@ -0,0 +1 @@ +Checks if passwords are allowed before setting it for the user. \ No newline at end of file diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 84af2dde7e..04e7c64c94 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -41,7 +41,7 @@ class SetPasswordHandler(BaseHandler): logout_devices: bool, requester: Optional[Requester] = None, ) -> None: - if not self.hs.config.password_localdb_enabled: + if not self._auth_handler.can_change_password(): raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) try: diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 2c89b62e25..aaa56a7024 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -271,7 +271,7 @@ class UserRestServletV2(RestServlet): elif not deactivate and user["deactivated"]: if ( "password" not in body - and self.hs.config.password_localdb_enabled + and self.auth_handler.can_change_password() ): raise SynapseError( 400, "Must provide a password to re-activate an account." diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index eba66ff352..90a8f664ef 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1210,6 +1210,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): self._invalidate_cache_and_stream( txn, self.get_user_deactivated_status, (user_id,) ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,)) @cached() diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index e58d5cf0db..cf61f284cb 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1003,12 +1003,23 @@ class UserRestTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() + self.auth_handler = hs.get_auth_handler() + # create users and get access tokens + # regardless of whether password login or SSO is allowed self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") + self.admin_user_tok = self.get_success( + self.auth_handler.get_access_token_for_user_id( + self.admin_user, device_id=None, valid_until_ms=None + ) + ) self.other_user = self.register_user("user", "pass", displayname="User") - self.other_user_token = self.login("user", "pass") + self.other_user_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + self.other_user, device_id=None, valid_until_ms=None + ) + ) self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote( self.other_user ) @@ -1081,7 +1092,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) - self.assertEqual(True, channel.json_body["admin"]) + self.assertTrue(channel.json_body["admin"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) # Get user @@ -1096,9 +1107,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) - self.assertEqual(True, channel.json_body["admin"]) - self.assertEqual(False, channel.json_body["is_guest"]) - self.assertEqual(False, channel.json_body["deactivated"]) + self.assertTrue(channel.json_body["admin"]) + self.assertFalse(channel.json_body["is_guest"]) + self.assertFalse(channel.json_body["deactivated"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) def test_create_user(self): @@ -1130,7 +1141,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) - self.assertEqual(False, channel.json_body["admin"]) + self.assertFalse(channel.json_body["admin"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) # Get user @@ -1145,10 +1156,10 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) - self.assertEqual(False, channel.json_body["admin"]) - self.assertEqual(False, channel.json_body["is_guest"]) - self.assertEqual(False, channel.json_body["deactivated"]) - self.assertEqual(False, channel.json_body["shadow_banned"]) + self.assertFalse(channel.json_body["admin"]) + self.assertFalse(channel.json_body["is_guest"]) + self.assertFalse(channel.json_body["deactivated"]) + self.assertFalse(channel.json_body["shadow_banned"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) @override_config( @@ -1197,7 +1208,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) - self.assertEqual(False, channel.json_body["admin"]) + self.assertFalse(channel.json_body["admin"]) @override_config( {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0} @@ -1237,7 +1248,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Admin user is not blocked by mau anymore self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) - self.assertEqual(False, channel.json_body["admin"]) + self.assertFalse(channel.json_body["admin"]) @override_config( { @@ -1429,24 +1440,23 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(False, channel.json_body["deactivated"]) + self.assertFalse(channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User", channel.json_body["displayname"]) # Deactivate user - body = json.dumps({"deactivated": True}) - channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"deactivated": True}, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(True, channel.json_body["deactivated"]) + self.assertTrue(channel.json_body["deactivated"]) + self.assertIsNone(channel.json_body["password_hash"]) self.assertEqual(0, len(channel.json_body["threepids"])) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User", channel.json_body["displayname"]) @@ -1461,7 +1471,8 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(True, channel.json_body["deactivated"]) + self.assertTrue(channel.json_body["deactivated"]) + self.assertIsNone(channel.json_body["password_hash"]) self.assertEqual(0, len(channel.json_body["threepids"])) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User", channel.json_body["displayname"]) @@ -1478,41 +1489,37 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertTrue(profile["display_name"] == "User") # Deactivate user - body = json.dumps({"deactivated": True}) - channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"deactivated": True}, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(True, channel.json_body["deactivated"]) + self.assertTrue(channel.json_body["deactivated"]) # is not in user directory profile = self.get_success(self.store.get_user_in_directory(self.other_user)) - self.assertTrue(profile is None) + self.assertIsNone(profile) # Set new displayname user - body = json.dumps({"displayname": "Foobar"}) - channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"displayname": "Foobar"}, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(True, channel.json_body["deactivated"]) + self.assertTrue(channel.json_body["deactivated"]) self.assertEqual("Foobar", channel.json_body["displayname"]) # is not in user directory profile = self.get_success(self.store.get_user_in_directory(self.other_user)) - self.assertTrue(profile is None) + self.assertIsNone(profile) def test_reactivate_user(self): """ @@ -1520,48 +1527,92 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ # Deactivate the user. + self._deactivate_user("@user:test") + + # Attempt to reactivate the user (without a password). + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"deactivated": False}, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + + # Reactivate the user. channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=json.dumps({"deactivated": True}).encode(encoding="utf_8"), + content={"deactivated": False, "password": "foo"}, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertFalse(channel.json_body["deactivated"]) + self.assertIsNotNone(channel.json_body["password_hash"]) self._is_erased("@user:test", False) - d = self.store.mark_user_erased("@user:test") - self.assertIsNone(self.get_success(d)) - self._is_erased("@user:test", True) - # Attempt to reactivate the user (without a password). + @override_config({"password_config": {"localdb_enabled": False}}) + def test_reactivate_user_localdb_disabled(self): + """ + Test reactivating another user when using SSO. + """ + + # Deactivate the user. + self._deactivate_user("@user:test") + + # Reactivate the user with a password channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=json.dumps({"deactivated": False}).encode(encoding="utf_8"), + content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - # Reactivate the user. + # Reactivate the user without a password. channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=json.dumps({"deactivated": False, "password": "foo"}).encode( - encoding="utf_8" - ), + content={"deactivated": False}, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertFalse(channel.json_body["deactivated"]) + self.assertIsNone(channel.json_body["password_hash"]) + self._is_erased("@user:test", False) - # Get user + @override_config({"password_config": {"enabled": False}}) + def test_reactivate_user_password_disabled(self): + """ + Test reactivating another user when using SSO. + """ + + # Deactivate the user. + self._deactivate_user("@user:test") + + # Reactivate the user with a password channel = self.make_request( - "GET", + "PUT", self.url_other_user, access_token=self.admin_user_tok, + content={"deactivated": False, "password": "foo"}, ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + # Reactivate the user without a password. + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"deactivated": False}, + ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(False, channel.json_body["deactivated"]) + self.assertFalse(channel.json_body["deactivated"]) + self.assertIsNone(channel.json_body["password_hash"]) self._is_erased("@user:test", False) def test_set_user_as_admin(self): @@ -1570,18 +1621,16 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ # Set a user as an admin - body = json.dumps({"admin": True}) - channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"admin": True}, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(True, channel.json_body["admin"]) + self.assertTrue(channel.json_body["admin"]) # Get user channel = self.make_request( @@ -1592,7 +1641,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(True, channel.json_body["admin"]) + self.assertTrue(channel.json_body["admin"]) def test_accidental_deactivation_prevention(self): """ @@ -1602,13 +1651,11 @@ class UserRestTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users/@bob:test" # Create user - body = json.dumps({"password": "abc123"}) - channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"password": "abc123"}, ) self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) @@ -1628,13 +1675,11 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(0, channel.json_body["deactivated"]) # Change password (and use a str for deactivate instead of a bool) - body = json.dumps({"password": "abc123", "deactivated": "false"}) # oops! - channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"password": "abc123", "deactivated": "false"}, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -1653,7 +1698,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Ensure they're still alive self.assertEqual(0, channel.json_body["deactivated"]) - def _is_erased(self, user_id, expect): + def _is_erased(self, user_id: str, expect: bool) -> None: """Assert that the user is erased or not""" d = self.store.is_user_erased(user_id) if expect: @@ -1661,6 +1706,24 @@ class UserRestTestCase(unittest.HomeserverTestCase): else: self.assertFalse(self.get_success(d)) + def _deactivate_user(self, user_id: str) -> None: + """Deactivate user and set as erased""" + + # Deactivate the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % urllib.parse.quote(user_id), + access_token=self.admin_user_tok, + content={"deactivated": True}, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertTrue(channel.json_body["deactivated"]) + self.assertIsNone(channel.json_body["password_hash"]) + self._is_erased(user_id, False) + d = self.store.mark_user_erased(user_id) + self.assertIsNone(self.get_success(d)) + self._is_erased(user_id, True) + class UserMembershipRestTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From 5b268997bdde04c905eed0a7c04e9e2352e35fba Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 22 Mar 2021 17:20:47 +0000 Subject: Allow providing credentials to HTTPS_PROXY (#9657) Addresses https://github.com/matrix-org/synapse-dinsic/issues/70 This PR causes `ProxyAgent` to attempt to extract credentials from an `HTTPS_PROXY` env var. If credentials are found, a `Proxy-Authorization` header ([details](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Proxy-Authorization)) is sent to the proxy server to authenticate against it. The headers are *not* passed to the remote server. Also added some type hints. --- changelog.d/9657.feature | 1 + synapse/http/connectproxyclient.py | 96 +++++++++++++++++++++++++++----------- synapse/http/proxyagent.py | 81 ++++++++++++++++++++++++++++---- tests/http/test_proxyagent.py | 40 ++++++++++++++++ 4 files changed, 184 insertions(+), 34 deletions(-) create mode 100644 changelog.d/9657.feature (limited to 'tests') diff --git a/changelog.d/9657.feature b/changelog.d/9657.feature new file mode 100644 index 0000000000..c56a615a8b --- /dev/null +++ b/changelog.d/9657.feature @@ -0,0 +1 @@ +Add support for credentials for proxy authentication in the `HTTPS_PROXY` environment variable. diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py index 856e28454f..b797e3ce80 100644 --- a/synapse/http/connectproxyclient.py +++ b/synapse/http/connectproxyclient.py @@ -19,9 +19,10 @@ from zope.interface import implementer from twisted.internet import defer, protocol from twisted.internet.error import ConnectError -from twisted.internet.interfaces import IStreamClientEndpoint -from twisted.internet.protocol import connectionDone +from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint +from twisted.internet.protocol import ClientFactory, Protocol, connectionDone from twisted.web import http +from twisted.web.http_headers import Headers logger = logging.getLogger(__name__) @@ -43,23 +44,33 @@ class HTTPConnectProxyEndpoint: Args: reactor: the Twisted reactor to use for the connection - proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the - proxy - host (bytes): hostname that we want to CONNECT to - port (int): port that we want to connect to + proxy_endpoint: the endpoint to use to connect to the proxy + host: hostname that we want to CONNECT to + port: port that we want to connect to + headers: Extra HTTP headers to include in the CONNECT request """ - def __init__(self, reactor, proxy_endpoint, host, port): + def __init__( + self, + reactor: IReactorCore, + proxy_endpoint: IStreamClientEndpoint, + host: bytes, + port: int, + headers: Headers, + ): self._reactor = reactor self._proxy_endpoint = proxy_endpoint self._host = host self._port = port + self._headers = headers def __repr__(self): return "" % (self._proxy_endpoint,) - def connect(self, protocolFactory): - f = HTTPProxiedClientFactory(self._host, self._port, protocolFactory) + def connect(self, protocolFactory: ClientFactory): + f = HTTPProxiedClientFactory( + self._host, self._port, protocolFactory, self._headers + ) d = self._proxy_endpoint.connect(f) # once the tcp socket connects successfully, we need to wait for the # CONNECT to complete. @@ -74,15 +85,23 @@ class HTTPProxiedClientFactory(protocol.ClientFactory): HTTP Protocol object and run the rest of the connection. Args: - dst_host (bytes): hostname that we want to CONNECT to - dst_port (int): port that we want to connect to - wrapped_factory (protocol.ClientFactory): The original Factory + dst_host: hostname that we want to CONNECT to + dst_port: port that we want to connect to + wrapped_factory: The original Factory + headers: Extra HTTP headers to include in the CONNECT request """ - def __init__(self, dst_host, dst_port, wrapped_factory): + def __init__( + self, + dst_host: bytes, + dst_port: int, + wrapped_factory: ClientFactory, + headers: Headers, + ): self.dst_host = dst_host self.dst_port = dst_port self.wrapped_factory = wrapped_factory + self.headers = headers self.on_connection = defer.Deferred() def startedConnecting(self, connector): @@ -92,7 +111,11 @@ class HTTPProxiedClientFactory(protocol.ClientFactory): wrapped_protocol = self.wrapped_factory.buildProtocol(addr) return HTTPConnectProtocol( - self.dst_host, self.dst_port, wrapped_protocol, self.on_connection + self.dst_host, + self.dst_port, + wrapped_protocol, + self.on_connection, + self.headers, ) def clientConnectionFailed(self, connector, reason): @@ -112,24 +135,37 @@ class HTTPConnectProtocol(protocol.Protocol): """Protocol that wraps an existing Protocol to do a CONNECT handshake at connect Args: - host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal + host: The original HTTP(s) hostname or IPv4 or IPv6 address literal to put in the CONNECT request - port (int): The original HTTP(s) port to put in the CONNECT request + port: The original HTTP(s) port to put in the CONNECT request - wrapped_protocol (interfaces.IProtocol): the original protocol (probably - HTTPChannel or TLSMemoryBIOProtocol, but could be anything really) + wrapped_protocol: the original protocol (probably HTTPChannel or + TLSMemoryBIOProtocol, but could be anything really) - connected_deferred (Deferred): a Deferred which will be callbacked with + connected_deferred: a Deferred which will be callbacked with wrapped_protocol when the CONNECT completes + + headers: Extra HTTP headers to include in the CONNECT request """ - def __init__(self, host, port, wrapped_protocol, connected_deferred): + def __init__( + self, + host: bytes, + port: int, + wrapped_protocol: Protocol, + connected_deferred: defer.Deferred, + headers: Headers, + ): self.host = host self.port = port self.wrapped_protocol = wrapped_protocol self.connected_deferred = connected_deferred - self.http_setup_client = HTTPConnectSetupClient(self.host, self.port) + self.headers = headers + + self.http_setup_client = HTTPConnectSetupClient( + self.host, self.port, self.headers + ) self.http_setup_client.on_connected.addCallback(self.proxyConnected) def connectionMade(self): @@ -154,7 +190,7 @@ class HTTPConnectProtocol(protocol.Protocol): if buf: self.wrapped_protocol.dataReceived(buf) - def dataReceived(self, data): + def dataReceived(self, data: bytes): # if we've set up the HTTP protocol, we can send the data there if self.wrapped_protocol.connected: return self.wrapped_protocol.dataReceived(data) @@ -168,21 +204,29 @@ class HTTPConnectSetupClient(http.HTTPClient): """HTTPClient protocol to send a CONNECT message for proxies and read the response. Args: - host (bytes): The hostname to send in the CONNECT message - port (int): The port to send in the CONNECT message + host: The hostname to send in the CONNECT message + port: The port to send in the CONNECT message + headers: Extra headers to send with the CONNECT message """ - def __init__(self, host, port): + def __init__(self, host: bytes, port: int, headers: Headers): self.host = host self.port = port + self.headers = headers self.on_connected = defer.Deferred() def connectionMade(self): logger.debug("Connected to proxy, sending CONNECT") self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port)) + + # Send any additional specified headers + for name, values in self.headers.getAllRawHeaders(): + for value in values: + self.sendHeader(name, value) + self.endHeaders() - def handleStatus(self, version, status, message): + def handleStatus(self, version: bytes, status: bytes, message: bytes): logger.debug("Got Status: %s %s %s", status, message, version) if status != b"200": raise ProxyConnectError("Unexpected status on CONNECT: %s" % status) diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 3d553ae236..16ec850064 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -12,10 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import base64 import logging import re +from typing import Optional, Tuple from urllib.request import getproxies_environment, proxy_bypass_environment +import attr from zope.interface import implementer from twisted.internet import defer @@ -23,6 +26,7 @@ from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.python.failure import Failure from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase from twisted.web.error import SchemeNotSupported +from twisted.web.http_headers import Headers from twisted.web.iweb import IAgent from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint @@ -32,6 +36,22 @@ logger = logging.getLogger(__name__) _VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z") +@attr.s +class ProxyCredentials: + username_password = attr.ib(type=bytes) + + def as_proxy_authorization_value(self) -> bytes: + """ + Return the value for a Proxy-Authorization header (i.e. 'Basic abdef=='). + + Returns: + A transformation of the authentication string the encoded value for + a Proxy-Authorization header. + """ + # Encode as base64 and prepend the authorization type + return b"Basic " + base64.encodebytes(self.username_password) + + @implementer(IAgent) class ProxyAgent(_AgentBase): """An Agent implementation which will use an HTTP proxy if one was requested @@ -96,6 +116,9 @@ class ProxyAgent(_AgentBase): https_proxy = proxies["https"].encode() if "https" in proxies else None no_proxy = proxies["no"] if "no" in proxies else None + # Parse credentials from https proxy connection string if present + self.https_proxy_creds, https_proxy = parse_username_password(https_proxy) + self.http_proxy_endpoint = _http_proxy_endpoint( http_proxy, self.proxy_reactor, **self._endpoint_kwargs ) @@ -175,11 +198,22 @@ class ProxyAgent(_AgentBase): and self.https_proxy_endpoint and not should_skip_proxy ): + connect_headers = Headers() + + # Determine whether we need to set Proxy-Authorization headers + if self.https_proxy_creds: + # Set a Proxy-Authorization header + connect_headers.addRawHeader( + b"Proxy-Authorization", + self.https_proxy_creds.as_proxy_authorization_value(), + ) + endpoint = HTTPConnectProxyEndpoint( self.proxy_reactor, self.https_proxy_endpoint, parsed_uri.host, parsed_uri.port, + headers=connect_headers, ) else: # not using a proxy @@ -208,12 +242,16 @@ class ProxyAgent(_AgentBase): ) -def _http_proxy_endpoint(proxy, reactor, **kwargs): +def _http_proxy_endpoint(proxy: Optional[bytes], reactor, **kwargs): """Parses an http proxy setting and returns an endpoint for the proxy Args: - proxy (bytes|None): the proxy setting + proxy: the proxy setting in the form: [:@][:] + Note that compared to other apps, this function currently lacks support + for specifying a protocol schema (i.e. protocol://...). + reactor: reactor to be used to connect to the proxy + kwargs: other args to be passed to HostnameEndpoint Returns: @@ -223,16 +261,43 @@ def _http_proxy_endpoint(proxy, reactor, **kwargs): if proxy is None: return None - # currently we only support hostname:port. Some apps also support - # protocol://[:port], which allows a way of requiring a TLS connection to the - # proxy. - + # Parse the connection string host, port = parse_host_port(proxy, default_port=1080) return HostnameEndpoint(reactor, host, port, **kwargs) -def parse_host_port(hostport, default_port=None): - # could have sworn we had one of these somewhere else... +def parse_username_password(proxy: bytes) -> Tuple[Optional[ProxyCredentials], bytes]: + """ + Parses the username and password from a proxy declaration e.g + username:password@hostname:port. + + Args: + proxy: The proxy connection string. + + Returns + An instance of ProxyCredentials and the proxy connection string with any credentials + stripped, i.e u:p@host:port -> host:port. If no credentials were found, the + ProxyCredentials instance is replaced with None. + """ + if proxy and b"@" in proxy: + # We use rsplit here as the password could contain an @ character + credentials, proxy_without_credentials = proxy.rsplit(b"@", 1) + return ProxyCredentials(credentials), proxy_without_credentials + + return None, proxy + + +def parse_host_port(hostport: bytes, default_port: int = None) -> Tuple[bytes, int]: + """ + Parse the hostname and port from a proxy connection byte string. + + Args: + hostport: The proxy connection string. Must be in the form 'host[:port]'. + default_port: The default port to return if one is not found in `hostport`. + + Returns: + A tuple containing the hostname and port. Uses `default_port` if one was not found. + """ if b":" in hostport: host, port = hostport.rsplit(b":", 1) try: diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py index 505ffcd300..3ea8b5bec7 100644 --- a/tests/http/test_proxyagent.py +++ b/tests/http/test_proxyagent.py @@ -12,8 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import base64 import logging import os +from typing import Optional from unittest.mock import patch import treq @@ -242,6 +244,21 @@ class MatrixFederationAgentTests(TestCase): @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"}) def test_https_request_via_proxy(self): + """Tests that TLS-encrypted requests can be made through a proxy""" + self._do_https_request_via_proxy(auth_credentials=None) + + @patch.dict( + os.environ, + {"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"}, + ) + def test_https_request_via_proxy_with_auth(self): + """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" + self._do_https_request_via_proxy(auth_credentials="bob:pinkponies") + + def _do_https_request_via_proxy( + self, + auth_credentials: Optional[str] = None, + ): agent = ProxyAgent( self.reactor, contextFactory=get_test_https_policy(), @@ -278,6 +295,22 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(request.method, b"CONNECT") self.assertEqual(request.path, b"test.com:443") + # Check whether auth credentials have been supplied to the proxy + proxy_auth_header_values = request.requestHeaders.getRawHeaders( + b"Proxy-Authorization" + ) + + if auth_credentials is not None: + # Compute the correct header value for Proxy-Authorization + encoded_credentials = base64.b64encode(b"bob:pinkponies") + expected_header_value = b"Basic " + encoded_credentials + + # Validate the header's value + self.assertIn(expected_header_value, proxy_auth_header_values) + else: + # Check that the Proxy-Authorization header has not been supplied to the proxy + self.assertIsNone(proxy_auth_header_values) + # tell the proxy server not to close the connection proxy_server.persistent = True @@ -312,6 +345,13 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"/abc") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + + # Check that the destination server DID NOT receive proxy credentials + proxy_auth_header_values = request.requestHeaders.getRawHeaders( + b"Proxy-Authorization" + ) + self.assertIsNone(proxy_auth_header_values) + request.write(b"result") request.finish() -- cgit 1.5.1