summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/__init__.py11
-rw-r--r--synapse/storage/databases/main/__init__.py31
-rw-r--r--synapse/storage/databases/main/account_data.py4
-rw-r--r--synapse/storage/databases/main/appservice.py7
-rw-r--r--synapse/storage/databases/main/cache.py4
-rw-r--r--synapse/storage/databases/main/deviceinbox.py4
-rw-r--r--synapse/storage/databases/main/devices.py13
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py43
-rw-r--r--synapse/storage/databases/main/event_federation.py83
-rw-r--r--synapse/storage/databases/main/event_push_actions.py9
-rw-r--r--synapse/storage/databases/main/events.py37
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py46
-rw-r--r--synapse/storage/databases/main/events_worker.py329
-rw-r--r--synapse/storage/databases/main/group_server.py9
-rw-r--r--synapse/storage/databases/main/keys.py28
-rw-r--r--synapse/storage/databases/main/presence.py34
-rw-r--r--synapse/storage/databases/main/profile.py132
-rw-r--r--synapse/storage/databases/main/push_rule.py132
-rw-r--r--synapse/storage/databases/main/pusher.py108
-rw-r--r--synapse/storage/databases/main/receipts.py90
-rw-r--r--synapse/storage/databases/main/registration.py158
-rw-r--r--synapse/storage/databases/main/room.py37
-rw-r--r--synapse/storage/databases/main/roommember.py21
-rw-r--r--synapse/storage/databases/main/schema/delta/48/profiles_batch.sql36
-rw-r--r--synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql23
-rw-r--r--synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql16
-rw-r--r--synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql25
-rw-r--r--synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql18
-rw-r--r--synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql17
-rw-r--r--synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres15
-rw-r--r--synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite4
-rw-r--r--synapse/storage/databases/main/state.py5
-rw-r--r--synapse/storage/databases/main/stream.py401
-rw-r--r--synapse/storage/databases/main/tags.py11
-rw-r--r--synapse/storage/databases/main/ui_auth.py61
-rw-r--r--synapse/storage/databases/main/user_erasure_store.py26
36 files changed, 1132 insertions, 896 deletions
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py

