summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorAmber Brown <hawkowl@atleastfornow.net>2019-04-05 00:21:16 +1100
committerGitHub <noreply@github.com>2019-04-05 00:21:16 +1100
commita33a5abc4cd0a73fdf851850da21d749d842ae32 (patch)
tree4ccfc5a2403a17d4a529608ecc89bb4506c90f2e /synapse
parentMerge pull request #5002 from matrix-org/erikj/delete_group (diff)
downloadsynapse-a33a5abc4cd0a73fdf851850da21d749d842ae32.tar.xz
Clean up the database pagination code (#5007)
* rewrite & simplify

* changelog

* cleanup potential sql injection
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/__init__.py20
-rw-r--r--synapse/storage/_base.py110
2 files changed, 63 insertions, 67 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index e9aa2fc9dd..c432041b4e 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -18,6 +18,8 @@ import calendar
 import logging
 import time
 
+from twisted.internet import defer
+
 from synapse.api.constants import PresenceState
 from synapse.storage.devices import DeviceStore
 from synapse.storage.user_erasure_store import UserErasureStore
@@ -453,6 +455,7 @@ class DataStore(
             desc="get_users",
         )
 
+    @defer.inlineCallbacks
     def get_users_paginate(self, order, start, limit):
         """Function to reterive a paginated list of users from
         users list. This will return a json object, which contains
@@ -465,16 +468,19 @@ class DataStore(
         Returns:
             defer.Deferred: resolves to json object {list[dict[str, Any]], count}
         """
-        is_guest = 0
-        i_start = (int)(start)
-        i_limit = (int)(limit)
-        return self.get_user_list_paginate(
+        users = yield self.runInteraction(
+            "get_users_paginate",
+            self._simple_select_list_paginate_txn,
             table="users",
-            keyvalues={"is_guest": is_guest},
-            pagevalues=[order, i_limit, i_start],
+            keyvalues={"is_guest": False},
+            orderby=order,
+            start=start,
+            limit=limit,
             retcols=["name", "password_hash", "is_guest", "admin"],
-            desc="get_users_paginate",
         )
+        count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn)
+        retval = {"users": users, "total": count}
+        defer.returnValue(retval)
 
     def search_users(self, term):
         """Function to search users list for one or more users with
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 131820628a..983ce026e1 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -595,7 +595,7 @@ class SQLBaseStore(object):
 
         Args:
             table (str): The table to upsert into
-            keyvalues (dict): The unique key tables and their new values
+            keyvalues (dict): The unique key columns and their new values
             values (dict): The nonunique columns and their new values
             insertion_values (dict): additional key/values to use only when
                 inserting
@@ -627,7 +627,7 @@ class SQLBaseStore(object):
 
                 # presumably we raced with another transaction: let's retry.
                 logger.warn(
-                    "%s when upserting into %s; retrying: %s", e.__name__, table, e
+                    "IntegrityError when upserting into %s; retrying: %s", table, e
                 )
 
     def _simple_upsert_txn(
@@ -1398,21 +1398,31 @@ class SQLBaseStore(object):
             return 0
 
     def _simple_select_list_paginate(
-        self, table, keyvalues, pagevalues, retcols, desc="_simple_select_list_paginate"
+        self,
+        table,
+        keyvalues,
+        orderby,
+        start,
+        limit,
+        retcols,
+        order_direction="ASC",
+        desc="_simple_select_list_paginate",
     ):
-        """Executes a SELECT query on the named table with start and limit,
+        """
+        Executes a SELECT query on the named table with start and limit,
         of row numbers, which may return zero or number of rows from start to limit,
         returning the result as a list of dicts.
 
         Args:
             table (str): the table name
-            keyvalues (dict[str, Any] | None):
+            keyvalues (dict[str, T] | None):
                 column names and values to select the rows with, or None to not
                 apply a WHERE clause.
+            orderby (str): Column to order the results by.
+            start (int): Index to begin the query at.
+            limit (int): Number of results to return.
             retcols (iterable[str]): the names of the columns to return
-            order (str): order the select by this column
-            start (int): start number to begin the query from
-            limit (int): number of rows to reterive
+            order_direction (str): Whether the results should be ordered "ASC" or "DESC".
         Returns:
             defer.Deferred: resolves to list[dict[str, Any]]
         """
@@ -1421,15 +1431,27 @@ class SQLBaseStore(object):
             self._simple_select_list_paginate_txn,
             table,
             keyvalues,
-            pagevalues,
+            orderby,
+            start,
+            limit,
             retcols,
+            order_direction=order_direction,
         )
 
     @classmethod
     def _simple_select_list_paginate_txn(
-        cls, txn, table, keyvalues, pagevalues, retcols
+        cls,
+        txn,
+        table,
+        keyvalues,
+        orderby,
+        start,
+        limit,
+        retcols,
+        order_direction="ASC",
     ):
