summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/app/generic_worker.py31
-rw-r--r--synapse/handlers/identity.py94
-rw-r--r--synapse/handlers/saml_handler.py12
-rw-r--r--synapse/handlers/ui_auth/checkers.py15
-rw-r--r--synapse/storage/prepare_database.py5
-rw-r--r--synapse/util/async_helpers.py12
6 files changed, 94 insertions, 75 deletions
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 5afe52f8d4..f3ec2a34ec 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -863,9 +863,24 @@ class FederationSenderHandler(object):
         a FEDERATION_ACK back to the master, and stores the token that we have processed
          in `federation_stream_position` so that we can restart where we left off.
         """
-        try:
-            self.federation_position = token
+        self.federation_position = token
+
+        # We save and send the ACK to master asynchronously, so we don't block
+        # processing on persistence. We don't need to do this operation for
+        # every single RDATA we receive, we just need to do it periodically.
+
+        if self._fed_position_linearizer.is_queued(None):
+            # There is already a task queued up to save and send the token, so
+            # no need to queue up another task.
+            return
+
+        run_as_background_process("_save_and_send_ack", self._save_and_send_ack)
 
+    async def _save_and_send_ack(self):
+        """Save the current federation position in the database and send an ACK
+        to master with where we're up to.
+        """
+        try:
             # We linearize here to ensure we don't have races updating the token
             #
             # XXX this appears to be redundant, since the ReplicationCommandHandler
@@ -875,16 +890,18 @@ class FederationSenderHandler(object):
             # we're not being re-entered?
 
             with (await self._fed_position_linearizer.queue(None)):
+                # We persist and ack the same position, so we take a copy of it
+                # here as otherwise it can get modified from underneath us.
+                current_position = self.federation_position
+
                 await self.store.update_federation_out_pos(
-                    "federation", self.federation_position
+                    "federation", current_position
                 )
 
                 # We ACK this token over replication so that the master can drop
                 # its in memory queues
-                self._hs.get_tcp_replication().send_federation_ack(
-                    self.federation_position
-                )
-                self._last_ack = self.federation_position
+                self._hs.get_tcp_replication().send_federation_ack(current_position)
+                self._last_ack = current_position
         except Exception:
             logger.exception("Error updating federation stream position")
 
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 9ed0d23b0f..4ba0042768 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -25,7 +25,6 @@ from signedjson.key import decode_verify_key_bytes
 from signedjson.sign import verify_signed_json
 from unpaddedbase64 import decode_base64
 
-from twisted.internet import defer
 from twisted.internet.error import TimeoutError
 
 from synapse.api.errors import (
@@ -60,8 +59,7 @@ class IdentityHandler(BaseHandler):
         self.federation_http_client = hs.get_http_client()
         self.hs = hs
 
-    @defer.inlineCallbacks
-    def threepid_from_creds(self, id_server, creds):
+    async def threepid_from_creds(self, id_server, creds):
         """
         Retrieve and validate a threepid identifier from a "credentials" dictionary against a
         given identity server
@@ -97,7 +95,7 @@ class IdentityHandler(BaseHandler):
         url = id_server + "/_matrix/identity/api/v1/3pid/getValidated3pid"
 
         try:
-            data = yield self.http_client.get_json(url, query_params)
+            data = await self.http_client.get_json(url, query_params)
         except TimeoutError:
             raise SynapseError(500, "Timed out contacting identity server")
         except HttpResponseException as e:
@@ -120,8 +118,7 @@ class IdentityHandler(BaseHandler):
         logger.info("%s reported non-validated threepid: %s", id_server, creds)
         return None
 
