summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorSean Quah <seanq@element.io>2021-12-07 16:38:29 +0000
committerSean Quah <seanq@element.io>2021-12-07 16:47:31 +0000
commit158d73ebdd61eef33831ae5f6990acf07244fc55 (patch)
tree723f79596374042e349d55a6195cbe2b5eea29eb /synapse/storage
parentSort internal changes in changelog (diff)
downloadsynapse-158d73ebdd61eef33831ae5f6990acf07244fc55.tar.xz
Revert accidental fast-forward merge from v1.49.0rc1
Revert "Sort internal changes in changelog"
Revert "Update CHANGES.md"
Revert "1.49.0rc1"
Revert "Revert "Move `glob_to_regex` and `re_word_boundary` to `matrix-python-common` (#11505) (#11527)"
Revert "Refactors in `_generate_sync_entry_for_rooms` (#11515)"
Revert "Correctly register shutdown handler for presence workers (#11518)"
Revert "Fix `ModuleApi.looping_background_call` for non-async functions (#11524)"
Revert "Fix 'delete room' admin api to work on incomplete rooms (#11523)"
Revert "Correctly ignore invites from ignored users (#11511)"
Revert "Fix the test breakage introduced by #11435 as a result of concurrent PRs (#11522)"
Revert "Stabilise support for MSC2918 refresh tokens as they have now been merged into the Matrix specification. (#11435)"
Revert "Save the OIDC session ID (sid) with the device on login (#11482)"
Revert "Add admin API to get some information about federation status (#11407)"
Revert "Include bundled aggregations in /sync and related fixes (#11478)"
Revert "Move `glob_to_regex` and `re_word_boundary` to `matrix-python-common` (#11505)"
Revert "Update backward extremity docs to make it clear that it does not indicate whether we have fetched an events' `prev_events` (#11469)"
Revert "Support configuring the lifetime of non-refreshable access tokens separately to refreshable access tokens. (#11445)"
Revert "Add type hints to `synapse/tests/rest/admin` (#11501)"
Revert "Revert accidental commits to develop."
Revert "Newsfile"
Revert "Give `tests.server.setup_test_homeserver` (nominally!) the same behaviour"
Revert "Move `tests.utils.setup_test_homeserver` to `tests.server`"
Revert "Convert one of the `setup_test_homeserver`s to `make_test_homeserver_synchronous`"
Revert "Disambiguate queries on `state_key` (#11497)"
Revert "Comments on the /sync tentacles (#11494)"
Revert "Clean up tests.storage.test_appservice (#11492)"
Revert "Clean up `tests.storage.test_main` to remove use of legacy code. (#11493)"
Revert "Clean up `tests.test_visibility` to remove legacy code. (#11495)"
Revert "Minor cleanup on recently ported doc pages  (#11466)"
Revert "Add most of the missing type hints to `synapse.federation`. (#11483)"
Revert "Avoid waiting for zombie processes in `synctl stop` (#11490)"
Revert "Fix media repository failing when media store path contains symlinks (#11446)"
Revert "Add type annotations to `tests.storage.test_appservice`. (#11488)"
Revert "`scripts-dev/sign_json`: support for signing events (#11486)"
Revert "Add MSC3030 experimental client and federation API endpoints to get the closest event to a given timestamp (#9445)"
Revert "Port wiki pages to documentation website (#11402)"
Revert "Add a license header and comment. (#11479)"
Revert "Clean-up get_version_string (#11468)"
Revert "Link background update controller docs to summary (#11475)"
Revert "Additional type hints for config module. (#11465)"
Revert "Register the login redirect endpoint for v3. (#11451)"
Revert "Update openid.md"
Revert "Remove mention of OIDC certification from Dex (#11470)"
Revert "Add a note about huge pages to our Postgres doc (#11467)"
Revert "Don't start Synapse master process if `worker_app` is set (#11416)"
Revert "Expose worker & homeserver as entrypoints in `setup.py` (#11449)"
Revert "Bundle relations of relations into the `/relations` result. (#11284)"
Revert "Fix `LruCache` corruption bug with a `size_callback` that can return 0 (#11454)"
Revert "Eliminate a few `Any`s in `LruCache` type hints (#11453)"
Revert "Remove unnecessary `json.dumps` from `tests.rest.admin` (#11461)"
Revert "Merge branch 'master' into develop"

This reverts commit 26b5d2320f62b5eb6262c7614fbdfc364a4dfc02.
This reverts commit bce4220f387bf5448387f0ed7d14ed1e41e40747.
This reverts commit 966b5d0fa0893c3b628c942dfc232e285417f46d.
This reverts commit 088d748f2cb51f03f3bcacc0fb3af1e0f9607737.
This reverts commit 14d593f72d10b4d8cb67e3288bb3131ee30ccf59.
This reverts commit 2a3ec6facf79f6aae011d9fb6f9ed5e43c7b6bec.
This reverts commit eccc49d7554d1fab001e1fefb0fda8ffb254b630.
This reverts commit b1ecd19c5d19815b69e425d80f442bf2877cab76.
This reverts commit 9c55dedc8c4484e6269451a8c3c10b3e314aeb4a.
This reverts commit 2d42e586a8c54be1a83643148358b1651c1ca666.
This reverts commit 2f053f3f82ca174cc1c858c75afffae51af8ce0d.
This reverts commit a15a893df8428395df7cb95b729431575001c38a.
This reverts commit 8b4b153c9e86c04c7db8c74fde4b6a04becbc461.
This reverts commit 494ebd7347ba52d702802fba4c3bb13e7bfbc2cf.
This reverts commit a77c36989785c0d5565ab9a1169f4f88e512ce8a.
This reverts commit 4eb77965cd016181d2111f37d93526e9bb0434f0.
This reverts commit 637df95de63196033a6da4a6e286e1d58ea517b6.
This reverts commit e5f426cd54609e7f05f8241d845e6e36c5f10d9a.
This reverts commit 8cd68b8102eeab1b525712097c1b2e9679c11896.
This reverts commit 6cae125e20865c52d770b24278bb7ab8fde5bc0d.
This reverts commit 7be88fbf48156b36b6daefb228e1258e7d48cae4.
This reverts commit b3fd99b74a3f6f42a9afd1b19ee4c60e38e8e91a.
This reverts commit f7ec6e7d9e0dc360d9fb41f3a1afd7bdba1475c7.
This reverts commit 5640992d176a499204a0756b1677c9b1575b0a49.
This reverts commit d26808dd854006bd26a2366c675428ce0737238c.
This reverts commit f91624a5950e14ba9007eed9bfa1c828676d4745.
This reverts commit 16d39a5490ce74c901c7a8dbb990c6e83c379207.
This reverts commit 8a4c2969874c0b7d72003f2523883eba8a348e83.
This reverts commit 49e1356ee3d5d72929c91f778b3a231726c1413c.
This reverts commit d2279f471ba8f44d9f578e62b286897a338d8aa1.
This reverts commit b50e39df578adc3f86c5efa16bee9035cfdab61b.
This reverts commit 858d80bf0f9f656a03992794874081b806e49222.
This reverts commit 435f04480728c5d982e1a63c1b2777784bf9cd26.
This reverts commit f61462e1be36a51dbf571076afa8e1930cb182f4.
This reverts commit a6f1a3abecf8e8fd3e1bff439a06b853df18f194.
This reverts commit 84dc50e160a2ec6590813374b5a1e58b97f7a18d.
This reverts commit ed635d32853ee0a3e5ec1078679b27e7844a4ac7.
This reverts commit 7b62791e001d6a4f8897ed48b3232d7f8fe6aa48.
This reverts commit 153194c7717d8016b0eb974c81b1baee7dc1917d.
This reverts commit f44d729d4ccae61bc0cdd5774acb3233eb5f7c13.
This reverts commit a265fbd397ae72b2d3ea4c9310591ff1d0f3e05c.
This reverts commit b9fef1a7cdfcc128fa589a32160e6aa7ed8964d7.
This reverts commit b0eb64ff7bf6bde42046e091f8bdea9b7aab5f04.
This reverts commit f1795463bf503a6fca909d77f598f641f9349f56.
This reverts commit 70cbb1a5e311f609b624e3fae1a1712db639c51e.
This reverts commit 42bf0204635213e2c75188b19ee66dc7e7d8a35e.
This reverts commit 379f2650cf875f50c59524147ec0e33cfd5ef60c.
This reverts commit 7ff22d6da41cd5ca80db95c18b409aea38e49fcd.
This reverts commit 5a0b652d36ae4b6d423498c1f2c82c97a49c6f75.
This reverts commit 432a174bc192740ac7a0a755009f6099b8363ad9.
This reverts commit b14f8a1baf6f500997ae4c1d6a6d72094ce14270, reversing
changes made to e713855dca17a7605bae99ea8d71bc7f8657e4b8.
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py4
-rw-r--r--synapse/storage/background_updates.py192
-rw-r--r--synapse/storage/databases/main/appservice.py6
-rw-r--r--synapse/storage/databases/main/devices.py50
-rw-r--r--synapse/storage/databases/main/event_federation.py4
-rw-r--r--synapse/storage/databases/main/event_push_actions.py19
-rw-r--r--synapse/storage/databases/main/events.py107
-rw-r--r--synapse/storage/databases/main/events_worker.py581
-rw-r--r--synapse/storage/databases/main/purge_events.py2
-rw-r--r--synapse/storage/databases/main/push_rule.py11
-rw-r--r--synapse/storage/databases/main/registration.py28
-rw-r--r--synapse/storage/databases/main/roommember.py4
-rw-r--r--synapse/storage/databases/main/stream.py15
-rw-r--r--synapse/storage/databases/main/transactions.py70
-rw-r--r--synapse/storage/persist_events.py3
-rw-r--r--synapse/storage/schema/__init__.py6
-rw-r--r--synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql28
-rw-r--r--synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql27
-rw-r--r--synapse/storage/util/id_generators.py116
19 files changed, 261 insertions, 1012 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py

index 3056e64ff5..0623da9aa1 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py
@@ -21,7 +21,7 @@ from synapse.storage.database import LoggingTransaction # noqa: F401 from synapse.storage.database import make_in_list_sql_clause # noqa: F401 from synapse.storage.database import DatabasePool from synapse.storage.types import Connection -from synapse.types import get_domain_from_id +from synapse.types import StreamToken, get_domain_from_id from synapse.util import json_decoder if TYPE_CHECKING: @@ -48,7 +48,7 @@ class SQLBaseStore(metaclass=ABCMeta): self, stream_name: str, instance_name: str, - token: int, + token: StreamToken, rows: Iterable[Any], ) -> None: pass diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index d64910aded..bc8364400d 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py
@@ -12,22 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import ( - TYPE_CHECKING, - AsyncContextManager, - Awaitable, - Callable, - Dict, - Iterable, - Optional, -) - -import attr +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.types import Connection from synapse.types import JsonDict -from synapse.util import Clock, json_encoder +from synapse.util import json_encoder from . import engines @@ -38,45 +28,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -ON_UPDATE_CALLBACK = Callable[[str, str, bool], AsyncContextManager[int]] -DEFAULT_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]] -MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]] - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class _BackgroundUpdateHandler: - """A handler for a given background update. - - Attributes: - callback: The function to call to make progress on the background - update. - oneshot: Wether the update is likely to happen all in one go, ignoring - the supplied target duration, e.g. index creation. This is used by - the update controller to help correctly schedule the update. - """ - - callback: Callable[[JsonDict, int], Awaitable[int]] - oneshot: bool = False - - -class _BackgroundUpdateContextManager: - BACKGROUND_UPDATE_INTERVAL_MS = 1000 - BACKGROUND_UPDATE_DURATION_MS = 100 - - def __init__(self, sleep: bool, clock: Clock): - self._sleep = sleep - self._clock = clock - - async def __aenter__(self) -> int: - if self._sleep: - await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000) - - return self.BACKGROUND_UPDATE_DURATION_MS - - async def __aexit__(self, *exc) -> None: - pass - - class BackgroundUpdatePerformance: """Tracks the how long a background update is taking to update its items""" @@ -133,22 +84,20 @@ class BackgroundUpdater: MINIMUM_BACKGROUND_BATCH_SIZE = 1 DEFAULT_BACKGROUND_BATCH_SIZE = 100 + BACKGROUND_UPDATE_INTERVAL_MS = 1000 + BACKGROUND_UPDATE_DURATION_MS = 100 def __init__(self, hs: "HomeServer", database: "DatabasePool"): self._clock = hs.get_clock() self.db_pool = database - self._database_name = database.name() - # if a background update is currently running, its name. self._current_background_update: Optional[str] = None - self._on_update_callback: Optional[ON_UPDATE_CALLBACK] = None - self._default_batch_size_callback: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None - self._min_batch_size_callback: Optional[MIN_BATCH_SIZE_CALLBACK] = None - self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {} - self._background_update_handlers: Dict[str, _BackgroundUpdateHandler] = {} + self._background_update_handlers: Dict[ + str, Callable[[JsonDict, int], Awaitable[int]] + ] = {} self._all_done = False # Whether we're currently running updates @@ -158,83 +107,6 @@ class BackgroundUpdater: # enable/disable background updates via the admin API. self.enabled = True - def register_update_controller_callbacks( - self, - on_update: ON_UPDATE_CALLBACK, - default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None, - min_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None, - ) -> None: - """Register callbacks from a module for each hook.""" - if self._on_update_callback is not None: - logger.warning( - "More than one module tried to register callbacks for controlling" - " background updates. Only the callbacks registered by the first module" - " (in order of appearance in Synapse's configuration file) that tried to" - " do so will be called." - ) - - return - - self._on_update_callback = on_update - - if default_batch_size is not None: - self._default_batch_size_callback = default_batch_size - - if min_batch_size is not None: - self._min_batch_size_callback = min_batch_size - - def _get_context_manager_for_update( - self, - sleep: bool, - update_name: str, - database_name: str, - oneshot: bool, - ) -> AsyncContextManager[int]: - """Get a context manager to run a background update with. - - If a module has registered a `update_handler` callback, use the context manager - it returns. - - Otherwise, returns a context manager that will return a default value, optionally - sleeping if needed. - - Args: - sleep: Whether we can sleep between updates. - update_name: The name of the update. - database_name: The name of the database the update is being run on. - oneshot: Whether the update will complete all in one go, e.g. index creation. - In such cases the returned target duration is ignored. - - Returns: - The target duration in milliseconds that the background update should run for. - - Note: this is a *target*, and an iteration may take substantially longer or - shorter. - """ - if self._on_update_callback is not None: - return self._on_update_callback(update_name, database_name, oneshot) - - return _BackgroundUpdateContextManager(sleep, self._clock) - - async def _default_batch_size(self, update_name: str, database_name: str) -> int: - """The batch size to use for the first iteration of a new background - update. - """ - if self._default_batch_size_callback is not None: - return await self._default_batch_size_callback(update_name, database_name) - - return self.DEFAULT_BACKGROUND_BATCH_SIZE - - async def _min_batch_size(self, update_name: str, database_name: str) -> int: - """A lower bound on the batch size of a new background update. - - Used to ensure that progress is always made. Must be greater than 0. - """ - if self._min_batch_size_callback is not None: - return await self._min_batch_size_callback(update_name, database_name) - - return self.MINIMUM_BACKGROUND_BATCH_SIZE - def get_current_update(self) -> Optional[BackgroundUpdatePerformance]: """Returns the current background update, if any.""" @@ -263,8 +135,13 @@ class BackgroundUpdater: try: logger.info("Starting background schema updates") while self.enabled: + if sleep: + await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0) + try: - result = await self.do_next_background_update(sleep) + result = await self.do_next_background_update( + self.BACKGROUND_UPDATE_DURATION_MS + ) except Exception: logger.exception("Error doing update") else: @@ -326,15 +203,13 @@ class BackgroundUpdater: return not update_exists - async def do_next_background_update(self, sleep: bool = True) -> bool: + async def do_next_background_update(self, desired_duration_ms: float) -> bool: """Does some amount of work on the next queued background update Returns once some amount of work is done. Args: - sleep: Whether to limit how quickly we run background updates or - not. - + desired_duration_ms: How long we want to spend updating. Returns: True if we have finished running all the background updates, otherwise False """ @@ -377,19 +252,7 @@ class BackgroundUpdater: self._current_background_update = upd["update_name"] - # We have a background update to run, otherwise we would have returned - # early. - assert self._current_background_update is not None - update_info = self._background_update_handlers[self._current_background_update] - - async with self._get_context_manager_for_update( - sleep=sleep, - update_name=self._current_background_update, - database_name=self._database_name, - oneshot=update_info.oneshot, - ) as desired_duration_ms: - await self._do_background_update(desired_duration_ms) - + await self._do_background_update(desired_duration_ms) return False async def _do_background_update(self, desired_duration_ms: float) -> int: @@ -397,7 +260,7 @@ class BackgroundUpdater: update_name = self._current_background_update logger.info("Starting update batch on background update '%s'", update_name) - update_handler = self._background_update_handlers[update_name].callback + update_handler = self._background_update_handlers[update_name] performance = self._background_update_performance.get(update_name) @@ -410,14 +273,9 @@ class BackgroundUpdater: if items_per_ms is not None: batch_size = int(desired_duration_ms * items_per_ms) # Clamp the batch size so that we always make progress - batch_size = max( - batch_size, - await self._min_batch_size(update_name, self._database_name), - ) + batch_size = max(batch_size, self.MINIMUM_BACKGROUND_BATCH_SIZE) else: - batch_size = await self._default_batch_size( - update_name, self._database_name - ) + batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE progress_json = await self.db_pool.simple_select_one_onecol( "background_updates", @@ -436,8 +294,6 @@ class BackgroundUpdater: duration_ms = time_stop - time_start - performance.update(items_updated, duration_ms) - logger.info( "Running background update %r. Processed %r items in %rms." " (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)", @@ -450,6 +306,8 @@ class BackgroundUpdater: batch_size, ) + performance.update(items_updated, duration_ms) + return len(self._background_update_performance) def register_background_update_handler( @@ -473,9 +331,7 @@ class BackgroundUpdater: update_name: The name of the update that this code handles. update_handler: The function that does the update. """ - self._background_update_handlers[update_name] = _BackgroundUpdateHandler( - update_handler - ) + self._background_update_handlers[update_name] = update_handler def register_noop_background_update(self, update_name: str) -> None: """Register a noop handler for a background update. @@ -597,9 +453,7 @@ class BackgroundUpdater: await self._end_background_update(update_name) return 1 - self._background_update_handlers[update_name] = _BackgroundUpdateHandler( - updater, oneshot=True - ) + self.register_background_update_handler(update_name, updater) async def _end_background_update(self, update_name: str) -> None: """Removes a completed background update task from the queue. diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 4a883dc166..baec35ee27 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py
@@ -143,7 +143,7 @@ class ApplicationServiceTransactionWorkerStore( A list of ApplicationServices, which may be empty. """ results = await self.db_pool.simple_select_list( - "application_services_state", {"state": state.value}, ["as_id"] + "application_services_state", {"state": state}, ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore as_list = self.get_app_services() @@ -173,7 +173,7 @@ class ApplicationServiceTransactionWorkerStore( desc="get_appservice_state", ) if result: - return ApplicationServiceState(result.get("state")) + return result.get("state") return None async def set_appservice_state( @@ -186,7 +186,7 @@ class ApplicationServiceTransactionWorkerStore( state: The connectivity state to apply. """ await self.db_pool.simple_upsert( - "application_services_state", {"as_id": service.id}, {"state": state.value} + "application_services_state", {"as_id": service.id}, {"state": state} ) async def create_appservice_txn( diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index d5a4a661cd..9ccc66e589 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -139,27 +139,6 @@ class DeviceWorkerStore(SQLBaseStore): return {d["device_id"]: d for d in devices} - async def get_devices_by_auth_provider_session_id( - self, auth_provider_id: str, auth_provider_session_id: str - ) -> List[Dict[str, Any]]: - """Retrieve the list of devices associated with a SSO IdP session ID. - - Args: - auth_provider_id: The SSO IdP ID as defined in the server config - auth_provider_session_id: The session ID within the IdP - Returns: - A list of dicts containing the device_id and the user_id of each device - """ - return await self.db_pool.simple_select_list( - table="device_auth_providers", - keyvalues={ - "auth_provider_id": auth_provider_id, - "auth_provider_session_id": auth_provider_session_id, - }, - retcols=("user_id", "device_id"), - desc="get_devices_by_auth_provider_session_id", - ) - @trace async def get_device_updates_by_remote( self, destination: str, from_stream_id: int, limit: int @@ -1091,12 +1070,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) async def store_device( - self, - user_id: str, - device_id: str, - initial_device_display_name: Optional[str], - auth_provider_id: Optional[str] = None, - auth_provider_session_id: Optional[str] = None, + self, user_id: str, device_id: str, initial_device_display_name: Optional[str] ) -> bool: """Ensure the given device is known; add it to the store if not @@ -1105,8 +1079,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): device_id: id of device initial_device_display_name: initial displayname of the device. Ignored if device exists. - auth_provider_id: The SSO IdP the user used, if any. - auth_provider_session_id: The session ID (sid) got from a OIDC login. Returns: Whether the device was inserted or an existing device existed with that ID. @@ -1143,18 +1115,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if hidden: raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN) - if auth_provider_id and auth_provider_session_id: - await self.db_pool.simple_insert( - "device_auth_providers", - values={ - "user_id": user_id, - "device_id": device_id, - "auth_provider_id": auth_provider_id, - "auth_provider_session_id": auth_provider_session_id, - }, - desc="store_device_auth_provider", - ) - self.device_id_exists_cache.set(key, True) return inserted except StoreError: @@ -1208,14 +1168,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): keyvalues={"user_id": user_id}, ) - self.db_pool.simple_delete_many_txn( - txn, - table="device_auth_providers", - column="device_id", - values=device_ids, - keyvalues={"user_id": user_id}, - ) - await self.db_pool.runInteraction("delete_devices", _delete_devices_txn) for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 9580a40785..ef5d1ef01e 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py
@@ -1552,9 +1552,9 @@ class EventFederationStore(EventFederationWorkerStore): DELETE FROM event_auth WHERE event_id IN ( SELECT event_id FROM events - LEFT JOIN state_events AS se USING (room_id, event_id) + LEFT JOIN state_events USING (room_id, event_id) WHERE ? <= stream_ordering AND stream_ordering < ? - AND se.state_key IS null + AND state_key IS null ) """ diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 3efdd0c920..d957e770dc 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py
@@ -16,7 +16,6 @@ import logging from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import attr -from typing_extensions import TypedDict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json @@ -38,20 +37,6 @@ DEFAULT_HIGHLIGHT_ACTION = [ ] -class BasePushAction(TypedDict): - event_id: str - actions: List[Union[dict, str]] - - -class HttpPushAction(BasePushAction): - room_id: str - stream_ordering: int - - -class EmailPushAction(HttpPushAction): - received_ts: Optional[int] - - def _serialize_action(actions, is_highlight): """Custom serializer for actions. This allows us to "compress" common actions. @@ -236,7 +221,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): min_stream_ordering: int, max_stream_ordering: int, limit: int = 20, - ) -> List[HttpPushAction]: + ) -> List[dict]: """Get a list of the most recent unread push actions for a given user, within the given stream ordering range. Called by the httppusher. @@ -341,7 +326,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): min_stream_ordering: int, max_stream_ordering: int, limit: int = 20, - ) -> List[EmailPushAction]: + ) -> List[dict]: """Get a list of the most recent unread push actions for a given user, within the given stream ordering range. Called by the emailpusher diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 4e528612ea..06832221ad 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -15,7 +15,7 @@ # limitations under the License. import itertools import logging -from collections import OrderedDict +from collections import OrderedDict, namedtuple from typing import ( TYPE_CHECKING, Any, @@ -41,10 +41,9 @@ from synapse.events.snapshot import EventContext # noqa: F401 from synapse.logging.utils import log_function from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction -from synapse.storage.databases.main.events_worker import EventCacheEntry from synapse.storage.databases.main.search import SearchEntry from synapse.storage.types import Connection -from synapse.storage.util.id_generators import AbstractStreamIdGenerator +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.storage.util.sequence import SequenceGenerator from synapse.types import StateMap, get_domain_from_id from synapse.util import json_encoder @@ -65,6 +64,9 @@ event_counter = Counter( ) +_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) + + @attr.s(slots=True) class DeltaState: """Deltas to use to update the `current_state_events` table. @@ -106,30 +108,23 @@ class PersistEventsStore: self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id + # Ideally we'd move these ID gens here, unfortunately some other ID + # generators are chained off them so doing so is a bit of a PITA. + self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen + self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen + # This should only exist on instances that are configured to write assert ( hs.get_instance_name() in hs.config.worker.writers.events ), "Can only instantiate EventsStore on master" - # Since we have been configured to write, we ought to have id generators, - # rather than id trackers. - assert isinstance(self.store._backfill_id_gen, AbstractStreamIdGenerator) - assert isinstance(self.store._stream_id_gen, AbstractStreamIdGenerator) - - # Ideally we'd move these ID gens here, unfortunately some other ID - # generators are chained off them so doing so is a bit of a PITA. - self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen - self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen - async def _persist_events_and_state_updates( self, events_and_contexts: List[Tuple[EventBase, EventContext]], - *, current_state_for_room: Dict[str, StateMap[str]], state_delta_for_room: Dict[str, DeltaState], new_forward_extremeties: Dict[str, List[str]], - use_negative_stream_ordering: bool = False, - inhibit_local_membership_updates: bool = False, + backfilled: bool = False, ) -> None: """Persist a set of events alongside updates to the current state and forward extremities tables. @@ -142,14 +137,7 @@ class PersistEventsStore: room state new_forward_extremities: Map from room_id to list of event IDs that are the new forward extremities of the room. - use_negative_stream_ordering: Whether to start stream_ordering on - the negative side and decrement. This should be set as True - for backfilled events because backfilled events get a negative - stream ordering so they don't come down incremental `/sync`. - inhibit_local_membership_updates: Stop the local_current_membership - from being updated by these events. This should be set to True - for backfilled events because backfilled events in the past do - not affect the current local state. + backfilled Returns: Resolves when the events have been persisted @@ -171,7 +159,7 @@ class PersistEventsStore: # # Note: Multiple instances of this function cannot be in flight at # the same time for the same room. - if use_negative_stream_ordering: + if backfilled: stream_ordering_manager = self._backfill_id_gen.get_next_mult( len(events_and_contexts) ) @@ -188,13 +176,13 @@ class PersistEventsStore: "persist_events", self._persist_events_txn, events_and_contexts=events_and_contexts, - inhibit_local_membership_updates=inhibit_local_membership_updates, + backfilled=backfilled, state_delta_for_room=state_delta_for_room, new_forward_extremeties=new_forward_extremeties, ) persist_event_counter.inc(len(events_and_contexts)) - if stream < 0: + if not backfilled: # backfilled events have negative stream orderings, so we don't # want to set the event_persisted_position to that. synapse.metrics.event_persisted_position.set( @@ -328,9 +316,8 @@ class PersistEventsStore: def _persist_events_txn( self, txn: LoggingTransaction, - *, events_and_contexts: List[Tuple[EventBase, EventContext]], - inhibit_local_membership_updates: bool = False, + backfilled: bool, state_delta_for_room: Optional[Dict[str, DeltaState]] = None, new_forward_extremeties: Optional[Dict[str, List[str]]] = None, ): @@ -343,10 +330,7 @@ class PersistEventsStore: Args: txn events_and_contexts: events to persist - inhibit_local_membership_updates: Stop the local_current_membership - from being updated by these events. This should be set to True - for backfilled events because backfilled events in the past do - not affect the current local state. + backfilled: True if the events were backfilled delete_existing True to purge existing table rows for the events from the database. This is useful when retrying due to IntegrityError. @@ -379,7 +363,9 @@ class PersistEventsStore: events_and_contexts ) - self._update_room_depths_txn(txn, events_and_contexts=events_and_contexts) + self._update_room_depths_txn( + txn, events_and_contexts=events_and_contexts, backfilled=backfilled + ) # _update_outliers_txn filters out any events which have already been # persisted, and returns the filtered list. @@ -412,7 +398,7 @@ class PersistEventsStore: txn, events_and_contexts=events_and_contexts, all_events_and_contexts=all_events_and_contexts, - inhibit_local_membership_updates=inhibit_local_membership_updates, + backfilled=backfilled, ) # We call this last as it assumes we've inserted the events into @@ -575,9 +561,9 @@ class PersistEventsStore: # fetch their auth event info. while missing_auth_chains: sql = """ - SELECT event_id, events.type, se.state_key, chain_id, sequence_number + SELECT event_id, events.type, state_key, chain_id, sequence_number FROM events - INNER JOIN state_events AS se USING (event_id) + INNER JOIN state_events USING (event_id) LEFT JOIN event_auth_chains USING (event_id) WHERE """ @@ -1214,6 +1200,7 @@ class PersistEventsStore: self, txn, events_and_contexts: List[Tuple[EventBase, EventContext]], + backfilled: bool, ): """Update min_depth for each room @@ -1221,18 +1208,13 @@ class PersistEventsStore: txn (twisted.enterprise.adbapi.Connection): db connection events_and_contexts (list[(EventBase, EventContext)]): events we are persisting + backfilled (bool): True if the events were backfilled """ depth_updates: Dict[str, int] = {} for event, context in events_and_contexts: # Remove the any existing cache entries for the event_ids txn.call_after(self.store._invalidate_get_event_cache, event.event_id) - # Then update the `stream_ordering` position to mark the latest - # event as the front of the room. This should not be done for - # backfilled events because backfilled events have negative - # stream_ordering and happened in the past so we know that we don't - # need to update the stream_ordering tip/front for the room. - assert event.internal_metadata.stream_ordering is not None - if event.internal_metadata.stream_ordering >= 0: + if not backfilled: txn.call_after( self.store._events_stream_cache.entity_has_changed, event.room_id, @@ -1445,12 +1427,7 @@ class PersistEventsStore: return [ec for ec in events_and_contexts if ec[0] not in to_remove] def _update_metadata_tables_txn( - self, - txn, - *, - events_and_contexts, - all_events_and_contexts, - inhibit_local_membership_updates: bool = False, + self, txn, events_and_contexts, all_events_and_contexts, backfilled ): """Update all the miscellaneous tables for new events @@ -1462,10 +1439,7 @@ class PersistEventsStore: events that we were going to persist. This includes events we've already persisted, etc, that wouldn't appear in events_and_context. - inhibit_local_membership_updates: Stop the local_current_membership - from being updated by these events. This should be set to True - for backfilled events because backfilled events in the past do - not affect the current local state. + backfilled (bool): True if the events were backfilled """ # Insert all the push actions into the event_push_actions table. @@ -1539,7 +1513,7 @@ class PersistEventsStore: for event, _ in events_and_contexts if event.type == EventTypes.Member ], - inhibit_local_membership_updates=inhibit_local_membership_updates, + backfilled=backfilled, ) # Insert event_reference_hashes table. @@ -1579,13 +1553,11 @@ class PersistEventsStore: for row in rows: event = ev_map[row["event_id"]] if not row["rejects"] and not row["redacts"]: - to_prefill.append(EventCacheEntry(event=event, redacted_event=None)) + to_prefill.append(_EventCacheEntry(event=event, redacted_event=None)) def prefill(): for cache_entry in to_prefill: - self.store._get_event_cache.set( - (cache_entry.event.event_id,), cache_entry - ) + self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry) txn.call_after(prefill) @@ -1666,19 +1638,8 @@ class PersistEventsStore: txn, table="event_reference_hashes", values=vals ) - def _store_room_members_txn( - self, txn, events, *, inhibit_local_membership_updates: bool = False - ): - """ - Store a room member in the database. - Args: - txn: The transaction to use. - events: List of events to store. - inhibit_local_membership_updates: Stop the local_current_membership - from being updated by these events. This should be set to True - for backfilled events because backfilled events in the past do - not affect the current local state. - """ + def _store_room_members_txn(self, txn, events, backfilled): + """Store a room member in the database.""" def non_null_str_or_none(val: Any) -> Optional[str]: return val if isinstance(val, str) and "\u0000" not in val else None @@ -1721,7 +1682,7 @@ class PersistEventsStore: # band membership", like a remote invite or a rejection of a remote invite. if ( self.is_mine_id(event.state_key) - and not inhibit_local_membership_updates + and not backfilled and event.internal_metadata.is_outlier() and event.internal_metadata.is_out_of_band_membership() ): diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c7b660ac5a..c6bf316d5b 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py
@@ -15,18 +15,14 @@ import logging import threading from typing import ( - TYPE_CHECKING, - Any, Collection, Container, Dict, Iterable, List, - NoReturn, Optional, Set, Tuple, - cast, overload, ) @@ -42,7 +38,6 @@ from synapse.api.errors import NotFoundError, SynapseError from synapse.api.room_versions import ( KNOWN_ROOM_VERSIONS, EventFormatVersions, - RoomVersion, RoomVersions, ) from synapse.events import EventBase, make_event_from_dict @@ -61,18 +56,10 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import ( - DatabasePool, - LoggingDatabaseConnection, - LoggingTransaction, -) +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.engines import PostgresEngine -from synapse.storage.types import Cursor -from synapse.storage.util.id_generators import ( - AbstractStreamIdTracker, - MultiWriterIdGenerator, - StreamIdGenerator, -) +from synapse.storage.types import Connection +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id from synapse.util import unwrapFirstError @@ -82,13 +69,10 @@ from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure -if TYPE_CHECKING: - from synapse.server import HomeServer - logger = logging.getLogger(__name__) -# These values are used in the `enqueue_event` and `_fetch_loop` methods to +# These values are used in the `enqueus_event` and `_do_fetch` methods to # control how we batch/bulk fetch events from the database. # The values are plucked out of thing air to make initial sync run faster # on jki.re @@ -105,7 +89,7 @@ event_fetch_ongoing_gauge = Gauge( @attr.s(slots=True, auto_attribs=True) -class EventCacheEntry: +class _EventCacheEntry: event: EventBase redacted_event: Optional[EventBase] @@ -145,7 +129,7 @@ class _EventRow: json: str internal_metadata: str format_version: Optional[int] - room_version_id: Optional[str] + room_version_id: Optional[int] rejected_reason: Optional[str] redactions: List[str] outlier: bool @@ -169,16 +153,9 @@ class EventsWorkerStore(SQLBaseStore): # options controlling this. USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True - def __init__( - self, - database: DatabasePool, - db_conn: LoggingDatabaseConnection, - hs: "HomeServer", - ): + def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self._stream_id_gen: AbstractStreamIdTracker - self._backfill_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): # If we're using Postgres than we can use `MultiWriterIdGenerator` # regardless of whether this process writes to the streams or not. @@ -237,7 +214,7 @@ class EventsWorkerStore(SQLBaseStore): 5 * 60 * 1000, ) - self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache( + self._get_event_cache = LruCache( cache_name="*getEvent*", max_size=hs.config.caches.event_cache_size, ) @@ -246,21 +223,19 @@ class EventsWorkerStore(SQLBaseStore): # ID to cache entry. Note that the returned dict may not have the # requested event in it if the event isn't in the DB. self._current_event_fetches: Dict[ - str, ObservableDeferred[Dict[str, EventCacheEntry]] + str, ObservableDeferred[Dict[str, _EventCacheEntry]] ] = {} self._event_fetch_lock = threading.Condition() - self._event_fetch_list: List[ - Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"] - ] = [] + self._event_fetch_list = [] self._event_fetch_ongoing = 0 event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) # We define this sequence here so that it can be referenced from both # the DataStore and PersistEventStore. - def get_chain_id_txn(txn: Cursor) -> int: + def get_chain_id_txn(txn): txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains") - return cast(Tuple[int], txn.fetchone())[0] + return txn.fetchone()[0] self.event_chain_id_gen = build_sequence_generator( db_conn, @@ -271,13 +246,7 @@ class EventsWorkerStore(SQLBaseStore): id_column="chain_id", ) - def process_replication_rows( - self, - stream_name: str, - instance_name: str, - token: int, - rows: Iterable[Any], - ) -> None: + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == EventsStream.NAME: self._stream_id_gen.advance(instance_name, token) elif stream_name == BackfillStream.NAME: @@ -311,10 +280,10 @@ class EventsWorkerStore(SQLBaseStore): self, event_id: str, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, - get_prev_content: bool = ..., - allow_rejected: bool = ..., - allow_none: Literal[False] = ..., - check_room_id: Optional[str] = ..., + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[False] = False, + check_room_id: Optional[str] = None, ) -> EventBase: ... @@ -323,10 +292,10 @@ class EventsWorkerStore(SQLBaseStore): self, event_id: str, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, - get_prev_content: bool = ..., - allow_rejected: bool = ..., - allow_none: Literal[True] = ..., - check_room_id: Optional[str] = ..., + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[True] = False, + check_room_id: Optional[str] = None, ) -> Optional[EventBase]: ... @@ -388,7 +357,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_events( self, - event_ids: Collection[str], + event_ids: Iterable[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, @@ -575,7 +544,7 @@ class EventsWorkerStore(SQLBaseStore): async def _get_events_from_cache_or_db( self, event_ids: Iterable[str], allow_rejected: bool = False - ) -> Dict[str, EventCacheEntry]: + ) -> Dict[str, _EventCacheEntry]: """Fetch a bunch of events from the cache or the database. If events are pulled from the database, they will be cached for future lookups. @@ -609,7 +578,7 @@ class EventsWorkerStore(SQLBaseStore): # same dict into itself N times). already_fetching_ids: Set[str] = set() already_fetching_deferreds: Set[ - ObservableDeferred[Dict[str, EventCacheEntry]] + ObservableDeferred[Dict[str, _EventCacheEntry]] ] = set() for event_id in missing_events_ids: @@ -632,8 +601,8 @@ class EventsWorkerStore(SQLBaseStore): # function returning more events than requested, but that can happen # already due to `_get_events_from_db`). fetching_deferred: ObservableDeferred[ - Dict[str, EventCacheEntry] - ] = ObservableDeferred(defer.Deferred(), consumeErrors=True) + Dict[str, _EventCacheEntry] + ] = ObservableDeferred(defer.Deferred()) for event_id in missing_events_ids: self._current_event_fetches[event_id] = fetching_deferred @@ -689,12 +658,12 @@ class EventsWorkerStore(SQLBaseStore): return event_entry_map - def _invalidate_get_event_cache(self, event_id: str) -> None: + def _invalidate_get_event_cache(self, event_id): self._get_event_cache.invalidate((event_id,)) def _get_events_from_cache( self, events: Iterable[str], update_metrics: bool = True - ) -> Dict[str, EventCacheEntry]: + ) -> Dict[str, _EventCacheEntry]: """Fetch events from the caches. May return rejected events. @@ -767,123 +736,38 @@ class EventsWorkerStore(SQLBaseStore): for e in state_to_include.values() ] - def _maybe_start_fetch_thread(self) -> None: - """Starts an event fetch thread if we are not yet at the maximum number.""" - with self._event_fetch_lock: - if ( - self._event_fetch_list - and self._event_fetch_ongoing < EVENT_QUEUE_THREADS - ): - self._event_fetch_ongoing += 1 - event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) - # `_event_fetch_ongoing` is decremented in `_fetch_thread`. - should_start = True - else: - should_start = False - - if should_start: - run_as_background_process("fetch_events", self._fetch_thread) - - async def _fetch_thread(self) -> None: - """Services requests for events from `_event_fetch_list`.""" - exc = None - try: - await self.db_pool.runWithConnection(self._fetch_loop) - except BaseException as e: - exc = e - raise - finally: - should_restart = False - event_fetches_to_fail = [] - with self._event_fetch_lock: - self._event_fetch_ongoing -= 1 - event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) - - # There may still be work remaining in `_event_fetch_list` if we - # failed, or it was added in between us deciding to exit and - # decrementing `_event_fetch_ongoing`. - if self._event_fetch_list: - if exc is None: - # We decided to exit, but then some more work was added - # before `_event_fetch_ongoing` was decremented. - # If a new event fetch thread was not started, we should - # restart ourselves since the remaining event fetch threads - # may take a while to get around to the new work. - # - # Unfortunately it is not possible to tell whether a new - # event fetch thread was started, so we restart - # unconditionally. If we are unlucky, we will end up with - # an idle fetch thread, but it will time out after - # `EVENT_QUEUE_ITERATIONS * EVENT_QUEUE_TIMEOUT_S` seconds - # in any case. - # - # Note that multiple fetch threads may run down this path at - # the same time. - should_restart = True - elif isinstance(exc, Exception): - if self._event_fetch_ongoing == 0: - # We were the last remaining fetcher and failed. - # Fail any outstanding fetches since no one else will - # handle them. - event_fetches_to_fail = self._event_fetch_list - self._event_fetch_list = [] - else: - # We weren't the last remaining fetcher, so another - # fetcher will pick up the work. This will either happen - # after their existing work, however long that takes, - # or after at most `EVENT_QUEUE_TIMEOUT_S` seconds if - # they are idle. - pass - else: - # The exception is a `SystemExit`, `KeyboardInterrupt` or - # `GeneratorExit`. Don't try to do anything clever here. - pass - - if should_restart: - # We exited cleanly but noticed more work. - self._maybe_start_fetch_thread() - - if event_fetches_to_fail: - # We were the last remaining fetcher and failed. - # Fail any outstanding fetches since no one else will handle them. - assert exc is not None - with PreserveLoggingContext(): - for _, deferred in event_fetches_to_fail: - deferred.errback(exc) - - def _fetch_loop(self, conn: LoggingDatabaseConnection) -> None: + def _do_fetch(self, conn: Connection) -> None: """Takes a database connection and waits for requests for events from the _event_fetch_list queue. """ - i = 0 - while True: - with self._event_fetch_lock: - event_list = self._event_fetch_list - self._event_fetch_list = [] - - if not event_list: - # There are no requests waiting. If we haven't yet reached the - # maximum iteration limit, wait for some more requests to turn up. - # Otherwise, bail out. - single_threaded = self.database_engine.single_threaded - if ( - not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING - or single_threaded - or i > EVENT_QUEUE_ITERATIONS - ): - return - - self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) - i += 1 - continue - i = 0 + try: + i = 0 + while True: + with self._event_fetch_lock: + event_list = self._event_fetch_list + self._event_fetch_list = [] + + if not event_list: + single_threaded = self.database_engine.single_threaded + if ( + not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING + or single_threaded + or i > EVENT_QUEUE_ITERATIONS + ): + break + else: + self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) + i += 1 + continue + i = 0 - self._fetch_event_list(conn, event_list) + self._fetch_event_list(conn, event_list) + finally: + self._event_fetch_ongoing -= 1 + event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) def _fetch_event_list( - self, - conn: LoggingDatabaseConnection, - event_list: List[Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]], + self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]] ) -> None: """Handle a load of requests from the _event_fetch_list queue @@ -910,7 +794,7 @@ class EventsWorkerStore(SQLBaseStore): ) # We only want to resolve deferreds from the main thread - def fire() -> None: + def fire(): for _, d in event_list: d.callback(row_dict) @@ -920,16 +804,18 @@ class EventsWorkerStore(SQLBaseStore): logger.exception("do_fetch") # We only want to resolve deferreds from the main thread - def fire_errback(exc: Exception) -> None: - for _, d in event_list: - d.errback(exc) + def fire(evs, exc): + for _, d in evs: + if not d.called: + with PreserveLoggingContext(): + d.errback(exc) with PreserveLoggingContext(): - self.hs.get_reactor().callFromThread(fire_errback, e) + self.hs.get_reactor().callFromThread(fire, event_list, e) async def _get_events_from_db( - self, event_ids: Collection[str] - ) -> Dict[str, EventCacheEntry]: + self, event_ids: Iterable[str] + ) -> Dict[str, _EventCacheEntry]: """Fetch a bunch of events from the database. May return rejected events. @@ -945,29 +831,29 @@ class EventsWorkerStore(SQLBaseStore): map from event id to result. May return extra events which weren't asked for. """ - fetched_event_ids: Set[str] = set() - fetched_events: Dict[str, _EventRow] = {} + fetched_events = {} events_to_fetch = event_ids while events_to_fetch: row_map = await self._enqueue_events(events_to_fetch) # we need to recursively fetch any redactions of those events - redaction_ids: Set[str] = set() + redaction_ids = set() for event_id in events_to_fetch: row = row_map.get(event_id) - fetched_event_ids.add(event_id) + fetched_events[event_id] = row if row: - fetched_events[event_id] = row redaction_ids.update(row.redactions) - events_to_fetch = redaction_ids.difference(fetched_event_ids) + events_to_fetch = redaction_ids.difference(fetched_events.keys()) if events_to_fetch: logger.debug("Also fetching redaction events %s", events_to_fetch) # build a map from event_id to EventBase - event_map: Dict[str, EventBase] = {} + event_map = {} for event_id, row in fetched_events.items(): + if not row: + continue assert row.event_id == event_id rejected_reason = row.rejected_reason @@ -995,7 +881,6 @@ class EventsWorkerStore(SQLBaseStore): room_version_id = row.room_version_id - room_version: Optional[RoomVersion] if not room_version_id: # this should only happen for out-of-band membership events which # arrived before #6983 landed. For all other events, we should have @@ -1066,14 +951,14 @@ class EventsWorkerStore(SQLBaseStore): # finally, we can decide whether each one needs redacting, and build # the cache entries. - result_map: Dict[str, EventCacheEntry] = {} + result_map = {} for event_id, original_ev in event_map.items(): redactions = fetched_events[event_id].redactions redacted_event = self._maybe_redact_event_row( original_ev, redactions, event_map ) - cache_entry = EventCacheEntry( + cache_entry = _EventCacheEntry( event=original_ev, redacted_event=redacted_event ) @@ -1082,7 +967,7 @@ class EventsWorkerStore(SQLBaseStore): return result_map - async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]: + async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]: """Fetches events from the database using the _event_fetch_list. This allows batch and bulk fetching of events - it allows us to fetch events without having to create a new transaction for each request for events. @@ -1095,12 +980,23 @@ class EventsWorkerStore(SQLBaseStore): that weren't requested. """ - events_d: "defer.Deferred[Dict[str, _EventRow]]" = defer.Deferred() + events_d = defer.Deferred() with self._event_fetch_lock: self._event_fetch_list.append((events, events_d)) + self._event_fetch_lock.notify() - self._maybe_start_fetch_thread() + if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: + self._event_fetch_ongoing += 1 + event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) + should_start = True + else: + should_start = False + + if should_start: + run_as_background_process( + "fetch_events", self.db_pool.runWithConnection, self._do_fetch + ) logger.debug("Loading %d events: %s", len(events), events) with PreserveLoggingContext(): @@ -1250,7 +1146,7 @@ class EventsWorkerStore(SQLBaseStore): # no valid redaction found for this event return None - async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]: + async def have_events_in_timeline(self, event_ids): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ @@ -1279,7 +1175,7 @@ class EventsWorkerStore(SQLBaseStore): event_ids: events we are looking for Returns: - The set of events we have already seen. + set[str]: The events we have already seen. """ res = await self._have_seen_events_dict( (room_id, event_id) for event_id in event_ids @@ -1302,9 +1198,7 @@ class EventsWorkerStore(SQLBaseStore): } results = {x: True for x in cache_results} - def have_seen_events_txn( - txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...] - ) -> None: + def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]): # we deliberately do *not* query the database for room_id, to make the # query an index-only lookup on `events_event_id_key`. # @@ -1330,14 +1224,12 @@ class EventsWorkerStore(SQLBaseStore): return results @cached(max_entries=100000, tree=True) - async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn: + async def have_seen_event(self, room_id: str, event_id: str): # this only exists for the benefit of the @cachedList descriptor on # _have_seen_events_dict raise NotImplementedError() - def _get_current_state_event_counts_txn( - self, txn: LoggingTransaction, room_id: str - ) -> int: + def _get_current_state_event_counts_txn(self, txn, room_id): """ See get_current_state_event_counts. """ @@ -1362,7 +1254,7 @@ class EventsWorkerStore(SQLBaseStore): room_id, ) - async def get_room_complexity(self, room_id: str) -> Dict[str, float]: + async def get_room_complexity(self, room_id): """ Get a rough approximation of the complexity of the room. This is used by remote servers to decide whether they wish to join the room or not. @@ -1370,10 +1262,10 @@ class EventsWorkerStore(SQLBaseStore): more resources. Args: - room_id: The room ID to query. + room_id (str) Returns: - dict[str:float] of complexity version to complexity. + dict[str:int] of complexity version to complexity. """ state_events = await self.get_current_state_event_counts(room_id) @@ -1383,13 +1275,13 @@ class EventsWorkerStore(SQLBaseStore): return {"v1": complexity_v1} - def get_current_events_token(self) -> int: + def get_current_events_token(self): """The current maximum token that events have reached""" return self._stream_id_gen.get_current_token() async def get_all_new_forward_event_rows( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + ) -> List[Tuple]: """Returns new events, for the Events replication stream Args: @@ -1403,15 +1295,13 @@ class EventsWorkerStore(SQLBaseStore): EventsStreamRow. """ - def get_all_new_forward_event_rows( - txn: LoggingTransaction, - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + def get_all_new_forward_event_rows(txn): sql = ( "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" + " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events AS se USING (event_id)" + " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " LEFT JOIN room_memberships USING (event_id)" " LEFT JOIN rejections USING (event_id)" @@ -1421,9 +1311,7 @@ class EventsWorkerStore(SQLBaseStore): " LIMIT ?" ) txn.execute(sql, (last_id, current_id, instance_name, limit)) - return cast( - List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall() - ) + return txn.fetchall() return await self.db_pool.runInteraction( "get_all_new_forward_event_rows", get_all_new_forward_event_rows @@ -1431,7 +1319,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_ex_outlier_stream_rows( self, instance_name: str, last_id: int, current_id: int - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + ) -> List[Tuple]: """Returns de-outliered events, for the Events replication stream Args: @@ -1444,16 +1332,14 @@ class EventsWorkerStore(SQLBaseStore): EventsStreamRow. """ - def get_ex_outlier_stream_rows_txn( - txn: LoggingTransaction, - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + def get_ex_outlier_stream_rows_txn(txn): sql = ( "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," - " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" + " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" " FROM events AS e" " INNER JOIN ex_outlier_stream AS out USING (event_id)" " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events AS se USING (event_id)" + " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " LEFT JOIN room_memberships USING (event_id)" " LEFT JOIN rejections USING (event_id)" @@ -1464,9 +1350,7 @@ class EventsWorkerStore(SQLBaseStore): ) txn.execute(sql, (last_id, current_id, instance_name)) - return cast( - List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall() - ) + return txn.fetchall() return await self.db_pool.runInteraction( "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn @@ -1474,7 +1358,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_all_new_backfill_event_rows( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]: + ) -> Tuple[List[Tuple[int, list]], int, bool]: """Get updates for backfill replication stream, including all new backfilled events and events that have gone from being outliers to not. @@ -1502,15 +1386,13 @@ class EventsWorkerStore(SQLBaseStore): if last_id == current_id: return [], current_id, False - def get_all_new_backfill_event_rows( - txn: LoggingTransaction, - ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]: + def get_all_new_backfill_event_rows(txn): sql = ( "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " se.state_key, redacts, relates_to_id" + " state_key, redacts, relates_to_id" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events AS se USING (event_id)" + " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " WHERE ? > stream_ordering AND stream_ordering >= ?" " AND instance_name = ?" @@ -1518,15 +1400,7 @@ class EventsWorkerStore(SQLBaseStore): " LIMIT ?" ) txn.execute(sql, (-last_id, -current_id, instance_name, limit)) - new_event_updates: List[ - Tuple[int, Tuple[str, str, str, str, str, str]] - ] = [] - row: Tuple[int, str, str, str, str, str, str] - # Type safety: iterating over `txn` yields `Tuple`, i.e. - # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a - # variadic tuple to a fixed length tuple and flags it up as an error. - for row in txn: # type: ignore[assignment] - new_event_updates.append((row[0], row[1:])) + new_event_updates = [(row[0], row[1:]) for row in txn] limited = False if len(new_event_updates) == limit: @@ -1537,11 +1411,11 @@ class EventsWorkerStore(SQLBaseStore): sql = ( "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," - " se.state_key, redacts, relates_to_id" + " state_key, redacts, relates_to_id" " FROM events AS e" " INNER JOIN ex_outlier_stream AS out USING (event_id)" " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events AS se USING (event_id)" + " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " WHERE ? > event_stream_ordering" " AND event_stream_ordering >= ?" @@ -1549,11 +1423,7 @@ class EventsWorkerStore(SQLBaseStore): " ORDER BY event_stream_ordering DESC" ) txn.execute(sql, (-last_id, -upper_bound, instance_name)) - # Type safety: iterating over `txn` yields `Tuple`, i.e. - # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a - # variadic tuple to a fixed length tuple and flags it up as an error. - for row in txn: # type: ignore[assignment] - new_event_updates.append((row[0], row[1:])) + new_event_updates.extend((row[0], row[1:]) for row in txn) if len(new_event_updates) >= limit: upper_bound = new_event_updates[-1][0] @@ -1567,7 +1437,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_all_updated_current_state_deltas( self, instance_name: str, from_token: int, to_token: int, target_row_count: int - ) -> Tuple[List[Tuple[int, str, str, str, str]], int, bool]: + ) -> Tuple[List[Tuple], int, bool]: """Fetch updates from current_state_delta_stream Args: @@ -1587,9 +1457,7 @@ class EventsWorkerStore(SQLBaseStore): * `limited` is whether there are more updates to fetch. """ - def get_all_updated_current_state_deltas_txn( - txn: LoggingTransaction, - ) -> List[Tuple[int, str, str, str, str]]: + def get_all_updated_current_state_deltas_txn(txn): sql = """ SELECT stream_id, room_id, type, state_key, event_id FROM current_state_delta_stream @@ -1598,23 +1466,21 @@ class EventsWorkerStore(SQLBaseStore): ORDER BY stream_id ASC LIMIT ? """ txn.execute(sql, (from_token, to_token, instance_name, target_row_count)) - return cast(List[Tuple[int, str, str, str, str]], txn.fetchall()) + return txn.fetchall() - def get_deltas_for_stream_id_txn( - txn: LoggingTransaction, stream_id: int - ) -> List[Tuple[int, str, str, str, str]]: + def get_deltas_for_stream_id_txn(txn, stream_id): sql = """ SELECT stream_id, room_id, type, state_key, event_id FROM current_state_delta_stream WHERE stream_id = ? """ txn.execute(sql, [stream_id]) - return cast(List[Tuple[int, str, str, str, str]], txn.fetchall()) + return txn.fetchall() # we need to make sure that, for every stream id in the results, we get *all* # the rows with that stream id. - rows: List[Tuple[int, str, str, str, str]] = await self.db_pool.runInteraction( + rows: List[Tuple] = await self.db_pool.runInteraction( "get_all_updated_current_state_deltas", get_all_updated_current_state_deltas_txn, ) @@ -1643,14 +1509,14 @@ class EventsWorkerStore(SQLBaseStore): return rows, to_token, True - async def is_event_after(self, event_id1: str, event_id2: str) -> bool: + async def is_event_after(self, event_id1, event_id2): """Returns True if event_id1 is after event_id2 in the stream""" to_1, so_1 = await self.get_event_ordering(event_id1) to_2, so_2 = await self.get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) @cached(max_entries=5000) - async def get_event_ordering(self, event_id: str) -> Tuple[int, int]: + async def get_event_ordering(self, event_id): res = await self.db_pool.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], @@ -1673,9 +1539,7 @@ class EventsWorkerStore(SQLBaseStore): None otherwise. """ - def get_next_event_to_expire_txn( - txn: LoggingTransaction, - ) -> Optional[Tuple[str, int]]: + def get_next_event_to_expire_txn(txn): txn.execute( """ SELECT event_id, expiry_ts FROM event_expiry @@ -1683,7 +1547,7 @@ class EventsWorkerStore(SQLBaseStore): """ ) - return cast(Optional[Tuple[str, int]], txn.fetchone()) + return txn.fetchone() return await self.db_pool.runInteraction( desc="get_next_event_to_expire", func=get_next_event_to_expire_txn @@ -1747,10 +1611,10 @@ class EventsWorkerStore(SQLBaseStore): return mapping @wrap_as_background_process("_cleanup_old_transaction_ids") - async def _cleanup_old_transaction_ids(self) -> None: + async def _cleanup_old_transaction_ids(self): """Cleans out transaction id mappings older than 24hrs.""" - def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None: + def _cleanup_old_transaction_ids_txn(txn): sql = """ DELETE FROM event_txn_id WHERE inserted_ts < ? @@ -1762,198 +1626,3 @@ class EventsWorkerStore(SQLBaseStore): "_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn, ) - - async def is_event_next_to_backward_gap(self, event: EventBase) -> bool: - """Check if the given event is next to a backward gap of missing events. - <latest messages> A(False)--->B(False)--->C(True)---> <gap, unknown events> <oldest messages> - - Args: - room_id: room where the event lives - event_id: event to check - - Returns: - Boolean indicating whether it's an extremity - """ - - def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool: - # If the event in question has any of its prev_events listed as a - # backward extremity, it's next to a gap. - # - # We can't just check the backward edges in `event_edges` because - # when we persist events, we will also record the prev_events as - # edges to the event in question regardless of whether we have those - # prev_events yet. We need to check whether those prev_events are - # backward extremities, also known as gaps, that need to be - # backfilled. - backward_extremity_query = """ - SELECT 1 FROM event_backward_extremities - WHERE - room_id = ? - AND %s - LIMIT 1 - """ - - # If the event in question is a backward extremity or has any of its - # prev_events listed as a backward extremity, it's next to a - # backward gap. - clause, args = make_in_list_sql_clause( - self.database_engine, - "event_id", - [event.event_id] + list(event.prev_event_ids()), - ) - - txn.execute(backward_extremity_query % (clause,), [event.room_id] + args) - backward_extremities = txn.fetchall() - - # We consider any backward extremity as a backward gap - if len(backward_extremities): - return True - - return False - - return await self.db_pool.runInteraction( - "is_event_next_to_backward_gap_txn", - is_event_next_to_backward_gap_txn, - ) - - async def is_event_next_to_forward_gap(self, event: EventBase) -> bool: - """Check if the given event is next to a forward gap of missing events. - The gap in front of the latest events is not considered a gap. - <latest messages> A(False)--->B(False)--->C(False)---> <gap, unknown events> <oldest messages> - <latest messages> A(False)--->B(False)---> <gap, unknown events> --->D(True)--->E(False) <oldest messages> - - Args: - room_id: room where the event lives - event_id: event to check - - Returns: - Boolean indicating whether it's an extremity - """ - - def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool: - # If the event in question is a forward extremity, we will just - # consider any potential forward gap as not a gap since it's one of - # the latest events in the room. - # - # `event_forward_extremities` does not include backfilled or outlier - # events so we can't rely on it to find forward gaps. We can only - # use it to determine whether a message is the latest in the room. - # - # We can't combine this query with the `forward_edge_query` below - # because if the event in question has no forward edges (isn't - # referenced by any other event's prev_events) but is in - # `event_forward_extremities`, we don't want to return 0 rows and - # say it's next to a gap. - forward_extremity_query = """ - SELECT 1 FROM event_forward_extremities - WHERE - room_id = ? - AND event_id = ? - LIMIT 1 - """ - - # Check to see whether the event in question is already referenced - # by another event. If we don't see any edges, we're next to a - # forward gap. - forward_edge_query = """ - SELECT 1 FROM event_edges - /* Check to make sure the event referencing our event in question is not rejected */ - LEFT JOIN rejections ON event_edges.event_id == rejections.event_id - WHERE - event_edges.room_id = ? - AND event_edges.prev_event_id = ? - /* It's not a valid edge if the event referencing our event in - * question is rejected. - */ - AND rejections.event_id IS NULL - LIMIT 1 - """ - - # We consider any forward extremity as the latest in the room and - # not a forward gap. - # - # To expand, even though there is technically a gap at the front of - # the room where the forward extremities are, we consider those the - # latest messages in the room so asking other homeservers for more - # is useless. The new latest messages will just be federated as - # usual. - txn.execute(forward_extremity_query, (event.room_id, event.event_id)) - forward_extremities = txn.fetchall() - if len(forward_extremities): - return False - - # If there are no forward edges to the event in question (another - # event hasn't referenced this event in their prev_events), then we - # assume there is a forward gap in the history. - txn.execute(forward_edge_query, (event.room_id, event.event_id)) - forward_edges = txn.fetchall() - if not len(forward_edges): - return True - - return False - - return await self.db_pool.runInteraction( - "is_event_next_to_gap_txn", - is_event_next_to_gap_txn, - ) - - async def get_event_id_for_timestamp( - self, room_id: str, timestamp: int, direction: str - ) -> Optional[str]: - """Find the closest event to the given timestamp in the given direction. - - Args: - room_id: Room to fetch the event from - timestamp: The point in time (inclusive) we should navigate from in - the given direction to find the closest event. - direction: ["f"|"b"] to indicate whether we should navigate forward - or backward from the given timestamp to find the closest event. - - Returns: - The closest event_id otherwise None if we can't find any event in - the given direction. - """ - - sql_template = """ - SELECT event_id FROM events - LEFT JOIN rejections USING (event_id) - WHERE - origin_server_ts %s ? - AND room_id = ? - /* Make sure event is not rejected */ - AND rejections.event_id IS NULL - ORDER BY origin_server_ts %s - LIMIT 1; - """ - - def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]: - if direction == "b": - # Find closest event *before* a given timestamp. We use descending - # (which gives values largest to smallest) because we want the - # largest possible timestamp *before* the given timestamp. - comparison_operator = "<=" - order = "DESC" - else: - # Find closest event *after* a given timestamp. We use ascending - # (which gives values smallest to largest) because we want the - # closest possible timestamp *after* the given timestamp. - comparison_operator = ">=" - order = "ASC" - - txn.execute( - sql_template % (comparison_operator, order), (timestamp, room_id) - ) - row = txn.fetchone() - if row: - (event_id,) = row - return event_id - - return None - - if direction not in ("f", "b"): - raise ValueError("Unknown direction: %s" % (direction,)) - - return await self.db_pool.runInteraction( - "get_event_id_for_timestamp_txn", - get_event_id_for_timestamp_txn, - ) diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 91b0576b85..3eb30944bf 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py
@@ -118,7 +118,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): logger.info("[purge] looking for events to delete") - should_delete_expr = "state_events.state_key IS NULL" + should_delete_expr = "state_key IS NULL" should_delete_params: Tuple[Any, ...] = () if not delete_local_events: should_delete_expr += " AND event_id NOT LIKE ?" diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 3b63267395..fa782023d4 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py
@@ -28,10 +28,7 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException -from synapse.storage.util.id_generators import ( - AbstractStreamIdTracker, - StreamIdGenerator, -) +from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -85,9 +82,9 @@ class PushRulesWorkerStore( super().__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: - self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, "push_rules_stream", "stream_id" - ) + self._push_rules_stream_id_gen: Union[ + StreamIdGenerator, SlavedIdTracker + ] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id") else: self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id" diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e1ddf06916..0e8c168667 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -106,15 +106,6 @@ class RefreshTokenLookupResult: has_next_access_token_been_used: bool """True if the next access token was already used at least once.""" - expiry_ts: Optional[int] - """The time at which the refresh token expires and can not be used. - If None, the refresh token doesn't expire.""" - - ultimate_session_expiry_ts: Optional[int] - """The time at which the session comes to an end and can no longer be - refreshed. - If None, the session can be refreshed indefinitely.""" - class RegistrationWorkerStore(CacheInvalidationWorkerStore): def __init__( @@ -1635,10 +1626,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): rt.user_id, rt.device_id, rt.next_token_id, - (nrt.next_token_id IS NOT NULL) AS has_next_refresh_token_been_refreshed, - at.used AS has_next_access_token_been_used, - rt.expiry_ts, - rt.ultimate_session_expiry_ts + (nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed, + at.used has_next_access_token_been_used FROM refresh_tokens rt LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id @@ -1659,8 +1648,6 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): has_next_refresh_token_been_refreshed=row[4], # This column is nullable, ensure it's a boolean has_next_access_token_been_used=(row[5] or False), - expiry_ts=row[6], - ultimate_session_expiry_ts=row[7], ) return await self.db_pool.runInteraction( @@ -1928,8 +1915,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): user_id: str, token: str, device_id: Optional[str], - expiry_ts: Optional[int], - ultimate_session_expiry_ts: Optional[int], ) -> int: """Adds a refresh token for the given user. @@ -1937,13 +1922,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): user_id: The user ID. token: The new access token to add. device_id: ID of the device to associate with the refresh token. - expiry_ts (milliseconds since the epoch): Time after which the - refresh token cannot be used. - If None, the refresh token never expires until it has been used. - ultimate_session_expiry_ts (milliseconds since the epoch): - Time at which the session will end and can not be extended any - further. - If None, the session can be refreshed indefinitely. Raises: StoreError if there was a problem adding this. Returns: @@ -1959,8 +1937,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): "device_id": device_id, "token": token, "next_token_id": None, - "expiry_ts": expiry_ts, - "ultimate_session_expiry_ts": ultimate_session_expiry_ts, }, desc="add_refresh_token_to_user", ) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 6b2a8d06a6..033a9831d6 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py
@@ -476,7 +476,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): INNER JOIN events AS e USING (room_id, event_id) WHERE c.type = 'm.room.member' - AND c.state_key = ? + AND state_key = ? AND c.membership = ? """ else: @@ -487,7 +487,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): INNER JOIN events AS e USING (room_id, event_id) WHERE c.type = 'm.room.member' - AND c.state_key = ? + AND state_key = ? AND m.membership = ? """ diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 57aab55259..42dc807d17 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py
@@ -497,7 +497,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): oldest `limit` events. Returns: - The list of events (in ascending stream order) and the token from the start + The list of events (in ascending order) and the token from the start of the chunk of events returned. """ if from_key == to_key: @@ -510,7 +510,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): if not has_changed: return [], from_key - def f(txn: LoggingTransaction) -> List[_EventDictReturn]: + def f(txn): # To handle tokens with a non-empty instance_map we fetch more # results than necessary and then filter down min_from_id = from_key.stream @@ -565,13 +565,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): async def get_membership_changes_for_user( self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken ) -> List[EventBase]: - """Fetch membership events for a given user. - - All such events whose stream ordering `s` lies in the range - `from_key < s <= to_key` are returned. Events are ordered by ascending stream - order. - """ - # Start by ruling out cases where a DB query is not necessary. if from_key == to_key: return [] @@ -582,7 +575,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): if not has_changed: return [] - def f(txn: LoggingTransaction) -> List[_EventDictReturn]: + def f(txn): # To handle tokens with a non-empty instance_map we fetch more # results than necessary and then filter down min_from_id = from_key.stream @@ -641,7 +634,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): Returns: A list of events and a token pointing to the start of the returned - events. The events returned are in ascending topological order. + events. The events returned are in ascending order. """ rows, token = await self.get_recent_event_ids_for_room( diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 1622822552..d7dc1f73ac 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py
@@ -14,7 +14,6 @@ import logging from collections import namedtuple -from enum import Enum from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple import attr @@ -45,16 +44,6 @@ _UpdateTransactionRow = namedtuple( ) -class DestinationSortOrder(Enum): - """Enum to define the sorting method used when returning destinations.""" - - DESTINATION = "destination" - RETRY_LAST_TS = "retry_last_ts" - RETTRY_INTERVAL = "retry_interval" - FAILURE_TS = "failure_ts" - LAST_SUCCESSFUL_STREAM_ORDERING = "last_successful_stream_ordering" - - @attr.s(slots=True, frozen=True, auto_attribs=True) class DestinationRetryTimings: """The current destination retry timing info for a remote server.""" @@ -491,62 +480,3 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): destinations = [row[0] for row in txn] return destinations - - async def get_destinations_paginate( - self, - start: int, - limit: int, - destination: Optional[str] = None, - order_by: str = DestinationSortOrder.DESTINATION.value, - direction: str = "f", - ) -> Tuple[List[JsonDict], int]: - """Function to retrieve a paginated list of destinations. - This will return a json list of destinations and the - total number of destinations matching the filter criteria. - - Args: - start: start number to begin the query from - limit: number of rows to retrieve - destination: search string in destination - order_by: the sort order of the returned list - direction: sort ascending or descending - Returns: - A tuple of a list of mappings from destination to information - and a count of total destinations. - """ - - def get_destinations_paginate_txn( - txn: LoggingTransaction, - ) -> Tuple[List[JsonDict], int]: - order_by_column = DestinationSortOrder(order_by).value - - if direction == "b": - order = "DESC" - else: - order = "ASC" - - args = [] - where_statement = "" - if destination: - args.extend(["%" + destination.lower() + "%"]) - where_statement = "WHERE LOWER(destination) LIKE ?" - - sql_base = f"FROM destinations {where_statement} " - sql = f"SELECT COUNT(*) as total_destinations {sql_base}" - txn.execute(sql, args) - count = txn.fetchone()[0] - - sql = f""" - SELECT destination, retry_last_ts, retry_interval, failure_ts, - last_successful_stream_ordering - {sql_base} - ORDER BY {order_by_column} {order}, destination ASC - LIMIT ? OFFSET ? - """ - txn.execute(sql, args + [limit, start]) - destinations = self.db_pool.cursor_to_dict(txn) - return destinations, count - - return await self.db_pool.runInteraction( - "get_destinations_paginate_txn", get_destinations_paginate_txn - ) diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 428d66a617..402f134d89 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py
@@ -583,8 +583,7 @@ class EventsPersistenceStorage: current_state_for_room=current_state_for_room, state_delta_for_room=state_delta_for_room, new_forward_extremeties=new_forward_extremeties, - use_negative_stream_ordering=backfilled, - inhibit_local_membership_updates=backfilled, + backfilled=backfilled, ) await self._handle_potentially_left_users(potentially_left_users) diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 50d08094d5..3a00ed6835 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 66 # remember to update the list below when updating +SCHEMA_VERSION = 65 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -46,10 +46,6 @@ Changes in SCHEMA_VERSION = 65: - MSC2716: Remove unique event_id constraint from insertion_event_edges because an insertion event can have multiple edges. - Remove unused tables `user_stats_historical` and `room_stats_historical`. - -Changes in SCHEMA_VERSION = 66: - - Queries on state_key columns are now disambiguated (ie, the codebase can handle - the `events` table having a `state_key` column). """ diff --git a/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql b/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql deleted file mode 100644
index bdc491c817..0000000000 --- a/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql +++ /dev/null
@@ -1,28 +0,0 @@ -/* Copyright 2021 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -ALTER TABLE refresh_tokens - -- We add an expiry_ts column (in milliseconds since the Epoch) to refresh tokens. - -- They may not be used after they have expired. - -- If null, then the refresh token's lifetime is unlimited. - ADD COLUMN expiry_ts BIGINT DEFAULT NULL; - -ALTER TABLE refresh_tokens - -- We also add an ultimate session expiry time (in milliseconds since the Epoch). - -- No matter how much the access and refresh tokens are refreshed, they cannot - -- be extended past this time. - -- If null, then the session length is unlimited. - ADD COLUMN ultimate_session_expiry_ts BIGINT DEFAULT NULL; diff --git a/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql b/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql deleted file mode 100644
index a65bfb520d..0000000000 --- a/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql +++ /dev/null
@@ -1,27 +0,0 @@ -/* Copyright 2021 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Track the auth provider used by each login as well as the session ID -CREATE TABLE device_auth_providers ( - user_id TEXT NOT NULL, - device_id TEXT NOT NULL, - auth_provider_id TEXT NOT NULL, - auth_provider_session_id TEXT NOT NULL -); - -CREATE INDEX device_auth_providers_devices - ON device_auth_providers (user_id, device_id); -CREATE INDEX device_auth_providers_sessions - ON device_auth_providers (auth_provider_id, auth_provider_session_id); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 4ff3013908..ac56bc9a05 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py
@@ -89,77 +89,31 @@ def _load_current_id( return (max if step > 0 else min)(current_id, step) -class AbstractStreamIdTracker(metaclass=abc.ABCMeta): - """Tracks the "current" stream ID of a stream that may have multiple writers. - - Stream IDs are monotonically increasing or decreasing integers representing write - transactions. The "current" stream ID is the stream ID such that all transactions - with equal or smaller stream IDs have completed. Since transactions may complete out - of order, this is not the same as the stream ID of the last completed transaction. - - Completed transactions include both committed transactions and transactions that - have been rolled back. - """ - - @abc.abstractmethod - def advance(self, instance_name: str, new_id: int) -> None: - """Advance the position of the named writer to the given ID, if greater - than existing entry. - """ - raise NotImplementedError() - +class AbstractStreamIdGenerator(metaclass=abc.ABCMeta): @abc.abstractmethod - def get_current_token(self) -> int: - """Returns the maximum stream id such that all stream ids less than or - equal to it have been successfully persisted. - - Returns: - The maximum stream id. - """ + def get_next(self) -> AsyncContextManager[int]: raise NotImplementedError() @abc.abstractmethod - def get_current_token_for_writer(self, instance_name: str) -> int: - """Returns the position of the given writer. - - For streams with single writers this is equivalent to `get_current_token`. - """ + def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: raise NotImplementedError() - -class AbstractStreamIdGenerator(AbstractStreamIdTracker): - """Generates stream IDs for a stream that may have multiple writers. - - Each stream ID represents a write transaction, whose completion is tracked - so that the "current" stream ID of the stream can be determined. - - See `AbstractStreamIdTracker` for more details. - """ - @abc.abstractmethod - def get_next(self) -> AsyncContextManager[int]: - """ - Usage: - async with stream_id_gen.get_next() as stream_id: - # ... persist event ... - """ + def get_current_token(self) -> int: raise NotImplementedError() @abc.abstractmethod - def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: - """ - Usage: - async with stream_id_gen.get_next(n) as stream_ids: - # ... persist events ... - """ + def get_current_token_for_writer(self, instance_name: str) -> int: raise NotImplementedError() class StreamIdGenerator(AbstractStreamIdGenerator): - """Generates and tracks stream IDs for a stream with a single writer. + """Used to generate new stream ids when persisting events while keeping + track of which transactions have been completed. - This class must only be used when the current Synapse process is the sole - writer for a stream. + This allows us to get the "current" stream id, i.e. the stream id such that + all ids less than or equal to it have completed. This handles the fact that + persistence of events can complete out of order. Args: db_conn(connection): A database connection to use to fetch the @@ -203,12 +157,12 @@ class StreamIdGenerator(AbstractStreamIdGenerator): # The key and values are the same, but we never look at the values. self._unfinished_ids: OrderedDict[int, int] = OrderedDict() - def advance(self, instance_name: str, new_id: int) -> None: - # `StreamIdGenerator` should only be used when there is a single writer, - # so replication should never happen. - raise Exception("Replication is not supported by StreamIdGenerator") - def get_next(self) -> AsyncContextManager[int]: + """ + Usage: + async with stream_id_gen.get_next() as stream_id: + # ... persist event ... + """ with self._lock: self._current += self._step next_id = self._current @@ -226,6 +180,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator): return _AsyncCtxManagerWrapper(manager()) def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: + """ + Usage: + async with stream_id_gen.get_next(n) as stream_ids: + # ... persist events ... + """ with self._lock: next_ids = range( self._current + self._step, @@ -249,6 +208,12 @@ class StreamIdGenerator(AbstractStreamIdGenerator): return _AsyncCtxManagerWrapper(manager()) def get_current_token(self) -> int: + """Returns the maximum stream id such that all stream ids less than or + equal to it have been successfully persisted. + + Returns: + The maximum stream id. + """ with self._lock: if self._unfinished_ids: return next(iter(self._unfinished_ids)) - self._step @@ -256,11 +221,16 @@ class StreamIdGenerator(AbstractStreamIdGenerator): return self._current def get_current_token_for_writer(self, instance_name: str) -> int: + """Returns the position of the given writer. + + For streams with single writers this is equivalent to + `get_current_token`. + """ return self.get_current_token() class MultiWriterIdGenerator(AbstractStreamIdGenerator): - """Generates and tracks stream IDs for a stream with multiple writers. + """An ID generator that tracks a stream that can have multiple writers. Uses a Postgres sequence to coordinate ID assignment, but positions of other writers will only get updated when `advance` is called (by replication). @@ -505,6 +475,12 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): return stream_ids def get_next(self) -> AsyncContextManager[int]: + """ + Usage: + async with stream_id_gen.get_next() as stream_id: + # ... persist event ... + """ + # If we have a list of instances that are allowed to write to this # stream, make sure we're in it. if self._writers and self._instance_name not in self._writers: @@ -516,6 +492,12 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): return cast(AsyncContextManager[int], _MultiWriterCtxManager(self)) def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]: + """ + Usage: + async with stream_id_gen.get_next_mult(5) as stream_ids: + # ... persist events ... + """ + # If we have a list of instances that are allowed to write to this # stream, make sure we're in it. if self._writers and self._instance_name not in self._writers: @@ -615,9 +597,15 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): self._add_persisted_position(next_id) def get_current_token(self) -> int: + """Returns the maximum stream id such that all stream ids less than or + equal to it have been successfully persisted. + """ + return self.get_persisted_upto_position() def get_current_token_for_writer(self, instance_name: str) -> int: + """Returns the position of the given writer.""" + # If we don't have an entry for the given instance name, we assume it's a # new writer. # @@ -643,6 +631,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): } def advance(self, instance_name: str, new_id: int) -> None: + """Advance the position of the named writer to the given ID, if greater + than existing entry. + """ + new_id *= self._return_factor with self._lock: