summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/__init__.py6
-rw-r--r--synapse/storage/databases/main/directory.py4
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py188
-rw-r--r--synapse/storage/databases/main/event_federation.py218
-rw-r--r--synapse/storage/databases/main/events.py86
-rw-r--r--synapse/storage/databases/main/events_worker.py161
-rw-r--r--synapse/storage/databases/main/presence.py23
-rw-r--r--synapse/storage/databases/main/purge_events.py1
-rw-r--r--synapse/storage/databases/main/pusher.py73
-rw-r--r--synapse/storage/databases/main/registration.py393
-rw-r--r--synapse/storage/databases/main/room.py358
-rw-r--r--synapse/storage/databases/main/roommember.py22
-rw-r--r--synapse/storage/databases/main/session.py145
-rw-r--r--synapse/storage/databases/main/ui_auth.py43
-rw-r--r--synapse/storage/databases/main/user_directory.py2
15 files changed, 1355 insertions, 368 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py

index 8d9f07111d..00a644e8f7 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py
@@ -63,6 +63,7 @@ from .relations import RelationsStore from .room import RoomStore from .roommember import RoomMemberStore from .search import SearchStore +from .session import SessionStore from .signatures import SignatureStore from .state import StateStore from .stats import StatsStore @@ -121,15 +122,13 @@ class DataStore( ServerMetricsStore, EventForwardExtremitiesStore, LockStore, + SessionStore, ): def __init__(self, database: DatabasePool, db_conn, hs): self.hs = hs self._clock = hs.get_clock() self.database_engine = database.engine - self._public_room_id_gen = StreamIdGenerator( - db_conn, "public_room_list_stream", "stream_id" - ) self._device_list_id_gen = StreamIdGenerator( db_conn, "device_lists_stream", @@ -170,6 +169,7 @@ class DataStore( sequence_name="cache_invalidation_stream_seq", writers=[], ) + else: self._cache_id_gen = None diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 86075bc55b..6daf8b8ffb 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py
@@ -75,8 +75,6 @@ class DirectoryWorkerStore(SQLBaseStore): desc="get_aliases_for_room", ) - -class DirectoryStore(DirectoryWorkerStore): async def create_room_alias_association( self, room_alias: RoomAlias, @@ -126,6 +124,8 @@ class DirectoryStore(DirectoryWorkerStore): 409, "Room alias %s already exists" % room_alias.to_string() ) + +class DirectoryStore(DirectoryWorkerStore): async def delete_room_alias(self, room_alias: RoomAlias) -> str: room_id = await self.db_pool.runInteraction( "delete_room_alias", self._delete_room_alias_txn, room_alias diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 1edc96042b..1f0a39eac4 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -755,81 +755,145 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): """ @trace - def _claim_e2e_one_time_keys(txn): - sql = ( - "SELECT key_id, key_json FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ? AND algorithm = ?" - " LIMIT 1" + def _claim_e2e_one_time_key_simple( + txn, user_id: str, device_id: str, algorithm: str + ) -> Optional[Tuple[str, str]]: + """Claim OTK for device for DBs that don't support RETURNING. + + Returns: + A tuple of key name (algorithm + key ID) and key JSON, if an + OTK was found. + """ + + sql = """ + SELECT key_id, key_json FROM e2e_one_time_keys_json + WHERE user_id = ? AND device_id = ? AND algorithm = ? + LIMIT 1 + """ + + txn.execute(sql, (user_id, device_id, algorithm)) + otk_row = txn.fetchone() + if otk_row is None: + return None + + key_id, key_json = otk_row + + self.db_pool.simple_delete_one_txn( + txn, + table="e2e_one_time_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + "key_id": key_id, + }, ) - fallback_sql = ( - "SELECT key_id, key_json, used FROM e2e_fallback_keys_json" - " WHERE user_id = ? AND device_id = ? AND algorithm = ?" - " LIMIT 1" + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - result = {} - delete = [] - used_fallbacks = [] - for user_id, device_id, algorithm in query_list: - user_result = result.setdefault(user_id, {}) - device_result = user_result.setdefault(device_id, {}) - txn.execute(sql, (user_id, device_id, algorithm)) - otk_row = txn.fetchone() - if otk_row is not None: - key_id, key_json = otk_row - device_result[algorithm + ":" + key_id] = key_json - delete.append((user_id, device_id, algorithm, key_id)) - else: - # no one-time key available, so see if there's a fallback - # key - txn.execute(fallback_sql, (user_id, device_id, algorithm)) - fallback_row = txn.fetchone() - if fallback_row is not None: - key_id, key_json, used = fallback_row - device_result[algorithm + ":" + key_id] = key_json - if not used: - used_fallbacks.append( - (user_id, device_id, algorithm, key_id) - ) - - # drop any one-time keys that were claimed - sql = ( - "DELETE FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ? AND algorithm = ?" - " AND key_id = ?" + + return f"{algorithm}:{key_id}", key_json + + @trace + def _claim_e2e_one_time_key_returning( + txn, user_id: str, device_id: str, algorithm: str + ) -> Optional[Tuple[str, str]]: + """Claim OTK for device for DBs that support RETURNING. + + Returns: + A tuple of key name (algorithm + key ID) and key JSON, if an + OTK was found. + """ + + # We can use RETURNING to do the fetch and DELETE in once step. + sql = """ + DELETE FROM e2e_one_time_keys_json + WHERE user_id = ? AND device_id = ? AND algorithm = ? + AND key_id IN ( + SELECT key_id FROM e2e_one_time_keys_json + WHERE user_id = ? AND device_id = ? AND algorithm = ? + LIMIT 1 + ) + RETURNING key_id, key_json + """ + + txn.execute( + sql, (user_id, device_id, algorithm, user_id, device_id, algorithm) ) - for user_id, device_id, algorithm, key_id in delete: - log_kv( - { - "message": "Executing claim e2e_one_time_keys transaction on database." - } - ) - txn.execute(sql, (user_id, device_id, algorithm, key_id)) - log_kv({"message": "finished executing and invalidating cache"}) - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) + otk_row = txn.fetchone() + if otk_row is None: + return None + + key_id, key_json = otk_row + return f"{algorithm}:{key_id}", key_json + + results = {} + for user_id, device_id, algorithm in query_list: + if self.database_engine.supports_returning: + # If we support RETURNING clause we can use a single query that + # allows us to use autocommit mode. + _claim_e2e_one_time_key = _claim_e2e_one_time_key_returning + db_autocommit = True + else: + _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple + db_autocommit = False + + row = await self.db_pool.runInteraction( + "claim_e2e_one_time_keys", + _claim_e2e_one_time_key, + user_id, + device_id, + algorithm, + db_autocommit=db_autocommit, + ) + if row: + device_results = results.setdefault(user_id, {}).setdefault( + device_id, {} ) - # mark fallback keys as used - for user_id, device_id, algorithm, key_id in used_fallbacks: - self.db_pool.simple_update_txn( - txn, - "e2e_fallback_keys_json", - { + device_results[row[0]] = row[1] + continue + + # No one-time key available, so see if there's a fallback + # key + row = await self.db_pool.simple_select_one( + table="e2e_fallback_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + }, + retcols=("key_id", "key_json", "used"), + desc="_get_fallback_key", + allow_none=True, + ) + if row is None: + continue + + key_id = row["key_id"] + key_json = row["key_json"] + used = row["used"] + + # Mark fallback key as used if not already. + if not used: + await self.db_pool.simple_update_one( + table="e2e_fallback_keys_json", + keyvalues={ "user_id": user_id, "device_id": device_id, "algorithm": algorithm, "key_id": key_id, }, - {"used": True}, + updatevalues={"used": True}, + desc="_get_fallback_key_set_used", ) - self._invalidate_cache_and_stream( - txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) + await self.invalidate_cache_and_stream( + "get_e2e_unused_fallback_key_types", (user_id, device_id) ) - return result + device_results = results.setdefault(user_id, {}).setdefault(device_id, {}) + device_results[f"{algorithm}:{key_id}"] = key_json - return await self.db_pool.runInteraction( - "claim_e2e_one_time_keys", _claim_e2e_one_time_keys - ) + return results class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 547e43ab98..bddf5ef192 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py
@@ -16,11 +16,11 @@ import logging from queue import Empty, PriorityQueue from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple -from prometheus_client import Gauge +from prometheus_client import Counter, Gauge from synapse.api.constants import MAX_DEPTH from synapse.api.errors import StoreError -from synapse.api.room_versions import RoomVersion +from synapse.api.room_versions import EventFormatVersions, RoomVersion from synapse.events import EventBase, make_event_from_dict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -44,6 +44,12 @@ number_pdus_in_federation_queue = Gauge( "The total number of events in the inbound federation staging", ) +pdus_pruned_from_federation_queue = Counter( + "synapse_federation_server_number_inbound_pdu_pruned", + "The number of events in the inbound federation staging that have been " + "pruned due to the queue getting too long", +) + logger = logging.getLogger(__name__) @@ -665,27 +671,97 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas # Return all events where not all sets can reach them. return {eid for eid, n in event_to_missing_sets.items() if n} - async def get_oldest_events_with_depth_in_room(self, room_id): + async def get_oldest_event_ids_with_depth_in_room(self, room_id) -> Dict[str, int]: + """Gets the oldest events(backwards extremities) in the room along with the + aproximate depth. + + We use this function so that we can compare and see if someones current + depth at their current scrollback is within pagination range of the + event extremeties. If the current depth is close to the depth of given + oldest event, we can trigger a backfill. + + Args: + room_id: Room where we want to find the oldest events + + Returns: + Map from event_id to depth + """ + + def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id): + # Assemble a dictionary with event_id -> depth for the oldest events + # we know of in the room. Backwards extremeties are the oldest + # events we know of in the room but we only know of them because + # some other event referenced them by prev_event and aren't peristed + # in our database yet (meaning we don't know their depth + # specifically). So we need to look for the aproximate depth from + # the events connected to the current backwards extremeties. + sql = """ + SELECT b.event_id, MAX(e.depth) FROM events as e + /** + * Get the edge connections from the event_edges table + * so we can see whether this event's prev_events points + * to a backward extremity in the next join. + */ + INNER JOIN event_edges as g + ON g.event_id = e.event_id + /** + * We find the "oldest" events in the room by looking for + * events connected to backwards extremeties (oldest events + * in the room that we know of so far). + */ + INNER JOIN event_backward_extremities as b + ON g.prev_event_id = b.event_id + WHERE b.room_id = ? AND g.is_state is ? + GROUP BY b.event_id + """ + + txn.execute(sql, (room_id, False)) + + return dict(txn) + return await self.db_pool.runInteraction( - "get_oldest_events_with_depth_in_room", - self.get_oldest_events_with_depth_in_room_txn, + "get_oldest_event_ids_with_depth_in_room", + get_oldest_event_ids_with_depth_in_room_txn, room_id, ) - def get_oldest_events_with_depth_in_room_txn(self, txn, room_id): - sql = ( - "SELECT b.event_id, MAX(e.depth) FROM events as e" - " INNER JOIN event_edges as g" - " ON g.event_id = e.event_id" - " INNER JOIN event_backward_extremities as b" - " ON g.prev_event_id = b.event_id" - " WHERE b.room_id = ? AND g.is_state is ?" - " GROUP BY b.event_id" - ) + async def get_insertion_event_backwards_extremities_in_room( + self, room_id + ) -> Dict[str, int]: + """Get the insertion events we know about that we haven't backfilled yet. - txn.execute(sql, (room_id, False)) + We use this function so that we can compare and see if someones current + depth at their current scrollback is within pagination range of the + insertion event. If the current depth is close to the depth of given + insertion event, we can trigger a backfill. - return dict(txn) + Args: + room_id: Room where we want to find the oldest events + + Returns: + Map from event_id to depth + """ + + def get_insertion_event_backwards_extremities_in_room_txn(txn, room_id): + sql = """ + SELECT b.event_id, MAX(e.depth) FROM insertion_events as i + /* We only want insertion events that are also marked as backwards extremities */ + INNER JOIN insertion_event_extremities as b USING (event_id) + /* Get the depth of the insertion event from the events table */ + INNER JOIN events AS e USING (event_id) + WHERE b.room_id = ? + GROUP BY b.event_id + """ + + txn.execute(sql, (room_id,)) + + return dict(txn) + + return await self.db_pool.runInteraction( + "get_insertion_event_backwards_extremities_in_room", + get_insertion_event_backwards_extremities_in_room_txn, + room_id, + ) async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[str, int]: """Returns the event ID and depth for the event that has the max depth from a set of event IDs @@ -1035,7 +1111,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas if row[1] not in event_results: queue.put((-row[0], row[1])) - # Navigate up the DAG by prev_event txn.execute(query, (event_id, False, limit - len(event_results))) prev_event_id_results = txn.fetchall() logger.debug( @@ -1130,6 +1205,19 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas _delete_old_forward_extrem_cache_txn, ) + async def insert_insertion_extremity(self, event_id: str, room_id: str) -> None: + await self.db_pool.simple_upsert( + table="insertion_event_extremities", + keyvalues={"event_id": event_id}, + values={ + "event_id": event_id, + "room_id": room_id, + }, + insertion_values={}, + desc="insert_insertion_extremity", + lock=False, + ) + async def insert_received_event_to_staging( self, origin: str, event: EventBase ) -> None: @@ -1277,6 +1365,100 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return origin, event + async def prune_staged_events_in_room( + self, + room_id: str, + room_version: RoomVersion, + ) -> bool: + """Checks if there are lots of staged events for the room, and if so + prune them down. + + Returns: + Whether any events were pruned + """ + + # First check the size of the queue. + count = await self.db_pool.simple_select_one_onecol( + table="federation_inbound_events_staging", + keyvalues={"room_id": room_id}, + retcol="COALESCE(COUNT(*), 0)", + desc="prune_staged_events_in_room_count", + ) + + if count < 100: + return False + + # If the queue is too large, then we want clear the entire queue, + # keeping only the forward extremities (i.e. the events not referenced + # by other events in the queue). We do this so that we can always + # backpaginate in all the events we have dropped. + rows = await self.db_pool.simple_select_list( + table="federation_inbound_events_staging", + keyvalues={"room_id": room_id}, + retcols=("event_id", "event_json"), + desc="prune_staged_events_in_room_fetch", + ) + + # Find the set of events referenced by those in the queue, as well as + # collecting all the event IDs in the queue. + referenced_events: Set[str] = set() + seen_events: Set[str] = set() + for row in rows: + event_id = row["event_id"] + seen_events.add(event_id) + event_d = db_to_json(row["event_json"]) + + # We don't bother parsing the dicts into full blown event objects, + # as that is needlessly expensive. + + # We haven't checked that the `prev_events` have the right format + # yet, so we check as we go. + prev_events = event_d.get("prev_events", []) + if not isinstance(prev_events, list): + logger.info("Invalid prev_events for %s", event_id) + continue + + if room_version.event_format == EventFormatVersions.V1: + for prev_event_tuple in prev_events: + if not isinstance(prev_event_tuple, list) or len(prev_events) != 2: + logger.info("Invalid prev_events for %s", event_id) + break + + prev_event_id = prev_event_tuple[0] + if not isinstance(prev_event_id, str): + logger.info("Invalid prev_events for %s", event_id) + break + + referenced_events.add(prev_event_id) + else: + for prev_event_id in prev_events: + if not isinstance(prev_event_id, str): + logger.info("Invalid prev_events for %s", event_id) + break + + referenced_events.add(prev_event_id) + + to_delete = referenced_events & seen_events + if not to_delete: + return False + + pdus_pruned_from_federation_queue.inc(len(to_delete)) + logger.info( + "Pruning %d events in room %s from federation queue", + len(to_delete), + room_id, + ) + + await self.db_pool.simple_delete_many( + table="federation_inbound_events_staging", + keyvalues={"room_id": room_id}, + iterable=to_delete, + column="event_id", + desc="prune_staged_events_in_room_delete", + ) + + return True + async def get_all_rooms_with_staged_incoming_events(self) -> List[str]: """Get the room IDs of all events currently staged.""" return await self.db_pool.simple_select_onecol( diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 2046b8e276..6a30aa6f81 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -578,7 +578,13 @@ class PersistEventsStore: missing_auth_chains.clear() - for auth_id, event_type, state_key, chain_id, sequence_number in txn: + for ( + auth_id, + event_type, + state_key, + chain_id, + sequence_number, + ) in txn.fetchall(): event_to_types[auth_id] = (event_type, state_key) if chain_id is None: @@ -1382,18 +1388,18 @@ class PersistEventsStore: # If we're persisting an unredacted event we go and ensure # that we mark any redactions that reference this event as # requiring censoring. - sql = "UPDATE redactions SET have_censored = ? WHERE redacts = ?" - txn.execute_batch( - sql, - ( - ( - False, - event.event_id, - ) - for event, _ in events_and_contexts - if not event.internal_metadata.is_redacted() - ), + unredacted_events = [ + event.event_id + for event, _ in events_and_contexts + if not event.internal_metadata.is_redacted() + ] + sql = "UPDATE redactions SET have_censored = ? WHERE " + clause, args = make_in_list_sql_clause( + self.database_engine, + "redacts", + unredacted_events, ) + txn.execute(sql + clause, [False] + args) state_events_and_contexts = [ ec for ec in events_and_contexts if ec[0].is_state() @@ -1773,10 +1779,21 @@ class PersistEventsStore: # Not a insertion event return - # Skip processing a insertion event if the room version doesn't - # support it. + # Skip processing an insertion event if the room version doesn't + # support it or the event is not from the room creator. room_version = self.store.get_room_version_txn(txn, event.room_id) - if not room_version.msc2716_historical: + room_creator = self.db_pool.simple_select_one_onecol_txn( + txn, + table="rooms", + keyvalues={"room_id": event.room_id}, + retcol="creator", + allow_none=True, + ) + if ( + not room_version.msc2716_historical + or not self.hs.config.experimental.msc2716_enabled + or event.sender != room_creator + ): return next_chunk_id = event.content.get(EventContentFields.MSC2716_NEXT_CHUNK_ID) @@ -1825,9 +1842,20 @@ class PersistEventsStore: return # Skip processing a chunk event if the room version doesn't - # support it. + # support it or the event is not from the room creator. room_version = self.store.get_room_version_txn(txn, event.room_id) - if not room_version.msc2716_historical: + room_creator = self.db_pool.simple_select_one_onecol_txn( + txn, + table="rooms", + keyvalues={"room_id": event.room_id}, + retcol="creator", + allow_none=True, + ) + if ( + not room_version.msc2716_historical + or not self.hs.config.experimental.msc2716_enabled + or event.sender != room_creator + ): return chunk_id = event.content.get(EventContentFields.MSC2716_CHUNK_ID) @@ -1848,6 +1876,18 @@ class PersistEventsStore: }, ) + # When we receive an event with a `chunk_id` referencing the + # `next_chunk_id` of the insertion event, we can remove it from the + # `insertion_event_extremities` table. + sql = """ + DELETE FROM insertion_event_extremities WHERE event_id IN ( + SELECT event_id FROM insertion_events + WHERE next_chunk_id = ? + ) + """ + + txn.execute(sql, (chunk_id,)) + def _handle_redaction(self, txn, redacted_event_id): """Handles receiving a redaction and checking whether we need to remove any redacted relations from the database. @@ -2104,15 +2144,17 @@ class PersistEventsStore: Forward extremities are handled when we first start persisting the events. """ + # From the events passed in, add all of the prev events as backwards extremities. + # Ignore any events that are already backwards extrems or outliers. query = ( "INSERT INTO event_backward_extremities (event_id, room_id)" " SELECT ?, ? WHERE NOT EXISTS (" - " SELECT 1 FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" + " SELECT 1 FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" " )" " AND NOT EXISTS (" - " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " - " AND outlier = ?" + " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " + " AND outlier = ?" " )" ) @@ -2126,6 +2168,8 @@ class PersistEventsStore: ], ) + # Delete all these events that we've already fetched and now know that their + # prev events are the new backwards extremeties. query = ( "DELETE FROM event_backward_extremities" " WHERE event_id = ? AND room_id = ?" diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 3c86adab56..9501f00f3b 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py
@@ -14,7 +14,6 @@ import logging import threading -from collections import namedtuple from typing import ( Collection, Container, @@ -27,6 +26,7 @@ from typing import ( overload, ) +import attr from constantly import NamedConstant, Names from typing_extensions import Literal @@ -42,7 +42,11 @@ from synapse.api.room_versions import ( from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.events.utils import prune_event -from synapse.logging.context import PreserveLoggingContext, current_context +from synapse.logging.context import ( + PreserveLoggingContext, + current_context, + make_deferred_yieldable, +) from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, @@ -56,6 +60,8 @@ from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id +from synapse.util import unwrapFirstError +from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter @@ -74,7 +80,10 @@ EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events -_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) +@attr.s(slots=True, auto_attribs=True) +class _EventCacheEntry: + event: EventBase + redacted_event: Optional[EventBase] class EventRedactBehaviour(Names): @@ -161,6 +170,13 @@ class EventsWorkerStore(SQLBaseStore): max_size=hs.config.caches.event_cache_size, ) + # Map from event ID to a deferred that will result in a map from event + # ID to cache entry. Note that the returned dict may not have the + # requested event in it if the event isn't in the DB. + self._current_event_fetches: Dict[ + str, ObservableDeferred[Dict[str, _EventCacheEntry]] + ] = {} + self._event_fetch_lock = threading.Condition() self._event_fetch_list = [] self._event_fetch_ongoing = 0 @@ -476,7 +492,9 @@ class EventsWorkerStore(SQLBaseStore): return events - async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): + async def _get_events_from_cache_or_db( + self, event_ids: Iterable[str], allow_rejected: bool = False + ) -> Dict[str, _EventCacheEntry]: """Fetch a bunch of events from the cache or the database. If events are pulled from the database, they will be cached for future lookups. @@ -485,53 +503,124 @@ class EventsWorkerStore(SQLBaseStore): Args: - event_ids (Iterable[str]): The event_ids of the events to fetch + event_ids: The event_ids of the events to fetch - allow_rejected (bool): Whether to include rejected events. If False, + allow_rejected: Whether to include rejected events. If False, rejected events are omitted from the response. Returns: - Dict[str, _EventCacheEntry]: - map from event id to result + map from event id to result """ event_entry_map = self._get_events_from_cache( - event_ids, allow_rejected=allow_rejected + event_ids, ) - missing_events_ids = [e for e in event_ids if e not in event_entry_map] + missing_events_ids = {e for e in event_ids if e not in event_entry_map} + + # We now look up if we're already fetching some of the events in the DB, + # if so we wait for those lookups to finish instead of pulling the same + # events out of the DB multiple times. + # + # Note: we might get the same `ObservableDeferred` back for multiple + # events we're already fetching, so we deduplicate the deferreds to + # avoid extraneous work (if we don't do this we can end up in a n^2 mode + # when we wait on the same Deferred N times, then try and merge the + # same dict into itself N times). + already_fetching_ids: Set[str] = set() + already_fetching_deferreds: Set[ + ObservableDeferred[Dict[str, _EventCacheEntry]] + ] = set() + + for event_id in missing_events_ids: + deferred = self._current_event_fetches.get(event_id) + if deferred is not None: + # We're already pulling the event out of the DB. Add the deferred + # to the collection of deferreds to wait on. + already_fetching_ids.add(event_id) + already_fetching_deferreds.add(deferred) + + missing_events_ids.difference_update(already_fetching_ids) if missing_events_ids: log_ctx = current_context() log_ctx.record_event_fetch(len(missing_events_ids)) + # Add entries to `self._current_event_fetches` for each event we're + # going to pull from the DB. We use a single deferred that resolves + # to all the events we pulled from the DB (this will result in this + # function returning more events than requested, but that can happen + # already due to `_get_events_from_db`). + fetching_deferred: ObservableDeferred[ + Dict[str, _EventCacheEntry] + ] = ObservableDeferred(defer.Deferred()) + for event_id in missing_events_ids: + self._current_event_fetches[event_id] = fetching_deferred + # Note that _get_events_from_db is also responsible for turning db rows # into FrozenEvents (via _get_event_from_row), which involves seeing if # the events have been redacted, and if so pulling the redaction event out # of the database to check it. # - missing_events = await self._get_events_from_db( - missing_events_ids, allow_rejected=allow_rejected - ) + try: + missing_events = await self._get_events_from_db( + missing_events_ids, + ) + + event_entry_map.update(missing_events) + except Exception as e: + with PreserveLoggingContext(): + fetching_deferred.errback(e) + raise e + finally: + # Ensure that we mark these events as no longer being fetched. + for event_id in missing_events_ids: + self._current_event_fetches.pop(event_id, None) + + with PreserveLoggingContext(): + fetching_deferred.callback(missing_events) + + if already_fetching_deferreds: + # Wait for the other event requests to finish and add their results + # to ours. + results = await make_deferred_yieldable( + defer.gatherResults( + (d.observe() for d in already_fetching_deferreds), + consumeErrors=True, + ) + ).addErrback(unwrapFirstError) + + for result in results: + # We filter out events that we haven't asked for as we might get + # a *lot* of superfluous events back, and there is no point + # going through and inserting them all (which can take time). + event_entry_map.update( + (event_id, entry) + for event_id, entry in result.items() + if event_id in already_fetching_ids + ) - event_entry_map.update(missing_events) + if not allow_rejected: + event_entry_map = { + event_id: entry + for event_id, entry in event_entry_map.items() + if not entry.event.rejected_reason + } return event_entry_map def _invalidate_get_event_cache(self, event_id): self._get_event_cache.invalidate((event_id,)) - def _get_events_from_cache(self, events, allow_rejected, update_metrics=True): - """Fetch events from the caches + def _get_events_from_cache( + self, events: Iterable[str], update_metrics: bool = True + ) -> Dict[str, _EventCacheEntry]: + """Fetch events from the caches. - Args: - events (Iterable[str]): list of event_ids to fetch - allow_rejected (bool): Whether to return events that were rejected - update_metrics (bool): Whether to update the cache hit ratio metrics + May return rejected events. - Returns: - dict of event_id -> _EventCacheEntry for each event_id in cache. If - allow_rejected is `False` then there will still be an entry but it - will be `None` + Args: + events: list of event_ids to fetch + update_metrics: Whether to update the cache hit ratio metrics """ event_map = {} @@ -542,10 +631,7 @@ class EventsWorkerStore(SQLBaseStore): if not ret: continue - if allow_rejected or not ret.event.rejected_reason: - event_map[event_id] = ret - else: - event_map[event_id] = None + event_map[event_id] = ret return event_map @@ -672,23 +758,23 @@ class EventsWorkerStore(SQLBaseStore): with PreserveLoggingContext(): self.hs.get_reactor().callFromThread(fire, event_list, e) - async def _get_events_from_db(self, event_ids, allow_rejected=False): + async def _get_events_from_db( + self, event_ids: Iterable[str] + ) -> Dict[str, _EventCacheEntry]: """Fetch a bunch of events from the database. + May return rejected events. + Returned events will be added to the cache for future lookups. Unknown events are omitted from the response. Args: - event_ids (Iterable[str]): The event_ids of the events to fetch - - allow_rejected (bool): Whether to include rejected events. If False, - rejected events are omitted from the response. + event_ids: The event_ids of the events to fetch Returns: - Dict[str, _EventCacheEntry]: - map from event id to result. May return extra events which - weren't asked for. + map from event id to result. May return extra events which + weren't asked for. """ fetched_events = {} events_to_fetch = event_ids @@ -717,9 +803,6 @@ class EventsWorkerStore(SQLBaseStore): rejected_reason = row["rejected_reason"] - if not allow_rejected and rejected_reason: - continue - # If the event or metadata cannot be parsed, log the error and act # as if the event is unknown. try: diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 1388771c40..12cf6995eb 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py
@@ -29,7 +29,26 @@ if TYPE_CHECKING: from synapse.server import HomeServer -class PresenceStore(SQLBaseStore): +class PresenceBackgroundUpdateStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: Connection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + # Used by `PresenceStore._get_active_presence()` + self.db_pool.updates.register_background_index_update( + "presence_stream_not_offline_index", + index_name="presence_stream_state_not_offline_idx", + table="presence_stream", + columns=["state"], + where_clause="state != 'offline'", + ) + + +class PresenceStore(PresenceBackgroundUpdateStore): def __init__( self, database: DatabasePool, @@ -332,6 +351,8 @@ class PresenceStore(SQLBaseStore): the appropriate time outs. """ + # The `presence_stream_state_not_offline_idx` index should be used for this + # query. sql = ( "SELECT user_id, state, last_active_ts, last_federation_update_ts," " last_user_sync_ts, status_msg, currently_active FROM presence_stream" diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 664c65dac5..bccff5e5b9 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py
@@ -295,6 +295,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): self._invalidate_cache_and_stream( txn, self.have_seen_event, (room_id, event_id) ) + self._invalidate_get_event_cache(event_id) logger.info("[purge] done") diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index b48fe086d4..63ac09c61d 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py
@@ -48,6 +48,11 @@ class PusherWorkerStore(SQLBaseStore): self._remove_stale_pushers, ) + self.db_pool.updates.register_background_update_handler( + "remove_deleted_email_pushers", + self._remove_deleted_email_pushers, + ) + def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]: """JSON-decode the data in the rows returned from the `pushers` table @@ -388,6 +393,74 @@ class PusherWorkerStore(SQLBaseStore): return number_deleted + async def _remove_deleted_email_pushers( + self, progress: dict, batch_size: int + ) -> int: + """A background update that deletes all pushers for deleted email addresses. + + In previous versions of synapse, when users deleted their email address, it didn't + also delete all the pushers for that email address. This background update removes + those to prevent unwanted emails. This should only need to be run once (when users + upgrade to v1.42.0 + + Args: + progress: dict used to store progress of this background update + batch_size: the maximum number of rows to retrieve in a single select query + + Returns: + The number of deleted rows + """ + + last_pusher = progress.get("last_pusher", 0) + + def _delete_pushers(txn) -> int: + + sql = """ + SELECT p.id, p.user_name, p.app_id, p.pushkey + FROM pushers AS p + LEFT JOIN user_threepids AS t + ON t.user_id = p.user_name + AND t.medium = 'email' + AND t.address = p.pushkey + WHERE t.user_id is NULL + AND p.app_id = 'm.email' + AND p.id > ? + ORDER BY p.id ASC + LIMIT ? + """ + + txn.execute(sql, (last_pusher, batch_size)) + rows = txn.fetchall() + + last = None + num_deleted = 0 + for row in rows: + last = row[0] + num_deleted += 1 + self.db_pool.simple_delete_txn( + txn, + "pushers", + {"user_name": row[1], "app_id": row[2], "pushkey": row[3]}, + ) + + if last is not None: + self.db_pool.updates._background_update_progress_txn( + txn, "remove_deleted_email_pushers", {"last_pusher": last} + ) + + return num_deleted + + number_deleted = await self.db_pool.runInteraction( + "_remove_deleted_email_pushers", _delete_pushers + ) + + if number_deleted < batch_size: + await self.db_pool.updates._end_background_update( + "remove_deleted_email_pushers" + ) + + return number_deleted + class PusherStore(PusherWorkerStore): def get_pushers_stream_token(self) -> int: diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 6ad1a0cf7f..a6517962f6 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -29,7 +29,7 @@ from synapse.storage.databases.main.stats import StatsStore from synapse.storage.types import Connection, Cursor from synapse.storage.util.id_generators import IdGenerator from synapse.storage.util.sequence import build_sequence_generator -from synapse.types import UserID +from synapse.types import UserID, UserInfo from synapse.util.caches.descriptors import cached if TYPE_CHECKING: @@ -146,6 +146,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): @cached() async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: + """Deprecated: use get_userinfo_by_id instead""" return await self.db_pool.simple_select_one( table="users", keyvalues={"name": user_id}, @@ -166,6 +167,33 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="get_user_by_id", ) + async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: + """Get a UserInfo object for a user by user ID. + + Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed, + this method should be cached. + + Args: + user_id: The user to fetch user info for. + Returns: + `UserInfo` object if user found, otherwise `None`. + """ + user_data = await self.get_user_by_id(user_id) + if not user_data: + return None + return UserInfo( + appservice_id=user_data["appservice_id"], + consent_server_notice_sent=user_data["consent_server_notice_sent"], + consent_version=user_data["consent_version"], + creation_ts=user_data["creation_ts"], + is_admin=bool(user_data["admin"]), + is_deactivated=bool(user_data["deactivated"]), + is_guest=bool(user_data["is_guest"]), + is_shadow_banned=bool(user_data["shadow_banned"]), + user_id=UserID.from_string(user_data["name"]), + user_type=user_data["user_type"], + ) + async def is_trial_user(self, user_id: str) -> bool: """Checks if user is in the "trial" period, i.e. within the first N days of registration defined by `mau_trial_days` config @@ -571,6 +599,28 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="record_user_external_id", ) + async def remove_user_external_id( + self, auth_provider: str, external_id: str, user_id: str + ) -> None: + """Remove a mapping from an external user id to a mxid + + If the mapping is not found, this method does nothing. + + Args: + auth_provider: identifier for the remote auth provider + external_id: id on that system + user_id: complete mxid that it is mapped to + """ + await self.db_pool.simple_delete( + table="user_external_ids", + keyvalues={ + "auth_provider": auth_provider, + "external_id": external_id, + "user_id": user_id, + }, + desc="remove_user_external_id", + ) + async def get_user_by_external_id( self, auth_provider: str, external_id: str ) -> Optional[str]: @@ -704,16 +754,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): ) return user_id - def get_user_id_by_threepid_txn(self, txn, medium, address): + def get_user_id_by_threepid_txn( + self, txn, medium: str, address: str + ) -> Optional[str]: """Returns user id from threepid Args: txn (cursor): - medium (str): threepid medium e.g. email - address (str): threepid address e.g. me@example.com + medium: threepid medium e.g. email + address: threepid address e.g. me@example.com Returns: - str|None: user id or None if no user id/threepid mapping exists + user id, or None if no user id/threepid mapping exists """ ret = self.db_pool.simple_select_one_txn( txn, @@ -726,14 +778,21 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return ret["user_id"] return None - async def user_add_threepid(self, user_id, medium, address, validated_at, added_at): + async def user_add_threepid( + self, + user_id: str, + medium: str, + address: str, + validated_at: int, + added_at: int, + ) -> None: await self.db_pool.simple_upsert( "user_threepids", {"medium": medium, "address": address}, {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, ) - async def user_get_threepids(self, user_id): + async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]: return await self.db_pool.simple_select_list( "user_threepids", {"user_id": user_id}, @@ -741,7 +800,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "user_get_threepids", ) - async def user_delete_threepid(self, user_id, medium, address) -> None: + async def user_delete_threepid( + self, user_id: str, medium: str, address: str + ) -> None: await self.db_pool.simple_delete( "user_threepids", keyvalues={"user_id": user_id, "medium": medium, "address": address}, @@ -1107,6 +1168,322 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="update_access_token_last_validated", ) + async def registration_token_is_valid(self, token: str) -> bool: + """Checks if a token can be used to authenticate a registration. + + Args: + token: The registration token to be checked + Returns: + True if the token is valid, False otherwise. + """ + res = await self.db_pool.simple_select_one( + "registration_tokens", + keyvalues={"token": token}, + retcols=["uses_allowed", "pending", "completed", "expiry_time"], + allow_none=True, + ) + + # Check if the token exists + if res is None: + return False + + # Check if the token has expired + now = self._clock.time_msec() + if res["expiry_time"] and res["expiry_time"] < now: + return False + + # Check if the token has been used up + if ( + res["uses_allowed"] + and res["pending"] + res["completed"] >= res["uses_allowed"] + ): + return False + + # Otherwise, the token is valid + return True + + async def set_registration_token_pending(self, token: str) -> None: + """Increment the pending registrations counter for a token. + + Args: + token: The registration token pending use + """ + + def _set_registration_token_pending_txn(txn): + pending = self.db_pool.simple_select_one_onecol_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcol="pending", + ) + self.db_pool.simple_update_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + updatevalues={"pending": pending + 1}, + ) + + return await self.db_pool.runInteraction( + "set_registration_token_pending", _set_registration_token_pending_txn + ) + + async def use_registration_token(self, token: str) -> None: + """Complete a use of the given registration token. + + The `pending` counter will be decremented, and the `completed` + counter will be incremented. + + Args: + token: The registration token to be 'used' + """ + + def _use_registration_token_txn(txn): + # Normally, res is Optional[Dict[str, Any]]. + # Override type because the return type is only optional if + # allow_none is True, and we don't want mypy throwing errors + # about None not being indexable. + res: Dict[str, Any] = self.db_pool.simple_select_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcols=["pending", "completed"], + ) # type: ignore + + # Decrement pending and increment completed + self.db_pool.simple_update_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + updatevalues={ + "completed": res["completed"] + 1, + "pending": res["pending"] - 1, + }, + ) + + return await self.db_pool.runInteraction( + "use_registration_token", _use_registration_token_txn + ) + + async def get_registration_tokens( + self, valid: Optional[bool] = None + ) -> List[Dict[str, Any]]: + """List all registration tokens. Used by the admin API. + + Args: + valid: If True, only valid tokens are returned. + If False, only invalid tokens are returned. + Default is None: return all tokens regardless of validity. + + Returns: + A list of dicts, each containing details of a token. + """ + + def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]): + if valid is None: + # Return all tokens regardless of validity + txn.execute("SELECT * FROM registration_tokens") + + elif valid: + # Select valid tokens only + sql = ( + "SELECT * FROM registration_tokens WHERE " + "(uses_allowed > pending + completed OR uses_allowed IS NULL) " + "AND (expiry_time > ? OR expiry_time IS NULL)" + ) + txn.execute(sql, [now]) + + else: + # Select invalid tokens only + sql = ( + "SELECT * FROM registration_tokens WHERE " + "uses_allowed <= pending + completed OR expiry_time <= ?" + ) + txn.execute(sql, [now]) + + return self.db_pool.cursor_to_dict(txn) + + return await self.db_pool.runInteraction( + "select_registration_tokens", + select_registration_tokens_txn, + self._clock.time_msec(), + valid, + ) + + async def get_one_registration_token(self, token: str) -> Optional[Dict[str, Any]]: + """Get info about the given registration token. Used by the admin API. + + Args: + token: The token to retrieve information about. + + Returns: + A dict, or None if token doesn't exist. + """ + return await self.db_pool.simple_select_one( + "registration_tokens", + keyvalues={"token": token}, + retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"], + allow_none=True, + desc="get_one_registration_token", + ) + + async def generate_registration_token( + self, length: int, chars: str + ) -> Optional[str]: + """Generate a random registration token. Used by the admin API. + + Args: + length: The length of the token to generate. + chars: A string of the characters allowed in the generated token. + + Returns: + The generated token. + + Raises: + SynapseError if a unique registration token could still not be + generated after a few tries. + """ + # Make a few attempts at generating a unique token of the required + # length before failing. + for _i in range(3): + # Generate token + token = "".join(random.choices(chars, k=length)) + + # Check if the token already exists + existing_token = await self.db_pool.simple_select_one_onecol( + "registration_tokens", + keyvalues={"token": token}, + retcol="token", + allow_none=True, + desc="check_if_registration_token_exists", + ) + + if existing_token is None: + # The generated token doesn't exist yet, return it + return token + + raise SynapseError( + 500, + "Unable to generate a unique registration token. Try again with a greater length", + Codes.UNKNOWN, + ) + + async def create_registration_token( + self, token: str, uses_allowed: Optional[int], expiry_time: Optional[int] + ) -> bool: + """Create a new registration token. Used by the admin API. + + Args: + token: The token to create. + uses_allowed: The number of times the token can be used to complete + a registration before it becomes invalid. A value of None indicates + unlimited uses. + expiry_time: The latest time the token is valid. Given as the + number of milliseconds since 1970-01-01 00:00:00 UTC. A value of + None indicates that the token does not expire. + + Returns: + Whether the row was inserted or not. + """ + + def _create_registration_token_txn(txn): + row = self.db_pool.simple_select_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcols=["token"], + allow_none=True, + ) + + if row is not None: + # Token already exists + return False + + self.db_pool.simple_insert_txn( + txn, + "registration_tokens", + values={ + "token": token, + "uses_allowed": uses_allowed, + "pending": 0, + "completed": 0, + "expiry_time": expiry_time, + }, + ) + + return True + + return await self.db_pool.runInteraction( + "create_registration_token", _create_registration_token_txn + ) + + async def update_registration_token( + self, token: str, updatevalues: Dict[str, Optional[int]] + ) -> Optional[Dict[str, Any]]: + """Update a registration token. Used by the admin API. + + Args: + token: The token to update. + updatevalues: A dict with the fields to update. E.g.: + `{"uses_allowed": 3}` to update just uses_allowed, or + `{"uses_allowed": 3, "expiry_time": None}` to update both. + This is passed straight to simple_update_one. + + Returns: + A dict with all info about the token, or None if token doesn't exist. + """ + + def _update_registration_token_txn(txn): + try: + self.db_pool.simple_update_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + updatevalues=updatevalues, + ) + except StoreError: + # Update failed because token does not exist + return None + + # Get all info about the token so it can be sent in the response + return self.db_pool.simple_select_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcols=[ + "token", + "uses_allowed", + "pending", + "completed", + "expiry_time", + ], + allow_none=True, + ) + + return await self.db_pool.runInteraction( + "update_registration_token", _update_registration_token_txn + ) + + async def delete_registration_token(self, token: str) -> bool: + """Delete a registration token. Used by the admin API. + + Args: + token: The token to delete. + + Returns: + Whether the token was successfully deleted or not. + """ + try: + await self.db_pool.simple_delete_one( + "registration_tokens", + keyvalues={"token": token}, + desc="delete_registration_token", + ) + except StoreError: + # Deletion failed because token does not exist + return False + + return True + @cached() async def mark_access_token_as_used(self, token_id: int) -> None: """ diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 443e5f3315..6e7312266d 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py
@@ -19,9 +19,10 @@ from abc import abstractmethod from enum import Enum from typing import Any, Dict, List, Optional, Tuple -from synapse.api.constants import EventTypes, JoinRules +from synapse.api.constants import EventContentFields, EventTypes, JoinRules from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion, RoomVersions +from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchStore @@ -73,6 +74,40 @@ class RoomWorkerStore(SQLBaseStore): self.config = hs.config + async def store_room( + self, + room_id: str, + room_creator_user_id: str, + is_public: bool, + room_version: RoomVersion, + ): + """Stores a room. + + Args: + room_id: The desired room ID, can be None. + room_creator_user_id: The user ID of the room creator. + is_public: True to indicate that this room should appear in + public room lists. + room_version: The version of the room + Raises: + StoreError if the room could not be stored. + """ + try: + await self.db_pool.simple_insert( + "rooms", + { + "room_id": room_id, + "creator": room_creator_user_id, + "is_public": is_public, + "room_version": room_version.identifier, + "has_auth_chain_index": True, + }, + desc="store_room", + ) + except Exception as e: + logger.error("store_room with room_id=%s failed: %s", room_id, e) + raise StoreError(500, "Problem creating room.") + async def get_room(self, room_id: str) -> dict: """Retrieve a room. @@ -890,55 +925,6 @@ class RoomWorkerStore(SQLBaseStore): return total_media_quarantined - async def get_all_new_public_rooms( - self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, tuple]], int, bool]: - """Get updates for public rooms replication stream. - - Args: - instance_name: The writer we want to fetch updates from. Unused - here since there is only ever one writer. - last_id: The token to fetch updates from. Exclusive. - current_id: The token to fetch updates up to. Inclusive. - limit: The requested limit for the number of rows to return. The - function may return more or fewer rows. - - Returns: - A tuple consisting of: the updates, a token to use to fetch - subsequent updates, and whether we returned fewer rows than exists - between the requested tokens due to the limit. - - The token returned can be used in a subsequent call to this - function to get further updatees. - - The updates are a list of 2-tuples of stream ID and the row data - """ - if last_id == current_id: - return [], current_id, False - - def get_all_new_public_rooms(txn): - sql = """ - SELECT stream_id, room_id, visibility, appservice_id, network_id - FROM public_room_list_stream - WHERE stream_id > ? AND stream_id <= ? - ORDER BY stream_id ASC - LIMIT ? - """ - - txn.execute(sql, (last_id, current_id, limit)) - updates = [(row[0], row[1:]) for row in txn] - limited = False - upto_token = current_id - if len(updates) >= limit: - upto_token = updates[-1][0] - limited = True - - return updates, upto_token, limited - - return await self.db_pool.runInteraction( - "get_all_new_public_rooms", get_all_new_public_rooms - ) - async def get_rooms_for_retention_period_in_range( self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False ) -> Dict[str, dict]: @@ -1028,6 +1014,7 @@ class _BackgroundUpdates: ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column" POPULATE_ROOM_DEPTH_MIN_DEPTH2 = "populate_room_depth_min_depth2" REPLACE_ROOM_DEPTH_MIN_DEPTH = "replace_room_depth_min_depth" + POPULATE_ROOMS_CREATOR_COLUMN = "populate_rooms_creator_column" _REPLACE_ROOM_DEPTH_SQL_COMMANDS = ( @@ -1069,6 +1056,11 @@ class RoomBackgroundUpdateStore(SQLBaseStore): self._background_replace_room_depth_min_depth, ) + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN, + self._background_populate_rooms_creator_column, + ) + async def _background_insert_retention(self, progress, batch_size): """Retrieves a list of all rooms within a range and inserts an entry for each of them into the room_retention table. @@ -1288,7 +1280,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): keyvalues={"room_id": room_id}, retcol="MAX(stream_ordering)", allow_none=True, - desc="upsert_room_on_join", + desc="has_auth_chain_index_fallback", ) return max_ordering is None @@ -1358,6 +1350,65 @@ class RoomBackgroundUpdateStore(SQLBaseStore): return 0 + async def _background_populate_rooms_creator_column( + self, progress: dict, batch_size: int + ): + """Background update to go and add creator information to `rooms` + table from `current_state_events` table. + """ + + last_room_id = progress.get("room_id", "") + + def _background_populate_rooms_creator_column_txn(txn: LoggingTransaction): + sql = """ + SELECT room_id, json FROM event_json + INNER JOIN rooms AS room USING (room_id) + INNER JOIN current_state_events AS state_event USING (room_id, event_id) + WHERE room_id > ? AND (room.creator IS NULL OR room.creator = '') AND state_event.type = 'm.room.create' AND state_event.state_key = '' + ORDER BY room_id + LIMIT ? + """ + + txn.execute(sql, (last_room_id, batch_size)) + room_id_to_create_event_results = txn.fetchall() + + new_last_room_id = "" + for room_id, event_json in room_id_to_create_event_results: + event_dict = db_to_json(event_json) + + creator = event_dict.get("content").get(EventContentFields.ROOM_CREATOR) + + self.db_pool.simple_update_txn( + txn, + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"creator": creator}, + ) + new_last_room_id = room_id + + if new_last_room_id == "": + return True + + self.db_pool.updates._background_update_progress_txn( + txn, + _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN, + {"room_id": new_last_room_id}, + ) + + return False + + end = await self.db_pool.runInteraction( + "_background_populate_rooms_creator_column", + _background_populate_rooms_creator_column_txn, + ) + + if end: + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN + ) + + return batch_size + class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): def __init__(self, database: DatabasePool, db_conn, hs): @@ -1365,7 +1416,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): self.config = hs.config - async def upsert_room_on_join(self, room_id: str, room_version: RoomVersion): + async def upsert_room_on_join( + self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase] + ): """Ensure that the room is stored in the table Called when we join a room over federation, and overwrites any room version @@ -1376,6 +1429,24 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): # mark the room as having an auth chain cover index. has_auth_chain_index = await self.has_auth_chain_index(room_id) + create_event = None + for e in auth_events: + if (e.type, e.state_key) == (EventTypes.Create, ""): + create_event = e + break + + if create_event is None: + # If the state doesn't have a create event then the room is + # invalid, and it would fail auth checks anyway. + raise StoreError(400, "No create event in state") + + room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR) + + if not isinstance(room_creator, str): + # If the create event does not have a creator then the room is + # invalid, and it would fail auth checks anyway. + raise StoreError(400, "No creator defined on the create event") + await self.db_pool.simple_upsert( desc="upsert_room_on_join", table="rooms", @@ -1383,7 +1454,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): values={"room_version": room_version.identifier}, insertion_values={ "is_public": False, - "creator": "", + "creator": room_creator, "has_auth_chain_index": has_auth_chain_index, }, # rooms has a unique constraint on room_id, so no need to lock when doing an @@ -1391,57 +1462,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): lock=False, ) - async def store_room( - self, - room_id: str, - room_creator_user_id: str, - is_public: bool, - room_version: RoomVersion, - ): - """Stores a room. - - Args: - room_id: The desired room ID, can be None. - room_creator_user_id: The user ID of the room creator. - is_public: True to indicate that this room should appear in - public room lists. - room_version: The version of the room - Raises: - StoreError if the room could not be stored. - """ - try: - - def store_room_txn(txn, next_id): - self.db_pool.simple_insert_txn( - txn, - "rooms", - { - "room_id": room_id, - "creator": room_creator_user_id, - "is_public": is_public, - "room_version": room_version.identifier, - "has_auth_chain_index": True, - }, - ) - if is_public: - self.db_pool.simple_insert_txn( - txn, - table="public_room_list_stream", - values={ - "stream_id": next_id, - "room_id": room_id, - "visibility": is_public, - }, - ) - - async with self._public_room_id_gen.get_next() as next_id: - await self.db_pool.runInteraction( - "store_room_txn", store_room_txn, next_id - ) - except Exception as e: - logger.error("store_room with room_id=%s failed: %s", room_id, e) - raise StoreError(500, "Problem creating room.") - async def maybe_store_room_on_outlier_membership( self, room_id: str, room_version: RoomVersion ): @@ -1462,6 +1482,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): insertion_values={ "room_version": room_version.identifier, "is_public": False, + # We don't worry about setting the `creator` here because + # we don't process any messages in a room while a user is + # invited (only after the join). "creator": "", "has_auth_chain_index": has_auth_chain_index, }, @@ -1470,49 +1493,14 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): lock=False, ) - async def set_room_is_public(self, room_id, is_public): - def set_room_is_public_txn(txn, next_id): - self.db_pool.simple_update_one_txn( - txn, - table="rooms", - keyvalues={"room_id": room_id}, - updatevalues={"is_public": is_public}, - ) - - entries = self.db_pool.simple_select_list_txn( - txn, - table="public_room_list_stream", - keyvalues={ - "room_id": room_id, - "appservice_id": None, - "network_id": None, - }, - retcols=("stream_id", "visibility"), - ) - - entries.sort(key=lambda r: r["stream_id"]) - - add_to_stream = True - if entries: - add_to_stream = bool(entries[-1]["visibility"]) != is_public - - if add_to_stream: - self.db_pool.simple_insert_txn( - txn, - table="public_room_list_stream", - values={ - "stream_id": next_id, - "room_id": room_id, - "visibility": is_public, - "appservice_id": None, - "network_id": None, - }, - ) + async def set_room_is_public(self, room_id: str, is_public: bool) -> None: + await self.db_pool.simple_update_one( + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"is_public": is_public}, + desc="set_room_is_public", + ) - async with self._public_room_id_gen.get_next() as next_id: - await self.db_pool.runInteraction( - "set_room_is_public", set_room_is_public_txn, next_id - ) self.hs.get_notifier().on_new_replication_data() async def set_room_is_public_appservice( @@ -1533,68 +1521,33 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): list. """ - def set_room_is_public_appservice_txn(txn, next_id): - if is_public: - try: - self.db_pool.simple_insert_txn( - txn, - table="appservice_room_list", - values={ - "appservice_id": appservice_id, - "network_id": network_id, - "room_id": room_id, - }, - ) - except self.database_engine.module.IntegrityError: - # We've already inserted, nothing to do. - return - else: - self.db_pool.simple_delete_txn( - txn, - table="appservice_room_list", - keyvalues={ - "appservice_id": appservice_id, - "network_id": network_id, - "room_id": room_id, - }, - ) - - entries = self.db_pool.simple_select_list_txn( - txn, - table="public_room_list_stream", + if is_public: + await self.db_pool.simple_upsert( + table="appservice_room_list", keyvalues={ + "appservice_id": appservice_id, + "network_id": network_id, "room_id": room_id, + }, + values={}, + insertion_values={ "appservice_id": appservice_id, "network_id": network_id, + "room_id": room_id, }, - retcols=("stream_id", "visibility"), + desc="set_room_is_public_appservice_true", ) - - entries.sort(key=lambda r: r["stream_id"]) - - add_to_stream = True - if entries: - add_to_stream = bool(entries[-1]["visibility"]) != is_public - - if add_to_stream: - self.db_pool.simple_insert_txn( - txn, - table="public_room_list_stream", - values={ - "stream_id": next_id, - "room_id": room_id, - "visibility": is_public, - "appservice_id": appservice_id, - "network_id": network_id, - }, - ) - - async with self._public_room_id_gen.get_next() as next_id: - await self.db_pool.runInteraction( - "set_room_is_public_appservice", - set_room_is_public_appservice_txn, - next_id, + else: + await self.db_pool.simple_delete( + table="appservice_room_list", + keyvalues={ + "appservice_id": appservice_id, + "network_id": network_id, + "room_id": room_id, + }, + desc="set_room_is_public_appservice_false", ) + self.hs.get_notifier().on_new_replication_data() async def add_event_report( @@ -1787,9 +1740,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): "get_event_reports_paginate", _get_event_reports_paginate_txn ) - def get_current_public_room_stream_id(self): - return self._public_room_id_gen.get_current_token() - async def block_room(self, room_id: str, user_id: str) -> None: """Marks the room as blocked. Can be called multiple times. diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 68f1b40ea6..c58a4b8690 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py
@@ -307,7 +307,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) @cached() - async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser: + async def get_invited_rooms_for_local_user( + self, user_id: str + ) -> List[RoomsForUser]: """Get all the rooms the *local* user is invited to. Args: @@ -384,9 +386,10 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) sql = """ - SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering + SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering, r.room_version FROM local_current_membership AS c INNER JOIN events AS e USING (room_id, event_id) + INNER JOIN rooms AS r USING (room_id) WHERE user_id = ? AND %s @@ -395,7 +398,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) txn.execute(sql, (user_id, *args)) - results = [RoomsForUser(**r) for r in self.db_pool.cursor_to_dict(txn)] + results = [RoomsForUser(*r) for r in txn] return results @@ -445,7 +448,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): Returns: Returns the rooms the user is in currently, along with the stream - ordering of the most recent join for that user and room. + ordering of the most recent join for that user and room, along with + the room version of the room. """ return await self.db_pool.runInteraction( "get_rooms_for_user_with_stream_ordering", @@ -522,7 +526,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): _get_users_server_still_shares_room_with_txn, ) - async def get_rooms_for_user(self, user_id: str, on_invalidate=None): + async def get_rooms_for_user( + self, user_id: str, on_invalidate=None + ) -> FrozenSet[str]: """Returns a set of room_ids the user is currently joined to. If a remote user only returns rooms this server is currently @@ -629,14 +635,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): # We don't update the event cache hit ratio as it completely throws off # the hit ratio counts. After all, we don't populate the cache if we # miss it here - event_map = self._get_events_from_cache( - member_event_ids, allow_rejected=False, update_metrics=False - ) + event_map = self._get_events_from_cache(member_event_ids, update_metrics=False) missing_member_event_ids = [] for event_id in member_event_ids: ev_entry = event_map.get(event_id) - if ev_entry: + if ev_entry and not ev_entry.event.rejected_reason: if ev_entry.event.membership == Membership.JOIN: users_in_room[ev_entry.event.state_key] = ProfileInfo( display_name=ev_entry.event.content.get("displayname", None), diff --git a/synapse/storage/databases/main/session.py b/synapse/storage/databases/main/session.py new file mode 100644
index 0000000000..172f27d109 --- /dev/null +++ b/synapse/storage/databases/main/session.py
@@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +import synapse.util.stringutils as stringutils +from synapse.api.errors import StoreError +from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) +from synapse.types import JsonDict +from synapse.util import json_encoder + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class SessionStore(SQLBaseStore): + """ + A store for generic session data. + + Each type of session should provide a unique type (to separate sessions). + + Sessions are automatically removed when they expire. + """ + + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + # Create a background job for culling expired sessions. + if hs.config.run_background_tasks: + self._clock.looping_call(self._delete_expired_sessions, 30 * 60 * 1000) + + async def create_session( + self, session_type: str, value: JsonDict, expiry_ms: int + ) -> str: + """ + Creates a new pagination session for the room hierarchy endpoint. + + Args: + session_type: The type for this session. + value: The value to store. + expiry_ms: How long before an item is evicted from the cache + in milliseconds. Default is 0, indicating items never get + evicted based on time. + + Returns: + The newly created session ID. + + Raises: + StoreError if a unique session ID cannot be generated. + """ + # autogen a session ID and try to create it. We may clash, so just + # try a few times till one goes through, giving up eventually. + attempts = 0 + while attempts < 5: + session_id = stringutils.random_string(24) + + try: + await self.db_pool.simple_insert( + table="sessions", + values={ + "session_id": session_id, + "session_type": session_type, + "value": json_encoder.encode(value), + "expiry_time_ms": self.hs.get_clock().time_msec() + expiry_ms, + }, + desc="create_session", + ) + + return session_id + except self.db_pool.engine.module.IntegrityError: + attempts += 1 + raise StoreError(500, "Couldn't generate a session ID.") + + async def get_session(self, session_type: str, session_id: str) -> JsonDict: + """ + Retrieve data stored with create_session + + Args: + session_type: The type for this session. + session_id: The session ID returned from create_session. + + Raises: + StoreError if the session cannot be found. + """ + + def _get_session( + txn: LoggingTransaction, session_type: str, session_id: str, ts: int + ) -> JsonDict: + # This includes the expiry time since items are only periodically + # deleted, not upon expiry. + select_sql = """ + SELECT value FROM sessions WHERE + session_type = ? AND session_id = ? AND expiry_time_ms > ? + """ + txn.execute(select_sql, [session_type, session_id, ts]) + row = txn.fetchone() + + if not row: + raise StoreError(404, "No session") + + return db_to_json(row[0]) + + return await self.db_pool.runInteraction( + "get_session", + _get_session, + session_type, + session_id, + self._clock.time_msec(), + ) + + @wrap_as_background_process("delete_expired_sessions") + async def _delete_expired_sessions(self) -> None: + """Remove sessions with expiry dates that have passed.""" + + def _delete_expired_sessions_txn(txn: LoggingTransaction, ts: int) -> None: + sql = "DELETE FROM sessions WHERE expiry_time_ms <= ?" + txn.execute(sql, (ts,)) + + await self.db_pool.runInteraction( + "delete_expired_sessions", + _delete_expired_sessions_txn, + self._clock.time_msec(), + ) diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 38bfdf5dad..4d6bbc94c7 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py
@@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import attr +from synapse.api.constants import LoginType from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import LoggingTransaction @@ -329,6 +330,48 @@ class UIAuthWorkerStore(SQLBaseStore): keyvalues={}, ) + # If a registration token was used, decrement the pending counter + # before deleting the session. + rows = self.db_pool.simple_select_many_txn( + txn, + table="ui_auth_sessions_credentials", + column="session_id", + iterable=session_ids, + keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN}, + retcols=["result"], + ) + + # Get the tokens used and how much pending needs to be decremented by. + token_counts: Dict[str, int] = {} + for r in rows: + # If registration was successfully completed, the result of the + # registration token stage for that session will be True. + # If a token was used to authenticate, but registration was + # never completed, the result will be the token used. + token = db_to_json(r["result"]) + if isinstance(token, str): + token_counts[token] = token_counts.get(token, 0) + 1 + + # Update the `pending` counters. + if len(token_counts) > 0: + token_rows = self.db_pool.simple_select_many_txn( + txn, + table="registration_tokens", + column="token", + iterable=list(token_counts.keys()), + keyvalues={}, + retcols=["token", "pending"], + ) + for token_row in token_rows: + token = token_row["token"] + new_pending = token_row["pending"] - token_counts[token] + self.db_pool.simple_update_one_txn( + txn, + table="registration_tokens", + keyvalues={"token": token}, + updatevalues={"pending": new_pending}, + ) + # Delete the corresponding completed credentials. self.db_pool.simple_delete_many_txn( txn, diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 9d28d69ac7..65dde67ae9 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py
@@ -365,7 +365,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): return False async def update_profile_in_user_dir( - self, user_id: str, display_name: str, avatar_url: str + self, user_id: str, display_name: Optional[str], avatar_url: Optional[str] ) -> None: """ Update or add a user's profile in the user directory.