summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/database.py83
-rw-r--r--synapse/storage/databases/main/client_ips.py1
-rw-r--r--synapse/storage/databases/main/devices.py2
-rw-r--r--synapse/storage/databases/main/events.py7
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py1
-rw-r--r--synapse/storage/databases/main/group_server.py4
-rw-r--r--synapse/storage/databases/main/state.py6
-rw-r--r--synapse/storage/databases/state/bg_updates.py5
-rw-r--r--synapse/storage/databases/state/store.py5
-rw-r--r--synapse/storage/engines/_base.py8
-rw-r--r--synapse/storage/engines/postgres.py11
-rw-r--r--synapse/storage/engines/sqlite.py15
-rw-r--r--synapse/storage/state.py26
-rw-r--r--synapse/storage/util/id_generators.py11
14 files changed, 72 insertions, 113 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 94590e7b45..fa15b0ce5b 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -900,7 +900,7 @@ class DatabasePool:
         table: str,
         keyvalues: Dict[str, Any],
         values: Dict[str, Any],
-        insertion_values: Dict[str, Any] = {},
+        insertion_values: Optional[Dict[str, Any]] = None,
         desc: str = "simple_upsert",
         lock: bool = True,
     ) -> Optional[bool]:
@@ -927,6 +927,8 @@ class DatabasePool:
             Native upserts always return None. Emulated upserts return True if a
             new entry was created, False if an existing one was updated.
         """
+        insertion_values = insertion_values or {}
+
         attempts = 0
         while True:
             try:
@@ -964,7 +966,7 @@ class DatabasePool:
         table: str,
         keyvalues: Dict[str, Any],
         values: Dict[str, Any],
-        insertion_values: Dict[str, Any] = {},
+        insertion_values: Optional[Dict[str, Any]] = None,
         lock: bool = True,
     ) -> Optional[bool]:
         """
@@ -982,6 +984,8 @@ class DatabasePool:
             Native upserts always return None. Emulated upserts return True if a
             new entry was created, False if an existing one was updated.
         """
+        insertion_values = insertion_values or {}
+
         if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
             self.simple_upsert_txn_native_upsert(
                 txn, table, keyvalues, values, insertion_values=insertion_values
@@ -1003,7 +1007,7 @@ class DatabasePool:
         table: str,
         keyvalues: Dict[str, Any],
         values: Dict[str, Any],
-        insertion_values: Dict[str, Any] = {},
+        insertion_values: Optional[Dict[str, Any]] = None,
         lock: bool = True,
     ) -> bool:
         """
@@ -1017,6 +1021,8 @@ class DatabasePool:
             Returns True if a new entry was created, False if an existing
             one was updated.
         """
+        insertion_values = insertion_values or {}
+
         # We need to lock the table :(, unless we're *really* careful
         if lock:
             self.engine.lock_table(txn, table)
@@ -1077,7 +1083,7 @@ class DatabasePool:
         table: str,
         keyvalues: Dict[str, Any],
         values: Dict[str, Any],
-        insertion_values: Dict[str, Any] = {},
+        insertion_values: Optional[Dict[str, Any]] = None,
     ) -> None:
         """
         Use the native UPSERT functionality in recent PostgreSQL versions.
@@ -1090,7 +1096,7 @@ class DatabasePool:
         """
         allvalues = {}  # type: Dict[str, Any]
         allvalues.update(keyvalues)
-        allvalues.update(insertion_values)
+        allvalues.update(insertion_values or {})
 
         if not values:
             latter = "NOTHING"
@@ -1513,7 +1519,7 @@ class DatabasePool:
         column: str,
         iterable: Iterable[Any],
         retcols: Iterable[str],
-        keyvalues: Dict[str, Any] = {},
+        keyvalues: Optional[Dict[str, Any]] = None,
         desc: str = "simple_select_many_batch",
         batch_size: int = 100,
     ) -> List[Any]:
@@ -1531,6 +1537,8 @@ class DatabasePool:
             desc: description of the transaction, for logging and metrics
             batch_size: the number of rows for each select query
         """
+        keyvalues = keyvalues or {}
+
         results = []  # type: List[Dict[str, Any]]
 
         if not iterable:
