summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2019-10-02 19:06:12 +0100
committerErik Johnston <erik@matrix.org>2019-10-10 13:15:24 +0100
commitb4fbf71187545748edf3ebd931b49350e5b1ca74 (patch)
tree271d3d8630ec8b5b0701ca2bdb34bd9cfdd3529f /synapse/storage
parentAdd domain validation when creating room with list of invitees (#6121) (diff)
downloadsynapse-b4fbf71187545748edf3ebd931b49350e5b1ca74.tar.xz
Add helper funcs to use postgres ANY
This means that we can write queries with `col = ANY(?)`, which helps
postgres.
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py64
1 files changed, 56 insertions, 8 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index abe16334ec..a94cbc27d3 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -20,6 +20,7 @@ import random
 import sys
 import threading
 import time
+from typing import Iterable, List, Tuple
 
 from six import PY2, iteritems, iterkeys, itervalues
 from six.moves import builtins, intern, range
@@ -1162,19 +1163,20 @@ class SQLBaseStore(object):
         if not iterable:
             return []
 
-        sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
-
         clauses = []
         values = []
-        clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
-        values.extend(iterable)
+
+        add_in_list_sql_clause(txn.database_engine, column, iterable, clauses, values)
 
         for key, value in iteritems(keyvalues):
             clauses.append("%s = ?" % (key,))
             values.append(value)
 
-        if clauses:
-            sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
+        sql = "SELECT %s FROM %s WHERE %s" % (
+            ", ".join(retcols),
+            table,
+            " AND ".join(clauses),
+        )
 
         txn.execute(sql, values)
         return cls.cursor_to_dict(txn)
@@ -1325,8 +1327,8 @@ class SQLBaseStore(object):
 
         clauses = []
         values = []
-        clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
-        values.extend(iterable)
+
+        add_in_list_sql_clause(txn.database_engine, column, iterable, clauses, values)
 
         for key, value in iteritems(keyvalues):
             clauses.append("%s = ?" % (key,))
@@ -1693,3 +1695,49 @@ def db_to_json(db_content):
     except Exception:
         logging.warning("Tried to decode '%r' as JSON and failed", db_content)
         raise
+
+
+def add_in_list_sql_clause(
+    database_engine, column: str, iterable: Iterable, clauses: List[str], args: List
+):
+    """Adds an SQL clause to the given list of clauses/args that checks the
+    given column is in the iterable. c.f. `make_in_list_sql_clause`
+
+    Args:
+        database_engine
+        column: Name of the column
+        iterable: The values to check the column against.
+        clauses: A list to add the expanded clause to
+        args: A list of arguments that we append the args to.
+    """
+
+    clause, new_args = make_in_list_sql_clause(database_engine, column, iterable)
+    clauses.append(clause)
+    args.extend(new_args)
+
+
+def make_in_list_sql_clause(
+    database_engine, column: str, iterable: Iterable
+) -> Tuple[str, Iterable]:
+    """Returns an SQL clause that checks the given column is in the iterable.
+
+    On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres
+    it expands to `column = ANY(?)`. While both DBs support the `IN` form,
+    using the `ANY` form on postgres means that it views queries with
+    different length iterables as the same, helping the query stats.
+
+    Args:
+        database_engine
+        column: Name of the column
+        iterable: The values to check the column against.
+
+    Returns:
+        A tuple of SQL query and the args
+    """
+
+    if isinstance(database_engine, PostgresEngine):
+        # This should hopefully be faster, but also makes postgres query
+        # stats easier to understand.
+        return "%s = ANY(?)" % (column,), [list(iterable)]
+    else:
+        return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), iterable