diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index a252f8eaa0..b4469eb964 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -2461,6 +2461,66 @@ def make_in_list_sql_clause(
return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
+# These overloads ensure that `columns` and `iterable` values have the same length.
+# Suppress "Single overload definition, multiple required" complaint.
+@overload # type: ignore[misc]
+def make_tuple_in_list_sql_clause(
+ database_engine: BaseDatabaseEngine,
+ columns: Tuple[str, str],
+ iterable: Collection[Tuple[Any, Any]],
+) -> Tuple[str, list]:
+ ...
+
+
+def make_tuple_in_list_sql_clause(
+ database_engine: BaseDatabaseEngine,
+ columns: Tuple[str, ...],
+ iterable: Collection[Tuple[Any, ...]],
+) -> Tuple[str, list]:
+ """Returns an SQL clause that checks the given tuple of columns is in the iterable.
+
+ Args:
+ database_engine
+ columns: Names of the columns in the tuple.
+ iterable: The tuples to check the columns against.
+
+ Returns:
+ A tuple of SQL query and the args
+ """
+ if len(columns) == 0:
+ # Should be unreachable due to mypy, as long as the overloads are set up right.
+ if () in iterable:
+ return "TRUE", []
+ else:
+ return "FALSE", []
+
+ if len(columns) == 1:
+ # Use `= ANY(?)` on postgres.
+ return make_in_list_sql_clause(
+ database_engine, next(iter(columns)), [values[0] for values in iterable]
+ )
+
+ # There are multiple columns. Avoid using an `= ANY(?)` clause on postgres, as
+ # indices are not used when there are multiple columns. Instead, use an `IN`
+ # expression.
+ #
+ # `IN ((?, ...), ...)` with tuples is supported by postgres only, whereas
+ # `IN (VALUES (?, ...), ...)` is supported by both sqlite and postgres.
+ # Thus, the latter is chosen.
+
+ if len(iterable) == 0:
+ # A 0-length `VALUES` list is not allowed in sqlite or postgres.
+ # Also note that a 0-length `IN (...)` clause (not using `VALUES`) is not
+ # allowed in postgres.
+ return "FALSE", []
+
+ tuple_sql = "(%s)" % (",".join("?" for _ in columns),)
+ return "(%s) IN (VALUES %s)" % (
+ ",".join(column for column in columns),
+ ",".join(tuple_sql for _ in iterable),
+ ), [value for values in iterable for value in values]
+
+
KV = TypeVar("KV")
|