summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/appservice.py6
-rw-r--r--synapse/storage/databases/main/devices.py50
-rw-r--r--synapse/storage/databases/main/event_federation.py4
-rw-r--r--synapse/storage/databases/main/event_push_actions.py19
-rw-r--r--synapse/storage/databases/main/events.py107
-rw-r--r--synapse/storage/databases/main/events_worker.py581
-rw-r--r--synapse/storage/databases/main/purge_events.py2
-rw-r--r--synapse/storage/databases/main/push_rule.py11
-rw-r--r--synapse/storage/databases/main/registration.py28
-rw-r--r--synapse/storage/databases/main/roommember.py4
-rw-r--r--synapse/storage/databases/main/stream.py15
-rw-r--r--synapse/storage/databases/main/transactions.py70
12 files changed, 180 insertions, 717 deletions
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py

index 4a883dc166..baec35ee27 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py
@@ -143,7 +143,7 @@ class ApplicationServiceTransactionWorkerStore( A list of ApplicationServices, which may be empty. """ results = await self.db_pool.simple_select_list( - "application_services_state", {"state": state.value}, ["as_id"] + "application_services_state", {"state": state}, ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore as_list = self.get_app_services() @@ -173,7 +173,7 @@ class ApplicationServiceTransactionWorkerStore( desc="get_appservice_state", ) if result: - return ApplicationServiceState(result.get("state")) + return result.get("state") return None async def set_appservice_state( @@ -186,7 +186,7 @@ class ApplicationServiceTransactionWorkerStore( state: The connectivity state to apply. """ await self.db_pool.simple_upsert( - "application_services_state", {"as_id": service.id}, {"state": state.value} + "application_services_state", {"as_id": service.id}, {"state": state} ) async def create_appservice_txn( diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index d5a4a661cd..9ccc66e589 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -139,27 +139,6 @@ class DeviceWorkerStore(SQLBaseStore): return {d["device_id"]: d for d in devices} - async def get_devices_by_auth_provider_session_id( - self, auth_provider_id: str, auth_provider_session_id: str - ) -> List[Dict[str, Any]]: - """Retrieve the list of devices associated with a SSO IdP session ID. - - Args: - auth_provider_id: The SSO IdP ID as defined in the server config - auth_provider_session_id: The session ID within the IdP - Returns: - A list of dicts containing the device_id and the user_id of each device - """ - return await self.db_pool.simple_select_list( - table="device_auth_providers", - keyvalues={ - "auth_provider_id": auth_provider_id, - "auth_provider_session_id": auth_provider_session_id, - }, - retcols=("user_id", "device_id"), - desc="get_devices_by_auth_provider_session_id", - ) - @trace async def get_device_updates_by_remote( self, destination: str, from_stream_id: int, limit: int @@ -1091,12 +1070,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) async def store_device( - self, - user_id: str, - device_id: str, - initial_device_display_name: Optional[str], - auth_provider_id: Optional[str] = None, - auth_provider_session_id: Optional[str] = None, + self, user_id: str, device_id: str, initial_device_display_name: Optional[str] ) -> bool: """Ensure the given device is known; add it to the store if not @@ -1105,8 +1079,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): device_id: id of device initial_device_display_name: initial displayname of the device. Ignored if device exists. - auth_provider_id: The SSO IdP the user used, if any. - auth_provider_session_id: The session ID (sid) got from a OIDC login. Returns: Whether the device was inserted or an existing device existed with that ID. @@ -1143,18 +1115,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if hidden: raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN) - if auth_provider_id and auth_provider_session_id: - await self.db_pool.simple_insert( - "device_auth_providers", - values={ - "user_id": user_id, - "device_id": device_id, - "auth_provider_id": auth_provider_id, - "auth_provider_session_id": auth_provider_session_id, - }, - desc="store_device_auth_provider", - ) - self.device_id_exists_cache.set(key, True) return inserted except StoreError: @@ -1208,14 +1168,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): keyvalues={"user_id": user_id}, ) - self.db_pool.simple_delete_many_txn( - txn, - table="device_auth_providers", - column="device_id", - values=device_ids, - keyvalues={"user_id": user_id}, - ) - await self.db_pool.runInteraction("delete_devices", _delete_devices_txn) for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 9580a40785..ef5d1ef01e 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py
@@ -1552,9 +1552,9 @@ class EventFederationStore(EventFederationWorkerStore): DELETE FROM event_auth WHERE event_id IN ( SELECT event_id FROM events - LEFT JOIN state_events AS se USING (room_id, event_id) + LEFT JOIN state_events USING (room_id, event_id) WHERE ? <= stream_ordering AND stream_ordering < ? - AND se.state_key IS null + AND state_key IS null ) """ diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 3efdd0c920..d957e770dc 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py
@@ -16,7 +16,6 @@ import logging from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import attr -from typing_extensions import TypedDict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json @@ -38,20 +37,6 @@ DEFAULT_HIGHLIGHT_ACTION = [ ] -class BasePushAction(TypedDict): - event_id: str - actions: List[Union[dict, str]] - - -class HttpPushAction(BasePushAction): - room_id: str - stream_ordering: int - - -class EmailPushAction(HttpPushAction): - received_ts: Optional[int] - - def _serialize_action(actions, is_highlight): """Custom serializer for actions. This allows us to "compress" common actions. @@ -236,7 +221,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): min_stream_ordering: int, max_stream_ordering: int, limit: int = 20, - ) -> List[HttpPushAction]: + ) -> List[dict]: """Get a list of the most recent unread push actions for a given user, within the given stream ordering range. Called by the httppusher. @@ -341,7 +326,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): min_stream_ordering: int, max_stream_ordering: int, limit: int = 20, - ) -> List[EmailPushAction]: + ) -> List[dict]: """Get a list of the most recent unread push actions for a given user, within the given stream ordering range. Called by the emailpusher diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 4e528612ea..06832221ad 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -15,7 +15,7 @@ # limitations under the License. import itertools import logging -from collections import OrderedDict +from collections import OrderedDict, namedtuple from typing import ( TYPE_CHECKING, Any, @@ -41,10 +41,9 @@ from synapse.events.snapshot import EventContext # noqa: F401 from synapse.logging.utils import log_function from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction -from synapse.storage.databases.main.events_worker import EventCacheEntry from synapse.storage.databases.main.search import SearchEntry from synapse.storage.types import Connection -from synapse.storage.util.id_generators import AbstractStreamIdGenerator +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.storage.util.sequence import SequenceGenerator from synapse.types import StateMap, get_domain_from_id from synapse.util import json_encoder @@ -65,6 +64,9 @@ event_counter = Counter( ) +_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) + + @attr.s(slots=True) class DeltaState: """Deltas to use to update the `current_state_events` table. @@ -106,30 +108,23 @@ class PersistEventsStore: self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id + # Ideally we'd move these ID gens here, unfortunately some other ID + # generators are chained off them so doing so is a bit of a PITA. + self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen + self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen + # This should only exist on instances that are configured to write assert ( hs.get_instance_name() in hs.config.worker.writers.events ), "Can only instantiate EventsStore on master" - # Since we have been configured to write, we ought to have id generators, - # rather than id trackers. - assert isinstance(self.store._backfill_id_gen, AbstractStreamIdGenerator) - assert isinstance(self.store._stream_id_gen, AbstractStreamIdGenerator) - - # Ideally we'd move these ID gens here, unfortunately some other ID - # generators are chained off them so doing so is a bit of a PITA. - self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen - self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen - async def _persist_events_and_state_updates( self, events_and_contexts: List[Tuple[EventBase, EventContext]], - *, current_state_for_room: Dict[str, StateMap[str]], state_delta_for_room: Dict[str, DeltaState], new_forward_extremeties: Dict[str, List[str]], - use_negative_stream_ordering: bool = False, - inhibit_local_membership_updates: bool = False, + backfilled: bool = False, ) -> None: """Persist a set of events alongside updates to the current state and forward extremities tables. @@ -142,14 +137,7 @@ class PersistEventsStore: room state new_forward_extremities: Map from room_id to list of event IDs that are the new forward extremities of the room. - use_negative_stream_ordering: Whether to start stream_ordering on - the negative side and decrement. This should be set as True - for backfilled events because backfilled events get a negative - stream ordering so they don't come down incremental `/sync`. - inhibit_local_membership_updates: Stop the local_current_membership - from being updated by these events. This should be set to True - for backfilled events because backfilled events in the past do - not affect the current local state. + backfilled Returns: Resolves when the events have been persisted @@ -171,7 +159,7 @@ class PersistEventsStore: # # Note: Multiple instances of this function cannot be in flight at # the same time for the same room. - if use_negative_stream_ordering: + if backfilled: stream_ordering_manager = self._backfill_id_gen.get_next_mult( len(events_and_contexts) ) @@ -188,13 +176,13 @@ class PersistEventsStore: "persist_events", self._persist_events_txn, events_and_contexts=events_and_contexts, - inhibit_local_membership_updates=inhibit_local_membership_updates, + backfilled=backfilled, state_delta_for_room=state_delta_for_room, new_forward_extremeties=new_forward_extremeties, ) persist_event_counter.inc(len(events_and_contexts)) - if stream < 0: + if not backfilled: # backfilled events have negative stream orderings, so we don't # want to set the event_persisted_position to that. synapse.metrics.event_persisted_position.set( @@ -328,9 +316,8 @@ class PersistEventsStore: def _persist_events_txn( self, txn: LoggingTransaction, - *, events_and_contexts: List[Tuple[EventBase, EventContext]], - inhibit_local_membership_updates: bool = False, + backfilled: bool, state_delta_for_room: Optional[Dict[str, DeltaState]] = None, new_forward_extremeties: Optional[Dict[str, List[str]]] = None, ): @@ -343,10 +330,7 @@ class PersistEventsStore: Args: txn events_and_contexts: events to persist - inhibit_local_membership_updates: Stop the local_current_membership - from being updated by these events. This should be set to True - for backfilled events because backfilled events in the past do - not affect the current local state. + backfilled: True if the events were backfilled delete_existing True to purge existing table rows for the events from the database. This is useful when retrying due to IntegrityError. @@ -379,7 +363,9 @@ class PersistEventsStore: events_and_contexts ) - self._update_room_depths_txn(txn, events_and_contexts=events_and_contexts) + self._update_room_depths_txn( + txn, events_and_contexts=events_and_contexts, backfilled=backfilled + ) # _update_outliers_txn filters out any events which have already been # persisted, and returns the filtered list. @@ -412,7 +398,7 @@ class PersistEventsStore: txn, events_and_contexts=events_and_contexts, all_events_and_contexts=all_events_and_contexts, - inhibit_local_membership_updates=inhibit_local_membership_updates, + backfilled=backfilled, ) # We call this last as it assumes we've inserted the events into @@ -575,9 +561,9 @@ class PersistEventsStore: # fetch their auth event info. while missing_auth_chains: sql = """ - SELECT event_id, events.type, se.state_key, chain_id, sequence_number + SELECT event_id, events.type, state_key, chain_id, sequence_number FROM events - INNER JOIN state_events AS se USING (event_id) + INNER JOIN state_events USING (event_id) LEFT JOIN event_auth_chains USING (event_id) WHERE """ @@ -1214,6 +1200,7 @@ class PersistEventsStore: self, txn, events_and_contexts: List[Tuple[EventBase, EventContext]], + backfilled: bool, ): """Update min_depth for each room @@ -1221,18 +1208,13 @@ class PersistEventsStore: txn (twisted.enterprise.adbapi.Connection): db connection events_and_contexts (list[(EventBase, EventContext)]): events we are persisting + backfilled (bool): True if the events were backfilled """ depth_updates: Dict[str, int] = {} for event, context in events_and_contexts: # Remove the any existing cache entries for the event_ids txn.call_after(self.store._invalidate_get_event_cache, event.event_id) - # Then update the `stream_ordering` position to mark the latest - # event as the front of the room. This should not be done for - # backfilled events because backfilled events have negative - # stream_ordering and happened in the past so we know that we don't - # need to update the stream_ordering tip/front for the room. - assert event.internal_metadata.stream_ordering is not None - if event.internal_metadata.stream_ordering >= 0: + if not backfilled: txn.call_after( self.store._events_stream_cache.entity_has_changed, event.room_id, @@ -1445,12 +1427,7 @@ class PersistEventsStore: return [ec for ec in events_and_contexts if ec[0] not in to_remove] def _update_metadata_tables_txn( - self, - txn, - *, - events_and_contexts, - all_events_and_contexts, - inhibit_local_membership_updates: bool = False, + self, txn, events_and_contexts, all_events_and_contexts, backfilled ): """Update all the miscellaneous tables for new events @@ -1462,10 +1439,7 @@ class PersistEventsStore: events that we were going to persist. This includes events we've already persisted, etc, that wouldn't appear in events_and_context. - inhibit_local_membership_updates: Stop the local_current_membership - from being updated by these events. This should be set to True - for backfilled events because backfilled events in the past do - not affect the current local state. + backfilled (bool): True if the events were backfilled """ # Insert all the push actions into the event_push_actions table. @@ -1539,7 +1513,7 @@ class PersistEventsStore: for event, _ in events_and_contexts if event.type == EventTypes.Member ], - inhibit_local_membership_updates=inhibit_local_membership_updates, + backfilled=backfilled, ) # Insert event_reference_hashes table. @@ -1579,13 +1553,11 @@ class PersistEventsStore: for row in rows: event = ev_map[row["event_id"]] if not row["rejects"] and not row["redacts"]: - to_prefill.append(EventCacheEntry(event=event, redacted_event=None)) + to_prefill.append(_EventCacheEntry(event=event, redacted_event=None)) def prefill(): for cache_entry in to_prefill: - self.store._get_event_cache.set( - (cache_entry.event.event_id,), cache_entry - ) + self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry) txn.call_after(prefill) @@ -1666,19 +1638,8 @@ class PersistEventsStore: txn, table="event_reference_hashes", values=vals ) - def _store_room_members_txn( - self, txn, events, *, inhibit_local_membership_updates: bool = False - ): - """ - Store a room member in the database. - Args: - txn: The transaction to use. - events: List of events to store. - inhibit_local_membership_updates: Stop the local_current_membership - from being updated by these events. This should be set to True - for backfilled events because backfilled events in the past do - not affect the current local state. - """ + def _store_room_members_txn(self, txn, events, backfilled): + """Store a room member in the database.""" def non_null_str_or_none(val: Any) -> Optional[str]: return val if isinstance(val, str) and "\u0000" not in val else None @@ -1721,7 +1682,7 @@ class PersistEventsStore: # band membership", like a remote invite or a rejection of a remote invite. if ( self.is_mine_id(event.state_key) - and not inhibit_local_membership_updates + and not backfilled and event.internal_metadata.is_outlier() and event.internal_metadata.is_out_of_band_membership() ): diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c7b660ac5a..c6bf316d5b 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py
@@ -15,18 +15,14 @@ import logging import threading from typing import ( - TYPE_CHECKING, - Any, Collection, Container, Dict, Iterable, List, - NoReturn, Optional, Set, Tuple, - cast, overload, ) @@ -42,7 +38,6 @@ from synapse.api.errors import NotFoundError, SynapseError from synapse.api.room_versions import ( KNOWN_ROOM_VERSIONS, EventFormatVersions, - RoomVersion, RoomVersions, ) from synapse.events import EventBase, make_event_from_dict @@ -61,18 +56,10 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import ( - DatabasePool, - LoggingDatabaseConnection, - LoggingTransaction, -) +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.engines import PostgresEngine -from synapse.storage.types import Cursor -from synapse.storage.util.id_generators import ( - AbstractStreamIdTracker, - MultiWriterIdGenerator, - StreamIdGenerator, -) +from synapse.storage.types import Connection +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id from synapse.util import unwrapFirstError @@ -82,13 +69,10 @@ from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure -if TYPE_CHECKING: - from synapse.server import HomeServer - logger = logging.getLogger(__name__) -# These values are used in the `enqueue_event` and `_fetch_loop` methods to +# These values are used in the `enqueus_event` and `_do_fetch` methods to # control how we batch/bulk fetch events from the database. # The values are plucked out of thing air to make initial sync run faster # on jki.re @@ -105,7 +89,7 @@ event_fetch_ongoing_gauge = Gauge( @attr.s(slots=True, auto_attribs=True) -class EventCacheEntry: +class _EventCacheEntry: event: EventBase redacted_event: Optional[EventBase] @@ -145,7 +129,7 @@ class _EventRow: json: str internal_metadata: str format_version: Optional[int] - room_version_id: Optional[str] + room_version_id: Optional[int] rejected_reason: Optional[str] redactions: List[str] outlier: bool @@ -169,16 +153,9 @@ class EventsWorkerStore(SQLBaseStore): # options controlling this. USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True - def __init__( - self, - database: DatabasePool, - db_conn: LoggingDatabaseConnection, - hs: "HomeServer", - ): + def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self._stream_id_gen: AbstractStreamIdTracker - self._backfill_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): # If we're using Postgres than we can use `MultiWriterIdGenerator` # regardless of whether this process writes to the streams or not. @@ -237,7 +214,7 @@ class EventsWorkerStore(SQLBaseStore): 5 * 60 * 1000, ) - self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache( + self._get_event_cache = LruCache( cache_name="*getEvent*", max_size=hs.config.caches.event_cache_size, ) @@ -246,21 +223,19 @@ class EventsWorkerStore(SQLBaseStore): # ID to cache entry. Note that the returned dict may not have the # requested event in it if the event isn't in the DB. self._current_event_fetches: Dict[ - str, ObservableDeferred[Dict[str, EventCacheEntry]] + str, ObservableDeferred[Dict[str, _EventCacheEntry]] ] = {} self._event_fetch_lock = threading.Condition() - self._event_fetch_list: List[ - Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"] - ] = [] + self._event_fetch_list = [] self._event_fetch_ongoing = 0 event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) # We define this sequence here so that it can be referenced from both # the DataStore and PersistEventStore. - def get_chain_id_txn(txn: Cursor) -> int: + def get_chain_id_txn(txn): txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains") - return cast(Tuple[int], txn.fetchone())[0] + return txn.fetchone()[0] self.event_chain_id_gen = build_sequence_generator( db_conn, @@ -271,13 +246,7 @@ class EventsWorkerStore(SQLBaseStore): id_column="chain_id", ) - def process_replication_rows( - self, - stream_name: str, - instance_name: str, - token: int, - rows: Iterable[Any], - ) -> None: + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == EventsStream.NAME: self._stream_id_gen.advance(instance_name, token) elif stream_name == BackfillStream.NAME: @@ -311,10 +280,10 @@ class EventsWorkerStore(SQLBaseStore): self, event_id: str, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, - get_prev_content: bool = ..., - allow_rejected: bool = ..., - allow_none: Literal[False] = ..., - check_room_id: Optional[str] = ..., + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[False] = False, + check_room_id: Optional[str] = None, ) -> EventBase: ... @@ -323,10 +292,10 @@ class EventsWorkerStore(SQLBaseStore): self, event_id: str, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, - get_prev_content: bool = ..., - allow_rejected: bool = ..., - allow_none: Literal[True] = ..., - check_room_id: Optional[str] = ..., + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[True] = False, + check_room_id: Optional[str] = None, ) -> Optional[EventBase]: ... @@ -388,7 +357,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_events( self, - event_ids: Collection[str], + event_ids: Iterable[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, @@ -575,7 +544,7 @@ class EventsWorkerStore(SQLBaseStore): async def _get_events_from_cache_or_db( self, event_ids: Iterable[str], allow_rejected: bool = False - ) -> Dict[str, EventCacheEntry]: + ) -> Dict[str, _EventCacheEntry]: """Fetch a bunch of events from the cache or the database. If events are pulled from the database, they will be cached for future lookups. @@ -609,7 +578,7 @@ class EventsWorkerStore(SQLBaseStore): # same dict into itself N times). already_fetching_ids: Set[str] = set() already_fetching_deferreds: Set[ - ObservableDeferred[Dict[str, EventCacheEntry]] + ObservableDeferred[Dict[str, _EventCacheEntry]] ] = set() for event_id in missing_events_ids: @@ -632,8 +601,8 @@ class EventsWorkerStore(SQLBaseStore): # function returning more events than requested, but that can happen # already due to `_get_events_from_db`). fetching_deferred: ObservableDeferred[ - Dict[str, EventCacheEntry] - ] = ObservableDeferred(defer.Deferred(), consumeErrors=True) + Dict[str, _EventCacheEntry] + ] = ObservableDeferred(defer.Deferred()) for event_id in missing_events_ids: self._current_event_fetches[event_id] = fetching_deferred @@ -689,12 +658,12 @@ class EventsWorkerStore(SQLBaseStore): return event_entry_map - def _invalidate_get_event_cache(self, event_id: str) -> None: + def _invalidate_get_event_cache(self, event_id): self._get_event_cache.invalidate((event_id,)) def _get_events_from_cache( self, events: Iterable[str], update_metrics: bool = True - ) -> Dict[str, EventCacheEntry]: + ) -> Dict[str, _EventCacheEntry]: """Fetch events from the caches. May return rejected events. @@ -767,123 +736,38 @@ class EventsWorkerStore(SQLBaseStore): for e in state_to_include.values() ] - def _maybe_start_fetch_thread(self) -> None: - """Starts an event fetch thread if we are not yet at the maximum number.""" - with self._event_fetch_lock: - if ( - self._event_fetch_list - and self._event_fetch_ongoing < EVENT_QUEUE_THREADS - ): - self._event_fetch_ongoing += 1 - event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) - # `_event_fetch_ongoing` is decremented in `_fetch_thread`. - should_start = True - else: - should_start = False - - if should_start: - run_as_background_process("fetch_events", self._fetch_thread) - - async def _fetch_thread(self) -> None: - """Services requests for events from `_event_fetch_list`.""" - exc = None - try: - await self.db_pool.runWithConnection(self._fetch_loop) - except BaseException as e: - exc = e - raise - finally: - should_restart = False - event_fetches_to_fail = [] - with self._event_fetch_lock: - self._event_fetch_ongoing -= 1 - event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) - - # There may still be work remaining in `_event_fetch_list` if we - # failed, or it was added in between us deciding to exit and - # decrementing `_event_fetch_ongoing`. - if self._event_fetch_list: - if exc is None: - # We decided to exit, but then some more work was added - # before `_event_fetch_ongoing` was decremented. - # If a new event fetch thread was not started, we should - # restart ourselves since the remaining event fetch threads - # may take a while to get around to the new work. - # - # Unfortunately it is not possible to tell whether a new - # event fetch thread was started, so we restart - # unconditionally. If we are unlucky, we will end up with - # an idle fetch thread, but it will time out after - # `EVENT_QUEUE_ITERATIONS * EVENT_QUEUE_TIMEOUT_S` seconds - # in any case. - # - # Note that multiple fetch threads may run down this path at - # the same time. - should_restart = True - elif isinstance(exc, Exception): - if self._event_fetch_ongoing == 0: - # We were the last remaining fetcher and failed. - # Fail any outstanding fetches since no one else will - # handle them. - event_fetches_to_fail = self._event_fetch_list - self._event_fetch_list = [] - else: - # We weren't the last remaining fetcher, so another - # fetcher will pick up the work. This will either happen - # after their existing work, however long that takes, - # or after at most `EVENT_QUEUE_TIMEOUT_S` seconds if - # they are idle. - pass - else: - # The exception is a `SystemExit`, `KeyboardInterrupt` or - # `GeneratorExit`. Don't try to do anything clever here. - pass - - if should_restart: - # We exited cleanly but noticed more work. - self._maybe_start_fetch_thread() - - if event_fetches_to_fail: - # We were the last remaining fetcher and failed. - # Fail any outstanding fetches since no one else will handle them. - assert exc is not None - with PreserveLoggingContext(): - for _, deferred in event_fetches_to_fail: - deferred.errback(exc) - - def _fetch_loop(self, conn: LoggingDatabaseConnection) -> None: + def _do_fetch(self, conn: Connection) -> None: """Takes a database connection and waits for requests for events from the _event_fetch_list queue. """ - i = 0 - while True: - with self._event_fetch_lock: - event_list = self._event_fetch_list - self._event_fetch_list = [] - - if not event_list: - # There are no requests waiting. If we haven't yet reached the - # maximum iteration limit, wait for some more requests to turn up. - # Otherwise, bail out. - single_threaded = self.database_engine.single_threaded - if ( - not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING - or single_threaded - or i > EVENT_QUEUE_ITERATIONS - ): - return - - self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) - i += 1 - continue - i = 0 + try: + i = 0 + while True: + with self._event_fetch_lock: + event_list = self._event_fetch_list + self._event_fetch_list = [] + + if not event_list: + single_threaded = self.database_engine.single_threaded + if ( + not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING + or single_threaded + or i > EVENT_QUEUE_ITERATIONS + ): + break + else: + self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) + i += 1 + continue + i = 0 - self._fetch_event_list(conn, event_list) + self._fetch_event_list(conn, event_list) + finally: + self._event_fetch_ongoing -= 1 + event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) def _fetch_event_list( - self, - conn: LoggingDatabaseConnection, - event_list: List[Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]], + self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]] ) -> None: """Handle a load of requests from the _event_fetch_list queue @@ -910,7 +794,7 @@ class EventsWorkerStore(SQLBaseStore): ) # We only want to resolve deferreds from the main thread - def fire() -> None: + def fire(): for _, d in event_list: d.callback(row_dict) @@ -920,16 +804,18 @@ class EventsWorkerStore(SQLBaseStore): logger.exception("do_fetch") # We only want to resolve deferreds from the main thread - def fire_errback(exc: Exception) -> None: - for _, d in event_list: - d.errback(exc) + def fire(evs, exc): + for _, d in evs: + if not d.called: + with PreserveLoggingContext(): + d.errback(exc) with PreserveLoggingContext(): - self.hs.get_reactor().callFromThread(fire_errback, e) + self.hs.get_reactor().callFromThread(fire, event_list, e) async def _get_events_from_db( - self, event_ids: Collection[str] - ) -> Dict[str, EventCacheEntry]: + self, event_ids: Iterable[str] + ) -> Dict[str, _EventCacheEntry]: """Fetch a bunch of events from the database. May return rejected events. @@ -945,29 +831,29 @@ class EventsWorkerStore(SQLBaseStore): map from event id to result. May return extra events which weren't asked for. """ - fetched_event_ids: Set[str] = set() - fetched_events: Dict[str, _EventRow] = {} + fetched_events = {} events_to_fetch = event_ids while events_to_fetch: row_map = await self._enqueue_events(events_to_fetch) # we need to recursively fetch any redactions of those events - redaction_ids: Set[str] = set() + redaction_ids = set() for event_id in events_to_fetch: row = row_map.get(event_id) - fetched_event_ids.add(event_id) + fetched_events[event_id] = row if row: - fetched_events[event_id] = row redaction_ids.update(row.redactions) - events_to_fetch = redaction_ids.difference(fetched_event_ids) + events_to_fetch = redaction_ids.difference(fetched_events.keys()) if events_to_fetch: logger.debug("Also fetching redaction events %s", events_to_fetch) # build a map from event_id to EventBase - event_map: Dict[str, EventBase] = {} + event_map = {} for event_id, row in fetched_events.items(): + if not row: + continue assert row.event_id == event_id rejected_reason = row.rejected_reason @@ -995,7 +881,6 @@ class EventsWorkerStore(SQLBaseStore): room_version_id = row.room_version_id - room_version: Optional[RoomVersion] if not room_version_id: # this should only happen for out-of-band membership events which # arrived before #6983 landed. For all other events, we should have @@ -1066,14 +951,14 @@ class EventsWorkerStore(SQLBaseStore): # finally, we can decide whether each one needs redacting, and build # the cache entries. - result_map: Dict[str, EventCacheEntry] = {} + result_map = {} for event_id, original_ev in event_map.items(): redactions = fetched_events[event_id].redactions redacted_event = self._maybe_redact_event_row( original_ev, redactions, event_map ) - cache_entry = EventCacheEntry( + cache_entry = _EventCacheEntry( event=original_ev, redacted_event=redacted_event ) @@ -1082,7 +967,7 @@ class EventsWorkerStore(SQLBaseStore): return result_map - async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]: + async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]: """Fetches events from the database using the _event_fetch_list. This allows batch and bulk fetching of events - it allows us to fetch events without having to create a new transaction for each request for events. @@ -1095,12 +980,23 @@ class EventsWorkerStore(SQLBaseStore): that weren't requested. """ - events_d: "defer.Deferred[Dict[str, _EventRow]]" = defer.Deferred() + events_d = defer.Deferred() with self._event_fetch_lock: self._event_fetch_list.append((events, events_d)) + self._event_fetch_lock.notify() - self._maybe_start_fetch_thread() + if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: + self._event_fetch_ongoing += 1 + event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) + should_start = True + else: + should_start = False + + if should_start: + run_as_background_process( + "fetch_events", self.db_pool.runWithConnection, self._do_fetch + ) logger.debug("Loading %d events: %s", len(events), events) with PreserveLoggingContext(): @@ -1250,7 +1146,7 @@ class EventsWorkerStore(SQLBaseStore): # no valid redaction found for this event return None - async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]: + async def have_events_in_timeline(self, event_ids): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ @@ -1279,7 +1175,7 @@ class EventsWorkerStore(SQLBaseStore): event_ids: events we are looking for Returns: - The set of events we have already seen. + set[str]: The events we have already seen. """ res = await self._have_seen_events_dict( (room_id, event_id) for event_id in event_ids @@ -1302,9 +1198,7 @@ class EventsWorkerStore(SQLBaseStore): } results = {x: True for x in cache_results} - def have_seen_events_txn( - txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...] - ) -> None: + def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]): # we deliberately do *not* query the database for room_id, to make the # query an index-only lookup on `events_event_id_key`. # @@ -1330,14 +1224,12 @@ class EventsWorkerStore(SQLBaseStore): return results @cached(max_entries=100000, tree=True) - async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn: + async def have_seen_event(self, room_id: str, event_id: str): # this only exists for the benefit of the @cachedList descriptor on # _have_seen_events_dict raise NotImplementedError() - def _get_current_state_event_counts_txn( - self, txn: LoggingTransaction, room_id: str - ) -> int: + def _get_current_state_event_counts_txn(self, txn, room_id): """ See get_current_state_event_counts. """ @@ -1362,7 +1254,7 @@ class EventsWorkerStore(SQLBaseStore): room_id, ) - async def get_room_complexity(self, room_id: str) -> Dict[str, float]: + async def get_room_complexity(self, room_id): """ Get a rough approximation of the complexity of the room. This is used by remote servers to decide whether they wish to join the room or not. @@ -1370,10 +1262,10 @@ class EventsWorkerStore(SQLBaseStore): more resources. Args: - room_id: The room ID to query. + room_id (str) Returns: - dict[str:float] of complexity version to complexity. + dict[str:int] of complexity version to complexity. """ state_events = await self.get_current_state_event_counts(room_id) @@ -1383,13 +1275,13 @@ class EventsWorkerStore(SQLBaseStore): return {"v1": complexity_v1} - def get_current_events_token(self) -> int: + def get_current_events_token(self): """The current maximum token that events have reached""" return self._stream_id_gen.get_current_token() async def get_all_new_forward_event_rows( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + ) -> List[Tuple]: """Returns new events, for the Events replication stream Args: @@ -1403,15 +1295,13 @@ class EventsWorkerStore(SQLBaseStore): EventsStreamRow. """ - def get_all_new_forward_event_rows( - txn: LoggingTransaction, - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + def get_all_new_forward_event_rows(txn): sql = ( "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" + " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events AS se USING (event_id)" + " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " LEFT JOIN room_memberships USING (event_id)" " LEFT JOIN rejections USING (event_id)" @@ -1421,9 +1311,7 @@ class EventsWorkerStore(SQLBaseStore): " LIMIT ?" ) txn.execute(sql, (last_id, current_id, instance_name, limit)) - return cast( - List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall() - ) + return txn.fetchall() return await self.db_pool.runInteraction( "get_all_new_forward_event_rows", get_all_new_forward_event_rows @@ -1431,7 +1319,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_ex_outlier_stream_rows( self, instance_name: str, last_id: int, current_id: int - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + ) -> List[Tuple]: """Returns de-outliered events, for the Events replication stream Args: @@ -1444,16 +1332,14 @@ class EventsWorkerStore(SQLBaseStore): EventsStreamRow. """ - def get_ex_outlier_stream_rows_txn( - txn: LoggingTransaction, - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + def get_ex_outlier_stream_rows_txn(txn): sql = ( "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," - " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" + " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" " FROM events AS e" " INNER JOIN ex_outlier_stream AS out USING (event_id)" " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events AS se USING (event_id)" + " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " LEFT JOIN room_memberships USING (event_id)" " LEFT JOIN rejections USING (event_id)" @@ -1464,9 +1350,7 @@ class EventsWorkerStore(SQLBaseStore): ) txn.execute(sql, (last_id, current_id, instance_name)) - return cast( - List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall() - ) + return txn.fetchall() return await self.db_pool.runInteraction( "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn @@ -1474,7 +1358,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_all_new_backfill_event_rows( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]: + ) -> Tuple[List[Tuple[int, list]], int, bool]: """Get updates for backfill replication stream, including all new backfilled events and events that have gone from being outliers to not. @@ -1502,15 +1386,13 @@ class EventsWorkerStore(SQLBaseStore): if last_id == current_id: return [], current_id, False - def get_all_new_backfill_event_rows( - txn: LoggingTransaction, - ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]: + def get_all_new_backfill_event_rows(txn): sql = ( "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " se.state_key, redacts, relates_to_id" + " state_key, redacts, relates_to_id" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events AS se USING (event_id)" + " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " WHERE ? > stream_ordering AND stream_ordering >= ?" " AND instance_name = ?" @@ -1518,15 +1400,7 @@ class EventsWorkerStore(SQLBaseStore): " LIMIT ?" ) txn.execute(sql, (-last_id, -current_id, instance_name, limit)) - new_event_updates: List[ - Tuple[int, Tuple[str, str, str, str, str, str]] - ] = [] - row: Tuple[int, str, str, str, str, str, str] - # Type safety: iterating over `txn` yields `Tuple`, i.e. - # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a - # variadic tuple to a fixed length tuple and flags it up as an error. - for row in txn: # type: ignore[assignment] - new_event_updates.append((row[0], row[1:])) + new_event_updates = [(row[0], row[1:]) for row in txn] limited = False if len(new_event_updates) == limit: @@ -1537,11 +1411,11 @@ class EventsWorkerStore(SQLBaseStore): sql = ( "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," - " se.state_key, redacts, relates_to_id" + " state_key, redacts, relates_to_id" " FROM events AS e" " INNER JOIN ex_outlier_stream AS out USING (event_id)" " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events AS se USING (event_id)" + " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " WHERE ? > event_stream_ordering" " AND event_stream_ordering >= ?" @@ -1549,11 +1423,7 @@ class EventsWorkerStore(SQLBaseStore): " ORDER BY event_stream_ordering DESC" ) txn.execute(sql, (-last_id, -upper_bound, instance_name)) - # Type safety: iterating over `txn` yields `Tuple`, i.e. - # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a - # variadic tuple to a fixed length tuple and flags it up as an error. - for row in txn: # type: ignore[assignment] - new_event_updates.append((row[0], row[1:])) + new_event_updates.extend((row[0], row[1:]) for row in txn) if len(new_event_updates) >= limit: upper_bound = new_event_updates[-1][0] @@ -1567,7 +1437,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_all_updated_current_state_deltas( self, instance_name: str, from_token: int, to_token: int, target_row_count: int - ) -> Tuple[List[Tuple[int, str, str, str, str]], int, bool]: + ) -> Tuple[List[Tuple], int, bool]: """Fetch updates from current_state_delta_stream Args: @@ -1587,9 +1457,7 @@ class EventsWorkerStore(SQLBaseStore): * `limited` is whether there are more updates to fetch. """ - def get_all_updated_current_state_deltas_txn( - txn: LoggingTransaction, - ) -> List[Tuple[int, str, str, str, str]]: + def get_all_updated_current_state_deltas_txn(txn): sql = """ SELECT stream_id, room_id, type, state_key, event_id FROM current_state_delta_stream @@ -1598,23 +1466,21 @@ class EventsWorkerStore(SQLBaseStore): ORDER BY stream_id ASC LIMIT ? """ txn.execute(sql, (from_token, to_token, instance_name, target_row_count)) - return cast(List[Tuple[int, str, str, str, str]], txn.fetchall()) + return txn.fetchall() - def get_deltas_for_stream_id_txn( - txn: LoggingTransaction, stream_id: int - ) -> List[Tuple[int, str, str, str, str]]: + def get_deltas_for_stream_id_txn(txn, stream_id): sql = """ SELECT stream_id, room_id, type, state_key, event_id FROM current_state_delta_stream WHERE stream_id = ? """ txn.execute(sql, [stream_id]) - return cast(List[Tuple[int, str, str, str, str]], txn.fetchall()) + return txn.fetchall() # we need to make sure that, for every stream id in the results, we get *all* # the rows with that stream id. - rows: List[Tuple[int, str, str, str, str]] = await self.db_pool.runInteraction( + rows: List[Tuple] = await self.db_pool.runInteraction( "get_all_updated_current_state_deltas", get_all_updated_current_state_deltas_txn, ) @@ -1643,14 +1509,14 @@ class EventsWorkerStore(SQLBaseStore): return rows, to_token, True - async def is_event_after(self, event_id1: str, event_id2: str) -> bool: + async def is_event_after(self, event_id1, event_id2): """Returns True if event_id1 is after event_id2 in the stream""" to_1, so_1 = await self.get_event_ordering(event_id1) to_2, so_2 = await self.get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) @cached(max_entries=5000) - async def get_event_ordering(self, event_id: str) -> Tuple[int, int]: + async def get_event_ordering(self, event_id): res = await self.db_pool.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], @@ -1673,9 +1539,7 @@ class EventsWorkerStore(SQLBaseStore): None otherwise. """ - def get_next_event_to_expire_txn( - txn: LoggingTransaction, - ) -> Optional[Tuple[str, int]]: + def get_next_event_to_expire_txn(txn): txn.execute( """ SELECT event_id, expiry_ts FROM event_expiry @@ -1683,7 +1547,7 @@ class EventsWorkerStore(SQLBaseStore): """ ) - return cast(Optional[Tuple[str, int]], txn.fetchone()) + return txn.fetchone() return await self.db_pool.runInteraction( desc="get_next_event_to_expire", func=get_next_event_to_expire_txn @@ -1747,10 +1611,10 @@ class EventsWorkerStore(SQLBaseStore): return mapping @wrap_as_background_process("_cleanup_old_transaction_ids") - async def _cleanup_old_transaction_ids(self) -> None: + async def _cleanup_old_transaction_ids(self): """Cleans out transaction id mappings older than 24hrs.""" - def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None: + def _cleanup_old_transaction_ids_txn(txn): sql = """ DELETE FROM event_txn_id WHERE inserted_ts < ? @@ -1762,198 +1626,3 @@ class EventsWorkerStore(SQLBaseStore): "_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn, ) - - async def is_event_next_to_backward_gap(self, event: EventBase) -> bool: - """Check if the given event is next to a backward gap of missing events. - <latest messages> A(False)--->B(False)--->C(True)---> <gap, unknown events> <oldest messages> - - Args: - room_id: room where the event lives - event_id: event to check - - Returns: - Boolean indicating whether it's an extremity - """ - - def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool: - # If the event in question has any of its prev_events listed as a - # backward extremity, it's next to a gap. - # - # We can't just check the backward edges in `event_edges` because - # when we persist events, we will also record the prev_events as - # edges to the event in question regardless of whether we have those - # prev_events yet. We need to check whether those prev_events are - # backward extremities, also known as gaps, that need to be - # backfilled. - backward_extremity_query = """ - SELECT 1 FROM event_backward_extremities - WHERE - room_id = ? - AND %s - LIMIT 1 - """ - - # If the event in question is a backward extremity or has any of its - # prev_events listed as a backward extremity, it's next to a - # backward gap. - clause, args = make_in_list_sql_clause( - self.database_engine, - "event_id", - [event.event_id] + list(event.prev_event_ids()), - ) - - txn.execute(backward_extremity_query % (clause,), [event.room_id] + args) - backward_extremities = txn.fetchall() - - # We consider any backward extremity as a backward gap - if len(backward_extremities): - return True - - return False - - return await self.db_pool.runInteraction( - "is_event_next_to_backward_gap_txn", - is_event_next_to_backward_gap_txn, - ) - - async def is_event_next_to_forward_gap(self, event: EventBase) -> bool: - """Check if the given event is next to a forward gap of missing events. - The gap in front of the latest events is not considered a gap. - <latest messages> A(False)--->B(False)--->C(False)---> <gap, unknown events> <oldest messages> - <latest messages> A(False)--->B(False)---> <gap, unknown events> --->D(True)--->E(False) <oldest messages> - - Args: - room_id: room where the event lives - event_id: event to check - - Returns: - Boolean indicating whether it's an extremity - """ - - def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool: - # If the event in question is a forward extremity, we will just - # consider any potential forward gap as not a gap since it's one of - # the latest events in the room. - # - # `event_forward_extremities` does not include backfilled or outlier - # events so we can't rely on it to find forward gaps. We can only - # use it to determine whether a message is the latest in the room. - # - # We can't combine this query with the `forward_edge_query` below - # because if the event in question has no forward edges (isn't - # referenced by any other event's prev_events) but is in - # `event_forward_extremities`, we don't want to return 0 rows and - # say it's next to a gap. - forward_extremity_query = """ - SELECT 1 FROM event_forward_extremities - WHERE - room_id = ? - AND event_id = ? - LIMIT 1 - """ - - # Check to see whether the event in question is already referenced - # by another event. If we don't see any edges, we're next to a - # forward gap. - forward_edge_query = """ - SELECT 1 FROM event_edges - /* Check to make sure the event referencing our event in question is not rejected */ - LEFT JOIN rejections ON event_edges.event_id == rejections.event_id - WHERE - event_edges.room_id = ? - AND event_edges.prev_event_id = ? - /* It's not a valid edge if the event referencing our event in - * question is rejected. - */ - AND rejections.event_id IS NULL - LIMIT 1 - """ - - # We consider any forward extremity as the latest in the room and - # not a forward gap. - # - # To expand, even though there is technically a gap at the front of - # the room where the forward extremities are, we consider those the - # latest messages in the room so asking other homeservers for more - # is useless. The new latest messages will just be federated as - # usual. - txn.execute(forward_extremity_query, (event.room_id, event.event_id)) - forward_extremities = txn.fetchall() - if len(forward_extremities): - return False - - # If there are no forward edges to the event in question (another - # event hasn't referenced this event in their prev_events), then we - # assume there is a forward gap in the history. - txn.execute(forward_edge_query, (event.room_id, event.event_id)) - forward_edges = txn.fetchall() - if not len(forward_edges): - return True - - return False - - return await self.db_pool.runInteraction( - "is_event_next_to_gap_txn", - is_event_next_to_gap_txn, - ) - - async def get_event_id_for_timestamp( - self, room_id: str, timestamp: int, direction: str - ) -> Optional[str]: - """Find the closest event to the given timestamp in the given direction. - - Args: - room_id: Room to fetch the event from - timestamp: The point in time (inclusive) we should navigate from in - the given direction to find the closest event. - direction: ["f"|"b"] to indicate whether we should navigate forward - or backward from the given timestamp to find the closest event. - - Returns: - The closest event_id otherwise None if we can't find any event in - the given direction. - """ - - sql_template = """ - SELECT event_id FROM events - LEFT JOIN rejections USING (event_id) - WHERE - origin_server_ts %s ? - AND room_id = ? - /* Make sure event is not rejected */ - AND rejections.event_id IS NULL - ORDER BY origin_server_ts %s - LIMIT 1; - """ - - def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]: - if direction == "b": - # Find closest event *before* a given timestamp. We use descending - # (which gives values largest to smallest) because we want the - # largest possible timestamp *before* the given timestamp. - comparison_operator = "<=" - order = "DESC" - else: - # Find closest event *after* a given timestamp. We use ascending - # (which gives values smallest to largest) because we want the - # closest possible timestamp *after* the given timestamp. - comparison_operator = ">=" - order = "ASC" - - txn.execute( - sql_template % (comparison_operator, order), (timestamp, room_id) - ) - row = txn.fetchone() - if row: - (event_id,) = row - return event_id - - return None - - if direction not in ("f", "b"): - raise ValueError("Unknown direction: %s" % (direction,)) - - return await self.db_pool.runInteraction( - "get_event_id_for_timestamp_txn", - get_event_id_for_timestamp_txn, - ) diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 91b0576b85..3eb30944bf 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py
@@ -118,7 +118,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): logger.info("[purge] looking for events to delete") - should_delete_expr = "state_events.state_key IS NULL" + should_delete_expr = "state_key IS NULL" should_delete_params: Tuple[Any, ...] = () if not delete_local_events: should_delete_expr += " AND event_id NOT LIKE ?" diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 3b63267395..fa782023d4 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py
@@ -28,10 +28,7 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException -from synapse.storage.util.id_generators import ( - AbstractStreamIdTracker, - StreamIdGenerator, -) +from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -85,9 +82,9 @@ class PushRulesWorkerStore( super().__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: - self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, "push_rules_stream", "stream_id" - ) + self._push_rules_stream_id_gen: Union[ + StreamIdGenerator, SlavedIdTracker + ] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id") else: self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id" diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e1ddf06916..0e8c168667 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -106,15 +106,6 @@ class RefreshTokenLookupResult: has_next_access_token_been_used: bool """True if the next access token was already used at least once.""" - expiry_ts: Optional[int] - """The time at which the refresh token expires and can not be used. - If None, the refresh token doesn't expire.""" - - ultimate_session_expiry_ts: Optional[int] - """The time at which the session comes to an end and can no longer be - refreshed. - If None, the session can be refreshed indefinitely.""" - class RegistrationWorkerStore(CacheInvalidationWorkerStore): def __init__( @@ -1635,10 +1626,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): rt.user_id, rt.device_id, rt.next_token_id, - (nrt.next_token_id IS NOT NULL) AS has_next_refresh_token_been_refreshed, - at.used AS has_next_access_token_been_used, - rt.expiry_ts, - rt.ultimate_session_expiry_ts + (nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed, + at.used has_next_access_token_been_used FROM refresh_tokens rt LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id @@ -1659,8 +1648,6 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): has_next_refresh_token_been_refreshed=row[4], # This column is nullable, ensure it's a boolean has_next_access_token_been_used=(row[5] or False), - expiry_ts=row[6], - ultimate_session_expiry_ts=row[7], ) return await self.db_pool.runInteraction( @@ -1928,8 +1915,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): user_id: str, token: str, device_id: Optional[str], - expiry_ts: Optional[int], - ultimate_session_expiry_ts: Optional[int], ) -> int: """Adds a refresh token for the given user. @@ -1937,13 +1922,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): user_id: The user ID. token: The new access token to add. device_id: ID of the device to associate with the refresh token. - expiry_ts (milliseconds since the epoch): Time after which the - refresh token cannot be used. - If None, the refresh token never expires until it has been used. - ultimate_session_expiry_ts (milliseconds since the epoch): - Time at which the session will end and can not be extended any - further. - If None, the session can be refreshed indefinitely. Raises: StoreError if there was a problem adding this. Returns: @@ -1959,8 +1937,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): "device_id": device_id, "token": token, "next_token_id": None, - "expiry_ts": expiry_ts, - "ultimate_session_expiry_ts": ultimate_session_expiry_ts, }, desc="add_refresh_token_to_user", ) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 6b2a8d06a6..033a9831d6 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py
@@ -476,7 +476,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): INNER JOIN events AS e USING (room_id, event_id) WHERE c.type = 'm.room.member' - AND c.state_key = ? + AND state_key = ? AND c.membership = ? """ else: @@ -487,7 +487,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): INNER JOIN events AS e USING (room_id, event_id) WHERE c.type = 'm.room.member' - AND c.state_key = ? + AND state_key = ? AND m.membership = ? """ diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 57aab55259..42dc807d17 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py
@@ -497,7 +497,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): oldest `limit` events. Returns: - The list of events (in ascending stream order) and the token from the start + The list of events (in ascending order) and the token from the start of the chunk of events returned. """ if from_key == to_key: @@ -510,7 +510,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): if not has_changed: return [], from_key - def f(txn: LoggingTransaction) -> List[_EventDictReturn]: + def f(txn): # To handle tokens with a non-empty instance_map we fetch more # results than necessary and then filter down min_from_id = from_key.stream @@ -565,13 +565,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): async def get_membership_changes_for_user( self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken ) -> List[EventBase]: - """Fetch membership events for a given user. - - All such events whose stream ordering `s` lies in the range - `from_key < s <= to_key` are returned. Events are ordered by ascending stream - order. - """ - # Start by ruling out cases where a DB query is not necessary. if from_key == to_key: return [] @@ -582,7 +575,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): if not has_changed: return [] - def f(txn: LoggingTransaction) -> List[_EventDictReturn]: + def f(txn): # To handle tokens with a non-empty instance_map we fetch more # results than necessary and then filter down min_from_id = from_key.stream @@ -641,7 +634,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): Returns: A list of events and a token pointing to the start of the returned - events. The events returned are in ascending topological order. + events. The events returned are in ascending order. """ rows, token = await self.get_recent_event_ids_for_room( diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 1622822552..d7dc1f73ac 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py
@@ -14,7 +14,6 @@ import logging from collections import namedtuple -from enum import Enum from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple import attr @@ -45,16 +44,6 @@ _UpdateTransactionRow = namedtuple( ) -class DestinationSortOrder(Enum): - """Enum to define the sorting method used when returning destinations.""" - - DESTINATION = "destination" - RETRY_LAST_TS = "retry_last_ts" - RETTRY_INTERVAL = "retry_interval" - FAILURE_TS = "failure_ts" - LAST_SUCCESSFUL_STREAM_ORDERING = "last_successful_stream_ordering" - - @attr.s(slots=True, frozen=True, auto_attribs=True) class DestinationRetryTimings: """The current destination retry timing info for a remote server.""" @@ -491,62 +480,3 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): destinations = [row[0] for row in txn] return destinations - - async def get_destinations_paginate( - self, - start: int, - limit: int, - destination: Optional[str] = None, - order_by: str = DestinationSortOrder.DESTINATION.value, - direction: str = "f", - ) -> Tuple[List[JsonDict], int]: - """Function to retrieve a paginated list of destinations. - This will return a json list of destinations and the - total number of destinations matching the filter criteria. - - Args: - start: start number to begin the query from - limit: number of rows to retrieve - destination: search string in destination - order_by: the sort order of the returned list - direction: sort ascending or descending - Returns: - A tuple of a list of mappings from destination to information - and a count of total destinations. - """ - - def get_destinations_paginate_txn( - txn: LoggingTransaction, - ) -> Tuple[List[JsonDict], int]: - order_by_column = DestinationSortOrder(order_by).value - - if direction == "b": - order = "DESC" - else: - order = "ASC" - - args = [] - where_statement = "" - if destination: - args.extend(["%" + destination.lower() + "%"]) - where_statement = "WHERE LOWER(destination) LIKE ?" - - sql_base = f"FROM destinations {where_statement} " - sql = f"SELECT COUNT(*) as total_destinations {sql_base}" - txn.execute(sql, args) - count = txn.fetchone()[0] - - sql = f""" - SELECT destination, retry_last_ts, retry_interval, failure_ts, - last_successful_stream_ordering - {sql_base} - ORDER BY {order_by_column} {order}, destination ASC - LIMIT ? OFFSET ? - """ - txn.execute(sql, args + [limit, start]) - destinations = self.db_pool.cursor_to_dict(txn) - return destinations, count - - return await self.db_pool.runInteraction( - "get_destinations_paginate_txn", get_destinations_paginate_txn - )