@@ -2059,69 +2067,18 @@ def make_in_list_sql_clause(
 KV = TypeVar("KV")
 
 
-def make_tuple_comparison_clause(
-    database_engine: BaseDatabaseEngine, keys: List[Tuple[str, KV]]
-) -> Tuple[str, List[KV]]:
+def make_tuple_comparison_clause(keys: List[Tuple[str, KV]]) -> Tuple[str, List[KV]]:
     """Returns a tuple comparison SQL clause
 
-    Depending what the SQL engine supports, builds a SQL clause that looks like either
-    "(a, b) > (?, ?)", or "(a > ?) OR (a == ? AND b > ?)".
+    Builds a SQL clause that looks like "(a, b) > (?, ?)"
 
     Args:
-        database_engine
         keys: A set of (column, value) pairs to be compared.
 
     Returns:
         A tuple of SQL query and the args
     """
-    if database_engine.supports_tuple_comparison:
-        return (
-            "(%s) > (%s)" % (",".join(k[0] for k in keys), ",".join("?" for _ in keys)),
-            [k[1] for k in keys],
-        )
-
-    # we want to build a clause
-    #    (a > ?) OR
-    #    (a == ? AND b > ?) OR
-    #    (a == ? AND b == ? AND c > ?)
-    #    ...
-    #    (a == ? AND b == ? AND ... AND z > ?)
-    #
-    # or, equivalently:
-    #
-    #  (a > ? OR (a == ? AND
-    #    (b > ? OR (b == ? AND
-    #      ...
-    #        (y > ? OR (y == ? AND
-    #          z > ?
-    #        ))
-    #      ...
-    #    ))
-    #  ))
-    #
-    # which itself is equivalent to (and apparently easier for the query optimiser):
-    #
-    #  (a >= ? AND (a > ? OR
-    #    (b >= ? AND (b > ? OR
-    #      ...
-    #        (y >= ? AND (y > ? OR
-    #          z > ?
-    #        ))
-    #      ...
-    #    ))
-    #  ))
-    #
-    #
-
-    clause = ""
-    args = []  # type: List[KV]
-    for k, v in keys[:-1]:
-        clause = clause + "(%s >= ? AND (%s > ? OR " % (k, k)
-        args.extend([v, v])
-
-    (k, v) = keys[-1]
-    clause += "%s > ?" % (k,)
-    args.append(v)
-
-    clause += "))" * (len(keys) - 1)
-    return clause, args
+    return (
+        "(%s) > (%s)" % (",".join(k[0] for k in keys), ",".join("?" for _ in keys)),
+        [k[1] for k in keys],
+    )
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 6d18e692b0..ea3c15fd0e 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -298,7 +298,6 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
             #      times, which is fine.
 
             where_clause, where_args = make_tuple_comparison_clause(
-                self.database_engine,
                 [("user_id", last_user_id), ("device_id", last_device_id)],
             )
 
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index d327e9aa0b..9bf8ba888f 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -985,7 +985,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
 
         def _txn(txn):
             clause, args = make_tuple_comparison_clause(
-                self.db_pool.engine, [(x, last_row[x]) for x in KEY_COLS]
+                [(x, last_row[x]) for x in KEY_COLS]
             )
             sql = """
                 SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 98dac19a95..ad17123915 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -320,8 +320,8 @@ class PersistEventsStore:
         txn: LoggingTransaction,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         backfilled: bool,
-        state_delta_for_room: Dict[str, DeltaState] = {},
-        new_forward_extremeties: Dict[str, List[str]] = {},
+        state_delta_for_room: Optional[Dict[str, DeltaState]] = None,
+        new_forward_extremeties: Optional[Dict[str, List[str]]] = None,
     ):
         """Insert some number of room events into the necessary database tables.
 
@@ -342,6 +342,9 @@ class PersistEventsStore:
                 extremities.
 
         """
+        state_delta_for_room = state_delta_for_room or {}
+        new_forward_extremeties = new_forward_extremeties or {}
+
         all_events_and_contexts = events_and_contexts
 
         min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 78367ea58d..79e7df6ca9 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -838,7 +838,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         # We want to do a `(topological_ordering, stream_ordering) > (?,?)`
         # comparison, but that is not supported on older SQLite versions
         tuple_clause, tuple_args = make_tuple_comparison_clause(
-            self.database_engine,
             [
                 ("events.room_id", last_room_id),
                 ("topological_ordering", last_depth),
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 8f462dfc31..bd7826f4e9 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -1171,7 +1171,7 @@ class GroupServerStore(GroupServerWorkerStore):
         user_id: str,
         membership: str,
         is_admin: bool = False,
-        content: JsonDict = {},
+        content: Optional[JsonDict] = None,
         local_attestation: Optional[dict] = None,
         remote_attestation: Optional[dict] = None,
         is_publicised: bool = False,
@@ -1192,6 +1192,8 @@ class GroupServerStore(GroupServerWorkerStore):
             is_publicised: Whether this should be publicised.
         """
 
+        content = content or {}
+
         def _register_user_group_membership_txn(txn, next_id):
             # TODO: Upsert?
             self.db_pool.simple_delete_txn(
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index a7f371732f..93431efe00 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -190,7 +190,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     # FIXME: how should this be cached?
     async def get_filtered_current_state_ids(
-        self, room_id: str, state_filter: StateFilter = StateFilter.all()
+        self, room_id: str, state_filter: Optional[StateFilter] = None
     ) -> StateMap[str]:
         """Get the current state event of a given type for a room based on the
         current_state_events table.  This may not be as up-to-date as the result
@@ -205,7 +205,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             Map from type/state_key to event ID.
         """
 
-        where_clause, where_args = state_filter.make_sql_filter_clause()
+        where_clause, where_args = (
+            state_filter or StateFilter.all()
+        ).make_sql_filter_clause()
 
         if not where_clause:
             # We delegate to the cached version
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index 1fd333b707..75c09b3687 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import Optional
 
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
@@ -73,8 +74,10 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
             return count
 
     def _get_state_groups_from_groups_txn(
-        self, txn, groups, state_filter=StateFilter.all()
+        self, txn, groups, state_filter: Optional[StateFilter] = None
     ):
+        state_filter = state_filter or StateFilter.all()
+
         results = {group: {} for group in groups}
 
         where_clause, where_args = state_filter.make_sql_filter_clause()
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 97ec65f757..dfcf89d91c 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -15,7 +15,7 @@
 
 import logging
 from collections import namedtuple
-from typing import Dict, Iterable, List, Set, Tuple
+from typing import Dict, Iterable, List, Optional, Set, Tuple
 
 from synapse.api.constants import EventTypes
 from synapse.storage._base import SQLBaseStore
@@ -210,7 +210,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         return state_filter.filter_state(state_dict_ids), not missing_types
 
     async def _get_state_for_groups(
-        self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
+        self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
     ) -> Dict[int, MutableStateMap[str]]:
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key
@@ -223,6 +223,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         Returns:
             Dict of state group to state map.
         """
+        state_filter = state_filter or StateFilter.all()
 
         member_filter, non_member_filter = state_filter.get_member_split()
 
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index cca839c70f..21db1645d3 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -44,14 +44,6 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
 
     @property
     @abc.abstractmethod
-    def supports_tuple_comparison(self) -> bool:
-        """
-        Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
-        """
-        ...
-
-    @property
-    @abc.abstractmethod
     def supports_using_any_list(self) -> bool:
         """
         Do we support using `a = ANY(?)` and passing a list
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 80a3558aec..dba8cc51d3 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -47,8 +47,8 @@ class PostgresEngine(BaseDatabaseEngine):
         self._version = db_conn.server_version
 
         # Are we on a supported PostgreSQL version?
-        if not allow_outdated_version and self._version < 90500:
-            raise RuntimeError("Synapse requires PostgreSQL 9.5+ or above.")
+        if not allow_outdated_version and self._version < 90600:
+            raise RuntimeError("Synapse requires PostgreSQL 9.6 or above.")
 
         with db_conn.cursor() as txn:
             txn.execute("SHOW SERVER_ENCODING")
@@ -130,13 +130,6 @@ class PostgresEngine(BaseDatabaseEngine):
         return True
 
     @property
-    def supports_tuple_comparison(self):
-        """
-        Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
-        """
-        return True
-
-    @property
     def supports_using_any_list(self):
         """Do we support using `a = ANY(?)` and passing a list"""
         return True
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index b87e7798da..f4f16456f2 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -57,14 +57,6 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
         return self.module.sqlite_version_info >= (3, 24, 0)
 
     @property
-    def supports_tuple_comparison(self):
-        """
-        Do we support comparing tuples, i.e. `(a, b) > (c, d)`? This requires
-        SQLite 3.15+.
-        """
-        return self.module.sqlite_version_info >= (3, 15, 0)
-
-    @property
     def supports_using_any_list(self):
         """Do we support using `a = ANY(?)` and passing a list"""
         return False
@@ -72,8 +64,11 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
     def check_database(self, db_conn, allow_outdated_version: bool = False):
         if not allow_outdated_version:
             version = self.module.sqlite_version_info
-            if version < (3, 11, 0):
-                raise RuntimeError("Synapse requires sqlite 3.11 or above.")
+            # Synapse is untested against older SQLite versions, and we don't want
+            # to let users upgrade to a version of Synapse with broken support for their
+            # sqlite version, because it risks leaving them with a half-upgraded db.
+            if version < (3, 22, 0):
+                raise RuntimeError("Synapse requires sqlite 3.22 or above.")
 
     def check_new_database(self, txn):
         """Gets called when setting up a brand new database. This allows us to
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 2e277a21c4..c1c147c62a 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -449,7 +449,7 @@ class StateGroupStorage:
         return self.stores.state._get_state_groups_from_groups(groups, state_filter)
 
     async def get_state_for_events(
-        self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all()
+        self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
     ) -> Dict[str, StateMap[EventBase]]:
         """Given a list of event_ids and type tuples, return a list of state
         dicts for each event.
@@ -465,7 +465,7 @@ class StateGroupStorage:
 
         groups = set(event_to_groups.values())
         group_to_state = await self.stores.state._get_state_for_groups(
-            groups, state_filter
+            groups, state_filter or StateFilter.all()
         )
 
         state_event_map = await self.stores.main.get_events(
@@ -485,7 +485,7 @@ class StateGroupStorage:
         return {event: event_to_state[event] for event in event_ids}
 
     async def get_state_ids_for_events(
-        self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all()
+        self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
     ) -> Dict[str, StateMap[str]]:
         """
         Get the state dicts corresponding to a list of events, containing the event_ids
@@ -502,7 +502,7 @@ class StateGroupStorage:
 
         groups = set(event_to_groups.values())
         group_to_state = await self.stores.state._get_state_for_groups(
-            groups, state_filter
+            groups, state_filter or StateFilter.all()
         )
 
         event_to_state = {
@@ -513,7 +513,7 @@ class StateGroupStorage:
         return {event: event_to_state[event] for event in event_ids}
 
     async def get_state_for_event(
-        self, event_id: str, state_filter: StateFilter = StateFilter.all()
+        self, event_id: str, state_filter: Optional[StateFilter] = None
     ) -> StateMap[EventBase]:
         """
         Get the state dict corresponding to a particular event
@@ -525,11 +525,13 @@ class StateGroupStorage:
         Returns:
             A dict from (type, state_key) -> state_event
         """
-        state_map = await self.get_state_for_events([event_id], state_filter)
+        state_map = await self.get_state_for_events(
+            [event_id], state_filter or StateFilter.all()
+        )
         return state_map[event_id]
 
     async def get_state_ids_for_event(
-        self, event_id: str, state_filter: StateFilter = StateFilter.all()
+        self, event_id: str, state_filter: Optional[StateFilter] = None
     ) -> StateMap[str]:
         """
         Get the state dict corresponding to a particular event
@@ -541,11 +543,13 @@ class StateGroupStorage:
         Returns:
             A dict from (type, state_key) -> state_event
         """
-        state_map = await self.get_state_ids_for_events([event_id], state_filter)
+        state_map = await self.get_state_ids_for_events(
+            [event_id], state_filter or StateFilter.all()
+        )
         return state_map[event_id]
 
     def _get_state_for_groups(
-        self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
+        self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
     ) -> Awaitable[Dict[int, MutableStateMap[str]]]:
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key
@@ -558,7 +562,9 @@ class StateGroupStorage:
         Returns:
             Dict of state group to state map.
         """
-        return self.stores.state._get_state_for_groups(groups, state_filter)
+        return self.stores.state._get_state_for_groups(
+            groups, state_filter or StateFilter.all()
+        )
 
     async def store_state_group(
         self,
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index d4643c4fdf..32d6cc16b9 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -17,7 +17,7 @@ import logging
 import threading
 from collections import OrderedDict
 from contextlib import contextmanager
-from typing import Dict, List, Optional, Set, Tuple, Union
+from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import attr
 
@@ -91,7 +91,14 @@ class StreamIdGenerator:
             # ... persist event ...
     """
 
-    def __init__(self, db_conn, table, column, extra_tables=[], step=1):
+    def __init__(
+        self,
+        db_conn,
+        table,
+        column,
+        extra_tables: Iterable[Tuple[str, str]] = (),
+        step=1,
+    ):
         assert step != 0
         self._lock = threading.Lock()
         self._step = step