diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/database.py | 20 | ||||
-rw-r--r-- | synapse/storage/databases/main/events.py | 7 | ||||
-rw-r--r-- | synapse/storage/databases/main/group_server.py | 4 | ||||
-rw-r--r-- | synapse/storage/databases/main/state.py | 6 | ||||
-rw-r--r-- | synapse/storage/databases/state/bg_updates.py | 5 | ||||
-rw-r--r-- | synapse/storage/databases/state/store.py | 5 | ||||
-rw-r--r-- | synapse/storage/state.py | 26 | ||||
-rw-r--r-- | synapse/storage/util/id_generators.py | 11 |
8 files changed, 58 insertions, 26 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py index b302cd5786..fa15b0ce5b 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -900,7 +900,7 @@ class DatabasePool: table: str, keyvalues: Dict[str, Any], values: Dict[str, Any], - insertion_values: Dict[str, Any] = {}, + insertion_values: Optional[Dict[str, Any]] = None, desc: str = "simple_upsert", lock: bool = True, ) -> Optional[bool]: @@ -927,6 +927,8 @@ class DatabasePool: Native upserts always return None. Emulated upserts return True if a new entry was created, False if an existing one was updated. """ + insertion_values = insertion_values or {} + attempts = 0 while True: try: @@ -964,7 +966,7 @@ class DatabasePool: table: str, keyvalues: Dict[str, Any], values: Dict[str, Any], - insertion_values: Dict[str, Any] = {}, + insertion_values: Optional[Dict[str, Any]] = None, lock: bool = True, ) -> Optional[bool]: """ @@ -982,6 +984,8 @@ class DatabasePool: Native upserts always return None. Emulated upserts return True if a new entry was created, False if an existing one was updated. """ + insertion_values = insertion_values or {} + if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: self.simple_upsert_txn_native_upsert( txn, table, keyvalues, values, insertion_values=insertion_values @@ -1003,7 +1007,7 @@ class DatabasePool: table: str, keyvalues: Dict[str, Any], values: Dict[str, Any], - insertion_values: Dict[str, Any] = {}, + insertion_values: Optional[Dict[str, Any]] = None, lock: bool = True, ) -> bool: """ @@ -1017,6 +1021,8 @@ class DatabasePool: Returns True if a new entry was created, False if an existing one was updated. """ + insertion_values = insertion_values or {} + # We need to lock the table :(, unless we're *really* careful if lock: self.engine.lock_table(txn, table) @@ -1077,7 +1083,7 @@ class DatabasePool: table: str, keyvalues: Dict[str, Any], values: Dict[str, Any], - insertion_values: Dict[str, Any] = {}, + insertion_values: Optional[Dict[str, Any]] = None, ) -> None: """ Use the native UPSERT functionality in recent PostgreSQL versions. @@ -1090,7 +1096,7 @@ class DatabasePool: """ allvalues = {} # type: Dict[str, Any] allvalues.update(keyvalues) - allvalues.update(insertion_values) + allvalues.update(insertion_values or {}) if not values: latter = "NOTHING" @@ -1513,7 +1519,7 @@ class DatabasePool: column: str, iterable: Iterable[Any], retcols: Iterable[str], - keyvalues: Dict[str, Any] = {}, + keyvalues: Optional[Dict[str, Any]] = None, desc: str = "simple_select_many_batch", batch_size: int = 100, ) -> List[Any]: @@ -1531,6 +1537,8 @@ class DatabasePool: desc: description of the transaction, for logging and metrics batch_size: the number of rows for each select query """ + keyvalues = keyvalues or {} + results = [] # type: List[Dict[str, Any]] if not iterable: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 98dac19a95..ad17123915 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -320,8 +320,8 @@ class PersistEventsStore: txn: LoggingTransaction, events_and_contexts: List[Tuple[EventBase, EventContext]], backfilled: bool, - state_delta_for_room: Dict[str, DeltaState] = {}, - new_forward_extremeties: Dict[str, List[str]] = {}, + state_delta_for_room: Optional[Dict[str, DeltaState]] = None, + new_forward_extremeties: Optional[Dict[str, List[str]]] = None, ): """Insert some number of room events into the necessary database tables. @@ -342,6 +342,9 @@ class PersistEventsStore: extremities. """ + state_delta_for_room = state_delta_for_room or {} + new_forward_extremeties = new_forward_extremeties or {} + all_events_and_contexts = events_and_contexts min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 8f462dfc31..bd7826f4e9 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -1171,7 +1171,7 @@ class GroupServerStore(GroupServerWorkerStore): user_id: str, membership: str, is_admin: bool = False, - content: JsonDict = {}, + content: Optional[JsonDict] = None, local_attestation: Optional[dict] = None, remote_attestation: Optional[dict] = None, is_publicised: bool = False, @@ -1192,6 +1192,8 @@ class GroupServerStore(GroupServerWorkerStore): is_publicised: Whether this should be publicised. """ + content = content or {} + def _register_user_group_membership_txn(txn, next_id): # TODO: Upsert? self.db_pool.simple_delete_txn( diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index a7f371732f..93431efe00 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -190,7 +190,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # FIXME: how should this be cached? async def get_filtered_current_state_ids( - self, room_id: str, state_filter: StateFilter = StateFilter.all() + self, room_id: str, state_filter: Optional[StateFilter] = None ) -> StateMap[str]: """Get the current state event of a given type for a room based on the current_state_events table. This may not be as up-to-date as the result @@ -205,7 +205,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Map from type/state_key to event ID. """ - where_clause, where_args = state_filter.make_sql_filter_clause() + where_clause, where_args = ( + state_filter or StateFilter.all() + ).make_sql_filter_clause() if not where_clause: # We delegate to the cached version diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 1fd333b707..75c09b3687 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Optional from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool @@ -73,8 +74,10 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): return count def _get_state_groups_from_groups_txn( - self, txn, groups, state_filter=StateFilter.all() + self, txn, groups, state_filter: Optional[StateFilter] = None ): + state_filter = state_filter or StateFilter.all() + results = {group: {} for group in groups} where_clause, where_args = state_filter.make_sql_filter_clause() diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 97ec65f757..dfcf89d91c 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -15,7 +15,7 @@ import logging from collections import namedtuple -from typing import Dict, Iterable, List, Set, Tuple +from typing import Dict, Iterable, List, Optional, Set, Tuple from synapse.api.constants import EventTypes from synapse.storage._base import SQLBaseStore @@ -210,7 +210,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return state_filter.filter_state(state_dict_ids), not missing_types async def _get_state_for_groups( - self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() + self, groups: Iterable[int], state_filter: Optional[StateFilter] = None ) -> Dict[int, MutableStateMap[str]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key @@ -223,6 +223,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): Returns: Dict of state group to state map. """ + state_filter = state_filter or StateFilter.all() member_filter, non_member_filter = state_filter.get_member_split() diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 2e277a21c4..c1c147c62a 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -449,7 +449,7 @@ class StateGroupStorage: return self.stores.state._get_state_groups_from_groups(groups, state_filter) async def get_state_for_events( - self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all() + self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None ) -> Dict[str, StateMap[EventBase]]: """Given a list of event_ids and type tuples, return a list of state dicts for each event. @@ -465,7 +465,7 @@ class StateGroupStorage: groups = set(event_to_groups.values()) group_to_state = await self.stores.state._get_state_for_groups( - groups, state_filter + groups, state_filter or StateFilter.all() ) state_event_map = await self.stores.main.get_events( @@ -485,7 +485,7 @@ class StateGroupStorage: return {event: event_to_state[event] for event in event_ids} async def get_state_ids_for_events( - self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all() + self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None ) -> Dict[str, StateMap[str]]: """ Get the state dicts corresponding to a list of events, containing the event_ids @@ -502,7 +502,7 @@ class StateGroupStorage: groups = set(event_to_groups.values()) group_to_state = await self.stores.state._get_state_for_groups( - groups, state_filter + groups, state_filter or StateFilter.all() ) event_to_state = { @@ -513,7 +513,7 @@ class StateGroupStorage: return {event: event_to_state[event] for event in event_ids} async def get_state_for_event( - self, event_id: str, state_filter: StateFilter = StateFilter.all() + self, event_id: str, state_filter: Optional[StateFilter] = None ) -> StateMap[EventBase]: """ Get the state dict corresponding to a particular event @@ -525,11 +525,13 @@ class StateGroupStorage: Returns: A dict from (type, state_key) -> state_event """ - state_map = await self.get_state_for_events([event_id], state_filter) + state_map = await self.get_state_for_events( + [event_id], state_filter or StateFilter.all() + ) return state_map[event_id] async def get_state_ids_for_event( - self, event_id: str, state_filter: StateFilter = StateFilter.all() + self, event_id: str, state_filter: Optional[StateFilter] = None ) -> StateMap[str]: """ Get the state dict corresponding to a particular event @@ -541,11 +543,13 @@ class StateGroupStorage: Returns: A dict from (type, state_key) -> state_event """ - state_map = await self.get_state_ids_for_events([event_id], state_filter) + state_map = await self.get_state_ids_for_events( + [event_id], state_filter or StateFilter.all() + ) return state_map[event_id] def _get_state_for_groups( - self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() + self, groups: Iterable[int], state_filter: Optional[StateFilter] = None ) -> Awaitable[Dict[int, MutableStateMap[str]]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key @@ -558,7 +562,9 @@ class StateGroupStorage: Returns: Dict of state group to state map. """ - return self.stores.state._get_state_for_groups(groups, state_filter) + return self.stores.state._get_state_for_groups( + groups, state_filter or StateFilter.all() + ) async def store_state_group( self, diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index d4643c4fdf..32d6cc16b9 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -17,7 +17,7 @@ import logging import threading from collections import OrderedDict from contextlib import contextmanager -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union import attr @@ -91,7 +91,14 @@ class StreamIdGenerator: # ... persist event ... """ - def __init__(self, db_conn, table, column, extra_tables=[], step=1): + def __init__( + self, + db_conn, + table, + column, + extra_tables: Iterable[Tuple[str, str]] = (), + step=1, + ): assert step != 0 self._lock = threading.Lock() self._step = step |