summary refs log tree commit diff
path: root/synapse/storage/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r--synapse/storage/registration.py28
1 files changed, 18 insertions, 10 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index d957a629dc..26ef1cfd8a 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -31,12 +31,14 @@ class RegistrationStore(SQLBaseStore):
         self.clock = hs.get_clock()
 
     @defer.inlineCallbacks
-    def add_access_token_to_user(self, user_id, token):
+    def add_access_token_to_user(self, user_id, token, device_id=None):
         """Adds an access token for the given user.
 
         Args:
             user_id (str): The user ID.
             token (str): The new access token to add.
+            device_id (str): ID of the device to associate with the access
+               token
         Raises:
             StoreError if there was a problem adding this.
         """
@@ -47,18 +49,21 @@ class RegistrationStore(SQLBaseStore):
             {
                 "id": next_id,
                 "user_id": user_id,
-                "token": token
+                "token": token,
+                "device_id": device_id,
             },
             desc="add_access_token_to_user",
         )
 
     @defer.inlineCallbacks
-    def add_refresh_token_to_user(self, user_id, token):
+    def add_refresh_token_to_user(self, user_id, token, device_id=None):
         """Adds a refresh token for the given user.
 
         Args:
             user_id (str): The user ID.
             token (str): The new refresh token to add.
+            device_id (str): ID of the device to associate with the access
+               token
         Raises:
             StoreError if there was a problem adding this.
         """
@@ -69,7 +74,8 @@ class RegistrationStore(SQLBaseStore):
             {
                 "id": next_id,
                 "user_id": user_id,
-                "token": token
+                "token": token,
+                "device_id": device_id,
             },
             desc="add_refresh_token_to_user",
         )
@@ -291,18 +297,18 @@ class RegistrationStore(SQLBaseStore):
         )
 
     def exchange_refresh_token(self, refresh_token, token_generator):
-        """Exchange a refresh token for a new access token and refresh token.
+        """Exchange a refresh token for a new one.
 
         Doing so invalidates the old refresh token - refresh tokens are single
         use.
 
         Args:
-            token (str): The refresh token of a user.
+            refresh_token (str): The refresh token of a user.
             token_generator (fn: str -> str): Function which, when given a
                 user ID, returns a unique refresh token for that user. This
                 function must never return the same value twice.
         Returns:
-            tuple of (user_id, refresh_token)
+            tuple of (user_id, new_refresh_token, device_id)
         Raises:
             StoreError if no user was found with that refresh token.
         """
@@ -314,12 +320,13 @@ class RegistrationStore(SQLBaseStore):
         )
 
     def _exchange_refresh_token(self, txn, old_token, token_generator):
-        sql = "SELECT user_id FROM refresh_tokens WHERE token = ?"
+        sql = "SELECT user_id, device_id FROM refresh_tokens WHERE token = ?"
         txn.execute(sql, (old_token,))
         rows = self.cursor_to_dict(txn)
         if not rows:
             raise StoreError(403, "Did not recognize refresh token")
         user_id = rows[0]["user_id"]
+        device_id = rows[0]["device_id"]
 
         # TODO(danielwh): Maybe perform a validation on the macaroon that
         # macaroon.user_id == user_id.
@@ -328,7 +335,7 @@ class RegistrationStore(SQLBaseStore):
         sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?"
         txn.execute(sql, (new_token, old_token,))
 
-        return user_id, new_token
+        return user_id, new_token, device_id
 
     @defer.inlineCallbacks
     def is_server_admin(self, user):
@@ -356,7 +363,8 @@ class RegistrationStore(SQLBaseStore):
 
     def _query_for_auth(self, txn, token):
         sql = (
-            "SELECT users.name, users.is_guest, access_tokens.id as token_id"
+            "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
+            " access_tokens.device_id"
             " FROM users"
             " INNER JOIN access_tokens on users.name = access_tokens.user_id"
             " WHERE token = ?"