diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/database.py | 83 | ||||
-rw-r--r-- | synapse/storage/databases/main/client_ips.py | 1 | ||||
-rw-r--r-- | synapse/storage/databases/main/devices.py | 2 | ||||
-rw-r--r-- | synapse/storage/databases/main/events.py | 7 | ||||
-rw-r--r-- | synapse/storage/databases/main/events_bg_updates.py | 1 | ||||
-rw-r--r-- | synapse/storage/databases/main/group_server.py | 4 | ||||
-rw-r--r-- | synapse/storage/databases/main/state.py | 6 | ||||
-rw-r--r-- | synapse/storage/databases/state/bg_updates.py | 5 | ||||
-rw-r--r-- | synapse/storage/databases/state/store.py | 5 | ||||
-rw-r--r-- | synapse/storage/engines/_base.py | 8 | ||||
-rw-r--r-- | synapse/storage/engines/postgres.py | 11 | ||||
-rw-r--r-- | synapse/storage/engines/sqlite.py | 15 | ||||
-rw-r--r-- | synapse/storage/state.py | 26 | ||||
-rw-r--r-- | synapse/storage/util/id_generators.py | 11 |
14 files changed, 72 insertions, 113 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 94590e7b45..fa15b0ce5b 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -900,7 +900,7 @@ class DatabasePool: table: str, keyvalues: Dict[str, Any], values: Dict[str, Any], - insertion_values: Dict[str, Any] = {}, + insertion_values: Optional[Dict[str, Any]] = None, desc: str = "simple_upsert", lock: bool = True, ) -> Optional[bool]: @@ -927,6 +927,8 @@ class DatabasePool: Native upserts always return None. Emulated upserts return True if a new entry was created, False if an existing one was updated. """ + insertion_values = insertion_values or {} + attempts = 0 while True: try: @@ -964,7 +966,7 @@ class DatabasePool: table: str, keyvalues: Dict[str, Any], values: Dict[str, Any], - insertion_values: Dict[str, Any] = {}, + insertion_values: Optional[Dict[str, Any]] = None, lock: bool = True, ) -> Optional[bool]: """ @@ -982,6 +984,8 @@ class DatabasePool: Native upserts always return None. Emulated upserts return True if a new entry was created, False if an existing one was updated. """ + insertion_values = insertion_values or {} + if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: self.simple_upsert_txn_native_upsert( txn, table, keyvalues, values, insertion_values=insertion_values @@ -1003,7 +1007,7 @@ class DatabasePool: table: str, keyvalues: Dict[str, Any], values: Dict[str, Any], - insertion_values: Dict[str, Any] = {}, + insertion_values: Optional[Dict[str, Any]] = None, lock: bool = True, ) -> bool: """ @@ -1017,6 +1021,8 @@ class DatabasePool: Returns True if a new entry was created, False if an existing one was updated. """ + insertion_values = insertion_values or {} + # We need to lock the table :(, unless we're *really* careful if lock: self.engine.lock_table(txn, table) @@ -1077,7 +1083,7 @@ class DatabasePool: table: str, keyvalues: Dict[str, Any], values: Dict[str, Any], - insertion_values: Dict[str, Any] = {}, + insertion_values: Optional[Dict[str, Any]] = None, ) -> None: """ Use the native UPSERT functionality in recent PostgreSQL versions. @@ -1090,7 +1096,7 @@ class DatabasePool: """ allvalues = {} # type: Dict[str, Any] allvalues.update(keyvalues) - allvalues.update(insertion_values) + allvalues.update(insertion_values or {}) if not values: latter = "NOTHING" @@ -1513,7 +1519,7 @@ class DatabasePool: column: str, iterable: Iterable[Any], retcols: Iterable[str], - keyvalues: Dict[str, Any] = {}, + keyvalues: Optional[Dict[str, Any]] = None, desc: str = "simple_select_many_batch", batch_size: int = 100, ) -> List[Any]: @@ -1531,6 +1537,8 @@ class DatabasePool: desc: description of the transaction, for logging and metrics batch_size: the number of rows for each select query """ + keyvalues = keyvalues or {} + results = [] # type: List[Dict[str, Any]] if not iterable: @@ -2059,69 +2067,18 @@ def make_in_list_sql_clause( KV = TypeVar("KV") -def make_tuple_comparison_clause( - database_engine: BaseDatabaseEngine, keys: List[Tuple[str, KV]] -) -> Tuple[str, List[KV]]: +def make_tuple_comparison_clause(keys: List[Tuple[str, KV]]) -> Tuple[str, List[KV]]: """Returns a tuple comparison SQL clause - Depending what the SQL engine supports, builds a SQL clause that looks like either - "(a, b) > (?, ?)", or "(a > ?) OR (a == ? AND b > ?)". + Builds a SQL clause that looks like "(a, b) > (?, ?)" Args: - database_engine keys: A set of (column, value) pairs to be compared. Returns: A tuple of SQL query and the args """ - if database_engine.supports_tuple_comparison: - return ( - "(%s) > (%s)" % (",".join(k[0] for k in keys), ",".join("?" for _ in keys)), - [k[1] for k in keys], - ) - - # we want to build a clause - # (a > ?) OR - # (a == ? AND b > ?) OR - # (a == ? AND b == ? AND c > ?) - # ... - # (a == ? AND b == ? AND ... AND z > ?) - # - # or, equivalently: - # - # (a > ? OR (a == ? AND - # (b > ? OR (b == ? AND - # ... - # (y > ? OR (y == ? AND - # z > ? - # )) - # ... - # )) - # )) - # - # which itself is equivalent to (and apparently easier for the query optimiser): - # - # (a >= ? AND (a > ? OR - # (b >= ? AND (b > ? OR - # ... - # (y >= ? AND (y > ? OR - # z > ? - # )) - # ... - # )) - # )) - # - # - - clause = "" - args = [] # type: List[KV] - for k, v in keys[:-1]: - clause = clause + "(%s >= ? AND (%s > ? OR " % (k, k) - args.extend([v, v]) - - (k, v) = keys[-1] - clause += "%s > ?" % (k,) - args.append(v) - - clause += "))" * (len(keys) - 1) - return clause, args + return ( + "(%s) > (%s)" % (",".join(k[0] for k in keys), ",".join("?" for _ in keys)), + [k[1] for k in keys], + ) diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 6d18e692b0..ea3c15fd0e 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -298,7 +298,6 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): # times, which is fine. where_clause, where_args = make_tuple_comparison_clause( - self.database_engine, [("user_id", last_user_id), ("device_id", last_device_id)], ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index d327e9aa0b..9bf8ba888f 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -985,7 +985,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): def _txn(txn): clause, args = make_tuple_comparison_clause( - self.db_pool.engine, [(x, last_row[x]) for x in KEY_COLS] + [(x, last_row[x]) for x in KEY_COLS] ) sql = """ SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 98dac19a95..ad17123915 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -320,8 +320,8 @@ class PersistEventsStore: txn: LoggingTransaction, events_and_contexts: List[Tuple[EventBase, EventContext]], backfilled: bool, - state_delta_for_room: Dict[str, DeltaState] = {}, - new_forward_extremeties: Dict[str, List[str]] = {}, + state_delta_for_room: Optional[Dict[str, DeltaState]] = None, + new_forward_extremeties: Optional[Dict[str, List[str]]] = None, ): """Insert some number of room events into the necessary database tables. @@ -342,6 +342,9 @@ class PersistEventsStore: extremities. """ + state_delta_for_room = state_delta_for_room or {} + new_forward_extremeties = new_forward_extremeties or {} + all_events_and_contexts = events_and_contexts min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 78367ea58d..79e7df6ca9 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -838,7 +838,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): # We want to do a `(topological_ordering, stream_ordering) > (?,?)` # comparison, but that is not supported on older SQLite versions tuple_clause, tuple_args = make_tuple_comparison_clause( - self.database_engine, [ ("events.room_id", last_room_id), ("topological_ordering", last_depth), diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 8f462dfc31..bd7826f4e9 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -1171,7 +1171,7 @@ class GroupServerStore(GroupServerWorkerStore): user_id: str, membership: str, is_admin: bool = False, - content: JsonDict = {}, + content: Optional[JsonDict] = None, local_attestation: Optional[dict] = None, remote_attestation: Optional[dict] = None, is_publicised: bool = False, @@ -1192,6 +1192,8 @@ class GroupServerStore(GroupServerWorkerStore): is_publicised: Whether this should be publicised. """ + content = content or {} + def _register_user_group_membership_txn(txn, next_id): # TODO: Upsert? self.db_pool.simple_delete_txn( diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index a7f371732f..93431efe00 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -190,7 +190,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # FIXME: how should this be cached? async def get_filtered_current_state_ids( - self, room_id: str, state_filter: StateFilter = StateFilter.all() + self, room_id: str, state_filter: Optional[StateFilter] = None ) -> StateMap[str]: """Get the current state event of a given type for a room based on the current_state_events table. This may not be as up-to-date as the result @@ -205,7 +205,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Map from type/state_key to event ID. """ - where_clause, where_args = state_filter.make_sql_filter_clause() + where_clause, where_args = ( + state_filter or StateFilter.all() + ).make_sql_filter_clause() if not where_clause: # We delegate to the cached version diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 1fd333b707..75c09b3687 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Optional from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool @@ -73,8 +74,10 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): return count def _get_state_groups_from_groups_txn( - self, txn, groups, state_filter=StateFilter.all() + self, txn, groups, state_filter: Optional[StateFilter] = None ): + state_filter = state_filter or StateFilter.all() + results = {group: {} for group in groups} where_clause, where_args = state_filter.make_sql_filter_clause() diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 97ec65f757..dfcf89d91c 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -15,7 +15,7 @@ import logging from collections import namedtuple -from typing import Dict, Iterable, List, Set, Tuple +from typing import Dict, Iterable, List, Optional, Set, Tuple from synapse.api.constants import EventTypes from synapse.storage._base import SQLBaseStore @@ -210,7 +210,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return state_filter.filter_state(state_dict_ids), not missing_types async def _get_state_for_groups( - self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() + self, groups: Iterable[int], state_filter: Optional[StateFilter] = None ) -> Dict[int, MutableStateMap[str]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key @@ -223,6 +223,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): Returns: Dict of state group to state map. """ + state_filter = state_filter or StateFilter.all() member_filter, non_member_filter = state_filter.get_member_split() diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index cca839c70f..21db1645d3 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -44,14 +44,6 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta): @property @abc.abstractmethod - def supports_tuple_comparison(self) -> bool: - """ - Do we support comparing tuples, i.e. `(a, b) > (c, d)`? - """ - ... - - @property - @abc.abstractmethod def supports_using_any_list(self) -> bool: """ Do we support using `a = ANY(?)` and passing a list diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 80a3558aec..dba8cc51d3 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -47,8 +47,8 @@ class PostgresEngine(BaseDatabaseEngine): self._version = db_conn.server_version # Are we on a supported PostgreSQL version? - if not allow_outdated_version and self._version < 90500: - raise RuntimeError("Synapse requires PostgreSQL 9.5+ or above.") + if not allow_outdated_version and self._version < 90600: + raise RuntimeError("Synapse requires PostgreSQL 9.6 or above.") with db_conn.cursor() as txn: txn.execute("SHOW SERVER_ENCODING") @@ -130,13 +130,6 @@ class PostgresEngine(BaseDatabaseEngine): return True @property - def supports_tuple_comparison(self): - """ - Do we support comparing tuples, i.e. `(a, b) > (c, d)`? - """ - return True - - @property def supports_using_any_list(self): """Do we support using `a = ANY(?)` and passing a list""" return True diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index b87e7798da..f4f16456f2 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -57,14 +57,6 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): return self.module.sqlite_version_info >= (3, 24, 0) @property - def supports_tuple_comparison(self): - """ - Do we support comparing tuples, i.e. `(a, b) > (c, d)`? This requires - SQLite 3.15+. - """ - return self.module.sqlite_version_info >= (3, 15, 0) - - @property def supports_using_any_list(self): """Do we support using `a = ANY(?)` and passing a list""" return False @@ -72,8 +64,11 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): def check_database(self, db_conn, allow_outdated_version: bool = False): if not allow_outdated_version: version = self.module.sqlite_version_info - if version < (3, 11, 0): - raise RuntimeError("Synapse requires sqlite 3.11 or above.") + # Synapse is untested against older SQLite versions, and we don't want + # to let users upgrade to a version of Synapse with broken support for their + # sqlite version, because it risks leaving them with a half-upgraded db. + if version < (3, 22, 0): + raise RuntimeError("Synapse requires sqlite 3.22 or above.") def check_new_database(self, txn): """Gets called when setting up a brand new database. This allows us to diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 2e277a21c4..c1c147c62a 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -449,7 +449,7 @@ class StateGroupStorage: return self.stores.state._get_state_groups_from_groups(groups, state_filter) async def get_state_for_events( - self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all() + self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None ) -> Dict[str, StateMap[EventBase]]: """Given a list of event_ids and type tuples, return a list of state dicts for each event. @@ -465,7 +465,7 @@ class StateGroupStorage: groups = set(event_to_groups.values()) group_to_state = await self.stores.state._get_state_for_groups( - groups, state_filter + groups, state_filter or StateFilter.all() ) state_event_map = await self.stores.main.get_events( @@ -485,7 +485,7 @@ class StateGroupStorage: return {event: event_to_state[event] for event in event_ids} async def get_state_ids_for_events( - self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all() + self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None ) -> Dict[str, StateMap[str]]: """ Get the state dicts corresponding to a list of events, containing the event_ids @@ -502,7 +502,7 @@ class StateGroupStorage: groups = set(event_to_groups.values()) group_to_state = await self.stores.state._get_state_for_groups( - groups, state_filter + groups, state_filter or StateFilter.all() ) event_to_state = { @@ -513,7 +513,7 @@ class StateGroupStorage: return {event: event_to_state[event] for event in event_ids} async def get_state_for_event( - self, event_id: str, state_filter: StateFilter = StateFilter.all() + self, event_id: str, state_filter: Optional[StateFilter] = None ) -> StateMap[EventBase]: """ Get the state dict corresponding to a particular event @@ -525,11 +525,13 @@ class StateGroupStorage: Returns: A dict from (type, state_key) -> state_event """ - state_map = await self.get_state_for_events([event_id], state_filter) + state_map = await self.get_state_for_events( + [event_id], state_filter or StateFilter.all() + ) return state_map[event_id] async def get_state_ids_for_event( - self, event_id: str, state_filter: StateFilter = StateFilter.all() + self, event_id: str, state_filter: Optional[StateFilter] = None ) -> StateMap[str]: """ Get the state dict corresponding to a particular event @@ -541,11 +543,13 @@ class StateGroupStorage: Returns: A dict from (type, state_key) -> state_event """ - state_map = await self.get_state_ids_for_events([event_id], state_filter) + state_map = await self.get_state_ids_for_events( + [event_id], state_filter or StateFilter.all() + ) return state_map[event_id] def _get_state_for_groups( - self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() + self, groups: Iterable[int], state_filter: Optional[StateFilter] = None ) -> Awaitable[Dict[int, MutableStateMap[str]]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key @@ -558,7 +562,9 @@ class StateGroupStorage: Returns: Dict of state group to state map. """ - return self.stores.state._get_state_for_groups(groups, state_filter) + return self.stores.state._get_state_for_groups( + groups, state_filter or StateFilter.all() + ) async def store_state_group( self, diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index d4643c4fdf..32d6cc16b9 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -17,7 +17,7 @@ import logging import threading from collections import OrderedDict from contextlib import contextmanager -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union import attr @@ -91,7 +91,14 @@ class StreamIdGenerator: # ... persist event ... """ - def __init__(self, db_conn, table, column, extra_tables=[], step=1): + def __init__( + self, + db_conn, + table, + column, + extra_tables: Iterable[Tuple[str, str]] = (), + step=1, + ): assert step != 0 self._lock = threading.Lock() self._step = step |