summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/_base.py35
-rw-r--r--synapse/storage/engines/postgres.py6
-rw-r--r--synapse/storage/engines/sqlite.py6
-rw-r--r--synapse/storage/search.py12
4 files changed, 25 insertions, 34 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 085b8ae871..6176838aa6 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -20,7 +20,7 @@ import random
 import sys
 import threading
 import time
-from typing import Iterable, List, Tuple
+from typing import Iterable, Tuple
 
 from six import PY2, iteritems, iterkeys, itervalues
 from six.moves import builtins, intern, range
@@ -1164,10 +1164,8 @@ class SQLBaseStore(object):
         if not iterable:
             return []
 
-        clauses = []
-        values = []
-
-        add_in_list_sql_clause(txn.database_engine, column, iterable, clauses, values)
+        clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
+        clauses = [clause]
 
         for key, value in iteritems(keyvalues):
             clauses.append("%s = ?" % (key,))
@@ -1326,10 +1324,8 @@ class SQLBaseStore(object):
 
         sql = "DELETE FROM %s" % table
 
-        clauses = []
-        values = []
-
-        add_in_list_sql_clause(txn.database_engine, column, iterable, clauses, values)
+        clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
+        clauses = [clause]
 
         for key, value in iteritems(keyvalues):
             clauses.append("%s = ?" % (key,))
@@ -1698,25 +1694,6 @@ def db_to_json(db_content):
         raise
 
 
-def add_in_list_sql_clause(
-    database_engine, column: str, iterable: Iterable, clauses: List[str], args: List
-):
-    """Adds an SQL clause to the given list of clauses/args that checks the
-    given column is in the iterable. c.f. `make_in_list_sql_clause`
-
-    Args:
-        database_engine
-        column: Name of the column
-        iterable: The values to check the column against.
-        clauses: A list to add the expanded clause to
-        args: A list of arguments that we append the args to.
-    """
-
-    clause, new_args = make_in_list_sql_clause(database_engine, column, iterable)
-    clauses.append(clause)
-    args.extend(new_args)
-
-
 def make_in_list_sql_clause(
     database_engine, column: str, iterable: Iterable
 ) -> Tuple[str, Iterable]:
@@ -1736,7 +1713,7 @@ def make_in_list_sql_clause(
         A tuple of SQL query and the args
     """
 
-    if isinstance(database_engine, PostgresEngine):
+    if database_engine.supports_using_any_list:
         # This should hopefully be faster, but also makes postgres query
         # stats easier to understand.
         return "%s = ANY(?)" % (column,), [list(iterable)]
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 601617b21e..f36600b4bb 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -79,6 +79,12 @@ class PostgresEngine(object):
         """
         return True
 
+    @property
+    def supports_using_any_list(self):
+        """Do we support using `a = ANY(?)` and passing a list
+        """
+        return True
+
     def is_deadlock(self, error):
         if isinstance(error, self.module.DatabaseError):
             # https://www.postgresql.org/docs/current/static/errcodes-appendix.html
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index ac92109366..2526258060 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -46,6 +46,12 @@ class Sqlite3Engine(object):
         """
         return self.module.sqlite_version_info >= (3, 15, 0)
 
+    @property
+    def supports_any_list(self):
+        """Do we support using `a = ANY(?)` and passing a list
+        """
+        return False
+
     def check_database(self, txn):
         pass
 
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index 4be6e56dfa..7695bf09fc 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -24,7 +24,7 @@ from canonicaljson import json
 from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
-from synapse.storage._base import add_in_list_sql_clause
+from synapse.storage._base import make_in_list_sql_clause
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
 from .background_updates import BackgroundUpdateStore
@@ -386,9 +386,10 @@ class SearchStore(SearchBackgroundUpdateStore):
         # Make sure we don't explode because the person is in too many rooms.
         # We filter the results below regardless.
         if len(room_ids) < 500:
-            add_in_list_sql_clause(
-                self.database_engine, "room_id", room_ids, clauses, args
+            clause, args = make_in_list_sql_clause(
+                self.database_engine, "room_id", room_ids
             )
+            clauses = [clause]
 
         local_clauses = []
         for key in keys:
@@ -494,9 +495,10 @@ class SearchStore(SearchBackgroundUpdateStore):
         # Make sure we don't explode because the person is in too many rooms.
         # We filter the results below regardless.
         if len(room_ids) < 500:
-            add_in_list_sql_clause(
-                self.database_engine, "room_id", room_ids, clauses, args
+            clause, args = make_in_list_sql_clause(
+                self.database_engine, "room_id", room_ids
             )
+            clauses = [clause]
 
         local_clauses = []
         for key in keys: