diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index f9930b6460..132a8fb5e6 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -342,7 +342,11 @@ class BaseFederationServlet(object):
continue
server.register_paths(
- method, (pattern,), self._wrap(code), self.__class__.__name__
+ method,
+ (pattern,),
+ self._wrap(code),
+ self.__class__.__name__,
+ trace=False,
)
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index d199521b58..583b612dd9 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -61,21 +61,76 @@ class IdentityHandler(BaseHandler):
return False
return True
+ def _extract_items_from_creds_dict(self, creds):
+ """
+ Retrieve entries from a "credentials" dictionary
+
+ Args:
+ creds (dict[str, str]): Dictionary of credentials that contain the following keys:
+ * client_secret|clientSecret: A unique secret str provided by the client
+ * id_server|idServer: the domain of the identity server to query
+ * id_access_token: The access token to authenticate to the identity
+ server with.
+
+ Returns:
+ tuple(str, str, str|None): A tuple containing the client_secret, the id_server,
+ and the id_access_token value if available.
+ """
+ client_secret = creds.get("client_secret") or creds.get("clientSecret")
+ if not client_secret:
+ raise SynapseError(
+ 400, "No client_secret in creds", errcode=Codes.MISSING_PARAM
+ )
+
+ id_server = creds.get("id_server") or creds.get("idServer")
+ if not id_server:
+ raise SynapseError(
+ 400, "No id_server in creds", errcode=Codes.MISSING_PARAM
+ )
+
+ id_access_token = creds.get("id_access_token")
+ return client_secret, id_server, id_access_token
+
@defer.inlineCallbacks
- def threepid_from_creds(self, creds):
- if "id_server" in creds:
- id_server = creds["id_server"]
- elif "idServer" in creds:
- id_server = creds["idServer"]
- else:
- raise SynapseError(400, "No id_server in creds")
+ def threepid_from_creds(self, creds, use_v2=True):
+ """
+ Retrieve and validate a threepid identitier from a "credentials" dictionary
+
+ Args:
+ creds (dict[str, str]): Dictionary of credentials that contain the following keys:
+ * client_secret|clientSecret: A unique secret str provided by the client
+ * id_server|idServer: the domain of the identity server to query
+ * id_access_token: The access token to authenticate to the identity
+ server with. Required if use_v2 is true
+ use_v2 (bool): Whether to use v2 Identity Service API endpoints
+
+ Returns:
+ Deferred[dict[str,str|int]|None]: A dictionary consisting of response params to
+ the /getValidated3pid endpoint of the Identity Service API, or None if the
+ threepid was not found
+ """
+ client_secret, id_server, id_access_token = self._extract_items_from_creds_dict(
+ creds
+ )
- if "client_secret" in creds:
- client_secret = creds["client_secret"]
- elif "clientSecret" in creds:
- client_secret = creds["clientSecret"]
+ # If an id_access_token is not supplied, force usage of v1
+ if id_access_token is None:
+ use_v2 = False
+
+ query_params = {"sid": creds["sid"], "client_secret": client_secret}
+
+ # Decide which API endpoint URLs and query parameters to use
+ if use_v2:
+ url = "https://%s%s" % (
+ id_server,
+ "/_matrix/identity/v2/3pid/getValidated3pid",
+ )
+ query_params["id_access_token"] = id_access_token
else:
- raise SynapseError(400, "No client_secret in creds")
+ url = "https://%s%s" % (
+ id_server,
+ "/_matrix/identity/api/v1/3pid/getValidated3pid",
+ )
if not self._should_trust_id_server(id_server):
logger.warn(
@@ -85,43 +140,55 @@ class IdentityHandler(BaseHandler):
return None
try:
- data = yield self.http_client.get_json(
- "https://%s%s"
- % (id_server, "/_matrix/identity/api/v1/3pid/getValidated3pid"),
- {"sid": creds["sid"], "client_secret": client_secret},
- )
+ data = yield self.http_client.get_json(url, query_params)
+ return data if "medium" in data else None
except HttpResponseException as e:
- logger.info("getValidated3pid failed with Matrix error: %r", e)
- raise e.to_synapse_error()
+ if e.code != 404 or not use_v2:
+ # Generic failure
+ logger.info("getValidated3pid failed with Matrix error: %r", e)
+ raise e.to_synapse_error()
- if "medium" in data:
- return data
- return None
+ # This identity server is too old to understand Identity Service API v2
+ # Attempt v1 endpoint
+ logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", url)
+ return (yield self.threepid_from_creds(creds, use_v2=False))
@defer.inlineCallbacks
- def bind_threepid(self, creds, mxid):
+ def bind_threepid(self, creds, mxid, use_v2=True):
+ """Bind a 3PID to an identity server
+
+ Args:
+ creds (dict[str, str]): Dictionary of credentials that contain the following keys:
+ * client_secret|clientSecret: A unique secret str provided by the client
+ * id_server|idServer: the domain of the identity server to query
+ * id_access_token: The access token to authenticate to the identity
+ server with. Required if use_v2 is true
+ mxid (str): The MXID to bind the 3PID to
+ use_v2 (bool): Whether to use v2 Identity Service API endpoints
+
+ Returns:
+ Deferred[dict]: The response from the identity server
+ """
logger.debug("binding threepid %r to %s", creds, mxid)
- data = None
- if "id_server" in creds:
- id_server = creds["id_server"]
- elif "idServer" in creds:
- id_server = creds["idServer"]
- else:
- raise SynapseError(400, "No id_server in creds")
+ client_secret, id_server, id_access_token = self._extract_items_from_creds_dict(
+ creds
+ )
+
+ # If an id_access_token is not supplied, force usage of v1
+ if id_access_token is None:
+ use_v2 = False
- if "client_secret" in creds:
- client_secret = creds["client_secret"]
- elif "clientSecret" in creds:
- client_secret = creds["clientSecret"]
+ # Decide which API endpoint URLs to use
+ bind_data = {"sid": creds["sid"], "client_secret": client_secret, "mxid": mxid}
+ if use_v2:
+ bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,)
+ bind_data["id_access_token"] = id_access_token
else:
- raise SynapseError(400, "No client_secret in creds")
+ bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,)
try:
- data = yield self.http_client.post_json_get_json(
- "https://%s%s" % (id_server, "/_matrix/identity/api/v1/3pid/bind"),
- {"sid": creds["sid"], "client_secret": client_secret, "mxid": mxid},
- )
+ data = yield self.http_client.post_json_get_json(bind_url, bind_data)
logger.debug("bound threepid %r to %s", creds, mxid)
# Remember where we bound the threepid
@@ -131,13 +198,23 @@ class IdentityHandler(BaseHandler):
address=data["address"],
id_server=id_server,
)
+
+ return data
+ except HttpResponseException as e:
+ if e.code != 404 or not use_v2:
+ logger.error("3PID bind failed with Matrix error: %r", e)
+ raise e.to_synapse_error()
except CodeMessageException as e:
data = json.loads(e.msg) # XXX WAT?
- return data
+ return data
+
+ logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url)
+ return (yield self.bind_threepid(creds, mxid, use_v2=False))
@defer.inlineCallbacks
def try_unbind_threepid(self, mxid, threepid):
- """Removes a binding from an identity server
+ """Attempt to remove a 3PID from an identity server, or if one is not provided, all
+ identity servers we're aware the binding is present on
Args:
mxid (str): Matrix user ID of binding to be removed
@@ -188,6 +265,8 @@ class IdentityHandler(BaseHandler):
server doesn't support unbinding
"""
url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
+ url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii")
+
content = {
"mxid": mxid,
"threepid": {"medium": threepid["medium"], "address": threepid["address"]},
@@ -199,7 +278,7 @@ class IdentityHandler(BaseHandler):
auth_headers = self.federation_http_client.build_auth_headers(
destination=None,
method="POST",
- url_bytes="/_matrix/identity/api/v1/3pid/unbind".encode("ascii"),
+ url_bytes=url_bytes,
content=content,
destination_is=id_server,
)
diff --git a/synapse/http/server.py b/synapse/http/server.py
index e6f351ba3b..cb9158fe1b 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -40,6 +40,7 @@ from synapse.api.errors import (
UnrecognizedRequestError,
)
from synapse.logging.context import preserve_fn
+from synapse.logging.opentracing import trace_servlet
from synapse.util.caches import intern_dict
logger = logging.getLogger(__name__)
@@ -257,7 +258,9 @@ class JsonResource(HttpServer, resource.Resource):
self.path_regexs = {}
self.hs = hs
- def register_paths(self, method, path_patterns, callback, servlet_classname):
+ def register_paths(
+ self, method, path_patterns, callback, servlet_classname, trace=True
+ ):
"""
Registers a request handler against a regular expression. Later request URLs are
checked against these regular expressions in order to identify an appropriate
@@ -273,8 +276,16 @@ class JsonResource(HttpServer, resource.Resource):
servlet_classname (str): The name of the handler to be used in prometheus
and opentracing logs.
+
+ trace (bool): Whether we should start a span to trace the servlet.
"""
method = method.encode("utf-8") # method is bytes on py3
+
+ if trace:
+ # We don't extract the context from the servlet because we can't
+ # trust the sender
+ callback = trace_servlet(servlet_classname)(callback)
+
for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append(
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index c186b31f59..274c1a6a87 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -20,7 +20,6 @@ import logging
from canonicaljson import json
from synapse.api.errors import Codes, SynapseError
-from synapse.logging.opentracing import trace_servlet
logger = logging.getLogger(__name__)
@@ -298,10 +297,7 @@ class RestServlet(object):
servlet_classname = self.__class__.__name__
method_handler = getattr(self, "on_%s" % (method,))
http_server.register_paths(
- method,
- patterns,
- trace_servlet(servlet_classname)(method_handler),
- servlet_classname,
+ method, patterns, method_handler, servlet_classname
)
else:
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 256b972aaa..2c34b54702 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -319,7 +319,7 @@ def whitelisted_homeserver(destination):
Args:
destination (str)
"""
- _homeserver_whitelist
+
if _homeserver_whitelist:
return _homeserver_whitelist.match(destination)
return False
@@ -493,6 +493,11 @@ def inject_active_span_twisted_headers(headers, destination, check_destination=T
Args:
headers (twisted.web.http_headers.Headers)
+ destination (str): address of entity receiving the span context. If check_destination
+ is true the context will only be injected if the destination matches the
+ opentracing whitelist
+ check_destination (bool): If false, destination will be ignored and the context
+ will always be injected.
span (opentracing.Span)
Returns:
@@ -525,6 +530,11 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
Args:
headers (dict)
+ destination (str): address of entity receiving the span context. If check_destination
+ is true the context will only be injected if the destination matches the
+ opentracing whitelist
+ check_destination (bool): If false, destination will be ignored and the context
+ will always be injected.
span (opentracing.Span)
Returns:
@@ -537,7 +547,7 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
here:
https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
"""
- if not whitelisted_homeserver(destination):
+ if check_destination and not whitelisted_homeserver(destination):
return
span = opentracing.tracer.active_span
@@ -556,9 +566,11 @@ def inject_active_span_text_map(carrier, destination, check_destination=True):
Args:
carrier (dict)
- destination (str): the name of the remote server. The span context
- will only be injected if the destination matches the homeserver_whitelist
- or destination is None.
+ destination (str): address of entity receiving the span context. If check_destination
+ is true the context will only be injected if the destination matches the
+ opentracing whitelist
+ check_destination (bool): If false, destination will be ignored and the context
+ will always be injected.
Returns:
In-place modification of carrier
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index c4be9273f6..afc9a8ff29 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -22,13 +22,13 @@ from six.moves import urllib
from twisted.internet import defer
-import synapse.logging.opentracing as opentracing
from synapse.api.errors import (
CodeMessageException,
HttpResponseException,
RequestSendFailed,
SynapseError,
)
+from synapse.logging.opentracing import inject_active_span_byte_dict, trace_servlet
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
@@ -167,9 +167,7 @@ class ReplicationEndpoint(object):
# the master, and so whether we should clean up or not.
while True:
headers = {}
- opentracing.inject_active_span_byte_dict(
- headers, None, check_destination=False
- )
+ inject_active_span_byte_dict(headers, None, check_destination=False)
try:
result = yield request_func(uri, data, headers=headers)
break
@@ -210,13 +208,11 @@ class ReplicationEndpoint(object):
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
+ handler = trace_servlet(self.__class__.__name__, extract_context=True)(handler)
+ # We don't let register paths trace this servlet using the default tracing
+ # options because we wish to extract the context explicitly.
http_server.register_paths(
- method,
- [pattern],
- opentracing.trace_servlet(self.__class__.__name__, extract_context=True)(
- handler
- ),
- self.__class__.__name__,
+ method, [pattern], handler, self.__class__.__name__, trace=False
)
def _cached_handler(self, request, txn_id, **kwargs):
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 0620a4d0cf..e9cc953bdd 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -542,15 +542,16 @@ class ThreepidRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
- threePidCreds = body.get("threePidCreds")
- threePidCreds = body.get("three_pid_creds", threePidCreds)
- if threePidCreds is None:
- raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
+ threepid_creds = body.get("threePidCreds") or body.get("three_pid_creds")
+ if threepid_creds is None:
+ raise SynapseError(
+ 400, "Missing param three_pid_creds", Codes.MISSING_PARAM
+ )
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
- threepid = yield self.identity_handler.threepid_from_creds(threePidCreds)
+ threepid = yield self.identity_handler.threepid_from_creds(threepid_creds)
if not threepid:
raise SynapseError(400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED)
@@ -566,11 +567,43 @@ class ThreepidRestServlet(RestServlet):
if "bind" in body and body["bind"]:
logger.debug("Binding threepid %s to %s", threepid, user_id)
- yield self.identity_handler.bind_threepid(threePidCreds, user_id)
+ yield self.identity_handler.bind_threepid(threepid_creds, user_id)
return 200, {}
+class ThreepidUnbindRestServlet(RestServlet):
+ PATTERNS = client_patterns("/account/3pid/unbind$")
+
+ def __init__(self, hs):
+ super(ThreepidUnbindRestServlet, self).__init__()
+ self.hs = hs
+ self.identity_handler = hs.get_handlers().identity_handler
+ self.auth = hs.get_auth()
+ self.datastore = self.hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ """Unbind the given 3pid from a specific identity server, or identity servers that are
+ known to have this 3pid bound
+ """
+ requester = yield self.auth.get_user_by_req(request)
+ body = parse_json_object_from_request(request)
+ assert_params_in_dict(body, ["medium", "address"])
+
+ medium = body.get("medium")
+ address = body.get("address")
+ id_server = body.get("id_server")
+
+ # Attempt to unbind the threepid from an identity server. If id_server is None, try to
+ # unbind from all identity servers this threepid has been added to in the past
+ result = yield self.identity_handler.try_unbind_threepid(
+ requester.user.to_string(),
+ {"address": address, "medium": medium, "id_server": id_server},
+ )
+ return 200, {"id_server_unbind_result": "success" if result else "no-support"}
+
+
class ThreepidDeleteRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/delete$")
@@ -629,5 +662,6 @@ def register_servlets(hs, http_server):
EmailThreepidRequestTokenRestServlet(hs).register(http_server)
MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server)
+ ThreepidUnbindRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server)
WhoamiRestServlet(hs).register(http_server)
|