index 4406e58273..0ac854aee2 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py
@@ -87,12 +87,21 @@ class Databases(object): logger.info("Database %r prepared", db_name) + # Closing the context manager doesn't close the connection. + # psycopg will close the connection when the object gets GCed, but *only* + # if the PID is the same as when the connection was opened [1], and + # it may not be if we fork in the meantime. + # + # [1]: https://github.com/psycopg/psycopg2/blob/2_8_5/psycopg/connection_type.c#L1378 + + db_conn.close() + # Sanity check that we have actually configured all the required stores. if not main: raise Exception("No 'main' data store configured") if not state: - raise Exception("No 'main' data store configured") + raise Exception("No 'state' data store configured") # We use local variables here to ensure that the databases do not have # optional types. diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 17fa470919..0934ae276c 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py
@@ -498,7 +498,7 @@ class DataStore( ) def get_users_paginate( - self, start, limit, name=None, guests=True, deactivated=False + self, start, limit, user_id=None, name=None, guests=True, deactivated=False ): """Function to retrieve a paginated list of users from users list. This will return a json list of users and the @@ -507,7 +507,8 @@ class DataStore( Args: start (int): start number to begin the query from limit (int): number of rows to retrieve - name (string): filter for user names + user_id (string): search for user_id. ignored if name is not None + name (string): search for local part of user_id or display name guests (bool): whether to in include guest users deactivated (bool): whether to include deactivated users Returns: @@ -516,11 +517,14 @@ class DataStore( def get_users_paginate_txn(txn): filters = [] - args = [] + args = [self.hs.config.server_name] if name: + filters.append("(name LIKE ? OR displayname LIKE ?)") + args.extend(["@%" + name + "%:%", "%" + name + "%"]) + elif user_id: filters.append("name LIKE ?") - args.append("%" + name + "%") + args.extend(["%" + user_id + "%"]) if not guests: filters.append("is_guest = 0") @@ -530,20 +534,23 @@ class DataStore( where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" - sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause) - txn.execute(sql, args) - count = txn.fetchone()[0] - - args = [self.hs.config.server_name] + args + [limit, start] - sql = """ - SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url + sql_base = """ FROM users as u LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ? {} - ORDER BY u.name LIMIT ? OFFSET ? """.format( where_clause ) + sql = "SELECT COUNT(*) as total_users " + sql_base + txn.execute(sql, args) + count = txn.fetchone()[0] + + sql = ( + "SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url " + + sql_base + + " ORDER BY u.name LIMIT ? OFFSET ?" + ) + args += [limit, start] txn.execute(sql, args) users = self.db_pool.cursor_to_dict(txn) return users, count diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 82aac2bbf3..04042a2c98 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py
@@ -336,7 +336,7 @@ class AccountDataStore(AccountDataWorkerStore): """ content_json = json_encoder.encode(content) - with self._account_data_id_gen.get_next() as next_id: + with await self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. @@ -384,7 +384,7 @@ class AccountDataStore(AccountDataWorkerStore): """ content_json = json_encoder.encode(content) - with self._account_data_id_gen.get_next() as next_id: + with await self._account_data_id_gen.get_next() as next_id: # no need to lock here as account_data has a unique constraint on # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 5cf1a88399..77723f7d4d 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py
@@ -16,13 +16,12 @@ import logging import re -from canonicaljson import json - from synapse.appservice import AppServiceTransaction from synapse.config.appservice import load_appservices from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.util import json_encoder logger = logging.getLogger(__name__) @@ -169,7 +168,7 @@ class ApplicationServiceTransactionWorkerStore( service(ApplicationService): The service whose state to set. state(ApplicationServiceState): The connectivity state to apply. Returns: - A Deferred which resolves when the state was set successfully. + An Awaitable which resolves when the state was set successfully. """ return self.db_pool.simple_upsert( "application_services_state", {"as_id": service.id}, {"state": state} @@ -204,7 +203,7 @@ class ApplicationServiceTransactionWorkerStore( new_txn_id = max(highest_txn_id, last_txn_id) + 1 # Insert new txn into txn table - event_ids = json.dumps([e.event_id for e in events]) + event_ids = json_encoder.encode([e.event_id for e in events]) txn.execute( "INSERT INTO application_services_txns(as_id, txn_id, event_ids) " "VALUES(?,?,?)", diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 10de446065..1e7637a6f5 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py
@@ -299,8 +299,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): }, ) - def get_cache_stream_token(self, instance_name): + def get_cache_stream_token_for_writer(self, instance_name: str) -> int: if self._cache_id_gen: - return self._cache_id_gen.get_current_token(instance_name) + return self._cache_id_gen.get_current_token_for_writer(instance_name) else: return 0 diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 1f6e995c4f..bb85637a95 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py
@@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) rows.append((destination, stream_id, now_ms, edu_json)) txn.executemany(sql, rows) - with self._device_inbox_id_gen.get_next() as stream_id: + with await self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id @@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) txn, stream_id, local_messages_by_user_then_device ) - with self._device_inbox_id_gen.get_next() as stream_id: + with await self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_from_remote_to_device_inbox", diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 2b33060480..03b45dbc4d 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -380,7 +380,7 @@ class DeviceWorkerStore(SQLBaseStore): THe new stream ID. """ - with self._device_list_id_gen.get_next() as stream_id: + with await self._device_list_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "add_user_sig_change_to_streams", self._add_user_signature_change_txn, @@ -671,10 +671,9 @@ class DeviceWorkerStore(SQLBaseStore): @cachedList( cached_method_name="get_device_list_last_stream_id_for_remote", list_name="user_ids", - inlineCallbacks=True, ) - def get_device_list_last_stream_id_for_remotes(self, user_ids: str): - rows = yield self.db_pool.simple_select_many_batch( + async def get_device_list_last_stream_id_for_remotes(self, user_ids: str): + rows = await self.db_pool.simple_select_many_batch( table="device_lists_remote_extremeties", column="user_id", iterable=user_ids, @@ -1147,7 +1146,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not device_ids: return - with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids: + with await self._device_list_id_gen.get_next_mult( + len(device_ids) + ) as stream_ids: await self.db_pool.runInteraction( "add_device_change_to_stream", self._add_device_change_to_stream_txn, @@ -1160,7 +1161,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return stream_ids[-1] context = get_active_span_text_map() - with self._device_list_id_gen.get_next_mult( + with await self._device_list_id_gen.get_next_mult( len(hosts) * len(device_ids) ) as stream_ids: await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index f93e0d320d..385868bdab 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -648,7 +648,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) - def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key): + def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id): """Set a user's cross-signing key. Args: @@ -658,6 +658,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key key (dict): the key data + stream_id (int) """ # the 'key' dict will look something like: # { @@ -695,23 +696,22 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): ) # and finally, store the key itself - with self._cross_signing_id_gen.get_next() as stream_id: - self.db_pool.simple_insert_txn( - txn, - "e2e_cross_signing_keys", - values={ - "user_id": user_id, - "keytype": key_type, - "keydata": json_encoder.encode(key), - "stream_id": stream_id, - }, - ) + self.db_pool.simple_insert_txn( + txn, + "e2e_cross_signing_keys", + values={ + "user_id": user_id, + "keytype": key_type, + "keydata": json_encoder.encode(key), + "stream_id": stream_id, + }, + ) self._invalidate_cache_and_stream( txn, self._get_bare_e2e_cross_signing_keys, (user_id,) ) - def set_e2e_cross_signing_key(self, user_id, key_type, key): + async def set_e2e_cross_signing_key(self, user_id, key_type, key): """Set a user's cross-signing key. Args: @@ -719,13 +719,16 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): key_type (str): the type of cross-signing key to set key (dict): the key data """ - return self.db_pool.runInteraction( - "add_e2e_cross_signing_key", - self._set_e2e_cross_signing_key_txn, - user_id, - key_type, - key, - ) + + with await self._cross_signing_id_gen.get_next() as stream_id: + return await self.db_pool.runInteraction( + "add_e2e_cross_signing_key", + self._set_e2e_cross_signing_key_txn, + user_id, + key_type, + key, + stream_id, + ) def store_e2e_cross_signing_signatures(self, user_id, signatures): """Stores cross-signing signatures. diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 484875f989..e6a97b018c 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py
@@ -15,14 +15,16 @@ import itertools import logging from queue import Empty, PriorityQueue -from typing import Dict, Iterable, List, Optional, Set, Tuple +from typing import Dict, Iterable, List, Set, Tuple from synapse.api.errors import StoreError +from synapse.events import EventBase from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore +from synapse.types import Collection from synapse.util.caches.descriptors import cached from synapse.util.iterutils import batch_iter @@ -30,57 +32,51 @@ logger = logging.getLogger(__name__) class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): - def get_auth_chain(self, event_ids, include_given=False): + async def get_auth_chain( + self, event_ids: Collection[str], include_given: bool = False + ) -> List[EventBase]: """Get auth events for given event_ids. The events *must* be state events. Args: - event_ids (list): state events - include_given (bool): include the given events in result + event_ids: state events + include_given: include the given events in result Returns: list of events """ - return self.get_auth_chain_ids( + event_ids = await self.get_auth_chain_ids( event_ids, include_given=include_given - ).addCallback(self.get_events_as_list) - - def get_auth_chain_ids( - self, - event_ids: List[str], - include_given: bool = False, - ignore_events: Optional[Set[str]] = None, - ): + ) + return await self.get_events_as_list(event_ids) + + async def get_auth_chain_ids( + self, event_ids: Collection[str], include_given: bool = False, + ) -> List[str]: """Get auth events for given event_ids. The events *must* be state events. Args: event_ids: state events include_given: include the given events in result - ignore_events: Set of events to exclude from the returned auth - chain. This is useful if the caller will just discard the - given events anyway, and saves us from figuring out their auth - chains if not required. Returns: list of event_ids """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given, - ignore_events, ) - def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events): - if ignore_events is None: - ignore_events = set() - + def _get_auth_chain_ids_txn( + self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool + ) -> List[str]: if include_given: results = set(event_ids) else: results = set() - base_sql = "SELECT auth_id FROM event_auth WHERE " + base_sql = "SELECT DISTINCT auth_id FROM event_auth WHERE " front = set(event_ids) while front: @@ -92,7 +88,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas txn.execute(base_sql + clause, args) new_front.update(r[0] for r in txn) - new_front -= ignore_events new_front -= results front = new_front @@ -257,11 +252,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas # Return all events where not all sets can reach them. return {eid for eid, n in event_to_missing_sets.items() if n} - def get_oldest_events_in_room(self, room_id): - return self.db_pool.runInteraction( - "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id - ) - def get_oldest_events_with_depth_in_room(self, room_id): return self.db_pool.runInteraction( "get_oldest_events_with_depth_in_room", @@ -303,14 +293,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas else: return max(row["depth"] for row in rows) - def _get_oldest_events_in_room_txn(self, txn, room_id): - return self.db_pool.simple_select_onecol_txn( - txn, - table="event_backward_extremities", - keyvalues={"room_id": room_id}, - retcol="event_id", - ) - def get_prev_events_for_room(self, room_id: str): """ Gets a subset of the current forward extremities in the given room. @@ -472,7 +454,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn ) - def get_backfill_events(self, room_id, event_list, limit): + async def get_backfill_events(self, room_id, event_list, limit): """Get a list of Events for a given topic that occurred before (and including) the events in event_list. Return a list of max size `limit` @@ -482,17 +464,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas event_list (list) limit (int) """ - return ( - self.db_pool.runInteraction( - "get_backfill_events", - self._get_backfill_events, - room_id, - event_list, - limit, - ) - .addCallback(self.get_events_as_list) - .addCallback(lambda l: sorted(l, key=lambda e: -e.depth)) + event_ids = await self.db_pool.runInteraction( + "get_backfill_events", + self._get_backfill_events, + room_id, + event_list, + limit, ) + events = await self.get_events_as_list(event_ids) + return sorted(events, key=lambda e: -e.depth) def _get_backfill_events(self, txn, room_id, event_list, limit): logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit) @@ -553,8 +533,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas latest_events, limit, ) - events = await self.get_events_as_list(ids) - return events + return await self.get_events_as_list(ids) def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 7c246d3e4c..e8834b2162 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py
@@ -21,7 +21,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.util import json_encoder -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -86,18 +86,17 @@ class EventPushActionsWorkerStore(SQLBaseStore): self._rotate_delay = 3 self._rotate_count = 10000 - @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000) - def get_unread_event_push_actions_by_room_for_user( + @cached(num_args=3, tree=True, max_entries=5000) + async def get_unread_event_push_actions_by_room_for_user( self, room_id, user_id, last_read_event_id ): - ret = yield self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_unread_event_push_actions_by_room", self._get_unread_counts_by_receipt_txn, room_id, user_id, last_read_event_id, ) - return ret def _get_unread_counts_by_receipt_txn( self, txn, room_id, user_id, last_read_event_id diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1a68bf32cb..6313b41eef 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -17,13 +17,11 @@ import itertools import logging from collections import OrderedDict, namedtuple -from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple import attr from prometheus_client import Counter -from twisted.internet import defer - import synapse.metrics from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.room_versions import RoomVersions @@ -113,15 +111,14 @@ class PersistEventsStore: hs.config.worker.writers.events == hs.get_instance_name() ), "Can only instantiate EventsStore on master" - @defer.inlineCallbacks - def _persist_events_and_state_updates( + async def _persist_events_and_state_updates( self, events_and_contexts: List[Tuple[EventBase, EventContext]], current_state_for_room: Dict[str, StateMap[str]], state_delta_for_room: Dict[str, DeltaState], new_forward_extremeties: Dict[str, List[str]], backfilled: bool = False, - ): + ) -> None: """Persist a set of events alongside updates to the current state and forward extremities tables. @@ -136,7 +133,7 @@ class PersistEventsStore: backfilled Returns: - Deferred: resolves when the events have been persisted + Resolves when the events have been persisted """ # We want to calculate the stream orderings as late as possible, as @@ -156,11 +153,11 @@ class PersistEventsStore: # Note: Multiple instances of this function cannot be in flight at # the same time for the same room. if backfilled: - stream_ordering_manager = self._backfill_id_gen.get_next_mult( + stream_ordering_manager = await self._backfill_id_gen.get_next_mult( len(events_and_contexts) ) else: - stream_ordering_manager = self._stream_id_gen.get_next_mult( + stream_ordering_manager = await self._stream_id_gen.get_next_mult( len(events_and_contexts) ) @@ -168,7 +165,7 @@ class PersistEventsStore: for (event, context), stream in zip(events_and_contexts, stream_orderings): event.internal_metadata.stream_ordering = stream - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "persist_events", self._persist_events_txn, events_and_contexts=events_and_contexts, @@ -206,16 +203,15 @@ class PersistEventsStore: (room_id,), list(latest_event_ids) ) - @defer.inlineCallbacks - def _get_events_which_are_prevs(self, event_ids): + async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]: """Filter the supplied list of event_ids to get those which are prev_events of existing (non-outlier/rejected) events. Args: - event_ids (Iterable[str]): event ids to filter + event_ids: event ids to filter Returns: - Deferred[List[str]]: filtered event ids + Filtered event ids """ results = [] @@ -240,14 +236,13 @@ class PersistEventsStore: results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed")) for chunk in batch_iter(event_ids, 100): - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk ) return results - @defer.inlineCallbacks - def _get_prevs_before_rejected(self, event_ids): + async def _get_prevs_before_rejected(self, event_ids: Iterable[str]) -> Set[str]: """Get soft-failed ancestors to remove from the extremities. Given a set of events, find all those that have been soft-failed or @@ -259,11 +254,11 @@ class PersistEventsStore: are separated by soft failed events. Args: - event_ids (Iterable[str]): Events to find prev events for. Note - that these must have already been persisted. + event_ids: Events to find prev events for. Note that these must have + already been persisted. Returns: - Deferred[set[str]] + The previous events. """ # The set of event_ids to return. This includes all soft-failed events @@ -304,7 +299,7 @@ class PersistEventsStore: existing_prevs.add(prev_event_id) for chunk in batch_iter(event_ids, 100): - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk ) diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 35a0e09e3c..e53c6373a8 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.constants import EventContentFields from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool @@ -94,8 +92,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): where_clause="NOT have_censored", ) - @defer.inlineCallbacks - def _background_reindex_fields_sender(self, progress, batch_size): + async def _background_reindex_fields_sender(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) @@ -155,19 +152,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return len(rows) - result = yield self.db_pool.runInteraction( + result = await self.db_pool.runInteraction( self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn ) if not result: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME ) return result - @defer.inlineCallbacks - def _background_reindex_origin_server_ts(self, progress, batch_size): + async def _background_reindex_origin_server_ts(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) @@ -234,19 +230,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return len(rows_to_update) - result = yield self.db_pool.runInteraction( + result = await self.db_pool.runInteraction( self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn ) if not result: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.EVENT_ORIGIN_SERVER_TS_NAME ) return result - @defer.inlineCallbacks - def _cleanup_extremities_bg_update(self, progress, batch_size): + async def _cleanup_extremities_bg_update(self, progress, batch_size): """Background update to clean out extremities that should have been deleted previously. @@ -414,26 +409,25 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return len(original_set) - num_handled = yield self.db_pool.runInteraction( + num_handled = await self.db_pool.runInteraction( "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn ) if not num_handled: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.DELETE_SOFT_FAILED_EXTREMITIES ) def _drop_table_txn(txn): txn.execute("DROP TABLE _extremities_to_check") - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_cleanup_extremities_bg_update_drop_table", _drop_table_txn ) return num_handled - @defer.inlineCallbacks - def _redactions_received_ts(self, progress, batch_size): + async def _redactions_received_ts(self, progress, batch_size): """Handles filling out the `received_ts` column in redactions. """ last_event_id = progress.get("last_event_id", "") @@ -480,17 +474,16 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return len(rows) - count = yield self.db_pool.runInteraction( + count = await self.db_pool.runInteraction( "_redactions_received_ts", _redactions_received_ts_txn ) if not count: - yield self.db_pool.updates._end_background_update("redactions_received_ts") + await self.db_pool.updates._end_background_update("redactions_received_ts") return count - @defer.inlineCallbacks - def _event_fix_redactions_bytes(self, progress, batch_size): + async def _event_fix_redactions_bytes(self, progress, batch_size): """Undoes hex encoded censored redacted event JSON. """ @@ -511,16 +504,15 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): txn.execute("DROP INDEX redactions_censored_redacts") - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn ) - yield self.db_pool.updates._end_background_update("event_fix_redactions_bytes") + await self.db_pool.updates._end_background_update("event_fix_redactions_bytes") return 1 - @defer.inlineCallbacks - def _event_store_labels(self, progress, batch_size): + async def _event_store_labels(self, progress, batch_size): """Background update handler which will store labels for existing events.""" last_event_id = progress.get("last_event_id", "") @@ -575,11 +567,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return nbrows - num_rows = yield self.db_pool.runInteraction( + num_rows = await self.db_pool.runInteraction( desc="event_store_labels", func=_event_store_labels_txn ) if not num_rows: - yield self.db_pool.updates._end_background_update("event_store_labels") + await self.db_pool.updates._end_background_update("event_store_labels") return num_rows diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 755b7a2a85..e1241a724b 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py
@@ -19,9 +19,10 @@ import itertools import logging import threading from collections import namedtuple -from typing import List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple, overload from constantly import NamedConstant, Names +from typing_extensions import Literal from twisted.internet import defer @@ -32,7 +33,7 @@ from synapse.api.room_versions import ( EventFormatVersions, RoomVersions, ) -from synapse.events import make_event_from_dict +from synapse.events import EventBase, make_event_from_dict from synapse.events.utils import prune_event from synapse.logging.context import PreserveLoggingContext, current_context from synapse.metrics.background_process_metrics import run_as_background_process @@ -42,8 +43,8 @@ from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator -from synapse.types import get_domain_from_id -from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks +from synapse.types import Collection, get_domain_from_id +from synapse.util.caches.descriptors import Cache, cached from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -137,44 +138,33 @@ class EventsWorkerStore(SQLBaseStore): desc="get_received_ts", ) - def get_received_ts_by_stream_pos(self, stream_ordering): - """Given a stream ordering get an approximate timestamp of when it - happened. - - This is done by simply taking the received ts of the first event that - has a stream ordering greater than or equal to the given stream pos. - If none exists returns the current time, on the assumption that it must - have happened recently. - - Args: - stream_ordering (int) - - Returns: - Deferred[int] - """ - - def _get_approximate_received_ts_txn(txn): - sql = """ - SELECT received_ts FROM events - WHERE stream_ordering >= ? - LIMIT 1 - """ - - txn.execute(sql, (stream_ordering,)) - row = txn.fetchone() - if row and row[0]: - ts = row[0] - else: - ts = self.clock.time_msec() - - return ts + # Inform mypy that if allow_none is False (the default) then get_event + # always returns an EventBase. + @overload + async def get_event( + self, + event_id: str, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[False] = False, + check_room_id: Optional[str] = None, + ) -> EventBase: + ... - return self.db_pool.runInteraction( - "get_approximate_received_ts", _get_approximate_received_ts_txn - ) + @overload + async def get_event( + self, + event_id: str, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[True] = False, + check_room_id: Optional[str] = None, + ) -> Optional[EventBase]: + ... - @defer.inlineCallbacks - def get_event( + async def get_event( self, event_id: str, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, @@ -182,7 +172,7 @@ class EventsWorkerStore(SQLBaseStore): allow_rejected: bool = False, allow_none: bool = False, check_room_id: Optional[str] = None, - ): + ) -> Optional[EventBase]: """Get an event from the database by event_id. Args: @@ -207,12 +197,12 @@ class EventsWorkerStore(SQLBaseStore): If there is a mismatch, behave as per allow_none. Returns: - Deferred[EventBase|None] + The event, or None if the event was not found. """ if not isinstance(event_id, str): raise TypeError("Invalid event event_id %r" % (event_id,)) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [event_id], redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, @@ -230,14 +220,13 @@ class EventsWorkerStore(SQLBaseStore): return event - @defer.inlineCallbacks - def get_events( + async def get_events( self, - event_ids: List[str], + event_ids: Iterable[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, - ): + ) -> Dict[str, EventBase]: """Get events from the database Args: @@ -256,9 +245,9 @@ class EventsWorkerStore(SQLBaseStore): omits rejeted events from the response. Returns: - Deferred : Dict from event_id to event. + A mapping from event_id to event. """ - events = yield self.get_events_as_list( + events = await self.get_events_as_list( event_ids, redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, @@ -267,14 +256,13 @@ class EventsWorkerStore(SQLBaseStore): return {e.event_id: e for e in events} - @defer.inlineCallbacks - def get_events_as_list( + async def get_events_as_list( self, - event_ids: List[str], + event_ids: Collection[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, - ): + ) -> List[EventBase]: """Get events from the database and return in a list in the same order as given by `event_ids` arg. @@ -295,8 +283,8 @@ class EventsWorkerStore(SQLBaseStore): omits rejected events from the response. Returns: - Deferred[list[EventBase]]: List of events fetched from the database. The - events are in the same order as `event_ids` arg. + List of events fetched from the database. The events are in the same + order as `event_ids` arg. Note that the returned list may be smaller than the list of event IDs if not all events could be fetched. @@ -306,7 +294,7 @@ class EventsWorkerStore(SQLBaseStore): return [] # there may be duplicates so we cast the list to a set - event_entry_map = yield self._get_events_from_cache_or_db( + event_entry_map = await self._get_events_from_cache_or_db( set(event_ids), allow_rejected=allow_rejected ) @@ -341,7 +329,7 @@ class EventsWorkerStore(SQLBaseStore): continue redacted_event_id = entry.event.redacts - event_map = yield self._get_events_from_cache_or_db([redacted_event_id]) + event_map = await self._get_events_from_cache_or_db([redacted_event_id]) original_event_entry = event_map.get(redacted_event_id) if not original_event_entry: # we don't have the redacted event (or it was rejected). @@ -407,7 +395,7 @@ class EventsWorkerStore(SQLBaseStore): if get_prev_content: if "replaces_state" in event.unsigned: - prev = yield self.get_event( + prev = await self.get_event( event.unsigned["replaces_state"], get_prev_content=False, allow_none=True, @@ -419,8 +407,7 @@ class EventsWorkerStore(SQLBaseStore): return events - @defer.inlineCallbacks - def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): + async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): """Fetch a bunch of events from the cache or the database. If events are pulled from the database, they will be cached for future lookups. @@ -435,7 +422,7 @@ class EventsWorkerStore(SQLBaseStore): rejected events are omitted from the response. Returns: - Deferred[Dict[str, _EventCacheEntry]]: + Dict[str, _EventCacheEntry]: map from event id to result """ event_entry_map = self._get_events_from_cache( @@ -453,7 +440,7 @@ class EventsWorkerStore(SQLBaseStore): # the events have been redacted, and if so pulling the redaction event out # of the database to check it. # - missing_events = yield self._get_events_from_db( + missing_events = await self._get_events_from_db( missing_events_ids, allow_rejected=allow_rejected ) @@ -561,8 +548,7 @@ class EventsWorkerStore(SQLBaseStore): with PreserveLoggingContext(): self.hs.get_reactor().callFromThread(fire, event_list, e) - @defer.inlineCallbacks - def _get_events_from_db(self, event_ids, allow_rejected=False): + async def _get_events_from_db(self, event_ids, allow_rejected=False): """Fetch a bunch of events from the database. Returned events will be added to the cache for future lookups. @@ -576,7 +562,7 @@ class EventsWorkerStore(SQLBaseStore): rejected events are omitted from the response. Returns: - Deferred[Dict[str, _EventCacheEntry]]: + Dict[str, _EventCacheEntry]: map from event id to result. May return extra events which weren't asked for. """ @@ -584,7 +570,7 @@ class EventsWorkerStore(SQLBaseStore): events_to_fetch = event_ids while events_to_fetch: - row_map = yield self._enqueue_events(events_to_fetch) + row_map = await self._enqueue_events(events_to_fetch) # we need to recursively fetch any redactions of those events redaction_ids = set() @@ -610,8 +596,20 @@ class EventsWorkerStore(SQLBaseStore): if not allow_rejected and rejected_reason: continue - d = db_to_json(row["json"]) - internal_metadata = db_to_json(row["internal_metadata"]) + # If the event or metadata cannot be parsed, log the error and act + # as if the event is unknown. + try: + d = db_to_json(row["json"]) + except ValueError: + logger.error("Unable to parse json from event: %s", event_id) + continue + try: + internal_metadata = db_to_json(row["internal_metadata"]) + except ValueError: + logger.error( + "Unable to parse internal_metadata from event: %s", event_id + ) + continue format_version = row["format_version"] if format_version is None: @@ -622,19 +620,38 @@ class EventsWorkerStore(SQLBaseStore): room_version_id = row["room_version_id"] if not room_version_id: - # this should only happen for out-of-band membership events - if not internal_metadata.get("out_of_band_membership"): - logger.warning( - "Room %s for event %s is unknown", d["room_id"], event_id + # this should only happen for out-of-band membership events which + # arrived before #6983 landed. For all other events, we should have + # an entry in the 'rooms' table. + # + # However, the 'out_of_band_membership' flag is unreliable for older + # invites, so just accept it for all membership events. + # + if d["type"] != EventTypes.Member: + raise Exception( + "Room %s for event %s is unknown" % (d["room_id"], event_id) ) - continue - # take a wild stab at the room version based on the event format + # so, assuming this is an out-of-band-invite that arrived before #6983 + # landed, we know that the room version must be v5 or earlier (because + # v6 hadn't been invented at that point, so invites from such rooms + # would have been rejected.) + # + # The main reason we need to know the room version here (other than + # choosing the right python Event class) is in case the event later has + # to be redacted - and all the room versions up to v5 used the same + # redaction algorithm. + # + # So, the following approximations should be adequate. + if format_version == EventFormatVersions.V1: + # if it's event format v1 then it must be room v1 or v2 room_version = RoomVersions.V1 elif format_version == EventFormatVersions.V2: + # if it's event format v2 then it must be room v3 room_version = RoomVersions.V3 else: + # if it's event format v3 then it must be room v4 or v5 room_version = RoomVersions.V5 else: room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) @@ -686,8 +703,7 @@ class EventsWorkerStore(SQLBaseStore): return result_map - @defer.inlineCallbacks - def _enqueue_events(self, events): + async def _enqueue_events(self, events): """Fetches events from the database using the _event_fetch_list. This allows batch and bulk fetching of events - it allows us to fetch events without having to create a new transaction for each request for events. @@ -696,7 +712,7 @@ class EventsWorkerStore(SQLBaseStore): events (Iterable[str]): events to be fetched. Returns: - Deferred[Dict[str, Dict]]: map from event id to row data from the database. + Dict[str, Dict]: map from event id to row data from the database. May contain events that weren't requested. """ @@ -719,7 +735,7 @@ class EventsWorkerStore(SQLBaseStore): logger.debug("Loading %d events: %s", len(events), events) with PreserveLoggingContext(): - row_map = yield events_d + row_map = await events_d logger.debug("Loaded %d events (%d rows)", len(events), len(row_map)) return row_map @@ -878,12 +894,11 @@ class EventsWorkerStore(SQLBaseStore): # no valid redaction found for this event return None - @defer.inlineCallbacks - def have_events_in_timeline(self, event_ids): + async def have_events_in_timeline(self, event_ids): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ - rows = yield self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="events", retcols=("event_id",), column="event_id", @@ -894,15 +909,14 @@ class EventsWorkerStore(SQLBaseStore): return {r["event_id"] for r in rows} - @defer.inlineCallbacks - def have_seen_events(self, event_ids): + async def have_seen_events(self, event_ids): """Given a list of event ids, check if we have already processed them. Args: event_ids (iterable[str]): Returns: - Deferred[set[str]]: The events we have already seen. + set[str]: The events we have already seen. """ results = set() @@ -918,41 +932,11 @@ class EventsWorkerStore(SQLBaseStore): # break the input up into chunks of 100 input_iterator = iter(event_ids) for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "have_seen_events", have_seen_events_txn, chunk ) return results - def _get_total_state_event_counts_txn(self, txn, room_id): - """ - See get_total_state_event_counts. - """ - # We join against the events table as that has an index on room_id - sql = """ - SELECT COUNT(*) FROM state_events - INNER JOIN events USING (room_id, event_id) - WHERE room_id=? - """ - txn.execute(sql, (room_id,)) - row = txn.fetchone() - return row[0] if row else 0 - - def get_total_state_event_counts(self, room_id): - """ - Gets the total number of state events in a room. - - Args: - room_id (str) - - Returns: - Deferred[int] - """ - return self.db_pool.runInteraction( - "get_total_state_event_counts", - self._get_total_state_event_counts_txn, - room_id, - ) - def _get_current_state_event_counts_txn(self, txn, room_id): """ See get_current_state_event_counts. @@ -978,8 +962,7 @@ class EventsWorkerStore(SQLBaseStore): room_id, ) - @defer.inlineCallbacks - def get_room_complexity(self, room_id): + async def get_room_complexity(self, room_id): """ Get a rough approximation of the complexity of the room. This is used by remote servers to decide whether they wish to join the room or not. @@ -990,9 +973,9 @@ class EventsWorkerStore(SQLBaseStore): room_id (str) Returns: - Deferred[dict[str:int]] of complexity version to complexity. + dict[str:int] of complexity version to complexity. """ - state_events = yield self.get_current_state_event_counts(room_id) + state_events = await self.get_current_state_event_counts(room_id) # Call this one "v1", so we can introduce new ones as we want to develop # it. @@ -1222,97 +1205,6 @@ class EventsWorkerStore(SQLBaseStore): return rows, to_token, True - @cached(num_args=5, max_entries=10) - def get_all_new_events( - self, - last_backfill_id, - last_forward_id, - current_backfill_id, - current_forward_id, - limit, - ): - """Get all the new events that have arrived at the server either as - new events or as backfilled events""" - have_backfill_events = last_backfill_id != current_backfill_id - have_forward_events = last_forward_id != current_forward_id - - if not have_backfill_events and not have_forward_events: - return defer.succeed(AllNewEventsResult([], [], [], [], [])) - - def get_all_new_events_txn(txn): - sql = ( - "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " WHERE ? < stream_ordering AND stream_ordering <= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" - ) - if have_forward_events: - txn.execute(sql, (last_forward_id, current_forward_id, limit)) - new_forward_events = txn.fetchall() - - if len(new_forward_events) == limit: - upper_bound = new_forward_events[-1][0] - else: - upper_bound = current_forward_id - - sql = ( - "SELECT event_stream_ordering, event_id, state_group" - " FROM ex_outlier_stream" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (last_forward_id, upper_bound)) - forward_ex_outliers = txn.fetchall() - else: - new_forward_events = [] - forward_ex_outliers = [] - - sql = ( - "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " WHERE ? > stream_ordering AND stream_ordering >= ?" - " ORDER BY stream_ordering DESC" - " LIMIT ?" - ) - if have_backfill_events: - txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) - new_backfill_events = txn.fetchall() - - if len(new_backfill_events) == limit: - upper_bound = new_backfill_events[-1][0] - else: - upper_bound = current_backfill_id - - sql = ( - "SELECT -event_stream_ordering, event_id, state_group" - " FROM ex_outlier_stream" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (-last_backfill_id, -upper_bound)) - backward_ex_outliers = txn.fetchall() - else: - new_backfill_events = [] - backward_ex_outliers = [] - - return AllNewEventsResult( - new_forward_events, - new_backfill_events, - forward_ex_outliers, - backward_ex_outliers, - ) - - return self.db_pool.runInteraction("get_all_new_events", get_all_new_events_txn) - async def is_event_after(self, event_id1, event_id2): """Returns True if event_id1 is after event_id2 in the stream """ @@ -1320,9 +1212,9 @@ class EventsWorkerStore(SQLBaseStore): to_2, so_2 = await self.get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) - @cachedInlineCallbacks(max_entries=5000) - def get_event_ordering(self, event_id): - res = yield self.db_pool.simple_select_one( + @cached(max_entries=5000) + async def get_event_ordering(self, event_id): + res = await self.db_pool.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], keyvalues={"event_id": event_id}, @@ -1357,14 +1249,3 @@ class EventsWorkerStore(SQLBaseStore): return self.db_pool.runInteraction( desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) - - -AllNewEventsResult = namedtuple( - "AllNewEventsResult", - [ - "new_forward_events", - "new_backfill_events", - "forward_ex_outliers", - "backward_ex_outliers", - ], -) diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 380db3a3f3..a488e0924b 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py
@@ -341,14 +341,15 @@ class GroupServerWorkerStore(SQLBaseStore): "get_users_for_summary_by_role", _get_users_for_summary_txn ) - def is_user_in_group(self, user_id, group_id): - return self.db_pool.simple_select_one_onecol( + async def is_user_in_group(self, user_id: str, group_id: str) -> bool: + result = await self.db_pool.simple_select_one_onecol( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", allow_none=True, desc="is_user_in_group", - ).addCallback(lambda r: bool(r)) + ) + return bool(result) def is_user_admin_in_group(self, group_id, user_id): return self.db_pool.simple_select_one_onecol( @@ -1181,7 +1182,7 @@ class GroupServerStore(GroupServerWorkerStore): return next_id - with self._group_updates_id_gen.get_next() as next_id: + with await self._group_updates_id_gen.get_next() as next_id: res = await self.db_pool.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 384e9c5eb0..fadcad51e7 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py
@@ -16,6 +16,7 @@ import itertools import logging +from typing import Iterable, Tuple from signedjson.key import decode_verify_key_bytes @@ -88,12 +89,17 @@ class KeyStore(SQLBaseStore): return self.db_pool.runInteraction("get_server_verify_keys", _txn) - def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): + async def store_server_verify_keys( + self, + from_server: str, + ts_added_ms: int, + verify_keys: Iterable[Tuple[str, str, FetchKeyResult]], + ) -> None: """Stores NACL verification keys for remote servers. Args: - from_server (str): Where the verification keys were looked up - ts_added_ms (int): The time to record that the key was added - verify_keys (iterable[tuple[str, str, FetchKeyResult]]): + from_server: Where the verification keys were looked up + ts_added_ms: The time to record that the key was added + verify_keys: keys to be stored. Each entry is a triplet of (server_name, key_id, key). """ @@ -115,13 +121,7 @@ class KeyStore(SQLBaseStore): # param, which is itself the 2-tuple (server_name, key_id). invalidations.append((server_name, key_id)) - def _invalidate(res): - f = self._get_server_verify_key.invalidate - for i in invalidations: - f((i,)) - return res - - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "store_server_verify_keys", self.db_pool.simple_upsert_many_txn, table="server_signature_keys", @@ -134,7 +134,11 @@ class KeyStore(SQLBaseStore): "verify_key", ), value_values=value_values, - ).addCallback(_invalidate) + ) + + invalidate = self._get_server_verify_key.invalidate + for i in invalidations: + invalidate((i,)) def store_server_keys_json( self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 59ba12820a..c9f655dfb7 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py
@@ -15,15 +15,15 @@ from typing import List, Tuple +from synapse.api.presence import UserPresenceState from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.presence import UserPresenceState from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter class PresenceStore(SQLBaseStore): async def update_presence(self, presence_states): - stream_ordering_manager = self._presence_id_gen.get_next_mult( + stream_ordering_manager = await self._presence_id_gen.get_next_mult( len(presence_states) ) @@ -130,13 +130,10 @@ class PresenceStore(SQLBaseStore): raise NotImplementedError() @cachedList( - cached_method_name="_get_presence_for_user", - list_name="user_ids", - num_args=1, - inlineCallbacks=True, + cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1, ) - def get_presence_for_users(self, user_ids): - rows = yield self.db_pool.simple_select_many_batch( + async def get_presence_for_users(self, user_ids): + rows = await self.db_pool.simple_select_many_batch( table="presence_stream", column="user_id", iterable=user_ids, @@ -160,24 +157,3 @@ class PresenceStore(SQLBaseStore): def get_current_presence_token(self): return self._presence_id_gen.get_current_token() - - def allow_presence_visible(self, observed_localpart, observer_userid): - return self.db_pool.simple_insert( - table="presence_allow_inbound", - values={ - "observed_user_id": observed_localpart, - "observer_user_id": observer_userid, - }, - desc="allow_presence_visible", - or_ignore=True, - ) - - def disallow_presence_visible(self, observed_localpart, observer_userid): - return self.db_pool.simple_delete_one( - table="presence_allow_inbound", - keyvalues={ - "observed_user_id": observed_localpart, - "observer_user_id": observer_userid, - }, - desc="disallow_presence_visible", - ) diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index b8261357d4..086cfbeed4 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py
@@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,9 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Tuple + from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore from synapse.storage.databases.main.roommember import ProfileInfo +from synapse.types import UserID +from synapse.util.caches.descriptors import cached + +BATCH_SIZE = 100 class ProfileWorkerStore(SQLBaseStore): @@ -38,6 +45,7 @@ class ProfileWorkerStore(SQLBaseStore): avatar_url=profile["avatar_url"], display_name=profile["displayname"] ) + @cached(max_entries=5000) def get_profile_displayname(self, user_localpart): return self.db_pool.simple_select_one_onecol( table="profiles", @@ -46,6 +54,7 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_profile_displayname", ) + @cached(max_entries=5000) def get_profile_avatar_url(self, user_localpart): return self.db_pool.simple_select_one_onecol( table="profiles", @@ -54,6 +63,56 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_profile_avatar_url", ) + def get_latest_profile_replication_batch_number(self): + def f(txn): + txn.execute("SELECT MAX(batch) as maxbatch FROM profiles") + rows = self.db_pool.cursor_to_dict(txn) + return rows[0]["maxbatch"] + + return self.db_pool.runInteraction( + "get_latest_profile_replication_batch_number", f + ) + + def get_profile_batch(self, batchnum): + return self.db_pool.simple_select_list( + table="profiles", + keyvalues={"batch": batchnum}, + retcols=("user_id", "displayname", "avatar_url", "active"), + desc="get_profile_batch", + ) + + def assign_profile_batch(self): + def f(txn): + sql = ( + "UPDATE profiles SET batch = " + "(SELECT COALESCE(MAX(batch), -1) + 1 FROM profiles) " + "WHERE user_id in (" + " SELECT user_id FROM profiles WHERE batch is NULL limit ?" + ")" + ) + txn.execute(sql, (BATCH_SIZE,)) + return txn.rowcount + + return self.db_pool.runInteraction("assign_profile_batch", f) + + def get_replication_hosts(self): + def f(txn): + txn.execute( + "SELECT host, last_synced_batch FROM profile_replication_status" + ) + rows = self.db_pool.cursor_to_dict(txn) + return {r["host"]: r["last_synced_batch"] for r in rows} + + return self.db_pool.runInteraction("get_replication_hosts", f) + + def update_replication_batch_for_host(self, host, last_synced_batch): + return self.db_pool.simple_upsert( + table="profile_replication_status", + keyvalues={"host": host}, + values={"last_synced_batch": last_synced_batch}, + desc="update_replication_batch_for_host", + ) + def get_from_remote_profile_cache(self, user_id): return self.db_pool.simple_select_one( table="remote_profile_cache", @@ -68,24 +127,83 @@ class ProfileWorkerStore(SQLBaseStore): table="profiles", values={"user_id": user_localpart}, desc="create_profile" ) - def set_profile_displayname(self, user_localpart, new_displayname): - return self.db_pool.simple_update_one( + def set_profile_displayname(self, user_localpart, new_displayname, batchnum): + # Invalidate the read cache for this user + self.get_profile_displayname.invalidate((user_localpart,)) + + return self.db_pool.simple_upsert( table="profiles", keyvalues={"user_id": user_localpart}, - updatevalues={"displayname": new_displayname}, + values={"displayname": new_displayname, "batch": batchnum}, desc="set_profile_displayname", + lock=False, # we can do this because user_id has a unique index ) - def set_profile_avatar_url(self, user_localpart, new_avatar_url): - return self.db_pool.simple_update_one( + def set_profile_avatar_url(self, user_localpart, new_avatar_url, batchnum): + # Invalidate the read cache for this user + self.get_profile_avatar_url.invalidate((user_localpart,)) + + return self.db_pool.simple_upsert( table="profiles", keyvalues={"user_id": user_localpart}, - updatevalues={"avatar_url": new_avatar_url}, + values={"avatar_url": new_avatar_url, "batch": batchnum}, desc="set_profile_avatar_url", + lock=False, # we can do this because user_id has a unique index + ) + + def set_profiles_active( + self, users: List[UserID], active: bool, hide: bool, batchnum: int, + ): + """Given a set of users, set active and hidden flags on them. + + Args: + users: A list of UserIDs + active: Whether to set the users to active or inactive + hide: Whether to hide the users (withold from replication). If + False and active is False, users will have their profiles + erased + batchnum: The batch number, used for profile replication + + Returns: + Deferred + """ + # Convert list of localparts to list of tuples containing localparts + user_localparts = [(user.localpart,) for user in users] + + # Generate list of value tuples for each user + value_names = ("active", "batch") + values = [(int(active), batchnum) for _ in user_localparts] # type: List[Tuple] + + if not active and not hide: + # we are deactivating for real (not in hide mode) + # so clear the profile information + value_names += ("avatar_url", "displayname") + values = [v + (None, None) for v in values] + + return self.db_pool.runInteraction( + "set_profiles_active", + self.db_pool.simple_upsert_many_txn, + table="profiles", + key_names=("user_id",), + key_values=user_localparts, + value_names=value_names, + value_values=values, ) class ProfileStore(ProfileWorkerStore): + def __init__(self, database, db_conn, hs): + + super(ProfileStore, self).__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_index_update( + "profile_replication_status_host_index", + index_name="profile_replication_status_idx", + table="profile_replication_status", + columns=["host"], + unique=True, + ) + def add_remote_profile_cache(self, user_id, displayname, avatar_url): """Ensure we are caching the remote user's profiles. @@ -104,7 +222,7 @@ class ProfileStore(ProfileWorkerStore): ) def update_remote_profile_cache(self, user_id, displayname, avatar_url): - return self.db_pool.simple_update( + return self.db_pool.simple_upsert( table="remote_profile_cache", keyvalues={"user_id": user_id}, updatevalues={ diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 6562db5c2b..2fb5b02d7d 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py
@@ -30,9 +30,9 @@ from synapse.storage.databases.main.pusher import PusherWorkerStore from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException -from synapse.storage.util.id_generators import ChainedIdGenerator +from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util import json_encoder -from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList +from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) @@ -82,9 +82,9 @@ class PushRulesWorkerStore( super(PushRulesWorkerStore, self).__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: - self._push_rules_stream_id_gen = ChainedIdGenerator( - self._stream_id_gen, db_conn, "push_rules_stream", "stream_id" - ) # type: Union[ChainedIdGenerator, SlavedIdTracker] + self._push_rules_stream_id_gen = StreamIdGenerator( + db_conn, "push_rules_stream", "stream_id" + ) # type: Union[StreamIdGenerator, SlavedIdTracker] else: self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id" @@ -115,9 +115,9 @@ class PushRulesWorkerStore( """ raise NotImplementedError() - @cachedInlineCallbacks(max_entries=5000) - def get_push_rules_for_user(self, user_id): - rows = yield self.db_pool.simple_select_list( + @cached(max_entries=5000) + async def get_push_rules_for_user(self, user_id): + rows = await self.db_pool.simple_select_list( table="push_rules", keyvalues={"user_name": user_id}, retcols=( @@ -133,17 +133,15 @@ class PushRulesWorkerStore( rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) - enabled_map = yield self.get_push_rules_enabled_for_user(user_id) + enabled_map = await self.get_push_rules_enabled_for_user(user_id) use_new_defaults = user_id in self._users_new_default_push_rules - rules = _load_rules(rows, enabled_map, use_new_defaults) - - return rules + return _load_rules(rows, enabled_map, use_new_defaults) - @cachedInlineCallbacks(max_entries=5000) - def get_push_rules_enabled_for_user(self, user_id): - results = yield self.db_pool.simple_select_list( + @cached(max_entries=5000) + async def get_push_rules_enabled_for_user(self, user_id): + results = await self.db_pool.simple_select_list( table="push_rules_enable", keyvalues={"user_name": user_id}, retcols=("user_name", "rule_id", "enabled"), @@ -170,18 +168,15 @@ class PushRulesWorkerStore( ) @cachedList( - cached_method_name="get_push_rules_for_user", - list_name="user_ids", - num_args=1, - inlineCallbacks=True, + cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1, ) - def bulk_get_push_rules(self, user_ids): + async def bulk_get_push_rules(self, user_ids): if not user_ids: return {} results = {user_id: [] for user_id in user_ids} - rows = yield self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="push_rules", column="user_name", iterable=user_ids, @@ -194,7 +189,7 @@ class PushRulesWorkerStore( for row in rows: results.setdefault(row["user_name"], []).append(row) - enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids) + enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids) for user_id, rules in results.items(): use_new_defaults = user_id in self._users_new_default_push_rules @@ -205,14 +200,15 @@ class PushRulesWorkerStore( return results - @defer.inlineCallbacks - def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule): + async def copy_push_rule_from_room_to_room( + self, new_room_id: str, user_id: str, rule: dict + ) -> None: """Copy a single push rule from one room to another for a specific user. Args: - new_room_id (str): ID of the new room. - user_id (str): ID of user the push rule belongs to. - rule (Dict): A push rule. + new_room_id: ID of the new room. + user_id : ID of user the push rule belongs to. + rule: A push rule. """ # Create new rule id rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1]) @@ -224,7 +220,7 @@ class PushRulesWorkerStore( condition["pattern"] = new_room_id # Add the rule for the new room - yield self.add_push_rule( + await self.add_push_rule( user_id=user_id, rule_id=new_rule_id, priority_class=rule["priority_class"], @@ -232,20 +228,19 @@ class PushRulesWorkerStore( actions=rule["actions"], ) - @defer.inlineCallbacks - def copy_push_rules_from_room_to_room_for_user( - self, old_room_id, new_room_id, user_id - ): + async def copy_push_rules_from_room_to_room_for_user( + self, old_room_id: str, new_room_id: str, user_id: str + ) -> None: """Copy all of the push rules from one room to another for a specific user. Args: - old_room_id (str): ID of the old room. - new_room_id (str): ID of the new room. - user_id (str): ID of user to copy push rules for. + old_room_id: ID of the old room. + new_room_id: ID of the new room. + user_id: ID of user to copy push rules for. """ # Retrieve push rules for this user - user_push_rules = yield self.get_push_rules_for_user(user_id) + user_push_rules = await self.get_push_rules_for_user(user_id) # Get rules relating to the old room and copy them to the new room for rule in user_push_rules: @@ -254,21 +249,20 @@ class PushRulesWorkerStore( (c.get("key") == "room_id" and c.get("pattern") == old_room_id) for c in conditions ): - yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule) + await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule) @cachedList( cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids", num_args=1, - inlineCallbacks=True, ) - def bulk_get_push_rules_enabled(self, user_ids): + async def bulk_get_push_rules_enabled(self, user_ids): if not user_ids: return {} results = {user_id: {} for user_id in user_ids} - rows = yield self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="push_rules_enable", column="user_name", iterable=user_ids, @@ -332,8 +326,7 @@ class PushRulesWorkerStore( class PushRuleStore(PushRulesWorkerStore): - @defer.inlineCallbacks - def add_push_rule( + async def add_push_rule( self, user_id, rule_id, @@ -342,13 +335,14 @@ class PushRuleStore(PushRulesWorkerStore): actions, before=None, after=None, - ): + ) -> None: conditions_json = json_encoder.encode(conditions) actions_json = json_encoder.encode(actions) - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids + with await self._push_rules_stream_id_gen.get_next() as stream_id: + event_stream_ordering = self._stream_id_gen.get_current_token() + if before or after: - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_add_push_rule_relative_txn", self._add_push_rule_relative_txn, stream_id, @@ -362,7 +356,7 @@ class PushRuleStore(PushRulesWorkerStore): after, ) else: - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_add_push_rule_highest_priority_txn", self._add_push_rule_highest_priority_txn, stream_id, @@ -546,16 +540,15 @@ class PushRuleStore(PushRulesWorkerStore): }, ) - @defer.inlineCallbacks - def delete_push_rule(self, user_id, rule_id): + async def delete_push_rule(self, user_id: str, rule_id: str) -> None: """ Delete a push rule. Args specify the row to be deleted and can be any of the columns in the push_rule table, but below are the standard ones Args: - user_id (str): The matrix ID of the push rule owner - rule_id (str): The rule_id of the rule to be deleted + user_id: The matrix ID of the push rule owner + rule_id: The rule_id of the rule to be deleted """ def delete_push_rule_txn(txn, stream_id, event_stream_ordering): @@ -567,20 +560,21 @@ class PushRuleStore(PushRulesWorkerStore): txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" ) - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids - yield self.db_pool.runInteraction( + with await self._push_rules_stream_id_gen.get_next() as stream_id: + event_stream_ordering = self._stream_id_gen.get_current_token() + + await self.db_pool.runInteraction( "delete_push_rule", delete_push_rule_txn, stream_id, event_stream_ordering, ) - @defer.inlineCallbacks - def set_push_rule_enabled(self, user_id, rule_id, enabled): - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids - yield self.db_pool.runInteraction( + async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None: + with await self._push_rules_stream_id_gen.get_next() as stream_id: + event_stream_ordering = self._stream_id_gen.get_current_token() + + await self.db_pool.runInteraction( "_set_push_rule_enabled_txn", self._set_push_rule_enabled_txn, stream_id, @@ -611,8 +605,9 @@ class PushRuleStore(PushRulesWorkerStore): op="ENABLE" if enabled else "DISABLE", ) - @defer.inlineCallbacks - def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule): + async def set_push_rule_actions( + self, user_id, rule_id, actions, is_default_rule + ) -> None: actions_json = json_encoder.encode(actions) def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): @@ -651,9 +646,10 @@ class PushRuleStore(PushRulesWorkerStore): data={"actions": actions_json}, ) - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids - yield self.db_pool.runInteraction( + with await self._push_rules_stream_id_gen.get_next() as stream_id: + event_stream_ordering = self._stream_id_gen.get_current_token() + + await self.db_pool.runInteraction( "set_push_rule_actions", set_push_rule_actions_txn, stream_id, @@ -681,11 +677,5 @@ class PushRuleStore(PushRulesWorkerStore): self.push_rules_stream_cache.entity_has_changed, user_id, stream_id ) - def get_push_rules_stream_token(self): - """Get the position of the push rules stream. - Returns a pair of a stream id for the push_rules stream and the - room stream ordering it corresponds to.""" - return self._push_rules_stream_id_gen.get_current_token() - def get_max_push_rules_stream_id(self): - return self.get_push_rules_stream_token()[0] + return self._push_rules_stream_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index b5200fbe79..c388468273 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py
@@ -19,10 +19,8 @@ from typing import Iterable, Iterator, List, Tuple from canonicaljson import encode_canonical_json -from twisted.internet import defer - from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList +from synapse.util.caches.descriptors import cached, cachedList logger = logging.getLogger(__name__) @@ -34,23 +32,22 @@ class PusherWorkerStore(SQLBaseStore): Drops any rows whose data cannot be decoded """ for r in rows: - dataJson = r["data"] + data_json = r["data"] try: - r["data"] = db_to_json(dataJson) + r["data"] = db_to_json(data_json) except Exception as e: logger.warning( "Invalid JSON in data for pusher %d: %s, %s", r["id"], - dataJson, + data_json, e.args[0], ) continue yield r - @defer.inlineCallbacks - def user_has_pusher(self, user_id): - ret = yield self.db_pool.simple_select_one_onecol( + async def user_has_pusher(self, user_id): + ret = await self.db_pool.simple_select_one_onecol( "pushers", {"user_name": user_id}, "id", allow_none=True ) return ret is not None @@ -61,9 +58,8 @@ class PusherWorkerStore(SQLBaseStore): def get_pushers_by_user_id(self, user_id): return self.get_pushers_by({"user_name": user_id}) - @defer.inlineCallbacks - def get_pushers_by(self, keyvalues): - ret = yield self.db_pool.simple_select_list( + async def get_pushers_by(self, keyvalues): + ret = await self.db_pool.simple_select_list( "pushers", keyvalues, [ @@ -87,16 +83,14 @@ class PusherWorkerStore(SQLBaseStore): ) return self._decode_pushers_rows(ret) - @defer.inlineCallbacks - def get_all_pushers(self): + async def get_all_pushers(self): def get_pushers(txn): txn.execute("SELECT * FROM pushers") rows = self.db_pool.cursor_to_dict(txn) return self._decode_pushers_rows(rows) - rows = yield self.db_pool.runInteraction("get_all_pushers", get_pushers) - return rows + return await self.db_pool.runInteraction("get_all_pushers", get_pushers) async def get_all_updated_pushers_rows( self, instance_name: str, last_id: int, current_id: int, limit: int @@ -164,19 +158,16 @@ class PusherWorkerStore(SQLBaseStore): "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn ) - @cachedInlineCallbacks(num_args=1, max_entries=15000) - def get_if_user_has_pusher(self, user_id): + @cached(num_args=1, max_entries=15000) + async def get_if_user_has_pusher(self, user_id): # This only exists for the cachedList decorator raise NotImplementedError() @cachedList( - cached_method_name="get_if_user_has_pusher", - list_name="user_ids", - num_args=1, - inlineCallbacks=True, + cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1, ) - def get_if_users_have_pushers(self, user_ids): - rows = yield self.db_pool.simple_select_many_batch( + async def get_if_users_have_pushers(self, user_ids): + rows = await self.db_pool.simple_select_many_batch( table="pushers", column="user_name", iterable=user_ids, @@ -189,34 +180,38 @@ class PusherWorkerStore(SQLBaseStore): return result - @defer.inlineCallbacks - def update_pusher_last_stream_ordering( + async def update_pusher_last_stream_ordering( self, app_id, pushkey, user_id, last_stream_ordering - ): - yield self.db_pool.simple_update_one( + ) -> None: + await self.db_pool.simple_update_one( "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, {"last_stream_ordering": last_stream_ordering}, desc="update_pusher_last_stream_ordering", ) - @defer.inlineCallbacks - def update_pusher_last_stream_ordering_and_success( - self, app_id, pushkey, user_id, last_stream_ordering, last_success - ): + async def update_pusher_last_stream_ordering_and_success( + self, + app_id: str, + pushkey: str, + user_id: str, + last_stream_ordering: int, + last_success: int, + ) -> bool: """Update the last stream ordering position we've processed up to for the given pusher. Args: - app_id (str) - pushkey (str) - last_stream_ordering (int) - last_success (int) + app_id + pushkey + user_id + last_stream_ordering + last_success Returns: - Deferred[bool]: True if the pusher still exists; False if it has been deleted. + True if the pusher still exists; False if it has been deleted. """ - updated = yield self.db_pool.simple_update( + updated = await self.db_pool.simple_update( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, updatevalues={ @@ -228,18 +223,18 @@ class PusherWorkerStore(SQLBaseStore): return bool(updated) - @defer.inlineCallbacks - def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): - yield self.db_pool.simple_update( + async def update_pusher_failing_since( + self, app_id, pushkey, user_id, failing_since + ) -> None: + await self.db_pool.simple_update( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, updatevalues={"failing_since": failing_since}, desc="update_pusher_failing_since", ) - @defer.inlineCallbacks - def get_throttle_params_by_room(self, pusher_id): - res = yield self.db_pool.simple_select_list( + async def get_throttle_params_by_room(self, pusher_id): + res = await self.db_pool.simple_select_list( "pusher_throttle", {"pusher": pusher_id}, ["room_id", "last_sent_ts", "throttle_ms"], @@ -255,11 +250,10 @@ class PusherWorkerStore(SQLBaseStore): return params_by_room - @defer.inlineCallbacks - def set_throttle_params(self, pusher_id, room_id, params): + async def set_throttle_params(self, pusher_id, room_id, params) -> None: # no need to lock because `pusher_throttle` has a primary key on # (pusher, room_id) so simple_upsert will retry - yield self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( "pusher_throttle", {"pusher": pusher_id, "room_id": room_id}, params, @@ -272,8 +266,7 @@ class PusherStore(PusherWorkerStore): def get_pushers_stream_token(self): return self._pushers_id_gen.get_current_token() - @defer.inlineCallbacks - def add_pusher( + async def add_pusher( self, user_id, access_token, @@ -287,11 +280,11 @@ class PusherStore(PusherWorkerStore): data, last_stream_ordering, profile_tag="", - ): - with self._pushers_id_gen.get_next() as stream_id: + ) -> None: + with await self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on # (app_id, pushkey, user_name) so simple_upsert will retry - yield self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, values={ @@ -316,15 +309,16 @@ class PusherStore(PusherWorkerStore): if user_has_pusher is not True: # invalidate, since we the user might not have had a pusher before - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "add_pusher", self._invalidate_cache_and_stream, self.get_if_user_has_pusher, (user_id,), ) - @defer.inlineCallbacks - def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id): + async def delete_pusher_by_app_id_pushkey_user_id( + self, app_id, pushkey, user_id + ) -> None: def delete_pusher_txn(txn, stream_id): self._invalidate_cache_and_stream( txn, self.get_if_user_has_pusher, (user_id,) @@ -350,7 +344,7 @@ class PusherStore(PusherWorkerStore): }, ) - with self._pushers_id_gen.get_next() as stream_id: - yield self.db_pool.runInteraction( + with await self._pushers_id_gen.get_next() as stream_id: + await self.db_pool.runInteraction( "delete_pusher", delete_pusher_txn, stream_id ) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 1920a8a152..6821476ee0 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py
@@ -16,7 +16,7 @@ import abc import logging -from typing import List, Tuple +from typing import List, Optional, Tuple from twisted.internet import defer @@ -25,7 +25,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) @@ -56,9 +56,9 @@ class ReceiptsWorkerStore(SQLBaseStore): """ raise NotImplementedError() - @cachedInlineCallbacks() - def get_users_with_read_receipts_in_room(self, room_id): - receipts = yield self.get_receipts_for_room(room_id, "m.read") + @cached() + async def get_users_with_read_receipts_in_room(self, room_id): + receipts = await self.get_receipts_for_room(room_id, "m.read") return {r["user_id"] for r in receipts} @cached(num_args=2) @@ -84,9 +84,9 @@ class ReceiptsWorkerStore(SQLBaseStore): allow_none=True, ) - @cachedInlineCallbacks(num_args=2) - def get_receipts_for_user(self, user_id, receipt_type): - rows = yield self.db_pool.simple_select_list( + @cached(num_args=2) + async def get_receipts_for_user(self, user_id, receipt_type): + rows = await self.db_pool.simple_select_list( table="receipts_linearized", keyvalues={"user_id": user_id, "receipt_type": receipt_type}, retcols=("room_id", "event_id"), @@ -95,8 +95,7 @@ class ReceiptsWorkerStore(SQLBaseStore): return {row["room_id"]: row["event_id"] for row in rows} - @defer.inlineCallbacks - def get_receipts_for_user_with_orderings(self, user_id, receipt_type): + async def get_receipts_for_user_with_orderings(self, user_id, receipt_type): def f(txn): sql = ( "SELECT rl.room_id, rl.event_id," @@ -110,7 +109,7 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql, (user_id,)) return txn.fetchall() - rows = yield self.db_pool.runInteraction( + rows = await self.db_pool.runInteraction( "get_receipts_for_user_with_orderings", f ) return { @@ -122,56 +121,61 @@ class ReceiptsWorkerStore(SQLBaseStore): for row in rows } - @defer.inlineCallbacks - def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): + async def get_linearized_receipts_for_rooms( + self, room_ids: List[str], to_key: int, from_key: Optional[int] = None + ) -> List[dict]: """Get receipts for multiple rooms for sending to clients. Args: - room_ids (list): List of room_ids. - to_key (int): Max stream id to fetch receipts upto. - from_key (int): Min stream id to fetch receipts from. None fetches + room_id: List of room_ids. + to_key: Max stream id to fetch receipts upto. + from_key: Min stream id to fetch receipts from. None fetches from the start. Returns: - list: A list of receipts. + A list of receipts. """ room_ids = set(room_ids) if from_key is not None: # Only ask the database about rooms where there have been new # receipts added since `from_key` - room_ids = yield self._receipts_stream_cache.get_entities_changed( + room_ids = self._receipts_stream_cache.get_entities_changed( room_ids, from_key ) - results = yield self._get_linearized_receipts_for_rooms( + results = await self._get_linearized_receipts_for_rooms( room_ids, to_key, from_key=from_key ) return [ev for res in results.values() for ev in res] - def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): + async def get_linearized_receipts_for_room( + self, room_id: str, to_key: int, from_key: Optional[int] = None + ) -> List[dict]: """Get receipts for a single room for sending to clients. Args: - room_ids (str): The room id. - to_key (int): Max stream id to fetch receipts upto. - from_key (int): Min stream id to fetch receipts from. None fetches + room_ids: The room id. + to_key: Max stream id to fetch receipts upto. + from_key: Min stream id to fetch receipts from. None fetches from the start. Returns: - Deferred[list]: A list of receipts. + A list of receipts. """ if from_key is not None: # Check the cache first to see if any new receipts have been added # since`from_key`. If not we can no-op. if not self._receipts_stream_cache.has_entity_changed(room_id, from_key): - defer.succeed([]) + return [] - return self._get_linearized_receipts_for_room(room_id, to_key, from_key) + return await self._get_linearized_receipts_for_room(room_id, to_key, from_key) - @cachedInlineCallbacks(num_args=3, tree=True) - def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): + @cached(num_args=3, tree=True) + async def _get_linearized_receipts_for_room( + self, room_id: str, to_key: int, from_key: Optional[int] = None + ) -> List[dict]: """See get_linearized_receipts_for_room """ @@ -195,7 +199,7 @@ class ReceiptsWorkerStore(SQLBaseStore): return rows - rows = yield self.db_pool.runInteraction("get_linearized_receipts_for_room", f) + rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f) if not rows: return [] @@ -212,9 +216,8 @@ class ReceiptsWorkerStore(SQLBaseStore): cached_method_name="_get_linearized_receipts_for_room", list_name="room_ids", num_args=3, - inlineCallbacks=True, ) - def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): + async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): if not room_ids: return {} @@ -243,7 +246,7 @@ class ReceiptsWorkerStore(SQLBaseStore): return self.db_pool.cursor_to_dict(txn) - txn_results = yield self.db_pool.runInteraction( + txn_results = await self.db_pool.runInteraction( "_get_linearized_receipts_for_rooms", f ) @@ -346,7 +349,7 @@ class ReceiptsWorkerStore(SQLBaseStore): ) def _invalidate_get_users_with_receipts_in_room( - self, room_id, receipt_type, user_id + self, room_id: str, receipt_type: str, user_id: str ): if receipt_type != "m.read": return @@ -472,15 +475,21 @@ class ReceiptsStore(ReceiptsWorkerStore): return rx_ts - @defer.inlineCallbacks - def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data): + async def insert_receipt( + self, + room_id: str, + receipt_type: str, + user_id: str, + event_ids: List[str], + data: dict, + ) -> Optional[Tuple[int, int]]: """Insert a receipt, either from local client or remote server. Automatically does conversion between linearized and graph representations. """ if not event_ids: - return + return None if len(event_ids) == 1: linearized_event_id = event_ids[0] @@ -507,13 +516,12 @@ class ReceiptsStore(ReceiptsWorkerStore): else: raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) - linearized_event_id = yield self.db_pool.runInteraction( + linearized_event_id = await self.db_pool.runInteraction( "insert_receipt_conv", graph_to_linear ) - stream_id_manager = self._receipts_id_gen.get_next() - with stream_id_manager as stream_id: - event_ts = yield self.db_pool.runInteraction( + with await self._receipts_id_gen.get_next() as stream_id: + event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, room_id, @@ -535,7 +543,7 @@ class ReceiptsStore(ReceiptsWorkerStore): now - event_ts, ) - yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data) + await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data) max_persisted_id = self._receipts_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 402ae25571..336b578e23 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -17,9 +17,7 @@ import logging import re -from typing import Dict, List, Optional - -from twisted.internet.defer import Deferred +from typing import Awaitable, Dict, List, Optional from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError @@ -158,6 +156,37 @@ class RegistrationWorkerStore(SQLBaseStore): "set_account_validity_for_user", set_account_validity_for_user_txn ) + async def get_expired_users(self): + """Get UserIDs of all expired users. + + Users who are not active, or do not have profile information, are + excluded from the results. + + Returns: + Deferred[List[UserID]]: List of expired user IDs + """ + + def get_expired_users_txn(txn, now_ms): + # We need to use pattern matching as profiles.user_id is confusingly just the + # user's localpart, whereas account_validity.user_id is a full user ID + sql = """ + SELECT av.user_id from account_validity AS av + LEFT JOIN profiles as p + ON av.user_id LIKE '%%' || p.user_id || ':%%' + WHERE expiration_ts_ms <= ? + AND p.active = 1 + """ + txn.execute(sql, (now_ms,)) + rows = txn.fetchall() + + return [UserID.from_string(row[0]) for row in rows] + + res = await self.db_pool.runInteraction( + "get_expired_users", get_expired_users_txn, self.clock.time_msec() + ) + + return res + async def set_renewal_token_for_user( self, user_id: str, renewal_token: str ) -> None: @@ -264,6 +293,54 @@ class RegistrationWorkerStore(SQLBaseStore): desc="delete_account_validity_for_user", ) + async def get_info_for_users( + self, user_ids: List[str], + ): + """Return the user info for a given set of users + + Args: + user_ids: A list of users to return information about + + Returns: + Deferred[Dict[str, bool]]: A dictionary mapping each user ID to + a dict with the following keys: + * expired - whether this is an expired user + * deactivated - whether this is a deactivated user + """ + # Get information of all our local users + def _get_info_for_users_txn(txn): + rows = [] + + for user_id in user_ids: + sql = """ + SELECT u.name, u.deactivated, av.expiration_ts_ms + FROM users as u + LEFT JOIN account_validity as av + ON av.user_id = u.name + WHERE u.name = ? + """ + + txn.execute(sql, (user_id,)) + row = txn.fetchone() + if row: + rows.append(row) + + return rows + + info_rows = await self.db_pool.runInteraction( + "get_info_for_users", _get_info_for_users_txn + ) + + return { + user_id: { + "expired": ( + expiration is not None and self.clock.time_msec() >= expiration + ), + "deactivated": deactivated == 1, + } + for user_id, deactivated, expiration in info_rows + } + async def is_server_admin(self, user: UserID) -> bool: """Determines if a user is an admin of this homeserver. @@ -304,7 +381,7 @@ class RegistrationWorkerStore(SQLBaseStore): def _query_for_auth(self, txn, token): sql = ( - "SELECT users.name, users.is_guest, access_tokens.id as token_id," + "SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id," " access_tokens.device_id, access_tokens.valid_until_ms" " FROM users" " INNER JOIN access_tokens on users.name = access_tokens.user_id" @@ -563,7 +640,7 @@ class RegistrationWorkerStore(SQLBaseStore): id_server (str) Returns: - Deferred + Awaitable """ # We need to use an upsert, in case they user had already bound the # threepid @@ -891,6 +968,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): super(RegistrationStore, self).__init__(database, db_conn, hs) self._account_validity = hs.config.account_validity + self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors if self._account_validity.enabled: self._clock.call_later( @@ -952,6 +1030,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): create_profile_with_displayname=None, admin=False, user_type=None, + shadow_banned=False, ): """Attempts to register an account. @@ -968,6 +1047,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): admin (boolean): is an admin user? user_type (str|None): type of user. One of the values from api.constants.UserTypes, or None for a normal user. + shadow_banned (bool): Whether the user is shadow-banned, + i.e. they may be told their requests succeeded but we ignore them. Raises: StoreError if the user_id could not be registered. @@ -986,6 +1067,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): create_profile_with_displayname, admin, user_type, + shadow_banned, ) def _register_user( @@ -999,6 +1081,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): create_profile_with_displayname, admin, user_type, + shadow_banned, ): user_id_obj = UserID.from_string(user_id) @@ -1028,6 +1111,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): "appservice_id": appservice_id, "admin": 1 if admin else 0, "user_type": user_type, + "shadow_banned": shadow_banned, }, ) else: @@ -1042,6 +1126,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): "appservice_id": appservice_id, "admin": 1 if admin else 0, "user_type": user_type, + "shadow_banned": shadow_banned, }, ) @@ -1077,7 +1162,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): def record_user_external_id( self, auth_provider: str, external_id: str, user_id: str - ) -> Deferred: + ) -> Awaitable: """Record a mapping from an external user id to a mxid Args: @@ -1297,15 +1382,22 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) if not row: - raise ThreepidValidationError(400, "Unknown session_id") + if self._ignore_unknown_session_error: + # If we need to inhibit the error caused by an incorrect session ID, + # use None as placeholder values for the client secret and the + # validation timestamp. + # It shouldn't be an issue because they're both only checked after + # the token check, which should fail. And if it doesn't for some + # reason, the next check is on the client secret, which is NOT NULL, + # so we don't have to worry about the client secret matching by + # accident. + row = {"client_secret": None, "validated_at": None} + else: + raise ThreepidValidationError(400, "Unknown session_id") + retrieved_client_secret = row["client_secret"] validated_at = row["validated_at"] - if retrieved_client_secret != client_secret: - raise ThreepidValidationError( - 400, "This client_secret does not match the provided session_id" - ) - row = self.db_pool.simple_select_one_txn( txn, table="threepid_validation_token", @@ -1321,6 +1413,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): expires = row["expires"] next_link = row["next_link"] + if retrieved_client_secret != client_secret: + raise ThreepidValidationError( + 400, "This client_secret does not match the provided session_id" + ) + # If the session is already validated, no need to revalidate if validated_at: return next_link @@ -1345,43 +1442,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): "validate_threepid_session_txn", validate_threepid_session_txn ) - def upsert_threepid_validation_session( - self, - medium, - address, - client_secret, - send_attempt, - session_id, - validated_at=None, - ): - """Upsert a threepid validation session - Args: - medium (str): The medium of the 3PID - address (str): The address of the 3PID - client_secret (str): A unique string provided by the client to - help identify this validation attempt - send_attempt (int): The latest send_attempt on this session - session_id (str): The id of this validation session - validated_at (int|None): The unix timestamp in milliseconds of - when the session was marked as valid - """ - insertion_values = { - "medium": medium, - "address": address, - "client_secret": client_secret, - } - - if validated_at: - insertion_values["validated_at"] = validated_at - - return self.db_pool.simple_upsert( - table="threepid_validation_session", - keyvalues={"session_id": session_id}, - values={"last_send_attempt": send_attempt}, - insertion_values=insertion_values, - desc="upsert_threepid_validation_session", - ) - def start_or_continue_validation_session( self, medium, diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index f4008e6221..99a8a9fab0 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py
@@ -21,8 +21,6 @@ from abc import abstractmethod from enum import Enum from typing import Any, Dict, List, Optional, Tuple -from canonicaljson import json - from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion, RoomVersions @@ -30,15 +28,12 @@ from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchStore from synapse.types import ThirdPartyInstanceID +from synapse.util import json_encoder from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) -OpsLevel = collections.namedtuple( - "OpsLevel", ("ban_level", "kick_level", "redact_level") -) - RatelimitOverride = collections.namedtuple( "RatelimitOverride", ("messages_per_second", "burst_count") ) @@ -344,6 +339,23 @@ class RoomWorkerStore(SQLBaseStore): desc="is_room_blocked", ) + async def is_room_published(self, room_id: str) -> bool: + """Check whether a room has been published in the local public room + directory. + + Args: + room_id + Returns: + Whether the room is currently published in the room directory + """ + # Get room information + room_info = await self.get_room(room_id) + if not room_info: + return False + + # Check the is_public value + return room_info.get("is_public", False) + async def get_rooms_paginate( self, start: int, @@ -552,6 +564,11 @@ class RoomWorkerStore(SQLBaseStore): Returns: dict[int, int]: "min_lifetime" and "max_lifetime" for this room. """ + # If the room retention feature is disabled, return a policy with no minimum nor + # maximum, in order not to filter out events we should filter out when sending to + # the client. + if not self.config.retention_enabled: + return {"min_lifetime": None, "max_lifetime": None} def get_retention_policy_for_room_txn(txn): txn.execute( @@ -1134,7 +1151,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with self._public_room_id_gen.get_next() as next_id: + with await self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "store_room_txn", store_room_txn, next_id ) @@ -1201,7 +1218,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with self._public_room_id_gen.get_next() as next_id: + with await self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "set_room_is_public", set_room_is_public_txn, next_id ) @@ -1281,7 +1298,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with self._public_room_id_gen.get_next() as next_id: + with await self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "set_room_is_public_appservice", set_room_is_public_appservice_txn, @@ -1314,7 +1331,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): "event_id": event_id, "user_id": user_id, "reason": reason, - "content": json.dumps(content), + "content": json_encoder.encode(content), }, desc="add_event_report", ) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index b2fcfc9bfe..161edbeccb 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py
@@ -17,8 +17,6 @@ import logging from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase from synapse.events.snapshot import EventContext @@ -92,8 +90,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): lambda: self._known_servers_count, ) - @defer.inlineCallbacks - def _count_known_servers(self): + async def _count_known_servers(self): """ Count the servers that this server knows about. @@ -121,7 +118,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(query) return list(txn)[0][0] - count = yield self.db_pool.runInteraction("get_known_servers", _transact) + count = await self.db_pool.runInteraction("get_known_servers", _transact) # We always know about ourselves, even if we have nothing in # room_memberships (for example, the server is new). @@ -589,11 +586,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): raise NotImplementedError() @cachedList( - cached_method_name="_get_joined_profile_from_event_id", - list_name="event_ids", - inlineCallbacks=True, + cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids", ) - def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]): + async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]): """For given set of member event_ids check if they point to a join event and if so return the associated user and profile info. @@ -601,11 +596,11 @@ class RoomMemberWorkerStore(EventsWorkerStore): event_ids: The member event IDs to lookup Returns: - Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID + dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID to `user_id` and ProfileInfo (or None if not join event). """ - rows = yield self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="room_memberships", column="event_id", iterable=event_ids, @@ -772,13 +767,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): return set(room_ids) - def get_membership_from_event_ids( + async def get_membership_from_event_ids( self, member_event_ids: Iterable[str] ) -> List[dict]: """Get user_id and membership of a set of event IDs. """ - return self.db_pool.simple_select_many_batch( + return await self.db_pool.simple_select_many_batch( table="room_memberships", column="event_id", iterable=member_event_ids, diff --git a/synapse/storage/databases/main/schema/delta/48/profiles_batch.sql b/synapse/storage/databases/main/schema/delta/48/profiles_batch.sql new file mode 100644
index 0000000000..e744c02fe8 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/48/profiles_batch.sql
@@ -0,0 +1,36 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Add a batch number to track changes to profiles and the + * order they're made in so we can replicate user profiles + * to other hosts as they change + */ +ALTER TABLE profiles ADD COLUMN batch BIGINT DEFAULT NULL; + +/* + * Index on the batch number so we can get profiles + * by their batch + */ +CREATE INDEX profiles_batch_idx ON profiles(batch); + +/* + * A table to track what batch of user profiles has been + * synced to what profile replication target. + */ +CREATE TABLE profile_replication_status ( + host TEXT NOT NULL, + last_synced_batch BIGINT NOT NULL +); diff --git a/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql b/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql new file mode 100644
index 0000000000..96051ac179 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql
@@ -0,0 +1,23 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * A flag saying whether the user owning the profile has been deactivated + * This really belongs on the users table, not here, but the users table + * stores users by their full user_id and profiles stores them by localpart, + * so we can't easily join between the two tables. Plus, the batch number + * realy ought to represent data in this table that has changed. + */ +ALTER TABLE profiles ADD COLUMN active SMALLINT DEFAULT 1 NOT NULL; \ No newline at end of file diff --git a/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql b/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql new file mode 100644
index 0000000000..7542ab8cbd --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql
@@ -0,0 +1,16 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE UNIQUE INDEX profile_replication_status_idx ON profile_replication_status(host); \ No newline at end of file diff --git a/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql new file mode 100644
index 0000000000..4cc96a5341 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql
@@ -0,0 +1,25 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- A table of the IP address and user-agent used to complete each step of a +-- user-interactive authentication session. +CREATE TABLE IF NOT EXISTS ui_auth_sessions_ips( + session_id TEXT NOT NULL, + ip TEXT NOT NULL, + user_agent TEXT NOT NULL, + UNIQUE (session_id, ip, user_agent), + FOREIGN KEY (session_id) + REFERENCES ui_auth_sessions (session_id) +); diff --git a/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql b/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql new file mode 100644
index 0000000000..260b009b48 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql
@@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- A shadow-banned user may be told that their requests succeeded when they were +-- actually ignored. +ALTER TABLE users ADD COLUMN shadow_banned BOOLEAN; diff --git a/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql b/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql new file mode 100644
index 0000000000..15421b99ac --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql
@@ -0,0 +1,17 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This table is no longer used. +DROP TABLE IF EXISTS presence_allow_inbound; diff --git a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres
index 889a9a0ce4..20c5af2eb7 100644 --- a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres +++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres
@@ -658,10 +658,19 @@ CREATE TABLE presence_stream ( +CREATE TABLE profile_replication_status ( + host text NOT NULL, + last_synced_batch bigint NOT NULL +); + + + CREATE TABLE profiles ( user_id text NOT NULL, displayname text, - avatar_url text + avatar_url text, + batch bigint, + active smallint DEFAULT 1 NOT NULL ); @@ -1788,6 +1797,10 @@ CREATE INDEX presence_stream_user_id ON presence_stream USING btree (user_id); +CREATE INDEX profiles_batch_idx ON profiles USING btree (batch); + + + CREATE INDEX public_room_index ON rooms USING btree (is_public); diff --git a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
index a0411ede7e..e28ec3fa45 100644 --- a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite +++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
@@ -6,7 +6,7 @@ CREATE TABLE presence_allow_inbound( observed_user_id TEXT NOT NULL, observer_us CREATE TABLE users( name TEXT, password_hash TEXT, creation_ts BIGINT, admin SMALLINT DEFAULT 0 NOT NULL, upgrade_ts BIGINT, is_guest SMALLINT DEFAULT 0 NOT NULL, appservice_id TEXT, consent_version TEXT, consent_server_notice_sent TEXT, user_type TEXT DEFAULT NULL, UNIQUE(name) ); CREATE TABLE access_tokens( id BIGINT PRIMARY KEY, user_id TEXT NOT NULL, device_id TEXT, token TEXT NOT NULL, last_used BIGINT, UNIQUE(token) ); CREATE TABLE user_ips ( user_id TEXT NOT NULL, access_token TEXT NOT NULL, device_id TEXT, ip TEXT NOT NULL, user_agent TEXT NOT NULL, last_seen BIGINT NOT NULL ); -CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, UNIQUE(user_id) ); +CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, batch BIGINT DEFAULT NULL, active SMALLINT DEFAULT 1 NOT NULL, UNIQUE(user_id) ); CREATE TABLE received_transactions( transaction_id TEXT, origin TEXT, ts BIGINT, response_code INTEGER, response_json bytea, has_been_referenced smallint default 0, UNIQUE (transaction_id, origin) ); CREATE TABLE destinations( destination TEXT PRIMARY KEY, retry_last_ts BIGINT, retry_interval INTEGER ); CREATE TABLE events( stream_ordering INTEGER PRIMARY KEY, topological_ordering BIGINT NOT NULL, event_id TEXT NOT NULL, type TEXT NOT NULL, room_id TEXT NOT NULL, content TEXT, unrecognized_keys TEXT, processed BOOL NOT NULL, outlier BOOL NOT NULL, depth BIGINT DEFAULT 0 NOT NULL, origin_server_ts BIGINT, received_ts BIGINT, sender TEXT, contains_url BOOLEAN, UNIQUE (event_id) ); @@ -202,6 +202,8 @@ CREATE INDEX group_users_u_idx ON group_users(user_id); CREATE INDEX group_invites_u_idx ON group_invites(user_id); CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id); CREATE INDEX group_rooms_r_idx ON group_rooms(room_id); +CREATE INDEX profiles_batch_idx ON profiles(batch); +CREATE TABLE profile_replication_status ( host TEXT NOT NULL, last_synced_batch BIGINT NOT NULL ); CREATE TABLE user_daily_visits ( user_id TEXT NOT NULL, device_id TEXT, timestamp BIGINT NOT NULL ); CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits(user_id, timestamp); CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits(timestamp); diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 96e0378e50..991233a9bc 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py
@@ -273,12 +273,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): cached_method_name="_get_state_group_for_event", list_name="event_ids", num_args=1, - inlineCallbacks=True, ) - def _get_state_group_for_events(self, event_ids): + async def _get_state_group_for_events(self, event_ids): """Returns mapping event_id -> state_group """ - rows = yield self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="event_to_state_groups", column="event_id", iterable=event_ids, diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index aaf225894e..497f607703 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py
@@ -39,15 +39,17 @@ what sort order was used: import abc import logging from collections import namedtuple -from typing import Optional +from typing import Dict, Iterable, List, Optional, Tuple from twisted.internet import defer +from synapse.api.filtering import Filter +from synapse.events import EventBase from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_in_list_sql_clause from synapse.storage.databases.main.events_worker import EventsWorkerStore -from synapse.storage.engines import PostgresEngine +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine from synapse.types import RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -68,8 +70,12 @@ _EventDictReturn = namedtuple( def generate_pagination_where_clause( - direction, column_names, from_token, to_token, engine -): + direction: str, + column_names: Tuple[str, str], + from_token: Optional[Tuple[int, int]], + to_token: Optional[Tuple[int, int]], + engine: BaseDatabaseEngine, +) -> str: """Creates an SQL expression to bound the columns by the pagination tokens. @@ -90,21 +96,19 @@ def generate_pagination_where_clause( token, but include those that match the to token. Args: - direction (str): Whether we're paginating backwards("b") or - forwards ("f"). - column_names (tuple[str, str]): The column names to bound. Must *not* - be user defined as these get inserted directly into the SQL - statement without escapes. - from_token (tuple[int, int]|None): The start point for the pagination. - This is an exclusive minimum bound if direction is "f", and an - inclusive maximum bound if direction is "b". - to_token (tuple[int, int]|None): The endpoint point for the pagination. - This is an inclusive maximum bound if direction is "f", and an - exclusive minimum bound if direction is "b". + direction: Whether we're paginating backwards("b") or forwards ("f"). + column_names: The column names to bound. Must *not* be user defined as + these get inserted directly into the SQL statement without escapes. + from_token: The start point for the pagination. This is an exclusive + minimum bound if direction is "f", and an inclusive maximum bound if + direction is "b". + to_token: The endpoint point for the pagination. This is an inclusive + maximum bound if direction is "f", and an exclusive minimum bound if + direction is "b". engine: The database engine to generate the clauses for Returns: - str: The sql expression + The sql expression """ assert direction in ("b", "f") @@ -132,7 +136,12 @@ def generate_pagination_where_clause( return " AND ".join(where_clause) -def _make_generic_sql_bound(bound, column_names, values, engine): +def _make_generic_sql_bound( + bound: str, + column_names: Tuple[str, str], + values: Tuple[Optional[int], int], + engine: BaseDatabaseEngine, +) -> str: """Create an SQL expression that bounds the given column names by the values, e.g. create the equivalent of `(1, 2) < (col1, col2)`. @@ -142,18 +151,18 @@ def _make_generic_sql_bound(bound, column_names, values, engine): out manually. Args: - bound (str): The comparison operator to use. One of ">", "<", ">=", + bound: The comparison operator to use. One of ">", "<", ">=", "<=", where the values are on the left and columns on the right. - names (tuple[str, str]): The column names. Must *not* be user defined + names: The column names. Must *not* be user defined as these get inserted directly into the SQL statement without escapes. - values (tuple[int|None, int]): The values to bound the columns by. If + values: The values to bound the columns by. If the first value is None then only creates a bound on the second column. engine: The database engine to generate the SQL for Returns: - str + The SQL statement """ assert bound in (">", "<", ">=", "<=") @@ -193,7 +202,7 @@ def _make_generic_sql_bound(bound, column_names, values, engine): ) -def filter_to_clause(event_filter): +def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]: # NB: This may create SQL clauses that don't optimise well (and we don't # have indices on all possible clauses). E.g. it may create # "room_id == X AND room_id != X", which postgres doesn't optimise. @@ -291,34 +300,35 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): def get_room_min_stream_ordering(self): raise NotImplementedError() - @defer.inlineCallbacks - def get_room_events_stream_for_rooms( - self, room_ids, from_key, to_key, limit=0, order="DESC" - ): + async def get_room_events_stream_for_rooms( + self, + room_ids: Iterable[str], + from_key: str, + to_key: str, + limit: int = 0, + order: str = "DESC", + ) -> Dict[str, Tuple[List[EventBase], str]]: """Get new room events in stream ordering since `from_key`. Args: - room_id (str) - from_key (str): Token from which no events are returned before - to_key (str): Token from which no events are returned after. (This + room_ids + from_key: Token from which no events are returned before + to_key: Token from which no events are returned after. (This is typically the current stream token) - limit (int): Maximum number of events to return - order (str): Either "DESC" or "ASC". Determines which events are + limit: Maximum number of events to return + order: Either "DESC" or "ASC". Determines which events are returned when the result is limited. If "DESC" then the most recent `limit` events are returned, otherwise returns the oldest `limit` events. Returns: - Deferred[dict[str,tuple[list[FrozenEvent], str]]] - A map from room id to a tuple containing: - - list of recent events in the room - - stream ordering key for the start of the chunk of events returned. + A map from room id to a tuple containing: + - list of recent events in the room + - stream ordering key for the start of the chunk of events returned. """ from_id = RoomStreamToken.parse_stream_token(from_key).stream - room_ids = yield self._events_stream_cache.get_entities_changed( - room_ids, from_id - ) + room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id) if not room_ids: return {} @@ -326,7 +336,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): results = {} room_ids = list(room_ids) for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)): - res = yield make_deferred_yieldable( + res = await make_deferred_yieldable( defer.gatherResults( [ run_in_background( @@ -361,28 +371,30 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if self._events_stream_cache.has_entity_changed(room_id, from_key) } - @defer.inlineCallbacks - def get_room_events_stream_for_room( - self, room_id, from_key, to_key, limit=0, order="DESC" - ): - + async def get_room_events_stream_for_room( + self, + room_id: str, + from_key: str, + to_key: str, + limit: int = 0, + order: str = "DESC", + ) -> Tuple[List[EventBase], str]: """Get new room events in stream ordering since `from_key`. Args: - room_id (str) - from_key (str): Token from which no events are returned before - to_key (str): Token from which no events are returned after. (This + room_id + from_key: Token from which no events are returned before + to_key: Token from which no events are returned after. (This is typically the current stream token) - limit (int): Maximum number of events to return - order (str): Either "DESC" or "ASC". Determines which events are + limit: Maximum number of events to return + order: Either "DESC" or "ASC". Determines which events are returned when the result is limited. If "DESC" then the most recent `limit` events are returned, otherwise returns the oldest `limit` events. Returns: - Deferred[tuple[list[FrozenEvent], str]]: Returns the list of - events (in ascending order) and the token from the start of - the chunk of events returned. + The list of events (in ascending order) and the token from the start + of the chunk of events returned. """ if from_key == to_key: return [], from_key @@ -390,9 +402,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): from_id = RoomStreamToken.parse_stream_token(from_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream - has_changed = yield self._events_stream_cache.has_entity_changed( - room_id, from_id - ) + has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id) if not has_changed: return [], from_key @@ -410,9 +420,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] return rows - rows = yield self.db_pool.runInteraction("get_room_events_stream_for_room", f) + rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f) - ret = yield self.get_events_as_list( + ret = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) @@ -430,8 +440,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return ret, key - @defer.inlineCallbacks - def get_membership_changes_for_user(self, user_id, from_key, to_key): + async def get_membership_changes_for_user(self, user_id, from_key, to_key): from_id = RoomStreamToken.parse_stream_token(from_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream @@ -460,9 +469,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows - rows = yield self.db_pool.runInteraction("get_membership_changes_for_user", f) + rows = await self.db_pool.runInteraction("get_membership_changes_for_user", f) - ret = yield self.get_events_as_list( + ret = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) @@ -470,27 +479,26 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return ret - @defer.inlineCallbacks - def get_recent_events_for_room(self, room_id, limit, end_token): + async def get_recent_events_for_room( + self, room_id: str, limit: int, end_token: str + ) -> Tuple[List[EventBase], str]: """Get the most recent events in the room in topological ordering. Args: - room_id (str) - limit (int) - end_token (str): The stream token representing now. + room_id + limit + end_token: The stream token representing now. Returns: - Deferred[tuple[list[FrozenEvent], str]]: Returns a list of - events and a token pointing to the start of the returned - events. - The events returned are in ascending order. + A list of events and a token pointing to the start of the returned + events. The events returned are in ascending order. """ - rows, token = yield self.get_recent_event_ids_for_room( + rows, token = await self.get_recent_event_ids_for_room( room_id, limit, end_token ) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) @@ -498,20 +506,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return (events, token) - @defer.inlineCallbacks - def get_recent_event_ids_for_room(self, room_id, limit, end_token): + async def get_recent_event_ids_for_room( + self, room_id: str, limit: int, end_token: str + ) -> Tuple[List[_EventDictReturn], str]: """Get the most recent events in the room in topological ordering. Args: - room_id (str) - limit (int) - end_token (str): The stream token representing now. + room_id + limit + end_token: The stream token representing now. Returns: - Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of - _EventDictReturn and a token pointing to the start of the returned - events. - The events returned are in ascending order. + A list of _EventDictReturn and a token pointing to the start of the + returned events. The events returned are in ascending order. """ # Allow a zero limit here, and no-op. if limit == 0: @@ -519,7 +526,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): end_token = RoomStreamToken.parse(end_token) - rows, token = yield self.db_pool.runInteraction( + rows, token = await self.db_pool.runInteraction( "get_recent_event_ids_for_room", self._paginate_room_events_txn, room_id, @@ -532,12 +539,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows, token - def get_room_event_before_stream_ordering(self, room_id, stream_ordering): + def get_room_event_before_stream_ordering(self, room_id: str, stream_ordering: int): """Gets details of the first event in a room at or before a stream ordering Args: - room_id (str): - stream_ordering (int): + room_id: + stream_ordering: Returns: Deferred[(int, int, str)]: @@ -574,55 +581,67 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) return "t%d-%d" % (topo, token) - def get_stream_token_for_event(self, event_id): - """The stream token for an event + async def get_stream_id_for_event(self, event_id: str) -> int: + """The stream ID for an event Args: - event_id(str): The id of the event to look up a stream token for. + event_id: The id of the event to look up a stream token for. Raises: StoreError if the event wasn't in the database. Returns: - A deferred "s%d" stream token. + A stream ID. """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" - ).addCallback(lambda row: "s%d" % (row,)) + ) - def get_topological_token_for_event(self, event_id): + async def get_stream_token_for_event(self, event_id: str) -> str: """The stream token for an event Args: - event_id(str): The id of the event to look up a stream token for. + event_id: The id of the event to look up a stream token for. Raises: StoreError if the event wasn't in the database. Returns: - A deferred "t%d-%d" topological token. + A "s%d" stream token. """ - return self.db_pool.simple_select_one( + stream_id = await self.get_stream_id_for_event(event_id) + return "s%d" % (stream_id,) + + async def get_topological_token_for_event(self, event_id: str) -> str: + """The stream token for an event + Args: + event_id: The id of the event to look up a stream token for. + Raises: + StoreError if the event wasn't in the database. + Returns: + A "t%d-%d" topological token. + """ + row = await self.db_pool.simple_select_one( table="events", keyvalues={"event_id": event_id}, retcols=("stream_ordering", "topological_ordering"), desc="get_topological_token_for_event", - ).addCallback( - lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"]) ) + return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"]) - def get_max_topological_token(self, room_id, stream_key): + async def get_max_topological_token(self, room_id: str, stream_key: int) -> int: """Get the max topological token in a room before the given stream ordering. Args: - room_id (str) - stream_key (int) + room_id + stream_key Returns: - Deferred[int] + The maximum topological token. """ sql = ( "SELECT coalesce(max(topological_ordering), 0) FROM events" " WHERE room_id = ? AND stream_ordering < ?" ) - return self.db_pool.execute( + row = await self.db_pool.execute( "get_max_topological_token", None, sql, room_id, stream_key - ).addCallback(lambda r: r[0][0] if r else 0) + ) + return row[0][0] if row else 0 def _get_max_topological_txn(self, txn, room_id): txn.execute( @@ -634,16 +653,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows[0][0] if rows else 0 @staticmethod - def _set_before_and_after(events, rows, topo_order=True): + def _set_before_and_after( + events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True + ): """Inserts ordering information to events' internal metadata from the DB rows. Args: - events (list[FrozenEvent]) - rows (list[_EventDictReturn]) - topo_order (bool): Whether the events were ordered topologically - or by stream ordering. If true then all rows should have a non - null topological_ordering. + events + rows + topo_order: Whether the events were ordered topologically or by stream + ordering. If true then all rows should have a non null + topological_ordering. """ for event, row in zip(events, rows): stream = row.stream_ordering @@ -656,25 +677,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): internal.after = str(RoomStreamToken(topo, stream)) internal.order = (int(topo) if topo else 0, int(stream)) - @defer.inlineCallbacks - def get_events_around( - self, room_id, event_id, before_limit, after_limit, event_filter=None - ): + async def get_events_around( + self, + room_id: str, + event_id: str, + before_limit: int, + after_limit: int, + event_filter: Optional[Filter] = None, + ) -> dict: """Retrieve events and pagination tokens around a given event in a room. - - Args: - room_id (str) - event_id (str) - before_limit (int) - after_limit (int) - event_filter (Filter|None) - - Returns: - dict """ - results = yield self.db_pool.runInteraction( + results = await self.db_pool.runInteraction( "get_events_around", self._get_events_around_txn, room_id, @@ -684,11 +699,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): event_filter, ) - events_before = yield self.get_events_as_list( + events_before = await self.get_events_as_list( list(results["before"]["event_ids"]), get_prev_content=True ) - events_after = yield self.get_events_as_list( + events_after = await self.get_events_as_list( list(results["after"]["event_ids"]), get_prev_content=True ) @@ -700,17 +715,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): } def _get_events_around_txn( - self, txn, room_id, event_id, before_limit, after_limit, event_filter - ): + self, + txn, + room_id: str, + event_id: str, + before_limit: int, + after_limit: int, + event_filter: Optional[Filter], + ) -> dict: """Retrieves event_ids and pagination tokens around a given event in a room. Args: - room_id (str) - event_id (str) - before_limit (int) - after_limit (int) - event_filter (Filter|None) + room_id + event_id + before_limit + after_limit + event_filter Returns: dict @@ -758,22 +779,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): "after": {"event_ids": events_after, "token": end_token}, } - @defer.inlineCallbacks - def get_all_new_events_stream(self, from_id, current_id, limit): + async def get_all_new_events_stream( + self, from_id: int, current_id: int, limit: int + ) -> Tuple[int, List[EventBase]]: """Get all new events Returns all events with from_id < stream_ordering <= current_id. Args: - from_id (int): the stream_ordering of the last event we processed - current_id (int): the stream_ordering of the most recently processed event - limit (int): the maximum number of events to return + from_id: the stream_ordering of the last event we processed + current_id: the stream_ordering of the most recently processed event + limit: the maximum number of events to return Returns: - Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where - `next_id` is the next value to pass as `from_id` (it will either be the - stream_ordering of the last returned event, or, if fewer than `limit` events - were found, `current_id`. + A tuple of (next_id, events), where `next_id` is the next value to + pass as `from_id` (it will either be the stream_ordering of the + last returned event, or, if fewer than `limit` events were found, + the `current_id`). """ def get_all_new_events_stream_txn(txn): @@ -795,11 +817,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return upper_bound, [row[1] for row in rows] - upper_bound, event_ids = yield self.db_pool.runInteraction( + upper_bound, event_ids = await self.db_pool.runInteraction( "get_all_new_events_stream", get_all_new_events_stream_txn ) - events = yield self.get_events_as_list(event_ids) + events = await self.get_events_as_list(event_ids) return upper_bound, events @@ -817,21 +839,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): desc="get_federation_out_pos", ) - async def update_federation_out_pos(self, typ, stream_id): + async def update_federation_out_pos(self, typ: str, stream_id: int) -> None: if self._need_to_reset_federation_stream_positions: await self.db_pool.runInteraction( "_reset_federation_positions_txn", self._reset_federation_positions_txn ) self._need_to_reset_federation_stream_positions = False - return await self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="federation_stream_position", keyvalues={"type": typ, "instance_name": self._instance_name}, updatevalues={"stream_id": stream_id}, desc="update_federation_out_pos", ) - def _reset_federation_positions_txn(self, txn): + def _reset_federation_positions_txn(self, txn) -> None: """Fiddles with the `federation_stream_position` table to make it match the configured federation sender instances during start up. """ @@ -892,39 +914,37 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): values={"stream_id": stream_id}, ) - def has_room_changed_since(self, room_id, stream_id): + def has_room_changed_since(self, room_id: str, stream_id: int) -> bool: return self._events_stream_cache.has_entity_changed(room_id, stream_id) def _paginate_room_events_txn( self, txn, - room_id, - from_token, - to_token=None, - direction="b", - limit=-1, - event_filter=None, - ): + room_id: str, + from_token: RoomStreamToken, + to_token: Optional[RoomStreamToken] = None, + direction: str = "b", + limit: int = -1, + event_filter: Optional[Filter] = None, + ) -> Tuple[List[_EventDictReturn], str]: """Returns list of events before or after a given token. Args: txn - room_id (str) - from_token (RoomStreamToken): The token used to stream from - to_token (RoomStreamToken|None): A token which if given limits the - results to only those before - direction(char): Either 'b' or 'f' to indicate whether we are - paginating forwards or backwards from `from_key`. - limit (int): The maximum number of events to return. - event_filter (Filter|None): If provided filters the events to + room_id + from_token: The token used to stream from + to_token: A token which if given limits the results to only those before + direction: Either 'b' or 'f' to indicate whether we are paginating + forwards or backwards from `from_key`. + limit: The maximum number of events to return. + event_filter: If provided filters the events to those that match the filter. Returns: - Deferred[tuple[list[_EventDictReturn], str]]: Returns the results - as a list of _EventDictReturn and a token that points to the end - of the result set. If no events are returned then the end of the - stream has been reached (i.e. there are no events between - `from_token` and `to_token`), or `limit` is zero. + A list of _EventDictReturn and a token that points to the end of the + result set. If no events are returned then the end of the stream has + been reached (i.e. there are no events between `from_token` and + `to_token`), or `limit` is zero. """ assert int(limit) >= 0 @@ -1008,35 +1028,38 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows, str(next_token) - @defer.inlineCallbacks - def paginate_room_events( - self, room_id, from_key, to_key=None, direction="b", limit=-1, event_filter=None - ): + async def paginate_room_events( + self, + room_id: str, + from_key: str, + to_key: Optional[str] = None, + direction: str = "b", + limit: int = -1, + event_filter: Optional[Filter] = None, + ) -> Tuple[List[EventBase], str]: """Returns list of events before or after a given token. Args: - room_id (str) - from_key (str): The token used to stream from - to_key (str|None): A token which if given limits the results to - only those before - direction(char): Either 'b' or 'f' to indicate whether we are - paginating forwards or backwards from `from_key`. - limit (int): The maximum number of events to return. - event_filter (Filter|None): If provided filters the events to - those that match the filter. + room_id + from_key: The token used to stream from + to_key: A token which if given limits the results to only those before + direction: Either 'b' or 'f' to indicate whether we are paginating + forwards or backwards from `from_key`. + limit: The maximum number of events to return. + event_filter: If provided filters the events to those that match the filter. Returns: - tuple[list[FrozenEvent], str]: Returns the results as a list of - events and a token that points to the end of the result set. If no - events are returned then the end of the stream has been reached - (i.e. there are no events between `from_key` and `to_key`). + The results as a list of events and a token that points to the end + of the result set. If no events are returned then the end of the + stream has been reached (i.e. there are no events between `from_key` + and `to_key`). """ from_key = RoomStreamToken.parse(from_key) if to_key: to_key = RoomStreamToken.parse(to_key) - rows, token = yield self.db_pool.runInteraction( + rows, token = await self.db_pool.runInteraction( "paginate_room_events", self._paginate_room_events_txn, room_id, @@ -1047,7 +1070,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): event_filter, ) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) @@ -1057,8 +1080,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): class StreamStore(StreamWorkerStore): - def get_room_max_stream_ordering(self): + def get_room_max_stream_ordering(self) -> int: return self._stream_id_gen.get_current_token() - def get_room_min_stream_ordering(self): + def get_room_min_stream_ordering(self) -> int: return self._backfill_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index e4e0a0c433..0c34bbf21a 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py
@@ -17,11 +17,10 @@ import logging from typing import Dict, List, Tuple -from canonicaljson import json - from synapse.storage._base import db_to_json from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.types import JsonDict +from synapse.util import json_encoder from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -98,7 +97,7 @@ class TagsWorkerStore(AccountDataWorkerStore): txn.execute(sql, (user_id, room_id)) tags = [] for tag, content in txn: - tags.append(json.dumps(tag) + ":" + content) + tags.append(json_encoder.encode(tag) + ":" + content) tag_json = "{" + ",".join(tags) + "}" results.append((stream_id, (user_id, room_id, tag_json))) @@ -200,7 +199,7 @@ class TagsStore(TagsWorkerStore): Returns: The next account data ID. """ - content_json = json.dumps(content) + content_json = json_encoder.encode(content) def add_tag_txn(txn, next_id): self.db_pool.simple_upsert_txn( @@ -211,7 +210,7 @@ class TagsStore(TagsWorkerStore): ) self._update_revision_txn(txn, user_id, room_id, next_id) - with self._account_data_id_gen.get_next() as next_id: + with await self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) @@ -233,7 +232,7 @@ class TagsStore(TagsWorkerStore): txn.execute(sql, (user_id, room_id, tag)) self._update_revision_txn(txn, user_id, room_id, next_id) - with self._account_data_id_gen.get_next() as next_id: + with await self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 37276f73f8..9eef8e57c5 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py
@@ -12,15 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import attr -from canonicaljson import json from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import LoggingTransaction from synapse.types import JsonDict -from synapse.util import stringutils as stringutils +from synapse.util import json_encoder, stringutils @attr.s @@ -72,7 +72,7 @@ class UIAuthWorkerStore(SQLBaseStore): StoreError if a unique session ID cannot be generated. """ # The clientdict gets stored as JSON. - clientdict_json = json.dumps(clientdict) + clientdict_json = json_encoder.encode(clientdict) # autogen a session ID and try to create it. We may clash, so just # try a few times till one goes through, giving up eventually. @@ -143,7 +143,7 @@ class UIAuthWorkerStore(SQLBaseStore): await self.db_pool.simple_upsert( table="ui_auth_sessions_credentials", keyvalues={"session_id": session_id, "stage_type": stage_type}, - values={"result": json.dumps(result)}, + values={"result": json_encoder.encode(result)}, desc="mark_ui_auth_stage_complete", ) except self.db_pool.engine.module.IntegrityError: @@ -184,7 +184,7 @@ class UIAuthWorkerStore(SQLBaseStore): The dictionary from the client root level, not the 'auth' key. """ # The clientdict gets stored as JSON. - clientdict_json = json.dumps(clientdict) + clientdict_json = json_encoder.encode(clientdict) await self.db_pool.simple_update_one( table="ui_auth_sessions", @@ -214,14 +214,16 @@ class UIAuthWorkerStore(SQLBaseStore): value, ) - def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any): + def _set_ui_auth_session_data_txn( + self, txn: LoggingTransaction, session_id: str, key: str, value: Any + ): # Get the current value. result = self.db_pool.simple_select_one_txn( txn, table="ui_auth_sessions", keyvalues={"session_id": session_id}, retcols=("serverdict",), - ) + ) # type: Dict[str, Any] # type: ignore # Update it and add it back to the database. serverdict = db_to_json(result["serverdict"]) @@ -231,7 +233,7 @@ class UIAuthWorkerStore(SQLBaseStore): txn, table="ui_auth_sessions", keyvalues={"session_id": session_id}, - updatevalues={"serverdict": json.dumps(serverdict)}, + updatevalues={"serverdict": json_encoder.encode(serverdict)}, ) async def get_ui_auth_session_data( @@ -258,6 +260,34 @@ class UIAuthWorkerStore(SQLBaseStore): return serverdict.get(key, default) + async def add_user_agent_ip_to_ui_auth_session( + self, session_id: str, user_agent: str, ip: str, + ): + """Add the given user agent / IP to the tracking table + """ + await self.db_pool.simple_upsert( + table="ui_auth_sessions_ips", + keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip}, + values={}, + desc="add_user_agent_ip_to_ui_auth_session", + ) + + async def get_user_agents_ips_to_ui_auth_session( + self, session_id: str, + ) -> List[Tuple[str, str]]: + """Get the given user agents / IPs used during the ui auth process + + Returns: + List of user_agent/ip pairs + """ + rows = await self.db_pool.simple_select_list( + table="ui_auth_sessions_ips", + keyvalues={"session_id": session_id}, + retcols=("user_agent", "ip"), + desc="get_user_agents_ips_to_ui_auth_session", + ) + return [(row["user_agent"], row["ip"]) for row in rows] + class UIAuthStore(UIAuthWorkerStore): def delete_old_ui_auth_sessions(self, expiration_time: int): @@ -275,12 +305,23 @@ class UIAuthStore(UIAuthWorkerStore): expiration_time, ) - def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int): + def _delete_old_ui_auth_sessions_txn( + self, txn: LoggingTransaction, expiration_time: int + ): # Get the expired sessions. sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?" txn.execute(sql, [expiration_time]) session_ids = [r[0] for r in txn.fetchall()] + # Delete the corresponding IP/user agents. + self.db_pool.simple_delete_many_txn( + txn, + table="ui_auth_sessions_ips", + column="session_id", + iterable=session_ids, + keyvalues={}, + ) + # Delete the corresponding completed credentials. self.db_pool.simple_delete_many_txn( txn, diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index ab6cb2c1f6..e3547e53b3 100644 --- a/synapse/storage/databases/main/user_erasure_store.py +++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -13,35 +13,32 @@ # See the License for the specific language governing permissions and # limitations under the License. -import operator - from synapse.storage._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedList class UserErasureWorkerStore(SQLBaseStore): @cached() - def is_user_erased(self, user_id): + async def is_user_erased(self, user_id: str) -> bool: """ Check if the given user id has requested erasure Args: - user_id (str): full user id to check + user_id: full user id to check Returns: - Deferred[bool]: True if the user has requested erasure + True if the user has requested erasure """ - return self.db_pool.simple_select_onecol( + result = await self.db_pool.simple_select_onecol( table="erased_users", keyvalues={"user_id": user_id}, retcol="1", desc="is_user_erased", - ).addCallback(operator.truth) + ) + return bool(result) - @cachedList( - cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True - ) - def are_users_erased(self, user_ids): + @cachedList(cached_method_name="is_user_erased", list_name="user_ids") + async def are_users_erased(self, user_ids): """ Checks which users in a list have requested erasure @@ -49,14 +46,14 @@ class UserErasureWorkerStore(SQLBaseStore): user_ids (iterable[str]): full user id to check Returns: - Deferred[dict[str, bool]]: + dict[str, bool]: for each user, whether the user has requested erasure. """ # this serves the dual purpose of (a) making sure we can do len and # iterate it multiple times, and (b) avoiding duplicates. user_ids = tuple(set(user_ids)) - rows = yield self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="erased_users", column="user_id", iterable=user_ids, @@ -65,8 +62,7 @@ class UserErasureWorkerStore(SQLBaseStore): ) erased_users = {row["user_id"] for row in rows} - res = {u: u in erased_users for u in user_ids} - return res + return {u: u in erased_users for u in user_ids} class UserErasureStore(UserErasureWorkerStore):