summary refs log tree commit diff
path: root/synapse/storage/database.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/database.py')
-rw-r--r--synapse/storage/database.py60
1 files changed, 60 insertions, 0 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index a252f8eaa0..b4469eb964 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -2461,6 +2461,66 @@ def make_in_list_sql_clause(
         return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
 
 
+# These overloads ensure that `columns` and `iterable` values have the same length.
+# Suppress "Single overload definition, multiple required" complaint.
+@overload  # type: ignore[misc]
+def make_tuple_in_list_sql_clause(
+    database_engine: BaseDatabaseEngine,
+    columns: Tuple[str, str],
+    iterable: Collection[Tuple[Any, Any]],
+) -> Tuple[str, list]:
+    ...
+
+
+def make_tuple_in_list_sql_clause(
+    database_engine: BaseDatabaseEngine,
+    columns: Tuple[str, ...],
+    iterable: Collection[Tuple[Any, ...]],
+) -> Tuple[str, list]:
+    """Returns an SQL clause that checks the given tuple of columns is in the iterable.
+
+    Args:
+        database_engine
+        columns: Names of the columns in the tuple.
+        iterable: The tuples to check the columns against.
+
+    Returns:
+        A tuple of SQL query and the args
+    """
+    if len(columns) == 0:
+        # Should be unreachable due to mypy, as long as the overloads are set up right.
+        if () in iterable:
+            return "TRUE", []
+        else:
+            return "FALSE", []
+
+    if len(columns) == 1:
+        # Use `= ANY(?)` on postgres.
+        return make_in_list_sql_clause(
+            database_engine, next(iter(columns)), [values[0] for values in iterable]
+        )
+
+    # There are multiple columns. Avoid using an `= ANY(?)` clause on postgres, as
+    # indices are not used when there are multiple columns. Instead, use an `IN`
+    # expression.
+    #
+    # `IN ((?, ...), ...)` with tuples is supported by postgres only, whereas
+    # `IN (VALUES (?, ...), ...)` is supported by both sqlite and postgres.
+    # Thus, the latter is chosen.
+
+    if len(iterable) == 0:
+        # A 0-length `VALUES` list is not allowed in sqlite or postgres.
+        # Also note that a 0-length `IN (...)` clause (not using `VALUES`) is not
+        # allowed in postgres.
+        return "FALSE", []
+
+    tuple_sql = "(%s)" % (",".join("?" for _ in columns),)
+    return "(%s) IN (VALUES %s)" % (
+        ",".join(column for column in columns),
+        ",".join(tuple_sql for _ in iterable),
+    ), [value for values in iterable for value in values]
+
+
 KV = TypeVar("KV")