diff options
Diffstat (limited to 'synapse/storage')
27 files changed, 317 insertions, 145 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py index bc327e344e..181c3ec249 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -29,9 +29,11 @@ from typing import ( Tuple, TypeVar, Union, + overload, ) from prometheus_client import Histogram +from typing_extensions import Literal from twisted.enterprise import adbapi from twisted.internet import defer @@ -1020,14 +1022,36 @@ class DatabasePool(object): return txn.execute_batch(sql, args) - def simple_select_one( + @overload + async def simple_select_one( + self, + table: str, + keyvalues: Dict[str, Any], + retcols: Iterable[str], + allow_none: Literal[False] = False, + desc: str = "simple_select_one", + ) -> Dict[str, Any]: + ... + + @overload + async def simple_select_one( + self, + table: str, + keyvalues: Dict[str, Any], + retcols: Iterable[str], + allow_none: Literal[True] = True, + desc: str = "simple_select_one", + ) -> Optional[Dict[str, Any]]: + ... + + async def simple_select_one( self, table: str, keyvalues: Dict[str, Any], retcols: Iterable[str], allow_none: bool = False, desc: str = "simple_select_one", - ) -> defer.Deferred: + ) -> Optional[Dict[str, Any]]: """Executes a SELECT query on the named table, which is expected to return a single row, returning multiple columns from it. @@ -1038,18 +1062,18 @@ class DatabasePool(object): allow_none: If true, return None instead of failing if the SELECT statement returns no rows """ - return self.runInteraction( + return await self.runInteraction( desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none ) - def simple_select_one_onecol( + async def simple_select_one_onecol( self, table: str, keyvalues: Dict[str, Any], retcol: Iterable[str], allow_none: bool = False, desc: str = "simple_select_one_onecol", - ) -> defer.Deferred: + ) -> Optional[Any]: """Executes a SELECT query on the named table, which is expected to return a single row, returning a single column from it. @@ -1061,7 +1085,7 @@ class DatabasePool(object): statement returns no rows desc: description of the transaction, for logging and metrics """ - return self.runInteraction( + return await self.runInteraction( desc, self.simple_select_one_onecol_txn, table, diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 17fa470919..0934ae276c 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -498,7 +498,7 @@ class DataStore( ) def get_users_paginate( - self, start, limit, name=None, guests=True, deactivated=False + self, start, limit, user_id=None, name=None, guests=True, deactivated=False ): """Function to retrieve a paginated list of users from users list. This will return a json list of users and the @@ -507,7 +507,8 @@ class DataStore( Args: start (int): start number to begin the query from limit (int): number of rows to retrieve - name (string): filter for user names + user_id (string): search for user_id. ignored if name is not None + name (string): search for local part of user_id or display name guests (bool): whether to in include guest users deactivated (bool): whether to include deactivated users Returns: @@ -516,11 +517,14 @@ class DataStore( def get_users_paginate_txn(txn): filters = [] - args = [] + args = [self.hs.config.server_name] if name: + filters.append("(name LIKE ? OR displayname LIKE ?)") + args.extend(["@%" + name + "%:%", "%" + name + "%"]) + elif user_id: filters.append("name LIKE ?") - args.append("%" + name + "%") + args.extend(["%" + user_id + "%"]) if not guests: filters.append("is_guest = 0") @@ -530,20 +534,23 @@ class DataStore( where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" - sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause) - txn.execute(sql, args) - count = txn.fetchone()[0] - - args = [self.hs.config.server_name] + args + [limit, start] - sql = """ - SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url + sql_base = """ FROM users as u LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ? {} - ORDER BY u.name LIMIT ? OFFSET ? """.format( where_clause ) + sql = "SELECT COUNT(*) as total_users " + sql_base + txn.execute(sql, args) + count = txn.fetchone()[0] + + sql = ( + "SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url " + + sql_base + + " ORDER BY u.name LIMIT ? OFFSET ?" + ) + args += [limit, start] txn.execute(sql, args) users = self.db_pool.cursor_to_dict(txn) return users, count diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 82aac2bbf3..04042a2c98 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -336,7 +336,7 @@ class AccountDataStore(AccountDataWorkerStore): """ content_json = json_encoder.encode(content) - with self._account_data_id_gen.get_next() as next_id: + with await self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. @@ -384,7 +384,7 @@ class AccountDataStore(AccountDataWorkerStore): """ content_json = json_encoder.encode(content) - with self._account_data_id_gen.get_next() as next_id: + with await self._account_data_id_gen.get_next() as next_id: # no need to lock here as account_data has a unique constraint on # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 1f6e995c4f..bb85637a95 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) rows.append((destination, stream_id, now_ms, edu_json)) txn.executemany(sql, rows) - with self._device_inbox_id_gen.get_next() as stream_id: + with await self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id @@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) txn, stream_id, local_messages_by_user_then_device ) - with self._device_inbox_id_gen.get_next() as stream_id: + with await self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_from_remote_to_device_inbox", diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 9a786e2929..a811a39eb5 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Dict, Iterable, List, Optional, Set, Tuple +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple from synapse.api.errors import Codes, StoreError from synapse.logging.opentracing import ( @@ -47,7 +47,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" class DeviceWorkerStore(SQLBaseStore): - def get_device(self, user_id: str, device_id: str): + async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: """Retrieve a device. Only returns devices that are not marked as hidden. @@ -55,11 +55,11 @@ class DeviceWorkerStore(SQLBaseStore): user_id: The ID of the user which owns the device device_id: The ID of the device to retrieve Returns: - defer.Deferred for a dict containing the device information + A dict containing the device information Raises: StoreError: if the device is not found """ - return self.db_pool.simple_select_one( + return await self.db_pool.simple_select_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), @@ -380,7 +380,7 @@ class DeviceWorkerStore(SQLBaseStore): THe new stream ID. """ - with self._device_list_id_gen.get_next() as stream_id: + with await self._device_list_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "add_user_sig_change_to_streams", self._add_user_signature_change_txn, @@ -656,11 +656,13 @@ class DeviceWorkerStore(SQLBaseStore): ) @cached(max_entries=10000) - def get_device_list_last_stream_id_for_remote(self, user_id: str): + async def get_device_list_last_stream_id_for_remote( + self, user_id: str + ) -> Optional[Any]: """Get the last stream_id we got for a user. May be None if we haven't got any information for them. """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, retcol="stream_id", @@ -1146,7 +1148,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not device_ids: return - with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids: + with await self._device_list_id_gen.get_next_mult( + len(device_ids) + ) as stream_ids: await self.db_pool.runInteraction( "add_device_change_to_stream", self._add_device_change_to_stream_txn, @@ -1159,7 +1163,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return stream_ids[-1] context = get_active_span_text_map() - with self._device_list_id_gen.get_next_mult( + with await self._device_list_id_gen.get_next_mult( len(hosts) * len(device_ids) ) as stream_ids: await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 037e02603c..301d5d845a 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -59,8 +59,8 @@ class DirectoryWorkerStore(SQLBaseStore): return RoomAliasMapping(room_id, room_alias.to_string(), servers) - def get_room_alias_creator(self, room_alias): - return self.db_pool.simple_select_one_onecol( + async def get_room_alias_creator(self, room_alias: str) -> str: + return await self.db_pool.simple_select_one_onecol( table="room_aliases", keyvalues={"room_alias": room_alias}, retcol="creator", diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index 2eeb9f97dc..46c3e33cc6 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -223,15 +223,15 @@ class EndToEndRoomKeyStore(SQLBaseStore): return ret - def count_e2e_room_keys(self, user_id, version): + async def count_e2e_room_keys(self, user_id: str, version: str) -> int: """Get the number of keys in a backup version. Args: - user_id (str): the user whose backup we're querying - version (str): the version ID of the backup we're querying about + user_id: the user whose backup we're querying + version: the version ID of the backup we're querying about """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="e2e_room_keys", keyvalues={"user_id": user_id, "version": version}, retcol="COUNT(*)", diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index f93e0d320d..385868bdab 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -648,7 +648,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) - def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key): + def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id): """Set a user's cross-signing key. Args: @@ -658,6 +658,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key key (dict): the key data + stream_id (int) """ # the 'key' dict will look something like: # { @@ -695,23 +696,22 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): ) # and finally, store the key itself - with self._cross_signing_id_gen.get_next() as stream_id: - self.db_pool.simple_insert_txn( - txn, - "e2e_cross_signing_keys", - values={ - "user_id": user_id, - "keytype": key_type, - "keydata": json_encoder.encode(key), - "stream_id": stream_id, - }, - ) + self.db_pool.simple_insert_txn( + txn, + "e2e_cross_signing_keys", + values={ + "user_id": user_id, + "keytype": key_type, + "keydata": json_encoder.encode(key), + "stream_id": stream_id, + }, + ) self._invalidate_cache_and_stream( txn, self._get_bare_e2e_cross_signing_keys, (user_id,) ) - def set_e2e_cross_signing_key(self, user_id, key_type, key): + async def set_e2e_cross_signing_key(self, user_id, key_type, key): """Set a user's cross-signing key. Args: @@ -719,13 +719,16 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): key_type (str): the type of cross-signing key to set key (dict): the key data """ - return self.db_pool.runInteraction( - "add_e2e_cross_signing_key", - self._set_e2e_cross_signing_key_txn, - user_id, - key_type, - key, - ) + + with await self._cross_signing_id_gen.get_next() as stream_id: + return await self.db_pool.runInteraction( + "add_e2e_cross_signing_key", + self._set_e2e_cross_signing_key_txn, + user_id, + key_type, + key, + stream_id, + ) def store_e2e_cross_signing_signatures(self, user_id, signatures): """Stores cross-signing signatures. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index b90e6de2d5..6313b41eef 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -153,11 +153,11 @@ class PersistEventsStore: # Note: Multiple instances of this function cannot be in flight at # the same time for the same room. if backfilled: - stream_ordering_manager = self._backfill_id_gen.get_next_mult( + stream_ordering_manager = await self._backfill_id_gen.get_next_mult( len(events_and_contexts) ) else: - stream_ordering_manager = self._stream_id_gen.get_next_mult( + stream_ordering_manager = await self._stream_id_gen.get_next_mult( len(events_and_contexts) ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index e1241a724b..e6247d682d 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -113,25 +113,25 @@ class EventsWorkerStore(SQLBaseStore): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == EventsStream.NAME: - self._stream_id_gen.advance(token) + self._stream_id_gen.advance(instance_name, token) elif stream_name == BackfillStream.NAME: - self._backfill_id_gen.advance(-token) + self._backfill_id_gen.advance(instance_name, -token) super().process_replication_rows(stream_name, instance_name, token, rows) - def get_received_ts(self, event_id): + async def get_received_ts(self, event_id: str) -> Optional[int]: """Get received_ts (when it was persisted) for the event. Raises an exception for unknown events. Args: - event_id (str) + event_id: The event ID to query. Returns: - Deferred[int|None]: Timestamp in milliseconds, or None for events - that were persisted before received_ts was implemented. + Timestamp in milliseconds, or None for events that were persisted + before received_ts was implemented. """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="received_ts", diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 0e3b8739c6..c39864f59f 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json @@ -28,8 +28,8 @@ _DEFAULT_ROLE_ID = "" class GroupServerWorkerStore(SQLBaseStore): - def get_group(self, group_id): - return self.db_pool.simple_select_one( + async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="groups", keyvalues={"group_id": group_id}, retcols=( @@ -351,8 +351,10 @@ class GroupServerWorkerStore(SQLBaseStore): ) return bool(result) - def is_user_admin_in_group(self, group_id, user_id): - return self.db_pool.simple_select_one_onecol( + async def is_user_admin_in_group( + self, group_id: str, user_id: str + ) -> Optional[bool]: + return await self.db_pool.simple_select_one_onecol( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="is_admin", @@ -360,10 +362,12 @@ class GroupServerWorkerStore(SQLBaseStore): desc="is_user_admin_in_group", ) - def is_user_invited_to_local_group(self, group_id, user_id): + async def is_user_invited_to_local_group( + self, group_id: str, user_id: str + ) -> Optional[bool]: """Has the group server invited a user? """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", @@ -1182,7 +1186,7 @@ class GroupServerStore(GroupServerWorkerStore): return next_id - with self._group_updates_id_gen.get_next() as next_id: + with await self._group_updates_id_gen.get_next() as next_id: res = await self.db_pool.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 80fc1cd009..4ae255ebd8 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, Optional + from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool @@ -37,12 +39,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super(MediaRepositoryStore, self).__init__(database, db_conn, hs) - def get_local_media(self, media_id): + async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: """Get the metadata for a local piece of media + Returns: None if the media_id doesn't exist. """ - return self.db_pool.simple_select_one( + return await self.db_pool.simple_select_one( "local_media_repository", {"media_id": media_id}, ( @@ -191,8 +194,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="store_local_thumbnail", ) - def get_cached_remote_media(self, origin, media_id): - return self.db_pool.simple_select_one( + async def get_cached_remote_media( + self, origin, media_id: str + ) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( "remote_media_cache", {"media_origin": origin, "media_id": media_id}, ( diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index e71cdd2cb4..fe30552c08 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -99,17 +99,18 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): return users @cached(num_args=1) - def user_last_seen_monthly_active(self, user_id): + async def user_last_seen_monthly_active(self, user_id: str) -> int: """ - Checks if a given user is part of the monthly active user group - Arguments: - user_id (str): user to add/update - Return: - Deferred[int] : timestamp since last seen, None if never seen + Checks if a given user is part of the monthly active user group + Arguments: + user_id: user to add/update + + Return: + Timestamp since last seen, None if never seen """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="monthly_active_users", keyvalues={"user_id": user_id}, retcol="timestamp", diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 4e3ec02d14..c9f655dfb7 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -23,7 +23,7 @@ from synapse.util.iterutils import batch_iter class PresenceStore(SQLBaseStore): async def update_presence(self, presence_states): - stream_ordering_manager = self._presence_id_gen.get_next_mult( + stream_ordering_manager = await self._presence_id_gen.get_next_mult( len(presence_states) ) diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index b8261357d4..b8233c4848 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, Optional from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore @@ -19,7 +20,7 @@ from synapse.storage.databases.main.roommember import ProfileInfo class ProfileWorkerStore(SQLBaseStore): - async def get_profileinfo(self, user_localpart): + async def get_profileinfo(self, user_localpart: str) -> ProfileInfo: try: profile = await self.db_pool.simple_select_one( table="profiles", @@ -38,24 +39,26 @@ class ProfileWorkerStore(SQLBaseStore): avatar_url=profile["avatar_url"], display_name=profile["displayname"] ) - def get_profile_displayname(self, user_localpart): - return self.db_pool.simple_select_one_onecol( + async def get_profile_displayname(self, user_localpart: str) -> str: + return await self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="displayname", desc="get_profile_displayname", ) - def get_profile_avatar_url(self, user_localpart): - return self.db_pool.simple_select_one_onecol( + async def get_profile_avatar_url(self, user_localpart: str) -> str: + return await self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="avatar_url", desc="get_profile_avatar_url", ) - def get_from_remote_profile_cache(self, user_id): - return self.db_pool.simple_select_one( + async def get_from_remote_profile_cache( + self, user_id: str + ) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="remote_profile_cache", keyvalues={"user_id": user_id}, retcols=("displayname", "avatar_url"), diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index a585e54812..2fb5b02d7d 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore): ) -> None: conditions_json = json_encoder.encode(conditions) actions_json = json_encoder.encode(actions) - with self._push_rules_stream_id_gen.get_next() as stream_id: + with await self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() if before or after: @@ -560,7 +560,7 @@ class PushRuleStore(PushRulesWorkerStore): txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" ) - with self._push_rules_stream_id_gen.get_next() as stream_id: + with await self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( @@ -571,7 +571,7 @@ class PushRuleStore(PushRulesWorkerStore): ) async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None: - with self._push_rules_stream_id_gen.get_next() as stream_id: + with await self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( @@ -646,7 +646,7 @@ class PushRuleStore(PushRulesWorkerStore): data={"actions": actions_json}, ) - with self._push_rules_stream_id_gen.get_next() as stream_id: + with await self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 1126fd0751..c388468273 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore): last_stream_ordering, profile_tag="", ) -> None: - with self._pushers_id_gen.get_next() as stream_id: + with await self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on # (app_id, pushkey, user_name) so simple_upsert will retry await self.db_pool.simple_upsert( @@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore): }, ) - with self._pushers_id_gen.get_next() as stream_id: + with await self._pushers_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "delete_pusher", delete_pusher_txn, stream_id ) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 19ad1c056f..cea5ac9a68 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -71,8 +71,10 @@ class ReceiptsWorkerStore(SQLBaseStore): ) @cached(num_args=3) - def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): - return self.db_pool.simple_select_one_onecol( + async def get_last_receipt_event_id_for_user( + self, user_id: str, room_id: str, receipt_type: str + ) -> Optional[str]: + return await self.db_pool.simple_select_one_onecol( table="receipts_linearized", keyvalues={ "room_id": room_id, @@ -520,8 +522,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "insert_receipt_conv", graph_to_linear ) - stream_id_manager = self._receipts_id_gen.get_next() - with stream_id_manager as stream_id: + with await self._receipts_id_gen.get_next() as stream_id: event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 068ad22b30..eced53d470 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -17,7 +17,7 @@ import logging import re -from typing import Awaitable, Dict, List, Optional +from typing import Any, Awaitable, Dict, List, Optional from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError @@ -46,8 +46,8 @@ class RegistrationWorkerStore(SQLBaseStore): ) @cached() - def get_user_by_id(self, user_id): - return self.db_pool.simple_select_one( + async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="users", keyvalues={"name": user_id}, retcols=[ @@ -889,6 +889,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): super(RegistrationStore, self).__init__(database, db_conn, hs) self._account_validity = hs.config.account_validity + self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors if self._account_validity.enabled: self._clock.call_later( @@ -1258,12 +1259,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): desc="del_user_pending_deactivation", ) - def get_user_pending_deactivation(self): + async def get_user_pending_deactivation(self) -> Optional[str]: """ Gets one user from the table of users waiting to be parted from all the rooms they're in. """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( "users_pending_deactivation", keyvalues={}, retcol="user_id", @@ -1302,15 +1303,22 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) if not row: - raise ThreepidValidationError(400, "Unknown session_id") + if self._ignore_unknown_session_error: + # If we need to inhibit the error caused by an incorrect session ID, + # use None as placeholder values for the client secret and the + # validation timestamp. + # It shouldn't be an issue because they're both only checked after + # the token check, which should fail. And if it doesn't for some + # reason, the next check is on the client secret, which is NOT NULL, + # so we don't have to worry about the client secret matching by + # accident. + row = {"client_secret": None, "validated_at": None} + else: + raise ThreepidValidationError(400, "Unknown session_id") + retrieved_client_secret = row["client_secret"] validated_at = row["validated_at"] - if retrieved_client_secret != client_secret: - raise ThreepidValidationError( - 400, "This client_secret does not match the provided session_id" - ) - row = self.db_pool.simple_select_one_txn( txn, table="threepid_validation_token", @@ -1326,6 +1334,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): expires = row["expires"] next_link = row["next_link"] + if retrieved_client_secret != client_secret: + raise ThreepidValidationError( + 400, "This client_secret does not match the provided session_id" + ) + # If the session is already validated, no need to revalidate if validated_at: return next_link diff --git a/synapse/storage/databases/main/rejections.py b/synapse/storage/databases/main/rejections.py index cf9ba51205..1e361aaa9a 100644 --- a/synapse/storage/databases/main/rejections.py +++ b/synapse/storage/databases/main/rejections.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Optional from synapse.storage._base import SQLBaseStore @@ -21,8 +22,8 @@ logger = logging.getLogger(__name__) class RejectionsStore(SQLBaseStore): - def get_rejection_reason(self, event_id): - return self.db_pool.simple_select_one_onecol( + async def get_rejection_reason(self, event_id: str) -> Optional[str]: + return await self.db_pool.simple_select_one_onecol( table="rejections", retcol="reason", keyvalues={"event_id": event_id}, diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 7d3ac47261..97ecdb16e4 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -73,15 +73,15 @@ class RoomWorkerStore(SQLBaseStore): self.config = hs.config - def get_room(self, room_id): + async def get_room(self, room_id: str) -> dict: """Retrieve a room. Args: - room_id (str): The ID of the room to retrieve. + room_id: The ID of the room to retrieve. Returns: A dict containing the room information, or None if the room is unknown. """ - return self.db_pool.simple_select_one( + return await self.db_pool.simple_select_one( table="rooms", keyvalues={"room_id": room_id}, retcols=("room_id", "is_public", "creator"), @@ -330,8 +330,8 @@ class RoomWorkerStore(SQLBaseStore): return ret_val @cached(max_entries=10000) - def is_room_blocked(self, room_id): - return self.db_pool.simple_select_one_onecol( + async def is_room_blocked(self, room_id: str) -> Optional[bool]: + return await self.db_pool.simple_select_one_onecol( table="blocked_rooms", keyvalues={"room_id": room_id}, retcol="1", @@ -1129,7 +1129,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with self._public_room_id_gen.get_next() as next_id: + with await self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "store_room_txn", store_room_txn, next_id ) @@ -1196,7 +1196,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with self._public_room_id_gen.get_next() as next_id: + with await self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "set_room_is_public", set_room_is_public_txn, next_id ) @@ -1276,7 +1276,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with self._public_room_id_gen.get_next() as next_id: + with await self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "set_room_is_public_appservice", set_room_is_public_appservice_txn, diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 991233a9bc..458f169617 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -260,8 +260,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return event.content.get("canonical_alias") @cached(max_entries=50000) - def _get_state_group_for_event(self, event_id): - return self.db_pool.simple_select_one_onecol( + async def _get_state_group_for_event(self, event_id: str) -> Optional[int]: + return await self.db_pool.simple_select_one_onecol( table="event_to_state_groups", keyvalues={"event_id": event_id}, retcol="state_group", diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 802c9019b9..9fe97af56a 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -211,11 +211,11 @@ class StatsStore(StateDeltasStore): return len(rooms_to_work_on) - def get_stats_positions(self): + async def get_stats_positions(self) -> int: """ Returns the stats processor positions. """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="stats_incremental_position", keyvalues={}, retcol="stream_id", @@ -300,7 +300,7 @@ class StatsStore(StateDeltasStore): return slice_list @cached() - def get_earliest_token_for_stats(self, stats_type, id): + async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int: """ Fetch the "earliest token". This is used by the room stats delta processor to ignore deltas that have been processed between the @@ -308,11 +308,11 @@ class StatsStore(StateDeltasStore): being calculated. Returns: - Deferred[int] + The earliest token. """ table, id_col = TYPE_TO_TABLE[stats_type] - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( "%s_current" % (table,), keyvalues={id_col: id}, retcol="completed_delta_stream_id", diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index ade7abc927..0c34bbf21a 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore): ) self._update_revision_txn(txn, user_id, room_id, next_id) - with self._account_data_id_gen.get_next() as next_id: + with await self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) @@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore): txn.execute(sql, (user_id, room_id, tag)) self._update_revision_txn(txn, user_id, room_id, next_id) - with self._account_data_id_gen.get_next() as next_id: + with await self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index af21fe457a..20cbcd851c 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -15,6 +15,7 @@ import logging import re +from typing import Any, Dict, Optional from synapse.api.constants import EventTypes, JoinRules from synapse.storage.database import DatabasePool @@ -527,8 +528,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) @cached() - def get_user_in_directory(self, user_id): - return self.db_pool.simple_select_one( + async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="user_directory", keyvalues={"user_id": user_id}, retcols=("display_name", "avatar_url"), @@ -663,8 +664,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): users.update(rows) return list(users) - def get_user_directory_stream_pos(self): - return self.db_pool.simple_select_one_onecol( + async def get_user_directory_stream_pos(self) -> int: + return await self.db_pool.simple_select_one_onecol( table="user_directory_stream_pos", keyvalues={}, retcol="stream_id", diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 0bf772d4d1..5b07847773 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -14,9 +14,10 @@ # limitations under the License. import contextlib +import heapq import threading from collections import deque -from typing import Dict, Set +from typing import Dict, List, Set from typing_extensions import Deque @@ -80,7 +81,7 @@ class StreamIdGenerator(object): upwards, -1 to grow downwards. Usage: - with stream_id_gen.get_next() as stream_id: + with await stream_id_gen.get_next() as stream_id: # ... persist event ... """ @@ -95,10 +96,10 @@ class StreamIdGenerator(object): ) self._unfinished_ids = deque() # type: Deque[int] - def get_next(self): + async def get_next(self): """ Usage: - with stream_id_gen.get_next() as stream_id: + with await stream_id_gen.get_next() as stream_id: # ... persist event ... """ with self._lock: @@ -117,10 +118,10 @@ class StreamIdGenerator(object): return manager() - def get_next_mult(self, n): + async def get_next_mult(self, n): """ Usage: - with stream_id_gen.get_next(n) as stream_ids: + with await stream_id_gen.get_next(n) as stream_ids: # ... persist events ... """ with self._lock: @@ -210,6 +211,23 @@ class MultiWriterIdGenerator: # should be less than the minimum of this set (if not empty). self._unfinished_ids = set() # type: Set[int] + # We track the max position where we know everything before has been + # persisted. This is done by a) looking at the min across all instances + # and b) noting that if we have seen a run of persisted positions + # without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7). + # + # Note: There is no guarentee that the IDs generated by the sequence + # will be gapless; gaps can form when e.g. a transaction was rolled + # back. This means that sometimes we won't be able to skip forward the + # position even though everything has been persisted. However, since + # gaps should be relatively rare it's still worth doing the book keeping + # that allows us to skip forwards when there are gapless runs of + # positions. + self._persisted_upto_position = ( + min(self._current_positions.values()) if self._current_positions else 0 + ) + self._known_persisted_positions = [] # type: List[int] + self._sequence_gen = PostgresSequenceGenerator(sequence_name) def _load_current_ids( @@ -234,9 +252,12 @@ class MultiWriterIdGenerator: return current_positions - def _load_next_id_txn(self, txn): + def _load_next_id_txn(self, txn) -> int: return self._sequence_gen.get_next_id_txn(txn) + def _load_next_mult_id_txn(self, txn, n: int) -> List[int]: + return self._sequence_gen.get_next_mult_txn(txn, n) + async def get_next(self): """ Usage: @@ -262,6 +283,34 @@ class MultiWriterIdGenerator: return manager() + async def get_next_mult(self, n: int): + """ + Usage: + with await stream_id_gen.get_next_mult(5) as stream_ids: + # ... persist events ... + """ + next_ids = await self._db.runInteraction( + "_load_next_mult_id", self._load_next_mult_id_txn, n + ) + + # Assert the fetched ID is actually greater than any ID we've already + # seen. If not, then the sequence and table have got out of sync + # somehow. + assert max(self.get_positions().values(), default=0) < min(next_ids) + + with self._lock: + self._unfinished_ids.update(next_ids) + + @contextlib.contextmanager + def manager(): + try: + yield next_ids + finally: + for i in next_ids: + self._mark_id_as_finished(i) + + return manager() + def get_next_txn(self, txn: LoggingTransaction): """ Usage: @@ -326,3 +375,53 @@ class MultiWriterIdGenerator: self._current_positions[instance_name] = max( new_id, self._current_positions.get(instance_name, 0) ) + + self._add_persisted_position(new_id) + + def get_persisted_upto_position(self) -> int: + """Get the max position where all previous positions have been + persisted. + + Note: In the worst case scenario this will be equal to the minimum + position across writers. This means that the returned position here can + lag if one writer doesn't write very often. + """ + + with self._lock: + return self._persisted_upto_position + + def _add_persisted_position(self, new_id: int): + """Record that we have persisted a position. + + This is used to keep the `_current_positions` up to date. + """ + + # We require that the lock is locked by caller + assert self._lock.locked() + + heapq.heappush(self._known_persisted_positions, new_id) + + # We move the current min position up if the minimum current positions + # of all instances is higher (since by definition all positions less + # that that have been persisted). + min_curr = min(self._current_positions.values()) + self._persisted_upto_position = max(min_curr, self._persisted_upto_position) + + # We now iterate through the seen positions, discarding those that are + # less than the current min positions, and incrementing the min position + # if its exactly one greater. + # + # This is also where we discard items from `_known_persisted_positions` + # (to ensure the list doesn't infinitely grow). + while self._known_persisted_positions: + if self._known_persisted_positions[0] <= self._persisted_upto_position: + heapq.heappop(self._known_persisted_positions) + elif ( + self._known_persisted_positions[0] == self._persisted_upto_position + 1 + ): + heapq.heappop(self._known_persisted_positions) + self._persisted_upto_position += 1 + else: + # There was a gap in seen positions, so there is nothing more to + # do. + break diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index 63dfea4220..ffc1894748 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -14,7 +14,7 @@ # limitations under the License. import abc import threading -from typing import Callable, Optional +from typing import Callable, List, Optional from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine from synapse.storage.types import Cursor @@ -39,6 +39,12 @@ class PostgresSequenceGenerator(SequenceGenerator): txn.execute("SELECT nextval(?)", (self._sequence_name,)) return txn.fetchone()[0] + def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]: + txn.execute( + "SELECT nextval(?) FROM generate_series(1, ?)", (self._sequence_name, n) + ) + return [i for (i,) in txn] + GetFirstCallbackType = Callable[[Cursor], int] |