summary refs log tree commit diff
path: root/synapse/storage/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r--synapse/storage/registration.py50
1 files changed, 47 insertions, 3 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 967c732bda..ad1157f979 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -40,7 +40,7 @@ class RegistrationStore(SQLBaseStore):
         Raises:
             StoreError if there was a problem adding this.
         """
-        next_id = yield self._access_tokens_id_gen.get_next()
+        next_id = self._access_tokens_id_gen.get_next()
 
         yield self._simple_insert(
             "access_tokens",
@@ -62,7 +62,7 @@ class RegistrationStore(SQLBaseStore):
         Raises:
             StoreError if there was a problem adding this.
         """
-        next_id = yield self._refresh_tokens_id_gen.get_next()
+        next_id = self._refresh_tokens_id_gen.get_next()
 
         yield self._simple_insert(
             "refresh_tokens",
@@ -99,7 +99,7 @@ class RegistrationStore(SQLBaseStore):
     def _register(self, txn, user_id, token, password_hash, was_guest, make_guest):
         now = int(self.clock.time())
 
-        next_id = self._access_tokens_id_gen.get_next_txn(txn)
+        next_id = self._access_tokens_id_gen.get_next()
 
         try:
             if was_guest:
@@ -387,3 +387,47 @@ class RegistrationStore(SQLBaseStore):
             "find_next_generated_user_id",
             _find_next_generated_user_id
         )))
+
+    @defer.inlineCallbacks
+    def get_3pid_guest_access_token(self, medium, address):
+        ret = yield self._simple_select_one(
+            "threepid_guest_access_tokens",
+            {
+                "medium": medium,
+                "address": address
+            },
+            ["guest_access_token"], True, 'get_3pid_guest_access_token'
+        )
+        if ret:
+            defer.returnValue(ret["guest_access_token"])
+        defer.returnValue(None)
+
+    @defer.inlineCallbacks
+    def save_or_get_3pid_guest_access_token(
+            self, medium, address, access_token, inviter_user_id
+    ):
+        """
+        Gets the 3pid's guest access token if exists, else saves access_token.
+
+        :param medium (str): Medium of the 3pid. Must be "email".
+        :param address (str): 3pid address.
+        :param access_token (str): The access token to persist if none is
+            already persisted.
+        :param inviter_user_id (str): User ID of the inviter.
+        :return (deferred str): Whichever access token is persisted at the end
+            of this function call.
+        """
+        def insert(txn):
+            txn.execute(
+                "INSERT INTO threepid_guest_access_tokens "
+                "(medium, address, guest_access_token, first_inviter) "
+                "VALUES (?, ?, ?, ?)",
+                (medium, address, access_token, inviter_user_id)
+            )
+
+        try:
+            yield self.runInteraction("save_3pid_guest_access_token", insert)
+            defer.returnValue(access_token)
+        except self.database_engine.module.IntegrityError:
+            ret = yield self.get_3pid_guest_access_token(medium, address)
+            defer.returnValue(ret)