-    @defer.inlineCallbacks
-    def bind_threepid(
+    async def bind_threepid(
         self, client_secret, sid, mxid, id_server, id_access_token=None, use_v2=True
     ):
         """Bind a 3PID to an identity server
@@ -161,12 +158,12 @@ class IdentityHandler(BaseHandler):
         try:
             # Use the blacklisting http client as this call is only to identity servers
             # provided by a client
-            data = yield self.blacklisting_http_client.post_json_get_json(
+            data = await self.blacklisting_http_client.post_json_get_json(
                 bind_url, bind_data, headers=headers
             )
 
             # Remember where we bound the threepid
-            yield self.store.add_user_bound_threepid(
+            await self.store.add_user_bound_threepid(
                 user_id=mxid,
                 medium=data["medium"],
                 address=data["address"],
@@ -185,13 +182,12 @@ class IdentityHandler(BaseHandler):
             return data
 
         logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url)
-        res = yield self.bind_threepid(
+        res = await self.bind_threepid(
             client_secret, sid, mxid, id_server, id_access_token, use_v2=False
         )
         return res
 
-    @defer.inlineCallbacks
-    def try_unbind_threepid(self, mxid, threepid):
+    async def try_unbind_threepid(self, mxid, threepid):
         """Attempt to remove a 3PID from an identity server, or if one is not provided, all
         identity servers we're aware the binding is present on
 
@@ -211,7 +207,7 @@ class IdentityHandler(BaseHandler):
         if threepid.get("id_server"):
             id_servers = [threepid["id_server"]]
         else:
-            id_servers = yield self.store.get_id_servers_user_bound(
+            id_servers = await self.store.get_id_servers_user_bound(
                 user_id=mxid, medium=threepid["medium"], address=threepid["address"]
             )
 
@@ -221,14 +217,13 @@ class IdentityHandler(BaseHandler):
 
         changed = True
         for id_server in id_servers:
-            changed &= yield self.try_unbind_threepid_with_id_server(
+            changed &= await self.try_unbind_threepid_with_id_server(
                 mxid, threepid, id_server
             )
 
         return changed
 
-    @defer.inlineCallbacks
-    def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
+    async def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
         """Removes a binding from an identity server
 
         Args:
@@ -266,7 +261,7 @@ class IdentityHandler(BaseHandler):
         try:
             # Use the blacklisting http client as this call is only to identity servers
             # provided by a client
-            yield self.blacklisting_http_client.post_json_get_json(
+            await self.blacklisting_http_client.post_json_get_json(
                 url, content, headers
             )
             changed = True
@@ -281,7 +276,7 @@ class IdentityHandler(BaseHandler):
         except TimeoutError:
             raise SynapseError(500, "Timed out contacting identity server")
 
-        yield self.store.remove_user_bound_threepid(
+        await self.store.remove_user_bound_threepid(
             user_id=mxid,
             medium=threepid["medium"],
             address=threepid["address"],
@@ -376,8 +371,7 @@ class IdentityHandler(BaseHandler):
 
         return session_id
 
-    @defer.inlineCallbacks
-    def requestEmailToken(
+    async def requestEmailToken(
         self, id_server, email, client_secret, send_attempt, next_link=None
     ):
         """
@@ -412,7 +406,7 @@ class IdentityHandler(BaseHandler):
             )
 
         try:
-            data = yield self.http_client.post_json_get_json(
+            data = await self.http_client.post_json_get_json(
                 id_server + "/_matrix/identity/api/v1/validate/email/requestToken",
                 params,
             )
@@ -423,8 +417,7 @@ class IdentityHandler(BaseHandler):
         except TimeoutError:
             raise SynapseError(500, "Timed out contacting identity server")
 
-    @defer.inlineCallbacks
-    def requestMsisdnToken(
+    async def requestMsisdnToken(
         self,
         id_server,
         country,
@@ -466,7 +459,7 @@ class IdentityHandler(BaseHandler):
             )
 
         try:
-            data = yield self.http_client.post_json_get_json(
+            data = await self.http_client.post_json_get_json(
                 id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken",
                 params,
             )
@@ -487,8 +480,7 @@ class IdentityHandler(BaseHandler):
         )
         return data
 
-    @defer.inlineCallbacks
-    def validate_threepid_session(self, client_secret, sid):
+    async def validate_threepid_session(self, client_secret, sid):
         """Validates a threepid session with only the client secret and session ID
         Tries validating against any configured account_threepid_delegates as well as locally.
 
@@ -510,12 +502,12 @@ class IdentityHandler(BaseHandler):
         # Try to validate as email
         if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
             # Ask our delegated email identity server
-            validation_session = yield self.threepid_from_creds(
+            validation_session = await self.threepid_from_creds(
                 self.hs.config.account_threepid_delegate_email, threepid_creds
             )
         elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
             # Get a validated session matching these details
-            validation_session = yield self.store.get_threepid_validation_session(
+            validation_session = await self.store.get_threepid_validation_session(
                 "email", client_secret, sid=sid, validated=True
             )
 
@@ -525,14 +517,13 @@ class IdentityHandler(BaseHandler):
         # Try to validate as msisdn
         if self.hs.config.account_threepid_delegate_msisdn:
             # Ask our delegated msisdn identity server
-            validation_session = yield self.threepid_from_creds(
+            validation_session = await self.threepid_from_creds(
                 self.hs.config.account_threepid_delegate_msisdn, threepid_creds
             )
 
         return validation_session
 
-    @defer.inlineCallbacks
-    def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token):
+    async def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token):
         """Proxy a POST submitToken request to an identity server for verification purposes
 
         Args:
@@ -553,11 +544,9 @@ class IdentityHandler(BaseHandler):
         body = {"client_secret": client_secret, "sid": sid, "token": token}
 
         try:
-            return (
-                yield self.http_client.post_json_get_json(
-                    id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken",
-                    body,
-                )
+            return await self.http_client.post_json_get_json(
+                id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken",
+                body,
             )
         except TimeoutError:
             raise SynapseError(500, "Timed out contacting identity server")
@@ -565,8 +554,7 @@ class IdentityHandler(BaseHandler):
             logger.warning("Error contacting msisdn account_threepid_delegate: %s", e)
             raise SynapseError(400, "Error contacting the identity server")
 
-    @defer.inlineCallbacks
-    def lookup_3pid(self, id_server, medium, address, id_access_token=None):
+    async def lookup_3pid(self, id_server, medium, address, id_access_token=None):
         """Looks up a 3pid in the passed identity server.
 
         Args:
@@ -582,7 +570,7 @@ class IdentityHandler(BaseHandler):
         """
         if id_access_token is not None:
             try:
-                results = yield self._lookup_3pid_v2(
+                results = await self._lookup_3pid_v2(
                     id_server, id_access_token, medium, address
                 )
                 return results
@@ -601,10 +589,9 @@ class IdentityHandler(BaseHandler):
                     logger.warning("Error when looking up hashing details: %s", e)
                     return None
 
-        return (yield self._lookup_3pid_v1(id_server, medium, address))
+        return await self._lookup_3pid_v1(id_server, medium, address)
 
-    @defer.inlineCallbacks
-    def _lookup_3pid_v1(self, id_server, medium, address):
+    async def _lookup_3pid_v1(self, id_server, medium, address):
         """Looks up a 3pid in the passed identity server using v1 lookup.
 
         Args:
@@ -617,7 +604,7 @@ class IdentityHandler(BaseHandler):
             str: the matrix ID of the 3pid, or None if it is not recognized.
         """
         try:
-            data = yield self.blacklisting_http_client.get_json(
+            data = await self.blacklisting_http_client.get_json(
                 "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
                 {"medium": medium, "address": address},
             )
@@ -625,7 +612,7 @@ class IdentityHandler(BaseHandler):
             if "mxid" in data:
                 if "signatures" not in data:
                     raise AuthError(401, "No signatures on 3pid binding")
-                yield self._verify_any_signature(data, id_server)
+                await self._verify_any_signature(data, id_server)
                 return data["mxid"]
         except TimeoutError:
             raise SynapseError(500, "Timed out contacting identity server")
@@ -634,8 +621,7 @@ class IdentityHandler(BaseHandler):
 
         return None
 
-    @defer.inlineCallbacks
-    def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
+    async def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
         """Looks up a 3pid in the passed identity server using v2 lookup.
 
         Args:
@@ -650,7 +636,7 @@ class IdentityHandler(BaseHandler):
         """
         # Check what hashing details are supported by this identity server
         try:
-            hash_details = yield self.blacklisting_http_client.get_json(
+            hash_details = await self.blacklisting_http_client.get_json(
                 "%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server),
                 {"access_token": id_access_token},
             )
@@ -717,7 +703,7 @@ class IdentityHandler(BaseHandler):
         headers = {"Authorization": create_id_access_token_header(id_access_token)}
 
         try:
-            lookup_results = yield self.blacklisting_http_client.post_json_get_json(
+            lookup_results = await self.blacklisting_http_client.post_json_get_json(
                 "%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server),
                 {
                     "addresses": [lookup_value],
@@ -745,13 +731,12 @@ class IdentityHandler(BaseHandler):
         mxid = lookup_results["mappings"].get(lookup_value)
         return mxid
 
-    @defer.inlineCallbacks
-    def _verify_any_signature(self, data, server_hostname):
+    async def _verify_any_signature(self, data, server_hostname):
         if server_hostname not in data["signatures"]:
             raise AuthError(401, "No signature from server %s" % (server_hostname,))
         for key_name, signature in data["signatures"][server_hostname].items():
             try:
-                key_data = yield self.blacklisting_http_client.get_json(
+                key_data = await self.blacklisting_http_client.get_json(
                     "%s%s/_matrix/identity/api/v1/pubkey/%s"
                     % (id_server_scheme, server_hostname, key_name)
                 )
@@ -770,8 +755,7 @@ class IdentityHandler(BaseHandler):
             )
             return
 
-    @defer.inlineCallbacks
-    def ask_id_server_for_third_party_invite(
+    async def ask_id_server_for_third_party_invite(
         self,
         requester,
         id_server,
@@ -844,7 +828,7 @@ class IdentityHandler(BaseHandler):
             # Attempt a v2 lookup
             url = base_url + "/v2/store-invite"
             try:
-                data = yield self.blacklisting_http_client.post_json_get_json(
+                data = await self.blacklisting_http_client.post_json_get_json(
                     url,
                     invite_config,
                     {"Authorization": create_id_access_token_header(id_access_token)},
@@ -864,7 +848,7 @@ class IdentityHandler(BaseHandler):
             url = base_url + "/api/v1/store-invite"
 
             try:
-                data = yield self.blacklisting_http_client.post_json_get_json(
+                data = await self.blacklisting_http_client.post_json_get_json(
                     url, invite_config
                 )
             except TimeoutError:
@@ -882,7 +866,7 @@ class IdentityHandler(BaseHandler):
                 # types. This is especially true with old instances of Sydent, see
                 # https://github.com/matrix-org/sydent/pull/170
                 try:
-                    data = yield self.blacklisting_http_client.post_urlencoded_get_json(
+                    data = await self.blacklisting_http_client.post_urlencoded_get_json(
                         url, invite_config
                     )
                 except HttpResponseException as e:
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index e7015c704f..de6ba4ab55 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -271,6 +271,7 @@ class SamlHandler:
                     raise SynapseError(500, "Error parsing SAML2 response")
 
                 displayname = attribute_dict.get("displayname")
+                emails = attribute_dict.get("emails", [])
 
                 # Check if this mxid already exists
                 if not await self._datastore.get_users_by_id_case_insensitive(
@@ -288,7 +289,9 @@ class SamlHandler:
             logger.info("Mapped SAML user to local part %s", localpart)
 
             registered_user_id = await self._registration_handler.register_user(
-                localpart=localpart, default_display_name=displayname
+                localpart=localpart,
+                default_display_name=displayname,
+                bind_emails=emails,
             )
 
             await self._datastore.record_user_external_id(
@@ -381,6 +384,7 @@ class DefaultSamlMappingProvider(object):
             dict: A dict containing new user attributes. Possible keys:
                 * mxid_localpart (str): Required. The localpart of the user's mxid
                 * displayname (str): The displayname of the user
+                * emails (list[str]): Any emails for the user
         """
         try:
             mxid_source = saml_response.ava[self._mxid_source_attribute][0]
@@ -403,9 +407,13 @@ class DefaultSamlMappingProvider(object):
         # If displayname is None, the mxid_localpart will be used instead
         displayname = saml_response.ava.get("displayName", [None])[0]
 
+        # Retrieve any emails present in the saml response
+        emails = saml_response.ava.get("email", [])
+
         return {
             "mxid_localpart": localpart,
             "displayname": displayname,
+            "emails": emails,
         }
 
     @staticmethod
@@ -444,4 +452,4 @@ class DefaultSamlMappingProvider(object):
                 second set consists of those attributes which can be used if
                 available, but are not necessary
         """
-        return {"uid", config.mxid_source_attribute}, {"displayName"}
+        return {"uid", config.mxid_source_attribute}, {"displayName", "email"}
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 8363d887a9..8b24a73319 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -138,8 +138,7 @@ class _BaseThreepidAuthChecker:
         self.hs = hs
         self.store = hs.get_datastore()
 
-    @defer.inlineCallbacks
-    def _check_threepid(self, medium, authdict):
+    async def _check_threepid(self, medium, authdict):
         if "threepid_creds" not in authdict:
             raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
 
@@ -155,18 +154,18 @@ class _BaseThreepidAuthChecker:
                 raise SynapseError(
                     400, "Phone number verification is not enabled on this homeserver"
                 )
-            threepid = yield identity_handler.threepid_from_creds(
+            threepid = await identity_handler.threepid_from_creds(
                 self.hs.config.account_threepid_delegate_msisdn, threepid_creds
             )
         elif medium == "email":
             if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
                 assert self.hs.config.account_threepid_delegate_email
-                threepid = yield identity_handler.threepid_from_creds(
+                threepid = await identity_handler.threepid_from_creds(
                     self.hs.config.account_threepid_delegate_email, threepid_creds
                 )
             elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
                 threepid = None
-                row = yield self.store.get_threepid_validation_session(
+                row = await self.store.get_threepid_validation_session(
                     medium,
                     threepid_creds["client_secret"],
                     sid=threepid_creds["sid"],
@@ -181,7 +180,7 @@ class _BaseThreepidAuthChecker:
                     }
 
                     # Valid threepid returned, delete from the db
-                    yield self.store.delete_threepid_session(threepid_creds["sid"])
+                    await self.store.delete_threepid_session(threepid_creds["sid"])
             else:
                 raise SynapseError(
                     400, "Email address verification is not enabled on this homeserver"
@@ -220,7 +219,7 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec
         )
 
     def check_auth(self, authdict, clientip):
-        return self._check_threepid("email", authdict)
+        return defer.ensureDeferred(self._check_threepid("email", authdict))
 
 
 class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
@@ -234,7 +233,7 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
         return bool(self.hs.config.account_threepid_delegate_msisdn)
 
     def check_auth(self, authdict, clientip):
-        return self._check_threepid("msisdn", authdict)
+        return defer.ensureDeferred(self._check_threepid("msisdn", authdict))
 
 
 INTERACTIVE_AUTH_CHECKERS = [
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 9afc145340..b95434f031 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -366,9 +366,8 @@ def _upgrade_existing_database(
         if duplicates:
             # We don't support using the same file name in the same delta version.
             raise PrepareDatabaseException(
-                "Found multiple delta files with the same name in v%d: %s",
-                v,
-                duplicates,
+                "Found multiple delta files with the same name in v%d: %s"
+                % (v, duplicates,)
             )
 
         # We sort to ensure that we apply the delta files in a consistent
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 581dffd8a0..f7af2bca7f 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -225,6 +225,18 @@ class Linearizer(object):
             {}
         )  # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]]
 
+    def is_queued(self, key) -> bool:
+        """Checks whether there is a process queued up waiting
+        """
+        entry = self.key_to_defer.get(key)
+        if not entry:
+            # No entry so nothing is waiting.
+            return False
+
+        # There are waiting deferreds only in the OrderedDict of deferreds is
+        # non-empty.
+        return bool(entry[1])
+
     def queue(self, key):
         # we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
         # (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not