summary refs log tree commit diff
path: root/synapse/storage/registration.py
diff options
context:
space:
mode:
authorAndrew Morgan <andrew@amorgan.xyz>2019-06-06 15:53:40 +0100
committerAndrew Morgan <andrew@amorgan.xyz>2019-06-06 15:53:40 +0100
commit3b0a477db32c03096e3130ea5c233ddbf2d171bf (patch)
tree85dfb6b0f24854a902a9e748dcc38ac681ccd4a7 /synapse/storage/registration.py
parentFix clientip bug (diff)
downloadsynapse-3b0a477db32c03096e3130ea5c233ddbf2d171bf.tar.xz
Fix bugs with database
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r--synapse/storage/registration.py89
1 files changed, 29 insertions, 60 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 43650d7a48..9b41cbd757 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -1021,19 +1021,20 @@ class RegistrationStore(
         keyvalues = {
             "medium": medium,
             "client_secret": client_secret,
-            "session_id": sid,
-            "address": address,
         }
-        cols_to_return = [
-            "session_id", "medium", "address",
-            "client_secret", "last_send_attempt", "validated_at",
-        ]
+        if address:
+            keyvalues["address"] = address
+        if sid:
+            keyvalues["session_id"] = sid
+
+        assert(address or sid)
 
         def get_threepid_validation_session_txn(txn):
-            sql = "SELECT %s FROM threepid_validation_session WHERE %s" % (
-                ", ".join(cols_to_return),
-                " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),
-            )
+            sql = """
+                SELECT address, session_id, medium, client_secret,
+                last_send_attempt, validated_at
+                FROM threepid_validation_session WHERE %s
+                """ % (" AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),)
 
             if validated is not None:
                 sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL")
@@ -1088,10 +1089,6 @@ class RegistrationStore(
             retrieved_client_secret = row["client_secret"]
             validated_at = row["validated_at"]
 
-            if validated_at:
-                raise ThreepidValidationError(
-                    400, "This session has already been validated",
-                )
             if retrieved_client_secret != client_secret:
                 raise ThreepidValidationError(
                     400, "This client_secret does not match the provided session_id",
@@ -1112,6 +1109,10 @@ class RegistrationStore(
             expires = row["expires"]
             next_link = row["next_link"]
 
+            # If the session is already validated, no need to revalidate
+            if validated_at:
+                return next_link
+
             if expires <= current_ts:
                 raise ThreepidValidationError(
                     400, "This token has expired. Please request a new one",
@@ -1228,35 +1229,6 @@ class RegistrationStore(
             start_or_continue_validation_session_txn,
         )
 
-    def insert_threepid_validation_token(
-        self,
-        session_id,
-        token,
-        expires,
-        next_link=None,
-    ):
-        """Insert a new 3PID validation token and details
-
-        Args:
-            session_id (str): The id of the validation session this attempt
-                is related to
-            token (str): The validation token
-            expires (int): The timestamp for which after this token will no
-                longer be valid
-            next_link (str|None): The link to redirect the user to upon successful
-                validation
-        """
-        return self._simple_insert(
-            table="threepid_validation_token",
-            values={
-                "session_id": session_id,
-                "token": token,
-                "next_link": next_link,
-                "expires": expires,
-            },
-            desc="insert_threepid_validation_token",
-        )
-
     def cull_expired_threepid_validation_tokens(self):
         """Remove threepid validation tokens with expiry dates that have passed"""
         def cull_expired_threepid_validation_tokens_txn(txn, ts):
@@ -1280,22 +1252,19 @@ class RegistrationStore(
         Args:
             session_id (str): The ID of the session to delete
         """
-        return self._simple_delete(
-            table="threepid_validation_session",
-            keyvalues={"session_id": session_id},
-            desc="delete_threepid_session",
-        )
-
-    def delete_threepid_tokens(self, session_id):
-        """Removes threepid validation tokens from the database which match a
-        given session ID.
+        def delete_threepid_session_txn(txn):
+            self._simple_delete_txn(
+                txn,
+                table="threepid_validation_token",
+                keyvalues={"session_id": session_id},
+            )
+            self._simple_delete_txn(
+                txn,
+                table="threepid_validation_session",
+                keyvalues={"session_id": session_id},
+            )
 
-        Args:
-            session_id (str): The ID of the session to delete
-        """
-        # Delete tokens associated with this session id
-        return self._simple_delete_one(
-            table="threepid_validation_token",
-            keyvalues={"session_id": session_id},
-            desc="delete_threepid_session_tokens",
+        return self.runInteraction(
+            "delete_threepid_session",
+            delete_threepid_session_txn,
         )