diff --git a/changelog.d/7385.feature b/changelog.d/7385.feature
new file mode 100644
index 0000000000..9d8fb2311a
--- /dev/null
+++ b/changelog.d/7385.feature
@@ -0,0 +1 @@
+For SAML authentication, add the ability to pass email addresses to be added to new users' accounts via SAML attributes. Contributed by Christopher Cooper.
diff --git a/changelog.d/7561.misc b/changelog.d/7561.misc
new file mode 100644
index 0000000000..448dbd5699
--- /dev/null
+++ b/changelog.d/7561.misc
@@ -0,0 +1 @@
+Convert the identity handler to async/await.
diff --git a/changelog.d/7575.bugfix b/changelog.d/7575.bugfix
new file mode 100644
index 0000000000..0ab5516eb3
--- /dev/null
+++ b/changelog.d/7575.bugfix
@@ -0,0 +1 @@
+Fix str placeholders in an instance of `PrepareDatabaseException`. Introduced in Synapse v1.8.0.
diff --git a/changelog.d/7584.misc b/changelog.d/7584.misc
new file mode 100644
index 0000000000..55d1689f77
--- /dev/null
+++ b/changelog.d/7584.misc
@@ -0,0 +1 @@
+Speed up processing of federation stream RDATA rows.
diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md
index 4cd3a568f2..abea432343 100644
--- a/docs/sso_mapping_providers.md
+++ b/docs/sso_mapping_providers.md
@@ -138,6 +138,8 @@ A custom mapping provider must specify the following methods:
* `mxid_localpart` - Required. The mxid localpart of the new user.
* `displayname` - The displayname of the new user. If not provided, will default to
the value of `mxid_localpart`.
+ * `emails` - A list of emails for the new user. If not provided, will
+ default to an empty list.
### Default SAML Mapping Provider
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 5afe52f8d4..f3ec2a34ec 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -863,9 +863,24 @@ class FederationSenderHandler(object):
a FEDERATION_ACK back to the master, and stores the token that we have processed
in `federation_stream_position` so that we can restart where we left off.
"""
- try:
- self.federation_position = token
+ self.federation_position = token
+
+ # We save and send the ACK to master asynchronously, so we don't block
+ # processing on persistence. We don't need to do this operation for
+ # every single RDATA we receive, we just need to do it periodically.
+
+ if self._fed_position_linearizer.is_queued(None):
+ # There is already a task queued up to save and send the token, so
+ # no need to queue up another task.
+ return
+
+ run_as_background_process("_save_and_send_ack", self._save_and_send_ack)
+ async def _save_and_send_ack(self):
+ """Save the current federation position in the database and send an ACK
+ to master with where we're up to.
+ """
+ try:
# We linearize here to ensure we don't have races updating the token
#
# XXX this appears to be redundant, since the ReplicationCommandHandler
@@ -875,16 +890,18 @@ class FederationSenderHandler(object):
# we're not being re-entered?
with (await self._fed_position_linearizer.queue(None)):
+ # We persist and ack the same position, so we take a copy of it
+ # here as otherwise it can get modified from underneath us.
+ current_position = self.federation_position
+
await self.store.update_federation_out_pos(
- "federation", self.federation_position
+ "federation", current_position
)
# We ACK this token over replication so that the master can drop
# its in memory queues
- self._hs.get_tcp_replication().send_federation_ack(
- self.federation_position
- )
- self._last_ack = self.federation_position
+ self._hs.get_tcp_replication().send_federation_ack(current_position)
+ self._last_ack = current_position
except Exception:
logger.exception("Error updating federation stream position")
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 9ed0d23b0f..4ba0042768 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -25,7 +25,6 @@ from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
-from twisted.internet import defer
from twisted.internet.error import TimeoutError
from synapse.api.errors import (
@@ -60,8 +59,7 @@ class IdentityHandler(BaseHandler):
self.federation_http_client = hs.get_http_client()
self.hs = hs
- @defer.inlineCallbacks
- def threepid_from_creds(self, id_server, creds):
+ async def threepid_from_creds(self, id_server, creds):
"""
Retrieve and validate a threepid identifier from a "credentials" dictionary against a
given identity server
@@ -97,7 +95,7 @@ class IdentityHandler(BaseHandler):
url = id_server + "/_matrix/identity/api/v1/3pid/getValidated3pid"
try:
- data = yield self.http_client.get_json(url, query_params)
+ data = await self.http_client.get_json(url, query_params)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e:
@@ -120,8 +118,7 @@ class IdentityHandler(BaseHandler):
logger.info("%s reported non-validated threepid: %s", id_server, creds)
return None
- @defer.inlineCallbacks
- def bind_threepid(
+ async def bind_threepid(
self, client_secret, sid, mxid, id_server, id_access_token=None, use_v2=True
):
"""Bind a 3PID to an identity server
@@ -161,12 +158,12 @@ class IdentityHandler(BaseHandler):
try:
# Use the blacklisting http client as this call is only to identity servers
# provided by a client
- data = yield self.blacklisting_http_client.post_json_get_json(
+ data = await self.blacklisting_http_client.post_json_get_json(
bind_url, bind_data, headers=headers
)
# Remember where we bound the threepid
- yield self.store.add_user_bound_threepid(
+ await self.store.add_user_bound_threepid(
user_id=mxid,
medium=data["medium"],
address=data["address"],
@@ -185,13 +182,12 @@ class IdentityHandler(BaseHandler):
return data
logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url)
- res = yield self.bind_threepid(
+ res = await self.bind_threepid(
client_secret, sid, mxid, id_server, id_access_token, use_v2=False
)
return res
- @defer.inlineCallbacks
- def try_unbind_threepid(self, mxid, threepid):
+ async def try_unbind_threepid(self, mxid, threepid):
"""Attempt to remove a 3PID from an identity server, or if one is not provided, all
identity servers we're aware the binding is present on
@@ -211,7 +207,7 @@ class IdentityHandler(BaseHandler):
if threepid.get("id_server"):
id_servers = [threepid["id_server"]]
else:
- id_servers = yield self.store.get_id_servers_user_bound(
+ id_servers = await self.store.get_id_servers_user_bound(
user_id=mxid, medium=threepid["medium"], address=threepid["address"]
)
@@ -221,14 +217,13 @@ class IdentityHandler(BaseHandler):
changed = True
for id_server in id_servers:
- changed &= yield self.try_unbind_threepid_with_id_server(
+ changed &= await self.try_unbind_threepid_with_id_server(
mxid, threepid, id_server
)
return changed
- @defer.inlineCallbacks
- def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
+ async def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
"""Removes a binding from an identity server
Args:
@@ -266,7 +261,7 @@ class IdentityHandler(BaseHandler):
try:
# Use the blacklisting http client as this call is only to identity servers
# provided by a client
- yield self.blacklisting_http_client.post_json_get_json(
+ await self.blacklisting_http_client.post_json_get_json(
url, content, headers
)
changed = True
@@ -281,7 +276,7 @@ class IdentityHandler(BaseHandler):
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
- yield self.store.remove_user_bound_threepid(
+ await self.store.remove_user_bound_threepid(
user_id=mxid,
medium=threepid["medium"],
address=threepid["address"],
@@ -376,8 +371,7 @@ class IdentityHandler(BaseHandler):
return session_id
- @defer.inlineCallbacks
- def requestEmailToken(
+ async def requestEmailToken(
self, id_server, email, client_secret, send_attempt, next_link=None
):
"""
@@ -412,7 +406,7 @@ class IdentityHandler(BaseHandler):
)
try:
- data = yield self.http_client.post_json_get_json(
+ data = await self.http_client.post_json_get_json(
id_server + "/_matrix/identity/api/v1/validate/email/requestToken",
params,
)
@@ -423,8 +417,7 @@ class IdentityHandler(BaseHandler):
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
- @defer.inlineCallbacks
- def requestMsisdnToken(
+ async def requestMsisdnToken(
self,
id_server,
country,
@@ -466,7 +459,7 @@ class IdentityHandler(BaseHandler):
)
try:
- data = yield self.http_client.post_json_get_json(
+ data = await self.http_client.post_json_get_json(
id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken",
params,
)
@@ -487,8 +480,7 @@ class IdentityHandler(BaseHandler):
)
return data
- @defer.inlineCallbacks
- def validate_threepid_session(self, client_secret, sid):
+ async def validate_threepid_session(self, client_secret, sid):
"""Validates a threepid session with only the client secret and session ID
Tries validating against any configured account_threepid_delegates as well as locally.
@@ -510,12 +502,12 @@ class IdentityHandler(BaseHandler):
# Try to validate as email
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
# Ask our delegated email identity server
- validation_session = yield self.threepid_from_creds(
+ validation_session = await self.threepid_from_creds(
self.hs.config.account_threepid_delegate_email, threepid_creds
)
elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
# Get a validated session matching these details
- validation_session = yield self.store.get_threepid_validation_session(
+ validation_session = await self.store.get_threepid_validation_session(
"email", client_secret, sid=sid, validated=True
)
@@ -525,14 +517,13 @@ class IdentityHandler(BaseHandler):
# Try to validate as msisdn
if self.hs.config.account_threepid_delegate_msisdn:
# Ask our delegated msisdn identity server
- validation_session = yield self.threepid_from_creds(
+ validation_session = await self.threepid_from_creds(
self.hs.config.account_threepid_delegate_msisdn, threepid_creds
)
return validation_session
- @defer.inlineCallbacks
- def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token):
+ async def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token):
"""Proxy a POST submitToken request to an identity server for verification purposes
Args:
@@ -553,11 +544,9 @@ class IdentityHandler(BaseHandler):
body = {"client_secret": client_secret, "sid": sid, "token": token}
try:
- return (
- yield self.http_client.post_json_get_json(
- id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken",
- body,
- )
+ return await self.http_client.post_json_get_json(
+ id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken",
+ body,
)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
@@ -565,8 +554,7 @@ class IdentityHandler(BaseHandler):
logger.warning("Error contacting msisdn account_threepid_delegate: %s", e)
raise SynapseError(400, "Error contacting the identity server")
- @defer.inlineCallbacks
- def lookup_3pid(self, id_server, medium, address, id_access_token=None):
+ async def lookup_3pid(self, id_server, medium, address, id_access_token=None):
"""Looks up a 3pid in the passed identity server.
Args:
@@ -582,7 +570,7 @@ class IdentityHandler(BaseHandler):
"""
if id_access_token is not None:
try:
- results = yield self._lookup_3pid_v2(
+ results = await self._lookup_3pid_v2(
id_server, id_access_token, medium, address
)
return results
@@ -601,10 +589,9 @@ class IdentityHandler(BaseHandler):
logger.warning("Error when looking up hashing details: %s", e)
return None
- return (yield self._lookup_3pid_v1(id_server, medium, address))
+ return await self._lookup_3pid_v1(id_server, medium, address)
- @defer.inlineCallbacks
- def _lookup_3pid_v1(self, id_server, medium, address):
+ async def _lookup_3pid_v1(self, id_server, medium, address):
"""Looks up a 3pid in the passed identity server using v1 lookup.
Args:
@@ -617,7 +604,7 @@ class IdentityHandler(BaseHandler):
str: the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
- data = yield self.blacklisting_http_client.get_json(
+ data = await self.blacklisting_http_client.get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
{"medium": medium, "address": address},
)
@@ -625,7 +612,7 @@ class IdentityHandler(BaseHandler):
if "mxid" in data:
if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding")
- yield self._verify_any_signature(data, id_server)
+ await self._verify_any_signature(data, id_server)
return data["mxid"]
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
@@ -634,8 +621,7 @@ class IdentityHandler(BaseHandler):
return None
- @defer.inlineCallbacks
- def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
+ async def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
"""Looks up a 3pid in the passed identity server using v2 lookup.
Args:
@@ -650,7 +636,7 @@ class IdentityHandler(BaseHandler):
"""
# Check what hashing details are supported by this identity server
try:
- hash_details = yield self.blacklisting_http_client.get_json(
+ hash_details = await self.blacklisting_http_client.get_json(
"%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server),
{"access_token": id_access_token},
)
@@ -717,7 +703,7 @@ class IdentityHandler(BaseHandler):
headers = {"Authorization": create_id_access_token_header(id_access_token)}
try:
- lookup_results = yield self.blacklisting_http_client.post_json_get_json(
+ lookup_results = await self.blacklisting_http_client.post_json_get_json(
"%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server),
{
"addresses": [lookup_value],
@@ -745,13 +731,12 @@ class IdentityHandler(BaseHandler):
mxid = lookup_results["mappings"].get(lookup_value)
return mxid
- @defer.inlineCallbacks
- def _verify_any_signature(self, data, server_hostname):
+ async def _verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items():
try:
- key_data = yield self.blacklisting_http_client.get_json(
+ key_data = await self.blacklisting_http_client.get_json(
"%s%s/_matrix/identity/api/v1/pubkey/%s"
% (id_server_scheme, server_hostname, key_name)
)
@@ -770,8 +755,7 @@ class IdentityHandler(BaseHandler):
)
return
- @defer.inlineCallbacks
- def ask_id_server_for_third_party_invite(
+ async def ask_id_server_for_third_party_invite(
self,
requester,
id_server,
@@ -844,7 +828,7 @@ class IdentityHandler(BaseHandler):
# Attempt a v2 lookup
url = base_url + "/v2/store-invite"
try:
- data = yield self.blacklisting_http_client.post_json_get_json(
+ data = await self.blacklisting_http_client.post_json_get_json(
url,
invite_config,
{"Authorization": create_id_access_token_header(id_access_token)},
@@ -864,7 +848,7 @@ class IdentityHandler(BaseHandler):
url = base_url + "/api/v1/store-invite"
try:
- data = yield self.blacklisting_http_client.post_json_get_json(
+ data = await self.blacklisting_http_client.post_json_get_json(
url, invite_config
)
except TimeoutError:
@@ -882,7 +866,7 @@ class IdentityHandler(BaseHandler):
# types. This is especially true with old instances of Sydent, see
# https://github.com/matrix-org/sydent/pull/170
try:
- data = yield self.blacklisting_http_client.post_urlencoded_get_json(
+ data = await self.blacklisting_http_client.post_urlencoded_get_json(
url, invite_config
)
except HttpResponseException as e:
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index e7015c704f..de6ba4ab55 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -271,6 +271,7 @@ class SamlHandler:
raise SynapseError(500, "Error parsing SAML2 response")
displayname = attribute_dict.get("displayname")
+ emails = attribute_dict.get("emails", [])
# Check if this mxid already exists
if not await self._datastore.get_users_by_id_case_insensitive(
@@ -288,7 +289,9 @@ class SamlHandler:
logger.info("Mapped SAML user to local part %s", localpart)
registered_user_id = await self._registration_handler.register_user(
- localpart=localpart, default_display_name=displayname
+ localpart=localpart,
+ default_display_name=displayname,
+ bind_emails=emails,
)
await self._datastore.record_user_external_id(
@@ -381,6 +384,7 @@ class DefaultSamlMappingProvider(object):
dict: A dict containing new user attributes. Possible keys:
* mxid_localpart (str): Required. The localpart of the user's mxid
* displayname (str): The displayname of the user
+ * emails (list[str]): Any emails for the user
"""
try:
mxid_source = saml_response.ava[self._mxid_source_attribute][0]
@@ -403,9 +407,13 @@ class DefaultSamlMappingProvider(object):
# If displayname is None, the mxid_localpart will be used instead
displayname = saml_response.ava.get("displayName", [None])[0]
+ # Retrieve any emails present in the saml response
+ emails = saml_response.ava.get("email", [])
+
return {
"mxid_localpart": localpart,
"displayname": displayname,
+ "emails": emails,
}
@staticmethod
@@ -444,4 +452,4 @@ class DefaultSamlMappingProvider(object):
second set consists of those attributes which can be used if
available, but are not necessary
"""
- return {"uid", config.mxid_source_attribute}, {"displayName"}
+ return {"uid", config.mxid_source_attribute}, {"displayName", "email"}
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 8363d887a9..8b24a73319 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -138,8 +138,7 @@ class _BaseThreepidAuthChecker:
self.hs = hs
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def _check_threepid(self, medium, authdict):
+ async def _check_threepid(self, medium, authdict):
if "threepid_creds" not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
@@ -155,18 +154,18 @@ class _BaseThreepidAuthChecker:
raise SynapseError(
400, "Phone number verification is not enabled on this homeserver"
)
- threepid = yield identity_handler.threepid_from_creds(
+ threepid = await identity_handler.threepid_from_creds(
self.hs.config.account_threepid_delegate_msisdn, threepid_creds
)
elif medium == "email":
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
assert self.hs.config.account_threepid_delegate_email
- threepid = yield identity_handler.threepid_from_creds(
+ threepid = await identity_handler.threepid_from_creds(
self.hs.config.account_threepid_delegate_email, threepid_creds
)
elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
threepid = None
- row = yield self.store.get_threepid_validation_session(
+ row = await self.store.get_threepid_validation_session(
medium,
threepid_creds["client_secret"],
sid=threepid_creds["sid"],
@@ -181,7 +180,7 @@ class _BaseThreepidAuthChecker:
}
# Valid threepid returned, delete from the db
- yield self.store.delete_threepid_session(threepid_creds["sid"])
+ await self.store.delete_threepid_session(threepid_creds["sid"])
else:
raise SynapseError(
400, "Email address verification is not enabled on this homeserver"
@@ -220,7 +219,7 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec
)
def check_auth(self, authdict, clientip):
- return self._check_threepid("email", authdict)
+ return defer.ensureDeferred(self._check_threepid("email", authdict))
class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
@@ -234,7 +233,7 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
return bool(self.hs.config.account_threepid_delegate_msisdn)
def check_auth(self, authdict, clientip):
- return self._check_threepid("msisdn", authdict)
+ return defer.ensureDeferred(self._check_threepid("msisdn", authdict))
INTERACTIVE_AUTH_CHECKERS = [
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 9afc145340..b95434f031 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -366,9 +366,8 @@ def _upgrade_existing_database(
if duplicates:
# We don't support using the same file name in the same delta version.
raise PrepareDatabaseException(
- "Found multiple delta files with the same name in v%d: %s",
- v,
- duplicates,
+ "Found multiple delta files with the same name in v%d: %s"
+ % (v, duplicates,)
)
# We sort to ensure that we apply the delta files in a consistent
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 581dffd8a0..f7af2bca7f 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -225,6 +225,18 @@ class Linearizer(object):
{}
) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]]
+ def is_queued(self, key) -> bool:
+ """Checks whether there is a process queued up waiting
+ """
+ entry = self.key_to_defer.get(key)
+ if not entry:
+ # No entry so nothing is waiting.
+ return False
+
+ # There are waiting deferreds only in the OrderedDict of deferreds is
+ # non-empty.
+ return bool(entry[1])
+
def queue(self, key):
# we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
# (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index 852ef23185..ca3858b184 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -45,6 +45,38 @@ class LinearizerTestCase(unittest.TestCase):
with (yield d2):
pass
+ @defer.inlineCallbacks
+ def test_linearizer_is_queued(self):
+ linearizer = Linearizer()
+
+ key = object()
+
+ d1 = linearizer.queue(key)
+ cm1 = yield d1
+
+ # Since d1 gets called immediately, "is_queued" should return false.
+ self.assertFalse(linearizer.is_queued(key))
+
+ d2 = linearizer.queue(key)
+ self.assertFalse(d2.called)
+
+ # Now d2 is queued up behind successful completion of cm1
+ self.assertTrue(linearizer.is_queued(key))
+
+ with cm1:
+ self.assertFalse(d2.called)
+
+ # cm1 still not done, so d2 still queued.
+ self.assertTrue(linearizer.is_queued(key))
+
+ # And now d2 is called and nothing is in the queue again
+ self.assertFalse(linearizer.is_queued(key))
+
+ with (yield d2):
+ self.assertFalse(linearizer.is_queued(key))
+
+ self.assertFalse(linearizer.is_queued(key))
+
def test_lots_of_queued_things(self):
# we have one slow thing, and lots of fast things queued up behind it.
# it should *not* explode the stack.
|