From c4675e1b24f06a72c323c8131eab4998b4e71af1 Mon Sep 17 00:00:00 2001 From: David Florness Date: Wed, 2 Dec 2020 10:01:15 -0500 Subject: Add additional validation for the admin register endpoint. (#8837) Raise a proper 400 error if the `mac` field is missing. --- synapse/rest/admin/users.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'synapse/rest') diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index b0ff5e1ead..90940ff185 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -420,6 +420,9 @@ class UserRegisterServlet(RestServlet): if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: raise SynapseError(400, "Invalid user type") + if "mac" not in body: + raise SynapseError(400, "mac must be specified", errcode=Codes.BAD_JSON) + got_mac = body["mac"] want_mac_builder = hmac.new( -- cgit 1.5.1 From 30fba6210834a4ecd91badf0c8f3eb278b72e746 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 2 Dec 2020 11:09:24 -0500 Subject: Apply an IP range blacklist to push and key revocation requests. (#8821) Replaces the `federation_ip_range_blacklist` configuration setting with an `ip_range_blacklist` setting with wider scope. It now applies to: * Federation * Identity servers * Push notifications * Checking key validitity for third-party invite events The old `federation_ip_range_blacklist` setting is still honored if present, but with reduced scope (it only applies to federation and identity servers). --- changelog.d/8821.bugfix | 1 + docs/sample_config.yaml | 14 ++++--- synapse/app/generic_worker.py | 1 - synapse/config/federation.py | 40 ++++++++++++------- synapse/crypto/keyring.py | 4 +- synapse/federation/federation_server.py | 1 - synapse/federation/transport/client.py | 2 +- synapse/handlers/federation.py | 2 +- synapse/handlers/identity.py | 6 +-- synapse/http/client.py | 46 +++++++++++++++------- synapse/http/federation/matrix_federation_agent.py | 16 ++++++-- synapse/http/matrixfederationclient.py | 26 ++++-------- synapse/push/httppusher.py | 2 +- synapse/rest/media/v1/media_repository.py | 2 +- synapse/server.py | 36 +++++++++++++---- tests/api/test_filtering.py | 4 +- tests/app/test_frontend_proxy.py | 2 +- tests/app/test_openid_listener.py | 4 +- tests/crypto/test_keyring.py | 6 ++- tests/handlers/test_device.py | 4 +- tests/handlers/test_directory.py | 2 +- tests/handlers/test_federation.py | 2 +- tests/handlers/test_presence.py | 2 +- tests/handlers/test_profile.py | 2 +- tests/handlers/test_typing.py | 6 +-- .../federation/test_matrix_federation_agent.py | 3 ++ tests/push/test_http.py | 4 +- tests/replication/_base.py | 4 +- tests/replication/test_federation_sender_shard.py | 10 ++--- tests/replication/test_pusher_shard.py | 6 +-- tests/rest/admin/test_admin.py | 2 +- tests/rest/client/v1/test_presence.py | 2 +- tests/rest/client/v1/test_profile.py | 2 +- tests/rest/client/v1/test_rooms.py | 2 +- tests/rest/client/v1/test_typing.py | 2 +- tests/rest/key/v2/test_remote_key_resource.py | 4 +- tests/rest/media/v1/test_media_storage.py | 2 +- tests/storage/test_e2e_room_keys.py | 2 +- tests/storage/test_purge.py | 2 +- tests/storage/test_redaction.py | 2 +- tests/storage/test_roommember.py | 2 +- tests/test_federation.py | 2 +- tests/test_server.py | 5 ++- 43 files changed, 176 insertions(+), 115 deletions(-) create mode 100644 changelog.d/8821.bugfix (limited to 'synapse/rest') diff --git a/changelog.d/8821.bugfix b/changelog.d/8821.bugfix new file mode 100644 index 0000000000..8ddfbf31ce --- /dev/null +++ b/changelog.d/8821.bugfix @@ -0,0 +1 @@ +Apply the `federation_ip_range_blacklist` to push and key revocation requests. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 394eb9a3ff..6dbccf5932 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -642,17 +642,19 @@ acme: # - nyc.example.com # - syd.example.com -# Prevent federation requests from being sent to the following -# blacklist IP address CIDR ranges. If this option is not specified, or -# specified with an empty list, no ip range blacklist will be enforced. +# Prevent outgoing requests from being sent to the following blacklisted IP address +# CIDR ranges. If this option is not specified, or specified with an empty list, +# no IP range blacklist will be enforced. # -# As of Synapse v1.4.0 this option also affects any outbound requests to identity -# servers provided by user input. +# The blacklist applies to the outbound requests for federation, identity servers, +# push servers, and for checking key validitity for third-party invite events. # # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly # listed here, since they correspond to unroutable addresses.) # -federation_ip_range_blacklist: +# This option replaces federation_ip_range_blacklist in Synapse v1.24.0. +# +ip_range_blacklist: - '127.0.0.0/8' - '10.0.0.0/8' - '172.16.0.0/12' diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 1b511890aa..aa12c74358 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -266,7 +266,6 @@ class GenericWorkerPresence(BasePresenceHandler): super().__init__(hs) self.hs = hs self.is_mine_id = hs.is_mine_id - self.http_client = hs.get_simple_http_client() self._presence_enabled = hs.config.use_presence diff --git a/synapse/config/federation.py b/synapse/config/federation.py index ffd8fca54e..27ccf61c3c 100644 --- a/synapse/config/federation.py +++ b/synapse/config/federation.py @@ -36,22 +36,30 @@ class FederationConfig(Config): for domain in federation_domain_whitelist: self.federation_domain_whitelist[domain] = True - self.federation_ip_range_blacklist = config.get( - "federation_ip_range_blacklist", [] - ) + ip_range_blacklist = config.get("ip_range_blacklist", []) # Attempt to create an IPSet from the given ranges try: - self.federation_ip_range_blacklist = IPSet( - self.federation_ip_range_blacklist - ) - - # Always blacklist 0.0.0.0, :: - self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) + self.ip_range_blacklist = IPSet(ip_range_blacklist) + except Exception as e: + raise ConfigError("Invalid range(s) provided in ip_range_blacklist: %s" % e) + # Always blacklist 0.0.0.0, :: + self.ip_range_blacklist.update(["0.0.0.0", "::"]) + + # The federation_ip_range_blacklist is used for backwards-compatibility + # and only applies to federation and identity servers. If it is not given, + # default to ip_range_blacklist. + federation_ip_range_blacklist = config.get( + "federation_ip_range_blacklist", ip_range_blacklist + ) + try: + self.federation_ip_range_blacklist = IPSet(federation_ip_range_blacklist) except Exception as e: raise ConfigError( "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e ) + # Always blacklist 0.0.0.0, :: + self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) federation_metrics_domains = config.get("federation_metrics_domains") or [] validate_config( @@ -76,17 +84,19 @@ class FederationConfig(Config): # - nyc.example.com # - syd.example.com - # Prevent federation requests from being sent to the following - # blacklist IP address CIDR ranges. If this option is not specified, or - # specified with an empty list, no ip range blacklist will be enforced. + # Prevent outgoing requests from being sent to the following blacklisted IP address + # CIDR ranges. If this option is not specified, or specified with an empty list, + # no IP range blacklist will be enforced. # - # As of Synapse v1.4.0 this option also affects any outbound requests to identity - # servers provided by user input. + # The blacklist applies to the outbound requests for federation, identity servers, + # push servers, and for checking key validitity for third-party invite events. # # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly # listed here, since they correspond to unroutable addresses.) # - federation_ip_range_blacklist: + # This option replaces federation_ip_range_blacklist in Synapse v1.24.0. + # + ip_range_blacklist: - '127.0.0.0/8' - '10.0.0.0/8' - '172.16.0.0/12' diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index c04ad77cf9..f23eacc0d7 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -578,7 +578,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): def __init__(self, hs): super().__init__(hs) self.clock = hs.get_clock() - self.client = hs.get_http_client() + self.client = hs.get_federation_http_client() self.key_servers = self.config.key_servers async def get_keys(self, keys_to_fetch): @@ -748,7 +748,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher): def __init__(self, hs): super().__init__(hs) self.clock = hs.get_clock() - self.client = hs.get_http_client() + self.client = hs.get_federation_http_client() async def get_keys(self, keys_to_fetch): """ diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 4b6ab470d0..35e345ce70 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -845,7 +845,6 @@ class FederationHandlerRegistry: def __init__(self, hs: "HomeServer"): self.config = hs.config - self.http_client = hs.get_simple_http_client() self.clock = hs.get_clock() self._instance_name = hs.get_instance_name() diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 17a10f622e..abe9168c78 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -35,7 +35,7 @@ class TransportLayerClient: def __init__(self, hs): self.server_name = hs.hostname - self.client = hs.get_http_client() + self.client = hs.get_federation_http_client() @log_function def get_room_state_ids(self, destination, room_id, event_id): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index b9799090f7..df82e60b33 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -140,7 +140,7 @@ class FederationHandler(BaseHandler): self._message_handler = hs.get_message_handler() self._server_notices_mxid = hs.config.server_notices_mxid self.config = hs.config - self.http_client = hs.get_simple_http_client() + self.http_client = hs.get_proxied_blacklisted_http_client() self._instance_name = hs.get_instance_name() self._replication = hs.get_replication_data_handler() diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 9b3c6b4551..7301c24710 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -46,13 +46,13 @@ class IdentityHandler(BaseHandler): def __init__(self, hs): super().__init__(hs) + # An HTTP client for contacting trusted URLs. self.http_client = SimpleHttpClient(hs) - # We create a blacklisting instance of SimpleHttpClient for contacting identity - # servers specified by clients + # An HTTP client for contacting identity servers specified by clients. self.blacklisting_http_client = SimpleHttpClient( hs, ip_blacklist=hs.config.federation_ip_range_blacklist ) - self.federation_http_client = hs.get_http_client() + self.federation_http_client = hs.get_federation_http_client() self.hs = hs async def threepid_from_creds( diff --git a/synapse/http/client.py b/synapse/http/client.py index e5b13593f2..df7730078f 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -125,7 +125,7 @@ def _make_scheduler(reactor): return _scheduler -class IPBlacklistingResolver: +class _IPBlacklistingResolver: """ A proxy for reactor.nameResolver which only produces non-blacklisted IP addresses, preventing DNS rebinding attacks on URL preview. @@ -199,6 +199,35 @@ class IPBlacklistingResolver: return r +@implementer(IReactorPluggableNameResolver) +class BlacklistingReactorWrapper: + """ + A Reactor wrapper which will prevent DNS resolution to blacklisted IP + addresses, to prevent DNS rebinding. + """ + + def __init__( + self, + reactor: IReactorPluggableNameResolver, + ip_whitelist: Optional[IPSet], + ip_blacklist: IPSet, + ): + self._reactor = reactor + + # We need to use a DNS resolver which filters out blacklisted IP + # addresses, to prevent DNS rebinding. + self._nameResolver = _IPBlacklistingResolver( + self._reactor, ip_whitelist, ip_blacklist + ) + + def __getattr__(self, attr: str) -> Any: + # Passthrough to the real reactor except for the DNS resolver. + if attr == "nameResolver": + return self._nameResolver + else: + return getattr(self._reactor, attr) + + class BlacklistingAgentWrapper(Agent): """ An Agent wrapper which will prevent access to IP addresses being accessed @@ -292,22 +321,11 @@ class SimpleHttpClient: self.user_agent = self.user_agent.encode("ascii") if self._ip_blacklist: - real_reactor = hs.get_reactor() # If we have an IP blacklist, we need to use a DNS resolver which # filters out blacklisted IP addresses, to prevent DNS rebinding. - nameResolver = IPBlacklistingResolver( - real_reactor, self._ip_whitelist, self._ip_blacklist + self.reactor = BlacklistingReactorWrapper( + hs.get_reactor(), self._ip_whitelist, self._ip_blacklist ) - - @implementer(IReactorPluggableNameResolver) - class Reactor: - def __getattr__(_self, attr): - if attr == "nameResolver": - return nameResolver - else: - return getattr(real_reactor, attr) - - self.reactor = Reactor() else: self.reactor = hs.get_reactor() diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index e77f9587d0..3b756a7dc2 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -16,7 +16,7 @@ import logging import urllib.parse from typing import List, Optional -from netaddr import AddrFormatError, IPAddress +from netaddr import AddrFormatError, IPAddress, IPSet from zope.interface import implementer from twisted.internet import defer @@ -31,6 +31,7 @@ from twisted.web.http_headers import Headers from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer from synapse.crypto.context_factory import FederationPolicyForHTTPS +from synapse.http.client import BlacklistingAgentWrapper from synapse.http.federation.srv_resolver import Server, SrvResolver from synapse.http.federation.well_known_resolver import WellKnownResolver from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -70,6 +71,7 @@ class MatrixFederationAgent: reactor: IReactorCore, tls_client_options_factory: Optional[FederationPolicyForHTTPS], user_agent: bytes, + ip_blacklist: IPSet, _srv_resolver: Optional[SrvResolver] = None, _well_known_resolver: Optional[WellKnownResolver] = None, ): @@ -90,12 +92,18 @@ class MatrixFederationAgent: self.user_agent = user_agent if _well_known_resolver is None: + # Note that the name resolver has already been wrapped in a + # IPBlacklistingResolver by MatrixFederationHttpClient. _well_known_resolver = WellKnownResolver( self._reactor, - agent=Agent( + agent=BlacklistingAgentWrapper( + Agent( + self._reactor, + pool=self._pool, + contextFactory=tls_client_options_factory, + ), self._reactor, - pool=self._pool, - contextFactory=tls_client_options_factory, + ip_blacklist=ip_blacklist, ), user_agent=self.user_agent, ) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 4e27f93b7a..c962994727 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -26,11 +26,10 @@ import treq from canonicaljson import encode_canonical_json from prometheus_client import Counter from signedjson.sign import sign_json -from zope.interface import implementer from twisted.internet import defer from twisted.internet.error import DNSLookupError -from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime +from twisted.internet.interfaces import IReactorTime from twisted.internet.task import _EPSILON, Cooperator from twisted.web.http_headers import Headers from twisted.web.iweb import IBodyProducer, IResponse @@ -45,7 +44,7 @@ from synapse.api.errors import ( from synapse.http import QuieterFileBodyProducer from synapse.http.client import ( BlacklistingAgentWrapper, - IPBlacklistingResolver, + BlacklistingReactorWrapper, encode_query_args, readBodyToFile, ) @@ -221,31 +220,22 @@ class MatrixFederationHttpClient: self.signing_key = hs.signing_key self.server_name = hs.hostname - real_reactor = hs.get_reactor() - # We need to use a DNS resolver which filters out blacklisted IP # addresses, to prevent DNS rebinding. - nameResolver = IPBlacklistingResolver( - real_reactor, None, hs.config.federation_ip_range_blacklist + self.reactor = BlacklistingReactorWrapper( + hs.get_reactor(), None, hs.config.federation_ip_range_blacklist ) - @implementer(IReactorPluggableNameResolver) - class Reactor: - def __getattr__(_self, attr): - if attr == "nameResolver": - return nameResolver - else: - return getattr(real_reactor, attr) - - self.reactor = Reactor() - user_agent = hs.version_string if hs.config.user_agent_suffix: user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix) user_agent = user_agent.encode("ascii") self.agent = MatrixFederationAgent( - self.reactor, tls_client_options_factory, user_agent + self.reactor, + tls_client_options_factory, + user_agent, + hs.config.federation_ip_range_blacklist, ) # Use a BlacklistingAgentWrapper to prevent circumventing the IP diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index eff0975b6a..0e845212a9 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -100,7 +100,7 @@ class HttpPusher: if "url" not in self.data: raise PusherConfigException("'url' required in data for HTTP pusher") self.url = self.data["url"] - self.http_client = hs.get_proxied_http_client() + self.http_client = hs.get_proxied_blacklisted_http_client() self.data_minus_url = {} self.data_minus_url.update(self.data) del self.data_minus_url["url"] diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 9cac74ebd8..83beb02b05 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -66,7 +66,7 @@ class MediaRepository: def __init__(self, hs): self.hs = hs self.auth = hs.get_auth() - self.client = hs.get_http_client() + self.client = hs.get_federation_http_client() self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() diff --git a/synapse/server.py b/synapse/server.py index b017e3489f..9af759626e 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -350,16 +350,45 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_simple_http_client(self) -> SimpleHttpClient: + """ + An HTTP client with no special configuration. + """ return SimpleHttpClient(self) @cache_in_self def get_proxied_http_client(self) -> SimpleHttpClient: + """ + An HTTP client that uses configured HTTP(S) proxies. + """ + return SimpleHttpClient( + self, + http_proxy=os.getenvb(b"http_proxy"), + https_proxy=os.getenvb(b"HTTPS_PROXY"), + ) + + @cache_in_self + def get_proxied_blacklisted_http_client(self) -> SimpleHttpClient: + """ + An HTTP client that uses configured HTTP(S) proxies and blacklists IPs + based on the IP range blacklist. + """ return SimpleHttpClient( self, + ip_blacklist=self.config.ip_range_blacklist, http_proxy=os.getenvb(b"http_proxy"), https_proxy=os.getenvb(b"HTTPS_PROXY"), ) + @cache_in_self + def get_federation_http_client(self) -> MatrixFederationHttpClient: + """ + An HTTP client for federation. + """ + tls_client_options_factory = context_factory.FederationPolicyForHTTPS( + self.config + ) + return MatrixFederationHttpClient(self, tls_client_options_factory) + @cache_in_self def get_room_creation_handler(self) -> RoomCreationHandler: return RoomCreationHandler(self) @@ -514,13 +543,6 @@ class HomeServer(metaclass=abc.ABCMeta): def get_pusherpool(self) -> PusherPool: return PusherPool(self) - @cache_in_self - def get_http_client(self) -> MatrixFederationHttpClient: - tls_client_options_factory = context_factory.FederationPolicyForHTTPS( - self.config - ) - return MatrixFederationHttpClient(self, tls_client_options_factory) - @cache_in_self def get_media_repository_resource(self) -> MediaRepositoryResource: # build the media repo resource. This indirects through the HomeServer diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index c98ae75974..bcf1a8010e 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -50,7 +50,9 @@ class FilteringTestCase(unittest.TestCase): self.mock_http_client.put_json = DeferredMockCallable() hs = yield setup_test_homeserver( - self.addCleanup, http_client=self.mock_http_client, keyring=Mock(), + self.addCleanup, + federation_http_client=self.mock_http_client, + keyring=Mock(), ) self.filtering = hs.get_filtering() diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py index 40abe9d72d..43fef5d64a 100644 --- a/tests/app/test_frontend_proxy.py +++ b/tests/app/test_frontend_proxy.py @@ -23,7 +23,7 @@ class FrontendProxyTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserver_to_use=GenericWorkerServer + federation_http_client=None, homeserver_to_use=GenericWorkerServer ) return hs diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index ea3be95cf1..b260ab734d 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -27,7 +27,7 @@ from tests.unittest import HomeserverTestCase class FederationReaderOpenIDListenerTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserver_to_use=GenericWorkerServer + federation_http_client=None, homeserver_to_use=GenericWorkerServer ) return hs @@ -84,7 +84,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserver_to_use=SynapseHomeServer + federation_http_client=None, homeserver_to_use=SynapseHomeServer ) return hs diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 697916a019..d146f2254f 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -315,7 +315,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.http_client = Mock() - hs = self.setup_test_homeserver(http_client=self.http_client) + hs = self.setup_test_homeserver(federation_http_client=self.http_client) return hs def test_get_keys_from_server(self): @@ -395,7 +395,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): } ] - return self.setup_test_homeserver(http_client=self.http_client, config=config) + return self.setup_test_homeserver( + federation_http_client=self.http_client, config=config + ) def build_perspectives_response( self, server_name: str, signing_key: SigningKey, valid_until_ts: int, diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 875aaec2c6..5dfeccfeb6 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -27,7 +27,7 @@ user2 = "@theresa:bbb" class DeviceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver("server", http_client=None) + hs = self.setup_test_homeserver("server", federation_http_client=None) self.handler = hs.get_device_handler() self.store = hs.get_datastore() return hs @@ -229,7 +229,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): class DehydrationTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver("server", http_client=None) + hs = self.setup_test_homeserver("server", federation_http_client=None) self.handler = hs.get_device_handler() self.registration = hs.get_registration_handler() self.auth = hs.get_auth() diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index ee6ef5e6fa..2f1f2a5517 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -42,7 +42,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.mock_registry.register_query_handler = register_query_handler hs = self.setup_test_homeserver( - http_client=None, + federation_http_client=None, resource_for_federation=Mock(), federation_client=self.mock_federation, federation_registry=self.mock_registry, diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index bf866dacf3..d0452e1490 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -37,7 +37,7 @@ class FederationTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver(http_client=None) + hs = self.setup_test_homeserver(federation_http_client=None) self.handler = hs.get_federation_handler() self.store = hs.get_datastore() return hs diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 8ed67640f8..0794b32c9c 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -463,7 +463,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - "server", http_client=None, federation_sender=Mock() + "server", federation_http_client=None, federation_sender=Mock() ) return hs diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index a69fa28b41..ea2bcf2655 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -44,7 +44,7 @@ class ProfileTestCase(unittest.TestCase): hs = yield setup_test_homeserver( self.addCleanup, - http_client=None, + federation_http_client=None, resource_for_federation=Mock(), federation_client=self.mock_federation, federation_server=Mock(), diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index abbdf2d524..36086ca836 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -70,7 +70,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver( notifier=Mock(), - http_client=mock_federation_client, + federation_http_client=mock_federation_client, keyring=mock_keyring, replication_streams={}, ) @@ -192,7 +192,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) ) - put_json = self.hs.get_http_client().put_json + put_json = self.hs.get_federation_http_client().put_json put_json.assert_called_once_with( "farm", path="/_matrix/federation/v1/send/1000000", @@ -270,7 +270,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) - put_json = self.hs.get_http_client().put_json + put_json = self.hs.get_federation_http_client().put_json put_json.assert_called_once_with( "farm", path="/_matrix/federation/v1/send/1000000", diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 8b5ad4574f..626acdcaa3 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -17,6 +17,7 @@ import logging from mock import Mock import treq +from netaddr import IPSet from service_identity import VerificationError from zope.interface import implementer @@ -103,6 +104,7 @@ class MatrixFederationAgentTests(unittest.TestCase): reactor=self.reactor, tls_client_options_factory=self.tls_factory, user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided. + ip_blacklist=IPSet(), _srv_resolver=self.mock_resolver, _well_known_resolver=self.well_known_resolver, ) @@ -736,6 +738,7 @@ class MatrixFederationAgentTests(unittest.TestCase): reactor=self.reactor, tls_client_options_factory=tls_factory, user_agent=b"test-agent", # This is unused since _well_known_resolver is passed below. + ip_blacklist=IPSet(), _srv_resolver=self.mock_resolver, _well_known_resolver=WellKnownResolver( self.reactor, diff --git a/tests/push/test_http.py b/tests/push/test_http.py index f118430309..e8cea39c83 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -49,7 +49,9 @@ class HTTPPusherTests(HomeserverTestCase): config = self.default_config() config["start_pushers"] = True - hs = self.setup_test_homeserver(config=config, proxied_http_client=m) + hs = self.setup_test_homeserver( + config=config, proxied_blacklisted_http_client=m + ) return hs diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 295c5d58a6..728de28277 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -67,7 +67,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # Make a new HomeServer object for the worker self.reactor.lookups["testserv"] = "1.2.3.4" self.worker_hs = self.setup_test_homeserver( - http_client=None, + federation_http_client=None, homeserver_to_use=GenericWorkerServer, config=self._get_worker_hs_config(), reactor=self.reactor, @@ -264,7 +264,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): worker_app: Type of worker, e.g. `synapse.app.federation_sender`. extra_config: Any extra config to use for this instances. **kwargs: Options that get passed to `self.setup_test_homeserver`, - useful to e.g. pass some mocks for things like `http_client` + useful to e.g. pass some mocks for things like `federation_http_client` Returns: The new worker HomeServer instance. diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 779745ae9d..fffdb742c8 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -50,7 +50,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): self.make_worker_hs( "synapse.app.federation_sender", {"send_federation": True}, - http_client=mock_client, + federation_http_client=mock_client, ) user = self.register_user("user", "pass") @@ -81,7 +81,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): "worker_name": "sender1", "federation_sender_instances": ["sender1", "sender2"], }, - http_client=mock_client1, + federation_http_client=mock_client1, ) mock_client2 = Mock(spec=["put_json"]) @@ -93,7 +93,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): "worker_name": "sender2", "federation_sender_instances": ["sender1", "sender2"], }, - http_client=mock_client2, + federation_http_client=mock_client2, ) user = self.register_user("user2", "pass") @@ -144,7 +144,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): "worker_name": "sender1", "federation_sender_instances": ["sender1", "sender2"], }, - http_client=mock_client1, + federation_http_client=mock_client1, ) mock_client2 = Mock(spec=["put_json"]) @@ -156,7 +156,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): "worker_name": "sender2", "federation_sender_instances": ["sender1", "sender2"], }, - http_client=mock_client2, + federation_http_client=mock_client2, ) user = self.register_user("user3", "pass") diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index 67c27a089f..f894bcd6e7 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -98,7 +98,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): self.make_worker_hs( "synapse.app.pusher", {"start_pushers": True}, - proxied_http_client=http_client_mock, + proxied_blacklisted_http_client=http_client_mock, ) event_id = self._create_pusher_and_send_msg("user") @@ -133,7 +133,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): "worker_name": "pusher1", "pusher_instances": ["pusher1", "pusher2"], }, - proxied_http_client=http_client_mock1, + proxied_blacklisted_http_client=http_client_mock1, ) http_client_mock2 = Mock(spec_set=["post_json_get_json"]) @@ -148,7 +148,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): "worker_name": "pusher2", "pusher_instances": ["pusher1", "pusher2"], }, - proxied_http_client=http_client_mock2, + proxied_blacklisted_http_client=http_client_mock2, ) # We choose a user name that we know should go to pusher1. diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 4f76f8f768..67d8878395 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -210,7 +210,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): } config["media_storage_providers"] = [provider_config] - hs = self.setup_test_homeserver(config=config, http_client=client) + hs = self.setup_test_homeserver(config=config, federation_http_client=client) return hs diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 5d5c24d01c..11cd8efe21 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -38,7 +38,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver( "red", - http_client=None, + federation_http_client=None, federation_client=Mock(), presence_handler=presence_handler, ) diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 383a9eafac..2a3b483eaf 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -63,7 +63,7 @@ class MockHandlerProfileTestCase(unittest.TestCase): hs = yield setup_test_homeserver( self.addCleanup, "test", - http_client=None, + federation_http_client=None, resource_for_client=self.mock_resource, federation=Mock(), federation_client=Mock(), diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 49f1073c88..e67de41c18 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -45,7 +45,7 @@ class RoomBase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver( - "red", http_client=None, federation_client=Mock(), + "red", federation_http_client=None, federation_client=Mock(), ) self.hs.get_federation_handler = Mock() diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index bbd30f594b..ae0207366b 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -39,7 +39,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - "red", http_client=None, federation_client=Mock(), + "red", federation_http_client=None, federation_client=Mock(), ) self.event_source = hs.get_event_sources().sources["typing"] diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index fbcf8d5b86..5e90d656f7 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -39,7 +39,7 @@ from tests.utils import default_config class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.http_client = Mock() - return self.setup_test_homeserver(http_client=self.http_client) + return self.setup_test_homeserver(federation_http_client=self.http_client) def create_test_resource(self): return create_resource_tree( @@ -172,7 +172,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): } ] self.hs2 = self.setup_test_homeserver( - http_client=self.http_client2, config=config + federation_http_client=self.http_client2, config=config ) # wire up outbound POST /key/v2/query requests from hs2 so that they diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 2a3b2a8f27..4c749f1a61 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -214,7 +214,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): } config["media_storage_providers"] = [provider_config] - hs = self.setup_test_homeserver(config=config, http_client=client) + hs = self.setup_test_homeserver(config=config, federation_http_client=client) return hs diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py index 35dafbb904..3d7760d5d9 100644 --- a/tests/storage/test_e2e_room_keys.py +++ b/tests/storage/test_e2e_room_keys.py @@ -26,7 +26,7 @@ room_key = { class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver("server", http_client=None) + hs = self.setup_test_homeserver("server", federation_http_client=None) self.store = hs.get_datastore() return hs diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index cc1f3c53c5..a06ad2c03e 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -27,7 +27,7 @@ class PurgeTests(HomeserverTestCase): servlets = [room.register_servlets] def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver("server", http_client=None) + hs = self.setup_test_homeserver("server", federation_http_client=None) return hs def prepare(self, reactor, clock, hs): diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index d4f9e809db..fd0add5db3 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -34,7 +34,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): config = self.default_config() config["redaction_retention_period"] = "30d" return self.setup_test_homeserver( - resource_for_federation=Mock(), http_client=None, config=config + resource_for_federation=Mock(), federation_http_client=None, config=config ) def prepare(self, reactor, clock, hs): diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index ff972daeaa..5ba1db2332 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -36,7 +36,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - resource_for_federation=Mock(), http_client=None + resource_for_federation=Mock(), federation_http_client=None ) return hs diff --git a/tests/test_federation.py b/tests/test_federation.py index 1ce4ea3a01..fa45f8b3b7 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -37,7 +37,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( self.addCleanup, - http_client=self.http_client, + federation_http_client=self.http_client, clock=self.hs_clock, reactor=self.reactor, ) diff --git a/tests/test_server.py b/tests/test_server.py index c387a85f2e..6b2d2f0401 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -38,7 +38,10 @@ class JsonResourceTests(unittest.TestCase): self.reactor = ThreadedMemoryReactorClock() self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( - self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor + self.addCleanup, + federation_http_client=None, + clock=self.hs_clock, + reactor=self.reactor, ) def test_handler_for_request(self): -- cgit 1.5.1 From cf3b8156bec7d137894ebc3f6998c3c81a3854aa Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 3 Dec 2020 15:41:19 +0000 Subject: Fix errorcode for disabled registration (#8867) The spec says we should return `M_FORBIDDEN` when someone tries to register and registration is disabled. --- changelog.d/8867.bugfix | 1 + synapse/rest/client/v2_alpha/register.py | 2 +- tests/rest/client/v2_alpha/test_register.py | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8867.bugfix (limited to 'synapse/rest') diff --git a/changelog.d/8867.bugfix b/changelog.d/8867.bugfix new file mode 100644 index 0000000000..f2414ff111 --- /dev/null +++ b/changelog.d/8867.bugfix @@ -0,0 +1 @@ +Fix the error code that is returned when a user tries to register on a homeserver on which new-user registration has been disabled. diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index a89ae6ddf9..9041e7ed76 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -451,7 +451,7 @@ class RegisterRestServlet(RestServlet): # == Normal User Registration == (everyone else) if not self._registration_enabled: - raise SynapseError(403, "Registration has been disabled") + raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN) # For regular registration, convert the provided username to lowercase # before attempting to register it. This should mean that people who try diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 8f0c2430e8..bcb21d0ced 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -121,6 +121,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.json_body["error"], "Registration has been disabled") + self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN") def test_POST_guest_registration(self): self.hs.config.macaroon_secret_key = "test" -- cgit 1.5.1 From df3e6a23a74e85005881bd5bead3443f4f758095 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 4 Dec 2020 10:26:09 -0500 Subject: Do not 500 if the content-length is not provided when uploading media. (#8862) Instead return the proper 400 error. --- changelog.d/8862.bugfix | 1 + synapse/rest/media/v1/upload_resource.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8862.bugfix (limited to 'synapse/rest') diff --git a/changelog.d/8862.bugfix b/changelog.d/8862.bugfix new file mode 100644 index 0000000000..bdbd633f72 --- /dev/null +++ b/changelog.d/8862.bugfix @@ -0,0 +1 @@ +Fix a longstanding bug where a 500 error would be returned if the `Content-Length` header was not provided to the upload media resource. diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index d76f7389e1..42febc9afc 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -44,7 +44,7 @@ class UploadResource(DirectServeJsonResource): requester = await self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have # already been uploaded to a tmp file at this point - content_length = request.getHeader(b"Content-Length").decode("ascii") + content_length = request.getHeader("Content-Length") if content_length is None: raise SynapseError(msg="Request must specify a Content-Length", code=400) if int(content_length) > self.max_upload_size: -- cgit 1.5.1 From 1f3748f03398f8f91ec5121312aa79dd58306ec1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 7 Dec 2020 10:00:08 -0500 Subject: Do not raise a 500 exception when previewing empty media. (#8883) --- changelog.d/8883.bugfix | 1 + synapse/rest/media/v1/preview_url_resource.py | 6 +++++- tests/test_preview.py | 27 ++++++++++++++++----------- 3 files changed, 22 insertions(+), 12 deletions(-) create mode 100644 changelog.d/8883.bugfix (limited to 'synapse/rest') diff --git a/changelog.d/8883.bugfix b/changelog.d/8883.bugfix new file mode 100644 index 0000000000..6137fc5b2b --- /dev/null +++ b/changelog.d/8883.bugfix @@ -0,0 +1 @@ +Fix a 500 error when attempting to preview an empty HTML file. diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index dce6c4d168..1082389d9b 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -676,7 +676,11 @@ class PreviewUrlResource(DirectServeJsonResource): logger.debug("No media removed from url cache") -def decode_and_calc_og(body, media_uri, request_encoding=None): +def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]: + # If there's no body, nothing useful is going to be found. + if not body: + return {} + from lxml import etree try: diff --git a/tests/test_preview.py b/tests/test_preview.py index 7f67ee9e1f..a883d707df 100644 --- a/tests/test_preview.py +++ b/tests/test_preview.py @@ -56,7 +56,7 @@ class PreviewTestCase(unittest.TestCase): desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) - self.assertEquals( + self.assertEqual( desc, "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" @@ -69,7 +69,7 @@ class PreviewTestCase(unittest.TestCase): desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500) - self.assertEquals( + self.assertEqual( desc, "Tromsø lies in Northern Norway. The municipality has a population of" " (2015) 72,066, but with an annual influx of students it has over 75,000" @@ -96,7 +96,7 @@ class PreviewTestCase(unittest.TestCase): desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) - self.assertEquals( + self.assertEqual( desc, "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" @@ -122,7 +122,7 @@ class PreviewTestCase(unittest.TestCase): ] desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) - self.assertEquals( + self.assertEqual( desc, "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" @@ -149,7 +149,7 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."}) + self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) def test_comment(self): html = """ @@ -164,7 +164,7 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."}) + self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) def test_comment2(self): html = """ @@ -182,7 +182,7 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals( + self.assertEqual( og, { "og:title": "Foo", @@ -203,7 +203,7 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."}) + self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) def test_missing_title(self): html = """ @@ -216,7 +216,7 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {"og:title": None, "og:description": "Some text."}) + self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) def test_h1_as_title(self): html = """ @@ -230,7 +230,7 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {"og:title": "Title", "og:description": "Some text."}) + self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) def test_missing_title_and_broken_h1(self): html = """ @@ -244,4 +244,9 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {"og:title": None, "og:description": "Some text."}) + self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) + + def test_empty(self): + html = "" + og = decode_and_calc_og(html, "http://example.com/test.html") + self.assertEqual(og, {}) -- cgit 1.5.1 From ff1f0ee09472b554832fb39952f389d01a4233ac Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Mon, 7 Dec 2020 19:13:07 +0000 Subject: Call set_avatar_url with target_user, not user_id (#8872) * Call set_avatar_url with target_user, not user_id Fixes https://github.com/matrix-org/synapse/issues/8871 * Create 8872.bugfix * Update synapse/rest/admin/users.py Co-authored-by: Patrick Cloke * Testing * Update changelog.d/8872.bugfix Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Co-authored-by: Patrick Cloke Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- changelog.d/8872.bugfix | 1 + synapse/rest/admin/users.py | 4 ++-- tests/rest/admin/test_user.py | 7 ++++++- 3 files changed, 9 insertions(+), 3 deletions(-) create mode 100644 changelog.d/8872.bugfix (limited to 'synapse/rest') diff --git a/changelog.d/8872.bugfix b/changelog.d/8872.bugfix new file mode 100644 index 0000000000..ed00b70a0f --- /dev/null +++ b/changelog.d/8872.bugfix @@ -0,0 +1 @@ +Fix a bug where `PUT /_synapse/admin/v2/users/` failed to create a new user when `avatar_url` is specified. Bug introduced in Synapse v1.9.0. diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 90940ff185..88cba369f5 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -320,9 +320,9 @@ class UserRestServletV2(RestServlet): data={}, ) - if "avatar_url" in body and type(body["avatar_url"]) == str: + if "avatar_url" in body and isinstance(body["avatar_url"], str): await self.profile_handler.set_avatar_url( - user_id, requester, body["avatar_url"], True + target_user, requester, body["avatar_url"], True ) ret = await self.admin_handler.get_user(target_user) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 35c546aa69..ba1438cdc7 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -561,7 +561,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): "admin": True, "displayname": "Bob's name", "threepids": [{"medium": "email", "address": "bob@bob.bob"}], - "avatar_url": None, + "avatar_url": "mxc://fibble/wibble", } ) @@ -578,6 +578,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertEqual(True, channel.json_body["admin"]) + self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) # Get user request, channel = self.make_request( @@ -592,6 +593,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(True, channel.json_body["admin"]) self.assertEqual(False, channel.json_body["is_guest"]) self.assertEqual(False, channel.json_body["deactivated"]) + self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) def test_create_user(self): """ @@ -606,6 +608,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): "admin": False, "displayname": "Bob's name", "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + "avatar_url": "mxc://fibble/wibble", } ) @@ -622,6 +625,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertEqual(False, channel.json_body["admin"]) + self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) # Get user request, channel = self.make_request( @@ -636,6 +640,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(False, channel.json_body["admin"]) self.assertEqual(False, channel.json_body["is_guest"]) self.assertEqual(False, channel.json_body["deactivated"]) + self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) @override_config( {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0} -- cgit 1.5.1 From cd9e72b185e36ac3f18ab1fe567fdeee87bfcb8e Mon Sep 17 00:00:00 2001 From: Aaron Raimist Date: Tue, 8 Dec 2020 16:51:03 -0600 Subject: Add X-Robots-Tag header to stop crawlers from indexing media (#8887) Fixes / related to: https://github.com/matrix-org/synapse/issues/6533 This should do essentially the same thing as a robots.txt file telling robots to not index the media repo. https://developers.google.com/search/reference/robots_meta_tag Signed-off-by: Aaron Raimist --- changelog.d/8887.feature | 1 + synapse/rest/media/v1/_base.py | 5 +++++ tests/rest/media/v1/test_media_storage.py | 13 +++++++++++++ 3 files changed, 19 insertions(+) create mode 100644 changelog.d/8887.feature (limited to 'synapse/rest') diff --git a/changelog.d/8887.feature b/changelog.d/8887.feature new file mode 100644 index 0000000000..729eb1f1ea --- /dev/null +++ b/changelog.d/8887.feature @@ -0,0 +1 @@ +Add `X-Robots-Tag` header to stop web crawlers from indexing media. diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 67aa993f19..47c2b44bff 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -155,6 +155,11 @@ def add_file_headers(request, media_type, file_size, upload_name): request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") request.setHeader(b"Content-Length", b"%d" % (file_size,)) + # Tell web crawlers to not index, archive, or follow links in media. This + # should help to prevent things in the media repo from showing up in web + # search results. + request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex") + # separators as defined in RFC2616. SP and HT are handled separately. # see _can_encode_filename_as_token. diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 4c749f1a61..6f0677d335 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -362,3 +362,16 @@ class MediaRepoTests(unittest.HomeserverTestCase): "error": "Not found [b'example.com', b'12345']", }, ) + + def test_x_robots_tag_header(self): + """ + Tests that the `X-Robots-Tag` header is present, which informs web crawlers + to not index, archive, or follow links in media. + """ + channel = self._req(b"inline; filename=out" + self.test_image.extension) + + headers = channel.headers + self.assertEqual( + headers.getRawHeaders(b"X-Robots-Tag"), + [b"noindex, nofollow, noarchive, noimageindex"], + ) -- cgit 1.5.1 From 0a34cdfc6682c2654c745c4d7c2f5ffd1865dbc8 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 11 Dec 2020 11:42:47 +0100 Subject: Add number of local devices to Room Details Admin API (#8886) --- changelog.d/8886.feature | 1 + docs/admin_api/rooms.md | 24 +++++++++------- synapse/rest/admin/rooms.py | 48 ++++++++++++++++++++----------- synapse/storage/databases/main/devices.py | 32 +++++++++++++++++++++ tests/rest/admin/test_room.py | 34 ++++++++++++++++++++++ tests/storage/test_devices.py | 26 +++++++++++++++++ 6 files changed, 138 insertions(+), 27 deletions(-) create mode 100644 changelog.d/8886.feature (limited to 'synapse/rest') diff --git a/changelog.d/8886.feature b/changelog.d/8886.feature new file mode 100644 index 0000000000..9e446f28bd --- /dev/null +++ b/changelog.d/8886.feature @@ -0,0 +1 @@ +Add number of local devices to Room Details Admin API. Contributed by @dklimpel. \ No newline at end of file diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index 3ac21b5cae..d7b1740fe3 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -87,7 +87,7 @@ GET /_synapse/admin/v1/rooms Response: -``` +```jsonc { "rooms": [ { @@ -139,7 +139,7 @@ GET /_synapse/admin/v1/rooms?search_term=TWIM Response: -``` +```json { "rooms": [ { @@ -174,7 +174,7 @@ GET /_synapse/admin/v1/rooms?order_by=size Response: -``` +```jsonc { "rooms": [ { @@ -230,14 +230,14 @@ GET /_synapse/admin/v1/rooms?order_by=size&from=100 Response: -``` +```jsonc { "rooms": [ { "room_id": "!mscvqgqpHYjBGDxNym:matrix.org", "name": "Music Theory", "canonical_alias": "#musictheory:matrix.org", - "joined_members": 127 + "joined_members": 127, "joined_local_members": 2, "version": "1", "creator": "@foo:matrix.org", @@ -254,7 +254,7 @@ Response: "room_id": "!twcBhHVdZlQWuuxBhN:termina.org.uk", "name": "weechat-matrix", "canonical_alias": "#weechat-matrix:termina.org.uk", - "joined_members": 137 + "joined_members": 137, "joined_local_members": 20, "version": "4", "creator": "@foo:termina.org.uk", @@ -289,6 +289,7 @@ The following fields are possible in the JSON response body: * `canonical_alias` - The canonical (main) alias address of the room. * `joined_members` - How many users are currently in the room. * `joined_local_members` - How many local users are currently in the room. +* `joined_local_devices` - How many local devices are currently in the room. * `version` - The version of the room as a string. * `creator` - The `user_id` of the room creator. * `encryption` - Algorithm of end-to-end encryption of messages. Is `null` if encryption is not active. @@ -311,15 +312,16 @@ GET /_synapse/admin/v1/rooms/ Response: -``` +```json { "room_id": "!mscvqgqpHYjBGDxNym:matrix.org", "name": "Music Theory", "avatar": "mxc://matrix.org/AQDaVFlbkQoErdOgqWRgiGSV", "topic": "Theory, Composition, Notation, Analysis", "canonical_alias": "#musictheory:matrix.org", - "joined_members": 127 + "joined_members": 127, "joined_local_members": 2, + "joined_local_devices": 2, "version": "1", "creator": "@foo:matrix.org", "encryption": null, @@ -353,13 +355,13 @@ GET /_synapse/admin/v1/rooms//members Response: -``` +```json { "members": [ "@foo:matrix.org", "@bar:matrix.org", - "@foobar:matrix.org - ], + "@foobar:matrix.org" + ], "total": 3 } ``` diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 25f89e4685..b902af8028 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -14,7 +14,7 @@ # limitations under the License. import logging from http import HTTPStatus -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional, Tuple from synapse.api.constants import EventTypes, JoinRules from synapse.api.errors import Codes, NotFoundError, SynapseError @@ -25,13 +25,17 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.rest.admin._base import ( admin_patterns, assert_requester_is_admin, assert_user_is_admin, ) from synapse.storage.databases.main.room import RoomSortOrder -from synapse.types import RoomAlias, RoomID, UserID, create_requester +from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -45,12 +49,14 @@ class ShutdownRoomRestServlet(RestServlet): PATTERNS = admin_patterns("/shutdown_room/(?P[^/]+)") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.room_shutdown_handler = hs.get_room_shutdown_handler() - async def on_POST(self, request, room_id): + async def on_POST( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -86,13 +92,15 @@ class DeleteRoomRestServlet(RestServlet): PATTERNS = admin_patterns("/rooms/(?P[^/]+)/delete$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.room_shutdown_handler = hs.get_room_shutdown_handler() self.pagination_handler = hs.get_pagination_handler() - async def on_POST(self, request, room_id): + async def on_POST( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -146,12 +154,12 @@ class ListRoomRestServlet(RestServlet): PATTERNS = admin_patterns("/rooms$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -236,19 +244,24 @@ class RoomRestServlet(RestServlet): PATTERNS = admin_patterns("/rooms/(?P[^/]+)$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) ret = await self.store.get_room_with_stats(room_id) if not ret: raise NotFoundError("Room not found") - return 200, ret + members = await self.store.get_users_in_room(room_id) + ret["joined_local_devices"] = await self.store.count_devices_by_users(members) + + return (200, ret) class RoomMembersRestServlet(RestServlet): @@ -258,12 +271,14 @@ class RoomMembersRestServlet(RestServlet): PATTERNS = admin_patterns("/rooms/(?P[^/]+)/members") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) ret = await self.store.get_room(room_id) @@ -280,14 +295,16 @@ class JoinRoomAliasServlet(RestServlet): PATTERNS = admin_patterns("/join/(?P[^/]*)") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.room_member_handler = hs.get_room_member_handler() self.admin_handler = hs.get_admin_handler() self.state_handler = hs.get_state_handler() - async def on_POST(self, request, room_identifier): + async def on_POST( + self, request: SynapseRequest, room_identifier: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -314,7 +331,6 @@ class JoinRoomAliasServlet(RestServlet): handler = self.room_member_handler room_alias = RoomAlias.from_string(room_identifier) room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias) - room_id = room_id.to_string() else: raise SynapseError( 400, "%s was not legal room ID or room alias" % (room_identifier,) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index dfb4f87b8f..9097677648 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -57,6 +57,38 @@ class DeviceWorkerStore(SQLBaseStore): self._prune_old_outbound_device_pokes, 60 * 60 * 1000 ) + async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int: + """Retrieve number of all devices of given users. + Only returns number of devices that are not marked as hidden. + + Args: + user_ids: The IDs of the users which owns devices + Returns: + Number of devices of this users. + """ + + def count_devices_by_users_txn(txn, user_ids): + sql = """ + SELECT count(*) + FROM devices + WHERE + hidden = '0' AND + """ + + clause, args = make_in_list_sql_clause( + txn.database_engine, "user_id", user_ids + ) + + txn.execute(sql + clause, args) + return txn.fetchone()[0] + + if not user_ids: + return 0 + + return await self.db_pool.runInteraction( + "count_devices_by_users", count_devices_by_users_txn, user_ids + ) + async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: """Retrieve a device. Only returns devices that are not marked as hidden. diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 46933a0493..9c100050d2 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1084,6 +1084,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertIn("canonical_alias", channel.json_body) self.assertIn("joined_members", channel.json_body) self.assertIn("joined_local_members", channel.json_body) + self.assertIn("joined_local_devices", channel.json_body) self.assertIn("version", channel.json_body) self.assertIn("creator", channel.json_body) self.assertIn("encryption", channel.json_body) @@ -1096,6 +1097,39 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(room_id_1, channel.json_body["room_id"]) + def test_single_room_devices(self): + """Test that `joined_local_devices` can be requested correctly""" + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["joined_local_devices"]) + + # Have another user join the room + user_1 = self.register_user("foo", "pass") + user_tok_1 = self.login("foo", "pass") + self.helper.join(room_id_1, user_1, tok=user_tok_1) + + url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(2, channel.json_body["joined_local_devices"]) + + # leave room + self.helper.leave(room_id_1, self.admin_user, tok=self.admin_user_tok) + self.helper.leave(room_id_1, user_1, tok=user_tok_1) + url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["joined_local_devices"]) + def test_room_members(self): """Test that room members can be requested correctly""" # Create two test rooms diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index ecb00f4e02..dabc1c5f09 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -79,6 +79,32 @@ class DeviceStoreTestCase(tests.unittest.TestCase): res["device2"], ) + @defer.inlineCallbacks + def test_count_devices_by_users(self): + yield defer.ensureDeferred( + self.store.store_device("user_id", "device1", "display_name 1") + ) + yield defer.ensureDeferred( + self.store.store_device("user_id", "device2", "display_name 2") + ) + yield defer.ensureDeferred( + self.store.store_device("user_id2", "device3", "display_name 3") + ) + + res = yield defer.ensureDeferred(self.store.count_devices_by_users()) + self.assertEqual(0, res) + + res = yield defer.ensureDeferred(self.store.count_devices_by_users(["unknown"])) + self.assertEqual(0, res) + + res = yield defer.ensureDeferred(self.store.count_devices_by_users(["user_id"])) + self.assertEqual(2, res) + + res = yield defer.ensureDeferred( + self.store.count_devices_by_users(["user_id", "user_id2"]) + ) + self.assertEqual(3, res) + @defer.inlineCallbacks def test_get_device_updates_by_remote(self): device_ids = ["device_id1", "device_id2"] -- cgit 1.5.1 From a8eceb01e59fcbddcea7d19031ed2392772e6d66 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 11 Dec 2020 16:33:31 +0000 Subject: Honour AS ratelimit settings for /login requests (#8920) Fixes #8846. --- changelog.d/8920.bugfix | 1 + synapse/api/auth.py | 4 +++- synapse/handlers/auth.py | 7 ++++--- synapse/rest/client/v1/login.py | 25 +++++++++++++++++++------ 4 files changed, 27 insertions(+), 10 deletions(-) create mode 100644 changelog.d/8920.bugfix (limited to 'synapse/rest') diff --git a/changelog.d/8920.bugfix b/changelog.d/8920.bugfix new file mode 100644 index 0000000000..abcf186bda --- /dev/null +++ b/changelog.d/8920.bugfix @@ -0,0 +1 @@ +Fix login API to not ratelimit application services that have ratelimiting disabled. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index bfcaf68b2a..1951f6e178 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -31,7 +31,9 @@ from synapse.api.errors import ( MissingClientTokenError, ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.appservice import ApplicationService from synapse.events import EventBase +from synapse.http.site import SynapseRequest from synapse.logging import opentracing as opentracing from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import StateMap, UserID @@ -474,7 +476,7 @@ class Auth: now = self.hs.get_clock().time_msec() return now < expiry - def get_appservice_by_req(self, request): + def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService: token = self.get_access_token_from_request(request) service = self.store.get_app_service_by_token(token) if not service: diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index afae6d3272..62f98dabc0 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -22,6 +22,7 @@ import urllib.parse from typing import ( TYPE_CHECKING, Any, + Awaitable, Callable, Dict, Iterable, @@ -861,7 +862,7 @@ class AuthHandler(BaseHandler): async def validate_login( self, login_submission: Dict[str, Any], ratelimit: bool = False, - ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]: + ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]: """Authenticates the user for the /login API Also used by the user-interactive auth flow to validate auth types which don't @@ -1004,7 +1005,7 @@ class AuthHandler(BaseHandler): async def _validate_userid_login( self, username: str, login_submission: Dict[str, Any], - ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]: + ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]: """Helper for validate_login Handles login, once we've mapped 3pids onto userids @@ -1082,7 +1083,7 @@ class AuthHandler(BaseHandler): async def check_password_provider_3pid( self, medium: str, address: str, password: str - ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], None]]]: + ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]: """Check if a password provider is able to validate a thirdparty login Args: diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index d7ae148214..5f4c6703db 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Awaitable, Callable, Dict, Optional +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter @@ -30,6 +30,9 @@ from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.well_known import WellKnownBuilder from synapse.types import JsonDict, UserID +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -42,7 +45,7 @@ class LoginRestServlet(RestServlet): JWT_TYPE_DEPRECATED = "m.login.jwt" APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs @@ -105,22 +108,27 @@ class LoginRestServlet(RestServlet): return 200, {"flows": flows} async def on_POST(self, request: SynapseRequest): - self._address_ratelimiter.ratelimit(request.getClientIP()) - login_submission = parse_json_object_from_request(request) try: if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE: appservice = self.auth.get_appservice_by_req(request) + + if appservice.is_rate_limited(): + self._address_ratelimiter.ratelimit(request.getClientIP()) + result = await self._do_appservice_login(login_submission, appservice) elif self.jwt_enabled and ( login_submission["type"] == LoginRestServlet.JWT_TYPE or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED ): + self._address_ratelimiter.ratelimit(request.getClientIP()) result = await self._do_jwt_login(login_submission) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: + self._address_ratelimiter.ratelimit(request.getClientIP()) result = await self._do_token_login(login_submission) else: + self._address_ratelimiter.ratelimit(request.getClientIP()) result = await self._do_other_login(login_submission) except KeyError: raise SynapseError(400, "Missing JSON keys.") @@ -159,7 +167,9 @@ class LoginRestServlet(RestServlet): if not appservice.is_interested_in_user(qualified_user_id): raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN) - return await self._complete_login(qualified_user_id, login_submission) + return await self._complete_login( + qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited() + ) async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]: """Handle non-token/saml/jwt logins @@ -194,6 +204,7 @@ class LoginRestServlet(RestServlet): login_submission: JsonDict, callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None, create_non_existent_users: bool = False, + ratelimit: bool = True, ) -> Dict[str, str]: """Called when we've successfully authed the user and now need to actually login them in (e.g. create devices). This gets called on @@ -208,6 +219,7 @@ class LoginRestServlet(RestServlet): callback: Callback function to run after login. create_non_existent_users: Whether to create the user if they don't exist. Defaults to False. + ratelimit: Whether to ratelimit the login request. Returns: result: Dictionary of account information after successful login. @@ -216,7 +228,8 @@ class LoginRestServlet(RestServlet): # Before we actually log them in we check if they've already logged in # too often. This happens here rather than before as we don't # necessarily know the user before now. - self._account_ratelimiter.ratelimit(user_id.lower()) + if ratelimit: + self._account_ratelimiter.ratelimit(user_id.lower()) if create_non_existent_users: canonical_uid = await self.auth_handler.check_user_exists(user_id) -- cgit 1.5.1 From f14428b25c37e44675edac4a80d7bd1e47112586 Mon Sep 17 00:00:00 2001 From: David Teller Date: Fri, 11 Dec 2020 20:05:15 +0100 Subject: Allow spam-checker modules to be provide async methods. (#8890) Spam checker modules can now provide async methods. This is implemented in a backwards-compatible manner. --- changelog.d/8890.feature | 1 + docs/spam_checker.md | 19 ++++++--- synapse/events/spamcheck.py | 55 +++++++++++++++++++-------- synapse/federation/federation_base.py | 7 +++- synapse/handlers/auth.py | 8 ++-- synapse/handlers/directory.py | 6 ++- synapse/handlers/federation.py | 2 +- synapse/handlers/message.py | 2 +- synapse/handlers/receipts.py | 7 +--- synapse/handlers/register.py | 2 +- synapse/handlers/room.py | 4 +- synapse/handlers/room_member.py | 2 +- synapse/handlers/user_directory.py | 10 ++--- synapse/metrics/background_process_metrics.py | 9 +---- synapse/rest/media/v1/storage_provider.py | 16 +++----- synapse/server.py | 2 +- synapse/util/async_helpers.py | 8 ++-- synapse/util/distributor.py | 7 +--- tests/handlers/test_user_directory.py | 4 +- 19 files changed, 98 insertions(+), 73 deletions(-) create mode 100644 changelog.d/8890.feature (limited to 'synapse/rest') diff --git a/changelog.d/8890.feature b/changelog.d/8890.feature new file mode 100644 index 0000000000..97aa72a76e --- /dev/null +++ b/changelog.d/8890.feature @@ -0,0 +1 @@ +Spam-checkers may now define their methods as `async`. diff --git a/docs/spam_checker.md b/docs/spam_checker.md index 7fc08f1b70..5b4f6428e6 100644 --- a/docs/spam_checker.md +++ b/docs/spam_checker.md @@ -22,6 +22,8 @@ well as some specific methods: * `user_may_create_room` * `user_may_create_room_alias` * `user_may_publish_room` +* `check_username_for_spam` +* `check_registration_for_spam` The details of the each of these methods (as well as their inputs and outputs) are documented in the `synapse.events.spamcheck.SpamChecker` class. @@ -32,28 +34,33 @@ call back into the homeserver internals. ### Example ```python +from synapse.spam_checker_api import RegistrationBehaviour + class ExampleSpamChecker: def __init__(self, config, api): self.config = config self.api = api - def check_event_for_spam(self, foo): + async def check_event_for_spam(self, foo): return False # allow all events - def user_may_invite(self, inviter_userid, invitee_userid, room_id): + async def user_may_invite(self, inviter_userid, invitee_userid, room_id): return True # allow all invites - def user_may_create_room(self, userid): + async def user_may_create_room(self, userid): return True # allow all room creations - def user_may_create_room_alias(self, userid, room_alias): + async def user_may_create_room_alias(self, userid, room_alias): return True # allow all room aliases - def user_may_publish_room(self, userid, room_id): + async def user_may_publish_room(self, userid, room_id): return True # allow publishing of all rooms - def check_username_for_spam(self, user_profile): + async def check_username_for_spam(self, user_profile): return False # allow all usernames + + async def check_registration_for_spam(self, email_threepid, username, request_info): + return RegistrationBehaviour.ALLOW # allow all registrations ``` ## Configuration diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 936896656a..e7e3a7b9a4 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -15,10 +15,11 @@ # limitations under the License. import inspect -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import Collection +from synapse.util.async_helpers import maybe_awaitable if TYPE_CHECKING: import synapse.events @@ -39,7 +40,9 @@ class SpamChecker: else: self.spam_checkers.append(module(config=config)) - def check_event_for_spam(self, event: "synapse.events.EventBase") -> bool: + async def check_event_for_spam( + self, event: "synapse.events.EventBase" + ) -> Union[bool, str]: """Checks if a given event is considered "spammy" by this server. If the server considers an event spammy, then it will be rejected if @@ -50,15 +53,16 @@ class SpamChecker: event: the event to be checked Returns: - True if the event is spammy. + True or a string if the event is spammy. If a string is returned it + will be used as the error message returned to the user. """ for spam_checker in self.spam_checkers: - if spam_checker.check_event_for_spam(event): + if await maybe_awaitable(spam_checker.check_event_for_spam(event)): return True return False - def user_may_invite( + async def user_may_invite( self, inviter_userid: str, invitee_userid: str, room_id: str ) -> bool: """Checks if a given user may send an invite @@ -75,14 +79,18 @@ class SpamChecker: """ for spam_checker in self.spam_checkers: if ( - spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id) + await maybe_awaitable( + spam_checker.user_may_invite( + inviter_userid, invitee_userid, room_id + ) + ) is False ): return False return True - def user_may_create_room(self, userid: str) -> bool: + async def user_may_create_room(self, userid: str) -> bool: """Checks if a given user may create a room If this method returns false, the creation request will be rejected. @@ -94,12 +102,15 @@ class SpamChecker: True if the user may create a room, otherwise False """ for spam_checker in self.spam_checkers: - if spam_checker.user_may_create_room(userid) is False: + if ( + await maybe_awaitable(spam_checker.user_may_create_room(userid)) + is False + ): return False return True - def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool: + async def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool: """Checks if a given user may create a room alias If this method returns false, the association request will be rejected. @@ -112,12 +123,17 @@ class SpamChecker: True if the user may create a room alias, otherwise False """ for spam_checker in self.spam_checkers: - if spam_checker.user_may_create_room_alias(userid, room_alias) is False: + if ( + await maybe_awaitable( + spam_checker.user_may_create_room_alias(userid, room_alias) + ) + is False + ): return False return True - def user_may_publish_room(self, userid: str, room_id: str) -> bool: + async def user_may_publish_room(self, userid: str, room_id: str) -> bool: """Checks if a given user may publish a room to the directory If this method returns false, the publish request will be rejected. @@ -130,12 +146,17 @@ class SpamChecker: True if the user may publish the room, otherwise False """ for spam_checker in self.spam_checkers: - if spam_checker.user_may_publish_room(userid, room_id) is False: + if ( + await maybe_awaitable( + spam_checker.user_may_publish_room(userid, room_id) + ) + is False + ): return False return True - def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool: + async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool: """Checks if a user ID or display name are considered "spammy" by this server. If the server considers a username spammy, then it will not be included in @@ -157,12 +178,12 @@ class SpamChecker: if checker: # Make a copy of the user profile object to ensure the spam checker # cannot modify it. - if checker(user_profile.copy()): + if await maybe_awaitable(checker(user_profile.copy())): return True return False - def check_registration_for_spam( + async def check_registration_for_spam( self, email_threepid: Optional[dict], username: Optional[str], @@ -185,7 +206,9 @@ class SpamChecker: # spam checker checker = getattr(spam_checker, "check_registration_for_spam", None) if checker: - behaviour = checker(email_threepid, username, request_info) + behaviour = await maybe_awaitable( + checker(email_threepid, username, request_info) + ) assert isinstance(behaviour, RegistrationBehaviour) if behaviour != RegistrationBehaviour.ALLOW: return behaviour diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 38aa47963f..383737520a 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -78,6 +78,7 @@ class FederationBase: ctx = current_context() + @defer.inlineCallbacks def callback(_, pdu: EventBase): with PreserveLoggingContext(ctx): if not check_event_content_hash(pdu): @@ -105,7 +106,11 @@ class FederationBase: ) return redacted_event - if self.spam_checker.check_event_for_spam(pdu): + result = yield defer.ensureDeferred( + self.spam_checker.check_event_for_spam(pdu) + ) + + if result: logger.warning( "Event contains spam, redacting %s: %s", pdu.event_id, diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 62f98dabc0..8deec4cd0c 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -14,7 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect import logging import time import unicodedata @@ -59,6 +58,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import ModuleApi from synapse.types import JsonDict, Requester, UserID from synapse.util import stringutils as stringutils +from synapse.util.async_helpers import maybe_awaitable from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.threepids import canonicalise_email @@ -1639,6 +1639,6 @@ class PasswordProvider: # This might return an awaitable, if it does block the log out # until it completes. - result = g(user_id=user_id, device_id=device_id, access_token=access_token,) - if inspect.isawaitable(result): - await result + await maybe_awaitable( + g(user_id=user_id, device_id=device_id, access_token=access_token,) + ) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index ad5683d251..abcf86352d 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -133,7 +133,9 @@ class DirectoryHandler(BaseHandler): 403, "You must be in the room to create an alias for it" ) - if not self.spam_checker.user_may_create_room_alias(user_id, room_alias): + if not await self.spam_checker.user_may_create_room_alias( + user_id, room_alias + ): raise AuthError(403, "This user is not permitted to create this alias") if not self.config.is_alias_creation_allowed( @@ -409,7 +411,7 @@ class DirectoryHandler(BaseHandler): """ user_id = requester.user.to_string() - if not self.spam_checker.user_may_publish_room(user_id, room_id): + if not await self.spam_checker.user_may_publish_room(user_id, room_id): raise AuthError( 403, "This user is not permitted to publish rooms to the room list" ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index df82e60b33..fd8de8696d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1593,7 +1593,7 @@ class FederationHandler(BaseHandler): if self.hs.config.block_non_admin_invites: raise SynapseError(403, "This server does not accept room invites") - if not self.spam_checker.user_may_invite( + if not await self.spam_checker.user_may_invite( event.sender, event.state_key, event.room_id ): raise SynapseError( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 96843338ae..2b8aa9443d 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -744,7 +744,7 @@ class EventCreationHandler: event.sender, ) - spam_error = self.spam_checker.check_event_for_spam(event) + spam_error = await self.spam_checker.check_event_for_spam(event) if spam_error: if not isinstance(spam_error, str): spam_error = "Spam is not permitted here" diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 153cbae7b9..e850e45e46 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -18,7 +18,6 @@ from typing import List, Tuple from synapse.appservice import ApplicationService from synapse.handlers._base import BaseHandler from synapse.types import JsonDict, ReadReceipt, get_domain_from_id -from synapse.util.async_helpers import maybe_awaitable logger = logging.getLogger(__name__) @@ -98,10 +97,8 @@ class ReceiptsHandler(BaseHandler): self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids) # Note that the min here shouldn't be relied upon to be accurate. - await maybe_awaitable( - self.hs.get_pusherpool().on_new_receipts( - min_batch_id, max_batch_id, affected_room_ids - ) + await self.hs.get_pusherpool().on_new_receipts( + min_batch_id, max_batch_id, affected_room_ids ) return True diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 0d85fd0868..94b5610acd 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -187,7 +187,7 @@ class RegistrationHandler(BaseHandler): """ self.check_registration_ratelimit(address) - result = self.spam_checker.check_registration_for_spam( + result = await self.spam_checker.check_registration_for_spam( threepid, localpart, user_agent_ips or [], ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 82fb72b381..7583418946 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -358,7 +358,7 @@ class RoomCreationHandler(BaseHandler): """ user_id = requester.user.to_string() - if not self.spam_checker.user_may_create_room(user_id): + if not await self.spam_checker.user_may_create_room(user_id): raise SynapseError(403, "You are not permitted to create rooms") creation_content = { @@ -609,7 +609,7 @@ class RoomCreationHandler(BaseHandler): 403, "You are not permitted to create rooms", Codes.FORBIDDEN ) - if not is_requester_admin and not self.spam_checker.user_may_create_room( + if not is_requester_admin and not await self.spam_checker.user_may_create_room( user_id ): raise SynapseError(403, "You are not permitted to create rooms") diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index d85110a35e..cb5a29bc7e 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -408,7 +408,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ) block_invite = True - if not self.spam_checker.user_may_invite( + if not await self.spam_checker.user_may_invite( requester.user.to_string(), target.to_string(), room_id ): logger.info("Blocking invite due to spam checker") diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index afbebfc200..f263a638f8 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -81,11 +81,11 @@ class UserDirectoryHandler(StateDeltasHandler): results = await self.store.search_user_dir(user_id, search_term, limit) # Remove any spammy users from the results. - results["results"] = [ - user - for user in results["results"] - if not self.spam_checker.check_username_for_spam(user) - ] + non_spammy_users = [] + for user in results["results"]: + if not await self.spam_checker.check_username_for_spam(user): + non_spammy_users.append(user) + results["results"] = non_spammy_users return results diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 658f6ecd72..76b7decf26 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import logging import threading from functools import wraps @@ -25,6 +24,7 @@ from twisted.internet import defer from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.logging.opentracing import noop_context_manager, start_active_span +from synapse.util.async_helpers import maybe_awaitable if TYPE_CHECKING: import resource @@ -206,12 +206,7 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar if bg_start_span: ctx = start_active_span(desc, tags={"request_id": context.request}) with ctx: - result = func(*args, **kwargs) - - if inspect.isawaitable(result): - result = await result - - return result + return await maybe_awaitable(func(*args, **kwargs)) except Exception: logger.exception( "Background process '%s' threw an exception", desc, diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 18c9ed48d6..67f67efde7 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import logging import os import shutil @@ -21,6 +20,7 @@ from typing import Optional from synapse.config._base import Config from synapse.logging.context import defer_to_thread, run_in_background +from synapse.util.async_helpers import maybe_awaitable from ._base import FileInfo, Responder from .media_storage import FileResponder @@ -91,16 +91,14 @@ class StorageProviderWrapper(StorageProvider): if self.store_synchronous: # store_file is supposed to return an Awaitable, but guard # against improper implementations. - result = self.backend.store_file(path, file_info) - if inspect.isawaitable(result): - return await result + return await maybe_awaitable(self.backend.store_file(path, file_info)) else: # TODO: Handle errors. async def store(): try: - result = self.backend.store_file(path, file_info) - if inspect.isawaitable(result): - return await result + return await maybe_awaitable( + self.backend.store_file(path, file_info) + ) except Exception: logger.exception("Error storing file") @@ -110,9 +108,7 @@ class StorageProviderWrapper(StorageProvider): async def fetch(self, path, file_info): # store_file is supposed to return an Awaitable, but guard # against improper implementations. - result = self.backend.fetch(path, file_info) - if inspect.isawaitable(result): - return await result + return await maybe_awaitable(self.backend.fetch(path, file_info)) class FileStorageProviderBackend(StorageProvider): diff --git a/synapse/server.py b/synapse/server.py index 043810ad31..a198b0eb46 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -618,7 +618,7 @@ class HomeServer(metaclass=abc.ABCMeta): return StatsHandler(self) @cache_in_self - def get_spam_checker(self): + def get_spam_checker(self) -> SpamChecker: return SpamChecker(self) @cache_in_self diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 382f0cf3f0..9a873c8e8e 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -15,10 +15,12 @@ # limitations under the License. import collections +import inspect import logging from contextlib import contextmanager from typing import ( Any, + Awaitable, Callable, Dict, Hashable, @@ -542,11 +544,11 @@ class DoneAwaitable: raise StopIteration(self.value) -def maybe_awaitable(value): +def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]: """Convert a value to an awaitable if not already an awaitable. """ - - if hasattr(value, "__await__"): + if inspect.isawaitable(value): + assert isinstance(value, Awaitable) return value return DoneAwaitable(value) diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index f73e95393c..a6ee9edaec 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -12,13 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect import logging from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util.async_helpers import maybe_awaitable logger = logging.getLogger(__name__) @@ -105,10 +105,7 @@ class Signal: async def do(observer): try: - result = observer(*args, **kwargs) - if inspect.isawaitable(result): - result = await result - return result + return await maybe_awaitable(observer(*args, **kwargs)) except Exception as e: logger.warning( "%s signal observer %s failed: %r", self.name, observer, e, diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 98e5af2072..647a17cb90 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -270,7 +270,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): spam_checker = self.hs.get_spam_checker() class AllowAll: - def check_username_for_spam(self, user_profile): + async def check_username_for_spam(self, user_profile): # Allow all users. return False @@ -283,7 +283,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Configure a spam checker that filters all users. class BlockAll: - def check_username_for_spam(self, user_profile): + async def check_username_for_spam(self, user_profile): # All users are spammy. return True -- cgit 1.5.1 From bd30cfe86a5413191fe44d8f937a00117334ea82 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 16 Dec 2020 11:25:30 -0500 Subject: Convert internal pusher dicts to attrs classes. (#8940) This improves type hinting and should use less memory. --- changelog.d/8940.misc | 1 + mypy.ini | 1 + synapse/push/__init__.py | 60 +++++++-- synapse/push/emailpusher.py | 27 +++-- synapse/push/httppusher.py | 36 +++--- synapse/push/pusher.py | 24 ++-- synapse/push/pusherpool.py | 135 +++++++++++---------- .../slave/storage/_slaved_id_tracker.py | 20 ++- synapse/replication/slave/storage/pushers.py | 17 ++- synapse/rest/admin/users.py | 16 +-- synapse/rest/client/v1/pusher.py | 15 +-- synapse/storage/databases/main/__init__.py | 3 - synapse/storage/databases/main/pusher.py | 93 ++++++++------ synapse/storage/util/id_generators.py | 4 +- tests/push/test_email.py | 6 +- tests/push/test_http.py | 10 +- tests/rest/admin/test_user.py | 2 +- 17 files changed, 266 insertions(+), 204 deletions(-) create mode 100644 changelog.d/8940.misc (limited to 'synapse/rest') diff --git a/changelog.d/8940.misc b/changelog.d/8940.misc new file mode 100644 index 0000000000..4ff0b94b94 --- /dev/null +++ b/changelog.d/8940.misc @@ -0,0 +1 @@ +Add type hints to push module. diff --git a/mypy.ini b/mypy.ini index 334e3a22fb..1904204025 100644 --- a/mypy.ini +++ b/mypy.ini @@ -65,6 +65,7 @@ files = synapse/state, synapse/storage/databases/main/appservice.py, synapse/storage/databases/main/events.py, + synapse/storage/databases/main/pusher.py, synapse/storage/databases/main/registration.py, synapse/storage/databases/main/stream.py, synapse/storage/databases/main/ui_auth.py, diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index ad07ee86f6..9e7ac149a1 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -14,24 +14,70 @@ # limitations under the License. import abc -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, Optional -from synapse.types import RoomStreamToken +import attr + +from synapse.types import JsonDict, RoomStreamToken if TYPE_CHECKING: from synapse.app.homeserver import HomeServer +@attr.s(slots=True) +class PusherConfig: + """Parameters necessary to configure a pusher.""" + + id = attr.ib(type=Optional[str]) + user_name = attr.ib(type=str) + access_token = attr.ib(type=Optional[int]) + profile_tag = attr.ib(type=str) + kind = attr.ib(type=str) + app_id = attr.ib(type=str) + app_display_name = attr.ib(type=str) + device_display_name = attr.ib(type=str) + pushkey = attr.ib(type=str) + ts = attr.ib(type=int) + lang = attr.ib(type=Optional[str]) + data = attr.ib(type=Optional[JsonDict]) + last_stream_ordering = attr.ib(type=Optional[int]) + last_success = attr.ib(type=Optional[int]) + failing_since = attr.ib(type=Optional[int]) + + def as_dict(self) -> Dict[str, Any]: + """Information that can be retrieved about a pusher after creation.""" + return { + "app_display_name": self.app_display_name, + "app_id": self.app_id, + "data": self.data, + "device_display_name": self.device_display_name, + "kind": self.kind, + "lang": self.lang, + "profile_tag": self.profile_tag, + "pushkey": self.pushkey, + } + + +@attr.s(slots=True) +class ThrottleParams: + """Parameters for controlling the rate of sending pushes via email.""" + + last_sent_ts = attr.ib(type=int) + throttle_ms = attr.ib(type=int) + + class Pusher(metaclass=abc.ABCMeta): - def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any]): + def __init__(self, hs: "HomeServer", pusher_config: PusherConfig): self.hs = hs self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() - self.pusher_id = pusherdict["id"] - self.user_id = pusherdict["user_name"] - self.app_id = pusherdict["app_id"] - self.pushkey = pusherdict["pushkey"] + self.pusher_id = pusher_config.id + self.user_id = pusher_config.user_name + self.app_id = pusher_config.app_id + self.pushkey = pusher_config.pushkey + + self.last_stream_ordering = pusher_config.last_stream_ordering # This is the highest stream ordering we know it's safe to process. # When new events arrive, we'll be given a window of new events: we diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 11a97b8df4..d2eff75a58 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -14,13 +14,13 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional from twisted.internet.base import DelayedCall from twisted.internet.error import AlreadyCalled, AlreadyCancelled from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.push import Pusher +from synapse.push import Pusher, PusherConfig, ThrottleParams from synapse.push.mailer import Mailer if TYPE_CHECKING: @@ -60,15 +60,14 @@ class EmailPusher(Pusher): factor out the common parts """ - def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any], mailer: Mailer): - super().__init__(hs, pusherdict) + def __init__(self, hs: "HomeServer", pusher_config: PusherConfig, mailer: Mailer): + super().__init__(hs, pusher_config) self.mailer = mailer self.store = self.hs.get_datastore() - self.email = pusherdict["pushkey"] - self.last_stream_ordering = pusherdict["last_stream_ordering"] + self.email = pusher_config.pushkey self.timed_call = None # type: Optional[DelayedCall] - self.throttle_params = {} # type: Dict[str, Dict[str, int]] + self.throttle_params = {} # type: Dict[str, ThrottleParams] self._inited = False self._is_processing = False @@ -132,6 +131,7 @@ class EmailPusher(Pusher): if not self._inited: # this is our first loop: load up the throttle params + assert self.pusher_id is not None self.throttle_params = await self.store.get_throttle_params_by_room( self.pusher_id ) @@ -157,6 +157,7 @@ class EmailPusher(Pusher): being run. """ start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering + assert start is not None unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email( self.user_id, start, self.max_stream_ordering ) @@ -244,13 +245,13 @@ class EmailPusher(Pusher): def get_room_throttle_ms(self, room_id: str) -> int: if room_id in self.throttle_params: - return self.throttle_params[room_id]["throttle_ms"] + return self.throttle_params[room_id].throttle_ms else: return 0 def get_room_last_sent_ts(self, room_id: str) -> int: if room_id in self.throttle_params: - return self.throttle_params[room_id]["last_sent_ts"] + return self.throttle_params[room_id].last_sent_ts else: return 0 @@ -301,10 +302,10 @@ class EmailPusher(Pusher): new_throttle_ms = min( current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS ) - self.throttle_params[room_id] = { - "last_sent_ts": self.clock.time_msec(), - "throttle_ms": new_throttle_ms, - } + self.throttle_params[room_id] = ThrottleParams( + self.clock.time_msec(), new_throttle_ms, + ) + assert self.pusher_id is not None await self.store.set_throttle_params( self.pusher_id, room_id, self.throttle_params[room_id] ) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index e8b25bcd2a..417fe0f1f5 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -25,7 +25,7 @@ from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.logging import opentracing from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.push import Pusher, PusherConfigException +from synapse.push import Pusher, PusherConfig, PusherConfigException from . import push_rule_evaluator, push_tools @@ -62,33 +62,29 @@ class HttpPusher(Pusher): # This one's in ms because we compare it against the clock GIVE_UP_AFTER_MS = 24 * 60 * 60 * 1000 - def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any]): - super().__init__(hs, pusherdict) + def __init__(self, hs: "HomeServer", pusher_config: PusherConfig): + super().__init__(hs, pusher_config) self.storage = self.hs.get_storage() - self.app_display_name = pusherdict["app_display_name"] - self.device_display_name = pusherdict["device_display_name"] - self.pushkey_ts = pusherdict["ts"] - self.data = pusherdict["data"] - self.last_stream_ordering = pusherdict["last_stream_ordering"] + self.app_display_name = pusher_config.app_display_name + self.device_display_name = pusher_config.device_display_name + self.pushkey_ts = pusher_config.ts + self.data = pusher_config.data self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC - self.failing_since = pusherdict["failing_since"] + self.failing_since = pusher_config.failing_since self.timed_call = None self._is_processing = False self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room - if "data" not in pusherdict: - raise PusherConfigException("No 'data' key for HTTP pusher") - self.data = pusherdict["data"] + self.data = pusher_config.data + if self.data is None: + raise PusherConfigException("'data' key can not be null for HTTP pusher") self.name = "%s/%s/%s" % ( - pusherdict["user_name"], - pusherdict["app_id"], - pusherdict["pushkey"], + pusher_config.user_name, + pusher_config.app_id, + pusher_config.pushkey, ) - if self.data is None: - raise PusherConfigException("data can not be null for HTTP pusher") - # Validate that there's a URL and it is of the proper form. if "url" not in self.data: raise PusherConfigException("'url' required in data for HTTP pusher") @@ -180,6 +176,7 @@ class HttpPusher(Pusher): Never call this directly: use _process which will only allow this to run once per pusher. """ + assert self.last_stream_ordering is not None unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http( self.user_id, self.last_stream_ordering, self.max_stream_ordering ) @@ -208,6 +205,7 @@ class HttpPusher(Pusher): http_push_processed_counter.inc() self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.last_stream_ordering = push_action["stream_ordering"] + assert self.last_stream_ordering is not None pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success( self.app_id, self.pushkey, @@ -314,6 +312,8 @@ class HttpPusher(Pusher): # or may do so (i.e. is encrypted so has unknown effects). priority = "high" + # This was checked in the __init__, but mypy doesn't seem to know that. + assert self.data is not None if self.data.get("format") == "event_id_only": d = { "notification": { diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index 8f1072b094..2aa7918fb4 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -14,9 +14,9 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional +from typing import TYPE_CHECKING, Callable, Dict, Optional -from synapse.push import Pusher +from synapse.push import Pusher, PusherConfig from synapse.push.emailpusher import EmailPusher from synapse.push.httppusher import HttpPusher from synapse.push.mailer import Mailer @@ -34,7 +34,7 @@ class PusherFactory: self.pusher_types = { "http": HttpPusher - } # type: Dict[str, Callable[[HomeServer, dict], Pusher]] + } # type: Dict[str, Callable[[HomeServer, PusherConfig], Pusher]] logger.info("email enable notifs: %r", hs.config.email_enable_notifs) if hs.config.email_enable_notifs: @@ -47,18 +47,18 @@ class PusherFactory: logger.info("defined email pusher type") - def create_pusher(self, pusherdict: Dict[str, Any]) -> Optional[Pusher]: - kind = pusherdict["kind"] + def create_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]: + kind = pusher_config.kind f = self.pusher_types.get(kind, None) if not f: return None - logger.debug("creating %s pusher for %r", kind, pusherdict) - return f(self.hs, pusherdict) + logger.debug("creating %s pusher for %r", kind, pusher_config) + return f(self.hs, pusher_config) def _create_email_pusher( - self, _hs: "HomeServer", pusherdict: Dict[str, Any] + self, _hs: "HomeServer", pusher_config: PusherConfig ) -> EmailPusher: - app_name = self._app_name_from_pusherdict(pusherdict) + app_name = self._app_name_from_pusherdict(pusher_config) mailer = self.mailers.get(app_name) if not mailer: mailer = Mailer( @@ -68,10 +68,10 @@ class PusherFactory: template_text=self._notif_template_text, ) self.mailers[app_name] = mailer - return EmailPusher(self.hs, pusherdict, mailer) + return EmailPusher(self.hs, pusher_config, mailer) - def _app_name_from_pusherdict(self, pusherdict: Dict[str, Any]) -> str: - data = pusherdict["data"] + def _app_name_from_pusherdict(self, pusher_config: PusherConfig) -> str: + data = pusher_config.data if isinstance(data, dict): brand = data.get("brand") diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 9c12d81cfb..8158356d40 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Dict, Iterable, Optional from prometheus_client import Gauge @@ -23,9 +23,9 @@ from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, ) -from synapse.push import Pusher, PusherConfigException +from synapse.push import Pusher, PusherConfig, PusherConfigException from synapse.push.pusher import PusherFactory -from synapse.types import RoomStreamToken +from synapse.types import JsonDict, RoomStreamToken from synapse.util.async_helpers import concurrently_execute if TYPE_CHECKING: @@ -77,7 +77,7 @@ class PusherPool: # map from user id to app_id:pushkey to pusher self.pushers = {} # type: Dict[str, Dict[str, Pusher]] - def start(self): + def start(self) -> None: """Starts the pushers off in a background process. """ if not self._should_start_pushers: @@ -87,16 +87,16 @@ class PusherPool: async def add_pusher( self, - user_id, - access_token, - kind, - app_id, - app_display_name, - device_display_name, - pushkey, - lang, - data, - profile_tag="", + user_id: str, + access_token: Optional[int], + kind: str, + app_id: str, + app_display_name: str, + device_display_name: str, + pushkey: str, + lang: Optional[str], + data: JsonDict, + profile_tag: str = "", ) -> Optional[Pusher]: """Creates a new pusher and adds it to the pool @@ -111,21 +111,23 @@ class PusherPool: # recreated, added and started: this means we have only one # code path adding pushers. self.pusher_factory.create_pusher( - { - "id": None, - "user_name": user_id, - "kind": kind, - "app_id": app_id, - "app_display_name": app_display_name, - "device_display_name": device_display_name, - "pushkey": pushkey, - "ts": time_now_msec, - "lang": lang, - "data": data, - "last_stream_ordering": None, - "last_success": None, - "failing_since": None, - } + PusherConfig( + id=None, + user_name=user_id, + access_token=access_token, + profile_tag=profile_tag, + kind=kind, + app_id=app_id, + app_display_name=app_display_name, + device_display_name=device_display_name, + pushkey=pushkey, + ts=time_now_msec, + lang=lang, + data=data, + last_stream_ordering=None, + last_success=None, + failing_since=None, + ) ) # create the pusher setting last_stream_ordering to the current maximum @@ -151,43 +153,44 @@ class PusherPool: return pusher async def remove_pushers_by_app_id_and_pushkey_not_user( - self, app_id, pushkey, not_user_id - ): + self, app_id: str, pushkey: str, not_user_id: str + ) -> None: to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) for p in to_remove: - if p["user_name"] != not_user_id: + if p.user_name != not_user_id: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", app_id, pushkey, - p["user_name"], + p.user_name, ) - await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) + await self.remove_pusher(p.app_id, p.pushkey, p.user_name) - async def remove_pushers_by_access_token(self, user_id, access_tokens): + async def remove_pushers_by_access_token( + self, user_id: str, access_tokens: Iterable[int] + ) -> None: """Remove the pushers for a given user corresponding to a set of access_tokens. Args: - user_id (str): user to remove pushers for - access_tokens (Iterable[int]): access token *ids* to remove pushers - for + user_id: user to remove pushers for + access_tokens: access token *ids* to remove pushers for """ if not self._pusher_shard_config.should_handle(self._instance_name, user_id): return tokens = set(access_tokens) for p in await self.store.get_pushers_by_user_id(user_id): - if p["access_token"] in tokens: + if p.access_token in tokens: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", - p["app_id"], - p["pushkey"], - p["user_name"], + p.app_id, + p.pushkey, + p.user_name, ) - await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) + await self.remove_pusher(p.app_id, p.pushkey, p.user_name) - def on_new_notifications(self, max_token: RoomStreamToken): + def on_new_notifications(self, max_token: RoomStreamToken) -> None: if not self.pushers: # nothing to do here. return @@ -206,7 +209,7 @@ class PusherPool: self._on_new_notifications(max_token) @wrap_as_background_process("on_new_notifications") - async def _on_new_notifications(self, max_token: RoomStreamToken): + async def _on_new_notifications(self, max_token: RoomStreamToken) -> None: # We just use the minimum stream ordering and ignore the vector clock # component. This is safe to do as long as we *always* ignore the vector # clock components. @@ -236,7 +239,9 @@ class PusherPool: except Exception: logger.exception("Exception in pusher on_new_notifications") - async def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): + async def on_new_receipts( + self, min_stream_id: int, max_stream_id: int, affected_room_ids: Iterable[str] + ) -> None: if not self.pushers: # nothing to do here. return @@ -280,14 +285,14 @@ class PusherPool: resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) - pusher_dict = None + pusher_config = None for r in resultlist: - if r["user_name"] == user_id: - pusher_dict = r + if r.user_name == user_id: + pusher_config = r pusher = None - if pusher_dict: - pusher = await self._start_pusher(pusher_dict) + if pusher_config: + pusher = await self._start_pusher(pusher_config) return pusher @@ -302,44 +307,44 @@ class PusherPool: logger.info("Started pushers") - async def _start_pusher(self, pusherdict: Dict[str, Any]) -> Optional[Pusher]: + async def _start_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]: """Start the given pusher Args: - pusherdict: dict with the values pulled from the db table + pusher_config: The pusher configuration with the values pulled from the db table Returns: The newly created pusher or None. """ if not self._pusher_shard_config.should_handle( - self._instance_name, pusherdict["user_name"] + self._instance_name, pusher_config.user_name ): return None try: - p = self.pusher_factory.create_pusher(pusherdict) + p = self.pusher_factory.create_pusher(pusher_config) except PusherConfigException as e: logger.warning( "Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s", - pusherdict["id"], - pusherdict.get("user_name"), - pusherdict.get("app_id"), - pusherdict.get("pushkey"), + pusher_config.id, + pusher_config.user_name, + pusher_config.app_id, + pusher_config.pushkey, e, ) return None except Exception: logger.exception( - "Couldn't start pusher id %i: caught Exception", pusherdict["id"], + "Couldn't start pusher id %i: caught Exception", pusher_config.id, ) return None if not p: return None - appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"]) + appid_pushkey = "%s:%s" % (pusher_config.app_id, pusher_config.pushkey) - byuser = self.pushers.setdefault(pusherdict["user_name"], {}) + byuser = self.pushers.setdefault(pusher_config.user_name, {}) if appid_pushkey in byuser: byuser[appid_pushkey].on_stop() byuser[appid_pushkey] = p @@ -349,8 +354,8 @@ class PusherPool: # Check if there *may* be push to process. We do this as this check is a # lot cheaper to do than actually fetching the exact rows we need to # push. - user_id = pusherdict["user_name"] - last_stream_ordering = pusherdict["last_stream_ordering"] + user_id = pusher_config.user_name + last_stream_ordering = pusher_config.last_stream_ordering if last_stream_ordering: have_notifs = await self.store.get_if_maybe_push_in_range_for_user( user_id, last_stream_ordering @@ -364,7 +369,7 @@ class PusherPool: return p - async def remove_pusher(self, app_id, pushkey, user_id): + async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None: appid_pushkey = "%s:%s" % (app_id, pushkey) byuser = self.pushers.get(user_id, {}) diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py index eb74903d68..0d39a93ed2 100644 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -12,21 +12,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Optional, Tuple +from synapse.storage.types import Connection from synapse.storage.util.id_generators import _load_current_id class SlavedIdTracker: - def __init__(self, db_conn, table, column, extra_tables=[], step=1): + def __init__( + self, + db_conn: Connection, + table: str, + column: str, + extra_tables: Optional[List[Tuple[str, str]]] = None, + step: int = 1, + ): self.step = step self._current = _load_current_id(db_conn, table, column, step) - for table, column in extra_tables: - self.advance(None, _load_current_id(db_conn, table, column)) + if extra_tables: + for table, column in extra_tables: + self.advance(None, _load_current_id(db_conn, table, column)) - def advance(self, instance_name, new_id): + def advance(self, instance_name: Optional[str], new_id: int): self._current = (max if self.step > 0 else min)(self._current, new_id) - def get_current_token(self): + def get_current_token(self) -> int: """ Returns: diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index c418730ba8..045bd014da 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -13,26 +13,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING from synapse.replication.tcp.streams import PushersStream from synapse.storage.database import DatabasePool from synapse.storage.databases.main.pusher import PusherWorkerStore +from synapse.storage.types import Connection from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) - self._pushers_id_gen = SlavedIdTracker( + self._pushers_id_gen = SlavedIdTracker( # type: ignore db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] ) - def get_pushers_stream_token(self): + def get_pushers_stream_token(self) -> int: return self._pushers_id_gen.get_current_token() - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, stream_name: str, instance_name: str, token, rows + ) -> None: if stream_name == PushersStream.NAME: - self._pushers_id_gen.advance(instance_name, token) + self._pushers_id_gen.advance(instance_name, token) # type: ignore return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 88cba369f5..6658c2da56 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -42,17 +42,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -_GET_PUSHERS_ALLOWED_KEYS = { - "app_display_name", - "app_id", - "data", - "device_display_name", - "kind", - "lang", - "profile_tag", - "pushkey", -} - class UsersRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)$") @@ -770,10 +759,7 @@ class PushersRestServlet(RestServlet): pushers = await self.store.get_pushers_by_user_id(user_id) - filtered_pushers = [ - {k: v for k, v in p.items() if k in _GET_PUSHERS_ALLOWED_KEYS} - for p in pushers - ] + filtered_pushers = [p.as_dict() for p in pushers] return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)} diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 8fe83f321a..89823fcc39 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -28,17 +28,6 @@ from synapse.rest.client.v2_alpha._base import client_patterns logger = logging.getLogger(__name__) -ALLOWED_KEYS = { - "app_display_name", - "app_id", - "data", - "device_display_name", - "kind", - "lang", - "profile_tag", - "pushkey", -} - class PushersRestServlet(RestServlet): PATTERNS = client_patterns("/pushers$", v1=True) @@ -54,9 +43,7 @@ class PushersRestServlet(RestServlet): pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) - filtered_pushers = [ - {k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers - ] + filtered_pushers = [p.as_dict() for p in pushers] return 200, {"pushers": filtered_pushers} diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 43660ec4fb..871fb646a5 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -149,9 +149,6 @@ class DataStore( self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") - self._pushers_id_gen = StreamIdGenerator( - db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] - ) self._group_updates_id_gen = StreamIdGenerator( db_conn, "local_group_updates", "stream_id" ) diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 7997242d90..77ba9d819e 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -15,18 +15,32 @@ # limitations under the License. import logging -from typing import Iterable, Iterator, List, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple from canonicaljson import encode_canonical_json +from synapse.push import PusherConfig, ThrottleParams from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import DatabasePool +from synapse.storage.types import Connection +from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.types import JsonDict from synapse.util.caches.descriptors import cached, cachedList +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class PusherWorkerStore(SQLBaseStore): - def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]: + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + super().__init__(database, db_conn, hs) + self._pushers_id_gen = StreamIdGenerator( + db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] + ) + + def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]: """JSON-decode the data in the rows returned from the `pushers` table Drops any rows whose data cannot be decoded @@ -44,21 +58,23 @@ class PusherWorkerStore(SQLBaseStore): ) continue - yield r + yield PusherConfig(**r) - async def user_has_pusher(self, user_id): + async def user_has_pusher(self, user_id: str) -> bool: ret = await self.db_pool.simple_select_one_onecol( "pushers", {"user_name": user_id}, "id", allow_none=True ) return ret is not None - def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey): - return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey}) + async def get_pushers_by_app_id_and_pushkey( + self, app_id: str, pushkey: str + ) -> Iterator[PusherConfig]: + return await self.get_pushers_by({"app_id": app_id, "pushkey": pushkey}) - def get_pushers_by_user_id(self, user_id): - return self.get_pushers_by({"user_name": user_id}) + async def get_pushers_by_user_id(self, user_id: str) -> Iterator[PusherConfig]: + return await self.get_pushers_by({"user_name": user_id}) - async def get_pushers_by(self, keyvalues): + async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]: ret = await self.db_pool.simple_select_list( "pushers", keyvalues, @@ -83,7 +99,7 @@ class PusherWorkerStore(SQLBaseStore): ) return self._decode_pushers_rows(ret) - async def get_all_pushers(self): + async def get_all_pushers(self) -> Iterator[PusherConfig]: def get_pushers(txn): txn.execute("SELECT * FROM pushers") rows = self.db_pool.cursor_to_dict(txn) @@ -159,14 +175,16 @@ class PusherWorkerStore(SQLBaseStore): ) @cached(num_args=1, max_entries=15000) - async def get_if_user_has_pusher(self, user_id): + async def get_if_user_has_pusher(self, user_id: str): # This only exists for the cachedList decorator raise NotImplementedError() @cachedList( cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1, ) - async def get_if_users_have_pushers(self, user_ids): + async def get_if_users_have_pushers( + self, user_ids: Iterable[str] + ) -> Dict[str, bool]: rows = await self.db_pool.simple_select_many_batch( table="pushers", column="user_name", @@ -224,7 +242,7 @@ class PusherWorkerStore(SQLBaseStore): return bool(updated) async def update_pusher_failing_since( - self, app_id, pushkey, user_id, failing_since + self, app_id: str, pushkey: str, user_id: str, failing_since: Optional[int] ) -> None: await self.db_pool.simple_update( table="pushers", @@ -233,7 +251,9 @@ class PusherWorkerStore(SQLBaseStore): desc="update_pusher_failing_since", ) - async def get_throttle_params_by_room(self, pusher_id): + async def get_throttle_params_by_room( + self, pusher_id: str + ) -> Dict[str, ThrottleParams]: res = await self.db_pool.simple_select_list( "pusher_throttle", {"pusher": pusher_id}, @@ -243,43 +263,44 @@ class PusherWorkerStore(SQLBaseStore): params_by_room = {} for row in res: - params_by_room[row["room_id"]] = { - "last_sent_ts": row["last_sent_ts"], - "throttle_ms": row["throttle_ms"], - } + params_by_room[row["room_id"]] = ThrottleParams( + row["last_sent_ts"], row["throttle_ms"], + ) return params_by_room - async def set_throttle_params(self, pusher_id, room_id, params) -> None: + async def set_throttle_params( + self, pusher_id: str, room_id: str, params: ThrottleParams + ) -> None: # no need to lock because `pusher_throttle` has a primary key on # (pusher, room_id) so simple_upsert will retry await self.db_pool.simple_upsert( "pusher_throttle", {"pusher": pusher_id, "room_id": room_id}, - params, + {"last_sent_ts": params.last_sent_ts, "throttle_ms": params.throttle_ms}, desc="set_throttle_params", lock=False, ) class PusherStore(PusherWorkerStore): - def get_pushers_stream_token(self): + def get_pushers_stream_token(self) -> int: return self._pushers_id_gen.get_current_token() async def add_pusher( self, - user_id, - access_token, - kind, - app_id, - app_display_name, - device_display_name, - pushkey, - pushkey_ts, - lang, - data, - last_stream_ordering, - profile_tag="", + user_id: str, + access_token: Optional[int], + kind: str, + app_id: str, + app_display_name: str, + device_display_name: str, + pushkey: str, + pushkey_ts: int, + lang: Optional[str], + data: Optional[JsonDict], + last_stream_ordering: int, + profile_tag: str = "", ) -> None: async with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on @@ -311,16 +332,16 @@ class PusherStore(PusherWorkerStore): # invalidate, since we the user might not have had a pusher before await self.db_pool.runInteraction( "add_pusher", - self._invalidate_cache_and_stream, + self._invalidate_cache_and_stream, # type: ignore self.get_if_user_has_pusher, (user_id,), ) async def delete_pusher_by_app_id_pushkey_user_id( - self, app_id, pushkey, user_id + self, app_id: str, pushkey: str, user_id: str ) -> None: def delete_pusher_txn(txn, stream_id): - self._invalidate_cache_and_stream( + self._invalidate_cache_and_stream( # type: ignore txn, self.get_if_user_has_pusher, (user_id,) ) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 02d71302ea..133c0e7a28 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -153,12 +153,12 @@ class StreamIdGenerator: return _AsyncCtxManagerWrapper(manager()) - def get_current_token(self): + def get_current_token(self) -> int: """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. Returns: - int + The maximum stream id. """ with self._lock: if self._unfinished_ids: diff --git a/tests/push/test_email.py b/tests/push/test_email.py index bcdcafa5a9..961bf09de9 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -209,7 +209,7 @@ class EmailPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - last_stream_ordering = pushers[0]["last_stream_ordering"] + last_stream_ordering = pushers[0].last_stream_ordering # Advance time a bit, so the pusher will register something has happened self.pump(10) @@ -220,7 +220,7 @@ class EmailPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"]) + self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering) # One email was attempted to be sent self.assertEqual(len(self.email_attempts), 1) @@ -238,4 +238,4 @@ class EmailPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) + self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index cb3245d8cf..60f0820cff 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -144,7 +144,7 @@ class HTTPPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - last_stream_ordering = pushers[0]["last_stream_ordering"] + last_stream_ordering = pushers[0].last_stream_ordering # Advance time a bit, so the pusher will register something has happened self.pump() @@ -155,7 +155,7 @@ class HTTPPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"]) + self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering) # One push was attempted to be sent -- it'll be the first message self.assertEqual(len(self.push_attempts), 1) @@ -176,8 +176,8 @@ class HTTPPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) - last_stream_ordering = pushers[0]["last_stream_ordering"] + self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) + last_stream_ordering = pushers[0].last_stream_ordering # Now it'll try and send the second push message, which will be the second one self.assertEqual(len(self.push_attempts), 2) @@ -198,7 +198,7 @@ class HTTPPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) + self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) def test_sends_high_priority_for_encrypted(self): """ diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 582f983225..df62317e69 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -766,7 +766,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - self.assertEqual("@bob:test", pushers[0]["user_name"]) + self.assertEqual("@bob:test", pushers[0].user_name) @override_config( { -- cgit 1.5.1 From 4c33796b20f934a43f4f09a2bac6653c18d72b69 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 17 Dec 2020 12:55:21 +0000 Subject: Correctly handle AS registerations and add test --- synapse/handlers/auth.py | 8 +++++- synapse/handlers/register.py | 7 ++++- synapse/replication/http/login.py | 12 +++++++-- synapse/rest/client/v2_alpha/register.py | 14 +++++++--- tests/test_mau.py | 45 ++++++++++++++++++++++++++++++-- 5 files changed, 77 insertions(+), 9 deletions(-) (limited to 'synapse/rest') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index c7dc07008a..3b8ac4325b 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -709,6 +709,7 @@ class AuthHandler(BaseHandler): device_id: Optional[str], valid_until_ms: Optional[int], puppets_user_id: Optional[str] = None, + is_appservice_ghost: bool = False, ) -> str: """ Creates a new access token for the user with the given user ID. @@ -725,6 +726,7 @@ class AuthHandler(BaseHandler): we should always have a device ID) valid_until_ms: when the token is valid until. None for no expiry. + is_appservice_ghost: Whether the user is an application ghost user Returns: The access token for the user's session. Raises: @@ -745,7 +747,11 @@ class AuthHandler(BaseHandler): "Logging in user %s on device %s%s", user_id, device_id, fmt_expiry ) - await self.auth.check_auth_blocking(user_id) + if ( + not is_appservice_ghost + or self.hs.config.appservice.track_appservice_user_ips + ): + await self.auth.check_auth_blocking(user_id) access_token = self.macaroon_gen.generate_access_token(user_id) await self.store.add_access_token_to_user( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 0d85fd0868..039aff1061 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -630,6 +630,7 @@ class RegistrationHandler(BaseHandler): device_id: Optional[str], initial_display_name: Optional[str], is_guest: bool = False, + is_appservice_ghost: bool = False, ) -> Tuple[str, str]: """Register a device for a user and generate an access token. @@ -651,6 +652,7 @@ class RegistrationHandler(BaseHandler): device_id=device_id, initial_display_name=initial_display_name, is_guest=is_guest, + is_appservice_ghost=is_appservice_ghost, ) return r["device_id"], r["access_token"] @@ -672,7 +674,10 @@ class RegistrationHandler(BaseHandler): ) else: access_token = await self._auth_handler.get_access_token_for_user_id( - user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms + user_id, + device_id=registered_device_id, + valid_until_ms=valid_until_ms, + is_appservice_ghost=is_appservice_ghost, ) return (registered_device_id, access_token) diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index 4c81e2d784..36071feb36 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -36,7 +36,9 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): self.registration_handler = hs.get_registration_handler() @staticmethod - async def _serialize_payload(user_id, device_id, initial_display_name, is_guest): + async def _serialize_payload( + user_id, device_id, initial_display_name, is_guest, is_appservice_ghost + ): """ Args: device_id (str|None): Device ID to use, if None a new one is @@ -48,6 +50,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): "device_id": device_id, "initial_display_name": initial_display_name, "is_guest": is_guest, + "is_appservice_ghost": is_appservice_ghost, } async def _handle_request(self, request, user_id): @@ -56,9 +59,14 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): device_id = content["device_id"] initial_display_name = content["initial_display_name"] is_guest = content["is_guest"] + is_appservice_ghost = content["is_appservice_ghost"] device_id, access_token = await self.registration_handler.register_device( - user_id, device_id, initial_display_name, is_guest + user_id, + device_id, + initial_display_name, + is_guest, + is_appservice_ghost=is_appservice_ghost, ) return 200, {"device_id": device_id, "access_token": access_token} diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index a89ae6ddf9..722d993811 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -655,9 +655,13 @@ class RegisterRestServlet(RestServlet): user_id = await self.registration_handler.appservice_register( username, as_token ) - return await self._create_registration_details(user_id, body) + return await self._create_registration_details( + user_id, body, is_appservice_ghost=True, + ) - async def _create_registration_details(self, user_id, params): + async def _create_registration_details( + self, user_id, params, is_appservice_ghost=False + ): """Complete registration of newly-registered user Allocates device_id if one was not given; also creates access_token. @@ -674,7 +678,11 @@ class RegisterRestServlet(RestServlet): device_id = params.get("device_id") initial_display_name = params.get("initial_device_display_name") device_id, access_token = await self.registration_handler.register_device( - user_id, device_id, initial_display_name, is_guest=False + user_id, + device_id, + initial_display_name, + is_guest=False, + is_appservice_ghost=is_appservice_ghost, ) result.update({"access_token": access_token, "device_id": device_id}) diff --git a/tests/test_mau.py b/tests/test_mau.py index c5ec6396a7..26548b4611 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -19,6 +19,7 @@ import json from synapse.api.constants import LoginType from synapse.api.errors import Codes, HttpResponseException, SynapseError +from synapse.appservice import ApplicationService from synapse.rest.client.v2_alpha import register, sync from tests import unittest @@ -75,6 +76,44 @@ class TestMauLimit(unittest.HomeserverTestCase): self.assertEqual(e.code, 403) self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + def test_as_ignores_mau(self): + """Test that application services can still create users when the MAU + limit has been reached. + """ + + # Create and sync so that the MAU counts get updated + token1 = self.create_user("kermit1") + self.do_sync_for_user(token1) + token2 = self.create_user("kermit2") + self.do_sync_for_user(token2) + + # check we're testing what we think we are: there should be two active users + self.assertEqual(self.get_success(self.store.get_monthly_active_count()), 2) + + # We've created and activated two users, we shouldn't be able to + # register new users + with self.assertRaises(SynapseError) as cm: + self.create_user("kermit3") + + e = cm.exception + self.assertEqual(e.code, 403) + self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + + # Cheekily add an application service that we use to register a new user + # with. + as_token = "foobartoken" + self.store.services_cache.append( + ApplicationService( + token=as_token, + hostname=self.hs.hostname, + id="SomeASID", + sender="@as_sender:test", + namespaces={"users": [{"regex": "@as_*", "exclusive": True}]}, + ) + ) + + self.create_user("as_kermit4", token=as_token) + def test_allowed_after_a_month_mau(self): # Create and sync so that the MAU counts get updated token1 = self.create_user("kermit1") @@ -192,7 +231,7 @@ class TestMauLimit(unittest.HomeserverTestCase): self.reactor.advance(100) self.assertEqual(2, self.successResultOf(count)) - def create_user(self, localpart): + def create_user(self, localpart, token=None): request_data = json.dumps( { "username": localpart, @@ -201,7 +240,9 @@ class TestMauLimit(unittest.HomeserverTestCase): } ) - request, channel = self.make_request("POST", "/register", request_data) + request, channel = self.make_request( + "POST", "/register", request_data, access_token=token + ) if channel.code != 200: raise HttpResponseException( -- cgit 1.5.1 From 5d4c330ed979b0d60efe5f80fd76de8f162263a1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 18 Dec 2020 07:33:57 -0500 Subject: Allow re-using a UI auth validation for a period of time (#8970) --- changelog.d/8970.feature | 1 + docs/sample_config.yaml | 15 +++ synapse/config/_base.pyi | 4 +- synapse/config/auth.py | 110 +++++++++++++++++++++ synapse/config/homeserver.py | 4 +- synapse/config/password.py | 90 ----------------- synapse/handlers/auth.py | 32 ++++-- synapse/rest/client/v2_alpha/account.py | 10 +- synapse/storage/databases/main/registration.py | 38 +++++++ .../delta/58/26access_token_last_validated.sql | 18 ++++ tests/rest/client/v2_alpha/test_auth.py | 94 ++++++++++++------ 11 files changed, 280 insertions(+), 136 deletions(-) create mode 100644 changelog.d/8970.feature create mode 100644 synapse/config/auth.py delete mode 100644 synapse/config/password.py create mode 100644 synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql (limited to 'synapse/rest') diff --git a/changelog.d/8970.feature b/changelog.d/8970.feature new file mode 100644 index 0000000000..6d5b3303a6 --- /dev/null +++ b/changelog.d/8970.feature @@ -0,0 +1 @@ +Allow re-using an user-interactive authentication session for a period of time. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 75a01094d5..549c581a97 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2068,6 +2068,21 @@ password_config: # #require_uppercase: true +ui_auth: + # The number of milliseconds to allow a user-interactive authentication + # session to be active. + # + # This defaults to 0, meaning the user is queried for their credentials + # before every action, but this can be overridden to alow a single + # validation to be re-used. This weakens the protections afforded by + # the user-interactive authentication process, by allowing for multiple + # (and potentially different) operations to use the same validation session. + # + # Uncomment below to allow for credential validation to last for 15 + # seconds. + # + #session_timeout: 15000 + # Configuration for sending emails from Synapse. # diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index ed26e2fb60..29aa064e57 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -3,6 +3,7 @@ from typing import Any, Iterable, List, Optional from synapse.config import ( api, appservice, + auth, captcha, cas, consent_config, @@ -14,7 +15,6 @@ from synapse.config import ( logger, metrics, oidc_config, - password, password_auth_providers, push, ratelimiting, @@ -65,7 +65,7 @@ class RootConfig: sso: sso.SSOConfig oidc: oidc_config.OIDCConfig jwt: jwt_config.JWTConfig - password: password.PasswordConfig + auth: auth.AuthConfig email: emailconfig.EmailConfig worker: workers.WorkerConfig authproviders: password_auth_providers.PasswordAuthProviderConfig diff --git a/synapse/config/auth.py b/synapse/config/auth.py new file mode 100644 index 0000000000..2b3e2ce87b --- /dev/null +++ b/synapse/config/auth.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + + +class AuthConfig(Config): + """Password and login configuration + """ + + section = "auth" + + def read_config(self, config, **kwargs): + password_config = config.get("password_config", {}) + if password_config is None: + password_config = {} + + self.password_enabled = password_config.get("enabled", True) + self.password_localdb_enabled = password_config.get("localdb_enabled", True) + self.password_pepper = password_config.get("pepper", "") + + # Password policy + self.password_policy = password_config.get("policy") or {} + self.password_policy_enabled = self.password_policy.get("enabled", False) + + # User-interactive authentication + ui_auth = config.get("ui_auth") or {} + self.ui_auth_session_timeout = ui_auth.get("session_timeout", 0) + + def generate_config_section(self, config_dir_path, server_name, **kwargs): + return """\ + password_config: + # Uncomment to disable password login + # + #enabled: false + + # Uncomment to disable authentication against the local password + # database. This is ignored if `enabled` is false, and is only useful + # if you have other password_providers. + # + #localdb_enabled: false + + # Uncomment and change to a secret random string for extra security. + # DO NOT CHANGE THIS AFTER INITIAL SETUP! + # + #pepper: "EVEN_MORE_SECRET" + + # Define and enforce a password policy. Each parameter is optional. + # This is an implementation of MSC2000. + # + policy: + # Whether to enforce the password policy. + # Defaults to 'false'. + # + #enabled: true + + # Minimum accepted length for a password. + # Defaults to 0. + # + #minimum_length: 15 + + # Whether a password must contain at least one digit. + # Defaults to 'false'. + # + #require_digit: true + + # Whether a password must contain at least one symbol. + # A symbol is any character that's not a number or a letter. + # Defaults to 'false'. + # + #require_symbol: true + + # Whether a password must contain at least one lowercase letter. + # Defaults to 'false'. + # + #require_lowercase: true + + # Whether a password must contain at least one lowercase letter. + # Defaults to 'false'. + # + #require_uppercase: true + + ui_auth: + # The number of milliseconds to allow a user-interactive authentication + # session to be active. + # + # This defaults to 0, meaning the user is queried for their credentials + # before every action, but this can be overridden to alow a single + # validation to be re-used. This weakens the protections afforded by + # the user-interactive authentication process, by allowing for multiple + # (and potentially different) operations to use the same validation session. + # + # Uncomment below to allow for credential validation to last for 15 + # seconds. + # + #session_timeout: 15000 + """ diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index be65554524..4bd2b3587b 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -17,6 +17,7 @@ from ._base import RootConfig from .api import ApiConfig from .appservice import AppServiceConfig +from .auth import AuthConfig from .cache import CacheConfig from .captcha import CaptchaConfig from .cas import CasConfig @@ -30,7 +31,6 @@ from .key import KeyConfig from .logger import LoggingConfig from .metrics import MetricsConfig from .oidc_config import OIDCConfig -from .password import PasswordConfig from .password_auth_providers import PasswordAuthProviderConfig from .push import PushConfig from .ratelimiting import RatelimitConfig @@ -76,7 +76,7 @@ class HomeServerConfig(RootConfig): CasConfig, SSOConfig, JWTConfig, - PasswordConfig, + AuthConfig, EmailConfig, PasswordAuthProviderConfig, PushConfig, diff --git a/synapse/config/password.py b/synapse/config/password.py deleted file mode 100644 index 9c0ea8c30a..0000000000 --- a/synapse/config/password.py +++ /dev/null @@ -1,90 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ._base import Config - - -class PasswordConfig(Config): - """Password login configuration - """ - - section = "password" - - def read_config(self, config, **kwargs): - password_config = config.get("password_config", {}) - if password_config is None: - password_config = {} - - self.password_enabled = password_config.get("enabled", True) - self.password_localdb_enabled = password_config.get("localdb_enabled", True) - self.password_pepper = password_config.get("pepper", "") - - # Password policy - self.password_policy = password_config.get("policy") or {} - self.password_policy_enabled = self.password_policy.get("enabled", False) - - def generate_config_section(self, config_dir_path, server_name, **kwargs): - return """\ - password_config: - # Uncomment to disable password login - # - #enabled: false - - # Uncomment to disable authentication against the local password - # database. This is ignored if `enabled` is false, and is only useful - # if you have other password_providers. - # - #localdb_enabled: false - - # Uncomment and change to a secret random string for extra security. - # DO NOT CHANGE THIS AFTER INITIAL SETUP! - # - #pepper: "EVEN_MORE_SECRET" - - # Define and enforce a password policy. Each parameter is optional. - # This is an implementation of MSC2000. - # - policy: - # Whether to enforce the password policy. - # Defaults to 'false'. - # - #enabled: true - - # Minimum accepted length for a password. - # Defaults to 0. - # - #minimum_length: 15 - - # Whether a password must contain at least one digit. - # Defaults to 'false'. - # - #require_digit: true - - # Whether a password must contain at least one symbol. - # A symbol is any character that's not a number or a letter. - # Defaults to 'false'. - # - #require_symbol: true - - # Whether a password must contain at least one lowercase letter. - # Defaults to 'false'. - # - #require_lowercase: true - - # Whether a password must contain at least one lowercase letter. - # Defaults to 'false'. - # - #require_uppercase: true - """ diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 57ff461f92..f4434673dc 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -226,6 +226,9 @@ class AuthHandler(BaseHandler): burst_count=self.hs.config.rc_login_failed_attempts.burst_count, ) + # The number of seconds to keep a UI auth session active. + self._ui_auth_session_timeout = hs.config.ui_auth_session_timeout + # Ratelimitier for failed /login attempts self._failed_login_attempts_ratelimiter = Ratelimiter( clock=hs.get_clock(), @@ -283,7 +286,7 @@ class AuthHandler(BaseHandler): request_body: Dict[str, Any], clientip: str, description: str, - ) -> Tuple[dict, str]: + ) -> Tuple[dict, Optional[str]]: """ Checks that the user is who they claim to be, via a UI auth. @@ -310,7 +313,8 @@ class AuthHandler(BaseHandler): have been given only in a previous call). 'session_id' is the ID of this session, either passed in by the - client or assigned by this call + client or assigned by this call. This is None if UI auth was + skipped (by re-using a previous validation). Raises: InteractiveAuthIncompleteError if the client has not yet completed @@ -324,6 +328,16 @@ class AuthHandler(BaseHandler): """ + if self._ui_auth_session_timeout: + last_validated = await self.store.get_access_token_last_validated( + requester.access_token_id + ) + if self.clock.time_msec() - last_validated < self._ui_auth_session_timeout: + # Return the input parameters, minus the auth key, which matches + # the logic in check_ui_auth. + request_body.pop("auth", None) + return request_body, None + user_id = requester.user.to_string() # Check if we should be ratelimited due to too many previous failed attempts @@ -359,6 +373,9 @@ class AuthHandler(BaseHandler): if user_id != requester.user.to_string(): raise AuthError(403, "Invalid auth") + # Note that the access token has been validated. + await self.store.update_access_token_last_validated(requester.access_token_id) + return params, session_id async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]: @@ -452,13 +469,10 @@ class AuthHandler(BaseHandler): all the stages in any of the permitted flows. """ - authdict = None sid = None # type: Optional[str] - if clientdict and "auth" in clientdict: - authdict = clientdict["auth"] - del clientdict["auth"] - if "session" in authdict: - sid = authdict["session"] + authdict = clientdict.pop("auth", {}) + if "session" in authdict: + sid = authdict["session"] # Convert the URI and method to strings. uri = request.uri.decode("utf-8") @@ -563,6 +577,8 @@ class AuthHandler(BaseHandler): creds = await self.store.get_completed_ui_auth_stages(session.session_id) for f in flows: + # If all the required credentials have been supplied, the user has + # successfully completed the UI auth process! if len(set(f) - set(creds)) == 0: # it's very useful to know what args are stored, but this can # include the password in the case of registering, so only log diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index eebee44a44..d837bde1d6 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -254,14 +254,18 @@ class PasswordRestServlet(RestServlet): logger.error("Auth succeeded but no known type! %r", result.keys()) raise SynapseError(500, "", Codes.UNKNOWN) - # If we have a password in this request, prefer it. Otherwise, there - # must be a password hash from an earlier request. + # If we have a password in this request, prefer it. Otherwise, use the + # password hash from an earlier request. if new_password: password_hash = await self.auth_handler.hash(new_password) - else: + elif session_id is not None: password_hash = await self.auth_handler.get_session_data( session_id, "password_hash", None ) + else: + # UI validation was skipped, but the request did not include a new + # password. + password_hash = None if not password_hash: raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index ff96c34c2e..8d05288ed4 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -943,6 +943,42 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="del_user_pending_deactivation", ) + async def get_access_token_last_validated(self, token_id: int) -> int: + """Retrieves the time (in milliseconds) of the last validation of an access token. + + Args: + token_id: The ID of the access token to update. + Raises: + StoreError if the access token was not found. + + Returns: + The last validation time. + """ + result = await self.db_pool.simple_select_one_onecol( + "access_tokens", {"id": token_id}, "last_validated" + ) + + # If this token has not been validated (since starting to track this), + # return 0 instead of None. + return result or 0 + + async def update_access_token_last_validated(self, token_id: int) -> None: + """Updates the last time an access token was validated. + + Args: + token_id: The ID of the access token to update. + Raises: + StoreError if there was a problem updating this. + """ + now = self._clock.time_msec() + + await self.db_pool.simple_update_one( + "access_tokens", + {"id": token_id}, + {"last_validated": now}, + desc="update_access_token_last_validated", + ) + class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): @@ -1150,6 +1186,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): The token ID """ next_id = self._access_tokens_id_gen.get_next() + now = self._clock.time_msec() await self.db_pool.simple_insert( "access_tokens", @@ -1160,6 +1197,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): "device_id": device_id, "valid_until_ms": valid_until_ms, "puppets_user_id": puppets_user_id, + "last_validated": now, }, desc="add_access_token_to_user", ) diff --git a/synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql b/synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql new file mode 100644 index 0000000000..1a101cd5eb --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- The last time this access token was "validated" (i.e. logged in or succeeded +-- at user-interactive authentication). +ALTER TABLE access_tokens ADD COLUMN last_validated BIGINT; diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 51323b3da3..ac66a4e0b7 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import Union from twisted.internet.defer import succeed @@ -177,13 +177,8 @@ class UIAuthTests(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.user_pass = "pass" self.user = self.register_user("test", self.user_pass) - self.user_tok = self.login("test", self.user_pass) - - def get_device_ids(self, access_token: str) -> List[str]: - # Get the list of devices so one can be deleted. - channel = self.make_request("GET", "devices", access_token=access_token,) - self.assertEqual(channel.code, 200) - return [d["device_id"] for d in channel.json_body["devices"]] + self.device_id = "dev1" + self.user_tok = self.login("test", self.user_pass, self.device_id) def delete_device( self, @@ -219,11 +214,9 @@ class UIAuthTests(unittest.HomeserverTestCase): """ Test user interactive authentication outside of registration. """ - device_id = self.get_device_ids(self.user_tok)[0] - # Attempt to delete this device. # Returns a 401 as per the spec - channel = self.delete_device(self.user_tok, device_id, 401) + channel = self.delete_device(self.user_tok, self.device_id, 401) # Grab the session session = channel.json_body["session"] @@ -233,7 +226,7 @@ class UIAuthTests(unittest.HomeserverTestCase): # Make another request providing the UI auth flow. self.delete_device( self.user_tok, - device_id, + self.device_id, 200, { "auth": { @@ -252,14 +245,13 @@ class UIAuthTests(unittest.HomeserverTestCase): UIA - check that still works. """ - device_id = self.get_device_ids(self.user_tok)[0] - channel = self.delete_device(self.user_tok, device_id, 401) + channel = self.delete_device(self.user_tok, self.device_id, 401) session = channel.json_body["session"] # Make another request providing the UI auth flow. self.delete_device( self.user_tok, - device_id, + self.device_id, 200, { "auth": { @@ -282,14 +274,11 @@ class UIAuthTests(unittest.HomeserverTestCase): session ID should be rejected. """ # Create a second login. - self.login("test", self.user_pass) - - device_ids = self.get_device_ids(self.user_tok) - self.assertEqual(len(device_ids), 2) + self.login("test", self.user_pass, "dev2") # Attempt to delete the first device. # Returns a 401 as per the spec - channel = self.delete_devices(401, {"devices": [device_ids[0]]}) + channel = self.delete_devices(401, {"devices": [self.device_id]}) # Grab the session session = channel.json_body["session"] @@ -301,7 +290,7 @@ class UIAuthTests(unittest.HomeserverTestCase): self.delete_devices( 200, { - "devices": [device_ids[1]], + "devices": ["dev2"], "auth": { "type": "m.login.password", "identifier": {"type": "m.id.user", "user": self.user}, @@ -316,14 +305,11 @@ class UIAuthTests(unittest.HomeserverTestCase): The initial requested URI cannot be modified during the user interactive authentication session. """ # Create a second login. - self.login("test", self.user_pass) - - device_ids = self.get_device_ids(self.user_tok) - self.assertEqual(len(device_ids), 2) + self.login("test", self.user_pass, "dev2") # Attempt to delete the first device. # Returns a 401 as per the spec - channel = self.delete_device(self.user_tok, device_ids[0], 401) + channel = self.delete_device(self.user_tok, self.device_id, 401) # Grab the session session = channel.json_body["session"] @@ -332,9 +318,11 @@ class UIAuthTests(unittest.HomeserverTestCase): # Make another request providing the UI auth flow, but try to delete the # second device. This results in an error. + # + # This makes use of the fact that the device ID is embedded into the URL. self.delete_device( self.user_tok, - device_ids[1], + "dev2", 403, { "auth": { @@ -346,6 +334,52 @@ class UIAuthTests(unittest.HomeserverTestCase): }, ) + @unittest.override_config({"ui_auth": {"session_timeout": 5 * 1000}}) + def test_can_reuse_session(self): + """ + The session can be reused if configured. + + Compare to test_cannot_change_uri. + """ + # Create a second and third login. + self.login("test", self.user_pass, "dev2") + self.login("test", self.user_pass, "dev3") + + # Attempt to delete a device. This works since the user just logged in. + self.delete_device(self.user_tok, "dev2", 200) + + # Move the clock forward past the validation timeout. + self.reactor.advance(6) + + # Deleting another devices throws the user into UI auth. + channel = self.delete_device(self.user_tok, "dev3", 401) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow. + self.delete_device( + self.user_tok, + "dev3", + 200, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) + + # Make another request, but try to delete the first device. This works + # due to re-using the previous session. + # + # Note that *no auth* information is provided, not even a session iD! + self.delete_device(self.user_tok, self.device_id, 200) + def test_does_not_offer_password_for_sso_user(self): login_resp = self.helper.login_via_oidc("username") user_tok = login_resp["access_token"] @@ -361,8 +395,7 @@ class UIAuthTests(unittest.HomeserverTestCase): def test_does_not_offer_sso_for_password_user(self): # now call the device deletion API: we should get the option to auth with SSO # and not password. - device_ids = self.get_device_ids(self.user_tok) - channel = self.delete_device(self.user_tok, device_ids[0], 401) + channel = self.delete_device(self.user_tok, self.device_id, 401) flows = channel.json_body["flows"] self.assertEqual(flows, [{"stages": ["m.login.password"]}]) @@ -373,8 +406,7 @@ class UIAuthTests(unittest.HomeserverTestCase): login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) self.assertEqual(login_resp["user_id"], self.user) - device_ids = self.get_device_ids(self.user_tok) - channel = self.delete_device(self.user_tok, device_ids[0], 401) + channel = self.delete_device(self.user_tok, self.device_id, 401) flows = channel.json_body["flows"] # we have no particular expectations of ordering here -- cgit 1.5.1 From 28877fade90a5cfb3457c9e6c70924dbbe8af715 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 18 Dec 2020 14:19:46 +0000 Subject: Implement a username picker for synapse (#8942) The final part (for now) of my work to implement a username picker in synapse itself. The idea is that we allow `UsernameMappingProvider`s to return `localpart=None`, in which case, rather than redirecting the browser back to the client, we redirect to a username-picker resource, which allows the user to enter a username. We *then* complete the SSO flow (including doing the client permission checks). The static resources for the username picker itself (in https://github.com/matrix-org/synapse/tree/rav/username_picker/synapse/res/username_picker) are essentially lifted wholesale from https://github.com/matrix-org/matrix-synapse-saml-mozilla/tree/master/matrix_synapse_saml_mozilla/res. As the comment says, we might want to think about making them customisable, but that can be a follow-up. Fixes #8876. --- changelog.d/8942.feature | 1 + docs/sample_config.yaml | 5 +- docs/sso_mapping_providers.md | 28 +-- synapse/app/homeserver.py | 2 + synapse/config/oidc_config.py | 5 +- synapse/handlers/oidc_handler.py | 59 +++---- synapse/handlers/sso.py | 254 ++++++++++++++++++++++++++- synapse/res/username_picker/index.html | 19 ++ synapse/res/username_picker/script.js | 95 ++++++++++ synapse/res/username_picker/style.css | 27 +++ synapse/rest/synapse/client/pick_username.py | 88 ++++++++++ synapse/types.py | 8 +- tests/handlers/test_oidc.py | 143 ++++++++++++++- tests/unittest.py | 8 +- 14 files changed, 683 insertions(+), 59 deletions(-) create mode 100644 changelog.d/8942.feature create mode 100644 synapse/res/username_picker/index.html create mode 100644 synapse/res/username_picker/script.js create mode 100644 synapse/res/username_picker/style.css create mode 100644 synapse/rest/synapse/client/pick_username.py (limited to 'synapse/rest') diff --git a/changelog.d/8942.feature b/changelog.d/8942.feature new file mode 100644 index 0000000000..d450ef4998 --- /dev/null +++ b/changelog.d/8942.feature @@ -0,0 +1 @@ +Add support for allowing users to pick their own user ID during a single-sign-on login. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 549c581a97..077cb619c7 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1825,9 +1825,10 @@ oidc_config: # * user: The claims returned by the UserInfo Endpoint and/or in the ID # Token # - # This must be configured if using the default mapping provider. + # If this is not set, the user will be prompted to choose their + # own username. # - localpart_template: "{{ user.preferred_username }}" + #localpart_template: "{{ user.preferred_username }}" # Jinja2 template for the display name to set on first login. # diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md index 7714b1d844..e1d6ede7ba 100644 --- a/docs/sso_mapping_providers.md +++ b/docs/sso_mapping_providers.md @@ -15,12 +15,18 @@ where SAML mapping providers come into play. SSO mapping providers are currently supported for OpenID and SAML SSO configurations. Please see the details below for how to implement your own. -It is the responsibility of the mapping provider to normalise the SSO attributes -and map them to a valid Matrix ID. The -[specification for Matrix IDs](https://matrix.org/docs/spec/appendices#user-identifiers) -has some information about what is considered valid. Alternately an easy way to -ensure it is valid is to use a Synapse utility function: -`synapse.types.map_username_to_mxid_localpart`. +It is up to the mapping provider whether the user should be assigned a predefined +Matrix ID based on the SSO attributes, or if the user should be allowed to +choose their own username. + +In the first case - where users are automatically allocated a Matrix ID - it is +the responsibility of the mapping provider to normalise the SSO attributes and +map them to a valid Matrix ID. The [specification for Matrix +IDs](https://matrix.org/docs/spec/appendices#user-identifiers) has some +information about what is considered valid. + +If the mapping provider does not assign a Matrix ID, then Synapse will +automatically serve an HTML page allowing the user to pick their own username. External mapping providers are provided to Synapse in the form of an external Python module. You can retrieve this module from [PyPI](https://pypi.org) or elsewhere, @@ -80,8 +86,9 @@ A custom mapping provider must specify the following methods: with failures=1. The method should then return a different `localpart` value, such as `john.doe1`. - Returns a dictionary with two keys: - - localpart: A required string, used to generate the Matrix ID. - - displayname: An optional string, the display name for the user. + - `localpart`: A string, used to generate the Matrix ID. If this is + `None`, the user is prompted to pick their own username. + - `displayname`: An optional string, the display name for the user. * `get_extra_attributes(self, userinfo, token)` - This method must be async. - Arguments: @@ -165,12 +172,13 @@ A custom mapping provider must specify the following methods: redirected to. - This method must return a dictionary, which will then be used by Synapse to build a new user. The following keys are allowed: - * `mxid_localpart` - Required. The mxid localpart of the new user. + * `mxid_localpart` - The mxid localpart of the new user. If this is + `None`, the user is prompted to pick their own username. * `displayname` - The displayname of the new user. If not provided, will default to the value of `mxid_localpart`. * `emails` - A list of emails for the new user. If not provided, will default to an empty list. - + Alternatively it can raise a `synapse.api.errors.RedirectException` to redirect the user to another page. This is useful to prompt the user for additional information, e.g. if you want them to provide their own username. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index bbb7407838..8d9b53be53 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -63,6 +63,7 @@ from synapse.rest import ClientRestResource from synapse.rest.admin import AdminRestResource from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.rest.synapse.client.pick_username import pick_username_resource from synapse.rest.well_known import WellKnownResource from synapse.server import HomeServer from synapse.storage import DataStore @@ -192,6 +193,7 @@ class SynapseHomeServer(HomeServer): "/_matrix/client/versions": client_resource, "/.well-known/matrix/client": WellKnownResource(self), "/_synapse/admin": AdminRestResource(self), + "/_synapse/client/pick_username": pick_username_resource(self), } ) diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 1abf8ed405..4e3055282d 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -203,9 +203,10 @@ class OIDCConfig(Config): # * user: The claims returned by the UserInfo Endpoint and/or in the ID # Token # - # This must be configured if using the default mapping provider. + # If this is not set, the user will be prompted to choose their + # own username. # - localpart_template: "{{{{ user.preferred_username }}}}" + #localpart_template: "{{{{ user.preferred_username }}}}" # Jinja2 template for the display name to set on first login. # diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index cbd11a1382..709f8dfc13 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -947,7 +947,7 @@ class OidcHandler(BaseHandler): UserAttributeDict = TypedDict( - "UserAttributeDict", {"localpart": str, "display_name": Optional[str]} + "UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]} ) C = TypeVar("C") @@ -1028,10 +1028,10 @@ env = Environment(finalize=jinja_finalize) @attr.s class JinjaOidcMappingConfig: - subject_claim = attr.ib() # type: str - localpart_template = attr.ib() # type: Template - display_name_template = attr.ib() # type: Optional[Template] - extra_attributes = attr.ib() # type: Dict[str, Template] + subject_claim = attr.ib(type=str) + localpart_template = attr.ib(type=Optional[Template]) + display_name_template = attr.ib(type=Optional[Template]) + extra_attributes = attr.ib(type=Dict[str, Template]) class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): @@ -1047,18 +1047,14 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): def parse_config(config: dict) -> JinjaOidcMappingConfig: subject_claim = config.get("subject_claim", "sub") - if "localpart_template" not in config: - raise ConfigError( - "missing key: oidc_config.user_mapping_provider.config.localpart_template" - ) - - try: - localpart_template = env.from_string(config["localpart_template"]) - except Exception as e: - raise ConfigError( - "invalid jinja template for oidc_config.user_mapping_provider.config.localpart_template: %r" - % (e,) - ) + localpart_template = None # type: Optional[Template] + if "localpart_template" in config: + try: + localpart_template = env.from_string(config["localpart_template"]) + except Exception as e: + raise ConfigError( + "invalid jinja template", path=["localpart_template"] + ) from e display_name_template = None # type: Optional[Template] if "display_name_template" in config: @@ -1066,26 +1062,22 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): display_name_template = env.from_string(config["display_name_template"]) except Exception as e: raise ConfigError( - "invalid jinja template for oidc_config.user_mapping_provider.config.display_name_template: %r" - % (e,) - ) + "invalid jinja template", path=["display_name_template"] + ) from e extra_attributes = {} # type Dict[str, Template] if "extra_attributes" in config: extra_attributes_config = config.get("extra_attributes") or {} if not isinstance(extra_attributes_config, dict): - raise ConfigError( - "oidc_config.user_mapping_provider.config.extra_attributes must be a dict" - ) + raise ConfigError("must be a dict", path=["extra_attributes"]) for key, value in extra_attributes_config.items(): try: extra_attributes[key] = env.from_string(value) except Exception as e: raise ConfigError( - "invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r" - % (key, e) - ) + "invalid jinja template", path=["extra_attributes", key] + ) from e return JinjaOidcMappingConfig( subject_claim=subject_claim, @@ -1100,14 +1092,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): async def map_user_attributes( self, userinfo: UserInfo, token: Token, failures: int ) -> UserAttributeDict: - localpart = self._config.localpart_template.render(user=userinfo).strip() + localpart = None + + if self._config.localpart_template: + localpart = self._config.localpart_template.render(user=userinfo).strip() - # Ensure only valid characters are included in the MXID. - localpart = map_username_to_mxid_localpart(localpart) + # Ensure only valid characters are included in the MXID. + localpart = map_username_to_mxid_localpart(localpart) - # Append suffix integer if last call to this function failed to produce - # a usable mxid. - localpart += str(failures) if failures else "" + # Append suffix integer if last call to this function failed to produce + # a usable mxid. + localpart += str(failures) if failures else "" display_name = None # type: Optional[str] if self._config.display_name_template is not None: diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index f054b66a53..548b02211b 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -13,17 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional import attr +from typing_extensions import NoReturn from twisted.web.http import Request -from synapse.api.errors import RedirectException +from synapse.api.errors import RedirectException, SynapseError from synapse.http.server import respond_with_html from synapse.http.site import SynapseRequest from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters from synapse.util.async_helpers import Linearizer +from synapse.util.stringutils import random_string if TYPE_CHECKING: from synapse.server import HomeServer @@ -40,16 +42,52 @@ class MappingException(Exception): @attr.s class UserAttributes: - localpart = attr.ib(type=str) + # the localpart of the mxid that the mapper has assigned to the user. + # if `None`, the mapper has not picked a userid, and the user should be prompted to + # enter one. + localpart = attr.ib(type=Optional[str]) display_name = attr.ib(type=Optional[str], default=None) emails = attr.ib(type=List[str], default=attr.Factory(list)) +@attr.s(slots=True) +class UsernameMappingSession: + """Data we track about SSO sessions""" + + # A unique identifier for this SSO provider, e.g. "oidc" or "saml". + auth_provider_id = attr.ib(type=str) + + # user ID on the IdP server + remote_user_id = attr.ib(type=str) + + # attributes returned by the ID mapper + display_name = attr.ib(type=Optional[str]) + emails = attr.ib(type=List[str]) + + # An optional dictionary of extra attributes to be provided to the client in the + # login response. + extra_login_attributes = attr.ib(type=Optional[JsonDict]) + + # where to redirect the client back to + client_redirect_url = attr.ib(type=str) + + # expiry time for the session, in milliseconds + expiry_time_ms = attr.ib(type=int) + + +# the HTTP cookie used to track the mapping session id +USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session" + + class SsoHandler: # The number of attempts to ask the mapping provider for when generating an MXID. _MAP_USERNAME_RETRIES = 1000 + # the time a UsernameMappingSession remains valid for + _MAPPING_SESSION_VALIDITY_PERIOD_MS = 15 * 60 * 1000 + def __init__(self, hs: "HomeServer"): + self._clock = hs.get_clock() self._store = hs.get_datastore() self._server_name = hs.hostname self._registration_handler = hs.get_registration_handler() @@ -59,6 +97,9 @@ class SsoHandler: # a lock on the mappings self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock()) + # a map from session id to session data + self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession] + def render_error( self, request, error: str, error_description: Optional[str] = None ) -> None: @@ -206,6 +247,18 @@ class SsoHandler: # Otherwise, generate a new user. if not user_id: attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper) + + if attributes.localpart is None: + # the mapper doesn't return a username. bail out with a redirect to + # the username picker. + await self._redirect_to_username_picker( + auth_provider_id, + remote_user_id, + attributes, + client_redirect_url, + extra_login_attributes, + ) + user_id = await self._register_mapped_user( attributes, auth_provider_id, @@ -243,10 +296,8 @@ class SsoHandler: ) if not attributes.localpart: - raise MappingException( - "Error parsing SSO response: SSO mapping provider plugin " - "did not return a localpart value" - ) + # the mapper has not picked a localpart + return attributes # Check if this mxid already exists user_id = UserID(attributes.localpart, self._server_name).to_string() @@ -261,6 +312,59 @@ class SsoHandler: ) return attributes + async def _redirect_to_username_picker( + self, + auth_provider_id: str, + remote_user_id: str, + attributes: UserAttributes, + client_redirect_url: str, + extra_login_attributes: Optional[JsonDict], + ) -> NoReturn: + """Creates a UsernameMappingSession and redirects the browser + + Called if the user mapping provider doesn't return a localpart for a new user. + Raises a RedirectException which redirects the browser to the username picker. + + Args: + auth_provider_id: A unique identifier for this SSO provider, e.g. + "oidc" or "saml". + + remote_user_id: The unique identifier from the SSO provider. + + attributes: the user attributes returned by the user mapping provider. + + client_redirect_url: The redirect URL passed in by the client, which we + will eventually redirect back to. + + extra_login_attributes: An optional dictionary of extra + attributes to be provided to the client in the login response. + + Raises: + RedirectException + """ + session_id = random_string(16) + now = self._clock.time_msec() + session = UsernameMappingSession( + auth_provider_id=auth_provider_id, + remote_user_id=remote_user_id, + display_name=attributes.display_name, + emails=attributes.emails, + client_redirect_url=client_redirect_url, + expiry_time_ms=now + self._MAPPING_SESSION_VALIDITY_PERIOD_MS, + extra_login_attributes=extra_login_attributes, + ) + + self._username_mapping_sessions[session_id] = session + logger.info("Recorded registration session id %s", session_id) + + # Set the cookie and redirect to the username picker + e = RedirectException(b"/_synapse/client/pick_username") + e.cookies.append( + b"%s=%s; path=/" + % (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii")) + ) + raise e + async def _register_mapped_user( self, attributes: UserAttributes, @@ -269,9 +373,38 @@ class SsoHandler: user_agent: str, ip_address: str, ) -> str: + """Register a new SSO user. + + This is called once we have successfully mapped the remote user id onto a local + user id, one way or another. + + Args: + attributes: user attributes returned by the user mapping provider, + including a non-empty localpart. + + auth_provider_id: A unique identifier for this SSO provider, e.g. + "oidc" or "saml". + + remote_user_id: The unique identifier from the SSO provider. + + user_agent: The user-agent in the HTTP request (used for potential + shadow-banning.) + + ip_address: The IP address of the requester (used for potential + shadow-banning.) + + Raises: + a MappingException if the localpart is invalid. + + a SynapseError with code 400 and errcode Codes.USER_IN_USE if the localpart + is already taken. + """ + # Since the localpart is provided via a potentially untrusted module, # ensure the MXID is valid before registering. - if contains_invalid_mxid_characters(attributes.localpart): + if not attributes.localpart or contains_invalid_mxid_characters( + attributes.localpart + ): raise MappingException("localpart is invalid: %s" % (attributes.localpart,)) logger.debug("Mapped SSO user to local part %s", attributes.localpart) @@ -326,3 +459,108 @@ class SsoHandler: await self._auth_handler.complete_sso_ui_auth( user_id, ui_auth_session_id, request ) + + async def check_username_availability( + self, localpart: str, session_id: str, + ) -> bool: + """Handle an "is username available" callback check + + Args: + localpart: desired localpart + session_id: the session id for the username picker + Returns: + True if the username is available + Raises: + SynapseError if the localpart is invalid or the session is unknown + """ + + # make sure that there is a valid mapping session, to stop people dictionary- + # scanning for accounts + + self._expire_old_sessions() + session = self._username_mapping_sessions.get(session_id) + if not session: + logger.info("Couldn't find session id %s", session_id) + raise SynapseError(400, "unknown session") + + logger.info( + "[session %s] Checking for availability of username %s", + session_id, + localpart, + ) + + if contains_invalid_mxid_characters(localpart): + raise SynapseError(400, "localpart is invalid: %s" % (localpart,)) + user_id = UserID(localpart, self._server_name).to_string() + user_infos = await self._store.get_users_by_id_case_insensitive(user_id) + + logger.info("[session %s] users: %s", session_id, user_infos) + return not user_infos + + async def handle_submit_username_request( + self, request: SynapseRequest, localpart: str, session_id: str + ) -> None: + """Handle a request to the username-picker 'submit' endpoint + + Will serve an HTTP response to the request. + + Args: + request: HTTP request + localpart: localpart requested by the user + session_id: ID of the username mapping session, extracted from a cookie + """ + self._expire_old_sessions() + session = self._username_mapping_sessions.get(session_id) + if not session: + logger.info("Couldn't find session id %s", session_id) + raise SynapseError(400, "unknown session") + + logger.info("[session %s] Registering localpart %s", session_id, localpart) + + attributes = UserAttributes( + localpart=localpart, + display_name=session.display_name, + emails=session.emails, + ) + + # the following will raise a 400 error if the username has been taken in the + # meantime. + user_id = await self._register_mapped_user( + attributes, + session.auth_provider_id, + session.remote_user_id, + request.get_user_agent(""), + request.getClientIP(), + ) + + logger.info("[session %s] Registered userid %s", session_id, user_id) + + # delete the mapping session and the cookie + del self._username_mapping_sessions[session_id] + + # delete the cookie + request.addCookie( + USERNAME_MAPPING_SESSION_COOKIE_NAME, + b"", + expires=b"Thu, 01 Jan 1970 00:00:00 GMT", + path=b"/", + ) + + await self._auth_handler.complete_sso_login( + user_id, + request, + session.client_redirect_url, + session.extra_login_attributes, + ) + + def _expire_old_sessions(self): + to_expire = [] + now = int(self._clock.time_msec()) + + for session_id, session in self._username_mapping_sessions.items(): + if session.expiry_time_ms <= now: + to_expire.append(session_id) + + for session_id in to_expire: + logger.info("Expiring mapping session %s", session_id) + del self._username_mapping_sessions[session_id] diff --git a/synapse/res/username_picker/index.html b/synapse/res/username_picker/index.html new file mode 100644 index 0000000000..37ea8bb6d8 --- /dev/null +++ b/synapse/res/username_picker/index.html @@ -0,0 +1,19 @@ + + + + Synapse Login + + + +
+
+ + + +
+ + + +
+ + diff --git a/synapse/res/username_picker/script.js b/synapse/res/username_picker/script.js new file mode 100644 index 0000000000..416a7c6f41 --- /dev/null +++ b/synapse/res/username_picker/script.js @@ -0,0 +1,95 @@ +let inputField = document.getElementById("field-username"); +let inputForm = document.getElementById("form"); +let submitButton = document.getElementById("button-submit"); +let message = document.getElementById("message"); + +// Submit username and receive response +function showMessage(messageText) { + // Unhide the message text + message.classList.remove("hidden"); + + message.textContent = messageText; +}; + +function doSubmit() { + showMessage("Success. Please wait a moment for your browser to redirect."); + + // remove the event handler before re-submitting the form. + delete inputForm.onsubmit; + inputForm.submit(); +} + +function onResponse(response) { + // Display message + showMessage(response); + + // Enable submit button and input field + submitButton.classList.remove('button--disabled'); + submitButton.value = "Submit"; +}; + +let allowedUsernameCharacters = RegExp("[^a-z0-9\\.\\_\\=\\-\\/]"); +function usernameIsValid(username) { + return !allowedUsernameCharacters.test(username); +} +let allowedCharactersString = "lowercase letters, digits, ., _, -, /, ="; + +function buildQueryString(params) { + return Object.keys(params) + .map(k => encodeURIComponent(k) + '=' + encodeURIComponent(params[k])) + .join('&'); +} + +function submitUsername(username) { + if(username.length == 0) { + onResponse("Please enter a username."); + return; + } + if(!usernameIsValid(username)) { + onResponse("Invalid username. Only the following characters are allowed: " + allowedCharactersString); + return; + } + + // if this browser doesn't support fetch, skip the availability check. + if(!window.fetch) { + doSubmit(); + return; + } + + let check_uri = 'check?' + buildQueryString({"username": username}); + fetch(check_uri, { + // include the cookie + "credentials": "same-origin", + }).then((response) => { + if(!response.ok) { + // for non-200 responses, raise the body of the response as an exception + return response.text().then((text) => { throw text; }); + } else { + return response.json(); + } + }).then((json) => { + if(json.error) { + throw json.error; + } else if(json.available) { + doSubmit(); + } else { + onResponse("This username is not available, please choose another."); + } + }).catch((err) => { + onResponse("Error checking username availability: " + err); + }); +} + +function clickSubmit() { + event.preventDefault(); + if(submitButton.classList.contains('button--disabled')) { return; } + + // Disable submit button and input field + submitButton.classList.add('button--disabled'); + + // Submit username + submitButton.value = "Checking..."; + submitUsername(inputField.value); +}; + +inputForm.onsubmit = clickSubmit; diff --git a/synapse/res/username_picker/style.css b/synapse/res/username_picker/style.css new file mode 100644 index 0000000000..745bd4c684 --- /dev/null +++ b/synapse/res/username_picker/style.css @@ -0,0 +1,27 @@ +input[type="text"] { + font-size: 100%; + background-color: #ededf0; + border: 1px solid #fff; + border-radius: .2em; + padding: .5em .9em; + display: block; + width: 26em; +} + +.button--disabled { + border-color: #fff; + background-color: transparent; + color: #000; + text-transform: none; +} + +.hidden { + display: none; +} + +.tooltip { + background-color: #f9f9fa; + padding: 1em; + margin: 1em 0; +} + diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py new file mode 100644 index 0000000000..d3b6803e65 --- /dev/null +++ b/synapse/rest/synapse/client/pick_username.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +import pkg_resources + +from twisted.web.http import Request +from twisted.web.resource import Resource +from twisted.web.static import File + +from synapse.api.errors import SynapseError +from synapse.handlers.sso import USERNAME_MAPPING_SESSION_COOKIE_NAME +from synapse.http.server import DirectServeHtmlResource, DirectServeJsonResource +from synapse.http.servlet import parse_string +from synapse.http.site import SynapseRequest + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +def pick_username_resource(hs: "HomeServer") -> Resource: + """Factory method to generate the username picker resource. + + This resource gets mounted under /_synapse/client/pick_username. The top-level + resource is just a File resource which serves up the static files in the resources + "res" directory, but it has a couple of children: + + * "submit", which does the mechanics of registering the new user, and redirects the + browser back to the client URL + + * "check": checks if a userid is free. + """ + + # XXX should we make this path customisable so that admins can restyle it? + base_path = pkg_resources.resource_filename("synapse", "res/username_picker") + + res = File(base_path) + res.putChild(b"submit", SubmitResource(hs)) + res.putChild(b"check", AvailabilityCheckResource(hs)) + + return res + + +class AvailabilityCheckResource(DirectServeJsonResource): + def __init__(self, hs: "HomeServer"): + super().__init__() + self._sso_handler = hs.get_sso_handler() + + async def _async_render_GET(self, request: Request): + localpart = parse_string(request, "username", required=True) + + session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME) + if not session_id: + raise SynapseError(code=400, msg="missing session_id") + + is_available = await self._sso_handler.check_username_availability( + localpart, session_id.decode("ascii", errors="replace") + ) + return 200, {"available": is_available} + + +class SubmitResource(DirectServeHtmlResource): + def __init__(self, hs: "HomeServer"): + super().__init__() + self._sso_handler = hs.get_sso_handler() + + async def _async_render_POST(self, request: SynapseRequest): + localpart = parse_string(request, "username", required=True) + + session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME) + if not session_id: + raise SynapseError(code=400, msg="missing session_id") + + await self._sso_handler.handle_submit_username_request( + request, localpart, session_id.decode("ascii", errors="replace") + ) diff --git a/synapse/types.py b/synapse/types.py index 3ab6bdbe06..c7d4e95809 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -349,15 +349,17 @@ NON_MXID_CHARACTER_PATTERN = re.compile( ) -def map_username_to_mxid_localpart(username, case_sensitive=False): +def map_username_to_mxid_localpart( + username: Union[str, bytes], case_sensitive: bool = False +) -> str: """Map a username onto a string suitable for a MXID This follows the algorithm laid out at https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets. Args: - username (unicode|bytes): username to be mapped - case_sensitive (bool): true if TEST and test should be mapped + username: username to be mapped + case_sensitive: true if TEST and test should be mapped onto different mxids Returns: diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index c54f1c5797..368d600b33 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -13,14 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -from urllib.parse import parse_qs, urlparse +import re +from typing import Dict +from urllib.parse import parse_qs, urlencode, urlparse from mock import ANY, Mock, patch import pymacaroons +from twisted.web.resource import Resource + +from synapse.api.errors import RedirectException from synapse.handlers.oidc_handler import OidcError from synapse.handlers.sso import MappingException +from synapse.rest.client.v1 import login +from synapse.rest.synapse.client.pick_username import pick_username_resource from synapse.server import HomeServer from synapse.types import UserID @@ -793,6 +800,140 @@ class OidcHandlerTestCase(HomeserverTestCase): "mapping_error", "Unable to generate a Matrix ID from the SSO response" ) + def test_empty_localpart(self): + """Attempts to map onto an empty localpart should be rejected.""" + userinfo = { + "sub": "tester", + "username": "", + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + self.assertRenderedError("mapping_error", "localpart is invalid: ") + + @override_config( + { + "oidc_config": { + "user_mapping_provider": { + "config": {"localpart_template": "{{ user.username }}"} + } + } + } + ) + def test_null_localpart(self): + """Mapping onto a null localpart via an empty OIDC attribute should be rejected""" + userinfo = { + "sub": "tester", + "username": None, + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + self.assertRenderedError("mapping_error", "localpart is invalid: ") + + +class UsernamePickerTestCase(HomeserverTestCase): + servlets = [login.register_servlets] + + def default_config(self): + config = super().default_config() + config["public_baseurl"] = BASE_URL + oidc_config = { + "enabled": True, + "client_id": CLIENT_ID, + "client_secret": CLIENT_SECRET, + "issuer": ISSUER, + "scopes": SCOPES, + "user_mapping_provider": { + "config": {"display_name_template": "{{ user.displayname }}"} + }, + } + + # Update this config with what's in the default config so that + # override_config works as expected. + oidc_config.update(config.get("oidc_config", {})) + config["oidc_config"] = oidc_config + + # whitelist this client URI so we redirect straight to it rather than + # serving a confirmation page + config["sso"] = {"client_whitelist": ["https://whitelisted.client"]} + return config + + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d["/_synapse/client/pick_username"] = pick_username_resource(self.hs) + return d + + def test_username_picker(self): + """Test the happy path of a username picker flow.""" + client_redirect_url = "https://whitelisted.client" + + # first of all, mock up an OIDC callback to the OidcHandler, which should + # raise a RedirectException + userinfo = {"sub": "tester", "displayname": "Jonny"} + f = self.get_failure( + _make_callback_with_userinfo( + self.hs, userinfo, client_redirect_url=client_redirect_url + ), + RedirectException, + ) + + # check the Location and cookies returned by the RedirectException + self.assertEqual(f.value.location, b"/_synapse/client/pick_username") + cookieheader = f.value.cookies[0] + regex = re.compile(b"^username_mapping_session=([a-zA-Z]+);") + m = regex.search(cookieheader) + if not m: + self.fail("cookie header %s does not match %s" % (cookieheader, regex)) + + # introspect the sso handler a bit to check that the username mapping session + # looks ok. + session_id = m.group(1).decode("ascii") + username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions + self.assertIn( + session_id, username_mapping_sessions, "session id not found in map" + ) + session = username_mapping_sessions[session_id] + self.assertEqual(session.remote_user_id, "tester") + self.assertEqual(session.display_name, "Jonny") + self.assertEqual(session.client_redirect_url, client_redirect_url) + + # the expiry time should be about 15 minutes away + expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) + self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) + + # Now, submit a username to the username picker, which should serve a redirect + # back to the client + submit_path = f.value.location + b"/submit" + content = urlencode({b"username": b"bobby"}).encode("utf8") + chan = self.make_request( + "POST", + path=submit_path, + content=content, + content_is_form=True, + custom_headers=[ + ("Cookie", cookieheader), + # old versions of twisted don't do form-parsing without a valid + # content-length header. + ("Content-Length", str(len(content))), + ], + ) + self.assertEqual(chan.code, 302, chan.result) + location_headers = chan.headers.getRawHeaders("Location") + # ensure that the returned location starts with the requested redirect URL + self.assertEqual( + location_headers[0][: len(client_redirect_url)], client_redirect_url + ) + + # fish the login token out of the returned redirect uri + parts = urlparse(location_headers[0]) + query = parse_qs(parts.query) + login_token = query["loginToken"][0] + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token, mxid, and device id. + chan = self.make_request( + "POST", "/login", content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.json_body["user_id"], "@bobby:test") + async def _make_callback_with_userinfo( hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect" diff --git a/tests/unittest.py b/tests/unittest.py index 39e5e7b85c..af7f752c5a 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -20,7 +20,7 @@ import hmac import inspect import logging import time -from typing import Dict, Optional, Type, TypeVar, Union +from typing import Dict, Iterable, Optional, Tuple, Type, TypeVar, Union from mock import Mock, patch @@ -383,6 +383,9 @@ class HomeserverTestCase(TestCase): federation_auth_origin: str = None, content_is_form: bool = False, await_result: bool = True, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, ) -> FakeChannel: """ Create a SynapseRequest at the path using the method and containing the @@ -405,6 +408,8 @@ class HomeserverTestCase(TestCase): true (the default), will pump the test reactor until the the renderer tells the channel the request is finished. + custom_headers: (name, value) pairs to add as request headers + Returns: The FakeChannel object which stores the result of the request. """ @@ -420,6 +425,7 @@ class HomeserverTestCase(TestCase): federation_auth_origin, content_is_form, await_result, + custom_headers, ) def setup_test_homeserver(self, *args, **kwargs): -- cgit 1.5.1 From d781a81e692563c5785e3efd4aa2487696b9c995 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 18 Dec 2020 15:37:19 +0000 Subject: Allow server admin to get admin bit in rooms where local user is an admin (#8756) This adds an admin API that allows a server admin to get power in a room if a local user has power in a room. Will also invite the user if they're not in the room and its a private room. Can specify another user (rather than the admin user) to be granted power. Co-authored-by: Matthew Hodgson --- changelog.d/8756.feature | 1 + docs/admin_api/rooms.md | 20 +++++- synapse/rest/admin/__init__.py | 2 + synapse/rest/admin/rooms.py | 136 +++++++++++++++++++++++++++++++++++++++- tests/rest/admin/test_room.py | 138 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 294 insertions(+), 3 deletions(-) create mode 100644 changelog.d/8756.feature (limited to 'synapse/rest') diff --git a/changelog.d/8756.feature b/changelog.d/8756.feature new file mode 100644 index 0000000000..03eb79fb0a --- /dev/null +++ b/changelog.d/8756.feature @@ -0,0 +1 @@ +Add admin API that lets server admins get power in rooms in which local users have power. diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index d7b1740fe3..9e560003a9 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -8,6 +8,7 @@ * [Parameters](#parameters-1) * [Response](#response) * [Undoing room shutdowns](#undoing-room-shutdowns) +- [Make Room Admin API](#make-room-admin-api) # List Room API @@ -467,6 +468,7 @@ The following fields are returned in the JSON response body: the old room to the new. * `new_room_id` - A string representing the room ID of the new room. + ## Undoing room shutdowns *Note*: This guide may be outdated by the time you read it. By nature of room shutdowns being performed at the database level, @@ -492,4 +494,20 @@ You will have to manually handle, if you so choose, the following: * Aliases that would have been redirected to the Content Violation room. * Users that would have been booted from the room (and will have been force-joined to the Content Violation room). -* Removal of the Content Violation room if desired. \ No newline at end of file +* Removal of the Content Violation room if desired. + + +# Make Room Admin API + +Grants another user the highest power available to a local user who is in the room. +If the user is not in the room, and it is not publicly joinable, then invite the user. + +By default the server admin (the caller) is granted power, but another user can +optionally be specified, e.g.: + +``` + POST /_synapse/admin/v1/rooms//make_room_admin + { + "user_id": "@foo:example.com" + } +``` diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 55ddebb4fe..6f7dc06503 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -38,6 +38,7 @@ from synapse.rest.admin.rooms import ( DeleteRoomRestServlet, JoinRoomAliasServlet, ListRoomRestServlet, + MakeRoomAdminRestServlet, RoomMembersRestServlet, RoomRestServlet, ShutdownRoomRestServlet, @@ -228,6 +229,7 @@ def register_servlets(hs, http_server): EventReportDetailRestServlet(hs).register(http_server) EventReportsRestServlet(hs).register(http_server) PushersRestServlet(hs).register(http_server) + MakeRoomAdminRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index b902af8028..ab7cc9102a 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -16,8 +16,8 @@ import logging from http import HTTPStatus from typing import TYPE_CHECKING, List, Optional, Tuple -from synapse.api.constants import EventTypes, JoinRules -from synapse.api.errors import Codes, NotFoundError, SynapseError +from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -37,6 +37,7 @@ from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester if TYPE_CHECKING: from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -367,3 +368,134 @@ class JoinRoomAliasServlet(RestServlet): ) return 200, {"room_id": room_id} + + +class MakeRoomAdminRestServlet(RestServlet): + """Allows a server admin to get power in a room if a local user has power in + a room. Will also invite the user if they're not in the room and it's a + private room. Can specify another user (rather than the admin user) to be + granted power, e.g.: + + POST/_synapse/admin/v1/rooms//make_room_admin + { + "user_id": "@foo:example.com" + } + """ + + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/make_room_admin") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.auth = hs.get_auth() + self.room_member_handler = hs.get_room_member_handler() + self.event_creation_handler = hs.get_event_creation_handler() + self.state_handler = hs.get_state_handler() + self.is_mine_id = hs.is_mine_id + + async def on_POST(self, request, room_identifier): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + content = parse_json_object_from_request(request, allow_empty_body=True) + + # Resolve to a room ID, if necessary. + if RoomID.is_valid(room_identifier): + room_id = room_identifier + elif RoomAlias.is_valid(room_identifier): + room_alias = RoomAlias.from_string(room_identifier) + room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias) + room_id = room_id.to_string() + else: + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) + + # Which user to grant room admin rights to. + user_to_add = content.get("user_id", requester.user.to_string()) + + # Figure out which local users currently have power in the room, if any. + room_state = await self.state_handler.get_current_state(room_id) + if not room_state: + raise SynapseError(400, "Server not in room") + + create_event = room_state[(EventTypes.Create, "")] + power_levels = room_state.get((EventTypes.PowerLevels, "")) + + if power_levels is not None: + # We pick the local user with the highest power. + user_power = power_levels.content.get("users", {}) + admin_users = [ + user_id for user_id in user_power if self.is_mine_id(user_id) + ] + admin_users.sort(key=lambda user: user_power[user]) + + if not admin_users: + raise SynapseError(400, "No local admin user in room") + + admin_user_id = admin_users[-1] + + pl_content = power_levels.content + else: + # If there is no power level events then the creator has rights. + pl_content = {} + admin_user_id = create_event.sender + if not self.is_mine_id(admin_user_id): + raise SynapseError( + 400, "No local admin user in room", + ) + + # Grant the user power equal to the room admin by attempting to send an + # updated power level event. + new_pl_content = dict(pl_content) + new_pl_content["users"] = dict(pl_content.get("users", {})) + new_pl_content["users"][user_to_add] = new_pl_content["users"][admin_user_id] + + fake_requester = create_requester( + admin_user_id, authenticated_entity=requester.authenticated_entity, + ) + + try: + await self.event_creation_handler.create_and_send_nonmember_event( + fake_requester, + event_dict={ + "content": new_pl_content, + "sender": admin_user_id, + "type": EventTypes.PowerLevels, + "state_key": "", + "room_id": room_id, + }, + ) + except AuthError: + # The admin user we found turned out not to have enough power. + raise SynapseError( + 400, "No local admin user in room with power to update power levels." + ) + + # Now we check if the user we're granting admin rights to is already in + # the room. If not and it's not a public room we invite them. + member_event = room_state.get((EventTypes.Member, user_to_add)) + is_joined = False + if member_event: + is_joined = member_event.content["membership"] in ( + Membership.JOIN, + Membership.INVITE, + ) + + if is_joined: + return 200, {} + + join_rules = room_state.get((EventTypes.JoinRules, "")) + is_public = False + if join_rules: + is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC + + if is_public: + return 200, {} + + await self.room_member_handler.update_membership( + fake_requester, + target=UserID.from_string(user_to_add), + room_id=room_id, + action=Membership.INVITE, + ) + + return 200, {} diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 014c30287a..60a5fcecf7 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -20,6 +20,7 @@ from typing import List, Optional from mock import Mock import synapse.rest.admin +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import Codes from synapse.rest.client.v1 import directory, events, login, room @@ -1432,6 +1433,143 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) +class MakeRoomAdminTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.creator = self.register_user("creator", "test") + self.creator_tok = self.login("creator", "test") + + self.second_user_id = self.register_user("second", "test") + self.second_tok = self.login("second", "test") + + self.public_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=True + ) + self.url = "/_synapse/admin/v1/rooms/{}/make_room_admin".format( + self.public_room_id + ) + + def test_public_room(self): + """Test that getting admin in a public room works. + """ + room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=True + ) + + channel = self.make_request( + "POST", + "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Now we test that we can join the room and ban a user. + self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok) + self.helper.change_membership( + room_id, + self.admin_user, + "@test:test", + Membership.BAN, + tok=self.admin_user_tok, + ) + + def test_private_room(self): + """Test that getting admin in a private room works and we get invited. + """ + room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=False, + ) + + channel = self.make_request( + "POST", + "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Now we test that we can join the room (we should have received an + # invite) and can ban a user. + self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok) + self.helper.change_membership( + room_id, + self.admin_user, + "@test:test", + Membership.BAN, + tok=self.admin_user_tok, + ) + + def test_other_user(self): + """Test that giving admin in a public room works to a non-admin user works. + """ + room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=True + ) + + channel = self.make_request( + "POST", + "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + content={"user_id": self.second_user_id}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Now we test that we can join the room and ban a user. + self.helper.join(room_id, self.second_user_id, tok=self.second_tok) + self.helper.change_membership( + room_id, + self.second_user_id, + "@test:test", + Membership.BAN, + tok=self.second_tok, + ) + + def test_not_enough_power(self): + """Test that we get a sensible error if there are no local room admins. + """ + room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=True + ) + + # The creator drops admin rights in the room. + pl = self.helper.get_state( + room_id, EventTypes.PowerLevels, tok=self.creator_tok + ) + pl["users"][self.creator] = 0 + self.helper.send_state( + room_id, EventTypes.PowerLevels, body=pl, tok=self.creator_tok + ) + + channel = self.make_request( + "POST", + "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + content={}, + access_token=self.admin_user_tok, + ) + + # We expect this to fail with a 400 as there are no room admins. + # + # (Note we assert the error message to ensure that it's not denied for + # some other reason) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + channel.json_body["error"], + "No local admin user in room with power to update power levels.", + ) + + PURGE_TABLES = [ "current_state_events", "event_backward_extremities", -- cgit 1.5.1 From 68bb26da690c6db759983ba0cb86491af48da0a0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 29 Dec 2020 07:40:12 -0500 Subject: Allow redacting events on workers (#8994) Adds the redacts endpoint to workers that have the client listener. --- changelog.d/8994.feature | 1 + docs/workers.md | 1 + synapse/app/generic_worker.py | 31 ++++--------------------------- synapse/rest/client/v1/room.py | 17 ++++++++++------- 4 files changed, 16 insertions(+), 34 deletions(-) create mode 100644 changelog.d/8994.feature (limited to 'synapse/rest') diff --git a/changelog.d/8994.feature b/changelog.d/8994.feature new file mode 100644 index 0000000000..76aeb185cb --- /dev/null +++ b/changelog.d/8994.feature @@ -0,0 +1 @@ +Allow running the redact endpoint on workers. diff --git a/docs/workers.md b/docs/workers.md index efe97af31a..298adf8695 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -229,6 +229,7 @@ expressions: ^/_matrix/client/(r0|unstable)/auth/.*/fallback/web$ # Event sending requests + ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/redact ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state/ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$ diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index aa12c74358..fa23d9bb20 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -89,7 +89,7 @@ from synapse.replication.tcp.streams import ( ToDeviceStream, ) from synapse.rest.admin import register_servlets_for_media_repo -from synapse.rest.client.v1 import events +from synapse.rest.client.v1 import events, room from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet from synapse.rest.client.v1.login import LoginRestServlet from synapse.rest.client.v1.profile import ( @@ -98,20 +98,6 @@ from synapse.rest.client.v1.profile import ( ProfileRestServlet, ) from synapse.rest.client.v1.push_rule import PushRuleRestServlet -from synapse.rest.client.v1.room import ( - JoinedRoomMemberListRestServlet, - JoinRoomAliasServlet, - PublicRoomListRestServlet, - RoomEventContextServlet, - RoomInitialSyncRestServlet, - RoomMemberListRestServlet, - RoomMembershipRestServlet, - RoomMessageListRestServlet, - RoomSendEventRestServlet, - RoomStateEventRestServlet, - RoomStateRestServlet, - RoomTypingRestServlet, -) from synapse.rest.client.v1.voip import VoipRestServlet from synapse.rest.client.v2_alpha import groups, sync, user_directory from synapse.rest.client.v2_alpha._base import client_patterns @@ -512,12 +498,6 @@ class GenericWorkerServer(HomeServer): elif name == "client": resource = JsonResource(self, canonical_json=False) - PublicRoomListRestServlet(self).register(resource) - RoomMemberListRestServlet(self).register(resource) - JoinedRoomMemberListRestServlet(self).register(resource) - RoomStateRestServlet(self).register(resource) - RoomEventContextServlet(self).register(resource) - RoomMessageListRestServlet(self).register(resource) RegisterRestServlet(self).register(resource) LoginRestServlet(self).register(resource) ThreepidRestServlet(self).register(resource) @@ -526,22 +506,19 @@ class GenericWorkerServer(HomeServer): VoipRestServlet(self).register(resource) PushRuleRestServlet(self).register(resource) VersionsRestServlet(self).register(resource) - RoomSendEventRestServlet(self).register(resource) - RoomMembershipRestServlet(self).register(resource) - RoomStateEventRestServlet(self).register(resource) - JoinRoomAliasServlet(self).register(resource) + ProfileAvatarURLRestServlet(self).register(resource) ProfileDisplaynameRestServlet(self).register(resource) ProfileRestServlet(self).register(resource) KeyUploadServlet(self).register(resource) AccountDataServlet(self).register(resource) RoomAccountDataServlet(self).register(resource) - RoomTypingRestServlet(self).register(resource) sync.register_servlets(self, resource) events.register_servlets(self, resource) + room.register_servlets(self, resource, True) + room.register_deprecated_servlets(self, resource) InitialSyncRestServlet(self).register(resource) - RoomInitialSyncRestServlet(self).register(resource) user_directory.register_servlets(self, resource) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 93c06afe27..5647e8c577 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -963,25 +963,28 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False): ) -def register_servlets(hs, http_server): +def register_servlets(hs, http_server, is_worker=False): RoomStateEventRestServlet(hs).register(http_server) - RoomCreateRestServlet(hs).register(http_server) RoomMemberListRestServlet(hs).register(http_server) JoinedRoomMemberListRestServlet(hs).register(http_server) RoomMessageListRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) - RoomForgetRestServlet(hs).register(http_server) RoomMembershipRestServlet(hs).register(http_server) RoomSendEventRestServlet(hs).register(http_server) PublicRoomListRestServlet(hs).register(http_server) RoomStateRestServlet(hs).register(http_server) RoomRedactEventRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server) - SearchRestServlet(hs).register(http_server) - JoinedRoomsRestServlet(hs).register(http_server) - RoomEventServlet(hs).register(http_server) RoomEventContextServlet(hs).register(http_server) - RoomAliasListServlet(hs).register(http_server) + + # Some servlets only get registered for the main process. + if not is_worker: + RoomCreateRestServlet(hs).register(http_server) + RoomForgetRestServlet(hs).register(http_server) + SearchRestServlet(hs).register(http_server) + JoinedRoomsRestServlet(hs).register(http_server) + RoomEventServlet(hs).register(http_server) + RoomAliasListServlet(hs).register(http_server) def register_deprecated_servlets(hs, http_server): -- cgit 1.5.1 From 14a73713751f2aea2932708d25eb13dd89f67fa2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 29 Dec 2020 12:47:45 -0500 Subject: Validate input parameters for the sendToDevice API. (#8975) This makes the "messages" key in the content required. This is currently optional in the spec, but that seems to be an error. --- changelog.d/8975.bugfix | 1 + synapse/rest/client/v2_alpha/sendtodevice.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8975.bugfix (limited to 'synapse/rest') diff --git a/changelog.d/8975.bugfix b/changelog.d/8975.bugfix new file mode 100644 index 0000000000..75049b8e18 --- /dev/null +++ b/changelog.d/8975.bugfix @@ -0,0 +1 @@ +Add validation to the `sendToDevice` API to raise a missing parameters error instead of a 500 error. diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index bc4f43639a..a3dee14ed4 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -17,7 +17,7 @@ import logging from typing import Tuple from synapse.http import servlet -from synapse.http.servlet import parse_json_object_from_request +from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request from synapse.logging.opentracing import set_tag, trace from synapse.rest.client.transactions import HttpTransactionCache @@ -54,6 +54,7 @@ class SendToDeviceRestServlet(servlet.RestServlet): requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) + assert_params_in_dict(content, ("messages",)) sender_user_id = requester.user.to_string() -- cgit 1.5.1 From b7c580e33341ffbd12533a572439778929f811bf Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 30 Dec 2020 08:39:59 -0500 Subject: Check if group IDs are valid before using them. (#8977) --- changelog.d/8977.bugfix | 1 + synapse/handlers/groups_local.py | 2 +- synapse/rest/client/v2_alpha/groups.py | 48 +++++++++++++++++++++++++++++++--- 3 files changed, 47 insertions(+), 4 deletions(-) create mode 100644 changelog.d/8977.bugfix (limited to 'synapse/rest') diff --git a/changelog.d/8977.bugfix b/changelog.d/8977.bugfix new file mode 100644 index 0000000000..ae0b6bec14 --- /dev/null +++ b/changelog.d/8977.bugfix @@ -0,0 +1 @@ +Properly return 400 errors on invalid group IDs. diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index abd8d2af44..df29edeb83 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -29,7 +29,7 @@ def _create_rerouter(func_name): async def f(self, group_id, *args, **kwargs): if not GroupID.is_valid(group_id): - raise SynapseError(400, "%s was not legal group ID" % (group_id,)) + raise SynapseError(400, "%s is not a legal group ID" % (group_id,)) if self.is_mine_id(group_id): return await getattr(self.groups_server_handler, func_name)( diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index a3bb095c2d..5b5da71815 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +from functools import wraps from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -25,6 +26,22 @@ from ._base import client_patterns logger = logging.getLogger(__name__) +def _validate_group_id(f): + """Wrapper to validate the form of the group ID. + + Can be applied to any on_FOO methods that accepts a group ID as a URL parameter. + """ + + @wraps(f) + def wrapper(self, request, group_id, *args, **kwargs): + if not GroupID.is_valid(group_id): + raise SynapseError(400, "%s is not a legal group ID" % (group_id,)) + + return f(self, request, group_id, *args, **kwargs) + + return wrapper + + class GroupServlet(RestServlet): """Get the group profile """ @@ -37,6 +54,7 @@ class GroupServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -47,6 +65,7 @@ class GroupServlet(RestServlet): return 200, group_description + @_validate_group_id async def on_POST(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -71,6 +90,7 @@ class GroupSummaryServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -102,6 +122,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id, category_id, room_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -117,6 +138,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): return 200, resp + @_validate_group_id async def on_DELETE(self, request, group_id, category_id, room_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -142,6 +164,7 @@ class GroupCategoryServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id, category_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -152,6 +175,7 @@ class GroupCategoryServlet(RestServlet): return 200, category + @_validate_group_id async def on_PUT(self, request, group_id, category_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -163,6 +187,7 @@ class GroupCategoryServlet(RestServlet): return 200, resp + @_validate_group_id async def on_DELETE(self, request, group_id, category_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -186,6 +211,7 @@ class GroupCategoriesServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -209,6 +235,7 @@ class GroupRoleServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id, role_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -219,6 +246,7 @@ class GroupRoleServlet(RestServlet): return 200, category + @_validate_group_id async def on_PUT(self, request, group_id, role_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -230,6 +258,7 @@ class GroupRoleServlet(RestServlet): return 200, resp + @_validate_group_id async def on_DELETE(self, request, group_id, role_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -253,6 +282,7 @@ class GroupRolesServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -284,6 +314,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id, role_id, user_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -299,6 +330,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): return 200, resp + @_validate_group_id async def on_DELETE(self, request, group_id, role_id, user_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -322,13 +354,11 @@ class GroupRoomServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - if not GroupID.is_valid(group_id): - raise SynapseError(400, "%s was not legal group ID" % (group_id,)) - result = await self.groups_handler.get_rooms_in_group( group_id, requester_user_id ) @@ -348,6 +378,7 @@ class GroupUsersServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -371,6 +402,7 @@ class GroupInvitedUsersServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -393,6 +425,7 @@ class GroupSettingJoinPolicyServlet(RestServlet): self.auth = hs.get_auth() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -449,6 +482,7 @@ class GroupAdminRoomsServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id, room_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -460,6 +494,7 @@ class GroupAdminRoomsServlet(RestServlet): return 200, result + @_validate_group_id async def on_DELETE(self, request, group_id, room_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -486,6 +521,7 @@ class GroupAdminRoomsConfigServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id, room_id, config_key): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -514,6 +550,7 @@ class GroupAdminUsersInviteServlet(RestServlet): self.store = hs.get_datastore() self.is_mine_id = hs.is_mine_id + @_validate_group_id async def on_PUT(self, request, group_id, user_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -541,6 +578,7 @@ class GroupAdminUsersKickServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id, user_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -565,6 +603,7 @@ class GroupSelfLeaveServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -589,6 +628,7 @@ class GroupSelfJoinServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -613,6 +653,7 @@ class GroupSelfAcceptInviteServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -637,6 +678,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() + @_validate_group_id async def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() -- cgit 1.5.1 From 1c9a8505623475ae28067e6f0e8e74ede70c728a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 4 Jan 2021 10:04:50 -0500 Subject: Add type hints to the crypto module. (#8999) --- changelog.d/8999.misc | 1 + mypy.ini | 2 + synapse/crypto/context_factory.py | 2 +- synapse/crypto/event_signing.py | 29 ++-- synapse/crypto/keyring.py | 206 +++++++++++++++++------------ synapse/federation/transport/server.py | 2 +- synapse/rest/key/v2/remote_key_resource.py | 9 +- synapse/storage/databases/main/keys.py | 10 +- tests/crypto/test_keyring.py | 10 +- 9 files changed, 158 insertions(+), 113 deletions(-) create mode 100644 changelog.d/8999.misc (limited to 'synapse/rest') diff --git a/changelog.d/8999.misc b/changelog.d/8999.misc new file mode 100644 index 0000000000..3987204f06 --- /dev/null +++ b/changelog.d/8999.misc @@ -0,0 +1 @@ +Add type hints to the crypto module. diff --git a/mypy.ini b/mypy.ini index a54f34fe24..6a53abfaa9 100644 --- a/mypy.ini +++ b/mypy.ini @@ -17,6 +17,7 @@ files = synapse/api, synapse/appservice, synapse/config, + synapse/crypto, synapse/event_auth.py, synapse/events/builder.py, synapse/events/validator.py, @@ -75,6 +76,7 @@ files = synapse/storage/background_updates.py, synapse/storage/databases/main/appservice.py, synapse/storage/databases/main/events.py, + synapse/storage/databases/main/keys.py, synapse/storage/databases/main/pusher.py, synapse/storage/databases/main/registration.py, synapse/storage/databases/main/stream.py, diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 57fd426e87..74b67b230a 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -227,7 +227,7 @@ class ConnectionVerifier: # This code is based on twisted.internet.ssl.ClientTLSOptions. - def __init__(self, hostname: bytes, verify_certs): + def __init__(self, hostname: bytes, verify_certs: bool): self._verify_certs = verify_certs _decoded = hostname.decode("ascii") diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index 0422c43fab..8fb116ae18 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -18,7 +18,7 @@ import collections.abc import hashlib import logging -from typing import Dict +from typing import Any, Callable, Dict, Tuple from canonicaljson import encode_canonical_json from signedjson.sign import sign_json @@ -27,13 +27,18 @@ from unpaddedbase64 import decode_base64, encode_base64 from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion +from synapse.events import EventBase from synapse.events.utils import prune_event, prune_event_dict from synapse.types import JsonDict logger = logging.getLogger(__name__) +Hasher = Callable[[bytes], "hashlib._Hash"] -def check_event_content_hash(event, hash_algorithm=hashlib.sha256): + +def check_event_content_hash( + event: EventBase, hash_algorithm: Hasher = hashlib.sha256 +) -> bool: """Check whether the hash for this PDU matches the contents""" name, expected_hash = compute_content_hash(event.get_pdu_json(), hash_algorithm) logger.debug( @@ -67,18 +72,19 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256): return message_hash_bytes == expected_hash -def compute_content_hash(event_dict, hash_algorithm): +def compute_content_hash( + event_dict: Dict[str, Any], hash_algorithm: Hasher +) -> Tuple[str, bytes]: """Compute the content hash of an event, which is the hash of the unredacted event. Args: - event_dict (dict): The unredacted event as a dict + event_dict: The unredacted event as a dict hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use to hash the event Returns: - tuple[str, bytes]: A tuple of the name of hash and the hash as raw - bytes. + A tuple of the name of hash and the hash as raw bytes. """ event_dict = dict(event_dict) event_dict.pop("age_ts", None) @@ -94,18 +100,19 @@ def compute_content_hash(event_dict, hash_algorithm): return hashed.name, hashed.digest() -def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256): +def compute_event_reference_hash( + event, hash_algorithm: Hasher = hashlib.sha256 +) -> Tuple[str, bytes]: """Computes the event reference hash. This is the hash of the redacted event. Args: - event (FrozenEvent) + event hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use to hash the event Returns: - tuple[str, bytes]: A tuple of the name of hash and the hash as raw - bytes. + A tuple of the name of hash and the hash as raw bytes. """ tmp_event = prune_event(event) event_dict = tmp_event.get_pdu_json() @@ -156,7 +163,7 @@ def add_hashes_and_signatures( event_dict: JsonDict, signature_name: str, signing_key: SigningKey, -): +) -> None: """Add content hash and sign the event Args: diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index f23eacc0d7..902128a23c 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -14,9 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc import logging import urllib from collections import defaultdict +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple import attr from signedjson.key import ( @@ -40,6 +42,7 @@ from synapse.api.errors import ( RequestSendFailed, SynapseError, ) +from synapse.config.key import TrustedKeyServer from synapse.logging.context import ( PreserveLoggingContext, make_deferred_yieldable, @@ -47,11 +50,15 @@ from synapse.logging.context import ( run_in_background, ) from synapse.storage.keys import FetchKeyResult +from synapse.types import JsonDict from synapse.util import unwrapFirstError from synapse.util.async_helpers import yieldable_gather_results from synapse.util.metrics import Measure from synapse.util.retryutils import NotRetryingDestination +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -61,16 +68,17 @@ class VerifyJsonRequest: A request to verify a JSON object. Attributes: - server_name(str): The name of the server to verify against. - - key_ids(set[str]): The set of key_ids to that could be used to verify the - JSON object + server_name: The name of the server to verify against. - json_object(dict): The JSON object to verify. + json_object: The JSON object to verify. - minimum_valid_until_ts (int): time at which we require the signing key to + minimum_valid_until_ts: time at which we require the signing key to be valid. (0 implies we don't care) + request_name: The name of the request. + + key_ids: The set of key_ids to that could be used to verify the JSON object + key_ready (Deferred[str, str, nacl.signing.VerifyKey]): A deferred (server_name, key_id, verify_key) tuple that resolves when a verify key has been fetched. The deferreds' callbacks are run with no @@ -80,12 +88,12 @@ class VerifyJsonRequest: errbacks with an M_UNAUTHORIZED SynapseError. """ - server_name = attr.ib() - json_object = attr.ib() - minimum_valid_until_ts = attr.ib() - request_name = attr.ib() - key_ids = attr.ib(init=False) - key_ready = attr.ib(default=attr.Factory(defer.Deferred)) + server_name = attr.ib(type=str) + json_object = attr.ib(type=JsonDict) + minimum_valid_until_ts = attr.ib(type=int) + request_name = attr.ib(type=str) + key_ids = attr.ib(init=False, type=List[str]) + key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred) def __attrs_post_init__(self): self.key_ids = signature_ids(self.json_object, self.server_name) @@ -96,7 +104,9 @@ class KeyLookupError(ValueError): class Keyring: - def __init__(self, hs, key_fetchers=None): + def __init__( + self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None + ): self.clock = hs.get_clock() if key_fetchers is None: @@ -112,22 +122,26 @@ class Keyring: # completes. # # These are regular, logcontext-agnostic Deferreds. - self.key_downloads = {} + self.key_downloads = {} # type: Dict[str, defer.Deferred] def verify_json_for_server( - self, server_name, json_object, validity_time, request_name - ): + self, + server_name: str, + json_object: JsonDict, + validity_time: int, + request_name: str, + ) -> defer.Deferred: """Verify that a JSON object has been signed by a given server Args: - server_name (str): name of the server which must have signed this object + server_name: name of the server which must have signed this object - json_object (dict): object to be checked + json_object: object to be checked - validity_time (int): timestamp at which we require the signing key to + validity_time: timestamp at which we require the signing key to be valid. (0 implies we don't care) - request_name (str): an identifier for this json object (eg, an event id) + request_name: an identifier for this json object (eg, an event id) for logging. Returns: @@ -138,12 +152,14 @@ class Keyring: requests = (req,) return make_deferred_yieldable(self._verify_objects(requests)[0]) - def verify_json_objects_for_server(self, server_and_json): + def verify_json_objects_for_server( + self, server_and_json: Iterable[Tuple[str, dict, int, str]] + ) -> List[defer.Deferred]: """Bulk verifies signatures of json objects, bulk fetching keys as necessary. Args: - server_and_json (iterable[Tuple[str, dict, int, str]): + server_and_json: Iterable of (server_name, json_object, validity_time, request_name) tuples. @@ -164,13 +180,14 @@ class Keyring: for server_name, json_object, validity_time, request_name in server_and_json ) - def _verify_objects(self, verify_requests): + def _verify_objects( + self, verify_requests: Iterable[VerifyJsonRequest] + ) -> List[defer.Deferred]: """Does the work of verify_json_[objects_]for_server Args: - verify_requests (iterable[VerifyJsonRequest]): - Iterable of verification requests. + verify_requests: Iterable of verification requests. Returns: List: for each input item, a deferred indicating success @@ -182,7 +199,7 @@ class Keyring: key_lookups = [] handle = preserve_fn(_handle_key_deferred) - def process(verify_request): + def process(verify_request: VerifyJsonRequest) -> defer.Deferred: """Process an entry in the request list Adds a key request to key_lookups, and returns a deferred which @@ -222,18 +239,20 @@ class Keyring: return results - async def _start_key_lookups(self, verify_requests): + async def _start_key_lookups( + self, verify_requests: List[VerifyJsonRequest] + ) -> None: """Sets off the key fetches for each verify request Once each fetch completes, verify_request.key_ready will be resolved. Args: - verify_requests (List[VerifyJsonRequest]): + verify_requests: """ try: # map from server name to a set of outstanding request ids - server_to_request_ids = {} + server_to_request_ids = {} # type: Dict[str, Set[int]] for verify_request in verify_requests: server_name = verify_request.server_name @@ -275,11 +294,11 @@ class Keyring: except Exception: logger.exception("Error starting key lookups") - async def wait_for_previous_lookups(self, server_names) -> None: + async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None: """Waits for any previous key lookups for the given servers to finish. Args: - server_names (Iterable[str]): list of servers which we want to look up + server_names: list of servers which we want to look up Returns: Resolves once all key lookups for the given servers have @@ -304,7 +323,7 @@ class Keyring: loop_count += 1 - def _get_server_verify_keys(self, verify_requests): + def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None: """Tries to find at least one key for each verify request For each verify_request, verify_request.key_ready is called back with @@ -312,7 +331,7 @@ class Keyring: with a SynapseError if none of the keys are found. Args: - verify_requests (list[VerifyJsonRequest]): list of verify requests + verify_requests: list of verify requests """ remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called} @@ -366,17 +385,19 @@ class Keyring: run_in_background(do_iterations) - async def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests): + async def _attempt_key_fetches_with_fetcher( + self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest] + ): """Use a key fetcher to attempt to satisfy some key requests Args: - fetcher (KeyFetcher): fetcher to use to fetch the keys - remaining_requests (set[VerifyJsonRequest]): outstanding key requests. + fetcher: fetcher to use to fetch the keys + remaining_requests: outstanding key requests. Any successfully-completed requests will be removed from the list. """ - # dict[str, dict[str, int]]: keys to fetch. + # The keys to fetch. # server_name -> key_id -> min_valid_ts - missing_keys = defaultdict(dict) + missing_keys = defaultdict(dict) # type: Dict[str, Dict[str, int]] for verify_request in remaining_requests: # any completed requests should already have been removed @@ -438,16 +459,18 @@ class Keyring: remaining_requests.difference_update(completed) -class KeyFetcher: - async def get_keys(self, keys_to_fetch): +class KeyFetcher(metaclass=abc.ABCMeta): + @abc.abstractmethod + async def get_keys( + self, keys_to_fetch: Dict[str, Dict[str, int]] + ) -> Dict[str, Dict[str, FetchKeyResult]]: """ Args: - keys_to_fetch (dict[str, dict[str, int]]): + keys_to_fetch: the keys to be fetched. server_name -> key_id -> min_valid_ts Returns: - Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]: - map from server_name -> key_id -> FetchKeyResult + Map from server_name -> key_id -> FetchKeyResult """ raise NotImplementedError @@ -455,31 +478,35 @@ class KeyFetcher: class StoreKeyFetcher(KeyFetcher): """KeyFetcher impl which fetches keys from our data store""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() - async def get_keys(self, keys_to_fetch): + async def get_keys( + self, keys_to_fetch: Dict[str, Dict[str, int]] + ) -> Dict[str, Dict[str, FetchKeyResult]]: """see KeyFetcher.get_keys""" - keys_to_fetch = ( + key_ids_to_fetch = ( (server_name, key_id) for server_name, keys_for_server in keys_to_fetch.items() for key_id in keys_for_server.keys() ) - res = await self.store.get_server_verify_keys(keys_to_fetch) - keys = {} + res = await self.store.get_server_verify_keys(key_ids_to_fetch) + keys = {} # type: Dict[str, Dict[str, FetchKeyResult]] for (server_name, key_id), key in res.items(): keys.setdefault(server_name, {})[key_id] = key return keys -class BaseV2KeyFetcher: - def __init__(self, hs): +class BaseV2KeyFetcher(KeyFetcher): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.config = hs.get_config() - async def process_v2_response(self, from_server, response_json, time_added_ms): + async def process_v2_response( + self, from_server: str, response_json: JsonDict, time_added_ms: int + ) -> Dict[str, FetchKeyResult]: """Parse a 'Server Keys' structure from the result of a /key request This is used to parse either the entirety of the response from @@ -493,16 +520,16 @@ class BaseV2KeyFetcher: to /_matrix/key/v2/query. Args: - from_server (str): the name of the server producing this result: either + from_server: the name of the server producing this result: either the origin server for a /_matrix/key/v2/server request, or the notary for a /_matrix/key/v2/query. - response_json (dict): the json-decoded Server Keys response object + response_json: the json-decoded Server Keys response object - time_added_ms (int): the timestamp to record in server_keys_json + time_added_ms: the timestamp to record in server_keys_json Returns: - Deferred[dict[str, FetchKeyResult]]: map from key_id to result object + Map from key_id to result object """ ts_valid_until_ms = response_json["valid_until_ts"] @@ -575,21 +602,22 @@ class BaseV2KeyFetcher: class PerspectivesKeyFetcher(BaseV2KeyFetcher): """KeyFetcher impl which fetches keys from the "perspectives" servers""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.clock = hs.get_clock() self.client = hs.get_federation_http_client() self.key_servers = self.config.key_servers - async def get_keys(self, keys_to_fetch): + async def get_keys( + self, keys_to_fetch: Dict[str, Dict[str, int]] + ) -> Dict[str, Dict[str, FetchKeyResult]]: """see KeyFetcher.get_keys""" - async def get_key(key_server): + async def get_key(key_server: TrustedKeyServer) -> Dict: try: - result = await self.get_server_verify_key_v2_indirect( + return await self.get_server_verify_key_v2_indirect( keys_to_fetch, key_server ) - return result except KeyLookupError as e: logger.warning( "Key lookup failed from %r: %s", key_server.server_name, e @@ -611,25 +639,25 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): ).addErrback(unwrapFirstError) ) - union_of_keys = {} + union_of_keys = {} # type: Dict[str, Dict[str, FetchKeyResult]] for result in results: for server_name, keys in result.items(): union_of_keys.setdefault(server_name, {}).update(keys) return union_of_keys - async def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server): + async def get_server_verify_key_v2_indirect( + self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer + ) -> Dict[str, Dict[str, FetchKeyResult]]: """ Args: - keys_to_fetch (dict[str, dict[str, int]]): + keys_to_fetch: the keys to be fetched. server_name -> key_id -> min_valid_ts - key_server (synapse.config.key.TrustedKeyServer): notary server to query for - the keys + key_server: notary server to query for the keys Returns: - dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]: map - from server_name -> key_id -> FetchKeyResult + Map from server_name -> key_id -> FetchKeyResult Raises: KeyLookupError if there was an error processing the entire response from @@ -662,11 +690,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): except HttpResponseException as e: raise KeyLookupError("Remote server returned an error: %s" % (e,)) - keys = {} - added_keys = [] + keys = {} # type: Dict[str, Dict[str, FetchKeyResult]] + added_keys = [] # type: List[Tuple[str, str, FetchKeyResult]] time_now_ms = self.clock.time_msec() + assert isinstance(query_response, dict) for response in query_response["server_keys"]: # do this first, so that we can give useful errors thereafter server_name = response.get("server_name") @@ -704,14 +733,15 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): return keys - def _validate_perspectives_response(self, key_server, response): + def _validate_perspectives_response( + self, key_server: TrustedKeyServer, response: JsonDict + ) -> None: """Optionally check the signature on the result of a /key/query request Args: - key_server (synapse.config.key.TrustedKeyServer): the notary server that - produced this result + key_server: the notary server that produced this result - response (dict): the json-decoded Server Keys response object + response: the json-decoded Server Keys response object """ perspective_name = key_server.server_name perspective_keys = key_server.verify_keys @@ -745,25 +775,26 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): class ServerKeyFetcher(BaseV2KeyFetcher): """KeyFetcher impl which fetches keys from the origin servers""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.clock = hs.get_clock() self.client = hs.get_federation_http_client() - async def get_keys(self, keys_to_fetch): + async def get_keys( + self, keys_to_fetch: Dict[str, Dict[str, int]] + ) -> Dict[str, Dict[str, FetchKeyResult]]: """ Args: - keys_to_fetch (dict[str, iterable[str]]): + keys_to_fetch: the keys to be fetched. server_name -> key_ids Returns: - dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]: - map from server_name -> key_id -> FetchKeyResult + Map from server_name -> key_id -> FetchKeyResult """ results = {} - async def get_key(key_to_fetch_item): + async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None: server_name, key_ids = key_to_fetch_item try: keys = await self.get_server_verify_key_v2_direct(server_name, key_ids) @@ -778,20 +809,22 @@ class ServerKeyFetcher(BaseV2KeyFetcher): await yieldable_gather_results(get_key, keys_to_fetch.items()) return results - async def get_server_verify_key_v2_direct(self, server_name, key_ids): + async def get_server_verify_key_v2_direct( + self, server_name: str, key_ids: Iterable[str] + ) -> Dict[str, FetchKeyResult]: """ Args: - server_name (str): - key_ids (iterable[str]): + server_name: + key_ids: Returns: - dict[str, FetchKeyResult]: map from key ID to lookup result + Map from key ID to lookup result Raises: KeyLookupError if there was a problem making the lookup """ - keys = {} # type: dict[str, FetchKeyResult] + keys = {} # type: Dict[str, FetchKeyResult] for requested_key_id in key_ids: # we may have found this key as a side-effect of asking for another. @@ -825,6 +858,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher): except HttpResponseException as e: raise KeyLookupError("Remote server returned an error: %s" % (e,)) + assert isinstance(response, dict) if response["server_name"] != server_name: raise KeyLookupError( "Expected a response for server %r not %r" @@ -846,11 +880,11 @@ class ServerKeyFetcher(BaseV2KeyFetcher): return keys -async def _handle_key_deferred(verify_request) -> None: +async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None: """Waits for the key to become available, and then performs a verification Args: - verify_request (VerifyJsonRequest): + verify_request: Raises: SynapseError if there was a problem performing the verification diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 434718ddfc..cfd094e58f 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -144,7 +144,7 @@ class Authenticator: ): raise FederationDeniedError(origin) - if not json_request["signatures"]: + if origin is None or not json_request["signatures"]: raise NoAuthenticationError( 401, "Missing Authorization headers", Codes.UNAUTHORIZED ) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index f843f02454..c57ac22e58 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Dict, Set +from typing import Dict from signedjson.sign import sign_json @@ -142,12 +142,13 @@ class RemoteKey(DirectServeJsonResource): time_now_ms = self.clock.time_msec() - cache_misses = {} # type: Dict[str, Set[str]] + # Note that the value is unused. + cache_misses = {} # type: Dict[str, Dict[str, int]] for (server_name, key_id, from_server), results in cached.items(): results = [(result["ts_added_ms"], result) for result in results] if not results and key_id is not None: - cache_misses.setdefault(server_name, set()).add(key_id) + cache_misses.setdefault(server_name, {})[key_id] = 0 continue if key_id is not None: @@ -201,7 +202,7 @@ class RemoteKey(DirectServeJsonResource): ) if miss: - cache_misses.setdefault(server_name, set()).add(key_id) + cache_misses.setdefault(server_name, {})[key_id] = 0 # Cast to bytes since postgresql returns a memoryview. json_results.add(bytes(most_recent_result["key_json"])) else: diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index f8f4bb9b3f..04ac2d0ced 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -22,6 +22,7 @@ from signedjson.key import decode_verify_key_bytes from synapse.storage._base import SQLBaseStore from synapse.storage.keys import FetchKeyResult +from synapse.storage.types import Cursor from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter @@ -44,7 +45,7 @@ class KeyStore(SQLBaseStore): ) async def get_server_verify_keys( self, server_name_and_key_ids: Iterable[Tuple[str, str]] - ) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]: + ) -> Dict[Tuple[str, str], FetchKeyResult]: """ Args: server_name_and_key_ids: @@ -56,7 +57,7 @@ class KeyStore(SQLBaseStore): """ keys = {} - def _get_keys(txn, batch): + def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str]]) -> None: """Processes a batch of keys to fetch, and adds the result to `keys`.""" # batch_iter always returns tuples so it's safe to do len(batch) @@ -77,13 +78,12 @@ class KeyStore(SQLBaseStore): # `ts_valid_until_ms`. ts_valid_until_ms = 0 - res = FetchKeyResult( + keys[(server_name, key_id)] = FetchKeyResult( verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)), valid_until_ts=ts_valid_until_ms, ) - keys[(server_name, key_id)] = res - def _txn(txn): + def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]: for batch in batch_iter(server_name_and_key_ids, 50): _get_keys(txn, batch) return keys diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index d146f2254f..1d65ea2f9c 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -75,7 +75,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): return val def test_verify_json_objects_for_server_awaits_previous_requests(self): - mock_fetcher = keyring.KeyFetcher() + mock_fetcher = Mock() mock_fetcher.get_keys = Mock() kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,)) @@ -195,7 +195,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): """Tests that we correctly handle key requests for keys we've stored with a null `ts_valid_until_ms` """ - mock_fetcher = keyring.KeyFetcher() + mock_fetcher = Mock() mock_fetcher.get_keys = Mock(return_value=make_awaitable({})) kr = keyring.Keyring( @@ -249,7 +249,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): } } - mock_fetcher = keyring.KeyFetcher() + mock_fetcher = Mock() mock_fetcher.get_keys = Mock(side_effect=get_keys) kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,)) @@ -288,9 +288,9 @@ class KeyringTestCase(unittest.HomeserverTestCase): } } - mock_fetcher1 = keyring.KeyFetcher() + mock_fetcher1 = Mock() mock_fetcher1.get_keys = Mock(side_effect=get_keys1) - mock_fetcher2 = keyring.KeyFetcher() + mock_fetcher2 = Mock() mock_fetcher2.get_keys = Mock(side_effect=get_keys2) kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2)) -- cgit 1.5.1 From d2c616a41381c9e2d43b08d5f225b52042d94d23 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 4 Jan 2021 18:13:49 +0000 Subject: Combine the SSO Redirect Servlets (#9015) * Implement CasHandler.handle_redirect_request ... to make it match OidcHandler and SamlHandler * Clean up interface for OidcHandler.handle_redirect_request Make it accept `client_redirect_url=None`. * Clean up interface for `SamlHandler.handle_redirect_request` ... bring it into line with CAS and OIDC by making it take a Request parameter, move the magic for `client_redirect_url` for UIA into the handler, and fix the return type to be a `str` rather than a `bytes`. * Define a common protocol for SSO auth provider impls * Give SsoIdentityProvider an ID and register them * Combine the SSO Redirect servlets Now that the SsoHandler knows about the identity providers, we can combine the various *RedirectServlets into a single implementation which delegates to the right IdP. * changelog --- changelog.d/9015.feature | 1 + synapse/handlers/cas_handler.py | 35 ++++++++++---- synapse/handlers/oidc_handler.py | 15 ++++-- synapse/handlers/saml_handler.py | 25 +++++++--- synapse/handlers/sso.py | 86 +++++++++++++++++++++++++++++++++- synapse/rest/client/v1/login.py | 89 ++++++++---------------------------- synapse/rest/client/v2_alpha/auth.py | 34 ++++++-------- tests/rest/client/v1/test_login.py | 2 +- 8 files changed, 174 insertions(+), 113 deletions(-) create mode 100644 changelog.d/9015.feature (limited to 'synapse/rest') diff --git a/changelog.d/9015.feature b/changelog.d/9015.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9015.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index fca210a5a6..295974c521 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -75,10 +75,12 @@ class CasHandler: self._http_client = hs.get_proxied_http_client() # identifier for the external_ids table - self._auth_provider_id = "cas" + self.idp_id = "cas" self._sso_handler = hs.get_sso_handler() + self._sso_handler.register_identity_provider(self) + def _build_service_param(self, args: Dict[str, str]) -> str: """ Generates a value to use as the "service" parameter when redirecting or @@ -105,7 +107,7 @@ class CasHandler: Args: ticket: The CAS ticket from the client. service_args: Additional arguments to include in the service URL. - Should be the same as those passed to `get_redirect_url`. + Should be the same as those passed to `handle_redirect_request`. Raises: CasError: If there's an error parsing the CAS response. @@ -184,16 +186,31 @@ class CasHandler: return CasResponse(user, attributes) - def get_redirect_url(self, service_args: Dict[str, str]) -> str: - """ - Generates a URL for the CAS server where the client should be redirected. + async def handle_redirect_request( + self, + request: SynapseRequest, + client_redirect_url: Optional[bytes], + ui_auth_session_id: Optional[str] = None, + ) -> str: + """Generates a URL for the CAS server where the client should be redirected. Args: - service_args: Additional arguments to include in the final redirect URL. + request: the incoming HTTP request + client_redirect_url: the URL that we should redirect the + client to after login (or None for UI Auth). + ui_auth_session_id: The session ID of the ongoing UI Auth (or + None if this is a login). Returns: - The URL to redirect the client to. + URL to redirect to """ + + if ui_auth_session_id: + service_args = {"session": ui_auth_session_id} + else: + assert client_redirect_url + service_args = {"redirectUrl": client_redirect_url.decode("utf8")} + args = urllib.parse.urlencode( {"service": self._build_service_param(service_args)} ) @@ -275,7 +292,7 @@ class CasHandler: # first check if we're doing a UIA if session: return await self._sso_handler.complete_sso_ui_auth_request( - self._auth_provider_id, cas_response.username, session, request, + self.idp_id, cas_response.username, session, request, ) # otherwise, we're handling a login request. @@ -375,7 +392,7 @@ class CasHandler: return None await self._sso_handler.complete_sso_login_request( - self._auth_provider_id, + self.idp_id, cas_response.username, request, client_redirect_url, diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 709f8dfc13..3e2b60eb7b 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -119,10 +119,12 @@ class OidcHandler(BaseHandler): self._macaroon_secret_key = hs.config.macaroon_secret_key # identifier for the external_ids table - self._auth_provider_id = "oidc" + self.idp_id = "oidc" self._sso_handler = hs.get_sso_handler() + self._sso_handler.register_identity_provider(self) + def _validate_metadata(self): """Verifies the provider metadata. @@ -475,7 +477,7 @@ class OidcHandler(BaseHandler): async def handle_redirect_request( self, request: SynapseRequest, - client_redirect_url: bytes, + client_redirect_url: Optional[bytes], ui_auth_session_id: Optional[str] = None, ) -> str: """Handle an incoming request to /login/sso/redirect @@ -499,7 +501,7 @@ class OidcHandler(BaseHandler): request: the incoming request from the browser. We'll respond to it with a redirect and a cookie. client_redirect_url: the URL that we should redirect the client to - when everything is done + when everything is done (or None for UI Auth) ui_auth_session_id: The session ID of the ongoing UI Auth (or None if this is a login). @@ -511,6 +513,9 @@ class OidcHandler(BaseHandler): state = generate_token() nonce = generate_token() + if not client_redirect_url: + client_redirect_url = b"" + cookie = self._generate_oidc_session_token( state=state, nonce=nonce, @@ -682,7 +687,7 @@ class OidcHandler(BaseHandler): return return await self._sso_handler.complete_sso_ui_auth_request( - self._auth_provider_id, remote_user_id, ui_auth_session_id, request + self.idp_id, remote_user_id, ui_auth_session_id, request ) # otherwise, it's a login @@ -923,7 +928,7 @@ class OidcHandler(BaseHandler): extra_attributes = await get_extra_attributes(userinfo, token) await self._sso_handler.complete_sso_login_request( - self._auth_provider_id, + self.idp_id, remote_user_id, request, client_redirect_url, diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 5fa7ab3f8b..6106237f1f 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -73,27 +73,38 @@ class SamlHandler(BaseHandler): ) # identifier for the external_ids table - self._auth_provider_id = "saml" + self.idp_id = "saml" # a map from saml session id to Saml2SessionData object self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] self._sso_handler = hs.get_sso_handler() + self._sso_handler.register_identity_provider(self) - def handle_redirect_request( - self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None - ) -> bytes: + async def handle_redirect_request( + self, + request: SynapseRequest, + client_redirect_url: Optional[bytes], + ui_auth_session_id: Optional[str] = None, + ) -> str: """Handle an incoming request to /login/sso/redirect Args: + request: the incoming HTTP request client_redirect_url: the URL that we should redirect the - client to when everything is done + client to after login (or None for UI Auth). ui_auth_session_id: The session ID of the ongoing UI Auth (or None if this is a login). Returns: URL to redirect to """ + if not client_redirect_url: + # Some SAML identity providers (e.g. Google) require a + # RelayState parameter on requests, so pass in a dummy redirect URL + # (which will never get used). + client_redirect_url = b"unused" + reqid, info = self._saml_client.prepare_for_authenticate( entityid=self._saml_idp_entityid, relay_state=client_redirect_url ) @@ -210,7 +221,7 @@ class SamlHandler(BaseHandler): return return await self._sso_handler.complete_sso_ui_auth_request( - self._auth_provider_id, + self.idp_id, remote_user_id, current_session.ui_auth_session_id, request, @@ -306,7 +317,7 @@ class SamlHandler(BaseHandler): return None await self._sso_handler.complete_sso_login_request( - self._auth_provider_id, + self.idp_id, remote_user_id, request, client_redirect_url, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 33cd6bc178..d8fb8cdd05 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -12,15 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc import logging from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional import attr -from typing_extensions import NoReturn +from typing_extensions import NoReturn, Protocol from twisted.web.http import Request -from synapse.api.errors import RedirectException, SynapseError +from synapse.api.errors import Codes, RedirectException, SynapseError from synapse.http.server import respond_with_html from synapse.http.site import SynapseRequest from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters @@ -40,6 +41,53 @@ class MappingException(Exception): """ +class SsoIdentityProvider(Protocol): + """Abstract base class to be implemented by SSO Identity Providers + + An Identity Provider, or IdP, is an external HTTP service which authenticates a user + to say whether they should be allowed to log in, or perform a given action. + + Synapse supports various implementations of IdPs, including OpenID Connect, SAML, + and CAS. + + The main entry point is `handle_redirect_request`, which should return a URI to + redirect the user's browser to the IdP's authentication page. + + Each IdP should be registered with the SsoHandler via + `hs.get_sso_handler().register_identity_provider()`, so that requests to + `/_matrix/client/r0/login/sso/redirect` can be correctly dispatched. + """ + + @property + @abc.abstractmethod + def idp_id(self) -> str: + """A unique identifier for this SSO provider + + Eg, "saml", "cas", "github" + """ + + @abc.abstractmethod + async def handle_redirect_request( + self, + request: SynapseRequest, + client_redirect_url: Optional[bytes], + ui_auth_session_id: Optional[str] = None, + ) -> str: + """Handle an incoming request to /login/sso/redirect + + Args: + request: the incoming HTTP request + client_redirect_url: the URL that we should redirect the + client to after login (or None for UI Auth). + ui_auth_session_id: The session ID of the ongoing UI Auth (or + None if this is a login). + + Returns: + URL to redirect to + """ + raise NotImplementedError() + + @attr.s class UserAttributes: # the localpart of the mxid that the mapper has assigned to the user. @@ -100,6 +148,14 @@ class SsoHandler: # a map from session id to session data self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession] + # map from idp_id to SsoIdentityProvider + self._identity_providers = {} # type: Dict[str, SsoIdentityProvider] + + def register_identity_provider(self, p: SsoIdentityProvider): + p_id = p.idp_id + assert p_id not in self._identity_providers + self._identity_providers[p_id] = p + def render_error( self, request: Request, @@ -124,6 +180,32 @@ class SsoHandler: ) respond_with_html(request, code, html) + async def handle_redirect_request( + self, request: SynapseRequest, client_redirect_url: bytes, + ) -> str: + """Handle a request to /login/sso/redirect + + Args: + request: incoming HTTP request + client_redirect_url: the URL that we should redirect the + client to after login. + + Returns: + the URI to redirect to + """ + if not self._identity_providers: + raise SynapseError( + 400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED + ) + + # if we only have one auth provider, redirect to it directly + if len(self._identity_providers) == 1: + ap = next(iter(self._identity_providers.values())) + return await ap.handle_redirect_request(request, client_redirect_url) + + # otherwise, we have a configuration error + raise Exception("Multiple SSO identity providers have been configured!") + async def get_sso_user_by_remote_user_id( self, auth_provider_id: str, remote_user_id: str ) -> Optional[str]: diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 5f4c6703db..ebc346105b 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -311,48 +311,31 @@ class LoginRestServlet(RestServlet): return result -class BaseSSORedirectServlet(RestServlet): - """Common base class for /login/sso/redirect impls""" - +class SsoRedirectServlet(RestServlet): PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) + def __init__(self, hs: "HomeServer"): + # make sure that the relevant handlers are instantiated, so that they + # register themselves with the main SSOHandler. + if hs.config.cas_enabled: + hs.get_cas_handler() + elif hs.config.saml2_enabled: + hs.get_saml_handler() + elif hs.config.oidc_enabled: + hs.get_oidc_handler() + self._sso_handler = hs.get_sso_handler() + async def on_GET(self, request: SynapseRequest): - args = request.args - if b"redirectUrl" not in args: - return 400, "Redirect URL not specified for SSO auth" - client_redirect_url = args[b"redirectUrl"][0] - sso_url = await self.get_sso_url(request, client_redirect_url) + client_redirect_url = parse_string( + request, "redirectUrl", required=True, encoding=None + ) + sso_url = await self._sso_handler.handle_redirect_request( + request, client_redirect_url + ) + logger.info("Redirecting to %s", sso_url) request.redirect(sso_url) finish_request(request) - async def get_sso_url( - self, request: SynapseRequest, client_redirect_url: bytes - ) -> bytes: - """Get the URL to redirect to, to perform SSO auth - - Args: - request: The client request to redirect. - client_redirect_url: the URL that we should redirect the - client to when everything is done - - Returns: - URL to redirect to - """ - # to be implemented by subclasses - raise NotImplementedError() - - -class CasRedirectServlet(BaseSSORedirectServlet): - def __init__(self, hs): - self._cas_handler = hs.get_cas_handler() - - async def get_sso_url( - self, request: SynapseRequest, client_redirect_url: bytes - ) -> bytes: - return self._cas_handler.get_redirect_url( - {"redirectUrl": client_redirect_url} - ).encode("ascii") - class CasTicketServlet(RestServlet): PATTERNS = client_patterns("/login/cas/ticket", v1=True) @@ -379,40 +362,8 @@ class CasTicketServlet(RestServlet): ) -class SAMLRedirectServlet(BaseSSORedirectServlet): - PATTERNS = client_patterns("/login/sso/redirect", v1=True) - - def __init__(self, hs): - self._saml_handler = hs.get_saml_handler() - - async def get_sso_url( - self, request: SynapseRequest, client_redirect_url: bytes - ) -> bytes: - return self._saml_handler.handle_redirect_request(client_redirect_url) - - -class OIDCRedirectServlet(BaseSSORedirectServlet): - """Implementation for /login/sso/redirect for the OIDC login flow.""" - - PATTERNS = client_patterns("/login/sso/redirect", v1=True) - - def __init__(self, hs): - self._oidc_handler = hs.get_oidc_handler() - - async def get_sso_url( - self, request: SynapseRequest, client_redirect_url: bytes - ) -> bytes: - return await self._oidc_handler.handle_redirect_request( - request, client_redirect_url - ) - - def register_servlets(hs, http_server): LoginRestServlet(hs).register(http_server) + SsoRedirectServlet(hs).register(http_server) if hs.config.cas_enabled: - CasRedirectServlet(hs).register(http_server) CasTicketServlet(hs).register(http_server) - elif hs.config.saml2_enabled: - SAMLRedirectServlet(hs).register(http_server) - elif hs.config.oidc_enabled: - OIDCRedirectServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index fab077747f..9b9514632f 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -14,15 +14,20 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.api.constants import LoginType from synapse.api.errors import SynapseError from synapse.api.urls import CLIENT_API_PREFIX +from synapse.handlers.sso import SsoIdentityProvider from synapse.http.server import respond_with_html from synapse.http.servlet import RestServlet, parse_string from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -35,7 +40,7 @@ class AuthRestServlet(RestServlet): PATTERNS = client_patterns(r"/auth/(?P[\w\.]*)/fallback/web") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -85,31 +90,20 @@ class AuthRestServlet(RestServlet): elif stagetype == LoginType.SSO: # Display a confirmation page which prompts the user to # re-authenticate with their SSO provider. - if self._cas_enabled: - # Generate a request to CAS that redirects back to an endpoint - # to verify the successful authentication. - sso_redirect_url = self._cas_handler.get_redirect_url( - {"session": session}, - ) + if self._cas_enabled: + sso_auth_provider = self._cas_handler # type: SsoIdentityProvider elif self._saml_enabled: - # Some SAML identity providers (e.g. Google) require a - # RelayState parameter on requests. It is not necessary here, so - # pass in a dummy redirect URL (which will never get used). - client_redirect_url = b"unused" - sso_redirect_url = self._saml_handler.handle_redirect_request( - client_redirect_url, session - ) - + sso_auth_provider = self._saml_handler elif self._oidc_enabled: - client_redirect_url = b"" - sso_redirect_url = await self._oidc_handler.handle_redirect_request( - request, client_redirect_url, session - ) - + sso_auth_provider = self._oidc_handler else: raise SynapseError(400, "Homeserver not configured for SSO.") + sso_redirect_url = await sso_auth_provider.handle_redirect_request( + request, None, session + ) + html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session) else: diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 18932d7518..999d628315 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -385,7 +385,7 @@ class CASTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", cas_ticket_url) # Test that the response is HTML. - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, 200, channel.result) content_type_header_value = "" for header in channel.result.get("headers", []): if header[0] == b"Content-Type": -- cgit 1.5.1 From 111b673fc1bbd3d51302d915f2ad2c044ed7d3b8 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 5 Jan 2021 11:25:28 +0000 Subject: Add initial support for a "pick your IdP" page (#9017) During login, if there are multiple IdPs enabled, offer the user a choice of IdPs. --- changelog.d/9017.feature | 1 + docs/sample_config.yaml | 25 ++++++++ synapse/app/homeserver.py | 2 + synapse/config/sso.py | 27 ++++++++ synapse/handlers/cas_handler.py | 3 + synapse/handlers/oidc_handler.py | 3 + synapse/handlers/saml_handler.py | 3 + synapse/handlers/sso.py | 18 +++++- synapse/res/templates/sso_login_idp_picker.html | 28 +++++++++ synapse/rest/synapse/client/pick_idp.py | 82 +++++++++++++++++++++++++ synapse/static/client/login/style.css | 5 ++ 11 files changed, 194 insertions(+), 3 deletions(-) create mode 100644 changelog.d/9017.feature create mode 100644 synapse/res/templates/sso_login_idp_picker.html create mode 100644 synapse/rest/synapse/client/pick_idp.py (limited to 'synapse/rest') diff --git a/changelog.d/9017.feature b/changelog.d/9017.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9017.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index dd981609ac..c8ae46d1b3 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1909,6 +1909,31 @@ sso: # # Synapse will look for the following templates in this directory: # + # * HTML page to prompt the user to choose an Identity Provider during + # login: 'sso_login_idp_picker.html'. + # + # This is only used if multiple SSO Identity Providers are configured. + # + # When rendering, this template is given the following variables: + # * redirect_url: the URL that the user will be redirected to after + # login. Needs manual escaping (see + # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). + # + # * server_name: the homeserver's name. + # + # * providers: a list of available Identity Providers. Each element is + # an object with the following attributes: + # * idp_id: unique identifier for the IdP + # * idp_name: user-facing name for the IdP + # + # The rendered HTML page should contain a form which submits its results + # back as a GET request, with the following query parameters: + # + # * redirectUrl: the client redirect URI (ie, the `redirect_url` passed + # to the template) + # + # * idp: the 'idp_id' of the chosen IDP. + # # * HTML page for a confirmation step before redirecting back to the client # with the login token: 'sso_redirect_confirm.html'. # diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 8d9b53be53..b1d9817a6a 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -63,6 +63,7 @@ from synapse.rest import ClientRestResource from synapse.rest.admin import AdminRestResource from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.rest.synapse.client.pick_idp import PickIdpResource from synapse.rest.synapse.client.pick_username import pick_username_resource from synapse.rest.well_known import WellKnownResource from synapse.server import HomeServer @@ -194,6 +195,7 @@ class SynapseHomeServer(HomeServer): "/.well-known/matrix/client": WellKnownResource(self), "/_synapse/admin": AdminRestResource(self), "/_synapse/client/pick_username": pick_username_resource(self), + "/_synapse/client/pick_idp": PickIdpResource(self), } ) diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 93bbd40937..1aeb1c5c92 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -31,6 +31,7 @@ class SSOConfig(Config): # Read templates from disk ( + self.sso_login_idp_picker_template, self.sso_redirect_confirm_template, self.sso_auth_confirm_template, self.sso_error_template, @@ -38,6 +39,7 @@ class SSOConfig(Config): sso_auth_success_template, ) = self.read_templates( [ + "sso_login_idp_picker.html", "sso_redirect_confirm.html", "sso_auth_confirm.html", "sso_error.html", @@ -98,6 +100,31 @@ class SSOConfig(Config): # # Synapse will look for the following templates in this directory: # + # * HTML page to prompt the user to choose an Identity Provider during + # login: 'sso_login_idp_picker.html'. + # + # This is only used if multiple SSO Identity Providers are configured. + # + # When rendering, this template is given the following variables: + # * redirect_url: the URL that the user will be redirected to after + # login. Needs manual escaping (see + # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). + # + # * server_name: the homeserver's name. + # + # * providers: a list of available Identity Providers. Each element is + # an object with the following attributes: + # * idp_id: unique identifier for the IdP + # * idp_name: user-facing name for the IdP + # + # The rendered HTML page should contain a form which submits its results + # back as a GET request, with the following query parameters: + # + # * redirectUrl: the client redirect URI (ie, the `redirect_url` passed + # to the template) + # + # * idp: the 'idp_id' of the chosen IDP. + # # * HTML page for a confirmation step before redirecting back to the client # with the login token: 'sso_redirect_confirm.html'. # diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 295974c521..f3430c6713 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -77,6 +77,9 @@ class CasHandler: # identifier for the external_ids table self.idp_id = "cas" + # user-facing name of this auth provider + self.idp_name = "CAS" + self._sso_handler = hs.get_sso_handler() self._sso_handler.register_identity_provider(self) diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 3e2b60eb7b..6835c6c462 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -121,6 +121,9 @@ class OidcHandler(BaseHandler): # identifier for the external_ids table self.idp_id = "oidc" + # user-facing name of this auth provider + self.idp_name = "OIDC" + self._sso_handler = hs.get_sso_handler() self._sso_handler.register_identity_provider(self) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 6106237f1f..a8376543c9 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -75,6 +75,9 @@ class SamlHandler(BaseHandler): # identifier for the external_ids table self.idp_id = "saml" + # user-facing name of this auth provider + self.idp_name = "SAML" + # a map from saml session id to Saml2SessionData object self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index d8fb8cdd05..2da1ea2223 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -14,7 +14,8 @@ # limitations under the License. import abc import logging -from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Mapping, Optional +from urllib.parse import urlencode import attr from typing_extensions import NoReturn, Protocol @@ -66,6 +67,11 @@ class SsoIdentityProvider(Protocol): Eg, "saml", "cas", "github" """ + @property + @abc.abstractmethod + def idp_name(self) -> str: + """User-facing name for this provider""" + @abc.abstractmethod async def handle_redirect_request( self, @@ -156,6 +162,10 @@ class SsoHandler: assert p_id not in self._identity_providers self._identity_providers[p_id] = p + def get_identity_providers(self) -> Mapping[str, SsoIdentityProvider]: + """Get the configured identity providers""" + return self._identity_providers + def render_error( self, request: Request, @@ -203,8 +213,10 @@ class SsoHandler: ap = next(iter(self._identity_providers.values())) return await ap.handle_redirect_request(request, client_redirect_url) - # otherwise, we have a configuration error - raise Exception("Multiple SSO identity providers have been configured!") + # otherwise, redirect to the IDP picker + return "/_synapse/client/pick_idp?" + urlencode( + (("redirectUrl", client_redirect_url),) + ) async def get_sso_user_by_remote_user_id( self, auth_provider_id: str, remote_user_id: str diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html new file mode 100644 index 0000000000..f53c9cd679 --- /dev/null +++ b/synapse/res/templates/sso_login_idp_picker.html @@ -0,0 +1,28 @@ + + + + + + {{server_name | e}} Login + + +
+

{{server_name | e}} Login

+ +
+ + diff --git a/synapse/rest/synapse/client/pick_idp.py b/synapse/rest/synapse/client/pick_idp.py new file mode 100644 index 0000000000..e5b720bbca --- /dev/null +++ b/synapse/rest/synapse/client/pick_idp.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING + +from synapse.http.server import ( + DirectServeHtmlResource, + finish_request, + respond_with_html, +) +from synapse.http.servlet import parse_string +from synapse.http.site import SynapseRequest + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class PickIdpResource(DirectServeHtmlResource): + """IdP picker resource. + + This resource gets mounted under /_synapse/client/pick_idp. It serves an HTML page + which prompts the user to choose an Identity Provider from the list. + """ + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._sso_handler = hs.get_sso_handler() + self._sso_login_idp_picker_template = ( + hs.config.sso.sso_login_idp_picker_template + ) + self._server_name = hs.hostname + + async def _async_render_GET(self, request: SynapseRequest) -> None: + client_redirect_url = parse_string(request, "redirectUrl", required=True) + idp = parse_string(request, "idp", required=False) + + # if we need to pick an IdP, do so + if not idp: + return await self._serve_id_picker(request, client_redirect_url) + + # otherwise, redirect to the IdP's redirect URI + providers = self._sso_handler.get_identity_providers() + auth_provider = providers.get(idp) + if not auth_provider: + logger.info("Unknown idp %r", idp) + self._sso_handler.render_error( + request, "unknown_idp", "Unknown identity provider ID" + ) + return + + sso_url = await auth_provider.handle_redirect_request( + request, client_redirect_url.encode("utf8") + ) + logger.info("Redirecting to %s", sso_url) + request.redirect(sso_url) + finish_request(request) + + async def _serve_id_picker( + self, request: SynapseRequest, client_redirect_url: str + ) -> None: + # otherwise, serve up the IdP picker + providers = self._sso_handler.get_identity_providers() + html = self._sso_login_idp_picker_template.render( + redirect_url=client_redirect_url, + server_name=self._server_name, + providers=providers.values(), + ) + respond_with_html(request, 200, html) diff --git a/synapse/static/client/login/style.css b/synapse/static/client/login/style.css index 83e4f6abc8..dd76714a92 100644 --- a/synapse/static/client/login/style.css +++ b/synapse/static/client/login/style.css @@ -31,6 +31,11 @@ form { margin: 10px 0 0 0; } +ul.radiobuttons { + text-align: left; + list-style: none; +} + /* * Add some padding to the viewport. */ -- cgit 1.5.1 From 8a910f97a47edb398acba2ad87805e4593c76139 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 6 Jan 2021 16:16:16 +0000 Subject: Add some tests for the IDP picker flow --- synapse/rest/client/v1/login.py | 4 +- tests/rest/client/v1/test_login.py | 191 ++++++++++++++++++++++++++++++++++++- tests/rest/client/v1/utils.py | 3 +- 3 files changed, 193 insertions(+), 5 deletions(-) (limited to 'synapse/rest') diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index ebc346105b..be938df962 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -319,9 +319,9 @@ class SsoRedirectServlet(RestServlet): # register themselves with the main SSOHandler. if hs.config.cas_enabled: hs.get_cas_handler() - elif hs.config.saml2_enabled: + if hs.config.saml2_enabled: hs.get_saml_handler() - elif hs.config.oidc_enabled: + if hs.config.oidc_enabled: hs.get_oidc_handler() self._sso_handler = hs.get_sso_handler() diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index b5a2ac23d4..1d1dc9f8a2 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -15,17 +15,26 @@ import time import urllib.parse -from typing import Any, Dict, Union +from html.parser import HTMLParser +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from mock import Mock +import pymacaroons + +from twisted.web.resource import Resource + import synapse.rest.admin from synapse.appservice import ApplicationService from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import devices, register from synapse.rest.client.v2_alpha.account import WhoamiRestServlet +from synapse.rest.synapse.client.pick_idp import PickIdpResource from tests import unittest +from tests.handlers.test_oidc import HAS_OIDC +from tests.handlers.test_saml import has_saml2 +from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG from tests.unittest import override_config, skip_unless try: @@ -350,6 +359,184 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) +@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC") +class MultiSSOTestCase(unittest.HomeserverTestCase): + """Tests for homeservers with multiple SSO providers enabled""" + + servlets = [ + login.register_servlets, + ] + + def default_config(self) -> Dict[str, Any]: + config = super().default_config() + + config["public_baseurl"] = BASE_URL + + config["cas_config"] = { + "enabled": True, + "server_url": CAS_SERVER, + "service_url": "https://matrix.goodserver.com:8448", + } + + config["saml2_config"] = { + "sp_config": { + "metadata": {"inline": [TEST_SAML_METADATA]}, + # use the XMLSecurity backend to avoid relying on xmlsec1 + "crypto_backend": "XMLSecurity", + }, + } + + config["oidc_config"] = TEST_OIDC_CONFIG + + return config + + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs) + return d + + def test_multi_sso_redirect(self): + """/login/sso/redirect should redirect to an identity picker""" + client_redirect_url = "https://x?" + + # first hit the redirect url, which should redirect to our idp picker + channel = self.make_request( + "GET", + "/_matrix/client/r0/login/sso/redirect?redirectUrl=" + client_redirect_url, + ) + self.assertEqual(channel.code, 302, channel.result) + uri = channel.headers.getRawHeaders("Location")[0] + + # hitting that picker should give us some HTML + channel = self.make_request("GET", uri) + self.assertEqual(channel.code, 200, channel.result) + + # parse the form to check it has fields assumed elsewhere in this class + class FormPageParser(HTMLParser): + def __init__(self): + super().__init__() + + # the values of the hidden inputs: map from name to value + self.hiddens = {} # type: Dict[str, Optional[str]] + + # the values of the radio buttons + self.radios = [] # type: List[Optional[str]] + + def handle_starttag( + self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] + ) -> None: + attr_dict = dict(attrs) + if tag == "input": + if attr_dict["type"] == "radio" and attr_dict["name"] == "idp": + self.radios.append(attr_dict["value"]) + elif attr_dict["type"] == "hidden": + input_name = attr_dict["name"] + assert input_name + self.hiddens[input_name] = attr_dict["value"] + + def error(_, message): + self.fail(message) + + p = FormPageParser() + p.feed(channel.result["body"].decode("utf-8")) + p.close() + + self.assertCountEqual(p.radios, ["cas", "oidc", "saml"]) + + self.assertEqual(p.hiddens["redirectUrl"], client_redirect_url) + + def test_multi_sso_redirect_to_cas(self): + """If CAS is chosen, should redirect to the CAS server""" + client_redirect_url = "https://x?" + + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + client_redirect_url + "&idp=cas", + shorthand=False, + ) + self.assertEqual(channel.code, 302, channel.result) + cas_uri = channel.headers.getRawHeaders("Location")[0] + cas_uri_path, cas_uri_query = cas_uri.split("?", 1) + + # it should redirect us to the login page of the cas server + self.assertEqual(cas_uri_path, CAS_SERVER + "/login") + + # check that the redirectUrl is correctly encoded in the service param - ie, the + # place that CAS will redirect to + cas_uri_params = urllib.parse.parse_qs(cas_uri_query) + service_uri = cas_uri_params["service"][0] + _, service_uri_query = service_uri.split("?", 1) + service_uri_params = urllib.parse.parse_qs(service_uri_query) + self.assertEqual(service_uri_params["redirectUrl"][0], client_redirect_url) + + def test_multi_sso_redirect_to_saml(self): + """If SAML is chosen, should redirect to the SAML server""" + client_redirect_url = "https://x?" + + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + + client_redirect_url + + "&idp=saml", + ) + self.assertEqual(channel.code, 302, channel.result) + saml_uri = channel.headers.getRawHeaders("Location")[0] + saml_uri_path, saml_uri_query = saml_uri.split("?", 1) + + # it should redirect us to the login page of the SAML server + self.assertEqual(saml_uri_path, SAML_SERVER) + + # the RelayState is used to carry the client redirect url + saml_uri_params = urllib.parse.parse_qs(saml_uri_query) + relay_state_param = saml_uri_params["RelayState"][0] + self.assertEqual(relay_state_param, client_redirect_url) + + def test_multi_sso_redirect_to_oidc(self): + """If OIDC is chosen, should redirect to the OIDC auth endpoint""" + client_redirect_url = "https://x?" + + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + + client_redirect_url + + "&idp=oidc", + ) + self.assertEqual(channel.code, 302, channel.result) + oidc_uri = channel.headers.getRawHeaders("Location")[0] + oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) + + # it should redirect us to the auth page of the OIDC server + self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + + # ... and should have set a cookie including the redirect url + cookies = dict( + h.split(";")[0].split("=", maxsplit=1) + for h in channel.headers.getRawHeaders("Set-Cookie") + ) + + oidc_session_cookie = cookies["oidc_session"] + macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie) + self.assertEqual( + self._get_value_from_macaroon(macaroon, "client_redirect_url"), + client_redirect_url, + ) + + def test_multi_sso_redirect_to_unknown(self): + """An unknown IdP should cause a 400""" + channel = self.make_request( + "GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", + ) + self.assertEqual(channel.code, 400, channel.result) + + @staticmethod + def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: + prefix = key + " = " + for caveat in macaroon.caveats: + if caveat.caveat_id.startswith(prefix): + return caveat.caveat_id[len(prefix) :] + raise ValueError("No %s caveat in macaroon" % (key,)) + + class CASTestCase(unittest.HomeserverTestCase): servlets = [ @@ -363,7 +550,7 @@ class CASTestCase(unittest.HomeserverTestCase): config = self.default_config() config["cas_config"] = { "enabled": True, - "server_url": "https://fake.test", + "server_url": CAS_SERVER, "service_url": "https://matrix.goodserver.com:8448", } diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index dbc27893b5..81b7f84360 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -444,6 +444,7 @@ class RestHelper: # an 'oidc_config' suitable for login_via_oidc. +TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth" TEST_OIDC_CONFIG = { "enabled": True, "discover": False, @@ -451,7 +452,7 @@ TEST_OIDC_CONFIG = { "client_id": "test-client-id", "client_secret": "test-client-secret", "scopes": ["profile"], - "authorization_endpoint": "https://z", + "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT, "token_endpoint": "https://issuer.test/token", "userinfo_endpoint": "https://issuer.test/userinfo", "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, -- cgit 1.5.1 From b849e46139675c3098fdaca8ceff6b76be3f2f02 Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Thu, 7 Jan 2021 23:01:59 +0200 Subject: Add forward extremities endpoint to rooms admin API GET /_synapse/admin/v1/rooms//forward_extremities now gets forward extremities for a room, returning count and the list of extremities. Signed-off-by: Jason Robinson --- synapse/rest/admin/__init__.py | 2 + synapse/rest/admin/rooms.py | 53 ++++++++++++++++++++++ synapse/storage/databases/main/__init__.py | 2 + .../databases/main/events_forward_extremities.py | 20 ++++++++ 4 files changed, 77 insertions(+) create mode 100644 synapse/storage/databases/main/events_forward_extremities.py (limited to 'synapse/rest') diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 6f7dc06503..b80b036090 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -36,6 +36,7 @@ from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_medi from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet from synapse.rest.admin.rooms import ( DeleteRoomRestServlet, + ForwardExtremitiesRestServlet, JoinRoomAliasServlet, ListRoomRestServlet, MakeRoomAdminRestServlet, @@ -230,6 +231,7 @@ def register_servlets(hs, http_server): EventReportsRestServlet(hs).register(http_server) PushersRestServlet(hs).register(http_server) MakeRoomAdminRestServlet(hs).register(http_server) + ForwardExtremitiesRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index ab7cc9102a..37703610c5 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -499,3 +499,56 @@ class MakeRoomAdminRestServlet(RestServlet): ) return 200, {} + + +class ForwardExtremitiesRestServlet(RestServlet): + """Allows a server admin to get or clear forward extremities. + + Clearing does not require restarting the server. + + Clear forward extremities: + DELETE /_synapse/admin/v1/rooms//forward_extremities + + Get forward_extremities: + GET /_synapse/admin/v1/rooms//forward_extremities + """ + + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/forward_extremities") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.auth = hs.get_auth() + self.room_member_handler = hs.get_room_member_handler() + self.store = hs.get_datastore() + + async def resolve_room_id(self, room_identifier: str) -> str: + """Resolve to a room ID, if necessary.""" + if RoomID.is_valid(room_identifier): + return room_identifier + elif RoomAlias.is_valid(room_identifier): + room_alias = RoomAlias.from_string(room_identifier) + room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias) + return room_id.to_string() + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) + + async def on_DELETE(self, request, room_identifier): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + room_id = await self.resolve_room_id(room_identifier) + + async def on_GET(self, request, room_identifier): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + room_id = await self.resolve_room_id(room_identifier) + if not room_id: + raise SynapseError(400, "Unknown room ID or room alias %s" % room_identifier) + + extremities = await self.store.get_forward_extremities_for_room(room_id) + return 200, { + "count": len(extremities), + "results": extremities, + } diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index c4de07a0a8..93b25af057 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -43,6 +43,7 @@ from .end_to_end_keys import EndToEndKeyStore from .event_federation import EventFederationStore from .event_push_actions import EventPushActionsStore from .events_bg_updates import EventsBackgroundUpdatesStore +from .events_forward_extremities import EventForwardExtremitiesStore from .filtering import FilteringStore from .group_server import GroupServerStore from .keys import KeyStore @@ -118,6 +119,7 @@ class DataStore( UIAuthStore, CacheInvalidationWorkerStore, ServerMetricsStore, + EventForwardExtremitiesStore, ): def __init__(self, database: DatabasePool, db_conn, hs): self.hs = hs diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py new file mode 100644 index 0000000000..250a424cc0 --- /dev/null +++ b/synapse/storage/databases/main/events_forward_extremities.py @@ -0,0 +1,20 @@ +from typing import List, Dict + +from synapse.storage._base import SQLBaseStore + + +class EventForwardExtremitiesStore(SQLBaseStore): + async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]: + def get_forward_extremities_for_room_txn(txn): + sql = ( + "SELECT event_id, state_group FROM event_forward_extremities NATURAL JOIN event_to_state_groups " + "WHERE room_id = ?" + ) + + txn.execute(sql, (room_id,)) + rows = txn.fetchall() + return [{"event_id": row[0], "state_group": row[1]} for row in rows] + + return await self.db_pool.runInteraction( + "get_forward_extremities_for_room", get_forward_extremities_for_room_txn + ) -- cgit 1.5.1 From c91045f56c8acf78a11fd722525e98c7cee77ac3 Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Thu, 7 Jan 2021 23:03:54 +0200 Subject: Move unknown room ID error into resolve_room_id Signed-off-by: Jason Robinson --- synapse/rest/admin/rooms.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) (limited to 'synapse/rest') diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 37703610c5..1f7b7daea9 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -524,14 +524,18 @@ class ForwardExtremitiesRestServlet(RestServlet): async def resolve_room_id(self, room_identifier: str) -> str: """Resolve to a room ID, if necessary.""" if RoomID.is_valid(room_identifier): - return room_identifier + room_id = room_identifier elif RoomAlias.is_valid(room_identifier): room_alias = RoomAlias.from_string(room_identifier) room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias) - return room_id.to_string() - raise SynapseError( - 400, "%s was not legal room ID or room alias" % (room_identifier,) - ) + room_id = room_id.to_string() + else: + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) + if not room_id: + raise SynapseError(400, "Unknown room ID or room alias %s" % room_identifier) + return room_id async def on_DELETE(self, request, room_identifier): requester = await self.auth.get_user_by_req(request) @@ -544,8 +548,6 @@ class ForwardExtremitiesRestServlet(RestServlet): await assert_user_is_admin(self.auth, requester.user) room_id = await self.resolve_room_id(room_identifier) - if not room_id: - raise SynapseError(400, "Unknown room ID or room alias %s" % room_identifier) extremities = await self.store.get_forward_extremities_for_room(room_id) return 200, { -- cgit 1.5.1 From 85c0999bfb70f2e8438a9730b8858e7845027190 Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Fri, 8 Jan 2021 00:12:23 +0200 Subject: Add Rooms admin forward extremities DELETE endpoint Signed-off-by: Jason Robinson --- synapse/rest/admin/rooms.py | 5 +++ .../databases/main/events_forward_extremities.py | 49 +++++++++++++++++++++- 2 files changed, 53 insertions(+), 1 deletion(-) (limited to 'synapse/rest') diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 1f7b7daea9..76f8603821 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -543,6 +543,11 @@ class ForwardExtremitiesRestServlet(RestServlet): room_id = await self.resolve_room_id(room_identifier) + deleted_count = await self.store.delete_forward_extremities_for_room(room_id) + return 200, { + "deleted": deleted_count, + } + async def on_GET(self, request, room_identifier): requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py index 250a424cc0..cc684a94fe 100644 --- a/synapse/storage/databases/main/events_forward_extremities.py +++ b/synapse/storage/databases/main/events_forward_extremities.py @@ -4,7 +4,54 @@ from synapse.storage._base import SQLBaseStore class EventForwardExtremitiesStore(SQLBaseStore): + + async def delete_forward_extremities_for_room(self, room_id: str) -> int: + """Delete any extra forward extremities for a room. + + Returns count deleted. + """ + def delete_forward_extremities_for_room_txn(txn): + # First we need to get the event_id to not delete + sql = ( + "SELECT " + " last_value(event_id) OVER w AS event_id" + " FROM event_forward_extremities" + " NATURAL JOIN events" + " where room_id = ?" + " WINDOW w AS (" + " PARTITION BY room_id" + " ORDER BY stream_ordering" + " range between unbounded preceding and unbounded following" + " )" + " ORDER BY stream_ordering" + ) + txn.execute(sql, (room_id,)) + rows = txn.fetchall() + + # TODO: should this raise a SynapseError instead of better to blow? + event_id = rows[0][0] + + # Now delete the extra forward extremities + sql = ( + "DELETE FROM event_forward_extremities " + "WHERE" + " event_id != ?" + " AND room_id = ?" + ) + + # TODO we should not commit yet + txn.execute(sql, (event_id, room_id)) + + # TODO flush the cache then commit + + return txn.rowcount + + return await self.db_pool.runInteraction( + "delete_forward_extremities_for_room", delete_forward_extremities_for_room_txn, + ) + async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]: + """Get list of forward extremities for a room.""" def get_forward_extremities_for_room_txn(txn): sql = ( "SELECT event_id, state_group FROM event_forward_extremities NATURAL JOIN event_to_state_groups " @@ -16,5 +63,5 @@ class EventForwardExtremitiesStore(SQLBaseStore): return [{"event_id": row[0], "state_group": row[1]} for row in rows] return await self.db_pool.runInteraction( - "get_forward_extremities_for_room", get_forward_extremities_for_room_txn + "get_forward_extremities_for_room", get_forward_extremities_for_room_txn, ) -- cgit 1.5.1 From 90ad4d443a109ad95741b499d914006578acceef Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Sat, 9 Jan 2021 21:57:41 +0200 Subject: Implement clearing cache after deleting forward extremities Also run linter. Signed-off-by: Jason Robinson --- synapse/rest/admin/rooms.py | 21 +++++------ .../databases/main/events_forward_extremities.py | 41 +++++++++++++++++----- 2 files changed, 42 insertions(+), 20 deletions(-) (limited to 'synapse/rest') diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 76f8603821..6757a8100b 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -524,18 +524,20 @@ class ForwardExtremitiesRestServlet(RestServlet): async def resolve_room_id(self, room_identifier: str) -> str: """Resolve to a room ID, if necessary.""" if RoomID.is_valid(room_identifier): - room_id = room_identifier + resolved_room_id = room_identifier elif RoomAlias.is_valid(room_identifier): room_alias = RoomAlias.from_string(room_identifier) room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias) - room_id = room_id.to_string() + resolved_room_id = room_id.to_string() else: raise SynapseError( 400, "%s was not legal room ID or room alias" % (room_identifier,) ) - if not room_id: - raise SynapseError(400, "Unknown room ID or room alias %s" % room_identifier) - return room_id + if not resolved_room_id: + raise SynapseError( + 400, "Unknown room ID or room alias %s" % room_identifier + ) + return resolved_room_id async def on_DELETE(self, request, room_identifier): requester = await self.auth.get_user_by_req(request) @@ -544,9 +546,7 @@ class ForwardExtremitiesRestServlet(RestServlet): room_id = await self.resolve_room_id(room_identifier) deleted_count = await self.store.delete_forward_extremities_for_room(room_id) - return 200, { - "deleted": deleted_count, - } + return 200, {"deleted": deleted_count} async def on_GET(self, request, room_identifier): requester = await self.auth.get_user_by_req(request) @@ -555,7 +555,4 @@ class ForwardExtremitiesRestServlet(RestServlet): room_id = await self.resolve_room_id(room_identifier) extremities = await self.store.get_forward_extremities_for_room(room_id) - return 200, { - "count": len(extremities), - "results": extremities, - } + return 200, {"count": len(extremities), "results": extremities} diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py index cc684a94fe..6b8da52fee 100644 --- a/synapse/storage/databases/main/events_forward_extremities.py +++ b/synapse/storage/databases/main/events_forward_extremities.py @@ -1,15 +1,22 @@ -from typing import List, Dict +import logging +from typing import Dict, List +from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore +logger = logging.getLogger(__name__) -class EventForwardExtremitiesStore(SQLBaseStore): +class EventForwardExtremitiesStore(SQLBaseStore): async def delete_forward_extremities_for_room(self, room_id: str) -> int: """Delete any extra forward extremities for a room. + Invalidates the "get_latest_event_ids_in_room" cache if any forward + extremities were deleted. + Returns count deleted. """ + def delete_forward_extremities_for_room_txn(txn): # First we need to get the event_id to not delete sql = ( @@ -27,9 +34,17 @@ class EventForwardExtremitiesStore(SQLBaseStore): ) txn.execute(sql, (room_id,)) rows = txn.fetchall() - - # TODO: should this raise a SynapseError instead of better to blow? - event_id = rows[0][0] + try: + event_id = rows[0][0] + logger.debug( + "Found event_id %s as the forward extremity to keep for room %s", + event_id, + room_id, + ) + except KeyError: + msg = f"No forward extremity event found for room {room_id}" + logger.warning(msg) + raise SynapseError(400, msg) # Now delete the extra forward extremities sql = ( @@ -39,19 +54,29 @@ class EventForwardExtremitiesStore(SQLBaseStore): " AND room_id = ?" ) - # TODO we should not commit yet txn.execute(sql, (event_id, room_id)) + logger.info( + "Deleted %s extra forward extremities for room %s", + txn.rowcount, + room_id, + ) - # TODO flush the cache then commit + if txn.rowcount > 0: + # Invalidate the cache + self._invalidate_cache_and_stream( + txn, self.get_latest_event_ids_in_room, (room_id,), + ) return txn.rowcount return await self.db_pool.runInteraction( - "delete_forward_extremities_for_room", delete_forward_extremities_for_room_txn, + "delete_forward_extremities_for_room", + delete_forward_extremities_for_room_txn, ) async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]: """Get list of forward extremities for a room.""" + def get_forward_extremities_for_room_txn(txn): sql = ( "SELECT event_id, state_group FROM event_forward_extremities NATURAL JOIN event_to_state_groups " -- cgit 1.5.1 From b161528fccac4bf17f7afa9438a75d796433194e Mon Sep 17 00:00:00 2001 From: David Teller Date: Mon, 11 Jan 2021 20:32:17 +0100 Subject: Also support remote users on the joined_rooms admin API. (#8948) For remote users, only the rooms which the server knows about are returned. Local users have all of their joined rooms returned. --- changelog.d/8948.feature | 1 + docs/admin_api/user_admin_api.rst | 4 +++ synapse/rest/admin/users.py | 7 ----- tests/rest/admin/test_user.py | 58 +++++++++++++++++++++++++++++++++++---- 4 files changed, 57 insertions(+), 13 deletions(-) create mode 100644 changelog.d/8948.feature (limited to 'synapse/rest') diff --git a/changelog.d/8948.feature b/changelog.d/8948.feature new file mode 100644 index 0000000000..3b06cbfa22 --- /dev/null +++ b/changelog.d/8948.feature @@ -0,0 +1 @@ +Update `/_synapse/admin/v1/users//joined_rooms` to work for both local and remote users. diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index e4d6f8203b..3115951e1f 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -337,6 +337,10 @@ A response body like the following is returned: "total": 2 } +The server returns the list of rooms of which the user and the server +are member. If the user is local, all the rooms of which the user is +member are returned. + **Parameters** The following parameters should be set in the URL: diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 6658c2da56..f8a73e7d9d 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -714,13 +714,6 @@ class UserMembershipRestServlet(RestServlet): async def on_GET(self, request, user_id): await assert_requester_is_admin(self.auth, request) - if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(400, "Can only lookup local users") - - user = await self.store.get_user_by_id(user_id) - if user is None: - raise NotFoundError("Unknown user") - room_ids = await self.store.get_rooms_for_user(user_id) ret = {"joined_rooms": list(room_ids), "total": len(room_ids)} return 200, ret diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 877fd2587b..ad4588c1da 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -25,6 +25,7 @@ from mock import Mock import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError +from synapse.api.room_versions import RoomVersions from synapse.rest.client.v1 import login, logout, profile, room from synapse.rest.client.v2_alpha import devices, sync @@ -1234,24 +1235,26 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): def test_user_does_not_exist(self): """ - Tests that a lookup for a user that does not exist returns a 404 + Tests that a lookup for a user that does not exist returns an empty list """ url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms" channel = self.make_request("GET", url, access_token=self.admin_user_tok,) - self.assertEqual(404, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + self.assertEqual(0, len(channel.json_body["joined_rooms"])) def test_user_is_not_local(self): """ - Tests that a lookup for a user that is not a local returns a 400 + Tests that a lookup for a user that is not a local and participates in no conversation returns an empty list """ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms" channel = self.make_request("GET", url, access_token=self.admin_user_tok,) - self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only lookup local users", channel.json_body["error"]) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + self.assertEqual(0, len(channel.json_body["joined_rooms"])) def test_no_memberships(self): """ @@ -1282,6 +1285,49 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): self.assertEqual(number_rooms, channel.json_body["total"]) self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"])) + def test_get_rooms_with_nonlocal_user(self): + """ + Tests that a normal lookup for rooms is successful with a non-local user + """ + + other_user_tok = self.login("user", "pass") + event_builder_factory = self.hs.get_event_builder_factory() + event_creation_handler = self.hs.get_event_creation_handler() + storage = self.hs.get_storage() + + # Create two rooms, one with a local user only and one with both a local + # and remote user. + self.helper.create_room_as(self.other_user, tok=other_user_tok) + local_and_remote_room_id = self.helper.create_room_as( + self.other_user, tok=other_user_tok + ) + + # Add a remote user to the room. + builder = event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": "m.room.member", + "sender": "@joiner:remote_hs", + "state_key": "@joiner:remote_hs", + "room_id": local_and_remote_room_id, + "content": {"membership": "join"}, + }, + ) + + event, context = self.get_success( + event_creation_handler.create_new_client_event(builder) + ) + + self.get_success(storage.persistence.persist_event(event, context)) + + # Now get rooms + url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms" + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual([local_and_remote_room_id], channel.json_body["joined_rooms"]) + class PushersRestTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From da16d06301aec83d144812d727c24192eb890c93 Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Mon, 11 Jan 2021 23:43:58 +0200 Subject: Address pr feedback * docs updates * prettify SQL * add missing copyright * cursor_to_dict * update touched files copyright years Signed-off-by: Jason Robinson --- docs/admin_api/rooms.md | 12 +--- synapse/rest/admin/__init__.py | 2 + synapse/rest/admin/rooms.py | 2 +- synapse/storage/databases/main/__init__.py | 2 +- .../databases/main/events_forward_extremities.py | 64 +++++++++++++--------- 5 files changed, 46 insertions(+), 36 deletions(-) (limited to 'synapse/rest') diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index 1d59bb5c4b..86daa393a7 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -516,11 +516,8 @@ optionally be specified, e.g.: # Forward Extremities Admin API Enables querying and deleting forward extremities from rooms. When a lot of forward -extremities accumulate in a room, performance can become degraded. - -When using this API endpoint to delete any extra forward extremities for a room, -the server does not need to be restarted as the relevant caches will be cleared -in the API call. +extremities accumulate in a room, performance can become degraded. For details, see +[#1760](https://github.com/matrix-org/synapse/issues/1760). ## Check for forward extremities @@ -537,7 +534,7 @@ A response as follows will be returned: "count": 1, "results": [ { - "event_id": "$M5SP266vsnxctfwFgFLNceaCo3ujhRtg_NiiHabcdfgh", + "event_id": "$M5SP266vsnxctfwFgFLNceaCo3ujhRtg_NiiHabcdefgh", "state_group": 439 } ] @@ -561,6 +558,3 @@ that were deleted. "deleted": 1 } ``` - -The cache `get_latest_event_ids_in_room` will be invalidated, if any forward extremities -were deleted. diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index b80b036090..319ad7bf7f 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2018-2019 New Vector Ltd +# Copyright 2020, 2021 The Matrix.org Foundation C.I.C. + # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 6757a8100b..da1499cab3 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 93b25af057..b936f54f1e 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py index 83f751cf5b..e6c2d6e122 100644 --- a/synapse/storage/databases/main/events_forward_extremities.py +++ b/synapse/storage/databases/main/events_forward_extremities.py @@ -1,3 +1,18 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import logging from typing import Dict, List @@ -19,19 +34,19 @@ class EventForwardExtremitiesStore(SQLBaseStore): def delete_forward_extremities_for_room_txn(txn): # First we need to get the event_id to not delete - sql = ( - "SELECT " - " last_value(event_id) OVER w AS event_id" - " FROM event_forward_extremities" - " NATURAL JOIN events" - " where room_id = ?" - " WINDOW w AS (" - " PARTITION BY room_id" - " ORDER BY stream_ordering" - " range between unbounded preceding and unbounded following" - " )" - " ORDER BY stream_ordering" - ) + sql = """ + SELECT + last_value(event_id) OVER w AS event_id + FROM event_forward_extremities + NATURAL JOIN events + WHERE room_id = ? + WINDOW w AS ( + PARTITION BY room_id + ORDER BY stream_ordering + range between unbounded preceding and unbounded following + ) + ORDER BY stream_ordering + """ txn.execute(sql, (room_id,)) rows = txn.fetchall() try: @@ -47,12 +62,10 @@ class EventForwardExtremitiesStore(SQLBaseStore): raise SynapseError(400, msg) # Now delete the extra forward extremities - sql = ( - "DELETE FROM event_forward_extremities " - "WHERE" - " event_id != ?" - " AND room_id = ?" - ) + sql = """ + DELETE FROM event_forward_extremities + WHERE event_id != ? AND room_id = ? + """ txn.execute(sql, (event_id, room_id)) logger.info( @@ -78,14 +91,15 @@ class EventForwardExtremitiesStore(SQLBaseStore): """Get list of forward extremities for a room.""" def get_forward_extremities_for_room_txn(txn): - sql = ( - "SELECT event_id, state_group FROM event_forward_extremities NATURAL JOIN event_to_state_groups " - "WHERE room_id = ?" - ) + sql = """ + SELECT event_id, state_group + FROM event_forward_extremities + NATURAL JOIN event_to_state_groups + WHERE room_id = ? + """ txn.execute(sql, (room_id,)) - rows = txn.fetchall() - return [{"event_id": row[0], "state_group": row[1]} for row in rows] + return self.db_pool.cursor_to_dict(txn) return await self.db_pool.runInteraction( "get_forward_extremities_for_room", get_forward_extremities_for_room_txn, -- cgit 1.5.1 From 0f8945e166b5f1965e69943e16c8220da74211bd Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 12 Jan 2021 12:48:12 +0000 Subject: Kill off `HomeServer.get_ip_from_request()` (#9080) Homeserver.get_ip_from_request() used to be a bit more complicated, but now it is totally redundant. Let's get rid of it. --- changelog.d/9080.misc | 1 + synapse/api/auth.py | 4 ++-- synapse/handlers/auth.py | 9 ++------- synapse/rest/client/v2_alpha/account.py | 19 +++---------------- synapse/rest/client/v2_alpha/auth.py | 4 ++-- synapse/rest/client/v2_alpha/devices.py | 12 ++---------- synapse/rest/client/v2_alpha/keys.py | 6 +----- synapse/rest/client/v2_alpha/register.py | 8 ++------ synapse/server.py | 4 ---- 9 files changed, 15 insertions(+), 52 deletions(-) create mode 100644 changelog.d/9080.misc (limited to 'synapse/rest') diff --git a/changelog.d/9080.misc b/changelog.d/9080.misc new file mode 100644 index 0000000000..3da8171f5f --- /dev/null +++ b/changelog.d/9080.misc @@ -0,0 +1 @@ +Remove redundant `Homeserver.get_ip_from_request` method. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 6d6703250b..67ecbd32ff 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -187,7 +187,7 @@ class Auth: AuthError if access is denied for the user in the access token """ try: - ip_addr = self.hs.get_ip_from_request(request) + ip_addr = request.getClientIP() user_agent = get_request_user_agent(request) access_token = self.get_access_token_from_request(request) @@ -276,7 +276,7 @@ class Auth: return None, None if app_service.ip_range_whitelist: - ip_address = IPAddress(self.hs.get_ip_from_request(request)) + ip_address = IPAddress(request.getClientIP()) if ip_address not in app_service.ip_range_whitelist: return None, None diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 5b86ee85c7..2f5b2b61aa 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -284,7 +284,6 @@ class AuthHandler(BaseHandler): requester: Requester, request: SynapseRequest, request_body: Dict[str, Any], - clientip: str, description: str, ) -> Tuple[dict, Optional[str]]: """ @@ -301,8 +300,6 @@ class AuthHandler(BaseHandler): request_body: The body of the request sent by the client - clientip: The IP address of the client. - description: A human readable string to be displayed to the user that describes the operation happening on their account. @@ -351,7 +348,7 @@ class AuthHandler(BaseHandler): try: result, params, session_id = await self.check_ui_auth( - flows, request, request_body, clientip, description + flows, request, request_body, description ) except LoginError: # Update the ratelimiter to say we failed (`can_do_action` doesn't raise). @@ -426,7 +423,6 @@ class AuthHandler(BaseHandler): flows: List[List[str]], request: SynapseRequest, clientdict: Dict[str, Any], - clientip: str, description: str, ) -> Tuple[dict, dict, str]: """ @@ -448,8 +444,6 @@ class AuthHandler(BaseHandler): clientdict: The dictionary from the client root level, not the 'auth' key: this method prompts for auth if none is sent. - clientip: The IP address of the client. - description: A human readable string to be displayed to the user that describes the operation happening on their account. @@ -540,6 +534,7 @@ class AuthHandler(BaseHandler): await self.store.set_ui_auth_clientdict(sid, clientdict) user_agent = get_request_user_agent(request) + clientip = request.getClientIP() await self.store.add_user_agent_ip_to_ui_auth_session( session.session_id, user_agent, clientip diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index d837bde1d6..7f3445fe5d 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -189,11 +189,7 @@ class PasswordRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) try: params, session_id = await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - self.hs.get_ip_from_request(request), - "modify your account password", + requester, request, body, "modify your account password", ) except InteractiveAuthIncompleteError as e: # The user needs to provide more steps to complete auth, but @@ -215,7 +211,6 @@ class PasswordRestServlet(RestServlet): [[LoginType.EMAIL_IDENTITY]], request, body, - self.hs.get_ip_from_request(request), "modify your account password", ) except InteractiveAuthIncompleteError as e: @@ -309,11 +304,7 @@ class DeactivateAccountRestServlet(RestServlet): return 200, {} await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - self.hs.get_ip_from_request(request), - "deactivate your account", + requester, request, body, "deactivate your account", ) result = await self._deactivate_account_handler.deactivate_account( requester.user.to_string(), erase, id_server=body.get("id_server") @@ -695,11 +686,7 @@ class ThreepidAddRestServlet(RestServlet): assert_valid_client_secret(client_secret) await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - self.hs.get_ip_from_request(request), - "add a third-party identifier to your account", + requester, request, body, "add a third-party identifier to your account", ) validation_session = await self.identity_handler.validate_threepid_session( diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 9b9514632f..149697fc23 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -128,7 +128,7 @@ class AuthRestServlet(RestServlet): authdict = {"response": response, "session": session} success = await self.auth_handler.add_oob_auth( - LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request) + LoginType.RECAPTCHA, authdict, request.getClientIP() ) if success: @@ -144,7 +144,7 @@ class AuthRestServlet(RestServlet): authdict = {"session": session} success = await self.auth_handler.add_oob_auth( - LoginType.TERMS, authdict, self.hs.get_ip_from_request(request) + LoginType.TERMS, authdict, request.getClientIP() ) if success: diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index af117cb27c..314e01dfe4 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -83,11 +83,7 @@ class DeleteDevicesRestServlet(RestServlet): assert_params_in_dict(body, ["devices"]) await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - self.hs.get_ip_from_request(request), - "remove device(s) from your account", + requester, request, body, "remove device(s) from your account", ) await self.device_handler.delete_devices( @@ -133,11 +129,7 @@ class DeviceRestServlet(RestServlet): raise await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - self.hs.get_ip_from_request(request), - "remove a device from your account", + requester, request, body, "remove a device from your account", ) await self.device_handler.delete_device(requester.user.to_string(), device_id) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index b91996c738..a6134ead8a 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -271,11 +271,7 @@ class SigningKeyUploadServlet(RestServlet): body = parse_json_object_from_request(request) await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - self.hs.get_ip_from_request(request), - "add a device signing key to your account", + requester, request, body, "add a device signing key to your account", ) result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 6b5a1b7109..35e646390e 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -353,7 +353,7 @@ class UsernameAvailabilityRestServlet(RestServlet): 403, "Registration has been disabled", errcode=Codes.FORBIDDEN ) - ip = self.hs.get_ip_from_request(request) + ip = request.getClientIP() with self.ratelimiter.ratelimit(ip) as wait_deferred: await wait_deferred @@ -513,11 +513,7 @@ class RegisterRestServlet(RestServlet): # not this will raise a user-interactive auth error. try: auth_result, params, session_id = await self.auth_handler.check_ui_auth( - self._registration_flows, - request, - body, - self.hs.get_ip_from_request(request), - "register a new account", + self._registration_flows, request, body, "register a new account", ) except InteractiveAuthIncompleteError as e: # The user needs to provide more steps to complete auth. diff --git a/synapse/server.py b/synapse/server.py index a198b0eb46..12da92b63c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -283,10 +283,6 @@ class HomeServer(metaclass=abc.ABCMeta): """ return self._reactor - def get_ip_from_request(self, request) -> str: - # X-Forwarded-For is handled by our custom request type. - return request.getClientIP() - def is_mine(self, domain_specific_string: DomainSpecificString) -> bool: return domain_specific_string.domain == self.hostname -- cgit 1.5.1 From 789d9ebad3043b54a7c70cfadb41af7a20ce3877 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 12 Jan 2021 17:38:03 +0000 Subject: UI Auth via SSO: redirect the user to an appropriate SSO. (#9081) If we have integrations with multiple identity providers, when the user does a UI Auth, we need to redirect them to the right one. There are a few steps to this. First of all we actually need to store the userid of the user we are trying to validate in the UIA session, since the /auth/sso/fallback/web request is unauthenticated. Then, once we get the /auth/sso/fallback/web request, we can fish the user id out of the session, and use it to look up the external id mappings, and hence pick an SSO provider for them. --- changelog.d/9081.feature | 1 + synapse/handlers/auth.py | 82 +++++++++++++++++++++++++------- synapse/handlers/sso.py | 31 ++++++++++++ synapse/handlers/ui_auth/__init__.py | 15 ++++++ synapse/rest/client/v2_alpha/account.py | 18 ++++--- synapse/rest/client/v2_alpha/auth.py | 33 +------------ synapse/rest/client/v2_alpha/register.py | 13 +++-- 7 files changed, 133 insertions(+), 60 deletions(-) create mode 100644 changelog.d/9081.feature (limited to 'synapse/rest') diff --git a/changelog.d/9081.feature b/changelog.d/9081.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9081.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 2f5b2b61aa..4f881a439a 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -50,7 +50,10 @@ from synapse.api.errors import ( ) from synapse.api.ratelimiting import Ratelimiter from synapse.handlers._base import BaseHandler -from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS +from synapse.handlers.ui_auth import ( + INTERACTIVE_AUTH_CHECKERS, + UIAuthSessionDataConstants, +) from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.http import get_request_user_agent from synapse.http.server import finish_request, respond_with_html @@ -335,10 +338,10 @@ class AuthHandler(BaseHandler): request_body.pop("auth", None) return request_body, None - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() # Check if we should be ratelimited due to too many previous failed attempts - self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False) + self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False) # build a list of supported flows supported_ui_auth_types = await self._get_available_ui_auth_types( @@ -346,13 +349,16 @@ class AuthHandler(BaseHandler): ) flows = [[login_type] for login_type in supported_ui_auth_types] + def get_new_session_data() -> JsonDict: + return {UIAuthSessionDataConstants.REQUEST_USER_ID: requester_user_id} + try: result, params, session_id = await self.check_ui_auth( - flows, request, request_body, description + flows, request, request_body, description, get_new_session_data, ) except LoginError: # Update the ratelimiter to say we failed (`can_do_action` doesn't raise). - self._failed_uia_attempts_ratelimiter.can_do_action(user_id) + self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id) raise # find the completed login type @@ -360,14 +366,14 @@ class AuthHandler(BaseHandler): if login_type not in result: continue - user_id = result[login_type] + validated_user_id = result[login_type] break else: # this can't happen raise Exception("check_auth returned True but no successful login type") # check that the UI auth matched the access token - if user_id != requester.user.to_string(): + if validated_user_id != requester_user_id: raise AuthError(403, "Invalid auth") # Note that the access token has been validated. @@ -399,13 +405,9 @@ class AuthHandler(BaseHandler): # if sso is enabled, allow the user to log in via SSO iff they have a mapping # from sso to mxid. - if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled: - if await self.store.get_external_ids_by_user(user.to_string()): - ui_auth_types.add(LoginType.SSO) - - # Our CAS impl does not (yet) correctly register users in user_external_ids, - # so always offer that if it's available. - if self.hs.config.cas.cas_enabled: + if await self.hs.get_sso_handler().get_identity_providers_for_user( + user.to_string() + ): ui_auth_types.add(LoginType.SSO) return ui_auth_types @@ -424,6 +426,7 @@ class AuthHandler(BaseHandler): request: SynapseRequest, clientdict: Dict[str, Any], description: str, + get_new_session_data: Optional[Callable[[], JsonDict]] = None, ) -> Tuple[dict, dict, str]: """ Takes a dictionary sent by the client in the login / registration @@ -447,6 +450,13 @@ class AuthHandler(BaseHandler): description: A human readable string to be displayed to the user that describes the operation happening on their account. + get_new_session_data: + an optional callback which will be called when starting a new session. + it should return data to be stored as part of the session. + + The keys of the returned data should be entries in + UIAuthSessionDataConstants. + Returns: A tuple of (creds, params, session_id). @@ -474,10 +484,15 @@ class AuthHandler(BaseHandler): # If there's no session ID, create a new session. if not sid: + new_session_data = get_new_session_data() if get_new_session_data else {} + session = await self.store.create_ui_auth_session( clientdict, uri, method, description ) + for k, v in new_session_data.items(): + await self.set_session_data(session.session_id, k, v) + else: try: session = await self.store.get_ui_auth_session(sid) @@ -639,7 +654,8 @@ class AuthHandler(BaseHandler): Args: session_id: The ID of this session as returned from check_auth - key: The key to store the data under + key: The key to store the data under. An entry from + UIAuthSessionDataConstants. value: The data to store """ try: @@ -655,7 +671,8 @@ class AuthHandler(BaseHandler): Args: session_id: The ID of this session as returned from check_auth - key: The key to store the data under + key: The key the data was stored under. An entry from + UIAuthSessionDataConstants. default: Value to return if the key has not been set """ try: @@ -1329,12 +1346,12 @@ class AuthHandler(BaseHandler): else: return False - async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str: + async def start_sso_ui_auth(self, request: SynapseRequest, session_id: str) -> str: """ Get the HTML for the SSO redirect confirmation page. Args: - redirect_url: The URL to redirect to the SSO provider. + request: The incoming HTTP request session_id: The user interactive authentication session ID. Returns: @@ -1344,6 +1361,35 @@ class AuthHandler(BaseHandler): session = await self.store.get_ui_auth_session(session_id) except StoreError: raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) + + user_id_to_verify = await self.get_session_data( + session_id, UIAuthSessionDataConstants.REQUEST_USER_ID + ) # type: str + + idps = await self.hs.get_sso_handler().get_identity_providers_for_user( + user_id_to_verify + ) + + if not idps: + # we checked that the user had some remote identities before offering an SSO + # flow, so either it's been deleted or the client has requested SSO despite + # it not being offered. + raise SynapseError(400, "User has no SSO identities") + + # for now, just pick one + idp_id, sso_auth_provider = next(iter(idps.items())) + if len(idps) > 0: + logger.warning( + "User %r has previously logged in with multiple SSO IdPs; arbitrarily " + "picking %r", + user_id_to_verify, + idp_id, + ) + + redirect_url = await sso_auth_provider.handle_redirect_request( + request, None, session_id + ) + return self._sso_auth_confirm_template.render( description=session.description, redirect_url=redirect_url, ) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 740df7e4a0..d096e0b091 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -167,6 +167,37 @@ class SsoHandler: """Get the configured identity providers""" return self._identity_providers + async def get_identity_providers_for_user( + self, user_id: str + ) -> Mapping[str, SsoIdentityProvider]: + """Get the SsoIdentityProviders which a user has used + + Given a user id, get the identity providers that that user has used to log in + with in the past (and thus could use to re-identify themselves for UI Auth). + + Args: + user_id: MXID of user to look up + + Raises: + a map of idp_id to SsoIdentityProvider + """ + external_ids = await self._store.get_external_ids_by_user(user_id) + + valid_idps = {} + for idp_id, _ in external_ids: + idp = self._identity_providers.get(idp_id) + if not idp: + logger.warning( + "User %r has an SSO mapping for IdP %r, but this is no longer " + "configured.", + user_id, + idp_id, + ) + else: + valid_idps[idp_id] = idp + + return valid_idps + def render_error( self, request: Request, diff --git a/synapse/handlers/ui_auth/__init__.py b/synapse/handlers/ui_auth/__init__.py index 824f37f8f8..a68d5e790e 100644 --- a/synapse/handlers/ui_auth/__init__.py +++ b/synapse/handlers/ui_auth/__init__.py @@ -20,3 +20,18 @@ TODO: move more stuff out of AuthHandler in here. """ from synapse.handlers.ui_auth.checkers import INTERACTIVE_AUTH_CHECKERS # noqa: F401 + + +class UIAuthSessionDataConstants: + """Constants for use with AuthHandler.set_session_data""" + + # used during registration and password reset to store a hashed copy of the + # password, so that the client does not need to submit it each time. + PASSWORD_HASH = "password_hash" + + # used during registration to store the mxid of the registered user + REGISTERED_USER_ID = "registered_user_id" + + # used by validate_user_via_ui_auth to store the mxid of the user we are validating + # for. + REQUEST_USER_ID = "request_user_id" diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 7f3445fe5d..3b50dc885f 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -20,9 +20,6 @@ from http import HTTPStatus from typing import TYPE_CHECKING from urllib.parse import urlparse -if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer - from synapse.api.constants import LoginType from synapse.api.errors import ( Codes, @@ -31,6 +28,7 @@ from synapse.api.errors import ( ThreepidValidationError, ) from synapse.config.emailconfig import ThreepidBehaviour +from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http.server import finish_request, respond_with_html from synapse.http.servlet import ( RestServlet, @@ -46,6 +44,10 @@ from synapse.util.threepids import canonicalise_email, check_3pid_allowed from ._base import client_patterns, interactive_auth_handler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + + logger = logging.getLogger(__name__) @@ -200,7 +202,9 @@ class PasswordRestServlet(RestServlet): if new_password: password_hash = await self.auth_handler.hash(new_password) await self.auth_handler.set_session_data( - e.session_id, "password_hash", password_hash + e.session_id, + UIAuthSessionDataConstants.PASSWORD_HASH, + password_hash, ) raise user_id = requester.user.to_string() @@ -222,7 +226,9 @@ class PasswordRestServlet(RestServlet): if new_password: password_hash = await self.auth_handler.hash(new_password) await self.auth_handler.set_session_data( - e.session_id, "password_hash", password_hash + e.session_id, + UIAuthSessionDataConstants.PASSWORD_HASH, + password_hash, ) raise @@ -255,7 +261,7 @@ class PasswordRestServlet(RestServlet): password_hash = await self.auth_handler.hash(new_password) elif session_id is not None: password_hash = await self.auth_handler.get_session_data( - session_id, "password_hash", None + session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None ) else: # UI validation was skipped, but the request did not include a new diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 149697fc23..75ece1c911 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -19,7 +19,6 @@ from typing import TYPE_CHECKING from synapse.api.constants import LoginType from synapse.api.errors import SynapseError from synapse.api.urls import CLIENT_API_PREFIX -from synapse.handlers.sso import SsoIdentityProvider from synapse.http.server import respond_with_html from synapse.http.servlet import RestServlet, parse_string @@ -46,22 +45,6 @@ class AuthRestServlet(RestServlet): self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() - - # SSO configuration. - self._cas_enabled = hs.config.cas_enabled - if self._cas_enabled: - self._cas_handler = hs.get_cas_handler() - self._cas_server_url = hs.config.cas_server_url - self._cas_service_url = hs.config.cas_service_url - self._saml_enabled = hs.config.saml2_enabled - if self._saml_enabled: - self._saml_handler = hs.get_saml_handler() - self._oidc_enabled = hs.config.oidc_enabled - if self._oidc_enabled: - self._oidc_handler = hs.get_oidc_handler() - self._cas_server_url = hs.config.cas_server_url - self._cas_service_url = hs.config.cas_service_url - self.recaptcha_template = hs.config.recaptcha_template self.terms_template = hs.config.terms_template self.success_template = hs.config.fallback_success_template @@ -90,21 +73,7 @@ class AuthRestServlet(RestServlet): elif stagetype == LoginType.SSO: # Display a confirmation page which prompts the user to # re-authenticate with their SSO provider. - - if self._cas_enabled: - sso_auth_provider = self._cas_handler # type: SsoIdentityProvider - elif self._saml_enabled: - sso_auth_provider = self._saml_handler - elif self._oidc_enabled: - sso_auth_provider = self._oidc_handler - else: - raise SynapseError(400, "Homeserver not configured for SSO.") - - sso_redirect_url = await sso_auth_provider.handle_redirect_request( - request, None, session - ) - - html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session) + html = await self.auth_handler.start_sso_ui_auth(request, session) else: raise SynapseError(404, "Unknown auth stage type") diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 35e646390e..b093183e79 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -38,6 +38,7 @@ from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.config.registration import RegistrationConfig from synapse.config.server import is_threepid_reserved from synapse.handlers.auth import AuthHandler +from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http.server import finish_request, respond_with_html from synapse.http.servlet import ( RestServlet, @@ -494,11 +495,11 @@ class RegisterRestServlet(RestServlet): # user here. We carry on and go through the auth checks though, # for paranoia. registered_user_id = await self.auth_handler.get_session_data( - session_id, "registered_user_id", None + session_id, UIAuthSessionDataConstants.REGISTERED_USER_ID, None ) # Extract the previously-hashed password from the session. password_hash = await self.auth_handler.get_session_data( - session_id, "password_hash", None + session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None ) # Ensure that the username is valid. @@ -528,7 +529,9 @@ class RegisterRestServlet(RestServlet): if not password_hash and password: password_hash = await self.auth_handler.hash(password) await self.auth_handler.set_session_data( - e.session_id, "password_hash", password_hash + e.session_id, + UIAuthSessionDataConstants.PASSWORD_HASH, + password_hash, ) raise @@ -629,7 +632,9 @@ class RegisterRestServlet(RestServlet): # Remember that the user account has been registered (and the user # ID it was registered with, since it might not have been specified). await self.auth_handler.set_session_data( - session_id, "registered_user_id", registered_user_id + session_id, + UIAuthSessionDataConstants.REGISTERED_USER_ID, + registered_user_id, ) registered = True -- cgit 1.5.1 From 7a2e9b549defe3f55531711a863183a33e7af83c Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 12 Jan 2021 22:30:15 +0100 Subject: Remove user's avatar URL and displayname when deactivated. (#8932) This only applies if the user's data is to be erased. --- changelog.d/8932.feature | 1 + docs/admin_api/user_admin_api.rst | 21 +++ synapse/handlers/deactivate_account.py | 18 ++- synapse/handlers/profile.py | 8 +- synapse/rest/admin/users.py | 22 ++- synapse/rest/client/v2_alpha/account.py | 7 +- synapse/server.py | 2 +- synapse/storage/databases/main/profile.py | 2 +- tests/handlers/test_profile.py | 30 ++++ tests/rest/admin/test_user.py | 220 ++++++++++++++++++++++++++++++ tests/rest/client/v1/test_login.py | 5 +- tests/rest/client/v1/test_rooms.py | 6 +- tests/storage/test_profile.py | 26 ++++ 13 files changed, 351 insertions(+), 17 deletions(-) create mode 100644 changelog.d/8932.feature (limited to 'synapse/rest') diff --git a/changelog.d/8932.feature b/changelog.d/8932.feature new file mode 100644 index 0000000000..a1d17394d7 --- /dev/null +++ b/changelog.d/8932.feature @@ -0,0 +1 @@ +Remove a user's avatar URL and display name when deactivated with the Admin API. diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index 3115951e1f..b3d413cf57 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -98,6 +98,8 @@ Body parameters: - ``deactivated``, optional. If unspecified, deactivation state will be left unchanged on existing accounts and set to ``false`` for new accounts. + A user cannot be erased by deactivating with this API. For details on deactivating users see + `Deactivate Account <#deactivate-account>`_. If the user already exists then optional parameters default to the current value. @@ -248,6 +250,25 @@ server admin: see `README.rst `_. The erase parameter is optional and defaults to ``false``. An empty body may be passed for backwards compatibility. +The following actions are performed when deactivating an user: + +- Try to unpind 3PIDs from the identity server +- Remove all 3PIDs from the homeserver +- Delete all devices and E2EE keys +- Delete all access tokens +- Delete the password hash +- Removal from all rooms the user is a member of +- Remove the user from the user directory +- Reject all pending invites +- Remove all account validity information related to the user + +The following additional actions are performed during deactivation if``erase`` +is set to ``true``: + +- Remove the user's display name +- Remove the user's avatar URL +- Mark the user as erased + Reset password ============== diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index e808142365..c4a3b26a84 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional from synapse.api.errors import SynapseError from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import UserID, create_requester +from synapse.types import Requester, UserID, create_requester from ._base import BaseHandler @@ -38,6 +38,7 @@ class DeactivateAccountHandler(BaseHandler): self._device_handler = hs.get_device_handler() self._room_member_handler = hs.get_room_member_handler() self._identity_handler = hs.get_identity_handler() + self._profile_handler = hs.get_profile_handler() self.user_directory_handler = hs.get_user_directory_handler() self._server_name = hs.hostname @@ -52,16 +53,23 @@ class DeactivateAccountHandler(BaseHandler): self._account_validity_enabled = hs.config.account_validity.enabled async def deactivate_account( - self, user_id: str, erase_data: bool, id_server: Optional[str] = None + self, + user_id: str, + erase_data: bool, + requester: Requester, + id_server: Optional[str] = None, + by_admin: bool = False, ) -> bool: """Deactivate a user's account Args: user_id: ID of user to be deactivated erase_data: whether to GDPR-erase the user's data + requester: The user attempting to make this change. id_server: Use the given identity server when unbinding any threepids. If None then will attempt to unbind using the identity server specified when binding (if known). + by_admin: Whether this change was made by an administrator. Returns: True if identity server supports removing threepids, otherwise False. @@ -121,6 +129,12 @@ class DeactivateAccountHandler(BaseHandler): # Mark the user as erased, if they asked for that if erase_data: + user = UserID.from_string(user_id) + # Remove avatar URL from this user + await self._profile_handler.set_avatar_url(user, requester, "", by_admin) + # Remove displayname from this user + await self._profile_handler.set_displayname(user, requester, "", by_admin) + logger.info("Marking %s as erased", user_id) await self.store.mark_user_erased(user_id) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 36f9ee4b71..c02b951031 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -286,13 +286,19 @@ class ProfileHandler(BaseHandler): 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) ) + avatar_url_to_set = new_avatar_url # type: Optional[str] + if new_avatar_url == "": + avatar_url_to_set = None + # Same like set_displayname if by_admin: requester = create_requester( target_user, authenticated_entity=requester.authenticated_entity ) - await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url) + await self.store.set_profile_avatar_url( + target_user.localpart, avatar_url_to_set + ) if self.hs.config.user_directory_search_all_users: profile = await self.store.get_profileinfo(target_user.localpart) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index f8a73e7d9d..f39e3d6d5c 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -244,7 +244,7 @@ class UserRestServletV2(RestServlet): if deactivate and not user["deactivated"]: await self.deactivate_account_handler.deactivate_account( - target_user.to_string(), False + target_user.to_string(), False, requester, by_admin=True ) elif not deactivate and user["deactivated"]: if "password" not in body: @@ -486,12 +486,22 @@ class WhoisRestServlet(RestServlet): class DeactivateAccountRestServlet(RestServlet): PATTERNS = admin_patterns("/deactivate/(?P[^/]*)") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self._deactivate_account_handler = hs.get_deactivate_account_handler() self.auth = hs.get_auth() + self.is_mine = hs.is_mine + self.store = hs.get_datastore() + + async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + if not self.is_mine(UserID.from_string(target_user_id)): + raise SynapseError(400, "Can only deactivate local users") + + if not await self.store.get_user_by_id(target_user_id): + raise NotFoundError("User not found") - async def on_POST(self, request, target_user_id): - await assert_requester_is_admin(self.auth, request) body = parse_json_object_from_request(request, allow_empty_body=True) erase = body.get("erase", False) if not isinstance(erase, bool): @@ -501,10 +511,8 @@ class DeactivateAccountRestServlet(RestServlet): Codes.BAD_JSON, ) - UserID.from_string(target_user_id) - result = await self._deactivate_account_handler.deactivate_account( - target_user_id, erase + target_user_id, erase, requester, by_admin=True ) if result: id_server_unbind_result = "success" diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 3b50dc885f..65e68d641b 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -305,7 +305,7 @@ class DeactivateAccountRestServlet(RestServlet): # allow ASes to deactivate their own users if requester.app_service: await self._deactivate_account_handler.deactivate_account( - requester.user.to_string(), erase + requester.user.to_string(), erase, requester ) return 200, {} @@ -313,7 +313,10 @@ class DeactivateAccountRestServlet(RestServlet): requester, request, body, "deactivate your account", ) result = await self._deactivate_account_handler.deactivate_account( - requester.user.to_string(), erase, id_server=body.get("id_server") + requester.user.to_string(), + erase, + requester, + id_server=body.get("id_server"), ) if result: id_server_unbind_result = "success" diff --git a/synapse/server.py b/synapse/server.py index 12da92b63c..d4c235cda5 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -501,7 +501,7 @@ class HomeServer(metaclass=abc.ABCMeta): return InitialSyncHandler(self) @cache_in_self - def get_profile_handler(self): + def get_profile_handler(self) -> ProfileHandler: return ProfileHandler(self) @cache_in_self diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 0e25ca3d7a..54ef0f1f54 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -82,7 +82,7 @@ class ProfileWorkerStore(SQLBaseStore): ) async def set_profile_avatar_url( - self, user_localpart: str, new_avatar_url: str + self, user_localpart: str, new_avatar_url: Optional[str] ) -> None: await self.db_pool.simple_update_one( table="profiles", diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 919547556b..022943a10a 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -105,6 +105,21 @@ class ProfileTestCase(unittest.TestCase): "Frank", ) + # Set displayname to an empty string + yield defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "" + ) + ) + + self.assertIsNone( + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.frank.localpart) + ) + ) + ) + @defer.inlineCallbacks def test_set_my_name_if_disabled(self): self.hs.config.enable_set_displayname = False @@ -223,6 +238,21 @@ class ProfileTestCase(unittest.TestCase): "http://my.server/me.png", ) + # Set avatar to an empty string + yield defer.ensureDeferred( + self.handler.set_avatar_url( + self.frank, synapse.types.create_requester(self.frank), "", + ) + ) + + self.assertIsNone( + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.frank.localpart) + ) + ), + ) + @defer.inlineCallbacks def test_set_my_avatar_if_disabled(self): self.hs.config.enable_set_avatar_url = False diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index ad4588c1da..04599c2fcf 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -588,6 +588,200 @@ class UsersListTestCase(unittest.HomeserverTestCase): _search_test(None, "bar", "user_id") +class DeactivateAccountTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass", displayname="User1") + self.other_user_token = self.login("user", "pass") + self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote( + self.other_user + ) + self.url = "/_synapse/admin/v1/deactivate/%s" % urllib.parse.quote( + self.other_user + ) + + # set attributes for user + self.get_success( + self.store.set_profile_avatar_url("user", "mxc://servername/mediaid") + ) + self.get_success( + self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0) + ) + + def test_no_auth(self): + """ + Try to deactivate users without authentication. + """ + channel = self.make_request("POST", self.url, b"{}") + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_not_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + url = "/_synapse/admin/v1/deactivate/@bob:test" + + channel = self.make_request("POST", url, access_token=self.other_user_token) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("You are not a server admin", channel.json_body["error"]) + + channel = self.make_request( + "POST", url, access_token=self.other_user_token, content=b"{}", + ) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("You are not a server admin", channel.json_body["error"]) + + def test_user_does_not_exist(self): + """ + Tests that deactivation for a user that does not exist returns a 404 + """ + + channel = self.make_request( + "POST", + "/_synapse/admin/v1/deactivate/@unknown_person:test", + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_erase_is_not_bool(self): + """ + If parameter `erase` is not boolean, return an error + """ + body = json.dumps({"erase": "False"}) + + channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that deactivation for a user that is not a local returns a 400 + """ + url = "/_synapse/admin/v1/deactivate/@unknown_person:unknown_domain" + + channel = self.make_request("POST", url, access_token=self.admin_user_tok) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only deactivate local users", channel.json_body["error"]) + + def test_deactivate_user_erase_true(self): + """ + Test deactivating an user and set `erase` to `true` + """ + + # Get user + channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(False, channel.json_body["deactivated"]) + self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User1", channel.json_body["displayname"]) + + # Deactivate user + body = json.dumps({"erase": True}) + + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["deactivated"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self.assertIsNone(channel.json_body["avatar_url"]) + self.assertIsNone(channel.json_body["displayname"]) + + self._is_erased("@user:test", True) + + def test_deactivate_user_erase_false(self): + """ + Test deactivating an user and set `erase` to `false` + """ + + # Get user + channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(False, channel.json_body["deactivated"]) + self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User1", channel.json_body["displayname"]) + + # Deactivate user + body = json.dumps({"erase": False}) + + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["deactivated"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User1", channel.json_body["displayname"]) + + self._is_erased("@user:test", False) + + def _is_erased(self, user_id: str, expect: bool) -> None: + """Assert that the user is erased or not + """ + d = self.store.is_user_erased(user_id) + if expect: + self.assertTrue(self.get_success(d)) + else: + self.assertFalse(self.get_success(d)) + + class UserRestTestCase(unittest.HomeserverTestCase): servlets = [ @@ -987,6 +1181,26 @@ class UserRestTestCase(unittest.HomeserverTestCase): Test deactivating another user. """ + # set attributes for user + self.get_success( + self.store.set_profile_avatar_url("user", "mxc://servername/mediaid") + ) + self.get_success( + self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0) + ) + + # Get user + channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(False, channel.json_body["deactivated"]) + self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User", channel.json_body["displayname"]) + # Deactivate user body = json.dumps({"deactivated": True}) @@ -1000,6 +1214,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User", channel.json_body["displayname"]) # the user is deactivated, the threepid will be deleted # Get user @@ -1010,6 +1227,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User", channel.json_body["displayname"]) @override_config({"user_directory": {"enabled": True, "search_all_users": True}}) def test_change_name_deactivate_user_user_directory(self): diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 1d1dc9f8a2..f9b8011961 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -30,6 +30,7 @@ from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import devices, register from synapse.rest.client.v2_alpha.account import WhoamiRestServlet from synapse.rest.synapse.client.pick_idp import PickIdpResource +from synapse.types import create_requester from tests import unittest from tests.handlers.test_oidc import HAS_OIDC @@ -667,7 +668,9 @@ class CASTestCase(unittest.HomeserverTestCase): # Deactivate the account. self.get_success( - self.deactivate_account_handler.deactivate_account(self.user_id, False) + self.deactivate_account_handler.deactivate_account( + self.user_id, False, create_requester(self.user_id) + ) ) # Request the CAS ticket. diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 6105eac47c..d4e3165436 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -29,7 +29,7 @@ from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin from synapse.rest.client.v1 import directory, login, profile, room from synapse.rest.client.v2_alpha import account -from synapse.types import JsonDict, RoomAlias, UserID +from synapse.types import JsonDict, RoomAlias, UserID, create_requester from synapse.util.stringutils import random_string from tests import unittest @@ -1687,7 +1687,9 @@ class ContextTestCase(unittest.HomeserverTestCase): deactivate_account_handler = self.hs.get_deactivate_account_handler() self.get_success( - deactivate_account_handler.deactivate_account(self.user_id, erase_data=True) + deactivate_account_handler.deactivate_account( + self.user_id, True, create_requester(self.user_id) + ) ) # Invite another user in the room. This is needed because messages will be diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 3fd0a38cf5..ea63bd56b4 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -48,6 +48,19 @@ class ProfileStoreTestCase(unittest.TestCase): ), ) + # test set to None + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.u_frank.localpart, None) + ) + + self.assertIsNone( + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.u_frank.localpart) + ) + ) + ) + @defer.inlineCallbacks def test_avatar_url(self): yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) @@ -66,3 +79,16 @@ class ProfileStoreTestCase(unittest.TestCase): ) ), ) + + # test set to None + yield defer.ensureDeferred( + self.store.set_profile_avatar_url(self.u_frank.localpart, None) + ) + + self.assertIsNone( + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.u_frank.localpart) + ) + ) + ) -- cgit 1.5.1 From d34c6e1279a24c5eb8afb962a29950c85fbfbf8a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 15 Jan 2021 10:57:37 -0500 Subject: Add type hints to media rest resources. (#9093) --- changelog.d/9093.misc | 1 + synapse/rest/media/v1/_base.py | 76 +++++++++++--------- synapse/rest/media/v1/config_resource.py | 14 +++- synapse/rest/media/v1/download_resource.py | 18 +++-- synapse/rest/media/v1/filepath.py | 50 ++++++++----- synapse/rest/media/v1/media_repository.py | 50 +++++++------ synapse/rest/media/v1/media_storage.py | 12 ++-- synapse/rest/media/v1/preview_url_resource.py | 77 ++++++++++++-------- synapse/rest/media/v1/storage_provider.py | 37 ++++++---- synapse/rest/media/v1/thumbnail_resource.py | 81 ++++++++++++++-------- synapse/rest/media/v1/thumbnailer.py | 18 ++--- synapse/rest/media/v1/upload_resource.py | 14 +++- synapse/storage/databases/main/media_repository.py | 3 +- 13 files changed, 286 insertions(+), 165 deletions(-) create mode 100644 changelog.d/9093.misc (limited to 'synapse/rest') diff --git a/changelog.d/9093.misc b/changelog.d/9093.misc new file mode 100644 index 0000000000..53eb8f72a8 --- /dev/null +++ b/changelog.d/9093.misc @@ -0,0 +1 @@ +Add type hints to media repository. diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 47c2b44bff..31a41e4a27 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2019 New Vector Ltd +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,10 +17,11 @@ import logging import os import urllib -from typing import Awaitable +from typing import Awaitable, Dict, Generator, List, Optional, Tuple from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender +from twisted.web.http import Request from synapse.api.errors import Codes, SynapseError, cs_error from synapse.http.server import finish_request, respond_with_json @@ -46,7 +47,7 @@ TEXT_CONTENT_TYPES = [ ] -def parse_media_id(request): +def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: try: # This allows users to append e.g. /test.png to the URL. Useful for # clients that parse the URL to see content type. @@ -69,7 +70,7 @@ def parse_media_id(request): ) -def respond_404(request): +def respond_404(request: Request) -> None: respond_with_json( request, 404, @@ -79,8 +80,12 @@ def respond_404(request): async def respond_with_file( - request, media_type, file_path, file_size=None, upload_name=None -): + request: Request, + media_type: str, + file_path: str, + file_size: Optional[int] = None, + upload_name: Optional[str] = None, +) -> None: logger.debug("Responding with %r", file_path) if os.path.isfile(file_path): @@ -98,15 +103,20 @@ async def respond_with_file( respond_404(request) -def add_file_headers(request, media_type, file_size, upload_name): +def add_file_headers( + request: Request, + media_type: str, + file_size: Optional[int], + upload_name: Optional[str], +) -> None: """Adds the correct response headers in preparation for responding with the media. Args: - request (twisted.web.http.Request) - media_type (str): The media/content type. - file_size (int): Size in bytes of the media, if known. - upload_name (str): The name of the requested file, if any. + request + media_type: The media/content type. + file_size: Size in bytes of the media, if known. + upload_name: The name of the requested file, if any. """ def _quote(x): @@ -153,7 +163,8 @@ def add_file_headers(request, media_type, file_size, upload_name): # select private. don't bother setting Expires as all our # clients are smart enough to be happy with Cache-Control request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") - request.setHeader(b"Content-Length", b"%d" % (file_size,)) + if file_size is not None: + request.setHeader(b"Content-Length", b"%d" % (file_size,)) # Tell web crawlers to not index, archive, or follow links in media. This # should help to prevent things in the media repo from showing up in web @@ -184,7 +195,7 @@ _FILENAME_SEPARATOR_CHARS = { } -def _can_encode_filename_as_token(x): +def _can_encode_filename_as_token(x: str) -> bool: for c in x: # from RFC2616: # @@ -206,17 +217,21 @@ def _can_encode_filename_as_token(x): async def respond_with_responder( - request, responder, media_type, file_size, upload_name=None -): + request: Request, + responder: "Optional[Responder]", + media_type: str, + file_size: Optional[int], + upload_name: Optional[str] = None, +) -> None: """Responds to the request with given responder. If responder is None then returns 404. Args: - request (twisted.web.http.Request) - responder (Responder|None) - media_type (str): The media/content type. - file_size (int|None): Size in bytes of the media. If not known it should be None - upload_name (str|None): The name of the requested file, if any. + request + responder + media_type: The media/content type. + file_size: Size in bytes of the media. If not known it should be None + upload_name: The name of the requested file, if any. """ if request._disconnected: logger.warning( @@ -308,22 +323,22 @@ class FileInfo: self.thumbnail_type = thumbnail_type -def get_filename_from_headers(headers): +def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]: """ Get the filename of the downloaded file by inspecting the Content-Disposition HTTP header. Args: - headers (dict[bytes, list[bytes]]): The HTTP request headers. + headers: The HTTP request headers. Returns: - A Unicode string of the filename, or None. + The filename, or None. """ content_disposition = headers.get(b"Content-Disposition", [b""]) # No header, bail out. if not content_disposition[0]: - return + return None _, params = _parse_header(content_disposition[0]) @@ -356,17 +371,16 @@ def get_filename_from_headers(headers): return upload_name -def _parse_header(line): +def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]: """Parse a Content-type like header. Cargo-culted from `cgi`, but works on bytes rather than strings. Args: - line (bytes): header to be parsed + line: header to be parsed Returns: - Tuple[bytes, dict[bytes, bytes]]: - the main content-type, followed by the parameter dictionary + The main content-type, followed by the parameter dictionary """ parts = _parseparam(b";" + line) key = next(parts) @@ -386,16 +400,16 @@ def _parse_header(line): return key, pdict -def _parseparam(s): +def _parseparam(s: bytes) -> Generator[bytes, None, None]: """Generator which splits the input on ;, respecting double-quoted sequences Cargo-culted from `cgi`, but works on bytes rather than strings. Args: - s (bytes): header to be parsed + s: header to be parsed Returns: - Iterable[bytes]: the split input + The split input """ while s[:1] == b";": s = s[1:] diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py index 68dd2a1c8a..4e4c6971f7 100644 --- a/synapse/rest/media/v1/config_resource.py +++ b/synapse/rest/media/v1/config_resource.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2018 Will Hunt +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,22 +15,29 @@ # limitations under the License. # +from typing import TYPE_CHECKING + +from twisted.web.http import Request + from synapse.http.server import DirectServeJsonResource, respond_with_json +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + class MediaConfigResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() config = hs.get_config() self.clock = hs.get_clock() self.auth = hs.get_auth() self.limits_dict = {"m.upload.size": config.max_upload_size} - async def _async_render_GET(self, request): + async def _async_render_GET(self, request: Request) -> None: await self.auth.get_user_by_req(request) respond_with_json(request, 200, self.limits_dict, send_cors=True) - async def _async_render_OPTIONS(self, request): + async def _async_render_OPTIONS(self, request: Request) -> None: respond_with_json(request, 200, {}, send_cors=True) diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index d3d8457303..3ed219ae43 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,24 +14,31 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import TYPE_CHECKING + +from twisted.web.http import Request -import synapse.http.servlet from synapse.http.server import DirectServeJsonResource, set_cors_headers +from synapse.http.servlet import parse_boolean from ._base import parse_media_id, respond_404 +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + from synapse.rest.media.v1.media_repository import MediaRepository + logger = logging.getLogger(__name__) class DownloadResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs, media_repo): + def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): super().__init__() self.media_repo = media_repo self.server_name = hs.hostname - async def _async_render_GET(self, request): + async def _async_render_GET(self, request: Request) -> None: set_cors_headers(request) request.setHeader( b"Content-Security-Policy", @@ -49,9 +57,7 @@ class DownloadResource(DirectServeJsonResource): if server_name == self.server_name: await self.media_repo.get_local_media(request, media_id, name) else: - allow_remote = synapse.http.servlet.parse_boolean( - request, "allow_remote", default=True - ) + allow_remote = parse_boolean(request, "allow_remote", default=True) if not allow_remote: logger.info( "Rejecting request for remote media %s/%s due to allow_remote", diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index 9e079f672f..7792f26e78 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,11 +17,12 @@ import functools import os import re +from typing import Callable, List NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") -def _wrap_in_base_path(func): +def _wrap_in_base_path(func: "Callable[..., str]") -> "Callable[..., str]": """Takes a function that returns a relative path and turns it into an absolute path based on the location of the primary media store """ @@ -41,12 +43,18 @@ class MediaFilePaths: to write to the backup media store (when one is configured) """ - def __init__(self, primary_base_path): + def __init__(self, primary_base_path: str): self.base_path = primary_base_path def default_thumbnail_rel( - self, default_top_level, default_sub_type, width, height, content_type, method - ): + self, + default_top_level: str, + default_sub_type: str, + width: int, + height: int, + content_type: str, + method: str, + ) -> str: top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( @@ -55,12 +63,14 @@ class MediaFilePaths: default_thumbnail = _wrap_in_base_path(default_thumbnail_rel) - def local_media_filepath_rel(self, media_id): + def local_media_filepath_rel(self, media_id: str) -> str: return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:]) local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) - def local_media_thumbnail_rel(self, media_id, width, height, content_type, method): + def local_media_thumbnail_rel( + self, media_id: str, width: int, height: int, content_type: str, method: str + ) -> str: top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( @@ -86,7 +96,7 @@ class MediaFilePaths: media_id[4:], ) - def remote_media_filepath_rel(self, server_name, file_id): + def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str: return os.path.join( "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:] ) @@ -94,8 +104,14 @@ class MediaFilePaths: remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) def remote_media_thumbnail_rel( - self, server_name, file_id, width, height, content_type, method - ): + self, + server_name: str, + file_id: str, + width: int, + height: int, + content_type: str, + method: str, + ) -> str: top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( @@ -113,7 +129,7 @@ class MediaFilePaths: # Should be removed after some time, when most of the thumbnails are stored # using the new path. def remote_media_thumbnail_rel_legacy( - self, server_name, file_id, width, height, content_type + self, server_name: str, file_id: str, width: int, height: int, content_type: str ): top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) @@ -126,7 +142,7 @@ class MediaFilePaths: file_name, ) - def remote_media_thumbnail_dir(self, server_name, file_id): + def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str: return os.path.join( self.base_path, "remote_thumbnail", @@ -136,7 +152,7 @@ class MediaFilePaths: file_id[4:], ) - def url_cache_filepath_rel(self, media_id): + def url_cache_filepath_rel(self, media_id: str) -> str: if NEW_FORMAT_ID_RE.match(media_id): # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf @@ -146,7 +162,7 @@ class MediaFilePaths: url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) - def url_cache_filepath_dirs_to_delete(self, media_id): + def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]: "The dirs to try and remove if we delete the media_id file" if NEW_FORMAT_ID_RE.match(media_id): return [os.path.join(self.base_path, "url_cache", media_id[:10])] @@ -156,7 +172,9 @@ class MediaFilePaths: os.path.join(self.base_path, "url_cache", media_id[0:2]), ] - def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method): + def url_cache_thumbnail_rel( + self, media_id: str, width: int, height: int, content_type: str, method: str + ) -> str: # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf @@ -178,7 +196,7 @@ class MediaFilePaths: url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) - def url_cache_thumbnail_directory(self, media_id): + def url_cache_thumbnail_directory(self, media_id: str) -> str: # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf @@ -195,7 +213,7 @@ class MediaFilePaths: media_id[4:], ) - def url_cache_thumbnail_dirs_to_delete(self, media_id): + def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]: "The dirs to try and remove if we delete the media_id thumbnails" # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 83beb02b05..4c9946a616 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,12 +13,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import errno import logging import os import shutil -from typing import IO, Dict, List, Optional, Tuple +from io import BytesIO +from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple import twisted.internet.error import twisted.web.http @@ -56,6 +56,9 @@ from .thumbnail_resource import ThumbnailResource from .thumbnailer import Thumbnailer, ThumbnailError from .upload_resource import UploadResource +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -63,7 +66,7 @@ UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000 class MediaRepository: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.client = hs.get_federation_http_client() @@ -73,16 +76,16 @@ class MediaRepository: self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels - self.primary_base_path = hs.config.media_store_path - self.filepaths = MediaFilePaths(self.primary_base_path) + self.primary_base_path = hs.config.media_store_path # type: str + self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements self.remote_media_linearizer = Linearizer(name="media_remote") - self.recently_accessed_remotes = set() - self.recently_accessed_locals = set() + self.recently_accessed_remotes = set() # type: Set[Tuple[str, str]] + self.recently_accessed_locals = set() # type: Set[str] self.federation_domain_whitelist = hs.config.federation_domain_whitelist @@ -113,7 +116,7 @@ class MediaRepository: "update_recently_accessed_media", self._update_recently_accessed ) - async def _update_recently_accessed(self): + async def _update_recently_accessed(self) -> None: remote_media = self.recently_accessed_remotes self.recently_accessed_remotes = set() @@ -124,12 +127,12 @@ class MediaRepository: local_media, remote_media, self.clock.time_msec() ) - def mark_recently_accessed(self, server_name, media_id): + def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None: """Mark the given media as recently accessed. Args: - server_name (str|None): Origin server of media, or None if local - media_id (str): The media ID of the content + server_name: Origin server of media, or None if local + media_id: The media ID of the content """ if server_name: self.recently_accessed_remotes.add((server_name, media_id)) @@ -459,7 +462,14 @@ class MediaRepository: def _get_thumbnail_requirements(self, media_type): return self.thumbnail_requirements.get(media_type, ()) - def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type): + def _generate_thumbnail( + self, + thumbnailer: Thumbnailer, + t_width: int, + t_height: int, + t_method: str, + t_type: str, + ) -> Optional[BytesIO]: m_width = thumbnailer.width m_height = thumbnailer.height @@ -470,22 +480,20 @@ class MediaRepository: m_height, self.max_image_pixels, ) - return + return None if thumbnailer.transpose_method is not None: m_width, m_height = thumbnailer.transpose() if t_method == "crop": - t_byte_source = thumbnailer.crop(t_width, t_height, t_type) + return thumbnailer.crop(t_width, t_height, t_type) elif t_method == "scale": t_width, t_height = thumbnailer.aspect(t_width, t_height) t_width = min(m_width, t_width) t_height = min(m_height, t_height) - t_byte_source = thumbnailer.scale(t_width, t_height, t_type) - else: - t_byte_source = None + return thumbnailer.scale(t_width, t_height, t_type) - return t_byte_source + return None async def generate_local_exact_thumbnail( self, @@ -776,7 +784,7 @@ class MediaRepository: return {"width": m_width, "height": m_height} - async def delete_old_remote_media(self, before_ts): + async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]: old_media = await self.store.get_remote_media_before(before_ts) deleted = 0 @@ -928,7 +936,7 @@ class MediaRepositoryResource(Resource): within a given rectangle. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): # If we're not configured to use it, raise if we somehow got here. if not hs.config.can_load_media_repo: raise ConfigError("Synapse is not configured to use a media repo.") diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 268e0c8f50..89cdd605aa 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2018 New Vecotr Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +18,8 @@ import os import shutil from typing import IO, TYPE_CHECKING, Any, Optional, Sequence +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender from synapse.logging.context import defer_to_thread, make_deferred_yieldable @@ -270,7 +272,7 @@ class MediaStorage: return self.filepaths.local_media_filepath_rel(file_info.file_id) -def _write_file_synchronously(source, dest): +def _write_file_synchronously(source: IO, dest: IO) -> None: """Write `source` to the file like `dest` synchronously. Should be called from a thread. @@ -286,14 +288,14 @@ class FileResponder(Responder): """Wraps an open file that can be sent to a request. Args: - open_file (file): A file like object to be streamed ot the client, + open_file: A file like object to be streamed ot the client, is closed when finished streaming. """ - def __init__(self, open_file): + def __init__(self, open_file: IO): self.open_file = open_file - def write_to_consumer(self, consumer): + def write_to_consumer(self, consumer: IConsumer) -> Deferred: return make_deferred_yieldable( FileSender().beginFileTransfer(self.open_file, consumer) ) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 1082389d9b..a632099167 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import datetime import errno import fnmatch @@ -23,12 +23,13 @@ import re import shutil import sys import traceback -from typing import Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union from urllib import parse as urlparse import attr from twisted.internet.error import DNSLookupError +from twisted.web.http import Request from synapse.api.errors import Codes, SynapseError from synapse.http.client import SimpleHttpClient @@ -41,6 +42,7 @@ from synapse.http.servlet import parse_integer, parse_string from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.media.v1._base import get_filename_from_headers +from synapse.rest.media.v1.media_storage import MediaStorage from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache @@ -48,6 +50,12 @@ from synapse.util.stringutils import random_string from ._base import FileInfo +if TYPE_CHECKING: + from lxml import etree + + from synapse.app.homeserver import HomeServer + from synapse.rest.media.v1.media_repository import MediaRepository + logger = logging.getLogger(__name__) _charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I) @@ -119,7 +127,12 @@ class OEmbedError(Exception): class PreviewUrlResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs, media_repo, media_storage): + def __init__( + self, + hs: "HomeServer", + media_repo: "MediaRepository", + media_storage: MediaStorage, + ): super().__init__() self.auth = hs.get_auth() @@ -166,11 +179,11 @@ class PreviewUrlResource(DirectServeJsonResource): self._start_expire_url_cache_data, 10 * 1000 ) - async def _async_render_OPTIONS(self, request): + async def _async_render_OPTIONS(self, request: Request) -> None: request.setHeader(b"Allow", b"OPTIONS, GET") respond_with_json(request, 200, {}, send_cors=True) - async def _async_render_GET(self, request): + async def _async_render_GET(self, request: Request) -> None: # XXX: if get_user_by_req fails, what should we do in an async render? requester = await self.auth.get_user_by_req(request) @@ -450,7 +463,7 @@ class PreviewUrlResource(DirectServeJsonResource): logger.warning("Error downloading oEmbed metadata from %s: %r", url, e) raise OEmbedError() from e - async def _download_url(self, url: str, user): + async def _download_url(self, url: str, user: str) -> Dict[str, Any]: # TODO: we should probably honour robots.txt... except in practice # we're most likely being explicitly triggered by a human rather than a # bot, so are we really a robot? @@ -580,7 +593,7 @@ class PreviewUrlResource(DirectServeJsonResource): "expire_url_cache_data", self._expire_url_cache_data ) - async def _expire_url_cache_data(self): + async def _expire_url_cache_data(self) -> None: """Clean up expired url cache content, media and thumbnails. """ # TODO: Delete from backup media store @@ -676,7 +689,9 @@ class PreviewUrlResource(DirectServeJsonResource): logger.debug("No media removed from url cache") -def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]: +def decode_and_calc_og( + body: bytes, media_uri: str, request_encoding: Optional[str] = None +) -> Dict[str, Optional[str]]: # If there's no body, nothing useful is going to be found. if not body: return {} @@ -697,7 +712,7 @@ def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str] return og -def _calc_og(tree, media_uri): +def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]: # suck our tree into lxml and define our OG response. # if we see any image URLs in the OG response, then spider them @@ -801,7 +816,9 @@ def _calc_og(tree, media_uri): for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE) ) og["og:description"] = summarize_paragraphs(text_nodes) - else: + elif og["og:description"]: + # This must be a non-empty string at this point. + assert isinstance(og["og:description"], str) og["og:description"] = summarize_paragraphs([og["og:description"]]) # TODO: delete the url downloads to stop diskfilling, @@ -809,7 +826,9 @@ def _calc_og(tree, media_uri): return og -def _iterate_over_text(tree, *tags_to_ignore): +def _iterate_over_text( + tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]] +) -> Generator[str, None, None]: """Iterate over the tree returning text nodes in a depth first fashion, skipping text nodes inside certain tags. """ @@ -843,32 +862,32 @@ def _iterate_over_text(tree, *tags_to_ignore): ) -def _rebase_url(url, base): - base = list(urlparse.urlparse(base)) - url = list(urlparse.urlparse(url)) - if not url[0]: # fix up schema - url[0] = base[0] or "http" - if not url[1]: # fix up hostname - url[1] = base[1] - if not url[2].startswith("/"): - url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2] - return urlparse.urlunparse(url) +def _rebase_url(url: str, base: str) -> str: + base_parts = list(urlparse.urlparse(base)) + url_parts = list(urlparse.urlparse(url)) + if not url_parts[0]: # fix up schema + url_parts[0] = base_parts[0] or "http" + if not url_parts[1]: # fix up hostname + url_parts[1] = base_parts[1] + if not url_parts[2].startswith("/"): + url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2] + return urlparse.urlunparse(url_parts) -def _is_media(content_type): - if content_type.lower().startswith("image/"): - return True +def _is_media(content_type: str) -> bool: + return content_type.lower().startswith("image/") -def _is_html(content_type): +def _is_html(content_type: str) -> bool: content_type = content_type.lower() - if content_type.startswith("text/html") or content_type.startswith( + return content_type.startswith("text/html") or content_type.startswith( "application/xhtml" - ): - return True + ) -def summarize_paragraphs(text_nodes, min_size=200, max_size=500): +def summarize_paragraphs( + text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500 +) -> Optional[str]: # Try to get a summary of between 200 and 500 words, respecting # first paragraph and then word boundaries. # TODO: Respect sentences? diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 67f67efde7..e92006faa9 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc import logging import os import shutil -from typing import Optional +from typing import TYPE_CHECKING, Optional from synapse.config._base import Config from synapse.logging.context import defer_to_thread, run_in_background @@ -27,13 +28,17 @@ from .media_storage import FileResponder logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer -class StorageProvider: + +class StorageProvider(metaclass=abc.ABCMeta): """A storage provider is a service that can store uploaded media and retrieve them. """ - async def store_file(self, path: str, file_info: FileInfo): + @abc.abstractmethod + async def store_file(self, path: str, file_info: FileInfo) -> None: """Store the file described by file_info. The actual contents can be retrieved by reading the file in file_info.upload_path. @@ -42,6 +47,7 @@ class StorageProvider: file_info: The metadata of the file. """ + @abc.abstractmethod async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: """Attempt to fetch the file described by file_info and stream it into writer. @@ -78,10 +84,10 @@ class StorageProviderWrapper(StorageProvider): self.store_synchronous = store_synchronous self.store_remote = store_remote - def __str__(self): + def __str__(self) -> str: return "StorageProviderWrapper[%s]" % (self.backend,) - async def store_file(self, path, file_info): + async def store_file(self, path: str, file_info: FileInfo) -> None: if not file_info.server_name and not self.store_local: return None @@ -91,7 +97,7 @@ class StorageProviderWrapper(StorageProvider): if self.store_synchronous: # store_file is supposed to return an Awaitable, but guard # against improper implementations. - return await maybe_awaitable(self.backend.store_file(path, file_info)) + await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore else: # TODO: Handle errors. async def store(): @@ -103,9 +109,8 @@ class StorageProviderWrapper(StorageProvider): logger.exception("Error storing file") run_in_background(store) - return None - async def fetch(self, path, file_info): + async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: # store_file is supposed to return an Awaitable, but guard # against improper implementations. return await maybe_awaitable(self.backend.fetch(path, file_info)) @@ -115,11 +120,11 @@ class FileStorageProviderBackend(StorageProvider): """A storage provider that stores files in a directory on a filesystem. Args: - hs (HomeServer) + hs config: The config returned by `parse_config`. """ - def __init__(self, hs, config): + def __init__(self, hs: "HomeServer", config: str): self.hs = hs self.cache_directory = hs.config.media_store_path self.base_directory = config @@ -127,7 +132,7 @@ class FileStorageProviderBackend(StorageProvider): def __str__(self): return "FileStorageProviderBackend[%s]" % (self.base_directory,) - async def store_file(self, path, file_info): + async def store_file(self, path: str, file_info: FileInfo) -> None: """See StorageProvider.store_file""" primary_fname = os.path.join(self.cache_directory, path) @@ -137,19 +142,21 @@ class FileStorageProviderBackend(StorageProvider): if not os.path.exists(dirname): os.makedirs(dirname) - return await defer_to_thread( + await defer_to_thread( self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname ) - async def fetch(self, path, file_info): + async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: """See StorageProvider.fetch""" backup_fname = os.path.join(self.base_directory, path) if os.path.isfile(backup_fname): return FileResponder(open(backup_fname, "rb")) + return None + @staticmethod - def parse_config(config): + def parse_config(config: dict) -> str: """Called on startup to parse config supplied. This should parse the config and raise if there is a problem. diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 30421b663a..d6880f2e6e 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,10 +16,14 @@ import logging +from typing import TYPE_CHECKING + +from twisted.web.http import Request from synapse.api.errors import SynapseError from synapse.http.server import DirectServeJsonResource, set_cors_headers from synapse.http.servlet import parse_integer, parse_string +from synapse.rest.media.v1.media_storage import MediaStorage from ._base import ( FileInfo, @@ -28,13 +33,22 @@ from ._base import ( respond_with_responder, ) +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + from synapse.rest.media.v1.media_repository import MediaRepository + logger = logging.getLogger(__name__) class ThumbnailResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs, media_repo, media_storage): + def __init__( + self, + hs: "HomeServer", + media_repo: "MediaRepository", + media_storage: MediaStorage, + ): super().__init__() self.store = hs.get_datastore() @@ -43,7 +57,7 @@ class ThumbnailResource(DirectServeJsonResource): self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.server_name = hs.hostname - async def _async_render_GET(self, request): + async def _async_render_GET(self, request: Request) -> None: set_cors_headers(request) server_name, media_id, _ = parse_media_id(request) width = parse_integer(request, "width", required=True) @@ -73,8 +87,14 @@ class ThumbnailResource(DirectServeJsonResource): self.media_repo.mark_recently_accessed(server_name, media_id) async def _respond_local_thumbnail( - self, request, media_id, width, height, method, m_type - ): + self, + request: Request, + media_id: str, + width: int, + height: int, + method: str, + m_type: str, + ) -> None: media_info = await self.store.get_local_media(media_id) if not media_info: @@ -114,13 +134,13 @@ class ThumbnailResource(DirectServeJsonResource): async def _select_or_generate_local_thumbnail( self, - request, - media_id, - desired_width, - desired_height, - desired_method, - desired_type, - ): + request: Request, + media_id: str, + desired_width: int, + desired_height: int, + desired_method: str, + desired_type: str, + ) -> None: media_info = await self.store.get_local_media(media_id) if not media_info: @@ -178,14 +198,14 @@ class ThumbnailResource(DirectServeJsonResource): async def _select_or_generate_remote_thumbnail( self, - request, - server_name, - media_id, - desired_width, - desired_height, - desired_method, - desired_type, - ): + request: Request, + server_name: str, + media_id: str, + desired_width: int, + desired_height: int, + desired_method: str, + desired_type: str, + ) -> None: media_info = await self.media_repo.get_remote_media_info(server_name, media_id) thumbnail_infos = await self.store.get_remote_media_thumbnails( @@ -239,8 +259,15 @@ class ThumbnailResource(DirectServeJsonResource): raise SynapseError(400, "Failed to generate thumbnail.") async def _respond_remote_thumbnail( - self, request, server_name, media_id, width, height, method, m_type - ): + self, + request: Request, + server_name: str, + media_id: str, + width: int, + height: int, + method: str, + m_type: str, + ) -> None: # TODO: Don't download the whole remote file # We should proxy the thumbnail from the remote server instead of # downloading the remote file and generating our own thumbnails. @@ -275,12 +302,12 @@ class ThumbnailResource(DirectServeJsonResource): def _select_thumbnail( self, - desired_width, - desired_height, - desired_method, - desired_type, + desired_width: int, + desired_height: int, + desired_method: str, + desired_type: str, thumbnail_infos, - ): + ) -> dict: d_w = desired_width d_h = desired_height diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 32a8e4f960..07903e4017 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +15,7 @@ # limitations under the License. import logging from io import BytesIO +from typing import Tuple from PIL import Image @@ -39,7 +41,7 @@ class Thumbnailer: FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"} - def __init__(self, input_path): + def __init__(self, input_path: str): try: self.image = Image.open(input_path) except OSError as e: @@ -59,11 +61,11 @@ class Thumbnailer: # A lot of parsing errors can happen when parsing EXIF logger.info("Error parsing image EXIF information: %s", e) - def transpose(self): + def transpose(self) -> Tuple[int, int]: """Transpose the image using its EXIF Orientation tag Returns: - Tuple[int, int]: (width, height) containing the new image size in pixels. + A tuple containing the new image size in pixels as (width, height). """ if self.transpose_method is not None: self.image = self.image.transpose(self.transpose_method) @@ -73,7 +75,7 @@ class Thumbnailer: self.image.info["exif"] = None return self.image.size - def aspect(self, max_width, max_height): + def aspect(self, max_width: int, max_height: int) -> Tuple[int, int]: """Calculate the largest size that preserves aspect ratio which fits within the given rectangle:: @@ -91,7 +93,7 @@ class Thumbnailer: else: return (max_height * self.width) // self.height, max_height - def _resize(self, width, height): + def _resize(self, width: int, height: int) -> Image: # 1-bit or 8-bit color palette images need converting to RGB # otherwise they will be scaled using nearest neighbour which # looks awful @@ -99,7 +101,7 @@ class Thumbnailer: self.image = self.image.convert("RGB") return self.image.resize((width, height), Image.ANTIALIAS) - def scale(self, width, height, output_type): + def scale(self, width: int, height: int, output_type: str) -> BytesIO: """Rescales the image to the given dimensions. Returns: @@ -108,7 +110,7 @@ class Thumbnailer: scaled = self._resize(width, height) return self._encode_image(scaled, output_type) - def crop(self, width, height, output_type): + def crop(self, width: int, height: int, output_type: str) -> BytesIO: """Rescales and crops the image to the given dimensions preserving aspect:: (w_in / h_in) = (w_scaled / h_scaled) @@ -136,7 +138,7 @@ class Thumbnailer: cropped = scaled_image.crop((crop_left, 0, crop_right, height)) return self._encode_image(cropped, output_type) - def _encode_image(self, output_image, output_type): + def _encode_image(self, output_image: Image, output_type: str) -> BytesIO: output_bytes_io = BytesIO() fmt = self.FORMATS[output_type] if fmt == "JPEG": diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 42febc9afc..6da76ae994 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,18 +15,25 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING + +from twisted.web.http import Request from synapse.api.errors import Codes, SynapseError from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_string +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + from synapse.rest.media.v1.media_repository import MediaRepository + logger = logging.getLogger(__name__) class UploadResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs, media_repo): + def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): super().__init__() self.media_repo = media_repo @@ -37,10 +45,10 @@ class UploadResource(DirectServeJsonResource): self.max_upload_size = hs.config.max_upload_size self.clock = hs.get_clock() - async def _async_render_OPTIONS(self, request): + async def _async_render_OPTIONS(self, request: Request) -> None: respond_with_json(request, 200, {}, send_cors=True) - async def _async_render_POST(self, request): + async def _async_render_POST(self, request: Request) -> None: requester = await self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have # already been uploaded to a tmp file at this point diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 4b2f224718..283c8a5e22 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -169,7 +170,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def get_local_media_before( self, before_ts: int, size_gt: int, keep_profiles: bool, - ) -> Optional[List[str]]: + ) -> List[str]: # to find files that have never been accessed (last_access_ts IS NULL) # compare with `created_ts` -- cgit 1.5.1 From 3e4cdfe5d9b6152b9a0b7f7ad22623c2bf19f7ff Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 15 Jan 2021 11:18:09 -0500 Subject: Add an admin API endpoint to protect media. (#9086) Protecting media stops it from being quarantined when e.g. all media in a room is quarantined. This is useful for sticker packs and other media that is uploaded by server administrators, but used by many people. --- changelog.d/9086.feature | 1 + docs/admin_api/media_admin_api.md | 24 +++++++++++++++ synapse/rest/admin/media.py | 64 ++++++++++++++++++++++++++++++--------- tests/rest/admin/test_admin.py | 8 +++-- 4 files changed, 79 insertions(+), 18 deletions(-) create mode 100644 changelog.d/9086.feature (limited to 'synapse/rest') diff --git a/changelog.d/9086.feature b/changelog.d/9086.feature new file mode 100644 index 0000000000..3e678e24d5 --- /dev/null +++ b/changelog.d/9086.feature @@ -0,0 +1 @@ +Add an admin API for protecting local media from quarantine. diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md index dfb8c5d751..90faeaaef0 100644 --- a/docs/admin_api/media_admin_api.md +++ b/docs/admin_api/media_admin_api.md @@ -4,6 +4,7 @@ * [Quarantining media by ID](#quarantining-media-by-id) * [Quarantining media in a room](#quarantining-media-in-a-room) * [Quarantining all media of a user](#quarantining-all-media-of-a-user) + * [Protecting media from being quarantined](#protecting-media-from-being-quarantined) - [Delete local media](#delete-local-media) * [Delete a specific local media](#delete-a-specific-local-media) * [Delete local media by date or size](#delete-local-media-by-date-or-size) @@ -123,6 +124,29 @@ The following fields are returned in the JSON response body: * `num_quarantined`: integer - The number of media items successfully quarantined +## Protecting media from being quarantined + +This API protects a single piece of local media from being quarantined using the +above APIs. This is useful for sticker packs and other shared media which you do +not want to get quarantined, especially when +[quarantining media in a room](#quarantining-media-in-a-room). + +Request: + +``` +POST /_synapse/admin/v1/media/protect/ + +{} +``` + +Where `media_id` is in the form of `abcdefg12345...`. + +Response: + +```json +{} +``` + # Delete local media This API deletes the *local* media from the disk of your own server. This includes any local thumbnails and copies of media downloaded from diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index c82b4f87d6..8720b1401f 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -15,6 +15,9 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple + +from twisted.web.http import Request from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.http.servlet import RestServlet, parse_boolean, parse_integer @@ -23,6 +26,10 @@ from synapse.rest.admin._base import ( assert_requester_is_admin, assert_user_is_admin, ) +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -39,11 +46,11 @@ class QuarantineMediaInRoom(RestServlet): admin_patterns("/quarantine_media/(?P[^/]+)") ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request, room_id: str): + async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -64,11 +71,11 @@ class QuarantineMediaByUser(RestServlet): PATTERNS = admin_patterns("/user/(?P[^/]+)/media/quarantine") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request, user_id: str): + async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -91,11 +98,13 @@ class QuarantineMediaByID(RestServlet): "/media/quarantine/(?P[^/]+)/(?P[^/]+)" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request, server_name: str, media_id: str): + async def on_POST( + self, request: Request, server_name: str, media_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -109,17 +118,39 @@ class QuarantineMediaByID(RestServlet): return 200, {} +class ProtectMediaByID(RestServlet): + """Protect local media from being quarantined. + """ + + PATTERNS = admin_patterns("/media/protect/(?P[^/]+)") + + def __init__(self, hs: "HomeServer"): + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + logging.info("Protecting local media by ID: %s", media_id) + + # Quarantine this media id + await self.store.mark_local_media_as_safe(media_id) + + return 200, {} + + class ListMediaInRoom(RestServlet): """Lists all of the media in a given room. """ PATTERNS = admin_patterns("/room/(?P[^/]+)/media") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_GET(self, request, room_id): + async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) is_admin = await self.auth.is_server_admin(requester.user) if not is_admin: @@ -133,11 +164,11 @@ class ListMediaInRoom(RestServlet): class PurgeMediaCacheRestServlet(RestServlet): PATTERNS = admin_patterns("/purge_media_cache") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.media_repository = hs.get_media_repository() self.auth = hs.get_auth() - async def on_POST(self, request): + async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) before_ts = parse_integer(request, "before_ts", required=True) @@ -154,13 +185,15 @@ class DeleteMediaByID(RestServlet): PATTERNS = admin_patterns("/media/(?P[^/]+)/(?P[^/]+)") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() self.server_name = hs.hostname self.media_repository = hs.get_media_repository() - async def on_DELETE(self, request, server_name: str, media_id: str): + async def on_DELETE( + self, request: Request, server_name: str, media_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) if self.server_name != server_name: @@ -182,13 +215,13 @@ class DeleteMediaByDateSize(RestServlet): PATTERNS = admin_patterns("/media/(?P[^/]+)/delete") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() self.server_name = hs.hostname self.media_repository = hs.get_media_repository() - async def on_POST(self, request, server_name: str): + async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) before_ts = parse_integer(request, "before_ts", required=True) @@ -222,7 +255,7 @@ class DeleteMediaByDateSize(RestServlet): return 200, {"deleted_media": deleted_media, "total": total} -def register_servlets_for_media_repo(hs, http_server): +def register_servlets_for_media_repo(hs: "HomeServer", http_server): """ Media repo specific APIs. """ @@ -230,6 +263,7 @@ def register_servlets_for_media_repo(hs, http_server): QuarantineMediaInRoom(hs).register(http_server) QuarantineMediaByID(hs).register(http_server) QuarantineMediaByUser(hs).register(http_server) + ProtectMediaByID(hs).register(http_server) ListMediaInRoom(hs).register(http_server) DeleteMediaByID(hs).register(http_server) DeleteMediaByDateSize(hs).register(http_server) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 586b877bda..9d22c04073 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -153,8 +153,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - # Allow for uploading and downloading to/from the media repo self.media_repo = hs.get_media_repository_resource() self.download_resource = self.media_repo.children[b"download"] @@ -428,7 +426,11 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Mark the second item as safe from quarantine. _, media_id_2 = server_and_media_id_2.split("/") - self.get_success(self.store.mark_local_media_as_safe(media_id_2)) + # Quarantine the media + url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),) + channel = self.make_request("POST", url, access_token=admin_user_tok) + self.pump(1.0) + self.assertEqual(200, int(channel.code), msg=channel.result["body"]) # Quarantine all media by this user url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote( -- cgit 1.5.1 From 02070c69faa47bf6aef280939c2d5f32cbcb9f25 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 18 Jan 2021 14:52:49 +0000 Subject: Fix bugs in handling clientRedirectUrl, and improve OIDC tests (#9127, #9128) * Factor out a common TestHtmlParser Looks like I'm doing this in a few different places. * Improve OIDC login test Complete the OIDC login flow, rather than giving up halfway through. * Ensure that OIDC login works with multiple OIDC providers * Fix bugs in handling clientRedirectUrl - don't drop duplicate query-params, or params with no value - allow utf-8 in query-params --- changelog.d/9127.feature | 1 + changelog.d/9128.bugfix | 1 + synapse/handlers/auth.py | 4 +- synapse/handlers/oidc_handler.py | 2 +- synapse/rest/synapse/client/pick_idp.py | 4 +- tests/rest/client/v1/test_login.py | 146 ++++++++++++++++++++------------ tests/rest/client/v1/utils.py | 62 ++++++++------ tests/server.py | 2 +- tests/test_utils/html_parsers.py | 53 ++++++++++++ 9 files changed, 189 insertions(+), 86 deletions(-) create mode 100644 changelog.d/9127.feature create mode 100644 changelog.d/9128.bugfix create mode 100644 tests/test_utils/html_parsers.py (limited to 'synapse/rest') diff --git a/changelog.d/9127.feature b/changelog.d/9127.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9127.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/changelog.d/9128.bugfix b/changelog.d/9128.bugfix new file mode 100644 index 0000000000..f87b9fb9aa --- /dev/null +++ b/changelog.d/9128.bugfix @@ -0,0 +1 @@ +Fix minor bugs in handling the `clientRedirectUrl` parameter for SSO login. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 18cd2b62f0..0e98db22b3 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1504,8 +1504,8 @@ class AuthHandler(BaseHandler): @staticmethod def add_query_param_to_url(url: str, param_name: str, param: Any): url_parts = list(urllib.parse.urlparse(url)) - query = dict(urllib.parse.parse_qsl(url_parts[4])) - query.update({param_name: param}) + query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True) + query.append((param_name, param)) url_parts[4] = urllib.parse.urlencode(query) return urllib.parse.urlunparse(url_parts) diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 5e5fda7b2f..ba686d74b2 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -85,7 +85,7 @@ class OidcHandler: self._token_generator = OidcSessionTokenGenerator(hs) self._providers = { p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs - } + } # type: Dict[str, OidcProvider] async def load_metadata(self) -> None: """Validate the config and load the metadata from the remote endpoint. diff --git a/synapse/rest/synapse/client/pick_idp.py b/synapse/rest/synapse/client/pick_idp.py index e5b720bbca..9550b82998 100644 --- a/synapse/rest/synapse/client/pick_idp.py +++ b/synapse/rest/synapse/client/pick_idp.py @@ -45,7 +45,9 @@ class PickIdpResource(DirectServeHtmlResource): self._server_name = hs.hostname async def _async_render_GET(self, request: SynapseRequest) -> None: - client_redirect_url = parse_string(request, "redirectUrl", required=True) + client_redirect_url = parse_string( + request, "redirectUrl", required=True, encoding="utf-8" + ) idp = parse_string(request, "idp", required=False) # if we need to pick an IdP, do so diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 73a009efd1..2d25490374 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -15,9 +15,8 @@ import time import urllib.parse -from html.parser import HTMLParser -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union -from urllib.parse import parse_qs, urlencode, urlparse +from typing import Any, Dict, Union +from urllib.parse import urlencode from mock import Mock @@ -38,6 +37,7 @@ from tests import unittest from tests.handlers.test_oidc import HAS_OIDC from tests.handlers.test_saml import has_saml2 from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG +from tests.test_utils.html_parsers import TestHtmlParser from tests.unittest import HomeserverTestCase, override_config, skip_unless try: @@ -69,6 +69,12 @@ TEST_SAML_METADATA = """ LOGIN_URL = b"/_matrix/client/r0/login" TEST_URL = b"/_matrix/client/r0/account/whoami" +# a (valid) url with some annoying characters in. %3D is =, %26 is &, %2B is + +TEST_CLIENT_REDIRECT_URL = 'https://x?&q"+%3D%2B"="fö%26=o"' + +# the query params in TEST_CLIENT_REDIRECT_URL +EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("", ""), ('q" =+"', '"fö&=o"')] + class LoginRestServletTestCase(unittest.HomeserverTestCase): @@ -389,23 +395,44 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): }, } + # default OIDC provider config["oidc_config"] = TEST_OIDC_CONFIG + # additional OIDC providers + config["oidc_providers"] = [ + { + "idp_id": "idp1", + "idp_name": "IDP1", + "discover": False, + "issuer": "https://issuer1", + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": ["profile"], + "authorization_endpoint": "https://issuer1/auth", + "token_endpoint": "https://issuer1/token", + "userinfo_endpoint": "https://issuer1/userinfo", + "user_mapping_provider": { + "config": {"localpart_template": "{{ user.sub }}"} + }, + } + ] return config def create_resource_dict(self) -> Dict[str, Resource]: + from synapse.rest.oidc import OIDCResource + d = super().create_resource_dict() d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs) + d["/_synapse/oidc"] = OIDCResource(self.hs) return d def test_multi_sso_redirect(self): """/login/sso/redirect should redirect to an identity picker""" - client_redirect_url = "https://x?" - # first hit the redirect url, which should redirect to our idp picker channel = self.make_request( "GET", - "/_matrix/client/r0/login/sso/redirect?redirectUrl=" + client_redirect_url, + "/_matrix/client/r0/login/sso/redirect?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), ) self.assertEqual(channel.code, 302, channel.result) uri = channel.headers.getRawHeaders("Location")[0] @@ -415,46 +442,22 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) # parse the form to check it has fields assumed elsewhere in this class - class FormPageParser(HTMLParser): - def __init__(self): - super().__init__() - - # the values of the hidden inputs: map from name to value - self.hiddens = {} # type: Dict[str, Optional[str]] - - # the values of the radio buttons - self.radios = [] # type: List[Optional[str]] - - def handle_starttag( - self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] - ) -> None: - attr_dict = dict(attrs) - if tag == "input": - if attr_dict["type"] == "radio" and attr_dict["name"] == "idp": - self.radios.append(attr_dict["value"]) - elif attr_dict["type"] == "hidden": - input_name = attr_dict["name"] - assert input_name - self.hiddens[input_name] = attr_dict["value"] - - def error(_, message): - self.fail(message) - - p = FormPageParser() + p = TestHtmlParser() p.feed(channel.result["body"].decode("utf-8")) p.close() - self.assertCountEqual(p.radios, ["cas", "oidc", "saml"]) + self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "idp1", "saml"]) - self.assertEqual(p.hiddens["redirectUrl"], client_redirect_url) + self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL) def test_multi_sso_redirect_to_cas(self): """If CAS is chosen, should redirect to the CAS server""" - client_redirect_url = "https://x?" channel = self.make_request( "GET", - "/_synapse/client/pick_idp?redirectUrl=" + client_redirect_url + "&idp=cas", + "/_synapse/client/pick_idp?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + "&idp=cas", shorthand=False, ) self.assertEqual(channel.code, 302, channel.result) @@ -470,16 +473,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): service_uri = cas_uri_params["service"][0] _, service_uri_query = service_uri.split("?", 1) service_uri_params = urllib.parse.parse_qs(service_uri_query) - self.assertEqual(service_uri_params["redirectUrl"][0], client_redirect_url) + self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL) def test_multi_sso_redirect_to_saml(self): """If SAML is chosen, should redirect to the SAML server""" - client_redirect_url = "https://x?" - channel = self.make_request( "GET", "/_synapse/client/pick_idp?redirectUrl=" - + client_redirect_url + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + "&idp=saml", ) self.assertEqual(channel.code, 302, channel.result) @@ -492,16 +493,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): # the RelayState is used to carry the client redirect url saml_uri_params = urllib.parse.parse_qs(saml_uri_query) relay_state_param = saml_uri_params["RelayState"][0] - self.assertEqual(relay_state_param, client_redirect_url) + self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL) - def test_multi_sso_redirect_to_oidc(self): + def test_login_via_oidc(self): """If OIDC is chosen, should redirect to the OIDC auth endpoint""" - client_redirect_url = "https://x?" + # pick the default OIDC provider channel = self.make_request( "GET", "/_synapse/client/pick_idp?redirectUrl=" - + client_redirect_url + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + "&idp=oidc", ) self.assertEqual(channel.code, 302, channel.result) @@ -521,9 +522,41 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie) self.assertEqual( self._get_value_from_macaroon(macaroon, "client_redirect_url"), - client_redirect_url, + TEST_CLIENT_REDIRECT_URL, ) + channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"}) + + # that should serve a confirmation page + self.assertEqual(channel.code, 200, channel.result) + self.assertTrue( + channel.headers.getRawHeaders("Content-Type")[-1].startswith("text/html") + ) + p = TestHtmlParser() + p.feed(channel.text_body) + p.close() + + # ... which should contain our redirect link + self.assertEqual(len(p.links), 1) + path, query = p.links[0].split("?", 1) + self.assertEqual(path, "https://x") + + # it will have url-encoded the params properly, so we'll have to parse them + params = urllib.parse.parse_qsl( + query, keep_blank_values=True, strict_parsing=True, errors="strict" + ) + self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) + self.assertEqual(params[2][0], "loginToken") + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token, mxid, and device id. + login_token = params[2][1] + chan = self.make_request( + "POST", "/login", content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.json_body["user_id"], "@user1:test") + def test_multi_sso_redirect_to_unknown(self): """An unknown IdP should cause a 400""" channel = self.make_request( @@ -1082,7 +1115,7 @@ class UsernamePickerTestCase(HomeserverTestCase): # whitelist this client URI so we redirect straight to it rather than # serving a confirmation page - config["sso"] = {"client_whitelist": ["https://whitelisted.client"]} + config["sso"] = {"client_whitelist": ["https://x"]} return config def create_resource_dict(self) -> Dict[str, Resource]: @@ -1095,11 +1128,10 @@ class UsernamePickerTestCase(HomeserverTestCase): def test_username_picker(self): """Test the happy path of a username picker flow.""" - client_redirect_url = "https://whitelisted.client" # do the start of the login flow channel = self.helper.auth_via_oidc( - {"sub": "tester", "displayname": "Jonny"}, client_redirect_url + {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL ) # that should redirect to the username picker @@ -1122,7 +1154,7 @@ class UsernamePickerTestCase(HomeserverTestCase): session = username_mapping_sessions[session_id] self.assertEqual(session.remote_user_id, "tester") self.assertEqual(session.display_name, "Jonny") - self.assertEqual(session.client_redirect_url, client_redirect_url) + self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL) # the expiry time should be about 15 minutes away expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) @@ -1146,15 +1178,19 @@ class UsernamePickerTestCase(HomeserverTestCase): ) self.assertEqual(chan.code, 302, chan.result) location_headers = chan.headers.getRawHeaders("Location") - # ensure that the returned location starts with the requested redirect URL - self.assertEqual( - location_headers[0][: len(client_redirect_url)], client_redirect_url + # ensure that the returned location matches the requested redirect URL + path, query = location_headers[0].split("?", 1) + self.assertEqual(path, "https://x") + + # it will have url-encoded the params properly, so we'll have to parse them + params = urllib.parse.parse_qsl( + query, keep_blank_values=True, strict_parsing=True, errors="strict" ) + self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) + self.assertEqual(params[2][0], "loginToken") # fish the login token out of the returned redirect uri - parts = urlparse(location_headers[0]) - query = parse_qs(parts.query) - login_token = query["loginToken"][0] + login_token = params[2][1] # finally, submit the matrix login token to the login API, which gives us our # matrix access token, mxid, and device id. diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index c6647dbe08..b1333df82d 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -20,8 +20,7 @@ import json import re import time import urllib.parse -from html.parser import HTMLParser -from typing import Any, Dict, Iterable, List, MutableMapping, Optional, Tuple +from typing import Any, Dict, Mapping, MutableMapping, Optional from mock import patch @@ -35,6 +34,7 @@ from synapse.types import JsonDict from tests.server import FakeChannel, FakeSite, make_request from tests.test_utils import FakeResponse +from tests.test_utils.html_parsers import TestHtmlParser @attr.s @@ -440,10 +440,36 @@ class RestHelper: # param that synapse passes to the IdP via query params, as well as the cookie # that synapse passes to the client. - oauth_uri_path, oauth_uri_qs = oauth_uri.split("?", 1) + oauth_uri_path, _ = oauth_uri.split("?", 1) assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, ( "unexpected SSO URI " + oauth_uri_path ) + return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict) + + def complete_oidc_auth( + self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict, + ) -> FakeChannel: + """Mock out an OIDC authentication flow + + Assumes that an OIDC auth has been initiated by one of initiate_sso_login or + initiate_sso_ui_auth; completes the OIDC bits of the flow by making a request to + Synapse's OIDC callback endpoint, intercepting the HTTP requests that will get + sent back to the OIDC provider. + + Requires the OIDC callback resource to be mounted at the normal place. + + Args: + oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie, + from initiate_sso_login or initiate_sso_ui_auth). + cookies: the cookies set by synapse's redirect endpoint, which will be + sent back to the callback endpoint. + user_info_dict: the remote userinfo that the OIDC provider should present. + Typically this should be '{"sub": ""}'. + + Returns: + A FakeChannel containing the result of calling the OIDC callback endpoint. + """ + _, oauth_uri_qs = oauth_uri.split("?", 1) params = urllib.parse.parse_qs(oauth_uri_qs) callback_uri = "%s?%s" % ( urllib.parse.urlparse(params["redirect_uri"][0]).path, @@ -456,9 +482,9 @@ class RestHelper: expected_requests = [ # first we get a hit to the token endpoint, which we tell to return # a dummy OIDC access token - ("https://issuer.test/token", {"access_token": "TEST"}), + (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}), # and then one to the user_info endpoint, which returns our remote user id. - ("https://issuer.test/userinfo", user_info_dict), + (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict), ] async def mock_req(method: str, uri: str, data=None, headers=None): @@ -542,25 +568,7 @@ class RestHelper: channel.extract_cookies(cookies) # parse the confirmation page to fish out the link. - class ConfirmationPageParser(HTMLParser): - def __init__(self): - super().__init__() - - self.links = [] # type: List[str] - - def handle_starttag( - self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] - ) -> None: - attr_dict = dict(attrs) - if tag == "a": - href = attr_dict["href"] - if href: - self.links.append(href) - - def error(_, message): - raise AssertionError(message) - - p = ConfirmationPageParser() + p = TestHtmlParser() p.feed(channel.text_body) p.close() assert len(p.links) == 1, "not exactly one link in confirmation page" @@ -570,6 +578,8 @@ class RestHelper: # an 'oidc_config' suitable for login_via_oidc. TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth" +TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token" +TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo" TEST_OIDC_CONFIG = { "enabled": True, "discover": False, @@ -578,7 +588,7 @@ TEST_OIDC_CONFIG = { "client_secret": "test-client-secret", "scopes": ["profile"], "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT, - "token_endpoint": "https://issuer.test/token", - "userinfo_endpoint": "https://issuer.test/userinfo", + "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT, + "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT, "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, } diff --git a/tests/server.py b/tests/server.py index 5a1b66270f..5a85d5fe7f 100644 --- a/tests/server.py +++ b/tests/server.py @@ -74,7 +74,7 @@ class FakeChannel: return int(self.result["code"]) @property - def headers(self): + def headers(self) -> Headers: if not self.result: raise Exception("No result yet.") h = Headers() diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py new file mode 100644 index 0000000000..ad563eb3f0 --- /dev/null +++ b/tests/test_utils/html_parsers.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from html.parser import HTMLParser +from typing import Dict, Iterable, List, Optional, Tuple + + +class TestHtmlParser(HTMLParser): + """A generic HTML page parser which extracts useful things from the HTML""" + + def __init__(self): + super().__init__() + + # a list of links found in the doc + self.links = [] # type: List[str] + + # the values of any hidden s: map from name to value + self.hiddens = {} # type: Dict[str, Optional[str]] + + # the values of any radio buttons: map from name to list of values + self.radios = {} # type: Dict[str, List[Optional[str]]] + + def handle_starttag( + self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] + ) -> None: + attr_dict = dict(attrs) + if tag == "a": + href = attr_dict["href"] + if href: + self.links.append(href) + elif tag == "input": + input_name = attr_dict.get("name") + if attr_dict["type"] == "radio": + assert input_name + self.radios.setdefault(input_name, []).append(attr_dict["value"]) + elif attr_dict["type"] == "hidden": + assert input_name + self.hiddens[input_name] = attr_dict["value"] + + def error(_, message): + raise AssertionError(message) -- cgit 1.5.1 From 6633a4015a7b4ba60f87c5e6f979a9c9d8f9d8fe Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 18 Jan 2021 15:47:59 +0000 Subject: Allow moving account data and receipts streams off master (#9104) --- changelog.d/9104.feature | 1 + synapse/app/generic_worker.py | 15 +- synapse/config/workers.py | 18 +- synapse/handlers/account_data.py | 144 ++++++++++++++++ synapse/handlers/read_marker.py | 5 +- synapse/handlers/receipts.py | 27 ++- synapse/handlers/room_member.py | 7 +- synapse/replication/http/__init__.py | 2 + synapse/replication/http/account_data.py | 187 +++++++++++++++++++++ synapse/replication/slave/storage/_base.py | 10 +- synapse/replication/slave/storage/account_data.py | 40 +---- synapse/replication/slave/storage/receipts.py | 35 +--- synapse/replication/tcp/handler.py | 19 +++ synapse/rest/client/v2_alpha/account_data.py | 22 +-- synapse/rest/client/v2_alpha/tags.py | 11 +- synapse/server.py | 5 + synapse/storage/databases/main/__init__.py | 10 +- synapse/storage/databases/main/account_data.py | 107 +++++++++--- synapse/storage/databases/main/deviceinbox.py | 4 +- .../storage/databases/main/event_push_actions.py | 92 +++++----- synapse/storage/databases/main/events_worker.py | 8 +- synapse/storage/databases/main/receipts.py | 108 ++++++++---- .../main/schema/delta/59/06shard_account_data.sql | 20 +++ .../delta/59/06shard_account_data.sql.postgres | 32 ++++ synapse/storage/databases/main/tags.py | 10 +- synapse/storage/util/id_generators.py | 84 +++++---- tests/storage/test_id_generators.py | 112 +++++++++++- 27 files changed, 855 insertions(+), 280 deletions(-) create mode 100644 changelog.d/9104.feature create mode 100644 synapse/replication/http/account_data.py create mode 100644 synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql create mode 100644 synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres (limited to 'synapse/rest') diff --git a/changelog.d/9104.feature b/changelog.d/9104.feature new file mode 100644 index 0000000000..1c4f88bce9 --- /dev/null +++ b/changelog.d/9104.feature @@ -0,0 +1 @@ +Add experimental support for moving off receipts and account data persistence off master. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index cb202bda44..e60988fa4a 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -100,7 +100,16 @@ from synapse.rest.client.v1.profile import ( ) from synapse.rest.client.v1.push_rule import PushRuleRestServlet from synapse.rest.client.v1.voip import VoipRestServlet -from synapse.rest.client.v2_alpha import groups, room_keys, sync, user_directory +from synapse.rest.client.v2_alpha import ( + account_data, + groups, + read_marker, + receipts, + room_keys, + sync, + tags, + user_directory, +) from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.client.v2_alpha.account import ThreepidRestServlet from synapse.rest.client.v2_alpha.account_data import ( @@ -531,6 +540,10 @@ class GenericWorkerServer(HomeServer): room.register_deprecated_servlets(self, resource) InitialSyncRestServlet(self).register(resource) room_keys.register_servlets(self, resource) + tags.register_servlets(self, resource) + account_data.register_servlets(self, resource) + receipts.register_servlets(self, resource) + read_marker.register_servlets(self, resource) SendToDeviceRestServlet(self).register(resource) diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 364583f48b..f10e33f7b8 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -56,6 +56,12 @@ class WriterLocations: to_device = attr.ib( default=["master"], type=List[str], converter=_instance_to_list_converter, ) + account_data = attr.ib( + default=["master"], type=List[str], converter=_instance_to_list_converter, + ) + receipts = attr.ib( + default=["master"], type=List[str], converter=_instance_to_list_converter, + ) class WorkerConfig(Config): @@ -127,7 +133,7 @@ class WorkerConfig(Config): # Check that the configured writers for events and typing also appears in # `instance_map`. - for stream in ("events", "typing", "to_device"): + for stream in ("events", "typing", "to_device", "account_data", "receipts"): instances = _instance_to_list_converter(getattr(self.writers, stream)) for instance in instances: if instance != "master" and instance not in self.instance_map: @@ -141,6 +147,16 @@ class WorkerConfig(Config): "Must only specify one instance to handle `to_device` messages." ) + if len(self.writers.account_data) != 1: + raise ConfigError( + "Must only specify one instance to handle `account_data` messages." + ) + + if len(self.writers.receipts) != 1: + raise ConfigError( + "Must only specify one instance to handle `receipts` messages." + ) + self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) # Whether this worker should run background tasks or not. diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 341135822e..b1a5df9638 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,14 +13,157 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import random from typing import TYPE_CHECKING, List, Tuple +from synapse.replication.http.account_data import ( + ReplicationAddTagRestServlet, + ReplicationRemoveTagRestServlet, + ReplicationRoomAccountDataRestServlet, + ReplicationUserAccountDataRestServlet, +) from synapse.types import JsonDict, UserID if TYPE_CHECKING: from synapse.app.homeserver import HomeServer +class AccountDataHandler: + def __init__(self, hs: "HomeServer"): + self._store = hs.get_datastore() + self._instance_name = hs.get_instance_name() + self._notifier = hs.get_notifier() + + self._user_data_client = ReplicationUserAccountDataRestServlet.make_client(hs) + self._room_data_client = ReplicationRoomAccountDataRestServlet.make_client(hs) + self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs) + self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs) + self._account_data_writers = hs.config.worker.writers.account_data + + async def add_account_data_to_room( + self, user_id: str, room_id: str, account_data_type: str, content: JsonDict + ) -> int: + """Add some account_data to a room for a user. + + Args: + user_id: The user to add a tag for. + room_id: The room to add a tag for. + account_data_type: The type of account_data to add. + content: A json object to associate with the tag. + + Returns: + The maximum stream ID. + """ + if self._instance_name in self._account_data_writers: + max_stream_id = await self._store.add_account_data_to_room( + user_id, room_id, account_data_type, content + ) + + self._notifier.on_new_event( + "account_data_key", max_stream_id, users=[user_id] + ) + + return max_stream_id + else: + response = await self._room_data_client( + instance_name=random.choice(self._account_data_writers), + user_id=user_id, + room_id=room_id, + account_data_type=account_data_type, + content=content, + ) + return response["max_stream_id"] + + async def add_account_data_for_user( + self, user_id: str, account_data_type: str, content: JsonDict + ) -> int: + """Add some account_data to a room for a user. + + Args: + user_id: The user to add a tag for. + account_data_type: The type of account_data to add. + content: A json object to associate with the tag. + + Returns: + The maximum stream ID. + """ + + if self._instance_name in self._account_data_writers: + max_stream_id = await self._store.add_account_data_for_user( + user_id, account_data_type, content + ) + + self._notifier.on_new_event( + "account_data_key", max_stream_id, users=[user_id] + ) + return max_stream_id + else: + response = await self._user_data_client( + instance_name=random.choice(self._account_data_writers), + user_id=user_id, + account_data_type=account_data_type, + content=content, + ) + return response["max_stream_id"] + + async def add_tag_to_room( + self, user_id: str, room_id: str, tag: str, content: JsonDict + ) -> int: + """Add a tag to a room for a user. + + Args: + user_id: The user to add a tag for. + room_id: The room to add a tag for. + tag: The tag name to add. + content: A json object to associate with the tag. + + Returns: + The next account data ID. + """ + if self._instance_name in self._account_data_writers: + max_stream_id = await self._store.add_tag_to_room( + user_id, room_id, tag, content + ) + + self._notifier.on_new_event( + "account_data_key", max_stream_id, users=[user_id] + ) + return max_stream_id + else: + response = await self._add_tag_client( + instance_name=random.choice(self._account_data_writers), + user_id=user_id, + room_id=room_id, + tag=tag, + content=content, + ) + return response["max_stream_id"] + + async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int: + """Remove a tag from a room for a user. + + Returns: + The next account data ID. + """ + if self._instance_name in self._account_data_writers: + max_stream_id = await self._store.remove_tag_from_room( + user_id, room_id, tag + ) + + self._notifier.on_new_event( + "account_data_key", max_stream_id, users=[user_id] + ) + return max_stream_id + else: + response = await self._remove_tag_client( + instance_name=random.choice(self._account_data_writers), + user_id=user_id, + room_id=room_id, + tag=tag, + ) + return response["max_stream_id"] + + class AccountDataEventSource: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index a7550806e6..6bb2fd936b 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -31,8 +31,8 @@ class ReadMarkerHandler(BaseHandler): super().__init__(hs) self.server_name = hs.config.server_name self.store = hs.get_datastore() + self.account_data_handler = hs.get_account_data_handler() self.read_marker_linearizer = Linearizer(name="read_marker") - self.notifier = hs.get_notifier() async def received_client_read_marker( self, room_id: str, user_id: str, event_id: str @@ -59,7 +59,6 @@ class ReadMarkerHandler(BaseHandler): if should_update: content = {"event_id": event_id} - max_id = await self.store.add_account_data_to_room( + await self.account_data_handler.add_account_data_to_room( user_id, room_id, "m.fully_read", content ) - self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index a9abdf42e0..cc21fc2284 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -32,10 +32,26 @@ class ReceiptsHandler(BaseHandler): self.server_name = hs.config.server_name self.store = hs.get_datastore() self.hs = hs - self.federation = hs.get_federation_sender() - hs.get_federation_registry().register_edu_handler( - "m.receipt", self._received_remote_receipt - ) + + # We only need to poke the federation sender explicitly if its on the + # same instance. Other federation sender instances will get notified by + # `synapse.app.generic_worker.FederationSenderHandler` when it sees it + # in the receipts stream. + self.federation_sender = None + if hs.should_send_federation(): + self.federation_sender = hs.get_federation_sender() + + # If we can handle the receipt EDUs we do so, otherwise we route them + # to the appropriate worker. + if hs.get_instance_name() in hs.config.worker.writers.receipts: + hs.get_federation_registry().register_edu_handler( + "m.receipt", self._received_remote_receipt + ) + else: + hs.get_federation_registry().register_instances_for_edu( + "m.receipt", hs.config.worker.writers.receipts, + ) + self.clock = self.hs.get_clock() self.state = hs.get_state_handler() @@ -125,7 +141,8 @@ class ReceiptsHandler(BaseHandler): if not is_new: return - await self.federation.send_read_receipt(receipt) + if self.federation_sender: + await self.federation_sender.send_read_receipt(receipt) class ReceiptEventSource: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index cb5a29bc7e..e001e418f9 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -63,6 +63,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.registration_handler = hs.get_registration_handler() self.profile_handler = hs.get_profile_handler() self.event_creation_handler = hs.get_event_creation_handler() + self.account_data_handler = hs.get_account_data_handler() self.member_linearizer = Linearizer(name="member") @@ -253,7 +254,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): direct_rooms[key].append(new_room_id) # Save back to user's m.direct account data - await self.store.add_account_data_for_user( + await self.account_data_handler.add_account_data_for_user( user_id, AccountDataTypes.DIRECT, direct_rooms ) break @@ -263,7 +264,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # Copy each room tag to the new room for tag, tag_content in room_tags.items(): - await self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content) + await self.account_data_handler.add_tag_to_room( + user_id, new_room_id, tag, tag_content + ) async def update_membership( self, diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index a84a064c8d..dd527e807f 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -15,6 +15,7 @@ from synapse.http.server import JsonResource from synapse.replication.http import ( + account_data, devices, federation, login, @@ -40,6 +41,7 @@ class ReplicationRestResource(JsonResource): presence.register_servlets(hs, self) membership.register_servlets(hs, self) streams.register_servlets(hs, self) + account_data.register_servlets(hs, self) # The following can't currently be instantiated on workers. if hs.config.worker.worker_app is None: diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py new file mode 100644 index 0000000000..52d32528ee --- /dev/null +++ b/synapse/replication/http/account_data.py @@ -0,0 +1,187 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.http.servlet import parse_json_object_from_request +from synapse.replication.http._base import ReplicationEndpoint + +logger = logging.getLogger(__name__) + + +class ReplicationUserAccountDataRestServlet(ReplicationEndpoint): + """Add user account data on the appropriate account data worker. + + Request format: + + POST /_synapse/replication/add_user_account_data/:user_id/:type + + { + "content": { ... }, + } + + """ + + NAME = "add_user_account_data" + PATH_ARGS = ("user_id", "account_data_type") + CACHE = False + + def __init__(self, hs): + super().__init__(hs) + + self.handler = hs.get_account_data_handler() + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload(user_id, account_data_type, content): + payload = { + "content": content, + } + + return payload + + async def _handle_request(self, request, user_id, account_data_type): + content = parse_json_object_from_request(request) + + max_stream_id = await self.handler.add_account_data_for_user( + user_id, account_data_type, content["content"] + ) + + return 200, {"max_stream_id": max_stream_id} + + +class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint): + """Add room account data on the appropriate account data worker. + + Request format: + + POST /_synapse/replication/add_room_account_data/:user_id/:room_id/:account_data_type + + { + "content": { ... }, + } + + """ + + NAME = "add_room_account_data" + PATH_ARGS = ("user_id", "room_id", "account_data_type") + CACHE = False + + def __init__(self, hs): + super().__init__(hs) + + self.handler = hs.get_account_data_handler() + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload(user_id, room_id, account_data_type, content): + payload = { + "content": content, + } + + return payload + + async def _handle_request(self, request, user_id, room_id, account_data_type): + content = parse_json_object_from_request(request) + + max_stream_id = await self.handler.add_account_data_to_room( + user_id, room_id, account_data_type, content["content"] + ) + + return 200, {"max_stream_id": max_stream_id} + + +class ReplicationAddTagRestServlet(ReplicationEndpoint): + """Add tag on the appropriate account data worker. + + Request format: + + POST /_synapse/replication/add_tag/:user_id/:room_id/:tag + + { + "content": { ... }, + } + + """ + + NAME = "add_tag" + PATH_ARGS = ("user_id", "room_id", "tag") + CACHE = False + + def __init__(self, hs): + super().__init__(hs) + + self.handler = hs.get_account_data_handler() + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload(user_id, room_id, tag, content): + payload = { + "content": content, + } + + return payload + + async def _handle_request(self, request, user_id, room_id, tag): + content = parse_json_object_from_request(request) + + max_stream_id = await self.handler.add_tag_to_room( + user_id, room_id, tag, content["content"] + ) + + return 200, {"max_stream_id": max_stream_id} + + +class ReplicationRemoveTagRestServlet(ReplicationEndpoint): + """Remove tag on the appropriate account data worker. + + Request format: + + POST /_synapse/replication/remove_tag/:user_id/:room_id/:tag + + {} + + """ + + NAME = "remove_tag" + PATH_ARGS = ( + "user_id", + "room_id", + "tag", + ) + CACHE = False + + def __init__(self, hs): + super().__init__(hs) + + self.handler = hs.get_account_data_handler() + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload(user_id, room_id, tag): + + return {} + + async def _handle_request(self, request, user_id, room_id, tag): + max_stream_id = await self.handler.remove_tag_from_room(user_id, room_id, tag,) + + return 200, {"max_stream_id": max_stream_id} + + +def register_servlets(hs, http_server): + ReplicationUserAccountDataRestServlet(hs).register(http_server) + ReplicationRoomAccountDataRestServlet(hs).register(http_server) + ReplicationAddTagRestServlet(hs).register(http_server) + ReplicationRemoveTagRestServlet(hs).register(http_server) diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index d0089fe06c..693c9ab901 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -33,9 +33,13 @@ class BaseSlavedStore(CacheInvalidationWorkerStore): database, stream_name="caches", instance_name=hs.get_instance_name(), - table="cache_invalidation_stream_by_instance", - instance_column="instance_name", - id_column="stream_id", + tables=[ + ( + "cache_invalidation_stream_by_instance", + "instance_name", + "stream_id", + ) + ], sequence_name="cache_invalidation_stream_seq", writers=[], ) # type: Optional[MultiWriterIdGenerator] diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index 4268565fc8..21afe5f155 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -15,47 +15,9 @@ # limitations under the License. from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream -from synapse.storage.database import DatabasePool from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.storage.databases.main.tags import TagsWorkerStore class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): - self._account_data_id_gen = SlavedIdTracker( - db_conn, - "account_data", - "stream_id", - extra_tables=[ - ("room_account_data", "stream_id"), - ("room_tags_revisions", "stream_id"), - ], - ) - - super().__init__(database, db_conn, hs) - - def get_max_account_data_stream_id(self): - return self._account_data_id_gen.get_current_token() - - def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) - for row in rows: - self.get_tags_for_user.invalidate((row.user_id,)) - self._account_data_stream_cache.entity_has_changed(row.user_id, token) - elif stream_name == AccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) - for row in rows: - if not row.room_id: - self.get_global_account_data_by_type_for_user.invalidate( - (row.data_type, row.user_id) - ) - self.get_account_data_for_user.invalidate((row.user_id,)) - self.get_account_data_for_room.invalidate((row.user_id, row.room_id)) - self.get_account_data_for_room_and_type.invalidate( - (row.user_id, row.room_id, row.data_type) - ) - self._account_data_stream_cache.entity_has_changed(row.user_id, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) + pass diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 6195917376..3dfdd9961d 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -14,43 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.replication.tcp.streams import ReceiptsStream -from synapse.storage.database import DatabasePool from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): - # We instantiate this first as the ReceiptsWorkerStore constructor - # needs to be able to call get_max_receipt_stream_id - self._receipts_id_gen = SlavedIdTracker( - db_conn, "receipts_linearized", "stream_id" - ) - - super().__init__(database, db_conn, hs) - - def get_max_receipt_stream_id(self): - return self._receipts_id_gen.get_current_token() - - def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): - self.get_receipts_for_user.invalidate((user_id, receipt_type)) - self._get_linearized_receipts_for_room.invalidate_many((room_id,)) - self.get_last_receipt_event_id_for_user.invalidate( - (user_id, room_id, receipt_type) - ) - self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) - self.get_receipts_for_room.invalidate((room_id, receipt_type)) - - def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == ReceiptsStream.NAME: - self._receipts_id_gen.advance(instance_name, token) - for row in rows: - self.invalidate_caches_for_receipt( - row.room_id, row.receipt_type, row.user_id - ) - self._receipts_stream_cache.entity_has_changed(row.room_id, token) - - return super().process_replication_rows(stream_name, instance_name, token, rows) + pass diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 1f89249475..317796d5e0 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -51,11 +51,14 @@ from synapse.replication.tcp.commands import ( from synapse.replication.tcp.protocol import AbstractConnection from synapse.replication.tcp.streams import ( STREAMS_MAP, + AccountDataStream, BackfillStream, CachesStream, EventsStream, FederationStream, + ReceiptsStream, Stream, + TagAccountDataStream, ToDeviceStream, TypingStream, ) @@ -132,6 +135,22 @@ class ReplicationCommandHandler: continue + if isinstance(stream, (AccountDataStream, TagAccountDataStream)): + # Only add AccountDataStream and TagAccountDataStream as a source on the + # instance in charge of account_data persistence. + if hs.get_instance_name() in hs.config.worker.writers.account_data: + self._streams_to_replicate.append(stream) + + continue + + if isinstance(stream, ReceiptsStream): + # Only add ReceiptsStream as a source on the instance in charge of + # receipts. + if hs.get_instance_name() in hs.config.worker.writers.receipts: + self._streams_to_replicate.append(stream) + + continue + # Only add any other streams if we're on master. if hs.config.worker_app is not None: continue diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index 87a5b1b86b..3f28c0bc3e 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -37,24 +37,16 @@ class AccountDataServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() - self.notifier = hs.get_notifier() - self._is_worker = hs.config.worker_app is not None + self.handler = hs.get_account_data_handler() async def on_PUT(self, request, user_id, account_data_type): - if self._is_worker: - raise Exception("Cannot handle PUT /account_data on worker") - requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") body = parse_json_object_from_request(request) - max_id = await self.store.add_account_data_for_user( - user_id, account_data_type, body - ) - - self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) + await self.handler.add_account_data_for_user(user_id, account_data_type, body) return 200, {} @@ -89,13 +81,9 @@ class RoomAccountDataServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() - self.notifier = hs.get_notifier() - self._is_worker = hs.config.worker_app is not None + self.handler = hs.get_account_data_handler() async def on_PUT(self, request, user_id, room_id, account_data_type): - if self._is_worker: - raise Exception("Cannot handle PUT /account_data on worker") - requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") @@ -109,12 +97,10 @@ class RoomAccountDataServlet(RestServlet): " Use /rooms/!roomId:server.name/read_markers", ) - max_id = await self.store.add_account_data_to_room( + await self.handler.add_account_data_to_room( user_id, room_id, account_data_type, body ) - self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) - return 200, {} async def on_GET(self, request, user_id, room_id, account_data_type): diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index bf3a79db44..a97cd66c52 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -58,8 +58,7 @@ class TagServlet(RestServlet): def __init__(self, hs): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.notifier = hs.get_notifier() + self.handler = hs.get_account_data_handler() async def on_PUT(self, request, user_id, room_id, tag): requester = await self.auth.get_user_by_req(request) @@ -68,9 +67,7 @@ class TagServlet(RestServlet): body = parse_json_object_from_request(request) - max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body) - - self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) + await self.handler.add_tag_to_room(user_id, room_id, tag, body) return 200, {} @@ -79,9 +76,7 @@ class TagServlet(RestServlet): if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add tags for other users.") - max_id = await self.store.remove_tag_from_room(user_id, room_id, tag) - - self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) + await self.handler.remove_tag_from_room(user_id, room_id, tag) return 200, {} diff --git a/synapse/server.py b/synapse/server.py index d4c235cda5..9cdda83aa1 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -55,6 +55,7 @@ from synapse.federation.sender import FederationSender from synapse.federation.transport.client import TransportLayerClient from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler +from synapse.handlers.account_data import AccountDataHandler from synapse.handlers.account_validity import AccountValidityHandler from synapse.handlers.acme import AcmeHandler from synapse.handlers.admin import AdminHandler @@ -711,6 +712,10 @@ class HomeServer(metaclass=abc.ABCMeta): def get_module_api(self) -> ModuleApi: return ModuleApi(self, self.get_auth_handler()) + @cache_in_self + def get_account_data_handler(self) -> AccountDataHandler: + return AccountDataHandler(self) + async def remove_pusher(self, app_id: str, push_key: str, user_id: str): return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index c4de07a0a8..ae561a2da3 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -160,9 +160,13 @@ class DataStore( database, stream_name="caches", instance_name=hs.get_instance_name(), - table="cache_invalidation_stream_by_instance", - instance_column="instance_name", - id_column="stream_id", + tables=[ + ( + "cache_invalidation_stream_by_instance", + "instance_name", + "stream_id", + ) + ], sequence_name="cache_invalidation_stream_seq", writers=[], ) diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index bad8260892..68896f34af 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -14,14 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import abc import logging from typing import Dict, List, Optional, Set, Tuple from synapse.api.constants import AccountDataTypes +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool -from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.storage.engines import PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -30,14 +32,57 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) -# The ABCMeta metaclass ensures that it cannot be instantiated without -# the abstract methods being implemented. -class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): +class AccountDataWorkerStore(SQLBaseStore): """This is an abstract base class where subclasses must implement `get_max_account_data_stream_id` which can be called in the initializer. """ def __init__(self, database: DatabasePool, db_conn, hs): + self._instance_name = hs.get_instance_name() + + if isinstance(database.engine, PostgresEngine): + self._can_write_to_account_data = ( + self._instance_name in hs.config.worker.writers.account_data + ) + + self._account_data_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + stream_name="account_data", + instance_name=self._instance_name, + tables=[ + ("room_account_data", "instance_name", "stream_id"), + ("room_tags_revisions", "instance_name", "stream_id"), + ("account_data", "instance_name", "stream_id"), + ], + sequence_name="account_data_sequence", + writers=hs.config.worker.writers.account_data, + ) + else: + self._can_write_to_account_data = True + + # We shouldn't be running in worker mode with SQLite, but its useful + # to support it for unit tests. + # + # If this process is the writer than we need to use + # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets + # updated over replication. (Multiple writers are not supported for + # SQLite). + if hs.get_instance_name() in hs.config.worker.writers.events: + self._account_data_id_gen = StreamIdGenerator( + db_conn, + "room_account_data", + "stream_id", + extra_tables=[("room_tags_revisions", "stream_id")], + ) + else: + self._account_data_id_gen = SlavedIdTracker( + db_conn, + "room_account_data", + "stream_id", + extra_tables=[("room_tags_revisions", "stream_id")], + ) + account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max @@ -45,14 +90,13 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): super().__init__(database, db_conn, hs) - @abc.abstractmethod - def get_max_account_data_stream_id(self): + def get_max_account_data_stream_id(self) -> int: """Get the current max stream ID for account data stream Returns: int """ - raise NotImplementedError() + return self._account_data_id_gen.get_current_token() @cached() async def get_account_data_for_user( @@ -307,25 +351,26 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): ) ) - -class AccountDataStore(AccountDataWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): - self._account_data_id_gen = StreamIdGenerator( - db_conn, - "room_account_data", - "stream_id", - extra_tables=[("room_tags_revisions", "stream_id")], - ) - - super().__init__(database, db_conn, hs) - - def get_max_account_data_stream_id(self) -> int: - """Get the current max stream id for the private user data stream - - Returns: - The maximum stream ID. - """ - return self._account_data_id_gen.get_current_token() + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == TagAccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + for row in rows: + self.get_tags_for_user.invalidate((row.user_id,)) + self._account_data_stream_cache.entity_has_changed(row.user_id, token) + elif stream_name == AccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + for row in rows: + if not row.room_id: + self.get_global_account_data_by_type_for_user.invalidate( + (row.data_type, row.user_id) + ) + self.get_account_data_for_user.invalidate((row.user_id,)) + self.get_account_data_for_room.invalidate((row.user_id, row.room_id)) + self.get_account_data_for_room_and_type.invalidate( + (row.user_id, row.room_id, row.data_type) + ) + self._account_data_stream_cache.entity_has_changed(row.user_id, token) + return super().process_replication_rows(stream_name, instance_name, token, rows) async def add_account_data_to_room( self, user_id: str, room_id: str, account_data_type: str, content: JsonDict @@ -341,6 +386,8 @@ class AccountDataStore(AccountDataWorkerStore): Returns: The maximum stream ID. """ + assert self._can_write_to_account_data + content_json = json_encoder.encode(content) async with self._account_data_id_gen.get_next() as next_id: @@ -381,6 +428,8 @@ class AccountDataStore(AccountDataWorkerStore): Returns: The maximum stream ID. """ + assert self._can_write_to_account_data + async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "add_user_account_data", @@ -463,3 +512,7 @@ class AccountDataStore(AccountDataWorkerStore): # Invalidate the cache for any ignored users which were added or removed. for ignored_user_id in previously_ignored_users ^ currently_ignored_users: self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,)) + + +class AccountDataStore(AccountDataWorkerStore): + pass diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 58d3f71e45..31f70ac5ef 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -54,9 +54,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): db=database, stream_name="to_device", instance_name=self._instance_name, - table="device_inbox", - instance_column="instance_name", - id_column="stream_id", + tables=[("device_inbox", "instance_name", "stream_id")], sequence_name="device_inbox_sequence", writers=hs.config.worker.writers.to_device, ) diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index e5c03cc609..1b657191a9 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -835,6 +835,52 @@ class EventPushActionsWorkerStore(SQLBaseStore): (rotate_to_stream_ordering,), ) + def _remove_old_push_actions_before_txn( + self, txn, room_id, user_id, stream_ordering + ): + """ + Purges old push actions for a user and room before a given + stream_ordering. + + We however keep a months worth of highlighted notifications, so that + users can still get a list of recent highlights. + + Args: + txn: The transcation + room_id: Room ID to delete from + user_id: user ID to delete for + stream_ordering: The lowest stream ordering which will + not be deleted. + """ + txn.call_after( + self.get_unread_event_push_actions_by_room_for_user.invalidate_many, + (room_id, user_id), + ) + + # We need to join on the events table to get the received_ts for + # event_push_actions and sqlite won't let us use a join in a delete so + # we can't just delete where received_ts < x. Furthermore we can + # only identify event_push_actions by a tuple of room_id, event_id + # we we can't use a subquery. + # Instead, we look up the stream ordering for the last event in that + # room received before the threshold time and delete event_push_actions + # in the room with a stream_odering before that. + txn.execute( + "DELETE FROM event_push_actions " + " WHERE user_id = ? AND room_id = ? AND " + " stream_ordering <= ?" + " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)", + (user_id, room_id, stream_ordering, self.stream_ordering_month_ago), + ) + + txn.execute( + """ + DELETE FROM event_push_summary + WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? + """, + (room_id, user_id, stream_ordering), + ) + class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" @@ -894,52 +940,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore): pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) return push_actions - def _remove_old_push_actions_before_txn( - self, txn, room_id, user_id, stream_ordering - ): - """ - Purges old push actions for a user and room before a given - stream_ordering. - - We however keep a months worth of highlighted notifications, so that - users can still get a list of recent highlights. - - Args: - txn: The transcation - room_id: Room ID to delete from - user_id: user ID to delete for - stream_ordering: The lowest stream ordering which will - not be deleted. - """ - txn.call_after( - self.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (room_id, user_id), - ) - - # We need to join on the events table to get the received_ts for - # event_push_actions and sqlite won't let us use a join in a delete so - # we can't just delete where received_ts < x. Furthermore we can - # only identify event_push_actions by a tuple of room_id, event_id - # we we can't use a subquery. - # Instead, we look up the stream ordering for the last event in that - # room received before the threshold time and delete event_push_actions - # in the room with a stream_odering before that. - txn.execute( - "DELETE FROM event_push_actions " - " WHERE user_id = ? AND room_id = ? AND " - " stream_ordering <= ?" - " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)", - (user_id, room_id, stream_ordering, self.stream_ordering_month_ago), - ) - - txn.execute( - """ - DELETE FROM event_push_summary - WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? - """, - (room_id, user_id, stream_ordering), - ) - def _action_has_highlight(actions): for action in actions: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 4732685f6e..71d823be72 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -96,9 +96,7 @@ class EventsWorkerStore(SQLBaseStore): db=database, stream_name="events", instance_name=hs.get_instance_name(), - table="events", - instance_column="instance_name", - id_column="stream_ordering", + tables=[("events", "instance_name", "stream_ordering")], sequence_name="events_stream_seq", writers=hs.config.worker.writers.events, ) @@ -107,9 +105,7 @@ class EventsWorkerStore(SQLBaseStore): db=database, stream_name="backfill", instance_name=hs.get_instance_name(), - table="events", - instance_column="instance_name", - id_column="stream_ordering", + tables=[("events", "instance_name", "stream_ordering")], sequence_name="events_backfill_stream_seq", positive=False, writers=hs.config.worker.writers.events, diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 1e7949a323..e0e57f0578 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -14,15 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import abc import logging from typing import Any, Dict, List, Optional, Tuple from twisted.internet import defer +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool -from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.storage.engines import PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -31,28 +33,56 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) -# The ABCMeta metaclass ensures that it cannot be instantiated without -# the abstract methods being implemented. -class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): - """This is an abstract base class where subclasses must implement - `get_max_receipt_stream_id` which can be called in the initializer. - """ - +class ReceiptsWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): + self._instance_name = hs.get_instance_name() + + if isinstance(database.engine, PostgresEngine): + self._can_write_to_receipts = ( + self._instance_name in hs.config.worker.writers.receipts + ) + + self._receipts_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + stream_name="account_data", + instance_name=self._instance_name, + tables=[("receipts_linearized", "instance_name", "stream_id")], + sequence_name="receipts_sequence", + writers=hs.config.worker.writers.receipts, + ) + else: + self._can_write_to_receipts = True + + # We shouldn't be running in worker mode with SQLite, but its useful + # to support it for unit tests. + # + # If this process is the writer than we need to use + # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets + # updated over replication. (Multiple writers are not supported for + # SQLite). + if hs.get_instance_name() in hs.config.worker.writers.events: + self._receipts_id_gen = StreamIdGenerator( + db_conn, "receipts_linearized", "stream_id" + ) + else: + self._receipts_id_gen = SlavedIdTracker( + db_conn, "receipts_linearized", "stream_id" + ) + super().__init__(database, db_conn, hs) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() ) - @abc.abstractmethod def get_max_receipt_stream_id(self): """Get the current max stream ID for receipts stream Returns: int """ - raise NotImplementedError() + return self._receipts_id_gen.get_current_token() @cached() async def get_users_with_read_receipts_in_room(self, room_id): @@ -428,19 +458,25 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): self.get_users_with_read_receipts_in_room.invalidate((room_id,)) - -class ReceiptsStore(ReceiptsWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): - # We instantiate this first as the ReceiptsWorkerStore constructor - # needs to be able to call get_max_receipt_stream_id - self._receipts_id_gen = StreamIdGenerator( - db_conn, "receipts_linearized", "stream_id" + def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): + self.get_receipts_for_user.invalidate((user_id, receipt_type)) + self._get_linearized_receipts_for_room.invalidate_many((room_id,)) + self.get_last_receipt_event_id_for_user.invalidate( + (user_id, room_id, receipt_type) ) + self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) + self.get_receipts_for_room.invalidate((room_id, receipt_type)) + + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == ReceiptsStream.NAME: + self._receipts_id_gen.advance(instance_name, token) + for row in rows: + self.invalidate_caches_for_receipt( + row.room_id, row.receipt_type, row.user_id + ) + self._receipts_stream_cache.entity_has_changed(row.room_id, token) - super().__init__(database, db_conn, hs) - - def get_max_receipt_stream_id(self): - return self._receipts_id_gen.get_current_token() + return super().process_replication_rows(stream_name, instance_name, token, rows) def insert_linearized_receipt_txn( self, txn, room_id, receipt_type, user_id, event_id, data, stream_id @@ -452,6 +488,8 @@ class ReceiptsStore(ReceiptsWorkerStore): otherwise, the rx timestamp of the event that the RR corresponds to (or 0 if the event is unknown) """ + assert self._can_write_to_receipts + res = self.db_pool.simple_select_one_txn( txn, table="events", @@ -483,28 +521,14 @@ class ReceiptsStore(ReceiptsWorkerStore): ) return None - txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) - txn.call_after( - self._invalidate_get_users_with_receipts_in_room, - room_id, - receipt_type, - user_id, - ) - txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type)) - # FIXME: This shouldn't invalidate the whole cache txn.call_after( - self._get_linearized_receipts_for_room.invalidate_many, (room_id,) + self.invalidate_caches_for_receipt, room_id, receipt_type, user_id ) txn.call_after( self._receipts_stream_cache.entity_has_changed, room_id, stream_id ) - txn.call_after( - self.get_last_receipt_event_id_for_user.invalidate, - (user_id, room_id, receipt_type), - ) - self.db_pool.simple_upsert_txn( txn, table="receipts_linearized", @@ -543,6 +567,8 @@ class ReceiptsStore(ReceiptsWorkerStore): Automatically does conversion between linearized and graph representations. """ + assert self._can_write_to_receipts + if not event_ids: return None @@ -607,6 +633,8 @@ class ReceiptsStore(ReceiptsWorkerStore): async def insert_graph_receipt( self, room_id, receipt_type, user_id, event_ids, data ): + assert self._can_write_to_receipts + return await self.db_pool.runInteraction( "insert_graph_receipt", self.insert_graph_receipt_txn, @@ -620,6 +648,8 @@ class ReceiptsStore(ReceiptsWorkerStore): def insert_graph_receipt_txn( self, txn, room_id, receipt_type, user_id, event_ids, data ): + assert self._can_write_to_receipts + txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) txn.call_after( self._invalidate_get_users_with_receipts_in_room, @@ -653,3 +683,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "data": json_encoder.encode(data), }, ) + + +class ReceiptsStore(ReceiptsWorkerStore): + pass diff --git a/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql new file mode 100644 index 0000000000..46abf8d562 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql @@ -0,0 +1,20 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE room_account_data ADD COLUMN instance_name TEXT; +ALTER TABLE room_tags_revisions ADD COLUMN instance_name TEXT; +ALTER TABLE account_data ADD COLUMN instance_name TEXT; + +ALTER TABLE receipts_linearized ADD COLUMN instance_name TEXT; diff --git a/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres new file mode 100644 index 0000000000..4a6e6c74f5 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres @@ -0,0 +1,32 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE IF NOT EXISTS account_data_sequence; + +-- We need to take the max across all the account_data tables as they share the +-- ID generator +SELECT setval('account_data_sequence', ( + SELECT GREATEST( + (SELECT COALESCE(MAX(stream_id), 1) FROM room_account_data), + (SELECT COALESCE(MAX(stream_id), 1) FROM room_tags_revisions), + (SELECT COALESCE(MAX(stream_id), 1) FROM account_data) + ) +)); + +CREATE SEQUENCE IF NOT EXISTS receipts_sequence; + +SELECT setval('receipts_sequence', ( + SELECT COALESCE(MAX(stream_id), 1) FROM receipts_linearized +)); diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index 74da9c49f2..50067eabfc 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -183,8 +183,6 @@ class TagsWorkerStore(AccountDataWorkerStore): ) return {row["tag"]: db_to_json(row["content"]) for row in rows} - -class TagsStore(TagsWorkerStore): async def add_tag_to_room( self, user_id: str, room_id: str, tag: str, content: JsonDict ) -> int: @@ -199,6 +197,8 @@ class TagsStore(TagsWorkerStore): Returns: The next account data ID. """ + assert self._can_write_to_account_data + content_json = json_encoder.encode(content) def add_tag_txn(txn, next_id): @@ -223,6 +223,7 @@ class TagsStore(TagsWorkerStore): Returns: The next account data ID. """ + assert self._can_write_to_account_data def remove_tag_txn(txn, next_id): sql = ( @@ -250,6 +251,7 @@ class TagsStore(TagsWorkerStore): room_id: The ID of the room. next_id: The the revision to advance to. """ + assert self._can_write_to_account_data txn.call_after( self._account_data_stream_cache.entity_has_changed, user_id, next_id @@ -278,3 +280,7 @@ class TagsStore(TagsWorkerStore): # which stream_id ends up in the table, as long as it is higher # than the id that the client has. pass + + +class TagsStore(TagsWorkerStore): + pass diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 133c0e7a28..39a3ab1162 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -17,7 +17,7 @@ import logging import threading from collections import deque from contextlib import contextmanager -from typing import Dict, List, Optional, Set, Union +from typing import Dict, List, Optional, Set, Tuple, Union import attr from typing_extensions import Deque @@ -186,11 +186,12 @@ class MultiWriterIdGenerator: Args: db_conn db - stream_name: A name for the stream. + stream_name: A name for the stream, for use in the `stream_positions` + table. (Does not need to be the same as the replication stream name) instance_name: The name of this instance. - table: Database table associated with stream. - instance_column: Column that stores the row's writer's instance name - id_column: Column that stores the stream ID. + tables: List of tables associated with the stream. Tuple of table + name, column name that stores the writer's instance name, and + column name that stores the stream ID. sequence_name: The name of the postgres sequence used to generate new IDs. writers: A list of known writers to use to populate current positions @@ -206,9 +207,7 @@ class MultiWriterIdGenerator: db: DatabasePool, stream_name: str, instance_name: str, - table: str, - instance_column: str, - id_column: str, + tables: List[Tuple[str, str, str]], sequence_name: str, writers: List[str], positive: bool = True, @@ -260,15 +259,16 @@ class MultiWriterIdGenerator: self._sequence_gen = PostgresSequenceGenerator(sequence_name) # We check that the table and sequence haven't diverged. - self._sequence_gen.check_consistency( - db_conn, table=table, id_column=id_column, positive=positive - ) + for table, _, id_column in tables: + self._sequence_gen.check_consistency( + db_conn, table=table, id_column=id_column, positive=positive + ) # This goes and fills out the above state from the database. - self._load_current_ids(db_conn, table, instance_column, id_column) + self._load_current_ids(db_conn, tables) def _load_current_ids( - self, db_conn, table: str, instance_column: str, id_column: str + self, db_conn, tables: List[Tuple[str, str, str]], ): cur = db_conn.cursor(txn_name="_load_current_ids") @@ -306,17 +306,22 @@ class MultiWriterIdGenerator: # We add a GREATEST here to ensure that the result is always # positive. (This can be a problem for e.g. backfill streams where # the server has never backfilled). - sql = """ - SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1) - FROM %(table)s - """ % { - "id": id_column, - "table": table, - "agg": "MAX" if self._positive else "-MIN", - } - cur.execute(sql) - (stream_id,) = cur.fetchone() - self._persisted_upto_position = stream_id + max_stream_id = 1 + for table, _, id_column in tables: + sql = """ + SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1) + FROM %(table)s + """ % { + "id": id_column, + "table": table, + "agg": "MAX" if self._positive else "-MIN", + } + cur.execute(sql) + (stream_id,) = cur.fetchone() + + max_stream_id = max(max_stream_id, stream_id) + + self._persisted_upto_position = max_stream_id else: # If we have a min_stream_id then we pull out everything greater # than it from the DB so that we can prefill @@ -329,21 +334,28 @@ class MultiWriterIdGenerator: # stream positions table before restart (or the stream position # table otherwise got out of date). - sql = """ - SELECT %(instance)s, %(id)s FROM %(table)s - WHERE ? %(cmp)s %(id)s - """ % { - "id": id_column, - "table": table, - "instance": instance_column, - "cmp": "<=" if self._positive else ">=", - } - cur.execute(sql, (min_stream_id * self._return_factor,)) - self._persisted_upto_position = min_stream_id + rows = [] + for table, instance_column, id_column in tables: + sql = """ + SELECT %(instance)s, %(id)s FROM %(table)s + WHERE ? %(cmp)s %(id)s + """ % { + "id": id_column, + "table": table, + "instance": instance_column, + "cmp": "<=" if self._positive else ">=", + } + cur.execute(sql, (min_stream_id * self._return_factor,)) + + rows.extend(cur) + + # Sort so that we handle rows in order for each instance. + rows.sort() + with self._lock: - for (instance, stream_id,) in cur: + for (instance, stream_id,) in rows: stream_id = self._return_factor * stream_id self._add_persisted_position(stream_id) diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index cc0612cf65..3e2fd4da01 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -51,9 +51,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.db_pool, stream_name="test_stream", instance_name=instance_name, - table="foobar", - instance_column="instance_name", - id_column="stream_id", + tables=[("foobar", "instance_name", "stream_id")], sequence_name="foobar_seq", writers=writers, ) @@ -487,9 +485,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): self.db_pool, stream_name="test_stream", instance_name=instance_name, - table="foobar", - instance_column="instance_name", - id_column="stream_id", + tables=[("foobar", "instance_name", "stream_id")], sequence_name="foobar_seq", writers=writers, positive=False, @@ -579,3 +575,107 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2}) self.assertEqual(id_gen_1.get_persisted_upto_position(), -2) self.assertEqual(id_gen_2.get_persisted_upto_position(), -2) + + +class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): + if not USE_POSTGRES_FOR_TESTS: + skip = "Requires Postgres" + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.db_pool = self.store.db_pool # type: DatabasePool + + self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) + + def _setup_db(self, txn): + txn.execute("CREATE SEQUENCE foobar_seq") + txn.execute( + """ + CREATE TABLE foobar1 ( + stream_id BIGINT NOT NULL, + instance_name TEXT NOT NULL, + data TEXT + ); + """ + ) + + txn.execute( + """ + CREATE TABLE foobar2 ( + stream_id BIGINT NOT NULL, + instance_name TEXT NOT NULL, + data TEXT + ); + """ + ) + + def _create_id_generator( + self, instance_name="master", writers=["master"] + ) -> MultiWriterIdGenerator: + def _create(conn): + return MultiWriterIdGenerator( + conn, + self.db_pool, + stream_name="test_stream", + instance_name=instance_name, + tables=[ + ("foobar1", "instance_name", "stream_id"), + ("foobar2", "instance_name", "stream_id"), + ], + sequence_name="foobar_seq", + writers=writers, + ) + + return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) + + def _insert_rows( + self, + table: str, + instance_name: str, + number: int, + update_stream_table: bool = True, + ): + """Insert N rows as the given instance, inserting with stream IDs pulled + from the postgres sequence. + """ + + def _insert(txn): + for _ in range(number): + txn.execute( + "INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,), + (instance_name,), + ) + if update_stream_table: + txn.execute( + """ + INSERT INTO stream_positions VALUES ('test_stream', ?, lastval()) + ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval() + """, + (instance_name,), + ) + + self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) + + def test_load_existing_stream(self): + """Test creating ID gens with multiple tables that have rows from after + the position in `stream_positions` table. + """ + self._insert_rows("foobar1", "first", 3) + self._insert_rows("foobar2", "second", 3) + self._insert_rows("foobar2", "second", 1, update_stream_table=False) + + first_id_gen = self._create_id_generator("first", writers=["first", "second"]) + second_id_gen = self._create_id_generator("second", writers=["first", "second"]) + + # The first ID gen will notice that it can advance its token to 7 as it + # has no in progress writes... + self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 6}) + self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7) + self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6) + self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) + + # ... but the second ID gen doesn't know that. + self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) + self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3) + self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7) + self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) -- cgit 1.5.1 From 47d48a5853f3fadd90ace757e4e664097932640a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 19 Jan 2021 14:21:59 -0500 Subject: Validate the server name for the /publicRooms endpoint. (#9161) If a remote server name is provided, ensure it is something reasonable before making remote connections to it. --- changelog.d/9161.bugfix | 1 + synapse/rest/client/v1/room.py | 19 +++++++++++++++++-- 2 files changed, 18 insertions(+), 2 deletions(-) create mode 100644 changelog.d/9161.bugfix (limited to 'synapse/rest') diff --git a/changelog.d/9161.bugfix b/changelog.d/9161.bugfix new file mode 100644 index 0000000000..6798126b7c --- /dev/null +++ b/changelog.d/9161.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug "ValueError: invalid literal for int() with base 10" when `/publicRooms` is requested with an invalid `server` parameter. diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 5647e8c577..e6725b03b0 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -32,6 +32,7 @@ from synapse.api.errors import ( ) from synapse.api.filtering import Filter from synapse.events.utils import format_event_for_client_v2 +from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -347,8 +348,6 @@ class PublicRoomListRestServlet(TransactionRestServlet): # provided. if server: raise e - else: - pass limit = parse_integer(request, "limit", 0) since_token = parse_string(request, "since", None) @@ -359,6 +358,14 @@ class PublicRoomListRestServlet(TransactionRestServlet): handler = self.hs.get_room_list_handler() if server and server != self.hs.config.server_name: + # Ensure the server is valid. + try: + parse_and_validate_server_name(server) + except ValueError: + raise SynapseError( + 400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM, + ) + try: data = await handler.get_remote_public_room_list( server, limit=limit, since_token=since_token @@ -402,6 +409,14 @@ class PublicRoomListRestServlet(TransactionRestServlet): handler = self.hs.get_room_list_handler() if server and server != self.hs.config.server_name: + # Ensure the server is valid. + try: + parse_and_validate_server_name(server) + except ValueError: + raise SynapseError( + 400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM, + ) + try: data = await handler.get_remote_public_room_list( server, -- cgit 1.5.1 From fa50e4bf4ddcb8e98d44700513a28c490f80f02b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 20 Jan 2021 12:30:41 +0000 Subject: Give `public_baseurl` a default value (#9159) --- changelog.d/9159.feature | 1 + docs/sample_config.yaml | 31 +++++++++++++++++-------------- synapse/api/urls.py | 2 -- synapse/config/_base.py | 11 ++++++----- synapse/config/emailconfig.py | 8 -------- synapse/config/oidc_config.py | 2 -- synapse/config/registration.py | 21 ++++----------------- synapse/config/saml2_config.py | 2 -- synapse/config/server.py | 24 +++++++++++++++--------- synapse/config/sso.py | 13 +++++-------- synapse/handlers/identity.py | 2 -- synapse/rest/well_known.py | 4 ---- tests/rest/test_well_known.py | 9 --------- tests/utils.py | 1 - 14 files changed, 48 insertions(+), 83 deletions(-) create mode 100644 changelog.d/9159.feature (limited to 'synapse/rest') diff --git a/changelog.d/9159.feature b/changelog.d/9159.feature new file mode 100644 index 0000000000..b7748757de --- /dev/null +++ b/changelog.d/9159.feature @@ -0,0 +1 @@ +Give the `public_baseurl` a default value, if it is not explicitly set in the configuration file. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index ae995efe9b..7fdd798d70 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -67,11 +67,16 @@ pid_file: DATADIR/homeserver.pid # #web_client_location: https://riot.example.com/ -# The public-facing base URL that clients use to access this HS -# (not including _matrix/...). This is the same URL a user would -# enter into the 'custom HS URL' field on their client. If you -# use synapse with a reverse proxy, this should be the URL to reach -# synapse via the proxy. +# The public-facing base URL that clients use to access this Homeserver (not +# including _matrix/...). This is the same URL a user might enter into the +# 'Custom Homeserver URL' field on their client. If you use Synapse with a +# reverse proxy, this should be the URL to reach Synapse via the proxy. +# Otherwise, it should be the URL to reach Synapse's client HTTP listener (see +# 'listeners' below). +# +# If this is left unset, it defaults to 'https:///'. (Note that +# that will not work unless you configure Synapse or a reverse-proxy to listen +# on port 443.) # #public_baseurl: https://example.com/ @@ -1150,8 +1155,9 @@ account_validity: # send an email to the account's email address with a renewal link. By # default, no such emails are sent. # - # If you enable this setting, you will also need to fill out the 'email' and - # 'public_baseurl' configuration sections. + # If you enable this setting, you will also need to fill out the 'email' + # configuration section. You should also check that 'public_baseurl' is set + # correctly. # #renew_at: 1w @@ -1242,8 +1248,7 @@ account_validity: # The identity server which we suggest that clients should use when users log # in on this server. # -# (By default, no suggestion is made, so it is left up to the client. -# This setting is ignored unless public_baseurl is also set.) +# (By default, no suggestion is made, so it is left up to the client.) # #default_identity_server: https://matrix.org @@ -1268,8 +1273,6 @@ account_validity: # by the Matrix Identity Service API specification: # https://matrix.org/docs/spec/identity_service/latest # -# If a delegate is specified, the config option public_baseurl must also be filled out. -# account_threepid_delegates: #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process @@ -1901,9 +1904,9 @@ sso: # phishing attacks from evil.site. To avoid this, include a slash after the # hostname: "https://my.client/". # - # If public_baseurl is set, then the login fallback page (used by clients - # that don't natively support the required login flows) is whitelisted in - # addition to any URLs in this list. + # The login fallback page (used by clients that don't natively support the + # required login flows) is automatically whitelisted in addition to any URLs + # in this list. # # By default, this list is empty. # diff --git a/synapse/api/urls.py b/synapse/api/urls.py index 6379c86dde..e36aeef31f 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -42,8 +42,6 @@ class ConsentURIBuilder: """ if hs_config.form_secret is None: raise ConfigError("form_secret not set in config") - if hs_config.public_baseurl is None: - raise ConfigError("public_baseurl not set in config") self._hmac_secret = hs_config.form_secret.encode("utf-8") self._public_baseurl = hs_config.public_baseurl diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 2931a88207..94144efc87 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -252,11 +252,12 @@ class Config: env = jinja2.Environment(loader=loader, autoescape=autoescape) # Update the environment with our custom filters - env.filters.update({"format_ts": _format_ts_filter}) - if self.public_baseurl: - env.filters.update( - {"mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl)} - ) + env.filters.update( + { + "format_ts": _format_ts_filter, + "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl), + } + ) for filename in filenames: # Load the template diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index d4328c46b9..6a487afd34 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -166,11 +166,6 @@ class EmailConfig(Config): if not self.email_notif_from: missing.append("email.notif_from") - # public_baseurl is required to build password reset and validation links that - # will be emailed to users - if config.get("public_baseurl") is None: - missing.append("public_baseurl") - if missing: raise ConfigError( MISSING_PASSWORD_RESET_CONFIG_ERROR % (", ".join(missing),) @@ -269,9 +264,6 @@ class EmailConfig(Config): if not self.email_notif_from: missing.append("email.notif_from") - if config.get("public_baseurl") is None: - missing.append("public_baseurl") - if missing: raise ConfigError( "email.enable_notifs is True but required keys are missing: %s" diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 80a24cfbc9..df55367434 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -43,8 +43,6 @@ class OIDCConfig(Config): raise ConfigError(e.message) from e public_baseurl = self.public_baseurl - if public_baseurl is None: - raise ConfigError("oidc_config requires a public_baseurl to be set") self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback" @property diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 740c3fc1b1..4bfc69cb7a 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -49,10 +49,6 @@ class AccountValidityConfig(Config): self.startup_job_max_delta = self.period * 10.0 / 100.0 - if self.renew_by_email_enabled: - if "public_baseurl" not in synapse_config: - raise ConfigError("Can't send renewal emails without 'public_baseurl'") - template_dir = config.get("template_dir") if not template_dir: @@ -109,13 +105,6 @@ class RegistrationConfig(Config): account_threepid_delegates = config.get("account_threepid_delegates") or {} self.account_threepid_delegate_email = account_threepid_delegates.get("email") self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn") - if self.account_threepid_delegate_msisdn and not self.public_baseurl: - raise ConfigError( - "The configuration option `public_baseurl` is required if " - "`account_threepid_delegate.msisdn` is set, such that " - "clients know where to submit validation tokens to. Please " - "configure `public_baseurl`." - ) self.default_identity_server = config.get("default_identity_server") self.allow_guest_access = config.get("allow_guest_access", False) @@ -240,8 +229,9 @@ class RegistrationConfig(Config): # send an email to the account's email address with a renewal link. By # default, no such emails are sent. # - # If you enable this setting, you will also need to fill out the 'email' and - # 'public_baseurl' configuration sections. + # If you enable this setting, you will also need to fill out the 'email' + # configuration section. You should also check that 'public_baseurl' is set + # correctly. # #renew_at: 1w @@ -332,8 +322,7 @@ class RegistrationConfig(Config): # The identity server which we suggest that clients should use when users log # in on this server. # - # (By default, no suggestion is made, so it is left up to the client. - # This setting is ignored unless public_baseurl is also set.) + # (By default, no suggestion is made, so it is left up to the client.) # #default_identity_server: https://matrix.org @@ -358,8 +347,6 @@ class RegistrationConfig(Config): # by the Matrix Identity Service API specification: # https://matrix.org/docs/spec/identity_service/latest # - # If a delegate is specified, the config option public_baseurl must also be filled out. - # account_threepid_delegates: #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index 7b97d4f114..f33dfa0d6a 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -189,8 +189,6 @@ class SAML2Config(Config): import saml2 public_baseurl = self.public_baseurl - if public_baseurl is None: - raise ConfigError("saml2_config requires a public_baseurl to be set") if self.saml2_grandfathered_mxid_source_attribute: optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute) diff --git a/synapse/config/server.py b/synapse/config/server.py index 7242a4aa8e..75ba161f35 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -161,7 +161,11 @@ class ServerConfig(Config): self.print_pidfile = config.get("print_pidfile") self.user_agent_suffix = config.get("user_agent_suffix") self.use_frozen_dicts = config.get("use_frozen_dicts", False) - self.public_baseurl = config.get("public_baseurl") + self.public_baseurl = config.get("public_baseurl") or "https://%s/" % ( + self.server_name, + ) + if self.public_baseurl[-1] != "/": + self.public_baseurl += "/" # Whether to enable user presence. self.use_presence = config.get("use_presence", True) @@ -317,9 +321,6 @@ class ServerConfig(Config): # Always blacklist 0.0.0.0, :: self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) - if self.public_baseurl is not None: - if self.public_baseurl[-1] != "/": - self.public_baseurl += "/" self.start_pushers = config.get("start_pushers", True) # (undocumented) option for torturing the worker-mode replication a bit, @@ -740,11 +741,16 @@ class ServerConfig(Config): # #web_client_location: https://riot.example.com/ - # The public-facing base URL that clients use to access this HS - # (not including _matrix/...). This is the same URL a user would - # enter into the 'custom HS URL' field on their client. If you - # use synapse with a reverse proxy, this should be the URL to reach - # synapse via the proxy. + # The public-facing base URL that clients use to access this Homeserver (not + # including _matrix/...). This is the same URL a user might enter into the + # 'Custom Homeserver URL' field on their client. If you use Synapse with a + # reverse proxy, this should be the URL to reach Synapse via the proxy. + # Otherwise, it should be the URL to reach Synapse's client HTTP listener (see + # 'listeners' below). + # + # If this is left unset, it defaults to 'https:///'. (Note that + # that will not work unless you configure Synapse or a reverse-proxy to listen + # on port 443.) # #public_baseurl: https://example.com/ diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 366f0d4698..59be825532 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -64,11 +64,8 @@ class SSOConfig(Config): # gracefully to the client). This would make it pointless to ask the user for # confirmation, since the URL the confirmation page would be showing wouldn't be # the client's. - # public_baseurl is an optional setting, so we only add the fallback's URL to the - # list if it's provided (because we can't figure out what that URL is otherwise). - if self.public_baseurl: - login_fallback_url = self.public_baseurl + "_matrix/static/client/login" - self.sso_client_whitelist.append(login_fallback_url) + login_fallback_url = self.public_baseurl + "_matrix/static/client/login" + self.sso_client_whitelist.append(login_fallback_url) def generate_config_section(self, **kwargs): return """\ @@ -86,9 +83,9 @@ class SSOConfig(Config): # phishing attacks from evil.site. To avoid this, include a slash after the # hostname: "https://my.client/". # - # If public_baseurl is set, then the login fallback page (used by clients - # that don't natively support the required login flows) is whitelisted in - # addition to any URLs in this list. + # The login fallback page (used by clients that don't natively support the + # required login flows) is automatically whitelisted in addition to any URLs + # in this list. # # By default, this list is empty. # diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index c05036ad1f..f61844d688 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -476,8 +476,6 @@ class IdentityHandler(BaseHandler): except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") - assert self.hs.config.public_baseurl - # we need to tell the client to send the token back to us, since it doesn't # otherwise know where to send it, so add submit_url response parameter # (see also MSC2078) diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index f591cc6c5c..241fe746d9 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -34,10 +34,6 @@ class WellKnownBuilder: self._config = hs.config def get_well_known(self): - # if we don't have a public_baseurl, we can't help much here. - if self._config.public_baseurl is None: - return None - result = {"m.homeserver": {"base_url": self._config.public_baseurl}} if self._config.default_identity_server: diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index 14de0921be..c5e44af9f7 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -40,12 +40,3 @@ class WellKnownTests(unittest.HomeserverTestCase): "m.identity_server": {"base_url": "https://testis"}, }, ) - - def test_well_known_no_public_baseurl(self): - self.hs.config.public_baseurl = None - - channel = self.make_request( - "GET", "/.well-known/matrix/client", shorthand=False - ) - - self.assertEqual(channel.code, 404) diff --git a/tests/utils.py b/tests/utils.py index 977eeaf6ee..09614093bc 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -159,7 +159,6 @@ def default_config(name, parse=False): "remote": {"per_second": 10000, "burst_count": 10000}, }, "saml2_enabled": False, - "public_baseurl": None, "default_identity_server": None, "key_refresh_interval": 24 * 60 * 60 * 1000, "old_signing_keys": {}, -- cgit 1.5.1 From 0cd2938bc854d947ae8102ded688a626c9fac5b5 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 20 Jan 2021 13:15:14 +0000 Subject: Support icons for Identity Providers (#9154) --- changelog.d/9154.feature | 1 + docs/sample_config.yaml | 4 ++ mypy.ini | 1 + synapse/config/oidc_config.py | 20 ++++++ synapse/config/server.py | 2 +- synapse/federation/federation_server.py | 2 +- synapse/federation/transport/server.py | 2 +- synapse/handlers/cas_handler.py | 4 ++ synapse/handlers/oidc_handler.py | 3 + synapse/handlers/room.py | 2 +- synapse/handlers/saml_handler.py | 4 ++ synapse/handlers/sso.py | 5 ++ synapse/http/endpoint.py | 79 --------------------- synapse/res/templates/sso_login_idp_picker.html | 3 + synapse/rest/client/v1/room.py | 3 +- synapse/storage/databases/main/room.py | 6 +- synapse/types.py | 2 +- synapse/util/stringutils.py | 92 +++++++++++++++++++++++++ tests/http/test_endpoint.py | 2 +- 19 files changed, 146 insertions(+), 91 deletions(-) create mode 100644 changelog.d/9154.feature delete mode 100644 synapse/http/endpoint.py (limited to 'synapse/rest') diff --git a/changelog.d/9154.feature b/changelog.d/9154.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9154.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 7fdd798d70..b49a5da8cc 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1726,6 +1726,10 @@ saml2_config: # idp_name: A user-facing name for this identity provider, which is used to # offer the user a choice of login mechanisms. # +# idp_icon: An optional icon for this identity provider, which is presented +# by identity picker pages. If given, must be an MXC URI of the format +# mxc:/// +# # discover: set to 'false' to disable the use of the OIDC discovery mechanism # to discover endpoints. Defaults to true. # diff --git a/mypy.ini b/mypy.ini index b996867121..bd99069c81 100644 --- a/mypy.ini +++ b/mypy.ini @@ -100,6 +100,7 @@ files = synapse/util/async_helpers.py, synapse/util/caches, synapse/util/metrics.py, + synapse/util/stringutils.py, tests/replication, tests/test_utils, tests/handlers/test_password_providers.py, diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index df55367434..f257fcd412 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -23,6 +23,7 @@ from synapse.config._util import validate_config from synapse.python_dependencies import DependencyException, check_requirements from synapse.types import Collection, JsonDict from synapse.util.module_loader import load_module +from synapse.util.stringutils import parse_and_validate_mxc_uri from ._base import Config, ConfigError @@ -66,6 +67,10 @@ class OIDCConfig(Config): # idp_name: A user-facing name for this identity provider, which is used to # offer the user a choice of login mechanisms. # + # idp_icon: An optional icon for this identity provider, which is presented + # by identity picker pages. If given, must be an MXC URI of the format + # mxc:/// + # # discover: set to 'false' to disable the use of the OIDC discovery mechanism # to discover endpoints. Defaults to true. # @@ -207,6 +212,7 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { "properties": { "idp_id": {"type": "string", "minLength": 1, "maxLength": 128}, "idp_name": {"type": "string"}, + "idp_icon": {"type": "string"}, "discover": {"type": "boolean"}, "issuer": {"type": "string"}, "client_id": {"type": "string"}, @@ -336,9 +342,20 @@ def _parse_oidc_config_dict( config_path + ("idp_id",), ) + # MSC2858 also specifies that the idp_icon must be a valid MXC uri + idp_icon = oidc_config.get("idp_icon") + if idp_icon is not None: + try: + parse_and_validate_mxc_uri(idp_icon) + except ValueError as e: + raise ConfigError( + "idp_icon must be a valid MXC URI", config_path + ("idp_icon",) + ) from e + return OidcProviderConfig( idp_id=idp_id, idp_name=oidc_config.get("idp_name", "OIDC"), + idp_icon=idp_icon, discover=oidc_config.get("discover", True), issuer=oidc_config["issuer"], client_id=oidc_config["client_id"], @@ -366,6 +383,9 @@ class OidcProviderConfig: # user-facing name for this identity provider. idp_name = attr.ib(type=str) + # Optional MXC URI for icon for this IdP. + idp_icon = attr.ib(type=Optional[str]) + # whether the OIDC discovery mechanism is used to discover endpoints discover = attr.ib(type=bool) diff --git a/synapse/config/server.py b/synapse/config/server.py index 75ba161f35..47a0370173 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -26,7 +26,7 @@ import yaml from netaddr import IPSet from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.http.endpoint import parse_and_validate_server_name +from synapse.util.stringutils import parse_and_validate_server_name from ._base import Config, ConfigError diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index e5339aca23..171d25c945 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -49,7 +49,6 @@ from synapse.events import EventBase from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction -from synapse.http.endpoint import parse_server_name from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( make_deferred_yieldable, @@ -66,6 +65,7 @@ from synapse.types import JsonDict, get_domain_from_id from synapse.util import glob_to_regex, json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache +from synapse.util.stringutils import parse_server_name if TYPE_CHECKING: from synapse.server import HomeServer diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index cfd094e58f..95c64510a9 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -28,7 +28,6 @@ from synapse.api.urls import ( FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX, ) -from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.server import JsonResource from synapse.http.servlet import ( parse_boolean_from_args, @@ -45,6 +44,7 @@ from synapse.logging.opentracing import ( ) from synapse.server import HomeServer from synapse.types import ThirdPartyInstanceID, get_domain_from_id +from synapse.util.stringutils import parse_and_validate_server_name from synapse.util.versionstring import get_version_string logger = logging.getLogger(__name__) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index f3430c6713..0f342c607b 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -80,6 +80,10 @@ class CasHandler: # user-facing name of this auth provider self.idp_name = "CAS" + # we do not currently support icons for CAS auth, but this is required by + # the SsoIdentityProvider protocol type. + self.idp_icon = None + self._sso_handler = hs.get_sso_handler() self._sso_handler.register_identity_provider(self) diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index ba686d74b2..1607e12935 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -271,6 +271,9 @@ class OidcProvider: # user-facing name of this auth provider self.idp_name = provider.idp_name + # MXC URI for icon for this auth provider + self.idp_icon = provider.idp_icon + self._sso_handler = hs.get_sso_handler() self._sso_handler.register_identity_provider(self) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 3bece6d668..ee27d99135 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -38,7 +38,6 @@ from synapse.api.filtering import Filter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase from synapse.events.utils import copy_power_levels_contents -from synapse.http.endpoint import parse_and_validate_server_name from synapse.storage.state import StateFilter from synapse.types import ( JsonDict, @@ -55,6 +54,7 @@ from synapse.types import ( from synapse.util import stringutils from synapse.util.async_helpers import Linearizer from synapse.util.caches.response_cache import ResponseCache +from synapse.util.stringutils import parse_and_validate_server_name from synapse.visibility import filter_events_for_client from ._base import BaseHandler diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index a8376543c9..38461cf79d 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -78,6 +78,10 @@ class SamlHandler(BaseHandler): # user-facing name of this auth provider self.idp_name = "SAML" + # we do not currently support icons for SAML auth, but this is required by + # the SsoIdentityProvider protocol type. + self.idp_icon = None + # a map from saml session id to Saml2SessionData object self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index dcc85e9871..d493327a10 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -75,6 +75,11 @@ class SsoIdentityProvider(Protocol): def idp_name(self) -> str: """User-facing name for this provider""" + @property + def idp_icon(self) -> Optional[str]: + """Optional MXC URI for user-facing icon""" + return None + @abc.abstractmethod async def handle_redirect_request( self, diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py deleted file mode 100644 index 92a5b606c8..0000000000 --- a/synapse/http/endpoint.py +++ /dev/null @@ -1,79 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import re - -logger = logging.getLogger(__name__) - - -def parse_server_name(server_name): - """Split a server name into host/port parts. - - Args: - server_name (str): server name to parse - - Returns: - Tuple[str, int|None]: host/port parts. - - Raises: - ValueError if the server name could not be parsed. - """ - try: - if server_name[-1] == "]": - # ipv6 literal, hopefully - return server_name, None - - domain_port = server_name.rsplit(":", 1) - domain = domain_port[0] - port = int(domain_port[1]) if domain_port[1:] else None - return domain, port - except Exception: - raise ValueError("Invalid server name '%s'" % server_name) - - -VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z") - - -def parse_and_validate_server_name(server_name): - """Split a server name into host/port parts and do some basic validation. - - Args: - server_name (str): server name to parse - - Returns: - Tuple[str, int|None]: host/port parts. - - Raises: - ValueError if the server name could not be parsed. - """ - host, port = parse_server_name(server_name) - - # these tests don't need to be bulletproof as we'll find out soon enough - # if somebody is giving us invalid data. What we *do* need is to be sure - # that nobody is sneaking IP literals in that look like hostnames, etc. - - # look for ipv6 literals - if host[0] == "[": - if host[-1] != "]": - raise ValueError("Mismatched [...] in server name '%s'" % (server_name,)) - return host, port - - # otherwise it should only be alphanumerics. - if not VALID_HOST_REGEX.match(host): - raise ValueError( - "Server name '%s' contains invalid characters" % (server_name,) - ) - - return host, port diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html index f53c9cd679..5b38481012 100644 --- a/synapse/res/templates/sso_login_idp_picker.html +++ b/synapse/res/templates/sso_login_idp_picker.html @@ -17,6 +17,9 @@
  • +{% if p.idp_icon %} + +{% endif %}
  • {% endfor %} diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index e6725b03b0..f95627ee61 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -32,7 +32,6 @@ from synapse.api.errors import ( ) from synapse.api.filtering import Filter from synapse.events.utils import format_event_for_client_v2 -from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -47,7 +46,7 @@ from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID from synapse.util import json_decoder -from synapse.util.stringutils import random_string +from synapse.util.stringutils import parse_and_validate_server_name, random_string if TYPE_CHECKING: import synapse.server diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 284f2ce77c..a9fcb5f59c 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -16,7 +16,6 @@ import collections import logging -import re from abc import abstractmethod from enum import Enum from typing import Any, Dict, List, Optional, Tuple @@ -30,6 +29,7 @@ from synapse.storage.databases.main.search import SearchStore from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util import json_encoder from synapse.util.caches.descriptors import cached +from synapse.util.stringutils import MXC_REGEX logger = logging.getLogger(__name__) @@ -660,8 +660,6 @@ class RoomWorkerStore(SQLBaseStore): The local and remote media as a lists of tuples where the key is the hostname and the value is the media ID. """ - mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)") - sql = """ SELECT stream_ordering, json FROM events JOIN event_json USING (room_id, event_id) @@ -688,7 +686,7 @@ class RoomWorkerStore(SQLBaseStore): for url in (content_url, thumbnail_url): if not url: continue - matches = mxc_re.match(url) + matches = MXC_REGEX.match(url) if matches: hostname = matches.group(1) media_id = matches.group(2) diff --git a/synapse/types.py b/synapse/types.py index 20a43d05bf..eafe729dfe 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -37,7 +37,7 @@ from signedjson.key import decode_verify_key_bytes from unpaddedbase64 import decode_base64 from synapse.api.errors import Codes, SynapseError -from synapse.http.endpoint import parse_and_validate_server_name +from synapse.util.stringutils import parse_and_validate_server_name if TYPE_CHECKING: from synapse.appservice.api import ApplicationService diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index b103c8694c..f8038bf861 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -18,6 +18,7 @@ import random import re import string from collections.abc import Iterable +from typing import Optional, Tuple from synapse.api.errors import Codes, SynapseError @@ -26,6 +27,15 @@ _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$") +# https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris, +# together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically +# says "there is no grammar for media ids" +# +# The server_name part of this is purposely lax: use parse_and_validate_mxc for +# additional validation. +# +MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$") + # random_string and random_string_with_symbols are used for a range of things, # some cryptographically important, some less so. We use SystemRandom to make sure # we get cryptographically-secure randoms. @@ -59,6 +69,88 @@ def assert_valid_client_secret(client_secret): ) +def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]: + """Split a server name into host/port parts. + + Args: + server_name: server name to parse + + Returns: + host/port parts. + + Raises: + ValueError if the server name could not be parsed. + """ + try: + if server_name[-1] == "]": + # ipv6 literal, hopefully + return server_name, None + + domain_port = server_name.rsplit(":", 1) + domain = domain_port[0] + port = int(domain_port[1]) if domain_port[1:] else None + return domain, port + except Exception: + raise ValueError("Invalid server name '%s'" % server_name) + + +VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z") + + +def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]: + """Split a server name into host/port parts and do some basic validation. + + Args: + server_name: server name to parse + + Returns: + host/port parts. + + Raises: + ValueError if the server name could not be parsed. + """ + host, port = parse_server_name(server_name) + + # these tests don't need to be bulletproof as we'll find out soon enough + # if somebody is giving us invalid data. What we *do* need is to be sure + # that nobody is sneaking IP literals in that look like hostnames, etc. + + # look for ipv6 literals + if host[0] == "[": + if host[-1] != "]": + raise ValueError("Mismatched [...] in server name '%s'" % (server_name,)) + return host, port + + # otherwise it should only be alphanumerics. + if not VALID_HOST_REGEX.match(host): + raise ValueError( + "Server name '%s' contains invalid characters" % (server_name,) + ) + + return host, port + + +def parse_and_validate_mxc_uri(mxc: str) -> Tuple[str, Optional[int], str]: + """Parse the given string as an MXC URI + + Checks that the "server name" part is a valid server name + + Args: + mxc: the (alleged) MXC URI to be checked + Returns: + hostname, port, media id + Raises: + ValueError if the URI cannot be parsed + """ + m = MXC_REGEX.match(mxc) + if not m: + raise ValueError("mxc URI %r did not match expected format" % (mxc,)) + server_name = m.group(1) + media_id = m.group(2) + host, port = parse_and_validate_server_name(server_name) + return host, port, media_id + + def shortstr(iterable: Iterable, maxitems: int = 5) -> str: """If iterable has maxitems or fewer, return the stringification of a list containing those items. diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py index b2e9533b07..d06ea518ce 100644 --- a/tests/http/test_endpoint.py +++ b/tests/http/test_endpoint.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.http.endpoint import parse_and_validate_server_name, parse_server_name +from synapse.util.stringutils import parse_and_validate_server_name, parse_server_name from tests import unittest -- cgit 1.5.1 From c55e62548c0fddd49e7182133880d2ccb03dbb42 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 21 Jan 2021 15:18:46 +0100 Subject: Add tests for List Users Admin API (#9045) --- changelog.d/9045.misc | 1 + synapse/rest/admin/users.py | 21 +++- tests/rest/admin/test_user.py | 223 +++++++++++++++++++++++++++++++++++++----- 3 files changed, 215 insertions(+), 30 deletions(-) create mode 100644 changelog.d/9045.misc (limited to 'synapse/rest') diff --git a/changelog.d/9045.misc b/changelog.d/9045.misc new file mode 100644 index 0000000000..7f1886a0de --- /dev/null +++ b/changelog.d/9045.misc @@ -0,0 +1 @@ +Add tests to `test_user.UsersListTestCase` for List Users Admin API. \ No newline at end of file diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index f39e3d6d5c..86198bab30 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -83,17 +83,32 @@ class UsersRestServletV2(RestServlet): The parameter `deactivated` can be used to include deactivated users. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) start = parse_integer(request, "from", default=0) limit = parse_integer(request, "limit", default=100) + + if start < 0: + raise SynapseError( + 400, + "Query parameter from must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + 400, + "Query parameter limit must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + user_id = parse_string(request, "user_id", default=None) name = parse_string(request, "name", default=None) guests = parse_boolean(request, "guests", default=True) @@ -103,7 +118,7 @@ class UsersRestServletV2(RestServlet): start, limit, user_id, name, guests, deactivated ) ret = {"users": users, "total": total} - if len(users) >= limit: + if (start + limit) < total: ret["next_token"] = str(start + len(users)) return 200, ret diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 04599c2fcf..e48f8c1d7b 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -28,6 +28,7 @@ from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.room_versions import RoomVersions from synapse.rest.client.v1 import login, logout, profile, room from synapse.rest.client.v2_alpha import devices, sync +from synapse.types import JsonDict from tests import unittest from tests.test_utils import make_awaitable @@ -468,13 +469,6 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - self.user1 = self.register_user( - "user1", "pass1", admin=False, displayname="Name 1" - ) - self.user2 = self.register_user( - "user2", "pass2", admin=False, displayname="Name 2" - ) - def test_no_auth(self): """ Try to list users without authentication. @@ -488,6 +482,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): """ If the user is not a server admin, an error is returned. """ + self._create_users(1) other_user_token = self.login("user1", "pass1") channel = self.make_request("GET", self.url, access_token=other_user_token) @@ -499,6 +494,8 @@ class UsersListTestCase(unittest.HomeserverTestCase): """ List all users, including deactivated users. """ + self._create_users(2) + channel = self.make_request( "GET", self.url + "?deactivated=true", @@ -511,14 +508,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(3, channel.json_body["total"]) # Check that all fields are available - for u in channel.json_body["users"]: - self.assertIn("name", u) - self.assertIn("is_guest", u) - self.assertIn("admin", u) - self.assertIn("user_type", u) - self.assertIn("deactivated", u) - self.assertIn("displayname", u) - self.assertIn("avatar_url", u) + self._check_fields(channel.json_body["users"]) def test_search_term(self): """Test that searching for a users works correctly""" @@ -549,6 +539,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): # Check that users were returned self.assertTrue("users" in channel.json_body) + self._check_fields(channel.json_body["users"]) users = channel.json_body["users"] # Check that the expected number of users were returned @@ -561,25 +552,30 @@ class UsersListTestCase(unittest.HomeserverTestCase): u = users[0] self.assertEqual(expected_user_id, u["name"]) + self._create_users(2) + + user1 = "@user1:test" + user2 = "@user2:test" + # Perform search tests - _search_test(self.user1, "er1") - _search_test(self.user1, "me 1") + _search_test(user1, "er1") + _search_test(user1, "me 1") - _search_test(self.user2, "er2") - _search_test(self.user2, "me 2") + _search_test(user2, "er2") + _search_test(user2, "me 2") - _search_test(self.user1, "er1", "user_id") - _search_test(self.user2, "er2", "user_id") + _search_test(user1, "er1", "user_id") + _search_test(user2, "er2", "user_id") # Test case insensitive - _search_test(self.user1, "ER1") - _search_test(self.user1, "NAME 1") + _search_test(user1, "ER1") + _search_test(user1, "NAME 1") - _search_test(self.user2, "ER2") - _search_test(self.user2, "NAME 2") + _search_test(user2, "ER2") + _search_test(user2, "NAME 2") - _search_test(self.user1, "ER1", "user_id") - _search_test(self.user2, "ER2", "user_id") + _search_test(user1, "ER1", "user_id") + _search_test(user2, "ER2", "user_id") _search_test(None, "foo") _search_test(None, "bar") @@ -587,6 +583,179 @@ class UsersListTestCase(unittest.HomeserverTestCase): _search_test(None, "foo", "user_id") _search_test(None, "bar", "user_id") + def test_invalid_parameter(self): + """ + If parameters are invalid, an error is returned. + """ + + # negative limit + channel = self.make_request( + "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # negative from + channel = self.make_request( + "GET", self.url + "?from=-5", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # invalid guests + channel = self.make_request( + "GET", self.url + "?guests=not_bool", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + # invalid deactivated + channel = self.make_request( + "GET", self.url + "?deactivated=not_bool", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + def test_limit(self): + """ + Testing list of users with limit + """ + + number_users = 20 + # Create one less user (since there's already an admin user). + self._create_users(number_users - 1) + + channel = self.make_request( + "GET", self.url + "?limit=5", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), 5) + self.assertEqual(channel.json_body["next_token"], "5") + self._check_fields(channel.json_body["users"]) + + def test_from(self): + """ + Testing list of users with a defined starting point (from) + """ + + number_users = 20 + # Create one less user (since there's already an admin user). + self._create_users(number_users - 1) + + channel = self.make_request( + "GET", self.url + "?from=5", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), 15) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["users"]) + + def test_limit_and_from(self): + """ + Testing list of users with a defined starting point and limit + """ + + number_users = 20 + # Create one less user (since there's already an admin user). + self._create_users(number_users - 1) + + channel = self.make_request( + "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(channel.json_body["next_token"], "15") + self.assertEqual(len(channel.json_body["users"]), 10) + self._check_fields(channel.json_body["users"]) + + def test_next_token(self): + """ + Testing that `next_token` appears at the right place + """ + + number_users = 20 + # Create one less user (since there's already an admin user). + self._create_users(number_users - 1) + + # `next_token` does not appear + # Number of results is the number of entries + channel = self.make_request( + "GET", self.url + "?limit=20", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), number_users) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does not appear + # Number of max results is larger than the number of entries + channel = self.make_request( + "GET", self.url + "?limit=21", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), number_users) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does appear + # Number of max results is smaller than the number of entries + channel = self.make_request( + "GET", self.url + "?limit=19", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), 19) + self.assertEqual(channel.json_body["next_token"], "19") + + # Check + # Set `from` to value of `next_token` for request remaining entries + # `next_token` does not appear + channel = self.make_request( + "GET", self.url + "?from=19", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), 1) + self.assertNotIn("next_token", channel.json_body) + + def _check_fields(self, content: JsonDict): + """Checks that the expected user attributes are present in content + Args: + content: List that is checked for content + """ + for u in content: + self.assertIn("name", u) + self.assertIn("is_guest", u) + self.assertIn("admin", u) + self.assertIn("user_type", u) + self.assertIn("deactivated", u) + self.assertIn("displayname", u) + self.assertIn("avatar_url", u) + + def _create_users(self, number_users: int): + """ + Create a number of users + Args: + number_users: Number of users to be created + """ + for i in range(1, number_users + 1): + self.register_user( + "user%d" % i, "pass%d" % i, admin=False, displayname="Name %d" % i, + ) + class DeactivateAccountTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From a7882f98874684969910d3a6ed7d85f99114cc45 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 21 Jan 2021 14:53:58 -0500 Subject: Return a 404 if no valid thumbnail is found. (#9163) If no thumbnail of the requested type exists, return a 404 instead of erroring. This doesn't quite match the spec (which does not define what happens if no thumbnail can be found), but is consistent with what Synapse already does. --- changelog.d/9163.bugfix | 1 + synapse/rest/media/v1/_base.py | 3 + synapse/rest/media/v1/thumbnail_resource.py | 236 ++++++++++++++++++---------- tests/rest/media/v1/test_media_storage.py | 25 ++- 4 files changed, 183 insertions(+), 82 deletions(-) create mode 100644 changelog.d/9163.bugfix (limited to 'synapse/rest') diff --git a/changelog.d/9163.bugfix b/changelog.d/9163.bugfix new file mode 100644 index 0000000000..c51cf6ca80 --- /dev/null +++ b/changelog.d/9163.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where Synapse would return a 500 error when a thumbnail did not exist (and auto-generation of thumbnails was not enabled). diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 31a41e4a27..f71a03a12d 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -300,6 +300,7 @@ class FileInfo: thumbnail_height (int) thumbnail_method (str) thumbnail_type (str): Content type of thumbnail, e.g. image/png + thumbnail_length (int): The size of the media file, in bytes. """ def __init__( @@ -312,6 +313,7 @@ class FileInfo: thumbnail_height=None, thumbnail_method=None, thumbnail_type=None, + thumbnail_length=None, ): self.server_name = server_name self.file_id = file_id @@ -321,6 +323,7 @@ class FileInfo: self.thumbnail_height = thumbnail_height self.thumbnail_method = thumbnail_method self.thumbnail_type = thumbnail_type + self.thumbnail_length = thumbnail_length def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]: diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index d6880f2e6e..d653a58be9 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -16,7 +16,7 @@ import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional from twisted.web.http import Request @@ -106,31 +106,17 @@ class ThumbnailResource(DirectServeJsonResource): return thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) - - if thumbnail_infos: - thumbnail_info = self._select_thumbnail( - width, height, method, m_type, thumbnail_infos - ) - - file_info = FileInfo( - server_name=None, - file_id=media_id, - url_cache=media_info["url_cache"], - thumbnail=True, - thumbnail_width=thumbnail_info["thumbnail_width"], - thumbnail_height=thumbnail_info["thumbnail_height"], - thumbnail_type=thumbnail_info["thumbnail_type"], - thumbnail_method=thumbnail_info["thumbnail_method"], - ) - - t_type = file_info.thumbnail_type - t_length = thumbnail_info["thumbnail_length"] - - responder = await self.media_storage.fetch_media(file_info) - await respond_with_responder(request, responder, t_type, t_length) - else: - logger.info("Couldn't find any generated thumbnails") - respond_404(request) + await self._select_and_respond_with_thumbnail( + request, + width, + height, + method, + m_type, + thumbnail_infos, + media_id, + url_cache=media_info["url_cache"], + server_name=None, + ) async def _select_or_generate_local_thumbnail( self, @@ -276,26 +262,64 @@ class ThumbnailResource(DirectServeJsonResource): thumbnail_infos = await self.store.get_remote_media_thumbnails( server_name, media_id ) + await self._select_and_respond_with_thumbnail( + request, + width, + height, + method, + m_type, + thumbnail_infos, + media_info["filesystem_id"], + url_cache=None, + server_name=server_name, + ) + async def _select_and_respond_with_thumbnail( + self, + request: Request, + desired_width: int, + desired_height: int, + desired_method: str, + desired_type: str, + thumbnail_infos: List[Dict[str, Any]], + file_id: str, + url_cache: Optional[str] = None, + server_name: Optional[str] = None, + ) -> None: + """ + Respond to a request with an appropriate thumbnail from the previously generated thumbnails. + + Args: + request: The incoming request. + desired_width: The desired width, the returned thumbnail may be larger than this. + desired_height: The desired height, the returned thumbnail may be larger than this. + desired_method: The desired method used to generate the thumbnail. + desired_type: The desired content-type of the thumbnail. + thumbnail_infos: A list of dictionaries of candidate thumbnails. + file_id: The ID of the media that a thumbnail is being requested for. + url_cache: The URL cache value. + server_name: The server name, if this is a remote thumbnail. + """ if thumbnail_infos: - thumbnail_info = self._select_thumbnail( - width, height, method, m_type, thumbnail_infos + file_info = self._select_thumbnail( + desired_width, + desired_height, + desired_method, + desired_type, + thumbnail_infos, + file_id, + url_cache, + server_name, ) - file_info = FileInfo( - server_name=server_name, - file_id=media_info["filesystem_id"], - thumbnail=True, - thumbnail_width=thumbnail_info["thumbnail_width"], - thumbnail_height=thumbnail_info["thumbnail_height"], - thumbnail_type=thumbnail_info["thumbnail_type"], - thumbnail_method=thumbnail_info["thumbnail_method"], - ) - - t_type = file_info.thumbnail_type - t_length = thumbnail_info["thumbnail_length"] + if not file_info: + logger.info("Couldn't find a thumbnail matching the desired inputs") + respond_404(request) + return responder = await self.media_storage.fetch_media(file_info) - await respond_with_responder(request, responder, t_type, t_length) + await respond_with_responder( + request, responder, file_info.thumbnail_type, file_info.thumbnail_length + ) else: logger.info("Failed to find any generated thumbnails") respond_404(request) @@ -306,67 +330,117 @@ class ThumbnailResource(DirectServeJsonResource): desired_height: int, desired_method: str, desired_type: str, - thumbnail_infos, - ) -> dict: + thumbnail_infos: List[Dict[str, Any]], + file_id: str, + url_cache: Optional[str], + server_name: Optional[str], + ) -> Optional[FileInfo]: + """ + Choose an appropriate thumbnail from the previously generated thumbnails. + + Args: + desired_width: The desired width, the returned thumbnail may be larger than this. + desired_height: The desired height, the returned thumbnail may be larger than this. + desired_method: The desired method used to generate the thumbnail. + desired_type: The desired content-type of the thumbnail. + thumbnail_infos: A list of dictionaries of candidate thumbnails. + file_id: The ID of the media that a thumbnail is being requested for. + url_cache: The URL cache value. + server_name: The server name, if this is a remote thumbnail. + + Returns: + The thumbnail which best matches the desired parameters. + """ + desired_method = desired_method.lower() + + # The chosen thumbnail. + thumbnail_info = None + d_w = desired_width d_h = desired_height - if desired_method.lower() == "crop": + if desired_method == "crop": + # Thumbnails that match equal or larger sizes of desired width/height. crop_info_list = [] + # Other thumbnails. crop_info_list2 = [] for info in thumbnail_infos: + # Skip thumbnails generated with different methods. + if info["thumbnail_method"] != "crop": + continue + t_w = info["thumbnail_width"] t_h = info["thumbnail_height"] - t_method = info["thumbnail_method"] - if t_method == "crop": - aspect_quality = abs(d_w * t_h - d_h * t_w) - min_quality = 0 if d_w <= t_w and d_h <= t_h else 1 - size_quality = abs((d_w - t_w) * (d_h - t_h)) - type_quality = desired_type != info["thumbnail_type"] - length_quality = info["thumbnail_length"] - if t_w >= d_w or t_h >= d_h: - crop_info_list.append( - ( - aspect_quality, - min_quality, - size_quality, - type_quality, - length_quality, - info, - ) + aspect_quality = abs(d_w * t_h - d_h * t_w) + min_quality = 0 if d_w <= t_w and d_h <= t_h else 1 + size_quality = abs((d_w - t_w) * (d_h - t_h)) + type_quality = desired_type != info["thumbnail_type"] + length_quality = info["thumbnail_length"] + if t_w >= d_w or t_h >= d_h: + crop_info_list.append( + ( + aspect_quality, + min_quality, + size_quality, + type_quality, + length_quality, + info, ) - else: - crop_info_list2.append( - ( - aspect_quality, - min_quality, - size_quality, - type_quality, - length_quality, - info, - ) + ) + else: + crop_info_list2.append( + ( + aspect_quality, + min_quality, + size_quality, + type_quality, + length_quality, + info, ) + ) if crop_info_list: - return min(crop_info_list)[-1] - else: - return min(crop_info_list2)[-1] - else: + thumbnail_info = min(crop_info_list)[-1] + elif crop_info_list2: + thumbnail_info = min(crop_info_list2)[-1] + elif desired_method == "scale": + # Thumbnails that match equal or larger sizes of desired width/height. info_list = [] + # Other thumbnails. info_list2 = [] + for info in thumbnail_infos: + # Skip thumbnails generated with different methods. + if info["thumbnail_method"] != "scale": + continue + t_w = info["thumbnail_width"] t_h = info["thumbnail_height"] - t_method = info["thumbnail_method"] size_quality = abs((d_w - t_w) * (d_h - t_h)) type_quality = desired_type != info["thumbnail_type"] length_quality = info["thumbnail_length"] - if t_method == "scale" and (t_w >= d_w or t_h >= d_h): + if t_w >= d_w or t_h >= d_h: info_list.append((size_quality, type_quality, length_quality, info)) - elif t_method == "scale": + else: info_list2.append( (size_quality, type_quality, length_quality, info) ) if info_list: - return min(info_list)[-1] - else: - return min(info_list2)[-1] + thumbnail_info = min(info_list)[-1] + elif info_list2: + thumbnail_info = min(info_list2)[-1] + + if thumbnail_info: + return FileInfo( + file_id=file_id, + url_cache=url_cache, + server_name=server_name, + thumbnail=True, + thumbnail_width=thumbnail_info["thumbnail_width"], + thumbnail_height=thumbnail_info["thumbnail_height"], + thumbnail_type=thumbnail_info["thumbnail_type"], + thumbnail_method=thumbnail_info["thumbnail_method"], + thumbnail_length=thumbnail_info["thumbnail_length"], + ) + + # No matching thumbnail was found. + return None diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index ae2b32b131..a6c6985173 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -202,7 +202,6 @@ class MediaRepoTests(unittest.HomeserverTestCase): config = self.default_config() config["media_store_path"] = self.media_store_path - config["thumbnail_requirements"] = {} config["max_image_pixels"] = 2000000 provider_config = { @@ -313,15 +312,39 @@ class MediaRepoTests(unittest.HomeserverTestCase): self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None) def test_thumbnail_crop(self): + """Test that a cropped remote thumbnail is available.""" self._test_thumbnail( "crop", self.test_image.expected_cropped, self.test_image.expected_found ) def test_thumbnail_scale(self): + """Test that a scaled remote thumbnail is available.""" self._test_thumbnail( "scale", self.test_image.expected_scaled, self.test_image.expected_found ) + def test_invalid_type(self): + """An invalid thumbnail type is never available.""" + self._test_thumbnail("invalid", None, False) + + @unittest.override_config( + {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]} + ) + def test_no_thumbnail_crop(self): + """ + Override the config to generate only scaled thumbnails, but request a cropped one. + """ + self._test_thumbnail("crop", None, False) + + @unittest.override_config( + {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]} + ) + def test_no_thumbnail_scale(self): + """ + Override the config to generate only cropped thumbnails, but request a scaled one. + """ + self._test_thumbnail("scale", None, False) + def _test_thumbnail(self, method, expected_body, expected_found): params = "?width=32&height=32&method=" + method channel = make_request( -- cgit 1.5.1 From 4a55d267eef1388690e6781b580910e341358f95 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 25 Jan 2021 14:49:39 -0500 Subject: Add an admin API for shadow-banning users. (#9209) This expands the current shadow-banning feature to be usable via the admin API and adds documentation for it. A shadow-banned users receives successful responses to their client-server API requests, but the events are not propagated into rooms. Shadow-banning a user should be used as a tool of last resort and may lead to confusing or broken behaviour for the client. --- changelog.d/9209.feature | 1 + docs/admin_api/user_admin_api.rst | 30 ++++++++++++ stubs/txredisapi.pyi | 1 - synapse/rest/admin/__init__.py | 2 + synapse/rest/admin/users.py | 36 +++++++++++++++ synapse/storage/databases/main/registration.py | 29 ++++++++++++ tests/rest/admin/test_user.py | 64 ++++++++++++++++++++++++++ tests/rest/client/test_shadow_banned.py | 8 +--- 8 files changed, 164 insertions(+), 7 deletions(-) create mode 100644 changelog.d/9209.feature (limited to 'synapse/rest') diff --git a/changelog.d/9209.feature b/changelog.d/9209.feature new file mode 100644 index 0000000000..ec926e8eb4 --- /dev/null +++ b/changelog.d/9209.feature @@ -0,0 +1 @@ +Add an admin API endpoint for shadow-banning users. diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index b3d413cf57..1eb674939e 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -760,3 +760,33 @@ The following fields are returned in the JSON response body: - ``total`` - integer - Number of pushers. See also `Client-Server API Spec `_ + +Shadow-banning users +==================== + +Shadow-banning is a useful tool for moderating malicious or egregiously abusive users. +A shadow-banned users receives successful responses to their client-server API requests, +but the events are not propagated into rooms. This can be an effective tool as it +(hopefully) takes longer for the user to realise they are being moderated before +pivoting to another account. + +Shadow-banning a user should be used as a tool of last resort and may lead to confusing +or broken behaviour for the client. A shadow-banned user will not receive any +notification and it is generally more appropriate to ban or kick abusive users. +A shadow-banned user will be unable to contact anyone on the server. + +The API is:: + + POST /_synapse/admin/v1/users//shadow_ban + +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see `README.rst `_. + +An empty JSON dict is returned. + +**Parameters** + +The following parameters should be set in the URL: + +- ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must + be local. diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi index bfac6840e6..726454ba31 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi @@ -15,7 +15,6 @@ """Contains *incomplete* type hints for txredisapi. """ - from typing import List, Optional, Type, Union class RedisProtocol: diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 6f7dc06503..f04740cd38 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -51,6 +51,7 @@ from synapse.rest.admin.users import ( PushersRestServlet, ResetPasswordRestServlet, SearchUsersRestServlet, + ShadowBanRestServlet, UserAdminServlet, UserMediaRestServlet, UserMembershipRestServlet, @@ -230,6 +231,7 @@ def register_servlets(hs, http_server): EventReportsRestServlet(hs).register(http_server) PushersRestServlet(hs).register(http_server) MakeRoomAdminRestServlet(hs).register(http_server) + ShadowBanRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 86198bab30..68c3c64a0d 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -890,3 +890,39 @@ class UserTokenRestServlet(RestServlet): ) return 200, {"access_token": token} + + +class ShadowBanRestServlet(RestServlet): + """An admin API for shadow-banning a user. + + A shadow-banned users receives successful responses to their client-server + API requests, but the events are not propagated into rooms. + + Shadow-banning a user should be used as a tool of last resort and may lead + to confusing or broken behaviour for the client. + + Example: + + POST /_synapse/admin/v1/users/@test:example.com/shadow_ban + {} + + 200 OK + {} + """ + + PATTERNS = admin_patterns("/users/(?P[^/]*)/shadow_ban") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + async def on_POST(self, request, user_id): + await assert_requester_is_admin(self.auth, request) + + if not self.hs.is_mine_id(user_id): + raise SynapseError(400, "Only local users can be shadow-banned") + + await self.store.set_shadow_banned(UserID.from_string(user_id), True) + + return 200, {} diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 585b4049d6..0618b4387a 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -360,6 +360,35 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) + async def set_shadow_banned(self, user: UserID, shadow_banned: bool) -> None: + """Sets whether a user shadow-banned. + + Args: + user: user ID of the user to test + shadow_banned: true iff the user is to be shadow-banned, false otherwise. + """ + + def set_shadow_banned_txn(txn): + self.db_pool.simple_update_one_txn( + txn, + table="users", + keyvalues={"name": user.to_string()}, + updatevalues={"shadow_banned": shadow_banned}, + ) + # In order for this to apply immediately, clear the cache for this user. + tokens = self.db_pool.simple_select_onecol_txn( + txn, + table="access_tokens", + keyvalues={"user_id": user.to_string()}, + retcol="token", + ) + for token in tokens: + self._invalidate_cache_and_stream( + txn, self.get_user_by_access_token, (token,) + ) + + await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn) + def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]: sql = """ SELECT users.name as user_id, diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index e48f8c1d7b..ee05ee60bc 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2380,3 +2380,67 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) + + +class ShadowBanRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + + self.url = "/_synapse/admin/v1/users/%s/shadow_ban" % urllib.parse.quote( + self.other_user + ) + + def test_no_auth(self): + """ + Try to get information of an user without authentication. + """ + channel = self.make_request("POST", self.url) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_not_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + other_user_token = self.login("user", "pass") + + channel = self.make_request("POST", self.url, access_token=other_user_token) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that shadow-banning for a user that is not a local returns a 400 + """ + url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain" + + channel = self.make_request("POST", url, access_token=self.admin_user_tok) + self.assertEqual(400, channel.code, msg=channel.json_body) + + def test_success(self): + """ + Shadow-banning should succeed for an admin. + """ + # The user starts off as not shadow-banned. + other_user_token = self.login("user", "pass") + result = self.get_success(self.store.get_user_by_access_token(other_user_token)) + self.assertFalse(result.shadow_banned) + + channel = self.make_request("POST", self.url, access_token=self.admin_user_tok) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual({}, channel.json_body) + + # Ensure the user is shadow-banned (and the cache was cleared). + result = self.get_success(self.store.get_user_by_access_token(other_user_token)) + self.assertTrue(result.shadow_banned) diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index e689c3fbea..0ebdf1415b 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -18,6 +18,7 @@ import synapse.rest.admin from synapse.api.constants import EventTypes from synapse.rest.client.v1 import directory, login, profile, room from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet +from synapse.types import UserID from tests import unittest @@ -31,12 +32,7 @@ class _ShadowBannedBase(unittest.HomeserverTestCase): self.store = self.hs.get_datastore() self.get_success( - self.store.db_pool.simple_update( - table="users", - keyvalues={"name": self.banned_user_id}, - updatevalues={"shadow_banned": True}, - desc="shadow_ban", - ) + self.store.set_shadow_banned(UserID.from_string(self.banned_user_id), True) ) self.other_user_id = self.register_user("otheruser", "pass") -- cgit 1.5.1 From 4937fe3d6be94222b02760866496781f8cc88751 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 26 Jan 2021 07:32:17 -0500 Subject: Try to recover from unknown encodings when previewing media. (#9164) Treat unknown encodings (according to lxml) as UTF-8 when generating a preview for HTML documents. This isn't fully accurate, but will hopefully give a reasonable title and summary. --- changelog.d/9164.bugfix | 1 + synapse/rest/media/v1/preview_url_resource.py | 44 +++++++++++++++++++++------ tests/test_preview.py | 29 ++++++++++++++++++ 3 files changed, 64 insertions(+), 10 deletions(-) create mode 100644 changelog.d/9164.bugfix (limited to 'synapse/rest') diff --git a/changelog.d/9164.bugfix b/changelog.d/9164.bugfix new file mode 100644 index 0000000000..1c54a256c1 --- /dev/null +++ b/changelog.d/9164.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where an internal server error was raised when attempting to preview an HTML document in an unknown character encoding. diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index a632099167..bf3be653aa 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -386,7 +386,7 @@ class PreviewUrlResource(DirectServeJsonResource): """ Check whether the URL should be downloaded as oEmbed content instead. - Params: + Args: url: The URL to check. Returns: @@ -403,7 +403,7 @@ class PreviewUrlResource(DirectServeJsonResource): """ Request content from an oEmbed endpoint. - Params: + Args: endpoint: The oEmbed API endpoint. url: The URL to pass to the API. @@ -692,27 +692,51 @@ class PreviewUrlResource(DirectServeJsonResource): def decode_and_calc_og( body: bytes, media_uri: str, request_encoding: Optional[str] = None ) -> Dict[str, Optional[str]]: + """ + Calculate metadata for an HTML document. + + This uses lxml to parse the HTML document into the OG response. If errors + occur during processing of the document, an empty response is returned. + + Args: + body: The HTML document, as bytes. + media_url: The URI used to download the body. + request_encoding: The character encoding of the body, as a string. + + Returns: + The OG response as a dictionary. + """ # If there's no body, nothing useful is going to be found. if not body: return {} from lxml import etree + # Create an HTML parser. If this fails, log and return no metadata. try: parser = etree.HTMLParser(recover=True, encoding=request_encoding) - tree = etree.fromstring(body, parser) - og = _calc_og(tree, media_uri) + except LookupError: + # blindly consider the encoding as utf-8. + parser = etree.HTMLParser(recover=True, encoding="utf-8") + except Exception as e: + logger.warning("Unable to create HTML parser: %s" % (e,)) + return {} + + def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]: + # Attempt to parse the body. If this fails, log and return no metadata. + tree = etree.fromstring(body_attempt, parser) + return _calc_og(tree, media_uri) + + # Attempt to parse the body. If this fails, log and return no metadata. + try: + return _attempt_calc_og(body) except UnicodeDecodeError: # blindly try decoding the body as utf-8, which seems to fix # the charset mismatches on https://google.com - parser = etree.HTMLParser(recover=True, encoding=request_encoding) - tree = etree.fromstring(body.decode("utf-8", "ignore"), parser) - og = _calc_og(tree, media_uri) - - return og + return _attempt_calc_og(body.decode("utf-8", "ignore")) -def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]: +def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]: # suck our tree into lxml and define our OG response. # if we see any image URLs in the OG response, then spider them diff --git a/tests/test_preview.py b/tests/test_preview.py index c19facc1cb..0c6cbbd921 100644 --- a/tests/test_preview.py +++ b/tests/test_preview.py @@ -261,3 +261,32 @@ class PreviewUrlTestCase(unittest.TestCase): html = "" og = decode_and_calc_og(html, "http://example.com/test.html") self.assertEqual(og, {}) + + def test_invalid_encoding(self): + """An invalid character encoding should be ignored and treated as UTF-8, if possible.""" + html = """ + + Foo + + Some text. + + + """ + og = decode_and_calc_og( + html, "http://example.com/test.html", "invalid-encoding" + ) + self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) + + def test_invalid_encoding2(self): + """A body which doesn't match the sent character encoding.""" + # Note that this contains an invalid UTF-8 sequence in the title. + html = b""" + + \xff\xff Foo + + Some text. + + + """ + og = decode_and_calc_og(html, "http://example.com/test.html") + self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."}) -- cgit 1.5.1 From a737cc27134c50059440ca33510b0baea53b4225 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 27 Jan 2021 12:41:24 +0000 Subject: Implement MSC2858 support (#9183) Fixes #8928. --- changelog.d/9183.feature | 1 + synapse/config/_base.pyi | 2 + synapse/config/experimental.py | 29 ++++++++++++ synapse/config/homeserver.py | 2 + synapse/handlers/sso.py | 23 +++++++--- synapse/http/server.py | 44 ++++++++++++++---- synapse/rest/client/v1/login.py | 55 ++++++++++++++++++++--- tests/rest/client/v1/test_login.py | 92 ++++++++++++++++++++++++++++++++++++++ tests/utils.py | 3 +- 9 files changed, 230 insertions(+), 21 deletions(-) create mode 100644 changelog.d/9183.feature create mode 100644 synapse/config/experimental.py (limited to 'synapse/rest') diff --git a/changelog.d/9183.feature b/changelog.d/9183.feature new file mode 100644 index 0000000000..2d5c735042 --- /dev/null +++ b/changelog.d/9183.feature @@ -0,0 +1 @@ +Add experimental support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858). diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 29aa064e57..3ccea4b02d 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -9,6 +9,7 @@ from synapse.config import ( consent_config, database, emailconfig, + experimental, groups, jwt_config, key, @@ -48,6 +49,7 @@ def path_exists(file_path: str): ... class RootConfig: server: server.ServerConfig + experimental: experimental.ExperimentalConfig tls: tls.TlsConfig database: database.DatabaseConfig logging: logger.LoggingConfig diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py new file mode 100644 index 0000000000..b1c1c51e4d --- /dev/null +++ b/synapse/config/experimental.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.config._base import Config +from synapse.types import JsonDict + + +class ExperimentalConfig(Config): + """Config section for enabling experimental features""" + + section = "experimental" + + def read_config(self, config: JsonDict, **kwargs): + experimental = config.get("experimental_features") or {} + + # MSC2858 (multiple SSO identity providers) + self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 4bd2b3587b..64a2429f77 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -24,6 +24,7 @@ from .cas import CasConfig from .consent_config import ConsentConfig from .database import DatabaseConfig from .emailconfig import EmailConfig +from .experimental import ExperimentalConfig from .federation import FederationConfig from .groups import GroupsConfig from .jwt_config import JWTConfig @@ -57,6 +58,7 @@ class HomeServerConfig(RootConfig): config_classes = [ ServerConfig, + ExperimentalConfig, TlsConfig, FederationConfig, CacheConfig, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index d493327a10..afc1341d09 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -23,7 +23,7 @@ from typing_extensions import NoReturn, Protocol from twisted.web.http import Request from synapse.api.constants import LoginType -from synapse.api.errors import Codes, RedirectException, SynapseError +from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent from synapse.http.server import respond_with_html @@ -235,7 +235,10 @@ class SsoHandler: respond_with_html(request, code, html) async def handle_redirect_request( - self, request: SynapseRequest, client_redirect_url: bytes, + self, + request: SynapseRequest, + client_redirect_url: bytes, + idp_id: Optional[str], ) -> str: """Handle a request to /login/sso/redirect @@ -243,6 +246,7 @@ class SsoHandler: request: incoming HTTP request client_redirect_url: the URL that we should redirect the client to after login. + idp_id: optional identity provider chosen by the client Returns: the URI to redirect to @@ -252,10 +256,19 @@ class SsoHandler: 400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED ) + # if the client chose an IdP, use that + idp = None # type: Optional[SsoIdentityProvider] + if idp_id: + idp = self._identity_providers.get(idp_id) + if not idp: + raise NotFoundError("Unknown identity provider") + # if we only have one auth provider, redirect to it directly - if len(self._identity_providers) == 1: - ap = next(iter(self._identity_providers.values())) - return await ap.handle_redirect_request(request, client_redirect_url) + elif len(self._identity_providers) == 1: + idp = next(iter(self._identity_providers.values())) + + if idp: + return await idp.handle_redirect_request(request, client_redirect_url) # otherwise, redirect to the IDP picker return "/_synapse/client/pick_idp?" + urlencode( diff --git a/synapse/http/server.py b/synapse/http/server.py index e464bfe6c7..d69d579b3a 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -22,10 +22,22 @@ import types import urllib from http import HTTPStatus from io import BytesIO -from typing import Any, Callable, Dict, Iterator, List, Tuple, Union +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Iterable, + Iterator, + List, + Pattern, + Tuple, + Union, +) import jinja2 from canonicaljson import iterencode_canonical_json +from typing_extensions import Protocol from zope.interface import implementer from twisted.internet import defer, interfaces @@ -168,11 +180,25 @@ def wrap_async_request_handler(h): return preserve_fn(wrapped_async_request_handler) -class HttpServer: +# Type of a callback method for processing requests +# it is actually called with a SynapseRequest and a kwargs dict for the params, +# but I can't figure out how to represent that. +ServletCallback = Callable[ + ..., Union[None, Awaitable[None], Tuple[int, Any], Awaitable[Tuple[int, Any]]] +] + + +class HttpServer(Protocol): """ Interface for registering callbacks on a HTTP server """ - def register_paths(self, method, path_patterns, callback): + def register_paths( + self, + method: str, + path_patterns: Iterable[Pattern], + callback: ServletCallback, + servlet_classname: str, + ) -> None: """ Register a callback that gets fired if we receive a http request with the given method for a path that matches the given regex. @@ -180,12 +206,14 @@ class HttpServer: an unpacked tuple. Args: - method (str): The method to listen to. - path_patterns (list): The regex used to match requests. - callback (function): The function to fire if we receive a matched + method: The HTTP method to listen to. + path_patterns: The regex used to match requests. + callback: The function to fire if we receive a matched request. The first argument will be the request object and subsequent arguments will be any matched groups from the regex. - This should return a tuple of (code, response). + This should return either tuple of (code, response), or None. + servlet_classname (str): The name of the handler to be used in prometheus + and opentracing logs. """ pass @@ -354,7 +382,7 @@ class JsonResource(DirectServeJsonResource): def _get_handler_for_request( self, request: SynapseRequest - ) -> Tuple[Callable, str, Dict[str, str]]: + ) -> Tuple[ServletCallback, str, Dict[str, str]]: """Finds a callback method to handle the given request. Returns: diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index be938df962..0a561eea60 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -19,7 +19,8 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.appservice import ApplicationService -from synapse.http.server import finish_request +from synapse.handlers.sso import SsoIdentityProvider +from synapse.http.server import HttpServer, finish_request from synapse.http.servlet import ( RestServlet, parse_json_object_from_request, @@ -60,11 +61,14 @@ class LoginRestServlet(RestServlet): self.saml2_enabled = hs.config.saml2_enabled self.cas_enabled = hs.config.cas_enabled self.oidc_enabled = hs.config.oidc_enabled + self._msc2858_enabled = hs.config.experimental.msc2858_enabled self.auth = hs.get_auth() self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() + self._sso_handler = hs.get_sso_handler() + self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter( clock=hs.get_clock(), @@ -89,8 +93,17 @@ class LoginRestServlet(RestServlet): flows.append({"type": LoginRestServlet.CAS_TYPE}) if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: - flows.append({"type": LoginRestServlet.SSO_TYPE}) - # While its valid for us to advertise this login type generally, + sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict + + if self._msc2858_enabled: + sso_flow["org.matrix.msc2858.identity_providers"] = [ + _get_auth_flow_dict_for_idp(idp) + for idp in self._sso_handler.get_identity_providers().values() + ] + + flows.append(sso_flow) + + # While it's valid for us to advertise this login type generally, # synapse currently only gives out these tokens as part of the # SSO login flow. # Generally we don't want to advertise login flows that clients @@ -311,8 +324,20 @@ class LoginRestServlet(RestServlet): return result +def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict: + """Return an entry for the login flow dict + + Returns an entry suitable for inclusion in "identity_providers" in the + response to GET /_matrix/client/r0/login + """ + e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict + if idp.idp_icon: + e["icon"] = idp.idp_icon + return e + + class SsoRedirectServlet(RestServlet): - PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) + PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True) def __init__(self, hs: "HomeServer"): # make sure that the relevant handlers are instantiated, so that they @@ -324,13 +349,31 @@ class SsoRedirectServlet(RestServlet): if hs.config.oidc_enabled: hs.get_oidc_handler() self._sso_handler = hs.get_sso_handler() + self._msc2858_enabled = hs.config.experimental.msc2858_enabled + + def register(self, http_server: HttpServer) -> None: + super().register(http_server) + if self._msc2858_enabled: + # expose additional endpoint for MSC2858 support + http_server.register_paths( + "GET", + client_patterns( + "/org.matrix.msc2858/login/sso/redirect/(?P[A-Za-z0-9_.~-]+)$", + releases=(), + unstable=True, + ), + self.on_GET, + self.__class__.__name__, + ) - async def on_GET(self, request: SynapseRequest): + async def on_GET( + self, request: SynapseRequest, idp_id: Optional[str] = None + ) -> None: client_redirect_url = parse_string( request, "redirectUrl", required=True, encoding=None ) sso_url = await self._sso_handler.handle_redirect_request( - request, client_redirect_url + request, client_redirect_url, idp_id, ) logger.info("Redirecting to %s", sso_url) request.redirect(sso_url) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 2672ce24c6..e2bb945453 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -75,6 +75,10 @@ TEST_CLIENT_REDIRECT_URL = 'https://x?&q"+%3D%2B"="fö%26=o"' # the query params in TEST_CLIENT_REDIRECT_URL EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("", ""), ('q" =+"', '"fö&=o"')] +# (possibly experimental) login flows we expect to appear in the list after the normal +# ones +ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}] + class LoginRestServletTestCase(unittest.HomeserverTestCase): @@ -426,6 +430,57 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): d["/_synapse/oidc"] = OIDCResource(self.hs) return d + def test_get_login_flows(self): + """GET /login should return password and SSO flows""" + channel = self.make_request("GET", "/_matrix/client/r0/login") + self.assertEqual(channel.code, 200, channel.result) + + expected_flows = [ + {"type": "m.login.cas"}, + {"type": "m.login.sso"}, + {"type": "m.login.token"}, + {"type": "m.login.password"}, + ] + ADDITIONAL_LOGIN_FLOWS + + self.assertCountEqual(channel.json_body["flows"], expected_flows) + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_get_msc2858_login_flows(self): + """The SSO flow should include IdP info if MSC2858 is enabled""" + channel = self.make_request("GET", "/_matrix/client/r0/login") + self.assertEqual(channel.code, 200, channel.result) + + # stick the flows results in a dict by type + flow_results = {} # type: Dict[str, Any] + for f in channel.json_body["flows"]: + flow_type = f["type"] + self.assertNotIn( + flow_type, flow_results, "duplicate flow type %s" % (flow_type,) + ) + flow_results[flow_type] = f + + self.assertIn("m.login.sso", flow_results, "m.login.sso was not returned") + sso_flow = flow_results.pop("m.login.sso") + # we should have a set of IdPs + self.assertCountEqual( + sso_flow["org.matrix.msc2858.identity_providers"], + [ + {"id": "cas", "name": "CAS"}, + {"id": "saml", "name": "SAML"}, + {"id": "oidc-idp1", "name": "IDP1"}, + {"id": "oidc", "name": "OIDC"}, + ], + ) + + # the rest of the flows are simple + expected_flows = [ + {"type": "m.login.cas"}, + {"type": "m.login.token"}, + {"type": "m.login.password"}, + ] + ADDITIONAL_LOGIN_FLOWS + + self.assertCountEqual(flow_results.values(), expected_flows) + def test_multi_sso_redirect(self): """/login/sso/redirect should redirect to an identity picker""" # first hit the redirect url, which should redirect to our idp picker @@ -564,6 +619,43 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400, channel.result) + def test_client_idp_redirect_msc2858_disabled(self): + """If the client tries to pick an IdP but MSC2858 is disabled, return a 400""" + channel = self.make_request( + "GET", + "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), + ) + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_client_idp_redirect_to_unknown(self): + """If the client tries to pick an unknown IdP, return a 404""" + channel = self.make_request( + "GET", + "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), + ) + self.assertEqual(channel.code, 404, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_client_idp_redirect_to_oidc(self): + """If the client pick a known IdP, redirect to it""" + channel = self.make_request( + "GET", + "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), + ) + + self.assertEqual(channel.code, 302, channel.result) + oidc_uri = channel.headers.getRawHeaders("Location")[0] + oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) + + # it should redirect us to the auth page of the OIDC server + self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + @staticmethod def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: prefix = key + " = " diff --git a/tests/utils.py b/tests/utils.py index 09614093bc..022223cf24 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -33,7 +33,6 @@ from synapse.api.room_versions import RoomVersions from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION -from synapse.http.server import HttpServer from synapse.logging.context import current_context, set_current_context from synapse.server import HomeServer from synapse.storage import DataStore @@ -351,7 +350,7 @@ def mock_getRawHeaders(headers=None): # This is a mock /resource/ not an entire server -class MockHttpResource(HttpServer): +class MockHttpResource: def __init__(self, prefix=""): self.callbacks = [] # 3-tuple of method/pattern/function self.prefix = prefix -- cgit 1.5.1 From 2e537a02804af47862b8c7e0454ae85fde616f2d Mon Sep 17 00:00:00 2001 From: Pankaj Yadav <42418662+y-pankaj@users.noreply.github.com> Date: Wed, 27 Jan 2021 23:08:08 +0530 Subject: Check if a user is in the room before sending a PowerLevel event on their behalf (#9235) --- changelog.d/9235.bugfix | 1 + synapse/rest/admin/rooms.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) create mode 100644 changelog.d/9235.bugfix (limited to 'synapse/rest') diff --git a/changelog.d/9235.bugfix b/changelog.d/9235.bugfix new file mode 100644 index 0000000000..7809c8673b --- /dev/null +++ b/changelog.d/9235.bugfix @@ -0,0 +1 @@ +Fix a bug in the `make_room_admin` admin API where it failed if the admin with the greatest power level was not in the room. Contributed by Pankaj Yadav. diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index da1499cab3..f14915d47e 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -431,7 +431,17 @@ class MakeRoomAdminRestServlet(RestServlet): if not admin_users: raise SynapseError(400, "No local admin user in room") - admin_user_id = admin_users[-1] + admin_user_id = None + + for admin_user in reversed(admin_users): + if room_state.get((EventTypes.Member, admin_user)): + admin_user_id = admin_user + break + + if not admin_user_id: + raise SynapseError( + 400, "No local admin user in room", + ) pl_content = power_levels.content else: -- cgit 1.5.1 From a083aea396dbd455858e93d6a57a236e192b68e2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 27 Jan 2021 21:31:45 +0000 Subject: Add 'brand' field to MSC2858 response (#9242) We've decided to add a 'brand' field to help clients decide how to style the buttons. Also, fix up the allowed characters for idp_id, while I'm in the area. --- changelog.d/9183.feature | 2 +- changelog.d/9242.feature | 1 + docs/openid.md | 3 +++ docs/sample_config.yaml | 13 ++++++---- synapse/config/oidc_config.py | 52 +++++++++++++++++++++------------------- synapse/handlers/cas_handler.py | 3 ++- synapse/handlers/oidc_handler.py | 3 +++ synapse/handlers/saml_handler.py | 3 ++- synapse/handlers/sso.py | 5 ++++ synapse/rest/client/v1/login.py | 2 ++ 10 files changed, 55 insertions(+), 32 deletions(-) create mode 100644 changelog.d/9242.feature (limited to 'synapse/rest') diff --git a/changelog.d/9183.feature b/changelog.d/9183.feature index 2d5c735042..3bcd9f15d1 100644 --- a/changelog.d/9183.feature +++ b/changelog.d/9183.feature @@ -1 +1 @@ -Add experimental support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858). +Add experimental support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858)). diff --git a/changelog.d/9242.feature b/changelog.d/9242.feature new file mode 100644 index 0000000000..3bcd9f15d1 --- /dev/null +++ b/changelog.d/9242.feature @@ -0,0 +1 @@ +Add experimental support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858)). diff --git a/docs/openid.md b/docs/openid.md index b86ae89768..f01f46d326 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -225,6 +225,7 @@ Synapse config: oidc_providers: - idp_id: github idp_name: Github + idp_brand: "org.matrix.github" # optional: styling hint for clients discover: false issuer: "https://github.com/" client_id: "your-client-id" # TO BE FILLED @@ -250,6 +251,7 @@ oidc_providers: oidc_providers: - idp_id: google idp_name: Google + idp_brand: "org.matrix.google" # optional: styling hint for clients issuer: "https://accounts.google.com/" client_id: "your-client-id" # TO BE FILLED client_secret: "your-client-secret" # TO BE FILLED @@ -296,6 +298,7 @@ Synapse config: oidc_providers: - idp_id: gitlab idp_name: Gitlab + idp_brand: "org.matrix.gitlab" # optional: styling hint for clients issuer: "https://gitlab.com/" client_id: "your-client-id" # TO BE FILLED client_secret: "your-client-secret" # TO BE FILLED diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 1c90156db9..8777e3254d 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1727,10 +1727,14 @@ saml2_config: # offer the user a choice of login mechanisms. # # idp_icon: An optional icon for this identity provider, which is presented -# by identity picker pages. If given, must be an MXC URI of the format -# mxc:///. (An easy way to obtain such an MXC URI -# is to upload an image to an (unencrypted) room and then copy the "url" -# from the source of the event.) +# by clients and Synapse's own IdP picker page. If given, must be an +# MXC URI of the format mxc:///. (An easy way to +# obtain such an MXC URI is to upload an image to an (unencrypted) room +# and then copy the "url" from the source of the event.) +# +# idp_brand: An optional brand for this identity provider, allowing clients +# to style the login flow according to the identity provider in question. +# See the spec for possible options here. # # discover: set to 'false' to disable the use of the OIDC discovery mechanism # to discover endpoints. Defaults to true. @@ -1860,6 +1864,7 @@ oidc_providers: # #- idp_id: github # idp_name: Github + # idp_brand: org.matrix.github # discover: false # issuer: "https://github.com/" # client_id: "your-client-id" # TO BE FILLED diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 8237b2e797..f31511e039 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import string from collections import Counter from typing import Iterable, Optional, Tuple, Type @@ -79,10 +78,14 @@ class OIDCConfig(Config): # offer the user a choice of login mechanisms. # # idp_icon: An optional icon for this identity provider, which is presented - # by identity picker pages. If given, must be an MXC URI of the format - # mxc:///. (An easy way to obtain such an MXC URI - # is to upload an image to an (unencrypted) room and then copy the "url" - # from the source of the event.) + # by clients and Synapse's own IdP picker page. If given, must be an + # MXC URI of the format mxc:///. (An easy way to + # obtain such an MXC URI is to upload an image to an (unencrypted) room + # and then copy the "url" from the source of the event.) + # + # idp_brand: An optional brand for this identity provider, allowing clients + # to style the login flow according to the identity provider in question. + # See the spec for possible options here. # # discover: set to 'false' to disable the use of the OIDC discovery mechanism # to discover endpoints. Defaults to true. @@ -212,6 +215,7 @@ class OIDCConfig(Config): # #- idp_id: github # idp_name: Github + # idp_brand: org.matrix.github # discover: false # issuer: "https://github.com/" # client_id: "your-client-id" # TO BE FILLED @@ -235,11 +239,22 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { "type": "object", "required": ["issuer", "client_id", "client_secret"], "properties": { - # TODO: fix the maxLength here depending on what MSC2528 decides - # remember that we prefix the ID given here with `oidc-` - "idp_id": {"type": "string", "minLength": 1, "maxLength": 128}, + "idp_id": { + "type": "string", + "minLength": 1, + # MSC2858 allows a maxlen of 255, but we prefix with "oidc-" + "maxLength": 250, + "pattern": "^[A-Za-z0-9._~-]+$", + }, "idp_name": {"type": "string"}, "idp_icon": {"type": "string"}, + "idp_brand": { + "type": "string", + # MSC2758-style namespaced identifier + "minLength": 1, + "maxLength": 255, + "pattern": "^[a-z][a-z0-9_.-]*$", + }, "discover": {"type": "boolean"}, "issuer": {"type": "string"}, "client_id": {"type": "string"}, @@ -358,25 +373,8 @@ def _parse_oidc_config_dict( config_path + ("user_mapping_provider", "module"), ) - # MSC2858 will apply certain limits in what can be used as an IdP id, so let's - # enforce those limits now. - # TODO: factor out this stuff to a generic function idp_id = oidc_config.get("idp_id", "oidc") - # TODO: update this validity check based on what MSC2858 decides. - valid_idp_chars = set(string.ascii_lowercase + string.digits + "-._") - - if any(c not in valid_idp_chars for c in idp_id): - raise ConfigError( - 'idp_id may only contain a-z, 0-9, "-", ".", "_"', - config_path + ("idp_id",), - ) - - if idp_id[0] not in string.ascii_lowercase: - raise ConfigError( - "idp_id must start with a-z", config_path + ("idp_id",), - ) - # prefix the given IDP with a prefix specific to the SSO mechanism, to avoid # clashes with other mechs (such as SAML, CAS). # @@ -402,6 +400,7 @@ def _parse_oidc_config_dict( idp_id=idp_id, idp_name=oidc_config.get("idp_name", "OIDC"), idp_icon=idp_icon, + idp_brand=oidc_config.get("idp_brand"), discover=oidc_config.get("discover", True), issuer=oidc_config["issuer"], client_id=oidc_config["client_id"], @@ -432,6 +431,9 @@ class OidcProviderConfig: # Optional MXC URI for icon for this IdP. idp_icon = attr.ib(type=Optional[str]) + # Optional brand identifier for this IdP. + idp_brand = attr.ib(type=Optional[str]) + # whether the OIDC discovery mechanism is used to discover endpoints discover = attr.ib(type=bool) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 0f342c607b..048523ec94 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -80,9 +80,10 @@ class CasHandler: # user-facing name of this auth provider self.idp_name = "CAS" - # we do not currently support icons for CAS auth, but this is required by + # we do not currently support brands/icons for CAS auth, but this is required by # the SsoIdentityProvider protocol type. self.idp_icon = None + self.idp_brand = None self._sso_handler = hs.get_sso_handler() diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 324ddb798c..ca647fa78f 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -274,6 +274,9 @@ class OidcProvider: # MXC URI for icon for this auth provider self.idp_icon = provider.idp_icon + # optional brand identifier for this auth provider + self.idp_brand = provider.idp_brand + self._sso_handler = hs.get_sso_handler() self._sso_handler.register_identity_provider(self) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 38461cf79d..5946919c33 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -78,9 +78,10 @@ class SamlHandler(BaseHandler): # user-facing name of this auth provider self.idp_name = "SAML" - # we do not currently support icons for SAML auth, but this is required by + # we do not currently support icons/brands for SAML auth, but this is required by # the SsoIdentityProvider protocol type. self.idp_icon = None + self.idp_brand = None # a map from saml session id to Saml2SessionData object self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index afc1341d09..3308b037d2 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -80,6 +80,11 @@ class SsoIdentityProvider(Protocol): """Optional MXC URI for user-facing icon""" return None + @property + def idp_brand(self) -> Optional[str]: + """Optional branding identifier""" + return None + @abc.abstractmethod async def handle_redirect_request( self, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 0a561eea60..0fb9419e58 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -333,6 +333,8 @@ def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict: e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict if idp.idp_icon: e["icon"] = idp.idp_icon + if idp.idp_brand: + e["brand"] = idp.idp_brand return e -- cgit 1.5.1 From 4b73488e811714089ba447884dccb9b6ae3ac16c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 28 Jan 2021 17:39:21 +0000 Subject: Ratelimit 3PID /requestToken API (#9238) --- changelog.d/9238.feature | 1 + docs/sample_config.yaml | 6 +- synapse/config/_base.pyi | 2 +- synapse/config/ratelimiting.py | 13 ++++- synapse/handlers/identity.py | 28 ++++++++++ synapse/rest/client/v2_alpha/account.py | 12 +++- synapse/rest/client/v2_alpha/register.py | 6 ++ tests/rest/client/v2_alpha/test_account.py | 90 ++++++++++++++++++++++++++++-- tests/server.py | 9 ++- tests/unittest.py | 5 ++ tests/utils.py | 1 + 11 files changed, 159 insertions(+), 14 deletions(-) create mode 100644 changelog.d/9238.feature (limited to 'synapse/rest') diff --git a/changelog.d/9238.feature b/changelog.d/9238.feature new file mode 100644 index 0000000000..143a3e14f5 --- /dev/null +++ b/changelog.d/9238.feature @@ -0,0 +1 @@ +Add ratelimited to 3PID `/requestToken` API. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index c2ccd68f3a..e5b6268087 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -824,6 +824,7 @@ log_config: "CONFDIR/SERVERNAME.log.config" # users are joining rooms the server is already in (this is cheap) vs # "remote" for when users are trying to join rooms not on the server (which # can be more expensive) +# - one for ratelimiting how often a user or IP can attempt to validate a 3PID. # # The defaults are as shown below. # @@ -857,7 +858,10 @@ log_config: "CONFDIR/SERVERNAME.log.config" # remote: # per_second: 0.01 # burst_count: 3 - +# +#rc_3pid_validation: +# per_second: 0.003 +# burst_count: 5 # Ratelimiting settings for incoming federation # diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 7ed07a801d..70025b5d60 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -54,7 +54,7 @@ class RootConfig: tls: tls.TlsConfig database: database.DatabaseConfig logging: logger.LoggingConfig - ratelimit: ratelimiting.RatelimitConfig + ratelimiting: ratelimiting.RatelimitConfig media: repository.ContentRepositoryConfig captcha: captcha.CaptchaConfig voip: voip.VoipConfig diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 14b8836197..76f382527d 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -24,7 +24,7 @@ class RateLimitConfig: defaults={"per_second": 0.17, "burst_count": 3.0}, ): self.per_second = config.get("per_second", defaults["per_second"]) - self.burst_count = config.get("burst_count", defaults["burst_count"]) + self.burst_count = int(config.get("burst_count", defaults["burst_count"])) class FederationRateLimitConfig: @@ -102,6 +102,11 @@ class RatelimitConfig(Config): defaults={"per_second": 0.01, "burst_count": 3}, ) + self.rc_3pid_validation = RateLimitConfig( + config.get("rc_3pid_validation") or {}, + defaults={"per_second": 0.003, "burst_count": 5}, + ) + def generate_config_section(self, **kwargs): return """\ ## Ratelimiting ## @@ -131,6 +136,7 @@ class RatelimitConfig(Config): # users are joining rooms the server is already in (this is cheap) vs # "remote" for when users are trying to join rooms not on the server (which # can be more expensive) + # - one for ratelimiting how often a user or IP can attempt to validate a 3PID. # # The defaults are as shown below. # @@ -164,7 +170,10 @@ class RatelimitConfig(Config): # remote: # per_second: 0.01 # burst_count: 3 - + # + #rc_3pid_validation: + # per_second: 0.003 + # burst_count: 5 # Ratelimiting settings for incoming federation # diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index f61844d688..4f7137539b 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -27,9 +27,11 @@ from synapse.api.errors import ( HttpResponseException, SynapseError, ) +from synapse.api.ratelimiting import Ratelimiter from synapse.config.emailconfig import ThreepidBehaviour from synapse.http import RequestTimedOutError from synapse.http.client import SimpleHttpClient +from synapse.http.site import SynapseRequest from synapse.types import JsonDict, Requester from synapse.util import json_decoder from synapse.util.hash import sha256_and_url_safe_base64 @@ -57,6 +59,32 @@ class IdentityHandler(BaseHandler): self._web_client_location = hs.config.invite_client_location + # Ratelimiters for `/requestToken` endpoints. + self._3pid_validation_ratelimiter_ip = Ratelimiter( + clock=hs.get_clock(), + rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, + burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, + ) + self._3pid_validation_ratelimiter_address = Ratelimiter( + clock=hs.get_clock(), + rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, + burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, + ) + + def ratelimit_request_token_requests( + self, request: SynapseRequest, medium: str, address: str, + ): + """Used to ratelimit requests to `/requestToken` by IP and address. + + Args: + request: The associated request + medium: The type of threepid, e.g. "msisdn" or "email" + address: The actual threepid ID, e.g. the phone number or email address + """ + + self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP())) + self._3pid_validation_ratelimiter_address.ratelimit((medium, address)) + async def threepid_from_creds( self, id_server: str, creds: Dict[str, str] ) -> Optional[JsonDict]: diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 65e68d641b..a84a2fb385 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -54,7 +54,7 @@ logger = logging.getLogger(__name__) class EmailPasswordRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/account/password/email/requestToken$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.datastore = hs.get_datastore() @@ -103,6 +103,8 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) + self.identity_handler.ratelimit_request_token_requests(request, "email", email) + # The email will be sent to the stored address. # This avoids a potential account hijack by requesting a password reset to # an email address which is controlled by the attacker but which, after @@ -379,6 +381,8 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests(request, "email", email) + if next_link: # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) @@ -430,7 +434,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): class MsisdnThreepidRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs super().__init__() self.store = self.hs.get_datastore() @@ -458,6 +462,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests( + request, "msisdn", msisdn + ) + if next_link: # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index b093183e79..10e1891174 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -126,6 +126,8 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests(request, "email", email) + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "email", email ) @@ -205,6 +207,10 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests( + request, "msisdn", msisdn + ) + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "msisdn", msisdn ) diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index cb87b80e33..177dc476da 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -24,7 +24,7 @@ import pkg_resources import synapse.rest.admin from synapse.api.constants import LoginType, Membership -from synapse.api.errors import Codes +from synapse.api.errors import Codes, HttpResponseException from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import account, register from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource @@ -112,6 +112,56 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Assert we can't log in with the old password self.attempt_wrong_password_login("kermit", old_password) + @override_config({"rc_3pid_validation": {"burst_count": 3}}) + def test_ratelimit_by_email(self): + """Test that we ratelimit /requestToken for the same email. + """ + old_password = "monkey" + new_password = "kangeroo" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email = "test1@example.com" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + def reset(ip): + client_secret = "foobar" + session_id = self._request_token(email, client_secret, ip) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + self._reset_password(new_password, session_id, client_secret) + + self.email_attempts.clear() + + # We expect to be able to make three requests before getting rate + # limited. + # + # We change IPs to ensure that we're not being ratelimited due to the + # same IP + reset("127.0.0.1") + reset("127.0.0.2") + reset("127.0.0.3") + + with self.assertRaises(HttpResponseException) as cm: + reset("127.0.0.4") + + self.assertEqual(cm.exception.code, 429) + def test_basic_password_reset_canonicalise_email(self): """Test basic password reset flow Request password reset with different spelling @@ -239,13 +289,18 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.assertIsNotNone(session_id) - def _request_token(self, email, client_secret): + def _request_token(self, email, client_secret, ip="127.0.0.1"): channel = self.make_request( "POST", b"account/password/email/requestToken", {"client_secret": client_secret, "email": email, "send_attempt": 1}, + client_ip=ip, ) - self.assertEquals(200, channel.code, channel.result) + + if channel.code != 200: + raise HttpResponseException( + channel.code, channel.result["reason"], channel.result["body"], + ) return channel.json_body["sid"] @@ -509,6 +564,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_address_trim(self): self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar")) + @override_config({"rc_3pid_validation": {"burst_count": 3}}) + def test_ratelimit_by_ip(self): + """Tests that adding emails is ratelimited by IP + """ + + # We expect to be able to set three emails before getting ratelimited. + self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar")) + self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar")) + self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar")) + + with self.assertRaises(HttpResponseException) as cm: + self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar")) + + self.assertEqual(cm.exception.code, 429) + def test_add_email_if_disabled(self): """Test adding email to profile when doing so is disallowed """ @@ -777,7 +847,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): body["next_link"] = next_link channel = self.make_request("POST", b"account/3pid/email/requestToken", body,) - self.assertEquals(expect_code, channel.code, channel.result) + + if channel.code != expect_code: + raise HttpResponseException( + channel.code, channel.result["reason"], channel.result["body"], + ) return channel.json_body.get("sid") @@ -823,10 +897,12 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def _add_email(self, request_email, expected_email): """Test adding an email to profile """ + previous_email_attempts = len(self.email_attempts) + client_secret = "foobar" session_id = self._request_token(request_email, client_secret) - self.assertEquals(len(self.email_attempts), 1) + self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1) link = self._get_link_from_email() self._validate_token(link) @@ -855,4 +931,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual(expected_email, channel.json_body["threepids"][0]["address"]) + + threepids = {threepid["address"] for threepid in channel.json_body["threepids"]} + self.assertIn(expected_email, threepids) diff --git a/tests/server.py b/tests/server.py index 5a85d5fe7f..6419c445ec 100644 --- a/tests/server.py +++ b/tests/server.py @@ -47,6 +47,7 @@ class FakeChannel: site = attr.ib(type=Site) _reactor = attr.ib() result = attr.ib(type=dict, default=attr.Factory(dict)) + _ip = attr.ib(type=str, default="127.0.0.1") _producer = None @property @@ -120,7 +121,7 @@ class FakeChannel: def getPeer(self): # We give an address so that getClientIP returns a non null entry, # causing us to record the MAU - return address.IPv4Address("TCP", "127.0.0.1", 3423) + return address.IPv4Address("TCP", self._ip, 3423) def getHost(self): return None @@ -196,6 +197,7 @@ def make_request( custom_headers: Optional[ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] ] = None, + client_ip: str = "127.0.0.1", ) -> FakeChannel: """ Make a web request using the given method, path and content, and render it @@ -223,6 +225,9 @@ def make_request( will pump the reactor until the the renderer tells the channel the request is finished. + client_ip: The IP to use as the requesting IP. Useful for testing + ratelimiting. + Returns: channel """ @@ -250,7 +255,7 @@ def make_request( if isinstance(content, str): content = content.encode("utf8") - channel = FakeChannel(site, reactor) + channel = FakeChannel(site, reactor, ip=client_ip) req = request(channel) req.content = BytesIO(content) diff --git a/tests/unittest.py b/tests/unittest.py index bbd295687c..767d5d6077 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -386,6 +386,7 @@ class HomeserverTestCase(TestCase): custom_headers: Optional[ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] ] = None, + client_ip: str = "127.0.0.1", ) -> FakeChannel: """ Create a SynapseRequest at the path using the method and containing the @@ -410,6 +411,9 @@ class HomeserverTestCase(TestCase): custom_headers: (name, value) pairs to add as request headers + client_ip: The IP to use as the requesting IP. Useful for testing + ratelimiting. + Returns: The FakeChannel object which stores the result of the request. """ @@ -426,6 +430,7 @@ class HomeserverTestCase(TestCase): content_is_form, await_result, custom_headers, + client_ip, ) def setup_test_homeserver(self, *args, **kwargs): diff --git a/tests/utils.py b/tests/utils.py index 022223cf24..68033d7535 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -157,6 +157,7 @@ def default_config(name, parse=False): "local": {"per_second": 10000, "burst_count": 10000}, "remote": {"per_second": 10000, "burst_count": 10000}, }, + "rc_3pid_validation": {"per_second": 10000, "burst_count": 10000}, "saml2_enabled": False, "default_identity_server": None, "key_refresh_interval": 24 * 60 * 60 * 1000, -- cgit 1.5.1 From f78d07bf005f7212bcc74256721677a3b255ea0e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 1 Feb 2021 13:15:51 +0000 Subject: Split out a separate endpoint to complete SSO registration (#9262) There are going to be a couple of paths to get to the final step of SSO reg, and I want the URL in the browser to consistent. So, let's move the final step onto a separate path, which we redirect to. --- changelog.d/9262.feature | 1 + synapse/app/homeserver.py | 2 + synapse/handlers/sso.py | 81 ++++++++++++++++++++++------ synapse/http/server.py | 7 +++ synapse/rest/synapse/client/pick_username.py | 16 +++--- synapse/rest/synapse/client/sso_register.py | 50 +++++++++++++++++ tests/rest/client/v1/test_login.py | 14 ++++- 7 files changed, 145 insertions(+), 26 deletions(-) create mode 100644 changelog.d/9262.feature create mode 100644 synapse/rest/synapse/client/sso_register.py (limited to 'synapse/rest') diff --git a/changelog.d/9262.feature b/changelog.d/9262.feature new file mode 100644 index 0000000000..c21b197ca1 --- /dev/null +++ b/changelog.d/9262.feature @@ -0,0 +1 @@ +Improve the user experience of setting up an account via single-sign on. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 57a2f5237c..86d6f73674 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -62,6 +62,7 @@ from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.synapse.client.pick_idp import PickIdpResource from synapse.rest.synapse.client.pick_username import pick_username_resource +from synapse.rest.synapse.client.sso_register import SsoRegisterResource from synapse.rest.well_known import WellKnownResource from synapse.server import HomeServer from synapse.storage import DataStore @@ -192,6 +193,7 @@ class SynapseHomeServer(HomeServer): "/_synapse/admin": AdminRestResource(self), "/_synapse/client/pick_username": pick_username_resource(self), "/_synapse/client/pick_idp": PickIdpResource(self), + "/_synapse/client/sso_register": SsoRegisterResource(self), } ) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 3308b037d2..50c5ae142a 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -21,12 +21,13 @@ import attr from typing_extensions import NoReturn, Protocol from twisted.web.http import Request +from twisted.web.iweb import IRequest from synapse.api.constants import LoginType from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent -from synapse.http.server import respond_with_html +from synapse.http.server import respond_with_html, respond_with_redirect from synapse.http.site import SynapseRequest from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters from synapse.util.async_helpers import Linearizer @@ -141,6 +142,9 @@ class UsernameMappingSession: # expiry time for the session, in milliseconds expiry_time_ms = attr.ib(type=int) + # choices made by the user + chosen_localpart = attr.ib(type=Optional[str], default=None) + # the HTTP cookie used to track the mapping session id USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session" @@ -647,6 +651,25 @@ class SsoHandler: ) respond_with_html(request, 200, html) + def get_mapping_session(self, session_id: str) -> UsernameMappingSession: + """Look up the given username mapping session + + If it is not found, raises a SynapseError with an http code of 400 + + Args: + session_id: session to look up + Returns: + active mapping session + Raises: + SynapseError if the session is not found/has expired + """ + self._expire_old_sessions() + session = self._username_mapping_sessions.get(session_id) + if session: + return session + logger.info("Couldn't find session id %s", session_id) + raise SynapseError(400, "unknown session") + async def check_username_availability( self, localpart: str, session_id: str, ) -> bool: @@ -663,12 +686,7 @@ class SsoHandler: # make sure that there is a valid mapping session, to stop people dictionary- # scanning for accounts - - self._expire_old_sessions() - session = self._username_mapping_sessions.get(session_id) - if not session: - logger.info("Couldn't find session id %s", session_id) - raise SynapseError(400, "unknown session") + self.get_mapping_session(session_id) logger.info( "[session %s] Checking for availability of username %s", @@ -696,16 +714,33 @@ class SsoHandler: localpart: localpart requested by the user session_id: ID of the username mapping session, extracted from a cookie """ - self._expire_old_sessions() - session = self._username_mapping_sessions.get(session_id) - if not session: - logger.info("Couldn't find session id %s", session_id) - raise SynapseError(400, "unknown session") + session = self.get_mapping_session(session_id) + + # update the session with the user's choices + session.chosen_localpart = localpart + + # we're done; now we can register the user + respond_with_redirect(request, b"/_synapse/client/sso_register") + + async def register_sso_user(self, request: Request, session_id: str) -> None: + """Called once we have all the info we need to register a new user. - logger.info("[session %s] Registering localpart %s", session_id, localpart) + Does so and serves an HTTP response + + Args: + request: HTTP request + session_id: ID of the username mapping session, extracted from a cookie + """ + session = self.get_mapping_session(session_id) + + logger.info( + "[session %s] Registering localpart %s", + session_id, + session.chosen_localpart, + ) attributes = UserAttributes( - localpart=localpart, + localpart=session.chosen_localpart, display_name=session.display_name, emails=session.emails, ) @@ -720,7 +755,12 @@ class SsoHandler: request.getClientIP(), ) - logger.info("[session %s] Registered userid %s", session_id, user_id) + logger.info( + "[session %s] Registered userid %s with attributes %s", + session_id, + user_id, + attributes, + ) # delete the mapping session and the cookie del self._username_mapping_sessions[session_id] @@ -751,3 +791,14 @@ class SsoHandler: for session_id in to_expire: logger.info("Expiring mapping session %s", session_id) del self._username_mapping_sessions[session_id] + + +def get_username_mapping_session_cookie_from_request(request: IRequest) -> str: + """Extract the session ID from the cookie + + Raises a SynapseError if the cookie isn't found + """ + session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME) + if not session_id: + raise SynapseError(code=400, msg="missing session_id") + return session_id.decode("ascii", errors="replace") diff --git a/synapse/http/server.py b/synapse/http/server.py index d69d579b3a..8249732b27 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -761,6 +761,13 @@ def set_clickjacking_protection_headers(request: Request): request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';") +def respond_with_redirect(request: Request, url: bytes) -> None: + """Write a 302 response to the request, if it is still alive.""" + logger.debug("Redirect to %s", url.decode("utf-8")) + request.redirect(url) + finish_request(request) + + def finish_request(request: Request): """ Finish writing the response to the request. diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py index d3b6803e65..1bc737bad0 100644 --- a/synapse/rest/synapse/client/pick_username.py +++ b/synapse/rest/synapse/client/pick_username.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import TYPE_CHECKING import pkg_resources @@ -20,8 +21,7 @@ from twisted.web.http import Request from twisted.web.resource import Resource from twisted.web.static import File -from synapse.api.errors import SynapseError -from synapse.handlers.sso import USERNAME_MAPPING_SESSION_COOKIE_NAME +from synapse.handlers.sso import get_username_mapping_session_cookie_from_request from synapse.http.server import DirectServeHtmlResource, DirectServeJsonResource from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest @@ -61,12 +61,10 @@ class AvailabilityCheckResource(DirectServeJsonResource): async def _async_render_GET(self, request: Request): localpart = parse_string(request, "username", required=True) - session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME) - if not session_id: - raise SynapseError(code=400, msg="missing session_id") + session_id = get_username_mapping_session_cookie_from_request(request) is_available = await self._sso_handler.check_username_availability( - localpart, session_id.decode("ascii", errors="replace") + localpart, session_id ) return 200, {"available": is_available} @@ -79,10 +77,8 @@ class SubmitResource(DirectServeHtmlResource): async def _async_render_POST(self, request: SynapseRequest): localpart = parse_string(request, "username", required=True) - session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME) - if not session_id: - raise SynapseError(code=400, msg="missing session_id") + session_id = get_username_mapping_session_cookie_from_request(request) await self._sso_handler.handle_submit_username_request( - request, localpart, session_id.decode("ascii", errors="replace") + request, localpart, session_id ) diff --git a/synapse/rest/synapse/client/sso_register.py b/synapse/rest/synapse/client/sso_register.py new file mode 100644 index 0000000000..dfefeb7796 --- /dev/null +++ b/synapse/rest/synapse/client/sso_register.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING + +from twisted.web.http import Request + +from synapse.api.errors import SynapseError +from synapse.handlers.sso import get_username_mapping_session_cookie_from_request +from synapse.http.server import DirectServeHtmlResource + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class SsoRegisterResource(DirectServeHtmlResource): + """A resource which completes SSO registration + + This resource gets mounted at /_synapse/client/sso_register, and is shown + after we collect username and/or consent for a new SSO user. It (finally) registers + the user, and confirms redirect to the client + """ + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._sso_handler = hs.get_sso_handler() + + async def _async_render_GET(self, request: Request) -> None: + try: + session_id = get_username_mapping_session_cookie_from_request(request) + except SynapseError as e: + logger.warning("Error fetching session cookie: %s", e) + self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code) + return + await self._sso_handler.register_sso_user(request, session_id) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index e2bb945453..f01215ed1c 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -31,6 +31,7 @@ from synapse.rest.client.v2_alpha import devices, register from synapse.rest.client.v2_alpha.account import WhoamiRestServlet from synapse.rest.synapse.client.pick_idp import PickIdpResource from synapse.rest.synapse.client.pick_username import pick_username_resource +from synapse.rest.synapse.client.sso_register import SsoRegisterResource from synapse.types import create_requester from tests import unittest @@ -1215,6 +1216,7 @@ class UsernamePickerTestCase(HomeserverTestCase): d = super().create_resource_dict() d["/_synapse/client/pick_username"] = pick_username_resource(self.hs) + d["/_synapse/client/sso_register"] = SsoRegisterResource(self.hs) d["/_synapse/oidc"] = OIDCResource(self.hs) return d @@ -1253,7 +1255,7 @@ class UsernamePickerTestCase(HomeserverTestCase): self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) # Now, submit a username to the username picker, which should serve a redirect - # back to the client + # to the completion page submit_path = picker_url + "/submit" content = urlencode({b"username": b"bobby"}).encode("utf8") chan = self.make_request( @@ -1270,6 +1272,16 @@ class UsernamePickerTestCase(HomeserverTestCase): ) self.assertEqual(chan.code, 302, chan.result) location_headers = chan.headers.getRawHeaders("Location") + + # send a request to the completion page, which should 302 to the client redirectUrl + chan = self.make_request( + "GET", + path=location_headers[0], + custom_headers=[("Cookie", "username_mapping_session=" + session_id)], + ) + self.assertEqual(chan.code, 302, chan.result) + location_headers = chan.headers.getRawHeaders("Location") + # ensure that the returned location matches the requested redirect URL path, query = location_headers[0].split("?", 1) self.assertEqual(path, "https://x") -- cgit 1.5.1 From 9c715a5f1981891815c124353ba15cf4d17bf9bb Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 1 Feb 2021 15:47:59 +0000 Subject: Fix SSO on workers (#9271) Fixes #8966. * Factor out build_synapse_client_resource_tree Start a function which will mount resources common to all workers. * Move sso init into build_synapse_client_resource_tree ... so that we don't have to do it for each worker * Fix SSO-login-via-a-worker Expose the SSO login endpoints on workers, like the documentation says. * Update workers config for new endpoints Add documentation for endpoints recently added (#8942, #9017, #9262) * remove submit_token from workers endpoints list this *doesn't* work on workers (yet). * changelog * Add a comment about the odd path for SAML2Resource --- changelog.d/9271.bugfix | 1 + docs/workers.md | 18 +++++----- synapse/app/generic_worker.py | 11 +++--- synapse/app/homeserver.py | 18 ++-------- synapse/rest/synapse/client/__init__.py | 49 +++++++++++++++++++++++++- synapse/storage/databases/main/registration.py | 40 ++++++++++----------- tests/rest/client/v1/test_login.py | 15 ++------ tests/rest/client/v2_alpha/test_auth.py | 6 ++-- 8 files changed, 93 insertions(+), 65 deletions(-) create mode 100644 changelog.d/9271.bugfix (limited to 'synapse/rest') diff --git a/changelog.d/9271.bugfix b/changelog.d/9271.bugfix new file mode 100644 index 0000000000..ef30c6570f --- /dev/null +++ b/changelog.d/9271.bugfix @@ -0,0 +1 @@ +Fix single-sign-on when the endpoints are routed to synapse workers. diff --git a/docs/workers.md b/docs/workers.md index d01683681f..6b8887de36 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -225,7 +225,6 @@ expressions: ^/_matrix/client/(api/v1|r0|unstable)/joined_groups$ ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups$ ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups/ - ^/_synapse/client/password_reset/email/submit_token$ # Registration/login requests ^/_matrix/client/(api/v1|r0|unstable)/login$ @@ -256,25 +255,28 @@ Additionally, the following endpoints should be included if Synapse is configure to use SSO (you only need to include the ones for whichever SSO provider you're using): + # for all SSO providers + ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect + ^/_synapse/client/pick_idp$ + ^/_synapse/client/pick_username + ^/_synapse/client/sso_register$ + # OpenID Connect requests. - ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$ ^/_synapse/oidc/callback$ # SAML requests. - ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$ ^/_matrix/saml2/authn_response$ # CAS requests. - ^/_matrix/client/(api/v1|r0|unstable)/login/(cas|sso)/redirect$ ^/_matrix/client/(api/v1|r0|unstable)/login/cas/ticket$ -Note that a HTTP listener with `client` and `federation` resources must be -configured in the `worker_listeners` option in the worker config. - -Ensure that all SSO logins go to a single process (usually the main process). +Ensure that all SSO logins go to a single process. For multiple workers not handling the SSO endpoints properly, see [#7530](https://github.com/matrix-org/synapse/issues/7530). +Note that a HTTP listener with `client` and `federation` resources must be +configured in the `worker_listeners` option in the worker config. + #### Load balancing It is possible to run multiple instances of this worker app, with incoming requests diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index e60988fa4a..516f2464b4 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -22,6 +22,7 @@ from typing import Dict, Iterable, Optional, Set from typing_extensions import ContextManager from twisted.internet import address +from twisted.web.resource import IResource import synapse import synapse.events @@ -90,9 +91,8 @@ from synapse.replication.tcp.streams import ( ToDeviceStream, ) from synapse.rest.admin import register_servlets_for_media_repo -from synapse.rest.client.v1 import events, room +from synapse.rest.client.v1 import events, login, room from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet -from synapse.rest.client.v1.login import LoginRestServlet from synapse.rest.client.v1.profile import ( ProfileAvatarURLRestServlet, ProfileDisplaynameRestServlet, @@ -127,6 +127,7 @@ from synapse.rest.client.v2_alpha.sendtodevice import SendToDeviceRestServlet from synapse.rest.client.versions import VersionsRestServlet from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.server import HomeServer, cache_in_self from synapse.storage.databases.main.censor_events import CensorEventsStore from synapse.storage.databases.main.client_ips import ClientIpWorkerStore @@ -507,7 +508,7 @@ class GenericWorkerServer(HomeServer): site_tag = port # We always include a health resource. - resources = {"/health": HealthResource()} + resources = {"/health": HealthResource()} # type: Dict[str, IResource] for res in listener_config.http_options.resources: for name in res.names: @@ -517,7 +518,7 @@ class GenericWorkerServer(HomeServer): resource = JsonResource(self, canonical_json=False) RegisterRestServlet(self).register(resource) - LoginRestServlet(self).register(resource) + login.register_servlets(self, resource) ThreepidRestServlet(self).register(resource) DevicesRestServlet(self).register(resource) KeyQueryServlet(self).register(resource) @@ -557,6 +558,8 @@ class GenericWorkerServer(HomeServer): groups.register_servlets(self, resource) resources.update({CLIENT_API_PREFIX: resource}) + + resources.update(build_synapse_client_resource_tree(self)) elif name == "federation": resources.update({FEDERATION_PREFIX: TransportLayerServer(self)}) elif name == "media": diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 86d6f73674..244657cb88 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -60,9 +60,7 @@ from synapse.rest import ClientRestResource from synapse.rest.admin import AdminRestResource from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource -from synapse.rest.synapse.client.pick_idp import PickIdpResource -from synapse.rest.synapse.client.pick_username import pick_username_resource -from synapse.rest.synapse.client.sso_register import SsoRegisterResource +from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.rest.well_known import WellKnownResource from synapse.server import HomeServer from synapse.storage import DataStore @@ -191,22 +189,10 @@ class SynapseHomeServer(HomeServer): "/_matrix/client/versions": client_resource, "/.well-known/matrix/client": WellKnownResource(self), "/_synapse/admin": AdminRestResource(self), - "/_synapse/client/pick_username": pick_username_resource(self), - "/_synapse/client/pick_idp": PickIdpResource(self), - "/_synapse/client/sso_register": SsoRegisterResource(self), + **build_synapse_client_resource_tree(self), } ) - if self.get_config().oidc_enabled: - from synapse.rest.oidc import OIDCResource - - resources["/_synapse/oidc"] = OIDCResource(self) - - if self.get_config().saml2_enabled: - from synapse.rest.saml2 import SAML2Resource - - resources["/_matrix/saml2"] = SAML2Resource(self) - if self.get_config().threepid_behaviour_email == ThreepidBehaviour.LOCAL: from synapse.rest.synapse.client.password_reset import ( PasswordResetSubmitTokenResource, diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py index c0b733488b..6acbc03d73 100644 --- a/synapse/rest/synapse/client/__init__.py +++ b/synapse/rest/synapse/client/__init__.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2020 The Matrix.org Foundation C.I.C. +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,3 +12,50 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from typing import TYPE_CHECKING, Mapping + +from twisted.web.resource import Resource + +from synapse.rest.synapse.client.pick_idp import PickIdpResource +from synapse.rest.synapse.client.pick_username import pick_username_resource +from synapse.rest.synapse.client.sso_register import SsoRegisterResource + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resource]: + """Builds a resource tree to include synapse-specific client resources + + These are resources which should be loaded on all workers which expose a C-S API: + ie, the main process, and any generic workers so configured. + + Returns: + map from path to Resource. + """ + resources = { + # SSO bits. These are always loaded, whether or not SSO login is actually + # enabled (they just won't work very well if it's not) + "/_synapse/client/pick_idp": PickIdpResource(hs), + "/_synapse/client/pick_username": pick_username_resource(hs), + "/_synapse/client/sso_register": SsoRegisterResource(hs), + } + + # provider-specific SSO bits. Only load these if they are enabled, since they + # rely on optional dependencies. + if hs.config.oidc_enabled: + from synapse.rest.oidc import OIDCResource + + resources["/_synapse/oidc"] = OIDCResource(hs) + + if hs.config.saml2_enabled: + from synapse.rest.saml2 import SAML2Resource + + # This is mounted under '/_matrix' for backwards-compatibility. + resources["/_matrix/saml2"] = SAML2Resource(hs) + + return resources + + +__all__ = ["build_synapse_client_resource_tree"] diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 8d05288ed4..14c0878d81 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -443,6 +443,26 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f) + async def record_user_external_id( + self, auth_provider: str, external_id: str, user_id: str + ) -> None: + """Record a mapping from an external user id to a mxid + + Args: + auth_provider: identifier for the remote auth provider + external_id: id on that system + user_id: complete mxid that it is mapped to + """ + await self.db_pool.simple_insert( + table="user_external_ids", + values={ + "auth_provider": auth_provider, + "external_id": external_id, + "user_id": user_id, + }, + desc="record_user_external_id", + ) + async def get_user_by_external_id( self, auth_provider: str, external_id: str ) -> Optional[str]: @@ -1371,26 +1391,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - async def record_user_external_id( - self, auth_provider: str, external_id: str, user_id: str - ) -> None: - """Record a mapping from an external user id to a mxid - - Args: - auth_provider: identifier for the remote auth provider - external_id: id on that system - user_id: complete mxid that it is mapped to - """ - await self.db_pool.simple_insert( - table="user_external_ids", - values={ - "auth_provider": auth_provider, - "external_id": external_id, - "user_id": user_id, - }, - desc="record_user_external_id", - ) - async def user_set_password_hash( self, user_id: str, password_hash: Optional[str] ) -> None: diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index f01215ed1c..ded22a9767 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -29,9 +29,7 @@ from synapse.appservice import ApplicationService from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import devices, register from synapse.rest.client.v2_alpha.account import WhoamiRestServlet -from synapse.rest.synapse.client.pick_idp import PickIdpResource -from synapse.rest.synapse.client.pick_username import pick_username_resource -from synapse.rest.synapse.client.sso_register import SsoRegisterResource +from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.types import create_requester from tests import unittest @@ -424,11 +422,8 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): return config def create_resource_dict(self) -> Dict[str, Resource]: - from synapse.rest.oidc import OIDCResource - d = super().create_resource_dict() - d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs) - d["/_synapse/oidc"] = OIDCResource(self.hs) + d.update(build_synapse_client_resource_tree(self.hs)) return d def test_get_login_flows(self): @@ -1212,12 +1207,8 @@ class UsernamePickerTestCase(HomeserverTestCase): return config def create_resource_dict(self) -> Dict[str, Resource]: - from synapse.rest.oidc import OIDCResource - d = super().create_resource_dict() - d["/_synapse/client/pick_username"] = pick_username_resource(self.hs) - d["/_synapse/client/sso_register"] = SsoRegisterResource(self.hs) - d["/_synapse/oidc"] = OIDCResource(self.hs) + d.update(build_synapse_client_resource_tree(self.hs)) return d def test_username_picker(self): diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index a6488a3d29..3f50c56745 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -22,7 +22,7 @@ from synapse.api.constants import LoginType from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.rest.client.v1 import login from synapse.rest.client.v2_alpha import auth, devices, register -from synapse.rest.oidc import OIDCResource +from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.types import JsonDict, UserID from tests import unittest @@ -173,9 +173,7 @@ class UIAuthTests(unittest.HomeserverTestCase): def create_resource_dict(self): resource_dict = super().create_resource_dict() - if HAS_OIDC: - # mount the OIDC resource at /_synapse/oidc - resource_dict["/_synapse/oidc"] = OIDCResource(self.hs) + resource_dict.update(build_synapse_client_resource_tree(self.hs)) return resource_dict def prepare(self, reactor, clock, hs): -- cgit 1.5.1 From 4167494c90bc0477bdf4855a79e81dc81bba1377 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 1 Feb 2021 15:52:50 +0000 Subject: Replace username picker with a template (#9275) There's some prelimiary work here to pull out the construction of a jinja environment to a separate function. I wanted to load the template at display time rather than load time, so that it's easy to update on the fly. Honestly, I think we should do this with all our templates: the risk of ending up with malformed templates is far outweighed by the improved turnaround time for an admin trying to update them. --- changelog.d/9275.feature | 1 + docs/sample_config.yaml | 32 +++++- synapse/config/_base.py | 39 +------ synapse/config/oidc_config.py | 3 +- synapse/config/sso.py | 33 +++++- synapse/handlers/sso.py | 2 +- .../res/templates/sso_auth_account_details.html | 115 +++++++++++++++++++++ synapse/res/templates/sso_auth_account_details.js | 76 ++++++++++++++ synapse/res/username_picker/index.html | 19 ---- synapse/res/username_picker/script.js | 95 ----------------- synapse/res/username_picker/style.css | 27 ----- synapse/rest/consent/consent_resource.py | 1 + synapse/rest/synapse/client/pick_username.py | 79 ++++++++++---- synapse/util/templates.py | 106 +++++++++++++++++++ tests/rest/client/v1/test_login.py | 5 +- 15 files changed, 429 insertions(+), 204 deletions(-) create mode 100644 changelog.d/9275.feature create mode 100644 synapse/res/templates/sso_auth_account_details.html create mode 100644 synapse/res/templates/sso_auth_account_details.js delete mode 100644 synapse/res/username_picker/index.html delete mode 100644 synapse/res/username_picker/script.js delete mode 100644 synapse/res/username_picker/style.css create mode 100644 synapse/util/templates.py (limited to 'synapse/rest') diff --git a/changelog.d/9275.feature b/changelog.d/9275.feature new file mode 100644 index 0000000000..c21b197ca1 --- /dev/null +++ b/changelog.d/9275.feature @@ -0,0 +1 @@ +Improve the user experience of setting up an account via single-sign on. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 05506a7787..a6fbcc6080 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1801,7 +1801,8 @@ saml2_config: # # localpart_template: Jinja2 template for the localpart of the MXID. # If this is not set, the user will be prompted to choose their -# own username. +# own username (see 'sso_auth_account_details.html' in the 'sso' +# section of this file). # # display_name_template: Jinja2 template for the display name to set # on first login. If unset, no displayname will be set. @@ -1968,6 +1969,35 @@ sso: # # * idp: the 'idp_id' of the chosen IDP. # + # * HTML page to prompt new users to enter a userid and confirm other + # details: 'sso_auth_account_details.html'. This is only shown if the + # SSO implementation (with any user_mapping_provider) does not return + # a localpart. + # + # When rendering, this template is given the following variables: + # + # * server_name: the homeserver's name. + # + # * idp: details of the SSO Identity Provider that the user logged in + # with: an object with the following attributes: + # + # * idp_id: unique identifier for the IdP + # * idp_name: user-facing name for the IdP + # * idp_icon: if specified in the IdP config, an MXC URI for an icon + # for the IdP + # * idp_brand: if specified in the IdP config, a textual identifier + # for the brand of the IdP + # + # * user_attributes: an object containing details about the user that + # we received from the IdP. May have the following attributes: + # + # * display_name: the user's display_name + # * emails: a list of email addresses + # + # The template should render a form which submits the following fields: + # + # * username: the localpart of the user's chosen user id + # # * HTML page for a confirmation step before redirecting back to the client # with the login token: 'sso_redirect_confirm.html'. # diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 94144efc87..35e5594b73 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -18,18 +18,18 @@ import argparse import errno import os -import time -import urllib.parse from collections import OrderedDict from hashlib import sha256 from textwrap import dedent -from typing import Any, Callable, Iterable, List, MutableMapping, Optional +from typing import Any, Iterable, List, MutableMapping, Optional import attr import jinja2 import pkg_resources import yaml +from synapse.util.templates import _create_mxc_to_http_filter, _format_ts_filter + class ConfigError(Exception): """Represents a problem parsing the configuration @@ -248,6 +248,7 @@ class Config: # Search the custom template directory as well search_directories.insert(0, custom_template_directory) + # TODO: switch to synapse.util.templates.build_jinja_env loader = jinja2.FileSystemLoader(search_directories) env = jinja2.Environment(loader=loader, autoescape=autoescape) @@ -267,38 +268,6 @@ class Config: return templates -def _format_ts_filter(value: int, format: str): - return time.strftime(format, time.localtime(value / 1000)) - - -def _create_mxc_to_http_filter(public_baseurl: str) -> Callable: - """Create and return a jinja2 filter that converts MXC urls to HTTP - - Args: - public_baseurl: The public, accessible base URL of the homeserver - """ - - def mxc_to_http_filter(value, width, height, resize_method="crop"): - if value[0:6] != "mxc://": - return "" - - server_and_media_id = value[6:] - fragment = None - if "#" in server_and_media_id: - server_and_media_id, fragment = server_and_media_id.split("#", 1) - fragment = "#" + fragment - - params = {"width": width, "height": height, "method": resize_method} - return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( - public_baseurl, - server_and_media_id, - urllib.parse.urlencode(params), - fragment or "", - ) - - return mxc_to_http_filter - - class RootConfig: """ Holder of an application's configuration. diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index f31511e039..784b416f95 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -152,7 +152,8 @@ class OIDCConfig(Config): # # localpart_template: Jinja2 template for the localpart of the MXID. # If this is not set, the user will be prompted to choose their - # own username. + # own username (see 'sso_auth_account_details.html' in the 'sso' + # section of this file). # # display_name_template: Jinja2 template for the display name to set # on first login. If unset, no displayname will be set. diff --git a/synapse/config/sso.py b/synapse/config/sso.py index a470112ed4..e308fc9333 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -27,7 +27,7 @@ class SSOConfig(Config): sso_config = config.get("sso") or {} # type: Dict[str, Any] # The sso-specific template_dir - template_dir = sso_config.get("template_dir") + self.sso_template_dir = sso_config.get("template_dir") # Read templates from disk ( @@ -48,7 +48,7 @@ class SSOConfig(Config): "sso_auth_success.html", "sso_auth_bad_user.html", ], - template_dir, + self.sso_template_dir, ) # These templates have no placeholders, so render them here @@ -124,6 +124,35 @@ class SSOConfig(Config): # # * idp: the 'idp_id' of the chosen IDP. # + # * HTML page to prompt new users to enter a userid and confirm other + # details: 'sso_auth_account_details.html'. This is only shown if the + # SSO implementation (with any user_mapping_provider) does not return + # a localpart. + # + # When rendering, this template is given the following variables: + # + # * server_name: the homeserver's name. + # + # * idp: details of the SSO Identity Provider that the user logged in + # with: an object with the following attributes: + # + # * idp_id: unique identifier for the IdP + # * idp_name: user-facing name for the IdP + # * idp_icon: if specified in the IdP config, an MXC URI for an icon + # for the IdP + # * idp_brand: if specified in the IdP config, a textual identifier + # for the brand of the IdP + # + # * user_attributes: an object containing details about the user that + # we received from the IdP. May have the following attributes: + # + # * display_name: the user's display_name + # * emails: a list of email addresses + # + # The template should render a form which submits the following fields: + # + # * username: the localpart of the user's chosen user id + # # * HTML page for a confirmation step before redirecting back to the client # with the login token: 'sso_redirect_confirm.html'. # diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index ceaeb5a376..ff4750999a 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -530,7 +530,7 @@ class SsoHandler: logger.info("Recorded registration session id %s", session_id) # Set the cookie and redirect to the username picker - e = RedirectException(b"/_synapse/client/pick_username") + e = RedirectException(b"/_synapse/client/pick_username/account_details") e.cookies.append( b"%s=%s; path=/" % (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii")) diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html new file mode 100644 index 0000000000..f22b09aec1 --- /dev/null +++ b/synapse/res/templates/sso_auth_account_details.html @@ -0,0 +1,115 @@ + + + + Synapse Login + + + + + +
    +

    Your account is nearly ready

    +

    Check your details before creating an account on {{ server_name }}

    +
    +
    +
    +
    + +
    @
    + +
    :{{ server_name }}
    +
    + + {% if user_attributes %} +
    +

    Information from {{ idp.idp_name }}

    + {% if user_attributes.avatar_url %} +
    + +
    + {% endif %} + {% if user_attributes.display_name %} +
    +

    {{ user_attributes.display_name }}

    +
    + {% endif %} + {% for email in user_attributes.emails %} +
    +

    {{ email }}

    +
    + {% endfor %} +
    + {% endif %} +
    +
    + + + diff --git a/synapse/res/templates/sso_auth_account_details.js b/synapse/res/templates/sso_auth_account_details.js new file mode 100644 index 0000000000..deef419bb6 --- /dev/null +++ b/synapse/res/templates/sso_auth_account_details.js @@ -0,0 +1,76 @@ +const usernameField = document.getElementById("field-username"); + +function throttle(fn, wait) { + let timeout; + return function() { + const args = Array.from(arguments); + if (timeout) { + clearTimeout(timeout); + } + timeout = setTimeout(fn.bind.apply(fn, [null].concat(args)), wait); + } +} + +function checkUsernameAvailable(username) { + let check_uri = 'check?username=' + encodeURIComponent(username); + return fetch(check_uri, { + // include the cookie + "credentials": "same-origin", + }).then((response) => { + if(!response.ok) { + // for non-200 responses, raise the body of the response as an exception + return response.text().then((text) => { throw new Error(text); }); + } else { + return response.json(); + } + }).then((json) => { + if(json.error) { + return {message: json.error}; + } else if(json.available) { + return {available: true}; + } else { + return {message: username + " is not available, please choose another."}; + } + }); +} + +function validateUsername(username) { + usernameField.setCustomValidity(""); + if (usernameField.validity.valueMissing) { + usernameField.setCustomValidity("Please provide a username"); + return; + } + if (usernameField.validity.patternMismatch) { + usernameField.setCustomValidity("Invalid username, please only use " + allowedCharactersString); + return; + } + usernameField.setCustomValidity("Checking if username is available …"); + throttledCheckUsernameAvailable(username); +} + +const throttledCheckUsernameAvailable = throttle(function(username) { + const handleError = function(err) { + // don't prevent form submission on error + usernameField.setCustomValidity(""); + console.log(err.message); + }; + try { + checkUsernameAvailable(username).then(function(result) { + if (!result.available) { + usernameField.setCustomValidity(result.message); + usernameField.reportValidity(); + } else { + usernameField.setCustomValidity(""); + } + }, handleError); + } catch (err) { + handleError(err); + } +}, 500); + +usernameField.addEventListener("input", function(evt) { + validateUsername(usernameField.value); +}); +usernameField.addEventListener("change", function(evt) { + validateUsername(usernameField.value); +}); diff --git a/synapse/res/username_picker/index.html b/synapse/res/username_picker/index.html deleted file mode 100644 index 37ea8bb6d8..0000000000 --- a/synapse/res/username_picker/index.html +++ /dev/null @@ -1,19 +0,0 @@ - - - - Synapse Login - - - -
    -
    - - - -
    - - - -
    - - diff --git a/synapse/res/username_picker/script.js b/synapse/res/username_picker/script.js deleted file mode 100644 index 416a7c6f41..0000000000 --- a/synapse/res/username_picker/script.js +++ /dev/null @@ -1,95 +0,0 @@ -let inputField = document.getElementById("field-username"); -let inputForm = document.getElementById("form"); -let submitButton = document.getElementById("button-submit"); -let message = document.getElementById("message"); - -// Submit username and receive response -function showMessage(messageText) { - // Unhide the message text - message.classList.remove("hidden"); - - message.textContent = messageText; -}; - -function doSubmit() { - showMessage("Success. Please wait a moment for your browser to redirect."); - - // remove the event handler before re-submitting the form. - delete inputForm.onsubmit; - inputForm.submit(); -} - -function onResponse(response) { - // Display message - showMessage(response); - - // Enable submit button and input field - submitButton.classList.remove('button--disabled'); - submitButton.value = "Submit"; -}; - -let allowedUsernameCharacters = RegExp("[^a-z0-9\\.\\_\\=\\-\\/]"); -function usernameIsValid(username) { - return !allowedUsernameCharacters.test(username); -} -let allowedCharactersString = "lowercase letters, digits, ., _, -, /, ="; - -function buildQueryString(params) { - return Object.keys(params) - .map(k => encodeURIComponent(k) + '=' + encodeURIComponent(params[k])) - .join('&'); -} - -function submitUsername(username) { - if(username.length == 0) { - onResponse("Please enter a username."); - return; - } - if(!usernameIsValid(username)) { - onResponse("Invalid username. Only the following characters are allowed: " + allowedCharactersString); - return; - } - - // if this browser doesn't support fetch, skip the availability check. - if(!window.fetch) { - doSubmit(); - return; - } - - let check_uri = 'check?' + buildQueryString({"username": username}); - fetch(check_uri, { - // include the cookie - "credentials": "same-origin", - }).then((response) => { - if(!response.ok) { - // for non-200 responses, raise the body of the response as an exception - return response.text().then((text) => { throw text; }); - } else { - return response.json(); - } - }).then((json) => { - if(json.error) { - throw json.error; - } else if(json.available) { - doSubmit(); - } else { - onResponse("This username is not available, please choose another."); - } - }).catch((err) => { - onResponse("Error checking username availability: " + err); - }); -} - -function clickSubmit() { - event.preventDefault(); - if(submitButton.classList.contains('button--disabled')) { return; } - - // Disable submit button and input field - submitButton.classList.add('button--disabled'); - - // Submit username - submitButton.value = "Checking..."; - submitUsername(inputField.value); -}; - -inputForm.onsubmit = clickSubmit; diff --git a/synapse/res/username_picker/style.css b/synapse/res/username_picker/style.css deleted file mode 100644 index 745bd4c684..0000000000 --- a/synapse/res/username_picker/style.css +++ /dev/null @@ -1,27 +0,0 @@ -input[type="text"] { - font-size: 100%; - background-color: #ededf0; - border: 1px solid #fff; - border-radius: .2em; - padding: .5em .9em; - display: block; - width: 26em; -} - -.button--disabled { - border-color: #fff; - background-color: transparent; - color: #000; - text-transform: none; -} - -.hidden { - display: none; -} - -.tooltip { - background-color: #f9f9fa; - padding: 1em; - margin: 1em 0; -} - diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index b3e4d5612e..8b9ef26cf2 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -100,6 +100,7 @@ class ConsentResource(DirectServeHtmlResource): consent_template_directory = hs.config.user_consent_template_dir + # TODO: switch to synapse.util.templates.build_jinja_env loader = jinja2.FileSystemLoader(consent_template_directory) self._jinja_env = jinja2.Environment( loader=loader, autoescape=jinja2.select_autoescape(["html", "htm", "xml"]) diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py index 1bc737bad0..27540d3bbe 100644 --- a/synapse/rest/synapse/client/pick_username.py +++ b/synapse/rest/synapse/client/pick_username.py @@ -13,41 +13,41 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from typing import TYPE_CHECKING -import pkg_resources - from twisted.web.http import Request from twisted.web.resource import Resource -from twisted.web.static import File +from synapse.api.errors import SynapseError from synapse.handlers.sso import get_username_mapping_session_cookie_from_request -from synapse.http.server import DirectServeHtmlResource, DirectServeJsonResource +from synapse.http.server import ( + DirectServeHtmlResource, + DirectServeJsonResource, + respond_with_html, +) from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest +from synapse.util.templates import build_jinja_env if TYPE_CHECKING: from synapse.server import HomeServer +logger = logging.getLogger(__name__) + def pick_username_resource(hs: "HomeServer") -> Resource: """Factory method to generate the username picker resource. - This resource gets mounted under /_synapse/client/pick_username. The top-level - resource is just a File resource which serves up the static files in the resources - "res" directory, but it has a couple of children: + This resource gets mounted under /_synapse/client/pick_username and has two + children: - * "submit", which does the mechanics of registering the new user, and redirects the - browser back to the client URL - - * "check": checks if a userid is free. + * "account_details": renders the form and handles the POSTed response + * "check": a JSON endpoint which checks if a userid is free. """ - # XXX should we make this path customisable so that admins can restyle it? - base_path = pkg_resources.resource_filename("synapse", "res/username_picker") - - res = File(base_path) - res.putChild(b"submit", SubmitResource(hs)) + res = Resource() + res.putChild(b"account_details", AccountDetailsResource(hs)) res.putChild(b"check", AvailabilityCheckResource(hs)) return res @@ -69,15 +69,54 @@ class AvailabilityCheckResource(DirectServeJsonResource): return 200, {"available": is_available} -class SubmitResource(DirectServeHtmlResource): +class AccountDetailsResource(DirectServeHtmlResource): def __init__(self, hs: "HomeServer"): super().__init__() self._sso_handler = hs.get_sso_handler() - async def _async_render_POST(self, request: SynapseRequest): - localpart = parse_string(request, "username", required=True) + def template_search_dirs(): + if hs.config.sso.sso_template_dir: + yield hs.config.sso.sso_template_dir + yield hs.config.sso.default_template_dir + + self._jinja_env = build_jinja_env(template_search_dirs(), hs.config) + + async def _async_render_GET(self, request: Request) -> None: + try: + session_id = get_username_mapping_session_cookie_from_request(request) + session = self._sso_handler.get_mapping_session(session_id) + except SynapseError as e: + logger.warning("Error fetching session: %s", e) + self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code) + return + + idp_id = session.auth_provider_id + template_params = { + "idp": self._sso_handler.get_identity_providers()[idp_id], + "user_attributes": { + "display_name": session.display_name, + "emails": session.emails, + }, + } + + template = self._jinja_env.get_template("sso_auth_account_details.html") + html = template.render(template_params) + respond_with_html(request, 200, html) - session_id = get_username_mapping_session_cookie_from_request(request) + async def _async_render_POST(self, request: SynapseRequest): + try: + session_id = get_username_mapping_session_cookie_from_request(request) + except SynapseError as e: + logger.warning("Error fetching session cookie: %s", e) + self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code) + return + + try: + localpart = parse_string(request, "username", required=True) + except SynapseError as e: + logger.warning("[session %s] bad param: %s", session_id, e) + self._sso_handler.render_error(request, "bad_param", e.msg, code=e.code) + return await self._sso_handler.handle_submit_username_request( request, localpart, session_id diff --git a/synapse/util/templates.py b/synapse/util/templates.py new file mode 100644 index 0000000000..7e5109d206 --- /dev/null +++ b/synapse/util/templates.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for dealing with jinja2 templates""" + +import time +import urllib.parse +from typing import TYPE_CHECKING, Callable, Iterable, Union + +import jinja2 + +if TYPE_CHECKING: + from synapse.config.homeserver import HomeServerConfig + + +def build_jinja_env( + template_search_directories: Iterable[str], + config: "HomeServerConfig", + autoescape: Union[bool, Callable[[str], bool], None] = None, +) -> jinja2.Environment: + """Set up a Jinja2 environment to load templates from the given search path + + The returned environment defines the following filters: + - format_ts: formats timestamps as strings in the server's local timezone + (XXX: why is that useful??) + - mxc_to_http: converts mxc: uris to http URIs. Args are: + (uri, width, height, resize_method="crop") + + and the following global variables: + - server_name: matrix server name + + Args: + template_search_directories: directories to search for templates + + config: homeserver config, for things like `server_name` and `public_baseurl` + + autoescape: whether template variables should be autoescaped. bool, or + a function mapping from template name to bool. Defaults to escaping templates + whose names end in .html, .xml or .htm. + + Returns: + jinja environment + """ + + if autoescape is None: + autoescape = jinja2.select_autoescape() + + loader = jinja2.FileSystemLoader(template_search_directories) + env = jinja2.Environment(loader=loader, autoescape=autoescape) + + # Update the environment with our custom filters + env.filters.update( + { + "format_ts": _format_ts_filter, + "mxc_to_http": _create_mxc_to_http_filter(config.public_baseurl), + } + ) + + # common variables for all templates + env.globals.update({"server_name": config.server_name}) + + return env + + +def _create_mxc_to_http_filter(public_baseurl: str) -> Callable: + """Create and return a jinja2 filter that converts MXC urls to HTTP + + Args: + public_baseurl: The public, accessible base URL of the homeserver + """ + + def mxc_to_http_filter(value, width, height, resize_method="crop"): + if value[0:6] != "mxc://": + return "" + + server_and_media_id = value[6:] + fragment = None + if "#" in server_and_media_id: + server_and_media_id, fragment = server_and_media_id.split("#", 1) + fragment = "#" + fragment + + params = {"width": width, "height": height, "method": resize_method} + return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( + public_baseurl, + server_and_media_id, + urllib.parse.urlencode(params), + fragment or "", + ) + + return mxc_to_http_filter + + +def _format_ts_filter(value: int, format: str): + return time.strftime(format, time.localtime(value / 1000)) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index ded22a9767..66dfdaffbc 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -1222,7 +1222,7 @@ class UsernamePickerTestCase(HomeserverTestCase): # that should redirect to the username picker self.assertEqual(channel.code, 302, channel.result) picker_url = channel.headers.getRawHeaders("Location")[0] - self.assertEqual(picker_url, "/_synapse/client/pick_username") + self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details") # ... with a username_mapping_session cookie cookies = {} # type: Dict[str,str] @@ -1247,11 +1247,10 @@ class UsernamePickerTestCase(HomeserverTestCase): # Now, submit a username to the username picker, which should serve a redirect # to the completion page - submit_path = picker_url + "/submit" content = urlencode({b"username": b"bobby"}).encode("utf8") chan = self.make_request( "POST", - path=submit_path, + path=picker_url, content=content, content_is_form=True, custom_headers=[ -- cgit 1.5.1 From 85c56b5a679add887ec9716f176949561dca581b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 1 Feb 2021 17:30:42 +0000 Subject: Make importing display name and email optional (#9277) --- changelog.d/9277.feature | 1 + synapse/handlers/register.py | 5 ++- synapse/handlers/sso.py | 52 ++++++++++++++++++---- .../res/templates/sso_auth_account_details.html | 23 ++++++++++ synapse/rest/synapse/client/pick_username.py | 14 ++++-- 5 files changed, 82 insertions(+), 13 deletions(-) create mode 100644 changelog.d/9277.feature (limited to 'synapse/rest') diff --git a/changelog.d/9277.feature b/changelog.d/9277.feature new file mode 100644 index 0000000000..c21b197ca1 --- /dev/null +++ b/changelog.d/9277.feature @@ -0,0 +1 @@ +Improve the user experience of setting up an account via single-sign on. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index a2cf0f6f3e..b20a5d8605 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -14,8 +14,9 @@ # limitations under the License. """Contains functions for registering clients.""" + import logging -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from synapse import types from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType @@ -152,7 +153,7 @@ class RegistrationHandler(BaseHandler): user_type: Optional[str] = None, default_display_name: Optional[str] = None, address: Optional[str] = None, - bind_emails: List[str] = [], + bind_emails: Iterable[str] = [], by_admin: bool = False, user_agent_ips: Optional[List[Tuple[str, str]]] = None, ) -> str: diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index ff4750999a..d7ca2918f8 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -14,7 +14,16 @@ # limitations under the License. import abc import logging -from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Mapping, Optional +from typing import ( + TYPE_CHECKING, + Awaitable, + Callable, + Dict, + Iterable, + Mapping, + Optional, + Set, +) from urllib.parse import urlencode import attr @@ -29,7 +38,7 @@ from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent from synapse.http.server import respond_with_html, respond_with_redirect from synapse.http.site import SynapseRequest -from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters +from synapse.types import Collection, JsonDict, UserID, contains_invalid_mxid_characters from synapse.util.async_helpers import Linearizer from synapse.util.stringutils import random_string @@ -115,7 +124,7 @@ class UserAttributes: # enter one. localpart = attr.ib(type=Optional[str]) display_name = attr.ib(type=Optional[str], default=None) - emails = attr.ib(type=List[str], default=attr.Factory(list)) + emails = attr.ib(type=Collection[str], default=attr.Factory(list)) @attr.s(slots=True) @@ -130,7 +139,7 @@ class UsernameMappingSession: # attributes returned by the ID mapper display_name = attr.ib(type=Optional[str]) - emails = attr.ib(type=List[str]) + emails = attr.ib(type=Collection[str]) # An optional dictionary of extra attributes to be provided to the client in the # login response. @@ -144,6 +153,8 @@ class UsernameMappingSession: # choices made by the user chosen_localpart = attr.ib(type=Optional[str], default=None) + use_display_name = attr.ib(type=bool, default=True) + emails_to_use = attr.ib(type=Collection[str], default=()) # the HTTP cookie used to track the mapping session id @@ -710,7 +721,12 @@ class SsoHandler: return not user_infos async def handle_submit_username_request( - self, request: SynapseRequest, localpart: str, session_id: str + self, + request: SynapseRequest, + session_id: str, + localpart: str, + use_display_name: bool, + emails_to_use: Iterable[str], ) -> None: """Handle a request to the username-picker 'submit' endpoint @@ -720,11 +736,30 @@ class SsoHandler: request: HTTP request localpart: localpart requested by the user session_id: ID of the username mapping session, extracted from a cookie + use_display_name: whether the user wants to use the suggested display name + emails_to_use: emails that the user would like to use """ session = self.get_mapping_session(session_id) # update the session with the user's choices session.chosen_localpart = localpart + session.use_display_name = use_display_name + + emails_from_idp = set(session.emails) + filtered_emails = set() # type: Set[str] + + # we iterate through the list rather than just building a set conjunction, so + # that we can log attempts to use unknown addresses + for email in emails_to_use: + if email in emails_from_idp: + filtered_emails.add(email) + else: + logger.warning( + "[session %s] ignoring user request to use unknown email address %r", + session_id, + email, + ) + session.emails_to_use = filtered_emails # we're done; now we can register the user respond_with_redirect(request, b"/_synapse/client/sso_register") @@ -747,11 +782,12 @@ class SsoHandler: ) attributes = UserAttributes( - localpart=session.chosen_localpart, - display_name=session.display_name, - emails=session.emails, + localpart=session.chosen_localpart, emails=session.emails_to_use, ) + if session.use_display_name: + attributes.display_name = session.display_name + # the following will raise a 400 error if the username has been taken in the # meantime. user_id = await self._register_mapped_user( diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html index f22b09aec1..105063825a 100644 --- a/synapse/res/templates/sso_auth_account_details.html +++ b/synapse/res/templates/sso_auth_account_details.html @@ -53,6 +53,14 @@ border-top: 1px solid #E9ECF1; padding: 12px; } + .idp-pick-details .check-row { + display: flex; + align-items: center; + } + + .idp-pick-details .check-row .name { + flex: 1; + } .idp-pick-details .use, .idp-pick-details .idp-value { color: #737D8C; @@ -91,16 +99,31 @@

    Information from {{ idp.idp_name }}

    {% if user_attributes.avatar_url %}
    +
    + + + +
    {% endif %} {% if user_attributes.display_name %}
    +
    + + + +

    {{ user_attributes.display_name }}

    {% endif %} {% for email in user_attributes.emails %}
    +
    + + + +

    {{ email }}

    {% endfor %} diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py index 27540d3bbe..96077cfcd1 100644 --- a/synapse/rest/synapse/client/pick_username.py +++ b/synapse/rest/synapse/client/pick_username.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List from twisted.web.http import Request from twisted.web.resource import Resource @@ -26,7 +26,7 @@ from synapse.http.server import ( DirectServeJsonResource, respond_with_html, ) -from synapse.http.servlet import parse_string +from synapse.http.servlet import parse_boolean, parse_string from synapse.http.site import SynapseRequest from synapse.util.templates import build_jinja_env @@ -113,11 +113,19 @@ class AccountDetailsResource(DirectServeHtmlResource): try: localpart = parse_string(request, "username", required=True) + use_display_name = parse_boolean(request, "use_display_name", default=False) + + try: + emails_to_use = [ + val.decode("utf-8") for val in request.args.get(b"use_email", []) + ] # type: List[str] + except ValueError: + raise SynapseError(400, "Query parameter use_email must be utf-8") except SynapseError as e: logger.warning("[session %s] bad param: %s", session_id, e) self._sso_handler.render_error(request, "bad_param", e.msg, code=e.code) return await self._sso_handler.handle_submit_username_request( - request, localpart, session_id + request, session_id, localpart, use_display_name, emails_to_use ) -- cgit 1.5.1 From c543bf87ecf295fa68311beabd1dc013288a2e98 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 1 Feb 2021 18:37:41 +0000 Subject: Collect terms consent from the user during SSO registration (#9276) --- changelog.d/9276.feature | 1 + docs/sample_config.yaml | 22 ++++++ docs/workers.md | 1 + synapse/config/sso.py | 22 ++++++ synapse/handlers/register.py | 2 + synapse/handlers/sso.py | 44 +++++++++++ synapse/res/templates/sso_new_user_consent.html | 39 ++++++++++ synapse/rest/synapse/client/__init__.py | 2 + synapse/rest/synapse/client/new_user_consent.py | 97 +++++++++++++++++++++++++ 9 files changed, 230 insertions(+) create mode 100644 changelog.d/9276.feature create mode 100644 synapse/res/templates/sso_new_user_consent.html create mode 100644 synapse/rest/synapse/client/new_user_consent.py (limited to 'synapse/rest') diff --git a/changelog.d/9276.feature b/changelog.d/9276.feature new file mode 100644 index 0000000000..c21b197ca1 --- /dev/null +++ b/changelog.d/9276.feature @@ -0,0 +1 @@ +Improve the user experience of setting up an account via single-sign on. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index eec082ca8c..15e9746696 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2003,6 +2003,28 @@ sso: # # * username: the localpart of the user's chosen user id # + # * HTML page allowing the user to consent to the server's terms and + # conditions. This is only shown for new users, and only if + # `user_consent.require_at_registration` is set. + # + # When rendering, this template is given the following variables: + # + # * server_name: the homeserver's name. + # + # * user_id: the user's matrix proposed ID. + # + # * user_profile.display_name: the user's proposed display name, if any. + # + # * consent_version: the version of the terms that the user will be + # shown + # + # * terms_url: a link to the page showing the terms. + # + # The template should render a form which submits the following fields: + # + # * accepted_version: the version of the terms accepted by the user + # (ie, 'consent_version' from the input variables). + # # * HTML page for a confirmation step before redirecting back to the client # with the login token: 'sso_redirect_confirm.html'. # diff --git a/docs/workers.md b/docs/workers.md index 6b8887de36..0da805c333 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -259,6 +259,7 @@ using): ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect ^/_synapse/client/pick_idp$ ^/_synapse/client/pick_username + ^/_synapse/client/new_user_consent$ ^/_synapse/client/sso_register$ # OpenID Connect requests. diff --git a/synapse/config/sso.py b/synapse/config/sso.py index bf82183cdc..939eeac6de 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -158,6 +158,28 @@ class SSOConfig(Config): # # * username: the localpart of the user's chosen user id # + # * HTML page allowing the user to consent to the server's terms and + # conditions. This is only shown for new users, and only if + # `user_consent.require_at_registration` is set. + # + # When rendering, this template is given the following variables: + # + # * server_name: the homeserver's name. + # + # * user_id: the user's matrix proposed ID. + # + # * user_profile.display_name: the user's proposed display name, if any. + # + # * consent_version: the version of the terms that the user will be + # shown + # + # * terms_url: a link to the page showing the terms. + # + # The template should render a form which submits the following fields: + # + # * accepted_version: the version of the terms accepted by the user + # (ie, 'consent_version' from the input variables). + # # * HTML page for a confirmation step before redirecting back to the client # with the login token: 'sso_redirect_confirm.html'. # diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index b20a5d8605..49b085269b 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -694,6 +694,8 @@ class RegistrationHandler(BaseHandler): access_token: The access token of the newly logged in device, or None if `inhibit_login` enabled. """ + # TODO: 3pid registration can actually happen on the workers. Consider + # refactoring it. if self.hs.config.worker_app: await self._post_registration_client( user_id=user_id, auth_result=auth_result, access_token=access_token diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index d7ca2918f8..b450668f1c 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -155,6 +155,7 @@ class UsernameMappingSession: chosen_localpart = attr.ib(type=Optional[str], default=None) use_display_name = attr.ib(type=bool, default=True) emails_to_use = attr.ib(type=Collection[str], default=()) + terms_accepted_version = attr.ib(type=Optional[str], default=None) # the HTTP cookie used to track the mapping session id @@ -190,6 +191,8 @@ class SsoHandler: # map from idp_id to SsoIdentityProvider self._identity_providers = {} # type: Dict[str, SsoIdentityProvider] + self._consent_at_registration = hs.config.consent.user_consent_at_registration + def register_identity_provider(self, p: SsoIdentityProvider): p_id = p.idp_id assert p_id not in self._identity_providers @@ -761,6 +764,38 @@ class SsoHandler: ) session.emails_to_use = filtered_emails + # we may now need to collect consent from the user, in which case, redirect + # to the consent-extraction-unit + if self._consent_at_registration: + redirect_url = b"/_synapse/client/new_user_consent" + + # otherwise, redirect to the completion page + else: + redirect_url = b"/_synapse/client/sso_register" + + respond_with_redirect(request, redirect_url) + + async def handle_terms_accepted( + self, request: Request, session_id: str, terms_version: str + ): + """Handle a request to the new-user 'consent' endpoint + + Will serve an HTTP response to the request. + + Args: + request: HTTP request + session_id: ID of the username mapping session, extracted from a cookie + terms_version: the version of the terms which the user viewed and consented + to + """ + logger.info( + "[session %s] User consented to terms version %s", + session_id, + terms_version, + ) + session = self.get_mapping_session(session_id) + session.terms_accepted_version = terms_version + # we're done; now we can register the user respond_with_redirect(request, b"/_synapse/client/sso_register") @@ -816,6 +851,15 @@ class SsoHandler: path=b"/", ) + auth_result = {} + if session.terms_accepted_version: + # TODO: make this less awful. + auth_result[LoginType.TERMS] = True + + await self._registration_handler.post_registration_actions( + user_id, auth_result, access_token=None + ) + await self._auth_handler.complete_sso_login( user_id, request, diff --git a/synapse/res/templates/sso_new_user_consent.html b/synapse/res/templates/sso_new_user_consent.html new file mode 100644 index 0000000000..8c33787c54 --- /dev/null +++ b/synapse/res/templates/sso_new_user_consent.html @@ -0,0 +1,39 @@ + + + + + SSO redirect confirmation + + + + +
    +

    Your account is nearly ready

    +

    Agree to the terms to create your account.

    +
    +
    + +
    + +
    +
    {{ user_profile.display_name }}
    +
    {{ user_id }}
    +
    +
    + + +
    + + diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py index 6acbc03d73..02310c1900 100644 --- a/synapse/rest/synapse/client/__init__.py +++ b/synapse/rest/synapse/client/__init__.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Mapping from twisted.web.resource import Resource +from synapse.rest.synapse.client.new_user_consent import NewUserConsentResource from synapse.rest.synapse.client.pick_idp import PickIdpResource from synapse.rest.synapse.client.pick_username import pick_username_resource from synapse.rest.synapse.client.sso_register import SsoRegisterResource @@ -39,6 +40,7 @@ def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resourc # enabled (they just won't work very well if it's not) "/_synapse/client/pick_idp": PickIdpResource(hs), "/_synapse/client/pick_username": pick_username_resource(hs), + "/_synapse/client/new_user_consent": NewUserConsentResource(hs), "/_synapse/client/sso_register": SsoRegisterResource(hs), } diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py new file mode 100644 index 0000000000..b2e0f93810 --- /dev/null +++ b/synapse/rest/synapse/client/new_user_consent.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING + +from twisted.web.http import Request + +from synapse.api.errors import SynapseError +from synapse.handlers.sso import get_username_mapping_session_cookie_from_request +from synapse.http.server import DirectServeHtmlResource, respond_with_html +from synapse.http.servlet import parse_string +from synapse.types import UserID +from synapse.util.templates import build_jinja_env + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class NewUserConsentResource(DirectServeHtmlResource): + """A resource which collects consent to the server's terms from a new user + + This resource gets mounted at /_synapse/client/new_user_consent, and is shown + when we are automatically creating a new user due to an SSO login. + + It shows a template which prompts the user to go and read the Ts and Cs, and click + a clickybox if they have done so. + """ + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._sso_handler = hs.get_sso_handler() + self._server_name = hs.hostname + self._consent_version = hs.config.consent.user_consent_version + + def template_search_dirs(): + if hs.config.sso.sso_template_dir: + yield hs.config.sso.sso_template_dir + yield hs.config.sso.default_template_dir + + self._jinja_env = build_jinja_env(template_search_dirs(), hs.config) + + async def _async_render_GET(self, request: Request) -> None: + try: + session_id = get_username_mapping_session_cookie_from_request(request) + session = self._sso_handler.get_mapping_session(session_id) + except SynapseError as e: + logger.warning("Error fetching session: %s", e) + self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code) + return + + user_id = UserID(session.chosen_localpart, self._server_name) + user_profile = { + "display_name": session.display_name, + } + + template_params = { + "user_id": user_id.to_string(), + "user_profile": user_profile, + "consent_version": self._consent_version, + "terms_url": "/_matrix/consent?v=%s" % (self._consent_version,), + } + + template = self._jinja_env.get_template("sso_new_user_consent.html") + html = template.render(template_params) + respond_with_html(request, 200, html) + + async def _async_render_POST(self, request: Request): + try: + session_id = get_username_mapping_session_cookie_from_request(request) + except SynapseError as e: + logger.warning("Error fetching session cookie: %s", e) + self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code) + return + + try: + accepted_version = parse_string(request, "accepted_version", required=True) + except SynapseError as e: + self._sso_handler.render_error(request, "bad_param", e.msg, code=e.code) + return + + await self._sso_handler.handle_terms_accepted( + request, session_id, accepted_version + ) -- cgit 1.5.1 From 846b9d3df033be1043710e49e89bcba68722071e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 1 Feb 2021 22:56:01 +0000 Subject: Put OIDC callback URI under /_synapse/client. (#9288) --- CHANGES.md | 4 +++ UPGRADE.rst | 13 ++++++++- changelog.d/9288.feature | 1 + docs/openid.md | 19 ++++++------- docs/workers.md | 2 +- synapse/config/oidc_config.py | 2 +- synapse/handlers/oidc_handler.py | 8 +++--- synapse/rest/oidc/__init__.py | 27 ------------------- synapse/rest/oidc/callback_resource.py | 30 --------------------- synapse/rest/synapse/client/__init__.py | 4 +-- synapse/rest/synapse/client/oidc/__init__.py | 31 ++++++++++++++++++++++ .../rest/synapse/client/oidc/callback_resource.py | 30 +++++++++++++++++++++ tests/handlers/test_oidc.py | 15 +++++------ 13 files changed, 102 insertions(+), 84 deletions(-) create mode 100644 changelog.d/9288.feature delete mode 100644 synapse/rest/oidc/__init__.py delete mode 100644 synapse/rest/oidc/callback_resource.py create mode 100644 synapse/rest/synapse/client/oidc/__init__.py create mode 100644 synapse/rest/synapse/client/oidc/callback_resource.py (limited to 'synapse/rest') diff --git a/CHANGES.md b/CHANGES.md index fcd782fa94..e9ff14a03d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -3,6 +3,10 @@ Unreleased Note that this release includes a change in Synapse to use Redis as a cache ─ as well as a pub/sub mechanism ─ if Redis support is enabled. No action is needed by server administrators, and we do not expect resource usage of the Redis instance to change dramatically. +This release also changes the callback URI for OpenID Connect (OIDC) identity +providers. If your server is configured to use single sign-on via an +OIDC/OAuth2 IdP, you may need to make configuration changes. Please review +[UPGRADE.rst](UPGRADE.rst) for more details on these changes. Synapse 1.26.0 (2021-01-27) =========================== diff --git a/UPGRADE.rst b/UPGRADE.rst index eea0322695..d00f718cae 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -88,6 +88,17 @@ for example: Upgrading to v1.27.0 ==================== +Changes to callback URI for OAuth2 / OpenID Connect +--------------------------------------------------- + +This version changes the URI used for callbacks from OAuth2 identity providers. If +your server is configured for single sign-on via an OpenID Connect or OAuth2 identity +provider, you will need to add ``[synapse public baseurl]/_synapse/client/oidc/callback`` +to the list of permitted "redirect URIs" at the identity provider. + +See `docs/openid.md `_ for more information on setting up OpenID +Connect. + Changes to HTML templates ------------------------- @@ -235,7 +246,7 @@ shown below: return {"localpart": localpart} -Removal historical Synapse Admin API +Removal historical Synapse Admin API ------------------------------------ Historically, the Synapse Admin API has been accessible under: diff --git a/changelog.d/9288.feature b/changelog.d/9288.feature new file mode 100644 index 0000000000..efde69fb3c --- /dev/null +++ b/changelog.d/9288.feature @@ -0,0 +1 @@ +Update the redirect URI for OIDC authentication. diff --git a/docs/openid.md b/docs/openid.md index 3d07220967..9d19368845 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -54,7 +54,8 @@ Here are a few configs for providers that should work with Synapse. ### Microsoft Azure Active Directory Azure AD can act as an OpenID Connect Provider. Register a new application under *App registrations* in the Azure AD management console. The RedirectURI for your -application should point to your matrix server: `[synapse public baseurl]/_synapse/oidc/callback` +application should point to your matrix server: +`[synapse public baseurl]/_synapse/client/oidc/callback` Go to *Certificates & secrets* and register a new client secret. Make note of your Directory (tenant) ID as it will be used in the Azure links. @@ -94,7 +95,7 @@ staticClients: - id: synapse secret: secret redirectURIs: - - '[synapse public baseurl]/_synapse/oidc/callback' + - '[synapse public baseurl]/_synapse/client/oidc/callback' name: 'Synapse' ``` @@ -140,7 +141,7 @@ Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to | Enabled | `On` | | Client Protocol | `openid-connect` | | Access Type | `confidential` | -| Valid Redirect URIs | `[synapse public baseurl]/_synapse/oidc/callback` | +| Valid Redirect URIs | `[synapse public baseurl]/_synapse/client/oidc/callback` | 5. Click `Save` 6. On the Credentials tab, update the fields: @@ -168,7 +169,7 @@ oidc_providers: ### [Auth0][auth0] 1. Create a regular web application for Synapse -2. Set the Allowed Callback URLs to `[synapse public baseurl]/_synapse/oidc/callback` +2. Set the Allowed Callback URLs to `[synapse public baseurl]/_synapse/client/oidc/callback` 3. Add a rule to add the `preferred_username` claim.
    Code sample @@ -217,7 +218,7 @@ login mechanism needs an attribute to uniquely identify users, and that endpoint does not return a `sub` property, an alternative `subject_claim` has to be set. 1. Create a new OAuth application: https://github.com/settings/applications/new. -2. Set the callback URL to `[synapse public baseurl]/_synapse/oidc/callback`. +2. Set the callback URL to `[synapse public baseurl]/_synapse/client/oidc/callback`. Synapse config: @@ -262,13 +263,13 @@ oidc_providers: display_name_template: "{{ user.name }}" ``` 4. Back in the Google console, add this Authorized redirect URI: `[synapse - public baseurl]/_synapse/oidc/callback`. + public baseurl]/_synapse/client/oidc/callback`. ### Twitch 1. Setup a developer account on [Twitch](https://dev.twitch.tv/) 2. Obtain the OAuth 2.0 credentials by [creating an app](https://dev.twitch.tv/console/apps/) -3. Add this OAuth Redirect URL: `[synapse public baseurl]/_synapse/oidc/callback` +3. Add this OAuth Redirect URL: `[synapse public baseurl]/_synapse/client/oidc/callback` Synapse config: @@ -290,7 +291,7 @@ oidc_providers: 1. Create a [new application](https://gitlab.com/profile/applications). 2. Add the `read_user` and `openid` scopes. -3. Add this Callback URL: `[synapse public baseurl]/_synapse/oidc/callback` +3. Add this Callback URL: `[synapse public baseurl]/_synapse/client/oidc/callback` Synapse config: @@ -323,7 +324,7 @@ one so requires a little more configuration. 2. Once the app is created, add "Facebook Login" and choose "Web". You don't need to go through the whole form here. 3. In the left-hand menu, open "Products"/"Facebook Login"/"Settings". - * Add `[synapse public baseurl]/_synapse/oidc/callback` as an OAuth Redirect + * Add `[synapse public baseurl]/_synapse/client/oidc/callback` as an OAuth Redirect URL. 4. In the left-hand menu, open "Settings/Basic". Here you can copy the "App ID" and "App Secret" for use below. diff --git a/docs/workers.md b/docs/workers.md index c36549c621..c4a6c79238 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -266,7 +266,7 @@ using): ^/_synapse/client/sso_register$ # OpenID Connect requests. - ^/_synapse/oidc/callback$ + ^/_synapse/client/oidc/callback$ # SAML requests. ^/_matrix/saml2/authn_response$ diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index bb122ef182..4c24c50629 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -53,7 +53,7 @@ class OIDCConfig(Config): "Multiple OIDC providers have the idp_id %r." % idp_id ) - self.oidc_callback_url = self.public_baseurl + "_synapse/oidc/callback" + self.oidc_callback_url = self.public_baseurl + "_synapse/client/oidc/callback" @property def oidc_enabled(self) -> bool: diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index ca647fa78f..71008ec50d 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -102,7 +102,7 @@ class OidcHandler: ) from e async def handle_oidc_callback(self, request: SynapseRequest) -> None: - """Handle an incoming request to /_synapse/oidc/callback + """Handle an incoming request to /_synapse/client/oidc/callback Since we might want to display OIDC-related errors in a user-friendly way, we don't raise SynapseError from here. Instead, we call @@ -643,7 +643,7 @@ class OidcProvider: - ``client_id``: the client ID set in ``oidc_config.client_id`` - ``response_type``: ``code`` - - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/oidc/callback`` + - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/client/oidc/callback`` - ``scope``: the list of scopes set in ``oidc_config.scopes`` - ``state``: a random string - ``nonce``: a random string @@ -684,7 +684,7 @@ class OidcProvider: request.addCookie( SESSION_COOKIE_NAME, cookie, - path="/_synapse/oidc", + path="/_synapse/client/oidc", max_age="3600", httpOnly=True, sameSite="lax", @@ -705,7 +705,7 @@ class OidcProvider: async def handle_oidc_callback( self, request: SynapseRequest, session_data: "OidcSessionData", code: str ) -> None: - """Handle an incoming request to /_synapse/oidc/callback + """Handle an incoming request to /_synapse/client/oidc/callback By this time we have already validated the session on the synapse side, and now need to do the provider-specific operations. This includes: diff --git a/synapse/rest/oidc/__init__.py b/synapse/rest/oidc/__init__.py deleted file mode 100644 index d958dd65bb..0000000000 --- a/synapse/rest/oidc/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2020 Quentin Gliech -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -from twisted.web.resource import Resource - -from synapse.rest.oidc.callback_resource import OIDCCallbackResource - -logger = logging.getLogger(__name__) - - -class OIDCResource(Resource): - def __init__(self, hs): - Resource.__init__(self) - self.putChild(b"callback", OIDCCallbackResource(hs)) diff --git a/synapse/rest/oidc/callback_resource.py b/synapse/rest/oidc/callback_resource.py deleted file mode 100644 index f7a0bc4bdb..0000000000 --- a/synapse/rest/oidc/callback_resource.py +++ /dev/null @@ -1,30 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2020 Quentin Gliech -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -from synapse.http.server import DirectServeHtmlResource - -logger = logging.getLogger(__name__) - - -class OIDCCallbackResource(DirectServeHtmlResource): - isLeaf = 1 - - def __init__(self, hs): - super().__init__() - self._oidc_handler = hs.get_oidc_handler() - - async def _async_render_GET(self, request): - await self._oidc_handler.handle_oidc_callback(request) diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py index 02310c1900..381baf9729 100644 --- a/synapse/rest/synapse/client/__init__.py +++ b/synapse/rest/synapse/client/__init__.py @@ -47,9 +47,9 @@ def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resourc # provider-specific SSO bits. Only load these if they are enabled, since they # rely on optional dependencies. if hs.config.oidc_enabled: - from synapse.rest.oidc import OIDCResource + from synapse.rest.synapse.client.oidc import OIDCResource - resources["/_synapse/oidc"] = OIDCResource(hs) + resources["/_synapse/client/oidc"] = OIDCResource(hs) if hs.config.saml2_enabled: from synapse.rest.saml2 import SAML2Resource diff --git a/synapse/rest/synapse/client/oidc/__init__.py b/synapse/rest/synapse/client/oidc/__init__.py new file mode 100644 index 0000000000..64c0deb75d --- /dev/null +++ b/synapse/rest/synapse/client/oidc/__init__.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Quentin Gliech +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.web.resource import Resource + +from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource + +logger = logging.getLogger(__name__) + + +class OIDCResource(Resource): + def __init__(self, hs): + Resource.__init__(self) + self.putChild(b"callback", OIDCCallbackResource(hs)) + + +__all__ = ["OIDCResource"] diff --git a/synapse/rest/synapse/client/oidc/callback_resource.py b/synapse/rest/synapse/client/oidc/callback_resource.py new file mode 100644 index 0000000000..f7a0bc4bdb --- /dev/null +++ b/synapse/rest/synapse/client/oidc/callback_resource.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Quentin Gliech +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.http.server import DirectServeHtmlResource + +logger = logging.getLogger(__name__) + + +class OIDCCallbackResource(DirectServeHtmlResource): + isLeaf = 1 + + def __init__(self, hs): + super().__init__() + self._oidc_handler = hs.get_oidc_handler() + + async def _async_render_GET(self, request): + await self._oidc_handler.handle_oidc_callback(request) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index d8f90b9a80..ad20400b1d 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -40,7 +40,7 @@ ISSUER = "https://issuer/" CLIENT_ID = "test-client-id" CLIENT_SECRET = "test-client-secret" BASE_URL = "https://synapse/" -CALLBACK_URL = BASE_URL + "_synapse/oidc/callback" +CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback" SCOPES = ["openid"] AUTHORIZATION_ENDPOINT = ISSUER + "authorize" @@ -58,12 +58,6 @@ COMMON_CONFIG = { } -# The cookie name and path don't really matter, just that it has to be coherent -# between the callback & redirect handlers. -COOKIE_NAME = b"oidc_session" -COOKIE_PATH = "/_synapse/oidc" - - class TestMappingProvider: @staticmethod def parse_config(config): @@ -340,8 +334,11 @@ class OidcHandlerTestCase(HomeserverTestCase): # For some reason, call.args does not work with python3.5 args = calls[0][0] kwargs = calls[0][1] - self.assertEqual(args[0], COOKIE_NAME) - self.assertEqual(kwargs["path"], COOKIE_PATH) + + # The cookie name and path don't really matter, just that it has to be coherent + # between the callback & redirect handlers. + self.assertEqual(args[0], b"oidc_session") + self.assertEqual(kwargs["path"], "/_synapse/client/oidc") cookie = args[1] macaroon = pymacaroons.Macaroon.deserialize(cookie) -- cgit 1.5.1 From 8f75bf1df7f2bcb3ffe0bb89f8fe3351a48769c0 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 2 Feb 2021 09:43:50 +0000 Subject: Put SAML callback URI under /_synapse/client. (#9289) --- UPGRADE.rst | 4 +++ changelog.d/9289.removal | 1 + docs/sample_config.yaml | 4 +-- docs/workers.md | 2 +- synapse/config/saml2_config.py | 8 ++--- synapse/handlers/saml_handler.py | 2 +- synapse/rest/saml2/__init__.py | 29 ---------------- synapse/rest/saml2/metadata_resource.py | 36 -------------------- synapse/rest/saml2/response_resource.py | 39 ---------------------- synapse/rest/synapse/client/__init__.py | 9 +++-- synapse/rest/synapse/client/saml2/__init__.py | 33 ++++++++++++++++++ .../rest/synapse/client/saml2/metadata_resource.py | 36 ++++++++++++++++++++ .../rest/synapse/client/saml2/response_resource.py | 39 ++++++++++++++++++++++ 13 files changed, 127 insertions(+), 115 deletions(-) create mode 100644 changelog.d/9289.removal delete mode 100644 synapse/rest/saml2/__init__.py delete mode 100644 synapse/rest/saml2/metadata_resource.py delete mode 100644 synapse/rest/saml2/response_resource.py create mode 100644 synapse/rest/synapse/client/saml2/__init__.py create mode 100644 synapse/rest/synapse/client/saml2/metadata_resource.py create mode 100644 synapse/rest/synapse/client/saml2/response_resource.py (limited to 'synapse/rest') diff --git a/UPGRADE.rst b/UPGRADE.rst index d00f718cae..22edfe0d60 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -99,6 +99,10 @@ to the list of permitted "redirect URIs" at the identity provider. See `docs/openid.md `_ for more information on setting up OpenID Connect. +(Note: a similar change is being made for SAML2; in this case the old URI +``[synapse public baseurl]/_matrix/saml2`` is being deprecated, but will continue to +work, so no immediate changes are required for existing installations.) + Changes to HTML templates ------------------------- diff --git a/changelog.d/9289.removal b/changelog.d/9289.removal new file mode 100644 index 0000000000..49158fc4d3 --- /dev/null +++ b/changelog.d/9289.removal @@ -0,0 +1 @@ +Add new endpoint `/_synapse/client/saml2` for SAML2 authentication callbacks, and deprecate the old endpoint `/_matrix/saml2`. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index dd2981717d..6d265d2972 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1566,10 +1566,10 @@ trusted_key_servers: # enable SAML login. # # Once SAML support is enabled, a metadata file will be exposed at -# https://:/_matrix/saml2/metadata.xml, which you may be able to +# https://:/_synapse/client/saml2/metadata.xml, which you may be able to # use to configure your SAML IdP with. Alternatively, you can manually configure # the IdP to use an ACS location of -# https://:/_matrix/saml2/authn_response. +# https://:/_synapse/client/saml2/authn_response. # saml2_config: # `sp_config` is the configuration for the pysaml2 Service Provider. diff --git a/docs/workers.md b/docs/workers.md index c4a6c79238..f7fc6df119 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -269,7 +269,7 @@ using): ^/_synapse/client/oidc/callback$ # SAML requests. - ^/_matrix/saml2/authn_response$ + ^/_synapse/client/saml2/authn_response$ # CAS requests. ^/_matrix/client/(api/v1|r0|unstable)/login/cas/ticket$ diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index f33dfa0d6a..ad865a667f 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -194,8 +194,8 @@ class SAML2Config(Config): optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute) optional_attributes -= required_attributes - metadata_url = public_baseurl + "_matrix/saml2/metadata.xml" - response_url = public_baseurl + "_matrix/saml2/authn_response" + metadata_url = public_baseurl + "_synapse/client/saml2/metadata.xml" + response_url = public_baseurl + "_synapse/client/saml2/authn_response" return { "entityid": metadata_url, "service": { @@ -233,10 +233,10 @@ class SAML2Config(Config): # enable SAML login. # # Once SAML support is enabled, a metadata file will be exposed at - # https://:/_matrix/saml2/metadata.xml, which you may be able to + # https://:/_synapse/client/saml2/metadata.xml, which you may be able to # use to configure your SAML IdP with. Alternatively, you can manually configure # the IdP to use an ACS location of - # https://:/_matrix/saml2/authn_response. + # https://:/_synapse/client/saml2/authn_response. # saml2_config: # `sp_config` is the configuration for the pysaml2 Service Provider. diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 5946919c33..e88fd59749 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -133,7 +133,7 @@ class SamlHandler(BaseHandler): raise Exception("prepare_for_authenticate didn't return a Location header") async def handle_saml_response(self, request: SynapseRequest) -> None: - """Handle an incoming request to /_matrix/saml2/authn_response + """Handle an incoming request to /_synapse/client/saml2/authn_response Args: request: the incoming request from the browser. We'll diff --git a/synapse/rest/saml2/__init__.py b/synapse/rest/saml2/__init__.py deleted file mode 100644 index 68da37ca6a..0000000000 --- a/synapse/rest/saml2/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -from twisted.web.resource import Resource - -from synapse.rest.saml2.metadata_resource import SAML2MetadataResource -from synapse.rest.saml2.response_resource import SAML2ResponseResource - -logger = logging.getLogger(__name__) - - -class SAML2Resource(Resource): - def __init__(self, hs): - Resource.__init__(self) - self.putChild(b"metadata.xml", SAML2MetadataResource(hs)) - self.putChild(b"authn_response", SAML2ResponseResource(hs)) diff --git a/synapse/rest/saml2/metadata_resource.py b/synapse/rest/saml2/metadata_resource.py deleted file mode 100644 index 1e8526e22e..0000000000 --- a/synapse/rest/saml2/metadata_resource.py +++ /dev/null @@ -1,36 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import saml2.metadata - -from twisted.web.resource import Resource - - -class SAML2MetadataResource(Resource): - """A Twisted web resource which renders the SAML metadata""" - - isLeaf = 1 - - def __init__(self, hs): - Resource.__init__(self) - self.sp_config = hs.config.saml2_sp_config - - def render_GET(self, request): - metadata_xml = saml2.metadata.create_metadata_string( - configfile=None, config=self.sp_config - ) - request.setHeader(b"Content-Type", b"text/xml; charset=utf-8") - return metadata_xml diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py deleted file mode 100644 index f6668fb5e3..0000000000 --- a/synapse/rest/saml2/response_resource.py +++ /dev/null @@ -1,39 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.http.server import DirectServeHtmlResource - - -class SAML2ResponseResource(DirectServeHtmlResource): - """A Twisted web resource which handles the SAML response""" - - isLeaf = 1 - - def __init__(self, hs): - super().__init__() - self._saml_handler = hs.get_saml_handler() - - async def _async_render_GET(self, request): - # We're not expecting any GET request on that resource if everything goes right, - # but some IdPs sometimes end up responding with a 302 redirect on this endpoint. - # In this case, just tell the user that something went wrong and they should - # try to authenticate again. - self._saml_handler._render_error( - request, "unexpected_get", "Unexpected GET request on /saml2/authn_response" - ) - - async def _async_render_POST(self, request): - await self._saml_handler.handle_saml_response(request) diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py index 381baf9729..e5ef515090 100644 --- a/synapse/rest/synapse/client/__init__.py +++ b/synapse/rest/synapse/client/__init__.py @@ -52,10 +52,13 @@ def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resourc resources["/_synapse/client/oidc"] = OIDCResource(hs) if hs.config.saml2_enabled: - from synapse.rest.saml2 import SAML2Resource + from synapse.rest.synapse.client.saml2 import SAML2Resource - # This is mounted under '/_matrix' for backwards-compatibility. - resources["/_matrix/saml2"] = SAML2Resource(hs) + res = SAML2Resource(hs) + resources["/_synapse/client/saml2"] = res + + # This is also mounted under '/_matrix' for backwards-compatibility. + resources["/_matrix/saml2"] = res return resources diff --git a/synapse/rest/synapse/client/saml2/__init__.py b/synapse/rest/synapse/client/saml2/__init__.py new file mode 100644 index 0000000000..3e8235ee1e --- /dev/null +++ b/synapse/rest/synapse/client/saml2/__init__.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.web.resource import Resource + +from synapse.rest.synapse.client.saml2.metadata_resource import SAML2MetadataResource +from synapse.rest.synapse.client.saml2.response_resource import SAML2ResponseResource + +logger = logging.getLogger(__name__) + + +class SAML2Resource(Resource): + def __init__(self, hs): + Resource.__init__(self) + self.putChild(b"metadata.xml", SAML2MetadataResource(hs)) + self.putChild(b"authn_response", SAML2ResponseResource(hs)) + + +__all__ = ["SAML2Resource"] diff --git a/synapse/rest/synapse/client/saml2/metadata_resource.py b/synapse/rest/synapse/client/saml2/metadata_resource.py new file mode 100644 index 0000000000..1e8526e22e --- /dev/null +++ b/synapse/rest/synapse/client/saml2/metadata_resource.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import saml2.metadata + +from twisted.web.resource import Resource + + +class SAML2MetadataResource(Resource): + """A Twisted web resource which renders the SAML metadata""" + + isLeaf = 1 + + def __init__(self, hs): + Resource.__init__(self) + self.sp_config = hs.config.saml2_sp_config + + def render_GET(self, request): + metadata_xml = saml2.metadata.create_metadata_string( + configfile=None, config=self.sp_config + ) + request.setHeader(b"Content-Type", b"text/xml; charset=utf-8") + return metadata_xml diff --git a/synapse/rest/synapse/client/saml2/response_resource.py b/synapse/rest/synapse/client/saml2/response_resource.py new file mode 100644 index 0000000000..f6668fb5e3 --- /dev/null +++ b/synapse/rest/synapse/client/saml2/response_resource.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.http.server import DirectServeHtmlResource + + +class SAML2ResponseResource(DirectServeHtmlResource): + """A Twisted web resource which handles the SAML response""" + + isLeaf = 1 + + def __init__(self, hs): + super().__init__() + self._saml_handler = hs.get_saml_handler() + + async def _async_render_GET(self, request): + # We're not expecting any GET request on that resource if everything goes right, + # but some IdPs sometimes end up responding with a 302 redirect on this endpoint. + # In this case, just tell the user that something went wrong and they should + # try to authenticate again. + self._saml_handler._render_error( + request, "unexpected_get", "Unexpected GET request on /saml2/authn_response" + ) + + async def _async_render_POST(self, request): + await self._saml_handler.handle_saml_response(request) -- cgit 1.5.1 From b60bb28bbc3d916586a913970298baba483efc1f Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Tue, 2 Feb 2021 04:16:29 -0700 Subject: Add an admin API to get the current room state (#9168) This could arguably replace the existing admin API for `/members`, however that is out of scope of this change. This sort of endpoint is ideal for moderation use cases as well as other applications, such as needing to retrieve various bits of information about a room to perform a task (like syncing power levels between two places). This endpoint exposes nothing more than an admin would be able to access with a `select *` query on their database. --- changelog.d/9168.feature | 1 + docs/admin_api/rooms.md | 30 ++++++++++++++++++++++++++++++ synapse/handlers/message.py | 2 +- synapse/rest/admin/__init__.py | 2 ++ synapse/rest/admin/rooms.py | 39 +++++++++++++++++++++++++++++++++++++++ tests/rest/admin/test_room.py | 15 +++++++++++++++ 6 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 changelog.d/9168.feature (limited to 'synapse/rest') diff --git a/changelog.d/9168.feature b/changelog.d/9168.feature new file mode 100644 index 0000000000..8be1950eee --- /dev/null +++ b/changelog.d/9168.feature @@ -0,0 +1 @@ +Add an admin API for retrieving the current room state of a room. \ No newline at end of file diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index f34cec1ff7..3832b36407 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -368,6 +368,36 @@ Response: } ``` +# Room State API + +The Room State admin API allows server admins to get a list of all state events in a room. + +The response includes the following fields: + +* `state` - The current state of the room at the time of request. + +## Usage + +A standard request: + +``` +GET /_synapse/admin/v1/rooms//state + +{} +``` + +Response: + +```json +{ + "state": [ + {"type": "m.room.create", "state_key": "", "etc": true}, + {"type": "m.room.power_levels", "state_key": "", "etc": true}, + {"type": "m.room.name", "state_key": "", "etc": true} + ] +} +``` + # Delete Room API The Delete Room admin API allows server admins to remove rooms from server diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e2a7d567fa..a15336bf00 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -174,7 +174,7 @@ class MessageHandler: raise NotFoundError("Can't find event for token %s" % (at_token,)) visible_events = await filter_events_for_client( - self.storage, user_id, last_events, filter_send_to_client=False + self.storage, user_id, last_events, filter_send_to_client=False, ) event = last_events[0] diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 57e0a10837..f5c5d164f9 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -44,6 +44,7 @@ from synapse.rest.admin.rooms import ( MakeRoomAdminRestServlet, RoomMembersRestServlet, RoomRestServlet, + RoomStateRestServlet, ShutdownRoomRestServlet, ) from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet @@ -213,6 +214,7 @@ def register_servlets(hs, http_server): """ register_servlets_for_client_rest_resource(hs, http_server) ListRoomRestServlet(hs).register(http_server) + RoomStateRestServlet(hs).register(http_server) RoomRestServlet(hs).register(http_server) RoomMembersRestServlet(hs).register(http_server) DeleteRoomRestServlet(hs).register(http_server) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index f14915d47e..3e57e6a4d0 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -292,6 +292,45 @@ class RoomMembersRestServlet(RestServlet): return 200, ret +class RoomStateRestServlet(RestServlet): + """ + Get full state within a room. + """ + + PATTERNS = admin_patterns("/rooms/(?P[^/]+)/state") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self._event_serializer = hs.get_event_client_serializer() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + ret = await self.store.get_room(room_id) + if not ret: + raise NotFoundError("Room not found") + + event_ids = await self.store.get_current_state_ids(room_id) + events = await self.store.get_events(event_ids.values()) + now = self.clock.time_msec() + room_state = await self._event_serializer.serialize_events( + events.values(), + now, + # We don't bother bundling aggregations in when asked for state + # events, as clients won't use them. + bundle_aggregations=False, + ) + ret = {"state": room_state} + + return 200, ret + + class JoinRoomAliasServlet(RestServlet): PATTERNS = admin_patterns("/join/(?P[^/]*)") diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index a0f32c5512..7c47aa7e0a 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1180,6 +1180,21 @@ class RoomTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.json_body["total"], 3) + def test_room_state(self): + """Test that room state can be requested correctly""" + # Create two test rooms + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + url = "/_synapse/admin/v1/rooms/%s/state" % (room_id,) + channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertIn("state", channel.json_body) + # testing that the state events match is painful and not done here. We assume that + # the create_room already does the right thing, so no need to verify that we got + # the state events it created. + class JoinAliasRoomTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From e40d88cff3cca3d5186d5f623ad1107bc403d69b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 11 Feb 2021 11:16:54 -0500 Subject: Backout changes for automatically calculating the public baseurl. (#9313) This breaks some people's configurations (if their Client-Server API is not accessed via port 443). --- changelog.d/9313.bugfix | 1 + docs/sample_config.yaml | 20 +++++++++----------- synapse/api/urls.py | 2 ++ synapse/config/cas.py | 16 +++++++++------- synapse/config/emailconfig.py | 8 ++++++++ synapse/config/oidc_config.py | 5 ++++- synapse/config/registration.py | 21 +++++++++++++++++---- synapse/config/saml2_config.py | 2 ++ synapse/config/server.py | 13 ++++--------- synapse/config/sso.py | 13 ++++++++----- synapse/handlers/identity.py | 4 ++++ synapse/rest/well_known.py | 4 ++++ synapse/util/templates.py | 15 ++++++++++++--- tests/rest/client/v1/test_login.py | 4 +++- tests/rest/test_well_known.py | 9 +++++++++ tests/utils.py | 1 + 16 files changed, 97 insertions(+), 41 deletions(-) create mode 100644 changelog.d/9313.bugfix (limited to 'synapse/rest') diff --git a/changelog.d/9313.bugfix b/changelog.d/9313.bugfix new file mode 100644 index 0000000000..f578fd13dd --- /dev/null +++ b/changelog.d/9313.bugfix @@ -0,0 +1 @@ +Do not automatically calculate `public_baseurl` since it can be wrong in some situations. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 236abd9a3f..d395da11b4 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -74,10 +74,6 @@ pid_file: DATADIR/homeserver.pid # Otherwise, it should be the URL to reach Synapse's client HTTP listener (see # 'listeners' below). # -# If this is left unset, it defaults to 'https:///'. (Note that -# that will not work unless you configure Synapse or a reverse-proxy to listen -# on port 443.) -# #public_baseurl: https://example.com/ # Set the soft limit on the number of file descriptors synapse can use @@ -1169,9 +1165,8 @@ account_validity: # send an email to the account's email address with a renewal link. By # default, no such emails are sent. # - # If you enable this setting, you will also need to fill out the 'email' - # configuration section. You should also check that 'public_baseurl' is set - # correctly. + # If you enable this setting, you will also need to fill out the 'email' and + # 'public_baseurl' configuration sections. # #renew_at: 1w @@ -1262,7 +1257,8 @@ account_validity: # The identity server which we suggest that clients should use when users log # in on this server. # -# (By default, no suggestion is made, so it is left up to the client.) +# (By default, no suggestion is made, so it is left up to the client. +# This setting is ignored unless public_baseurl is also set.) # #default_identity_server: https://matrix.org @@ -1287,6 +1283,8 @@ account_validity: # by the Matrix Identity Service API specification: # https://matrix.org/docs/spec/identity_service/latest # +# If a delegate is specified, the config option public_baseurl must also be filled out. +# account_threepid_delegates: #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process @@ -1938,9 +1936,9 @@ sso: # phishing attacks from evil.site. To avoid this, include a slash after the # hostname: "https://my.client/". # - # The login fallback page (used by clients that don't natively support the - # required login flows) is automatically whitelisted in addition to any URLs - # in this list. + # If public_baseurl is set, then the login fallback page (used by clients + # that don't natively support the required login flows) is whitelisted in + # addition to any URLs in this list. # # By default, this list is empty. # diff --git a/synapse/api/urls.py b/synapse/api/urls.py index e36aeef31f..6379c86dde 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -42,6 +42,8 @@ class ConsentURIBuilder: """ if hs_config.form_secret is None: raise ConfigError("form_secret not set in config") + if hs_config.public_baseurl is None: + raise ConfigError("public_baseurl not set in config") self._hmac_secret = hs_config.form_secret.encode("utf-8") self._public_baseurl = hs_config.public_baseurl diff --git a/synapse/config/cas.py b/synapse/config/cas.py index b226890c2a..aaa7eba110 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config +from ._base import Config, ConfigError class CasConfig(Config): @@ -30,13 +30,15 @@ class CasConfig(Config): if self.cas_enabled: self.cas_server_url = cas_config["server_url"] - public_base_url = cas_config.get("service_url") or self.public_baseurl - if public_base_url[-1] != "/": - public_base_url += "/" + + # The public baseurl is required because it is used by the redirect + # template. + public_baseurl = self.public_baseurl + if not public_baseurl: + raise ConfigError("cas_config requires a public_baseurl to be set") + # TODO Update this to a _synapse URL. - self.cas_service_url = ( - public_base_url + "_matrix/client/r0/login/cas/ticket" - ) + self.cas_service_url = public_baseurl + "_matrix/client/r0/login/cas/ticket" self.cas_displayname_attribute = cas_config.get("displayname_attribute") self.cas_required_attributes = cas_config.get("required_attributes") or {} else: diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 6a487afd34..d4328c46b9 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -166,6 +166,11 @@ class EmailConfig(Config): if not self.email_notif_from: missing.append("email.notif_from") + # public_baseurl is required to build password reset and validation links that + # will be emailed to users + if config.get("public_baseurl") is None: + missing.append("public_baseurl") + if missing: raise ConfigError( MISSING_PASSWORD_RESET_CONFIG_ERROR % (", ".join(missing),) @@ -264,6 +269,9 @@ class EmailConfig(Config): if not self.email_notif_from: missing.append("email.notif_from") + if config.get("public_baseurl") is None: + missing.append("public_baseurl") + if missing: raise ConfigError( "email.enable_notifs is True but required keys are missing: %s" diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 4c24c50629..4d0f24a9d5 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -53,7 +53,10 @@ class OIDCConfig(Config): "Multiple OIDC providers have the idp_id %r." % idp_id ) - self.oidc_callback_url = self.public_baseurl + "_synapse/client/oidc/callback" + public_baseurl = self.public_baseurl + if public_baseurl is None: + raise ConfigError("oidc_config requires a public_baseurl to be set") + self.oidc_callback_url = public_baseurl + "_synapse/client/oidc/callback" @property def oidc_enabled(self) -> bool: diff --git a/synapse/config/registration.py b/synapse/config/registration.py index ac48913a0b..eb650af7fb 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -49,6 +49,10 @@ class AccountValidityConfig(Config): self.startup_job_max_delta = self.period * 10.0 / 100.0 + if self.renew_by_email_enabled: + if "public_baseurl" not in synapse_config: + raise ConfigError("Can't send renewal emails without 'public_baseurl'") + template_dir = config.get("template_dir") if not template_dir: @@ -105,6 +109,13 @@ class RegistrationConfig(Config): account_threepid_delegates = config.get("account_threepid_delegates") or {} self.account_threepid_delegate_email = account_threepid_delegates.get("email") self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn") + if self.account_threepid_delegate_msisdn and not self.public_baseurl: + raise ConfigError( + "The configuration option `public_baseurl` is required if " + "`account_threepid_delegate.msisdn` is set, such that " + "clients know where to submit validation tokens to. Please " + "configure `public_baseurl`." + ) self.default_identity_server = config.get("default_identity_server") self.allow_guest_access = config.get("allow_guest_access", False) @@ -227,9 +238,8 @@ class RegistrationConfig(Config): # send an email to the account's email address with a renewal link. By # default, no such emails are sent. # - # If you enable this setting, you will also need to fill out the 'email' - # configuration section. You should also check that 'public_baseurl' is set - # correctly. + # If you enable this setting, you will also need to fill out the 'email' and + # 'public_baseurl' configuration sections. # #renew_at: 1w @@ -320,7 +330,8 @@ class RegistrationConfig(Config): # The identity server which we suggest that clients should use when users log # in on this server. # - # (By default, no suggestion is made, so it is left up to the client.) + # (By default, no suggestion is made, so it is left up to the client. + # This setting is ignored unless public_baseurl is also set.) # #default_identity_server: https://matrix.org @@ -345,6 +356,8 @@ class RegistrationConfig(Config): # by the Matrix Identity Service API specification: # https://matrix.org/docs/spec/identity_service/latest # + # If a delegate is specified, the config option public_baseurl must also be filled out. + # account_threepid_delegates: #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index ad865a667f..7226abd829 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -189,6 +189,8 @@ class SAML2Config(Config): import saml2 public_baseurl = self.public_baseurl + if public_baseurl is None: + raise ConfigError("saml2_config requires a public_baseurl to be set") if self.saml2_grandfathered_mxid_source_attribute: optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute) diff --git a/synapse/config/server.py b/synapse/config/server.py index 47a0370173..5d72cf2d82 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -161,11 +161,7 @@ class ServerConfig(Config): self.print_pidfile = config.get("print_pidfile") self.user_agent_suffix = config.get("user_agent_suffix") self.use_frozen_dicts = config.get("use_frozen_dicts", False) - self.public_baseurl = config.get("public_baseurl") or "https://%s/" % ( - self.server_name, - ) - if self.public_baseurl[-1] != "/": - self.public_baseurl += "/" + self.public_baseurl = config.get("public_baseurl") # Whether to enable user presence. self.use_presence = config.get("use_presence", True) @@ -321,6 +317,9 @@ class ServerConfig(Config): # Always blacklist 0.0.0.0, :: self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) + if self.public_baseurl is not None: + if self.public_baseurl[-1] != "/": + self.public_baseurl += "/" self.start_pushers = config.get("start_pushers", True) # (undocumented) option for torturing the worker-mode replication a bit, @@ -748,10 +747,6 @@ class ServerConfig(Config): # Otherwise, it should be the URL to reach Synapse's client HTTP listener (see # 'listeners' below). # - # If this is left unset, it defaults to 'https:///'. (Note that - # that will not work unless you configure Synapse or a reverse-proxy to listen - # on port 443.) - # #public_baseurl: https://example.com/ # Set the soft limit on the number of file descriptors synapse can use diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 6c60c6fea4..19bdfd462b 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -64,8 +64,11 @@ class SSOConfig(Config): # gracefully to the client). This would make it pointless to ask the user for # confirmation, since the URL the confirmation page would be showing wouldn't be # the client's. - login_fallback_url = self.public_baseurl + "_matrix/static/client/login" - self.sso_client_whitelist.append(login_fallback_url) + # public_baseurl is an optional setting, so we only add the fallback's URL to the + # list if it's provided (because we can't figure out what that URL is otherwise). + if self.public_baseurl: + login_fallback_url = self.public_baseurl + "_matrix/static/client/login" + self.sso_client_whitelist.append(login_fallback_url) def generate_config_section(self, **kwargs): return """\ @@ -83,9 +86,9 @@ class SSOConfig(Config): # phishing attacks from evil.site. To avoid this, include a slash after the # hostname: "https://my.client/". # - # The login fallback page (used by clients that don't natively support the - # required login flows) is automatically whitelisted in addition to any URLs - # in this list. + # If public_baseurl is set, then the login fallback page (used by clients + # that don't natively support the required login flows) is whitelisted in + # addition to any URLs in this list. # # By default, this list is empty. # diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 4f7137539b..8fc1e8b91c 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -504,6 +504,10 @@ class IdentityHandler(BaseHandler): except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") + # It is already checked that public_baseurl is configured since this code + # should only be used if account_threepid_delegate_msisdn is true. + assert self.hs.config.public_baseurl + # we need to tell the client to send the token back to us, since it doesn't # otherwise know where to send it, so add submit_url response parameter # (see also MSC2078) diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index 241fe746d9..f591cc6c5c 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -34,6 +34,10 @@ class WellKnownBuilder: self._config = hs.config def get_well_known(self): + # if we don't have a public_baseurl, we can't help much here. + if self._config.public_baseurl is None: + return None + result = {"m.homeserver": {"base_url": self._config.public_baseurl}} if self._config.default_identity_server: diff --git a/synapse/util/templates.py b/synapse/util/templates.py index 7e5109d206..392dae4a40 100644 --- a/synapse/util/templates.py +++ b/synapse/util/templates.py @@ -17,7 +17,7 @@ import time import urllib.parse -from typing import TYPE_CHECKING, Callable, Iterable, Union +from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union import jinja2 @@ -74,14 +74,23 @@ def build_jinja_env( return env -def _create_mxc_to_http_filter(public_baseurl: str) -> Callable: +def _create_mxc_to_http_filter( + public_baseurl: Optional[str], +) -> Callable[[str, int, int, str], str]: """Create and return a jinja2 filter that converts MXC urls to HTTP Args: public_baseurl: The public, accessible base URL of the homeserver """ - def mxc_to_http_filter(value, width, height, resize_method="crop"): + def mxc_to_http_filter( + value: str, width: int, height: int, resize_method: str = "crop" + ) -> str: + if not public_baseurl: + raise RuntimeError( + "public_baseurl must be set in the homeserver config to convert MXC URLs to HTTP URLs." + ) + if value[0:6] != "mxc://": return "" diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 66dfdaffbc..bfcb786af8 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -672,10 +672,12 @@ class CASTestCase(unittest.HomeserverTestCase): self.redirect_path = "_synapse/client/login/sso/redirect/confirm" config = self.default_config() + config["public_baseurl"] = ( + config.get("public_baseurl") or "https://matrix.goodserver.com:8448" + ) config["cas_config"] = { "enabled": True, "server_url": CAS_SERVER, - "service_url": "https://matrix.goodserver.com:8448", } cas_user_id = "username" diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index c5e44af9f7..14de0921be 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -40,3 +40,12 @@ class WellKnownTests(unittest.HomeserverTestCase): "m.identity_server": {"base_url": "https://testis"}, }, ) + + def test_well_known_no_public_baseurl(self): + self.hs.config.public_baseurl = None + + channel = self.make_request( + "GET", "/.well-known/matrix/client", shorthand=False + ) + + self.assertEqual(channel.code, 404) diff --git a/tests/utils.py b/tests/utils.py index 68033d7535..840b657f82 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -159,6 +159,7 @@ def default_config(name, parse=False): }, "rc_3pid_validation": {"per_second": 10000, "burst_count": 10000}, "saml2_enabled": False, + "public_baseurl": None, "default_identity_server": None, "key_refresh_interval": 24 * 60 * 60 * 1000, "old_signing_keys": {}, -- cgit 1.5.1