summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
authorSean Quah <seanq@element.io>2021-12-07 16:38:29 +0000
committerSean Quah <seanq@element.io>2021-12-07 16:47:31 +0000
commit158d73ebdd61eef33831ae5f6990acf07244fc55 (patch)
tree723f79596374042e349d55a6195cbe2b5eea29eb /synapse/storage/databases
parentSort internal changes in changelog (diff)
downloadsynapse-158d73ebdd61eef33831ae5f6990acf07244fc55.tar.xz
Revert accidental fast-forward merge from v1.49.0rc1
Revert "Sort internal changes in changelog"
Revert "Update CHANGES.md"
Revert "1.49.0rc1"
Revert "Revert "Move `glob_to_regex` and `re_word_boundary` to `matrix-python-common` (#11505) (#11527)"
Revert "Refactors in `_generate_sync_entry_for_rooms` (#11515)"
Revert "Correctly register shutdown handler for presence workers (#11518)"
Revert "Fix `ModuleApi.looping_background_call` for non-async functions (#11524)"
Revert "Fix 'delete room' admin api to work on incomplete rooms (#11523)"
Revert "Correctly ignore invites from ignored users (#11511)"
Revert "Fix the test breakage introduced by #11435 as a result of concurrent PRs (#11522)"
Revert "Stabilise support for MSC2918 refresh tokens as they have now been merged into the Matrix specification. (#11435)"
Revert "Save the OIDC session ID (sid) with the device on login (#11482)"
Revert "Add admin API to get some information about federation status (#11407)"
Revert "Include bundled aggregations in /sync and related fixes (#11478)"
Revert "Move `glob_to_regex` and `re_word_boundary` to `matrix-python-common` (#11505)"
Revert "Update backward extremity docs to make it clear that it does not indicate whether we have fetched an events' `prev_events` (#11469)"
Revert "Support configuring the lifetime of non-refreshable access tokens separately to refreshable access tokens. (#11445)"
Revert "Add type hints to `synapse/tests/rest/admin` (#11501)"
Revert "Revert accidental commits to develop."
Revert "Newsfile"
Revert "Give `tests.server.setup_test_homeserver` (nominally!) the same behaviour"
Revert "Move `tests.utils.setup_test_homeserver` to `tests.server`"
Revert "Convert one of the `setup_test_homeserver`s to `make_test_homeserver_synchronous`"
Revert "Disambiguate queries on `state_key` (#11497)"
Revert "Comments on the /sync tentacles (#11494)"
Revert "Clean up tests.storage.test_appservice (#11492)"
Revert "Clean up `tests.storage.test_main` to remove use of legacy code. (#11493)"
Revert "Clean up `tests.test_visibility` to remove legacy code. (#11495)"
Revert "Minor cleanup on recently ported doc pages  (#11466)"
Revert "Add most of the missing type hints to `synapse.federation`. (#11483)"
Revert "Avoid waiting for zombie processes in `synctl stop` (#11490)"
Revert "Fix media repository failing when media store path contains symlinks (#11446)"
Revert "Add type annotations to `tests.storage.test_appservice`. (#11488)"
Revert "`scripts-dev/sign_json`: support for signing events (#11486)"
Revert "Add MSC3030 experimental client and federation API endpoints to get the closest event to a given timestamp (#9445)"
Revert "Port wiki pages to documentation website (#11402)"
Revert "Add a license header and comment. (#11479)"
Revert "Clean-up get_version_string (#11468)"
Revert "Link background update controller docs to summary (#11475)"
Revert "Additional type hints for config module. (#11465)"
Revert "Register the login redirect endpoint for v3. (#11451)"
Revert "Update openid.md"
Revert "Remove mention of OIDC certification from Dex (#11470)"
Revert "Add a note about huge pages to our Postgres doc (#11467)"
Revert "Don't start Synapse master process if `worker_app` is set (#11416)"
Revert "Expose worker & homeserver as entrypoints in `setup.py` (#11449)"
Revert "Bundle relations of relations into the `/relations` result. (#11284)"
Revert "Fix `LruCache` corruption bug with a `size_callback` that can return 0 (#11454)"
Revert "Eliminate a few `Any`s in `LruCache` type hints (#11453)"
Revert "Remove unnecessary `json.dumps` from `tests.rest.admin` (#11461)"
Revert "Merge branch 'master' into develop"

This reverts commit 26b5d2320f62b5eb6262c7614fbdfc364a4dfc02.
This reverts commit bce4220f387bf5448387f0ed7d14ed1e41e40747.
This reverts commit 966b5d0fa0893c3b628c942dfc232e285417f46d.
This reverts commit 088d748f2cb51f03f3bcacc0fb3af1e0f9607737.
This reverts commit 14d593f72d10b4d8cb67e3288bb3131ee30ccf59.
This reverts commit 2a3ec6facf79f6aae011d9fb6f9ed5e43c7b6bec.
This reverts commit eccc49d7554d1fab001e1fefb0fda8ffb254b630.
This reverts commit b1ecd19c5d19815b69e425d80f442bf2877cab76.
This reverts commit 9c55dedc8c4484e6269451a8c3c10b3e314aeb4a.
This reverts commit 2d42e586a8c54be1a83643148358b1651c1ca666.
This reverts commit 2f053f3f82ca174cc1c858c75afffae51af8ce0d.
This reverts commit a15a893df8428395df7cb95b729431575001c38a.
This reverts commit 8b4b153c9e86c04c7db8c74fde4b6a04becbc461.
This reverts commit 494ebd7347ba52d702802fba4c3bb13e7bfbc2cf.
This reverts commit a77c36989785c0d5565ab9a1169f4f88e512ce8a.
This reverts commit 4eb77965cd016181d2111f37d93526e9bb0434f0.
This reverts commit 637df95de63196033a6da4a6e286e1d58ea517b6.
This reverts commit e5f426cd54609e7f05f8241d845e6e36c5f10d9a.
This reverts commit 8cd68b8102eeab1b525712097c1b2e9679c11896.
This reverts commit 6cae125e20865c52d770b24278bb7ab8fde5bc0d.
This reverts commit 7be88fbf48156b36b6daefb228e1258e7d48cae4.
This reverts commit b3fd99b74a3f6f42a9afd1b19ee4c60e38e8e91a.
This reverts commit f7ec6e7d9e0dc360d9fb41f3a1afd7bdba1475c7.
This reverts commit 5640992d176a499204a0756b1677c9b1575b0a49.
This reverts commit d26808dd854006bd26a2366c675428ce0737238c.
This reverts commit f91624a5950e14ba9007eed9bfa1c828676d4745.
This reverts commit 16d39a5490ce74c901c7a8dbb990c6e83c379207.
This reverts commit 8a4c2969874c0b7d72003f2523883eba8a348e83.
This reverts commit 49e1356ee3d5d72929c91f778b3a231726c1413c.
This reverts commit d2279f471ba8f44d9f578e62b286897a338d8aa1.
This reverts commit b50e39df578adc3f86c5efa16bee9035cfdab61b.
This reverts commit 858d80bf0f9f656a03992794874081b806e49222.
This reverts commit 435f04480728c5d982e1a63c1b2777784bf9cd26.
This reverts commit f61462e1be36a51dbf571076afa8e1930cb182f4.
This reverts commit a6f1a3abecf8e8fd3e1bff439a06b853df18f194.
This reverts commit 84dc50e160a2ec6590813374b5a1e58b97f7a18d.
This reverts commit ed635d32853ee0a3e5ec1078679b27e7844a4ac7.
This reverts commit 7b62791e001d6a4f8897ed48b3232d7f8fe6aa48.
This reverts commit 153194c7717d8016b0eb974c81b1baee7dc1917d.
This reverts commit f44d729d4ccae61bc0cdd5774acb3233eb5f7c13.
This reverts commit a265fbd397ae72b2d3ea4c9310591ff1d0f3e05c.
This reverts commit b9fef1a7cdfcc128fa589a32160e6aa7ed8964d7.
This reverts commit b0eb64ff7bf6bde42046e091f8bdea9b7aab5f04.
This reverts commit f1795463bf503a6fca909d77f598f641f9349f56.
This reverts commit 70cbb1a5e311f609b624e3fae1a1712db639c51e.
This reverts commit 42bf0204635213e2c75188b19ee66dc7e7d8a35e.
This reverts commit 379f2650cf875f50c59524147ec0e33cfd5ef60c.
This reverts commit 7ff22d6da41cd5ca80db95c18b409aea38e49fcd.
This reverts commit 5a0b652d36ae4b6d423498c1f2c82c97a49c6f75.
This reverts commit 432a174bc192740ac7a0a755009f6099b8363ad9.
This reverts commit b14f8a1baf6f500997ae4c1d6a6d72094ce14270, reversing
changes made to e713855dca17a7605bae99ea8d71bc7f8657e4b8.
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/appservice.py6
-rw-r--r--synapse/storage/databases/main/devices.py50
-rw-r--r--synapse/storage/databases/main/event_federation.py4
-rw-r--r--synapse/storage/databases/main/event_push_actions.py19
-rw-r--r--synapse/storage/databases/main/events.py107
-rw-r--r--synapse/storage/databases/main/events_worker.py581
-rw-r--r--synapse/storage/databases/main/purge_events.py2
-rw-r--r--synapse/storage/databases/main/push_rule.py11
-rw-r--r--synapse/storage/databases/main/registration.py28
-rw-r--r--synapse/storage/databases/main/roommember.py4
-rw-r--r--synapse/storage/databases/main/stream.py15
-rw-r--r--synapse/storage/databases/main/transactions.py70
12 files changed, 180 insertions, 717 deletions
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py

index 4a883dc166..baec35ee27 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py
@@ -143,7 +143,7 @@ class ApplicationServiceTransactionWorkerStore( A list of ApplicationServices, which may be empty. """ results = await self.db_pool.simple_select_list( - "application_services_state", {"state": state.value}, ["as_id"] + "application_services_state", {"state": state}, ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore as_list = self.get_app_services() @@ -173,7 +173,7 @@ class ApplicationServiceTransactionWorkerStore( desc="get_appservice_state", ) if result: - return ApplicationServiceState(result.get("state")) + return result.get("state") return None async def set_appservice_state( @@ -186,7 +186,7 @@ class ApplicationServiceTransactionWorkerStore( state: The connectivity state to apply. """ await self.db_pool.simple_upsert( - "application_services_state", {"as_id": service.id}, {"state": state.value} + "application_services_state", {"as_id": service.id}, {"state": state} ) async def create_appservice_txn( diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index d5a4a661cd..9ccc66e589 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -139,27 +139,6 @@ class DeviceWorkerStore(SQLBaseStore): return {d["device_id"]: d for d in devices} - async def get_devices_by_auth_provider_session_id( - self, auth_provider_id: str, auth_provider_session_id: str - ) -> List[Dict[str, Any]]: - """Retrieve the list of devices associated with a SSO IdP session ID. - - Args: - auth_provider_id: The SSO IdP ID as defined in the server config - auth_provider_session_id: The session ID within the IdP - Returns: - A list of dicts containing the device_id and the user_id of each device - """ - return await self.db_pool.simple_select_list( - table="device_auth_providers", - keyvalues={ - "auth_provider_id": auth_provider_id, - "auth_provider_session_id": auth_provider_session_id, - }, - retcols=("user_id", "device_id"), - desc="get_devices_by_auth_provider_session_id", - ) - @trace async def get_device_updates_by_remote( self, destination: str, from_stream_id: int, limit: int @@ -1091,12 +1070,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) async def store_device( - self, - user_id: str, - device_id: str, - initial_device_display_name: Optional[str], - auth_provider_id: Optional[str] = None, - auth_provider_session_id: Optional[str] = None, + self, user_id: str, device_id: str, initial_device_display_name: Optional[str] ) -> bool: """Ensure the given device is known; add it to the store if not @@ -1105,8 +1079,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): device_id: id of device initial_device_display_name: initial displayname of the device. Ignored if device exists. - auth_provider_id: The SSO IdP the user used, if any. - auth_provider_session_id: The session ID (sid) got from a OIDC login. Returns: Whether the device was inserted or an existing device existed with that ID. @@ -1143,18 +1115,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if hidden: raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN) - if auth_provider_id and auth_provider_session_id: - await self.db_pool.simple_insert( - "device_auth_providers", - values={ - "user_id": user_id, - "device_id": device_id, - "auth_provider_id": auth_provider_id, - "auth_provider_session_id": auth_provider_session_id, - }, - desc="store_device_auth_provider", - ) - self.device_id_exists_cache.set(key, True) return inserted except StoreError: @@ -1208,14 +1168,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): keyvalues={"user_id": user_id}, ) - self.db_pool.simple_delete_many_txn( - txn, - table="device_auth_providers", - column="device_id", - values=device_ids, - keyvalues={"user_id": user_id}, - ) - await self.db_pool.runInteraction("delete_devices", _delete_devices_txn) for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 9580a40785..ef5d1ef01e 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py
@@ -1552,9 +1552,9 @@ class EventFederationStore(EventFederationWorkerStore): DELETE FROM event_auth WHERE event_id IN ( SELECT event_id FROM events - LEFT JOIN state_events AS se USING (room_id, event_id) + LEFT JOIN state_events USING (room_id, event_id) WHERE ? <= stream_ordering AND stream_ordering < ? - AND se.state_key IS null + AND state_key IS null ) """ diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 3efdd0c920..d957e770dc 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py
@@ -16,7 +16,6 @@ import logging from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import attr -from typing_extensions import TypedDict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json @@ -38,20 +37,6 @@ DEFAULT_HIGHLIGHT_ACTION = [ ] -class BasePushAction(TypedDict): - event_id: str - actions: List[Union[dict, str]] - - -class HttpPushAction(BasePushAction): - room_id: str - stream_ordering: int - - -class EmailPushAction(HttpPushAction): - received_ts: Optional[int] - - def _serialize_action(actions, is_highlight): """Custom serializer for actions. This allows us to "compress" common actions. @@ -236,7 +221,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): min_stream_ordering: int, max_stream_ordering: int, limit: int = 20, - ) -> List[HttpPushAction]: + ) -> List[dict]: """Get a list of the most recent unread push actions for a given user, within the given stream ordering range. Called by the httppusher. @@ -341,7 +326,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): min_stream_ordering: int, max_stream_ordering: int, limit: int = 20, - ) -> List[EmailPushAction]: + ) -> List[dict]: """Get a list of the most recent unread push actions for a given user, within the given stream ordering range. Called by the emailpusher diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 4e528612ea..06832221ad 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -15,7 +15,7 @@ # limitations under the License. import itertools import logging -from collections import OrderedDict +from collections import OrderedDict, namedtuple from typing import ( TYPE_CHECKING, Any, @@ -41,10 +41,9 @@ from synapse.events.snapshot import EventContext # noqa: F401 from synapse.logging.utils import log_function from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction -from synapse.storage.databases.main.events_worker import EventCacheEntry from synapse.storage.databases.main.search import SearchEntry from synapse.storage.types import Connection -from synapse.storage.util.id_generators import AbstractStreamIdGenerator +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.storage.util.sequence import SequenceGenerator from synapse.types import StateMap, get_domain_from_id from synapse.util import json_encoder @@ -65,6 +64,9 @@ event_counter = Counter( ) +_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) + + @attr.s(slots=True) class DeltaState: """Deltas to use to update the `current_state_events` table. @@ -106,30 +108,23 @@ class PersistEventsStore: self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id + # Ideally we'd move these ID gens here, unfortunately some other ID + # generators are chained off them so doing so is a bit of a PITA. + self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen + self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen + # This should only exist on instances that are configured to write assert ( hs.get_instance_name() in hs.config.worker.writers.events ), "Can only instantiate EventsStore on master" - # Since we have been configured to write, we ought to have id generators, - # rather than id trackers. - assert isinstance(self.store._backfill_id_gen, AbstractStreamIdGenerator) - assert isinstance(self.store._stream_id_gen, AbstractStreamIdGenerator) - - # Ideally we'd move these ID gens here, unfortunately some other ID - # generators are chained off them so doing so is a bit of a PITA. - self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen - self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen - async def _persist_events_and_state_updates( self, events_and_contexts: List[Tuple[EventBase, EventContext]], - *, current_state_for_room: Dict[str, StateMap[str]], state_delta_for_room: Dict[str, DeltaState], new_forward_extremeties: Dict[str, List[str]], - use_negative_stream_ordering: bool = False, - inhibit_local_membership_updates: bool = False, + backfilled: bool = False, ) -> None: """Persist a set of events alongside updates to the current state and forward extremities tables. @@ -142,14 +137,7 @@ class PersistEventsStore: room state new_forward_extremities: Map from room_id to list of event IDs that are the new forward extremities of the room. - use_negative_stream_ordering: Whether to start stream_ordering on - the negative side and decrement. This should be set as True - for backfilled events because backfilled events get a negative - stream ordering so they don't come down incremental `/sync`. - inhibit_local_membership_updates: Stop the local_current_membership - from being updated by these events. This should be set to True - for backfilled events because backfilled events in the past do - not affect the current local state. + backfilled Returns: Resolves when the events have been persisted @@ -171,7 +159,7 @@ class PersistEventsStore: # # Note: Multiple instances of this function cannot be in flight at # the same time for the same room. - if use_negative_stream_ordering: + if backfilled: stream_ordering_manager = self._backfill_id_gen.get_next_mult( len(events_and_contexts) ) @@ -188,13 +176,13 @@ class PersistEventsStore: "persist_events", self._persist_events_txn, events_and_contexts=events_and_contexts, - inhibit_local_membership_updates=inhibit_local_membership_updates, + backfilled=backfilled, state_delta_for_room=state_delta_for_room, new_forward_extremeties=new_forward_extremeties, ) persist_event_counter.inc(len(events_and_contexts)) - if stream < 0: + if not backfilled: # backfilled events have negative stream orderings, so we don't # want to set the event_persisted_position to that. synapse.metrics.event_persisted_position.set( @@ -328,9 +316,8 @@ class PersistEventsStore: def _persist_events_txn( self, txn: LoggingTransaction, - *, events_and_contexts: List[Tuple[EventBase, EventContext]], - inhibit_local_membership_updates: bool = False, + backfilled: bool, state_delta_for_room: Optional[Dict[str, DeltaState]] = None, new_forward_extremeties: Optional[Dict[str, List[str]]] = None, ): @@ -343,10 +330,7 @@ class PersistEventsStore: Args: txn events_and_contexts: events to persist - inhibit_local_membership_updates: Stop the local_current_membership - from being updated by these events. This should be set to True - for backfilled events because backfilled events in the past do - not affect the current local state. + backfilled: True if the events were backfilled delete_existing True to purge existing table rows for the events from the database. This is useful when retrying due to IntegrityError. @@ -379,7 +363,9 @@ class PersistEventsStore: events_and_contexts ) - self._update_room_depths_txn(txn, events_and_contexts=events_and_contexts) + self._update_room_depths_txn( + txn, events_and_contexts=events_and_contexts, backfilled=backfilled + ) # _update_outliers_txn filters out any events which have already been # persisted, and returns the filtered list. @@ -412,7 +398,7 @@ class PersistEventsStore: txn, events_and_contexts=events_and_contexts, all_events_and_contexts=all_events_and_contexts, - inhibit_local_membership_updates=inhibit_local_membership_updates, + backfilled=backfilled, ) # We call this last as it assumes we've inserted the events into @@ -575,9 +561,9 @@ class PersistEventsStore: # fetch their auth event info. while missing_auth_chains: sql = """ - SELECT event_id, events.type, se.state_key, chain_id, sequence_number + SELECT event_id, events.type, state_key, chain_id, sequence_number FROM events - INNER JOIN state_events AS se USING (event_id) + INNER JOIN state_events USING (event_id) LEFT JOIN event_auth_chains USING (event_id) WHERE """ @@ -1214,6 +1200,7 @@ class PersistEventsStore: self, txn, events_and_contexts: List[Tuple[EventBase, EventContext]], + backfilled: bool, ): """Update min_depth for each room @@ -1221,18 +1208,13 @@ class PersistEventsStore: txn (twisted.enterprise.adbapi.Connection): db connection events_and_contexts (list[(EventBase, EventContext)]): events we are persisting + backfilled (bool): True if the events were backfilled """ depth_updates: Dict[str, int] = {} for event, context in events_and_contexts: # Remove the any existing cache entries for the event_ids txn.call_after(self.store._invalidate_get_event_cache, event.event_id) - # Then update the `stream_ordering` position to mark the latest - # event as the front of the room. This should not be done for - # backfilled events because backfilled events have negative - # stream_ordering and happened in the past so we know that we don't - # need to update the stream_ordering tip/front for the room. - assert event.internal_metadata.stream_ordering is not None - if event.internal_metadata.stream_ordering >= 0: + if not backfilled: txn.call_after( self.store._events_stream_cache.entity_has_changed, event.room_id, @@ -1445,12 +1427,7 @@ class PersistEventsStore: return [ec for ec in events_and_contexts if ec[0] not in to_remove] def _update_metadata_tables_txn( - self, - txn, - *, - events_and_contexts, - all_events_and_contexts, - inhibit_local_membership_updates: bool = False, + self, txn, events_and_contexts, all_events_and_contexts, backfilled ): """Update all the miscellaneous tables for new events @@ -1462,10 +1439,7 @@ class PersistEventsStore: events that we were going to persist. This includes events we've already persisted, etc, that wouldn't appear in events_and_context. - inhibit_local_membership_updates: Stop the local_current_membership - from being updated by these events. This should be set to True - for backfilled events because backfilled events in the past do - not affect the current local state. + backfilled (bool): True if the events were backfilled """ # Insert all the push actions into the event_push_actions table. @@ -1539,7 +1513,7 @@ class PersistEventsStore: for event, _ in events_and_contexts if event.type == EventTypes.Member ], - inhibit_local_membership_updates=inhibit_local_membership_updates, + backfilled=backfilled, ) # Insert event_reference_hashes table. @@ -1579,13 +1553,11 @@ class PersistEventsStore: for row in rows: event = ev_map[row["event_id"]] if not row["rejects"] and not row["redacts"]: - to_prefill.append(EventCacheEntry(event=event, redacted_event=None)) + to_prefill.append(_EventCacheEntry(event=event, redacted_event=None)) def prefill(): for cache_entry in to_prefill: - self.store._get_event_cache.set( - (cache_entry.event.event_id,), cache_entry - ) + self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry) txn.call_after(prefill) @@ -1666,19 +1638,8 @@ class PersistEventsStore: txn, table="event_reference_hashes", values=vals ) - def _store_room_members_txn( - self, txn, events, *, inhibit_local_membership_updates: bool = False - ): - """ - Store a room member in the database. - Args: - txn: The transaction to use. - events: List of events to store. - inhibit_local_membership_updates: Stop the local_current_membership - from being updated by these events. This should be set to True - for backfilled events because backfilled events in the past do - not affect the current local state. - """ + def _store_room_members_txn(self, txn, events, backfilled): + """Store a room member in the database.""" def non_null_str_or_none(val: Any) -> Optional[str]: return val if isinstance(val, str) and "\u0000" not in val else None @@ -1721,7 +1682,7 @@ class PersistEventsStore: # band membership", like a remote invite or a rejection of a remote invite. if ( self.is_mine_id(event.state_key) - and not inhibit_local_membership_updates + and not backfilled and event.internal_metadata.is_outlier() and event.internal_metadata.is_out_of_band_membership() ): diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c7b660ac5a..c6bf316d5b 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py
@@ -15,18 +15,14 @@ import logging import threading from typing import ( - TYPE_CHECKING, - Any, Collection, Container, Dict, Iterable, List, - NoReturn, Optional, Set, Tuple, - cast, overload, ) @@ -42,7 +38,6 @@ from synapse.api.errors import NotFoundError, SynapseError from synapse.api.room_versions import ( KNOWN_ROOM_VERSIONS, EventFormatVersions, - RoomVersion, RoomVersions, ) from synapse.events import EventBase, make_event_from_dict @@ -61,18 +56,10 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import ( - DatabasePool, - LoggingDatabaseConnection, - LoggingTransaction, -) +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.engines import PostgresEngine -from synapse.storage.types import Cursor -from synapse.storage.util.id_generators import ( - AbstractStreamIdTracker, - MultiWriterIdGenerator, - StreamIdGenerator, -) +from synapse.storage.types import Connection +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id from synapse.util import unwrapFirstError @@ -82,13 +69,10 @@ from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure -if TYPE_CHECKING: - from synapse.server import HomeServer - logger = logging.getLogger(__name__) -# These values are used in the `enqueue_event` and `_fetch_loop` methods to +# These values are used in the `enqueus_event` and `_do_fetch` methods to # control how we batch/bulk fetch events from the database. # The values are plucked out of thing air to make initial sync run faster # on jki.re @@ -105,7 +89,7 @@ event_fetch_ongoing_gauge = Gauge( @attr.s(slots=True, auto_attribs=True) -class EventCacheEntry: +class _EventCacheEntry: event: EventBase redacted_event: Optional[EventBase] @@ -145,7 +129,7 @@ class _EventRow: json: str internal_metadata: str format_version: Optional[int] - room_version_id: Optional[str] + room_version_id: Optional[int] rejected_reason: Optional[str] redactions: List[str] outlier: bool @@ -169,16 +153,9 @@ class EventsWorkerStore(SQLBaseStore): # options controlling this. USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True - def __init__( - self, - database: DatabasePool, - db_conn: LoggingDatabaseConnection, - hs: "HomeServer", - ): + def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self._stream_id_gen: AbstractStreamIdTracker - self._backfill_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): # If we're using Postgres than we can use `MultiWriterIdGenerator` # regardless of whether this process writes to the streams or not. @@ -237,7 +214,7 @@ class EventsWorkerStore(SQLBaseStore): 5 * 60 * 1000, ) - self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache( + self._get_event_cache = LruCache( cache_name="*getEvent*", max_size=hs.config.caches.event_cache_size, ) @@ -246,21 +223,19 @@ class EventsWorkerStore(SQLBaseStore): # ID to cache entry. Note that the returned dict may not have the # requested event in it if the event isn't in the DB. self._current_event_fetches: Dict[ - str, ObservableDeferred[Dict[str, EventCacheEntry]] + str, ObservableDeferred[Dict[str, _EventCacheEntry]] ] = {} self._event_fetch_lock = threading.Condition() - self._event_fetch_list: List[ - Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"] - ] = [] + self._event_fetch_list = [] self._event_fetch_ongoing = 0 event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) # We define this sequence here so that it can be referenced from both # the DataStore and PersistEventStore. - def get_chain_id_txn(txn: Cursor) -> int: + def get_chain_id_txn(txn): txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains") - return cast(Tuple[int], txn.fetchone())[0] + return txn.fetchone()[0] self.event_chain_id_gen = build_sequence_generator( db_conn, @@ -271,13 +246,7 @@ class EventsWorkerStore(SQLBaseStore): id_column="chain_id", ) - def process_replication_rows( - self, - stream_name: str, - instance_name: str, - token: int, - rows: Iterable[Any], - ) -> None: + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == EventsStream.NAME: self._stream_id_gen.advance(instance_name, token) elif stream_name == BackfillStream.NAME: @@ -311,10 +280,10 @@ class EventsWorkerStore(SQLBaseStore): self, event_id: str, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, - get_prev_content: bool = ..., - allow_rejected: bool = ..., - allow_none: Literal[False] = ..., - check_room_id: Optional[str] = ..., + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[False] = False, + check_room_id: Optional[str] = None, ) -> EventBase: ... @@ -323,10 +292,10 @@ class EventsWorkerStore(SQLBaseStore): self, event_id: str, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, - get_prev_content: bool = ..., - allow_rejected: bool = ..., - allow_none: Literal[True] = ..., - check_room_id: Optional[str] = ..., + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[True] = False, + check_room_id: Optional[str] = None, ) -> Optional[EventBase]: ... @@ -388,7 +357,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_events( self, - event_ids: Collection[str], + event_ids: Iterable[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, @@ -575,7 +544,7 @@ class EventsWorkerStore(SQLBaseStore): async def _get_events_from_cache_or_db( self, event_ids: Iterable[str], allow_rejected: bool = False - ) -> Dict[str, EventCacheEntry]: + ) -> Dict[str, _EventCacheEntry]: """Fetch a bunch of events from the cache or the database. If events are pulled from the database, they will be cached for future lookups. @@ -609,7 +578,7 @@ class EventsWorkerStore(SQLBaseStore): # same dict into itself N times). already_fetching_ids: Set[str] = set() already_fetching_deferreds: Set[ - ObservableDeferred[Dict[str, EventCacheEntry]] + ObservableDeferred[Dict[str, _EventCacheEntry]] ] = set() for event_id in missing_events_ids: @@ -632,8 +601,8 @@ class EventsWorkerStore(SQLBaseStore): # function returning more events than requested, but that can happen # already due to `_get_events_from_db`). fetching_deferred: ObservableDeferred[ - Dict[str, EventCacheEntry] - ] = ObservableDeferred(defer.Deferred(), consumeErrors=True) + Dict[str, _EventCacheEntry] + ] = ObservableDeferred(defer.Deferred()) for event_id in missing_events_ids: self._current_event_fetches[event_id] = fetching_deferred @@ -689,12 +658,12 @@ class EventsWorkerStore(SQLBaseStore): return event_entry_map - def _invalidate_get_event_cache(self, event_id: str) -> None: + def _invalidate_get_event_cache(self, event_id): self._get_event_cache.invalidate((event_id,)) def _get_events_from_cache( self, events: Iterable[str], update_metrics: bool = True - ) -> Dict[str, EventCacheEntry]: + ) -> Dict[str, _EventCacheEntry]: """Fetch events from the caches. May return rejected events. @@ -767,123 +736,38 @@ class EventsWorkerStore(SQLBaseStore): for e in state_to_include.values() ] - def _maybe_start_fetch_thread(self) -> None: - """Starts an event fetch thread if we are not yet at the maximum number.""" - with self._event_fetch_lock: - if ( - self._event_fetch_list - and self._event_fetch_ongoing < EVENT_QUEUE_THREADS - ): - self._event_fetch_ongoing += 1 - event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) - # `_event_fetch_ongoing` is decremented in `_fetch_thread`. - should_start = True - else: - should_start = False - - if should_start: - run_as_background_process("fetch_events", self._fetch_thread) - - async def _fetch_thread(self) -> None: - """Services requests for events from `_event_fetch_list`.""" - exc = None - try: - await self.db_pool.runWithConnection(self._fetch_loop) - except BaseException as e: - exc = e - raise - finally: - should_restart = False - event_fetches_to_fail = [] - with self._event_fetch_lock: - self._event_fetch_ongoing -= 1 - event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) - - # There may still be work remaining in `_event_fetch_list` if we - # failed, or it was added in between us deciding to exit and - # decrementing `_event_fetch_ongoing`. - if self._event_fetch_list: - if exc is None: - # We decided to exit, but then some more work was added - # before `_event_fetch_ongoing` was decremented. - # If a new event fetch thread was not started, we should - # restart ourselves since the remaining event fetch threads - # may take a while to get around to the new work. - # - # Unfortunately it is not possible to tell whether a new - # event fetch thread was started, so we restart - # unconditionally. If we are unlucky, we will end up with - # an idle fetch thread, but it will time out after - # `EVENT_QUEUE_ITERATIONS * EVENT_QUEUE_TIMEOUT_S` seconds - # in any case. - # - # Note that multiple fetch threads may run down this path at - # the same time. - should_restart = True - elif isinstance(exc, Exception): - if self._event_fetch_ongoing == 0: - # We were the last remaining fetcher and failed. - # Fail any outstanding fetches since no one else will - # handle them. - event_fetches_to_fail = self._event_fetch_list - self._event_fetch_list = [] - else: - # We weren't the last remaining fetcher, so another - # fetcher will pick up the work. This will either happen - # after their existing work, however long that takes, - # or after at most `EVENT_QUEUE_TIMEOUT_S` seconds if - # they are idle. - pass - else: - # The exception is a `SystemExit`, `KeyboardInterrupt` or - # `GeneratorExit`. Don't try to do anything clever here. - pass - - if should_restart: - # We exited cleanly but noticed more work. - self._maybe_start_fetch_thread() - - if event_fetches_to_fail: - # We were the last remaining fetcher and failed. - # Fail any outstanding fetches since no one else will handle them. - assert exc is not None - with PreserveLoggingContext(): - for _, deferred in event_fetches_to_fail: - deferred.errback(exc) - - def _fetch_loop(self, conn: LoggingDatabaseConnection) -> None: + def _do_fetch(self, conn: Connection) -> None: """Takes a database connection and waits for requests for events from the _event_fetch_list queue. """ - i = 0 - while True: - with self._event_fetch_lock: - event_list = self._event_fetch_list - self._event_fetch_list = [] - - if not event_list: - # There are no requests waiting. If we haven't yet reached the - # maximum iteration limit, wait for some more requests to turn up. - # Otherwise, bail out. - single_threaded = self.database_engine.single_threaded - if ( - not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING - or single_threaded - or i > EVENT_QUEUE_ITERATIONS - ): - return - - self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) - i += 1 - continue - i = 0 + try: + i = 0 + while True: + with self._event_fetch_lock: + event_list = self._event_fetch_list + self._event_fetch_list = [] + + if not event_list: + single_threaded = self.database_engine.single_threaded + if ( + not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING + or single_threaded + or i > EVENT_QUEUE_ITERATIONS + ): + break + else: + self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) + i += 1 + continue + i = 0 - self._fetch_event_list(conn, event_list) + self._fetch_event_list(conn, event_list) + finally: + self._event_fetch_ongoing -= 1 + event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) def _fetch_event_list( - self, - conn: LoggingDatabaseConnection, - event_list: List[Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]], + self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]] ) -> None: """Handle a load of requests from the _event_fetch_list queue @@ -910,7 +794,7 @@ class EventsWorkerStore(SQLBaseStore): ) # We only want to resolve deferreds from the main thread - def fire() -> None: + def fire(): for _, d in event_list: d.callback(row_dict) @@ -920,16 +804,18 @@ class EventsWorkerStore(SQLBaseStore): logger.exception("do_fetch") # We only want to resolve deferreds from the main thread - def fire_errback(exc: Exception) -> None: - for _, d in event_list: - d.errback(exc) + def fire(evs, exc): + for _, d in evs: + if not d.called: + with PreserveLoggingContext(): + d.errback(exc) with PreserveLoggingContext(): - self.hs.get_reactor().callFromThread(fire_errback, e) + self.hs.get_reactor().callFromThread(fire, event_list, e) async def _get_events_from_db( - self, event_ids: Collection[str] - ) -> Dict[str, EventCacheEntry]: + self, event_ids: Iterable[str] + ) -> Dict[str, _EventCacheEntry]: """Fetch a bunch of events from the database. May return rejected events. @@ -945,29 +831,29 @@ class EventsWorkerStore(SQLBaseStore): map from event id to result. May return extra events which weren't asked for. """ - fetched_event_ids: Set[str] = set() - fetched_events: Dict[str, _EventRow] = {} + fetched_events = {} events_to_fetch = event_ids while events_to_fetch: row_map = await self._enqueue_events(events_to_fetch) # we need to recursively fetch any redactions of those events - redaction_ids: Set[str] = set() + redaction_ids = set() for event_id in events_to_fetch: row = row_map.get(event_id) - fetched_event_ids.add(event_id) + fetched_events[event_id] = row if row: - fetched_events[event_id] = row redaction_ids.update(row.redactions) - events_to_fetch = redaction_ids.difference(fetched_event_ids) + events_to_fetch = redaction_ids.difference(fetched_events.keys()) if events_to_fetch: logger.debug("Also fetching redaction events %s", events_to_fetch) # build a map from event_id to EventBase - event_map: Dict[str, EventBase] = {} + event_map = {} for event_id, row in fetched_events.items(): + if not row: + continue assert row.event_id == event_id rejected_reason = row.rejected_reason @@ -995,7 +881,6 @@ class EventsWorkerStore(SQLBaseStore): room_version_id = row.room_version_id - room_version: Optional[RoomVersion] if not room_version_id: # this should only happen for out-of-band membership events which # arrived before #6983 landed. For all other events, we should have @@ -1066,14 +951,14 @@ class EventsWorkerStore(SQLBaseStore): # finally, we can decide whether each one needs redacting, and build # the cache entries. - result_map: Dict[str, EventCacheEntry] = {} + result_map = {} for event_id, original_ev in event_map.items(): redactions = fetched_events[event_id].redactions redacted_event = self._maybe_redact_event_row( original_ev, redactions, event_map ) - cache_entry = EventCacheEntry( + cache_entry = _EventCacheEntry( event=original_ev, redacted_event=redacted_event ) @@ -1082,7 +967,7 @@ class EventsWorkerStore(SQLBaseStore): return result_map - async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]: + async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]: """Fetches events from the database using the _event_fetch_list. This allows batch and bulk fetching of events - it allows us to fetch events without having to create a new transaction for each request for events. @@ -1095,12 +980,23 @@ class EventsWorkerStore(SQLBaseStore): that weren't requested. """ - events_d: "defer.Deferred[Dict[str, _EventRow]]" = defer.Deferred() + events_d = defer.Deferred() with self._event_fetch_lock: self._event_fetch_list.append((events, events_d)) + self._event_fetch_lock.notify() - self._maybe_start_fetch_thread() + if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: + self._event_fetch_ongoing += 1 + event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) + should_start = True + else: + should_start = False + + if should_start: + run_as_background_process( + "fetch_events", self.db_pool.runWithConnection, self._do_fetch + ) logger.debug("Loading %d events: %s", len(events), events) with PreserveLoggingContext(): @@ -1250,7 +1146,7 @@ class EventsWorkerStore(SQLBaseStore): # no valid redaction found for this event return None - async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]: + async def have_events_in_timeline(self, event_ids): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ @@ -1279,7 +1175,7 @@ class EventsWorkerStore(SQLBaseStore): event_ids: events we are looking for Returns: - The set of events we have already seen. + set[str]: The events we have already seen. """ res = await self._have_seen_events_dict( (room_id, event_id) for event_id in event_ids @@ -1302,9 +1198,7 @@ class EventsWorkerStore(SQLBaseStore): } results = {x: True for x in cache_results} - def have_seen_events_txn( - txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...] - ) -> None: + def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]): # we deliberately do *not* query the database for room_id, to make the # query an index-only lookup on `events_event_id_key`. # @@ -1330,14 +1224,12 @@ class EventsWorkerStore(SQLBaseStore): return results @cached(max_entries=100000, tree=True) - async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn: + async def have_seen_event(self, room_id: str, event_id: str): # this only exists for the benefit of the @cachedList descriptor on # _have_seen_events_dict raise NotImplementedError() - def _get_current_state_event_counts_txn( - self, txn: LoggingTransaction, room_id: str - ) -> int: + def _get_current_state_event_counts_txn(self, txn, room_id): """ See get_current_state_event_counts. """ @@ -1362,7 +1254,7 @@ class EventsWorkerStore(SQLBaseStore): room_id, ) - async def get_room_complexity(self, room_id: str) -> Dict[str, float]: + async def get_room_complexity(self, room_id): """ Get a rough approximation of the complexity of the room. This is used by remote servers to decide whether they wish to join the room or not. @@ -1370,10 +1262,10 @@ class EventsWorkerStore(SQLBaseStore): more resources. Args: - room_id: The room ID to query. + room_id (str) Returns: - dict[str:float] of complexity version to complexity. + dict[str:int] of complexity version to complexity. """ state_events = await self.get_current_state_event_counts(room_id) @@ -1383,13 +1275,13 @@ class EventsWorkerStore(SQLBaseStore): return {"v1": complexity_v1} - def get_current_events_token(self) -> int: + def get_current_events_token(self): """The current maximum token that events have reached""" return self._stream_id_gen.get_current_token() async def get_all_new_forward_event_rows( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + ) -> List[Tuple]: """Returns new events, for the Events replication stream Args: @@ -1403,15 +1295,13 @@ class EventsWorkerStore(SQLBaseStore): EventsStreamRow. """ - def get_all_new_forward_event_rows( - txn: LoggingTransaction, - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + def get_all_new_forward_event_rows(txn): sql = ( "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" + " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events AS se USING (event_id)" + " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " LEFT JOIN room_memberships USING (event_id)" " LEFT JOIN rejections USING (event_id)" @@ -1421,9 +1311,7 @@ class EventsWorkerStore(SQLBaseStore): " LIMIT ?" ) txn.execute(sql, (last_id, current_id, instance_name, limit)) - return cast( - List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall() - ) + return txn.fetchall() return await self.db_pool.runInteraction( "get_all_new_forward_event_rows", get_all_new_forward_event_rows @@ -1431,7 +1319,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_ex_outlier_stream_rows( self, instance_name: str, last_id: int, current_id: int - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + ) -> List[Tuple]: """Returns de-outliered events, for the Events replication stream Args: @@ -1444,16 +1332,14 @@ class EventsWorkerStore(SQLBaseStore): EventsStreamRow. """ - def get_ex_outlier_stream_rows_txn( - txn: LoggingTransaction, - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + def get_ex_outlier_stream_rows_txn(txn): sql = ( "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," - " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" + " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" " FROM events AS e" " INNER JOIN ex_outlier_stream AS out USING (event_id)" " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events AS se USING (event_id)" + " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " LEFT JOIN room_memberships USING (event_id)" " LEFT JOIN rejections USING (event_id)" @@ -1464,9 +1350,7 @@ class EventsWorkerStore(SQLBaseStore): ) txn.execute(sql, (last_id, current_id, instance_name)) - return cast( - List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall() - ) + return txn.fetchall() return await self.db_pool.runInteraction( "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn @@ -1474,7 +1358,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_all_new_backfill_event_rows( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]: + ) -> Tuple[List[Tuple[int, list]], int, bool]: """Get updates for backfill replication stream, including all new backfilled events and events that have gone from being outliers to not. @@ -1502,15 +1386,13 @@ class EventsWorkerStore(SQLBaseStore): if last_id == current_id: return [], current_id, False - def get_all_new_backfill_event_rows( - txn: LoggingTransaction, - ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]: + def get_all_new_backfill_event_rows(txn): sql = ( "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " se.state_key, redacts, relates_to_id" + " state_key, redacts, relates_to_id" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events AS se USING (event_id)" + " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " WHERE ? > stream_ordering AND stream_ordering >= ?" " AND instance_name = ?" @@ -1518,15 +1400,7 @@ class EventsWorkerStore(SQLBaseStore): " LIMIT ?" ) txn.execute(sql, (-last_id, -current_id, instance_name, limit)) - new_event_updates: List[ - Tuple[int, Tuple[str, str, str, str, str, str]] - ] = [] - row: Tuple[int, str, str, str, str, str, str] - # Type safety: iterating over `txn` yields `Tuple`, i.e. - # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a - # variadic tuple to a fixed length tuple and flags it up as an error. - for row in txn: # type: ignore[assignment] - new_event_updates.append((row[0], row[1:])) + new_event_updates = [(row[0], row[1:]) for row in txn] limited = False if len(new_event_updates) == limit: @@ -1537,11 +1411,11 @@ class EventsWorkerStore(SQLBaseStore): sql = ( "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," - " se.state_key, redacts, relates_to_id" + " state_key, redacts, relates_to_id" " FROM events AS e" " INNER JOIN ex_outlier_stream AS out USING (event_id)" " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events AS se USING (event_id)" + " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " WHERE ? > event_stream_ordering" " AND event_stream_ordering >= ?" @@ -1549,11 +1423,7 @@ class EventsWorkerStore(SQLBaseStore): " ORDER BY event_stream_ordering DESC" ) txn.execute(sql, (-last_id, -upper_bound, instance_name)) - # Type safety: iterating over `txn` yields `Tuple`, i.e. - # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a - # variadic tuple to a fixed length tuple and flags it up as an error. - for row in txn: # type: ignore[assignment] - new_event_updates.append((row[0], row[1:])) + new_event_updates.extend((row[0], row[1:]) for row in txn) if len(new_event_updates) >= limit: upper_bound = new_event_updates[-1][0] @@ -1567,7 +1437,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_all_updated_current_state_deltas( self, instance_name: str, from_token: int, to_token: int, target_row_count: int - ) -> Tuple[List[Tuple[int, str, str, str, str]], int, bool]: + ) -> Tuple[List[Tuple], int, bool]: """Fetch updates from current_state_delta_stream Args: @@ -1587,9 +1457,7 @@ class EventsWorkerStore(SQLBaseStore): * `limited` is whether there are more updates to fetch. """ - def get_all_updated_current_state_deltas_txn( - txn: LoggingTransaction, - ) -> List[Tuple[int, str, str, str, str]]: + def get_all_updated_current_state_deltas_txn(txn): sql = """ SELECT stream_id, room_id, type, state_key, event_id FROM current_state_delta_stream @@ -1598,23 +1466,21 @@ class EventsWorkerStore(SQLBaseStore): ORDER BY stream_id ASC LIMIT ? """ txn.execute(sql, (from_token, to_token, instance_name, target_row_count)) - return cast(List[Tuple[int, str, str, str, str]], txn.fetchall()) + return txn.fetchall() - def get_deltas_for_stream_id_txn( - txn: LoggingTransaction, stream_id: int - ) -> List[Tuple[int, str, str, str, str]]: + def get_deltas_for_stream_id_txn(txn, stream_id): sql = """ SELECT stream_id, room_id, type, state_key, event_id FROM current_state_delta_stream WHERE stream_id = ? """ txn.execute(sql, [stream_id]) - return cast(List[Tuple[int, str, str, str, str]], txn.fetchall()) + return txn.fetchall() # we need to make sure that, for every stream id in the results, we get *all* # the rows with that stream id. - rows: List[Tuple[int, str, str, str, str]] = await self.db_pool.runInteraction( + rows: List[Tuple] = await self.db_pool.runInteraction( "get_all_updated_current_state_deltas", get_all_updated_current_state_deltas_txn, ) @@ -1643,14 +1509,14 @@ class EventsWorkerStore(SQLBaseStore): return rows, to_token, True - async def is_event_after(self, event_id1: str, event_id2: str) -> bool: + async def is_event_after(self, event_id1, event_id2): """Returns True if event_id1 is after event_id2 in the stream""" to_1, so_1 = await self.get_event_ordering(event_id1) to_2, so_2 = await self.get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) @cached(max_entries=5000) - async def get_event_ordering(self, event_id: str) -> Tuple[int, int]: + async def get_event_ordering(self, event_id): res = await self.db_pool.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], @@ -1673,9 +1539,7 @@ class EventsWorkerStore(SQLBaseStore): None otherwise. """ - def get_next_event_to_expire_txn( - txn: LoggingTransaction, - ) -> Optional[Tuple[str, int]]: + def get_next_event_to_expire_txn(txn): txn.execute( """ SELECT event_id, expiry_ts FROM event_expiry @@ -1683,7 +1547,7 @@ class EventsWorkerStore(SQLBaseStore): """ ) - return cast(Optional[Tuple[str, int]], txn.fetchone()) + return txn.fetchone() return await self.db_pool.runInteraction( desc="get_next_event_to_expire", func=get_next_event_to_expire_txn @@ -1747,10 +1611,10 @@ class EventsWorkerStore(SQLBaseStore): return mapping @wrap_as_background_process("_cleanup_old_transaction_ids") - async def _cleanup_old_transaction_ids(self) -> None: + async def _cleanup_old_transaction_ids(self): """Cleans out transaction id mappings older than 24hrs.""" - def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None: + def _cleanup_old_transaction_ids_txn(txn): sql = """ DELETE FROM event_txn_id WHERE inserted_ts < ? @@ -1762,198 +1626,3 @@ class EventsWorkerStore(SQLBaseStore): "_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn, ) - - async def is_event_next_to_backward_gap(self, event: EventBase) -> bool: - """Check if the given event is next to a backward gap of missing events. - <latest messages> A(False)--->B(False)--->C(True)---> <gap, unknown events> <oldest messages> - - Args: - room_id: room where the event lives - event_id: event to check - - Returns: - Boolean indicating whether it's an extremity - """ - - def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool: - # If the event in question has any of its prev_events listed as a - # backward extremity, it's next to a gap. - # - # We can't just check the backward edges in `event_edges` because - # when we persist events, we will also record the prev_events as - # edges to the event in question regardless of whether we have those - # prev_events yet. We need to check whether those prev_events are - # backward extremities, also known as gaps, that need to be - # backfilled. - backward_extremity_query = """ - SELECT 1 FROM event_backward_extremities - WHERE - room_id = ? - AND %s - LIMIT 1 - """ - - # If the event in question is a backward extremity or has any of its - # prev_events listed as a backward extremity, it's next to a - # backward gap. - clause, args = make_in_list_sql_clause( - self.database_engine, - "event_id", - [event.event_id] + list(event.prev_event_ids()), - ) - - txn.execute(backward_extremity_query % (clause,), [event.room_id] + args) - backward_extremities = txn.fetchall() - - # We consider any backward extremity as a backward gap - if len(backward_extremities): - return True - - return False - - return await self.db_pool.runInteraction( - "is_event_next_to_backward_gap_txn", - is_event_next_to_backward_gap_txn, - ) - - async def is_event_next_to_forward_gap(self, event: EventBase) -> bool: - """Check if the given event is next to a forward gap of missing events. - The gap in front of the latest events is not considered a gap. - <latest messages> A(False)--->B(False)--->C(False)---> <gap, unknown events> <oldest messages> - <latest messages> A(False)--->B(False)---> <gap, unknown events> --->D(True)--->E(False) <oldest messages> - - Args: - room_id: room where the event lives - event_id: event to check - - Returns: - Boolean indicating whether it's an extremity - """ - - def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool: - # If the event in question is a forward extremity, we will just - # consider any potential forward gap as not a gap since it's one of - # the latest events in the room. - # - # `event_forward_extremities` does not include backfilled or outlier - # events so we can't rely on it to find forward gaps. We can only - # use it to determine whether a message is the latest in the room. - # - # We can't combine this query with the `forward_edge_query` below - # because if the event in question has no forward edges (isn't - # referenced by any other event's prev_events) but is in - # `event_forward_extremities`, we don't want to return 0 rows and - # say it's next to a gap. - forward_extremity_query = """ - SELECT 1 FROM event_forward_extremities - WHERE - room_id = ? - AND event_id = ? - LIMIT 1 - """ - - # Check to see whether the event in question is already referenced - # by another event. If we don't see any edges, we're next to a - # forward gap. - forward_edge_query = """ - SELECT 1 FROM event_edges - /* Check to make sure the event referencing our event in question is not rejected */ - LEFT JOIN rejections ON event_edges.event_id == rejections.event_id - WHERE - event_edges.room_id = ? - AND event_edges.prev_event_id = ? - /* It's not a valid edge if the event referencing our event in - * question is rejected. - */ - AND rejections.event_id IS NULL - LIMIT 1 - """ - - # We consider any forward extremity as the latest in the room and - # not a forward gap. - # - # To expand, even though there is technically a gap at the front of - # the room where the forward extremities are, we consider those the - # latest messages in the room so asking other homeservers for more - # is useless. The new latest messages will just be federated as - # usual. - txn.execute(forward_extremity_query, (event.room_id, event.event_id)) - forward_extremities = txn.fetchall() - if len(forward_extremities): - return False - - # If there are no forward edges to the event in question (another - # event hasn't referenced this event in their prev_events), then we - # assume there is a forward gap in the history. - txn.execute(forward_edge_query, (event.room_id, event.event_id)) - forward_edges = txn.fetchall() - if not len(forward_edges): - return True - - return False - - return await self.db_pool.runInteraction( - "is_event_next_to_gap_txn", - is_event_next_to_gap_txn, - ) - - async def get_event_id_for_timestamp( - self, room_id: str, timestamp: int, direction: str - ) -> Optional[str]: - """Find the closest event to the given timestamp in the given direction. - - Args: - room_id: Room to fetch the event from - timestamp: The point in time (inclusive) we should navigate from in - the given direction to find the closest event. - direction: ["f"|"b"] to indicate whether we should navigate forward - or backward from the given timestamp to find the closest event. - - Returns: - The closest event_id otherwise None if we can't find any event in - the given direction. - """ - - sql_template = """ - SELECT event_id FROM events - LEFT JOIN rejections USING (event_id) - WHERE - origin_server_ts %s ? - AND room_id = ? - /* Make sure event is not rejected */ - AND rejections.event_id IS NULL - ORDER BY origin_server_ts %s - LIMIT 1; - """ - - def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]: - if direction == "b": - # Find closest event *before* a given timestamp. We use descending - # (which gives values largest to smallest) because we want the - # largest possible timestamp *before* the given timestamp. - comparison_operator = "<=" - order = "DESC" - else: - # Find closest event *after* a given timestamp. We use ascending - # (which gives values smallest to largest) because we want the - # closest possible timestamp *after* the given timestamp. - comparison_operator = ">=" - order = "ASC" - - txn.execute( - sql_template % (comparison_operator, order), (timestamp, room_id) - ) - row = txn.fetchone() - if row: - (event_id,) = row - return event_id - - return None - - if direction not in ("f", "b"): - raise ValueError("Unknown direction: %s" % (direction,)) - - return await self.db_pool.runInteraction( - "get_event_id_for_timestamp_txn", - get_event_id_for_timestamp_txn, - ) diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 91b0576b85..3eb30944bf 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py
@@ -118,7 +118,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): logger.info("[purge] looking for events to delete") - should_delete_expr = "state_events.state_key IS NULL" + should_delete_expr = "state_key IS NULL" should_delete_params: Tuple[Any, ...] = () if not delete_local_events: should_delete_expr += " AND event_id NOT LIKE ?" diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 3b63267395..fa782023d4 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py
@@ -28,10 +28,7 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException -from synapse.storage.util.id_generators import ( - AbstractStreamIdTracker, - StreamIdGenerator, -) +from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -85,9 +82,9 @@ class PushRulesWorkerStore( super().__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: - self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, "push_rules_stream", "stream_id" - ) + self._push_rules_stream_id_gen: Union[ + StreamIdGenerator, SlavedIdTracker + ] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id") else: self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id" diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e1ddf06916..0e8c168667 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -106,15 +106,6 @@ class RefreshTokenLookupResult: has_next_access_token_been_used: bool """True if the next access token was already used at least once.""" - expiry_ts: Optional[int] - """The time at which the refresh token expires and can not be used. - If None, the refresh token doesn't expire.""" - - ultimate_session_expiry_ts: Optional[int] - """The time at which the session comes to an end and can no longer be - refreshed. - If None, the session can be refreshed indefinitely.""" - class RegistrationWorkerStore(CacheInvalidationWorkerStore): def __init__( @@ -1635,10 +1626,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): rt.user_id, rt.device_id, rt.next_token_id, - (nrt.next_token_id IS NOT NULL) AS has_next_refresh_token_been_refreshed, - at.used AS has_next_access_token_been_used, - rt.expiry_ts, - rt.ultimate_session_expiry_ts + (nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed, + at.used has_next_access_token_been_used FROM refresh_tokens rt LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id @@ -1659,8 +1648,6 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): has_next_refresh_token_been_refreshed=row[4], # This column is nullable, ensure it's a boolean has_next_access_token_been_used=(row[5] or False), - expiry_ts=row[6], - ultimate_session_expiry_ts=row[7], ) return await self.db_pool.runInteraction( @@ -1928,8 +1915,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): user_id: str, token: str, device_id: Optional[str], - expiry_ts: Optional[int], - ultimate_session_expiry_ts: Optional[int], ) -> int: """Adds a refresh token for the given user. @@ -1937,13 +1922,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): user_id: The user ID. token: The new access token to add. device_id: ID of the device to associate with the refresh token. - expiry_ts (milliseconds since the epoch): Time after which the - refresh token cannot be used. - If None, the refresh token never expires until it has been used. - ultimate_session_expiry_ts (milliseconds since the epoch): - Time at which the session will end and can not be extended any - further. - If None, the session can be refreshed indefinitely. Raises: StoreError if there was a problem adding this. Returns: @@ -1959,8 +1937,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): "device_id": device_id, "token": token, "next_token_id": None, - "expiry_ts": expiry_ts, - "ultimate_session_expiry_ts": ultimate_session_expiry_ts, }, desc="add_refresh_token_to_user", ) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 6b2a8d06a6..033a9831d6 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py
@@ -476,7 +476,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): INNER JOIN events AS e USING (room_id, event_id) WHERE c.type = 'm.room.member' - AND c.state_key = ? + AND state_key = ? AND c.membership = ? """ else: @@ -487,7 +487,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): INNER JOIN events AS e USING (room_id, event_id) WHERE c.type = 'm.room.member' - AND c.state_key = ? + AND state_key = ? AND m.membership = ? """ diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 57aab55259..42dc807d17 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py
@@ -497,7 +497,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): oldest `limit` events. Returns: - The list of events (in ascending stream order) and the token from the start + The list of events (in ascending order) and the token from the start of the chunk of events returned. """ if from_key == to_key: @@ -510,7 +510,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): if not has_changed: return [], from_key - def f(txn: LoggingTransaction) -> List[_EventDictReturn]: + def f(txn): # To handle tokens with a non-empty instance_map we fetch more # results than necessary and then filter down min_from_id = from_key.stream @@ -565,13 +565,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): async def get_membership_changes_for_user( self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken ) -> List[EventBase]: - """Fetch membership events for a given user. - - All such events whose stream ordering `s` lies in the range - `from_key < s <= to_key` are returned. Events are ordered by ascending stream - order. - """ - # Start by ruling out cases where a DB query is not necessary. if from_key == to_key: return [] @@ -582,7 +575,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): if not has_changed: return [] - def f(txn: LoggingTransaction) -> List[_EventDictReturn]: + def f(txn): # To handle tokens with a non-empty instance_map we fetch more # results than necessary and then filter down min_from_id = from_key.stream @@ -641,7 +634,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): Returns: A list of events and a token pointing to the start of the returned - events. The events returned are in ascending topological order. + events. The events returned are in ascending order. """ rows, token = await self.get_recent_event_ids_for_room( diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 1622822552..d7dc1f73ac 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py
@@ -14,7 +14,6 @@ import logging from collections import namedtuple -from enum import Enum from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple import attr @@ -45,16 +44,6 @@ _UpdateTransactionRow = namedtuple( ) -class DestinationSortOrder(Enum): - """Enum to define the sorting method used when returning destinations.""" - - DESTINATION = "destination" - RETRY_LAST_TS = "retry_last_ts" - RETTRY_INTERVAL = "retry_interval" - FAILURE_TS = "failure_ts" - LAST_SUCCESSFUL_STREAM_ORDERING = "last_successful_stream_ordering" - - @attr.s(slots=True, frozen=True, auto_attribs=True) class DestinationRetryTimings: """The current destination retry timing info for a remote server.""" @@ -491,62 +480,3 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): destinations = [row[0] for row in txn] return destinations - - async def get_destinations_paginate( - self, - start: int, - limit: int, - destination: Optional[str] = None, - order_by: str = DestinationSortOrder.DESTINATION.value, - direction: str = "f", - ) -> Tuple[List[JsonDict], int]: - """Function to retrieve a paginated list of destinations. - This will return a json list of destinations and the - total number of destinations matching the filter criteria. - - Args: - start: start number to begin the query from - limit: number of rows to retrieve - destination: search string in destination - order_by: the sort order of the returned list - direction: sort ascending or descending - Returns: - A tuple of a list of mappings from destination to information - and a count of total destinations. - """ - - def get_destinations_paginate_txn( - txn: LoggingTransaction, - ) -> Tuple[List[JsonDict], int]: - order_by_column = DestinationSortOrder(order_by).value - - if direction == "b": - order = "DESC" - else: - order = "ASC" - - args = [] - where_statement = "" - if destination: - args.extend(["%" + destination.lower() + "%"]) - where_statement = "WHERE LOWER(destination) LIKE ?" - - sql_base = f"FROM destinations {where_statement} " - sql = f"SELECT COUNT(*) as total_destinations {sql_base}" - txn.execute(sql, args) - count = txn.fetchone()[0] - - sql = f""" - SELECT destination, retry_last_ts, retry_interval, failure_ts, - last_successful_stream_ordering - {sql_base} - ORDER BY {order_by_column} {order}, destination ASC - LIMIT ? OFFSET ? - """ - txn.execute(sql, args + [limit, start]) - destinations = self.db_pool.cursor_to_dict(txn) - return destinations, count - - return await self.db_pool.runInteraction( - "get_destinations_paginate_txn", get_destinations_paginate_txn - )