diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index abe16334ec..a94cbc27d3 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -20,6 +20,7 @@ import random
import sys
import threading
import time
+from typing import Iterable, List, Tuple
from six import PY2, iteritems, iterkeys, itervalues
from six.moves import builtins, intern, range
@@ -1162,19 +1163,20 @@ class SQLBaseStore(object):
if not iterable:
return []
- sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
-
clauses = []
values = []
- clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
- values.extend(iterable)
+
+ add_in_list_sql_clause(txn.database_engine, column, iterable, clauses, values)
for key, value in iteritems(keyvalues):
clauses.append("%s = ?" % (key,))
values.append(value)
- if clauses:
- sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
+ sql = "SELECT %s FROM %s WHERE %s" % (
+ ", ".join(retcols),
+ table,
+ " AND ".join(clauses),
+ )
txn.execute(sql, values)
return cls.cursor_to_dict(txn)
@@ -1325,8 +1327,8 @@ class SQLBaseStore(object):
clauses = []
values = []
- clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
- values.extend(iterable)
+
+ add_in_list_sql_clause(txn.database_engine, column, iterable, clauses, values)
for key, value in iteritems(keyvalues):
clauses.append("%s = ?" % (key,))
@@ -1693,3 +1695,49 @@ def db_to_json(db_content):
except Exception:
logging.warning("Tried to decode '%r' as JSON and failed", db_content)
raise
+
+
+def add_in_list_sql_clause(
+ database_engine, column: str, iterable: Iterable, clauses: List[str], args: List
+):
+ """Adds an SQL clause to the given list of clauses/args that checks the
+ given column is in the iterable. c.f. `make_in_list_sql_clause`
+
+ Args:
+ database_engine
+ column: Name of the column
+ iterable: The values to check the column against.
+ clauses: A list to add the expanded clause to
+ args: A list of arguments that we append the args to.
+ """
+
+ clause, new_args = make_in_list_sql_clause(database_engine, column, iterable)
+ clauses.append(clause)
+ args.extend(new_args)
+
+
+def make_in_list_sql_clause(
+ database_engine, column: str, iterable: Iterable
+) -> Tuple[str, Iterable]:
+ """Returns an SQL clause that checks the given column is in the iterable.
+
+ On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres
+ it expands to `column = ANY(?)`. While both DBs support the `IN` form,
+ using the `ANY` form on postgres means that it views queries with
+ different length iterables as the same, helping the query stats.
+
+ Args:
+ database_engine
+ column: Name of the column
+ iterable: The values to check the column against.
+
+ Returns:
+ A tuple of SQL query and the args
+ """
+
+ if isinstance(database_engine, PostgresEngine):
+ # This should hopefully be faster, but also makes postgres query
+ # stats easier to understand.
+ return "%s = ANY(?)" % (column,), [list(iterable)]
+ else:
+ return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), iterable
|