-        """Executes a SELECT query on the named table with start and limit,
+        """
+        Executes a SELECT query on the named table with start and limit,
         of row numbers, which may return zero or number of rows from start to limit,
         returning the result as a list of dicts.
 
@@ -1439,64 +1461,32 @@ class SQLBaseStore(object):
             keyvalues (dict[str, T] | None):
                 column names and values to select the rows with, or None to not
                 apply a WHERE clause.
-            pagevalues ([]):
-                order (str): order the select by this column
-                start (int): start number to begin the query from
-                limit (int): number of rows to reterive
+            orderby (str): Column to order the results by.
+            start (int): Index to begin the query at.
+            limit (int): Number of results to return.
             retcols (iterable[str]): the names of the columns to return
+            order_direction (str): Whether the results should be ordered "ASC" or "DESC".
         Returns:
             defer.Deferred: resolves to list[dict[str, Any]]
-
         """
+        if order_direction not in ["ASC", "DESC"]:
+            raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
+
         if keyvalues:
-            sql = "SELECT %s FROM %s WHERE %s ORDER BY %s" % (
-                ", ".join(retcols),
-                table,
-                " AND ".join("%s = ?" % (k,) for k in keyvalues),
-                " ? ASC LIMIT ? OFFSET ?",
-            )
-            txn.execute(sql, list(keyvalues.values()) + list(pagevalues))
+            where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues)
         else:
-            sql = "SELECT %s FROM %s ORDER BY %s" % (
-                ", ".join(retcols),
-                table,
-                " ? ASC LIMIT ? OFFSET ?",
-            )
-            txn.execute(sql, pagevalues)
+            where_clause = ""
 
-        return cls.cursor_to_dict(txn)
-
-    @defer.inlineCallbacks
-    def get_user_list_paginate(
-        self, table, keyvalues, pagevalues, retcols, desc="get_user_list_paginate"
-    ):
-        """Get a list of users from start row to a limit number of rows. This will
-        return a json object with users and total number of users in users list.
-
-        Args:
-            table (str): the table name
-            keyvalues (dict[str, Any] | None):
-                column names and values to select the rows with, or None to not
-                apply a WHERE clause.
-            pagevalues ([]):
-                order (str): order the select by this column
-                start (int): start number to begin the query from
-                limit (int): number of rows to reterive
-            retcols (iterable[str]): the names of the columns to return
-        Returns:
-            defer.Deferred: resolves to json object {list[dict[str, Any]], count}
-        """
-        users = yield self.runInteraction(
-            desc,
-            self._simple_select_list_paginate_txn,
+        sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
+            ", ".join(retcols),
             table,
-            keyvalues,
-            pagevalues,
-            retcols,
+            where_clause,
+            orderby,
+            order_direction,
         )
-        count = yield self.runInteraction(desc, self.get_user_count_txn)
-        retval = {"users": users, "total": count}
-        defer.returnValue(retval)
+        txn.execute(sql, list(keyvalues.values()) + [limit, start])
+
+        return cls.cursor_to_dict(txn)
 
     def get_user_count_txn(self, txn):
         """Get a total number of registered users in the users list.