diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/federation/transport/server.py | 6 | ||||
-rw-r--r-- | synapse/handlers/identity.py | 163 | ||||
-rw-r--r-- | synapse/http/server.py | 13 | ||||
-rw-r--r-- | synapse/http/servlet.py | 6 | ||||
-rw-r--r-- | synapse/logging/opentracing.py | 22 | ||||
-rw-r--r-- | synapse/replication/http/_base.py | 16 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/account.py | 46 |
7 files changed, 202 insertions, 70 deletions
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index f9930b6460..132a8fb5e6 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -342,7 +342,11 @@ class BaseFederationServlet(object): continue server.register_paths( - method, (pattern,), self._wrap(code), self.__class__.__name__ + method, + (pattern,), + self._wrap(code), + self.__class__.__name__, + trace=False, ) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index d199521b58..583b612dd9 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -61,21 +61,76 @@ class IdentityHandler(BaseHandler): return False return True + def _extract_items_from_creds_dict(self, creds): + """ + Retrieve entries from a "credentials" dictionary + + Args: + creds (dict[str, str]): Dictionary of credentials that contain the following keys: + * client_secret|clientSecret: A unique secret str provided by the client + * id_server|idServer: the domain of the identity server to query + * id_access_token: The access token to authenticate to the identity + server with. + + Returns: + tuple(str, str, str|None): A tuple containing the client_secret, the id_server, + and the id_access_token value if available. + """ + client_secret = creds.get("client_secret") or creds.get("clientSecret") + if not client_secret: + raise SynapseError( + 400, "No client_secret in creds", errcode=Codes.MISSING_PARAM + ) + + id_server = creds.get("id_server") or creds.get("idServer") + if not id_server: + raise SynapseError( + 400, "No id_server in creds", errcode=Codes.MISSING_PARAM + ) + + id_access_token = creds.get("id_access_token") + return client_secret, id_server, id_access_token + @defer.inlineCallbacks - def threepid_from_creds(self, creds): - if "id_server" in creds: - id_server = creds["id_server"] - elif "idServer" in creds: - id_server = creds["idServer"] - else: - raise SynapseError(400, "No id_server in creds") + def threepid_from_creds(self, creds, use_v2=True): + """ + Retrieve and validate a threepid identitier from a "credentials" dictionary + + Args: + creds (dict[str, str]): Dictionary of credentials that contain the following keys: + * client_secret|clientSecret: A unique secret str provided by the client + * id_server|idServer: the domain of the identity server to query + * id_access_token: The access token to authenticate to the identity + server with. Required if use_v2 is true + use_v2 (bool): Whether to use v2 Identity Service API endpoints + + Returns: + Deferred[dict[str,str|int]|None]: A dictionary consisting of response params to + the /getValidated3pid endpoint of the Identity Service API, or None if the + threepid was not found + """ + client_secret, id_server, id_access_token = self._extract_items_from_creds_dict( + creds + ) - if "client_secret" in creds: - client_secret = creds["client_secret"] - elif "clientSecret" in creds: - client_secret = creds["clientSecret"] + # If an id_access_token is not supplied, force usage of v1 + if id_access_token is None: + use_v2 = False + + query_params = {"sid": creds["sid"], "client_secret": client_secret} + + # Decide which API endpoint URLs and query parameters to use + if use_v2: + url = "https://%s%s" % ( + id_server, + "/_matrix/identity/v2/3pid/getValidated3pid", + ) + query_params["id_access_token"] = id_access_token else: - raise SynapseError(400, "No client_secret in creds") + url = "https://%s%s" % ( + id_server, + "/_matrix/identity/api/v1/3pid/getValidated3pid", + ) if not self._should_trust_id_server(id_server): logger.warn( @@ -85,43 +140,55 @@ class IdentityHandler(BaseHandler): return None try: - data = yield self.http_client.get_json( - "https://%s%s" - % (id_server, "/_matrix/identity/api/v1/3pid/getValidated3pid"), - {"sid": creds["sid"], "client_secret": client_secret}, - ) + data = yield self.http_client.get_json(url, query_params) + return data if "medium" in data else None except HttpResponseException as e: - logger.info("getValidated3pid failed with Matrix error: %r", e) - raise e.to_synapse_error() + if e.code != 404 or not use_v2: + # Generic failure + logger.info("getValidated3pid failed with Matrix error: %r", e) + raise e.to_synapse_error() - if "medium" in data: - return data - return None + # This identity server is too old to understand Identity Service API v2 + # Attempt v1 endpoint + logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", url) + return (yield self.threepid_from_creds(creds, use_v2=False)) @defer.inlineCallbacks - def bind_threepid(self, creds, mxid): + def bind_threepid(self, creds, mxid, use_v2=True): + """Bind a 3PID to an identity server + + Args: + creds (dict[str, str]): Dictionary of credentials that contain the following keys: + * client_secret|clientSecret: A unique secret str provided by the client + * id_server|idServer: the domain of the identity server to query + * id_access_token: The access token to authenticate to the identity + server with. Required if use_v2 is true + mxid (str): The MXID to bind the 3PID to + use_v2 (bool): Whether to use v2 Identity Service API endpoints + + Returns: + Deferred[dict]: The response from the identity server + """ logger.debug("binding threepid %r to %s", creds, mxid) - data = None - if "id_server" in creds: - id_server = creds["id_server"] - elif "idServer" in creds: - id_server = creds["idServer"] - else: - raise SynapseError(400, "No id_server in creds") + client_secret, id_server, id_access_token = self._extract_items_from_creds_dict( + creds + ) + + # If an id_access_token is not supplied, force usage of v1 + if id_access_token is None: + use_v2 = False - if "client_secret" in creds: - client_secret = creds["client_secret"] - elif "clientSecret" in creds: - client_secret = creds["clientSecret"] + # Decide which API endpoint URLs to use + bind_data = {"sid": creds["sid"], "client_secret": client_secret, "mxid": mxid} + if use_v2: + bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,) + bind_data["id_access_token"] = id_access_token else: - raise SynapseError(400, "No client_secret in creds") + bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,) try: - data = yield self.http_client.post_json_get_json( - "https://%s%s" % (id_server, "/_matrix/identity/api/v1/3pid/bind"), - {"sid": creds["sid"], "client_secret": client_secret, "mxid": mxid}, - ) + data = yield self.http_client.post_json_get_json(bind_url, bind_data) logger.debug("bound threepid %r to %s", creds, mxid) # Remember where we bound the threepid @@ -131,13 +198,23 @@ class IdentityHandler(BaseHandler): address=data["address"], id_server=id_server, ) + + return data + except HttpResponseException as e: + if e.code != 404 or not use_v2: + logger.error("3PID bind failed with Matrix error: %r", e) + raise e.to_synapse_error() except CodeMessageException as e: data = json.loads(e.msg) # XXX WAT? - return data + return data + + logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url) + return (yield self.bind_threepid(creds, mxid, use_v2=False)) @defer.inlineCallbacks def try_unbind_threepid(self, mxid, threepid): - """Removes a binding from an identity server + """Attempt to remove a 3PID from an identity server, or if one is not provided, all + identity servers we're aware the binding is present on Args: mxid (str): Matrix user ID of binding to be removed @@ -188,6 +265,8 @@ class IdentityHandler(BaseHandler): server doesn't support unbinding """ url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,) + url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii") + content = { "mxid": mxid, "threepid": {"medium": threepid["medium"], "address": threepid["address"]}, @@ -199,7 +278,7 @@ class IdentityHandler(BaseHandler): auth_headers = self.federation_http_client.build_auth_headers( destination=None, method="POST", - url_bytes="/_matrix/identity/api/v1/3pid/unbind".encode("ascii"), + url_bytes=url_bytes, content=content, destination_is=id_server, ) diff --git a/synapse/http/server.py b/synapse/http/server.py index e6f351ba3b..cb9158fe1b 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -40,6 +40,7 @@ from synapse.api.errors import ( UnrecognizedRequestError, ) from synapse.logging.context import preserve_fn +from synapse.logging.opentracing import trace_servlet from synapse.util.caches import intern_dict logger = logging.getLogger(__name__) @@ -257,7 +258,9 @@ class JsonResource(HttpServer, resource.Resource): self.path_regexs = {} self.hs = hs - def register_paths(self, method, path_patterns, callback, servlet_classname): + def register_paths( + self, method, path_patterns, callback, servlet_classname, trace=True + ): """ Registers a request handler against a regular expression. Later request URLs are checked against these regular expressions in order to identify an appropriate @@ -273,8 +276,16 @@ class JsonResource(HttpServer, resource.Resource): servlet_classname (str): The name of the handler to be used in prometheus and opentracing logs. + + trace (bool): Whether we should start a span to trace the servlet. """ method = method.encode("utf-8") # method is bytes on py3 + + if trace: + # We don't extract the context from the servlet because we can't + # trust the sender + callback = trace_servlet(servlet_classname)(callback) + for path_pattern in path_patterns: logger.debug("Registering for %s %s", method, path_pattern.pattern) self.path_regexs.setdefault(method, []).append( diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index c186b31f59..274c1a6a87 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -20,7 +20,6 @@ import logging from canonicaljson import json from synapse.api.errors import Codes, SynapseError -from synapse.logging.opentracing import trace_servlet logger = logging.getLogger(__name__) @@ -298,10 +297,7 @@ class RestServlet(object): servlet_classname = self.__class__.__name__ method_handler = getattr(self, "on_%s" % (method,)) http_server.register_paths( - method, - patterns, - trace_servlet(servlet_classname)(method_handler), - servlet_classname, + method, patterns, method_handler, servlet_classname ) else: diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 256b972aaa..2c34b54702 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -319,7 +319,7 @@ def whitelisted_homeserver(destination): Args: destination (str) """ - _homeserver_whitelist + if _homeserver_whitelist: return _homeserver_whitelist.match(destination) return False @@ -493,6 +493,11 @@ def inject_active_span_twisted_headers(headers, destination, check_destination=T Args: headers (twisted.web.http_headers.Headers) + destination (str): address of entity receiving the span context. If check_destination + is true the context will only be injected if the destination matches the + opentracing whitelist + check_destination (bool): If false, destination will be ignored and the context + will always be injected. span (opentracing.Span) Returns: @@ -525,6 +530,11 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True): Args: headers (dict) + destination (str): address of entity receiving the span context. If check_destination + is true the context will only be injected if the destination matches the + opentracing whitelist + check_destination (bool): If false, destination will be ignored and the context + will always be injected. span (opentracing.Span) Returns: @@ -537,7 +547,7 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True): here: https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py """ - if not whitelisted_homeserver(destination): + if check_destination and not whitelisted_homeserver(destination): return span = opentracing.tracer.active_span @@ -556,9 +566,11 @@ def inject_active_span_text_map(carrier, destination, check_destination=True): Args: carrier (dict) - destination (str): the name of the remote server. The span context - will only be injected if the destination matches the homeserver_whitelist - or destination is None. + destination (str): address of entity receiving the span context. If check_destination + is true the context will only be injected if the destination matches the + opentracing whitelist + check_destination (bool): If false, destination will be ignored and the context + will always be injected. Returns: In-place modification of carrier diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index c4be9273f6..afc9a8ff29 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -22,13 +22,13 @@ from six.moves import urllib from twisted.internet import defer -import synapse.logging.opentracing as opentracing from synapse.api.errors import ( CodeMessageException, HttpResponseException, RequestSendFailed, SynapseError, ) +from synapse.logging.opentracing import inject_active_span_byte_dict, trace_servlet from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import random_string @@ -167,9 +167,7 @@ class ReplicationEndpoint(object): # the master, and so whether we should clean up or not. while True: headers = {} - opentracing.inject_active_span_byte_dict( - headers, None, check_destination=False - ) + inject_active_span_byte_dict(headers, None, check_destination=False) try: result = yield request_func(uri, data, headers=headers) break @@ -210,13 +208,11 @@ class ReplicationEndpoint(object): args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args) pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args)) + handler = trace_servlet(self.__class__.__name__, extract_context=True)(handler) + # We don't let register paths trace this servlet using the default tracing + # options because we wish to extract the context explicitly. http_server.register_paths( - method, - [pattern], - opentracing.trace_servlet(self.__class__.__name__, extract_context=True)( - handler - ), - self.__class__.__name__, + method, [pattern], handler, self.__class__.__name__, trace=False ) def _cached_handler(self, request, txn_id, **kwargs): diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 0620a4d0cf..e9cc953bdd 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -542,15 +542,16 @@ class ThreepidRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) - threePidCreds = body.get("threePidCreds") - threePidCreds = body.get("three_pid_creds", threePidCreds) - if threePidCreds is None: - raise SynapseError(400, "Missing param", Codes.MISSING_PARAM) + threepid_creds = body.get("threePidCreds") or body.get("three_pid_creds") + if threepid_creds is None: + raise SynapseError( + 400, "Missing param three_pid_creds", Codes.MISSING_PARAM + ) requester = yield self.auth.get_user_by_req(request) user_id = requester.user.to_string() - threepid = yield self.identity_handler.threepid_from_creds(threePidCreds) + threepid = yield self.identity_handler.threepid_from_creds(threepid_creds) if not threepid: raise SynapseError(400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED) @@ -566,11 +567,43 @@ class ThreepidRestServlet(RestServlet): if "bind" in body and body["bind"]: logger.debug("Binding threepid %s to %s", threepid, user_id) - yield self.identity_handler.bind_threepid(threePidCreds, user_id) + yield self.identity_handler.bind_threepid(threepid_creds, user_id) return 200, {} +class ThreepidUnbindRestServlet(RestServlet): + PATTERNS = client_patterns("/account/3pid/unbind$") + + def __init__(self, hs): + super(ThreepidUnbindRestServlet, self).__init__() + self.hs = hs + self.identity_handler = hs.get_handlers().identity_handler + self.auth = hs.get_auth() + self.datastore = self.hs.get_datastore() + + @defer.inlineCallbacks + def on_POST(self, request): + """Unbind the given 3pid from a specific identity server, or identity servers that are + known to have this 3pid bound + """ + requester = yield self.auth.get_user_by_req(request) + body = parse_json_object_from_request(request) + assert_params_in_dict(body, ["medium", "address"]) + + medium = body.get("medium") + address = body.get("address") + id_server = body.get("id_server") + + # Attempt to unbind the threepid from an identity server. If id_server is None, try to + # unbind from all identity servers this threepid has been added to in the past + result = yield self.identity_handler.try_unbind_threepid( + requester.user.to_string(), + {"address": address, "medium": medium, "id_server": id_server}, + ) + return 200, {"id_server_unbind_result": "success" if result else "no-support"} + + class ThreepidDeleteRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/delete$") @@ -629,5 +662,6 @@ def register_servlets(hs, http_server): EmailThreepidRequestTokenRestServlet(hs).register(http_server) MsisdnThreepidRequestTokenRestServlet(hs).register(http_server) ThreepidRestServlet(hs).register(http_server) + ThreepidUnbindRestServlet(hs).register(http_server) ThreepidDeleteRestServlet(hs).register(http_server) WhoamiRestServlet(hs).register(http_server) |