From bd4919fb72b2a75f1c0a7f0c78bd619fd2ae30e8 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 24 Jun 2021 15:33:20 +0200 Subject: MSC2918 Refresh tokens implementation (#9450) This implements refresh tokens, as defined by MSC2918 This MSC has been implemented client side in Hydrogen Web: vector-im/hydrogen-web#235 The basics of the MSC works: requesting refresh tokens on login, having the access tokens expire, and using the refresh token to get a new one. Signed-off-by: Quentin Gliech --- synapse/storage/databases/main/registration.py | 207 ++++++++++++++++++++++++- 1 file changed, 203 insertions(+), 4 deletions(-) (limited to 'synapse/storage/databases/main') diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index e5c5cf8ff0..e31c5864ac 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -53,6 +53,9 @@ class TokenLookupResult: valid_until_ms: The timestamp the token expires, if any. token_owner: The "owner" of the token. This is either the same as the user, or a server admin who is logged in as the user. + token_used: True if this token was used at least once in a request. + This field can be out of date since `get_user_by_access_token` is + cached. """ user_id = attr.ib(type=str) @@ -62,6 +65,7 @@ class TokenLookupResult: device_id = attr.ib(type=Optional[str], default=None) valid_until_ms = attr.ib(type=Optional[int], default=None) token_owner = attr.ib(type=str) + token_used = attr.ib(type=bool, default=False) # Make the token owner default to the user ID, which is the common case. @token_owner.default @@ -69,6 +73,29 @@ class TokenLookupResult: return self.user_id +@attr.s(frozen=True, slots=True) +class RefreshTokenLookupResult: + """Result of looking up a refresh token.""" + + user_id = attr.ib(type=str) + """The user this token belongs to.""" + + device_id = attr.ib(type=str) + """The device associated with this refresh token.""" + + token_id = attr.ib(type=int) + """The ID of this refresh token.""" + + next_token_id = attr.ib(type=Optional[int]) + """The ID of the refresh token which replaced this one.""" + + has_next_refresh_token_been_refreshed = attr.ib(type=bool) + """True if the next refresh token was used for another refresh.""" + + has_next_access_token_been_used = attr.ib(type=bool) + """True if the next access token was already used at least once.""" + + class RegistrationWorkerStore(CacheInvalidationWorkerStore): def __init__( self, @@ -441,7 +468,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): access_tokens.id as token_id, access_tokens.device_id, access_tokens.valid_until_ms, - access_tokens.user_id as token_owner + access_tokens.user_id as token_owner, + access_tokens.used as token_used FROM users INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id) WHERE token = ? @@ -449,8 +477,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): txn.execute(sql, (token,)) rows = self.db_pool.cursor_to_dict(txn) + if rows: - return TokenLookupResult(**rows[0]) + row = rows[0] + + # This field is nullable, ensure it comes out as a boolean + if row["token_used"] is None: + row["token_used"] = False + + return TokenLookupResult(**row) return None @@ -1072,6 +1107,111 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="update_access_token_last_validated", ) + @cached() + async def mark_access_token_as_used(self, token_id: int) -> None: + """ + Mark the access token as used, which invalidates the refresh token used + to obtain it. + + Because get_user_by_access_token is cached, this function might be + called multiple times for the same token, effectively doing unnecessary + SQL updates. Because updating the `used` field only goes one way (from + False to True) it is safe to cache this function as well to avoid this + issue. + + Args: + token_id: The ID of the access token to update. + Raises: + StoreError if there was a problem updating this. + """ + await self.db_pool.simple_update_one( + "access_tokens", + {"id": token_id}, + {"used": True}, + desc="mark_access_token_as_used", + ) + + async def lookup_refresh_token( + self, token: str + ) -> Optional[RefreshTokenLookupResult]: + """Lookup a refresh token with hints about its validity.""" + + def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]: + txn.execute( + """ + SELECT + rt.id token_id, + rt.user_id, + rt.device_id, + rt.next_token_id, + (nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed, + at.used has_next_access_token_been_used + FROM refresh_tokens rt + LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id + LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id + WHERE rt.token = ? + """, + (token,), + ) + row = txn.fetchone() + + if row is None: + return None + + return RefreshTokenLookupResult( + token_id=row[0], + user_id=row[1], + device_id=row[2], + next_token_id=row[3], + has_next_refresh_token_been_refreshed=row[4], + # This column is nullable, ensure it's a boolean + has_next_access_token_been_used=(row[5] or False), + ) + + return await self.db_pool.runInteraction( + "lookup_refresh_token", _lookup_refresh_token_txn + ) + + async def replace_refresh_token(self, token_id: int, next_token_id: int) -> None: + """ + Set the successor of a refresh token, removing the existing successor + if any. + + Args: + token_id: ID of the refresh token to update. + next_token_id: ID of its successor. + """ + + def _replace_refresh_token_txn(txn) -> None: + # First check if there was an existing refresh token + old_next_token_id = self.db_pool.simple_select_one_onecol_txn( + txn, + "refresh_tokens", + {"id": token_id}, + "next_token_id", + allow_none=True, + ) + + self.db_pool.simple_update_one_txn( + txn, + "refresh_tokens", + {"id": token_id}, + {"next_token_id": next_token_id}, + ) + + # Delete the old "next" token if it exists. This should cascade and + # delete the associated access_token + if old_next_token_id is not None: + self.db_pool.simple_delete_one_txn( + txn, + "refresh_tokens", + {"id": old_next_token_id}, + ) + + await self.db_pool.runInteraction( + "replace_refresh_token", _replace_refresh_token_txn + ) + class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): def __init__( @@ -1263,6 +1403,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") + self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") async def add_access_token_to_user( self, @@ -1271,14 +1412,18 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): device_id: Optional[str], valid_until_ms: Optional[int], puppets_user_id: Optional[str] = None, + refresh_token_id: Optional[int] = None, ) -> int: """Adds an access token for the given user. Args: user_id: The user ID. token: The new access token to add. - device_id: ID of the device to associate with the access token + device_id: ID of the device to associate with the access token. valid_until_ms: when the token is valid until. None for no expiry. + puppets_user_id + refresh_token_id: ID of the refresh token generated alongside this + access token. Raises: StoreError if there was a problem adding this. Returns: @@ -1297,12 +1442,47 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): "valid_until_ms": valid_until_ms, "puppets_user_id": puppets_user_id, "last_validated": now, + "refresh_token_id": refresh_token_id, + "used": False, }, desc="add_access_token_to_user", ) return next_id + async def add_refresh_token_to_user( + self, + user_id: str, + token: str, + device_id: Optional[str], + ) -> int: + """Adds a refresh token for the given user. + + Args: + user_id: The user ID. + token: The new access token to add. + device_id: ID of the device to associate with the refresh token. + Raises: + StoreError if there was a problem adding this. + Returns: + The token ID + """ + next_id = self._refresh_tokens_id_gen.get_next() + + await self.db_pool.simple_insert( + "refresh_tokens", + { + "id": next_id, + "user_id": user_id, + "device_id": device_id, + "token": token, + "next_token_id": None, + }, + desc="add_refresh_token_to_user", + ) + + return next_id + def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str: old_device_id = self.db_pool.simple_select_one_onecol_txn( txn, "access_tokens", {"token": token}, "device_id" @@ -1545,7 +1725,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): device_id: Optional[str] = None, ) -> List[Tuple[str, int, Optional[str]]]: """ - Invalidate access tokens belonging to a user + Invalidate access and refresh tokens belonging to a user Args: user_id: ID of user the tokens belong to @@ -1565,7 +1745,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): items = keyvalues.items() where_clause = " AND ".join(k + " = ?" for k, _ in items) values = [v for _, v in items] # type: List[Union[str, int]] + # Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat + # is the `except_token_id` param that is tricky to get right, so for now we're just using the same where + # clause and values before we handle that. This seems to be only used in the "set password" handler. + refresh_where_clause = where_clause + refresh_values = values.copy() if except_token_id: + # TODO: support that for refresh tokens where_clause += " AND id != ?" values.append(except_token_id) @@ -1583,6 +1769,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values) + txn.execute( + "DELETE FROM refresh_tokens WHERE %s" % refresh_where_clause, + refresh_values, + ) + return tokens_and_devices return await self.db_pool.runInteraction("user_delete_access_tokens", f) @@ -1599,6 +1790,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): await self.db_pool.runInteraction("delete_access_token", f) + async def delete_refresh_token(self, refresh_token: str) -> None: + def f(txn): + self.db_pool.simple_delete_one_txn( + txn, table="refresh_tokens", keyvalues={"token": refresh_token} + ) + + await self.db_pool.runInteraction("delete_refresh_token", f) + async def add_user_pending_deactivation(self, user_id: str) -> None: """ Adds a user to the table of users who need to be parted from all the rooms they're -- cgit 1.5.1 From 60efc51a2bbc31f18a71ad1338afc430bfa65597 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 29 Jun 2021 11:25:34 +0100 Subject: Migrate stream_ordering to a bigint (#10264) * Move background update names out to a separate class `EventsBackgroundUpdatesStore` gets inherited and we don't really want to further pollute the namespace. * Migrate stream_ordering to a bigint * changelog --- changelog.d/10264.bugfix | 1 + .../storage/databases/main/events_bg_updates.py | 136 ++++++++++++++++++--- synapse/storage/schema/__init__.py | 2 +- .../60/01recreate_stream_ordering.sql.postgres | 40 ++++++ 4 files changed, 163 insertions(+), 16 deletions(-) create mode 100644 changelog.d/10264.bugfix create mode 100644 synapse/storage/schema/main/delta/60/01recreate_stream_ordering.sql.postgres (limited to 'synapse/storage/databases/main') diff --git a/changelog.d/10264.bugfix b/changelog.d/10264.bugfix new file mode 100644 index 0000000000..7ebda7cdc2 --- /dev/null +++ b/changelog.d/10264.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where Synapse would return errors after 231 events were handled by the server. diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index cbe4be1437..39aaee743c 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -29,6 +29,25 @@ from synapse.types import JsonDict logger = logging.getLogger(__name__) +_REPLACE_STREAM_ORDRING_SQL_COMMANDS = ( + # there should be no leftover rows without a stream_ordering2, but just in case... + "UPDATE events SET stream_ordering2 = stream_ordering WHERE stream_ordering2 IS NULL", + # finally, we can drop the rule and switch the columns + "DROP RULE populate_stream_ordering2 ON events", + "ALTER TABLE events DROP COLUMN stream_ordering", + "ALTER TABLE events RENAME COLUMN stream_ordering2 TO stream_ordering", +) + + +class _BackgroundUpdates: + EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" + EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" + DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" + POPULATE_STREAM_ORDERING2 = "populate_stream_ordering2" + INDEX_STREAM_ORDERING2 = "index_stream_ordering2" + REPLACE_STREAM_ORDERING_COLUMN = "replace_stream_ordering_column" + + @attr.s(slots=True, frozen=True) class _CalculateChainCover: """Return value for _calculate_chain_cover_txn.""" @@ -48,19 +67,15 @@ class _CalculateChainCover: class EventsBackgroundUpdatesStore(SQLBaseStore): - - EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" - EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" - DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" - def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( - self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts + _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME, + self._background_reindex_origin_server_ts, ) self.db_pool.updates.register_background_update_handler( - self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, + _BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, self._background_reindex_fields_sender, ) @@ -85,7 +100,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ) self.db_pool.updates.register_background_update_handler( - self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update + _BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES, + self._cleanup_extremities_bg_update, ) self.db_pool.updates.register_background_update_handler( @@ -139,6 +155,24 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): self._purged_chain_cover_index, ) + # bg updates for replacing stream_ordering with a BIGINT + # (these only run on postgres.) + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.POPULATE_STREAM_ORDERING2, + self._background_populate_stream_ordering2, + ) + self.db_pool.updates.register_background_index_update( + _BackgroundUpdates.INDEX_STREAM_ORDERING2, + index_name="events_stream_ordering", + table="events", + columns=["stream_ordering2"], + unique=True, + ) + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.REPLACE_STREAM_ORDERING_COLUMN, + self._background_replace_stream_ordering_column, + ) + async def _background_reindex_fields_sender(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] @@ -190,18 +224,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): } self.db_pool.updates._background_update_progress_txn( - txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress + txn, _BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress ) return len(rows) result = await self.db_pool.runInteraction( - self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn + _BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn ) if not result: await self.db_pool.updates._end_background_update( - self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME + _BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME ) return result @@ -264,18 +298,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): } self.db_pool.updates._background_update_progress_txn( - txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress + txn, _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME, progress ) return len(rows_to_update) result = await self.db_pool.runInteraction( - self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn + _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn ) if not result: await self.db_pool.updates._end_background_update( - self.EVENT_ORIGIN_SERVER_TS_NAME + _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME ) return result @@ -454,7 +488,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): if not num_handled: await self.db_pool.updates._end_background_update( - self.DELETE_SOFT_FAILED_EXTREMITIES + _BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES ) def _drop_table_txn(txn): @@ -1009,3 +1043,75 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): await self.db_pool.updates._end_background_update("purged_chain_cover") return result + + async def _background_populate_stream_ordering2( + self, progress: JsonDict, batch_size: int + ) -> int: + """Populate events.stream_ordering2, then replace stream_ordering + + This is to deal with the fact that stream_ordering was initially created as a + 32-bit integer field. + """ + batch_size = max(batch_size, 1) + + def process(txn: Cursor) -> int: + # if this is the first pass, find the minimum stream ordering + last_stream = progress.get("last_stream") + if last_stream is None: + txn.execute( + """ + SELECT stream_ordering FROM events ORDER BY stream_ordering LIMIT 1 + """ + ) + rows = txn.fetchall() + if not rows: + return 0 + last_stream = rows[0][0] - 1 + + txn.execute( + """ + UPDATE events SET stream_ordering2=stream_ordering + WHERE stream_ordering > ? AND stream_ordering <= ? + """, + (last_stream, last_stream + batch_size), + ) + row_count = txn.rowcount + + self.db_pool.updates._background_update_progress_txn( + txn, + _BackgroundUpdates.POPULATE_STREAM_ORDERING2, + {"last_stream": last_stream + batch_size}, + ) + return row_count + + result = await self.db_pool.runInteraction( + "_background_populate_stream_ordering2", process + ) + + if result != 0: + return result + + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.POPULATE_STREAM_ORDERING2 + ) + return 0 + + async def _background_replace_stream_ordering_column( + self, progress: JsonDict, batch_size: int + ) -> int: + """Drop the old 'stream_ordering' column and rename 'stream_ordering2' into its place.""" + + def process(txn: Cursor) -> None: + for sql in _REPLACE_STREAM_ORDRING_SQL_COMMANDS: + logger.info("completing stream_ordering migration: %s", sql) + txn.execute(sql) + + await self.db_pool.runInteraction( + "_background_replace_stream_ordering_column", process + ) + + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.REPLACE_STREAM_ORDERING_COLUMN + ) + + return 0 diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index d36ba1d773..0a53b73ccc 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 59 +SCHEMA_VERSION = 60 """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the diff --git a/synapse/storage/schema/main/delta/60/01recreate_stream_ordering.sql.postgres b/synapse/storage/schema/main/delta/60/01recreate_stream_ordering.sql.postgres new file mode 100644 index 0000000000..88c9f8bd0d --- /dev/null +++ b/synapse/storage/schema/main/delta/60/01recreate_stream_ordering.sql.postgres @@ -0,0 +1,40 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This migration handles the process of changing the type of `stream_ordering` to +-- a BIGINT. +-- +-- Note that this is only a problem on postgres as sqlite only has one "integer" type +-- which can cope with values up to 2^63. + +-- First add a new column to contain the bigger stream_ordering +ALTER TABLE events ADD COLUMN stream_ordering2 BIGINT; + +-- Create a rule which will populate it for new rows. +CREATE OR REPLACE RULE "populate_stream_ordering2" AS + ON INSERT TO events + DO UPDATE events SET stream_ordering2=NEW.stream_ordering WHERE stream_ordering=NEW.stream_ordering; + +-- Start a bg process to populate it for old events +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (6001, 'populate_stream_ordering2', '{}'); + +-- ... and another to build an index on it +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (6001, 'index_stream_ordering2', '{}', 'populate_stream_ordering2'); + +-- ... and another to do the switcheroo +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (6001, 'replace_stream_ordering_column', '{}', 'index_stream_ordering2'); -- cgit 1.5.1 From 7647b0337fb5d936c88c5949fa92c07bf2137ad0 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 29 Jun 2021 12:43:36 +0100 Subject: Fix `populate_stream_ordering2` background job (#10267) It was possible for us not to find any rows in a batch, and hence conclude that we had finished. Let's not do that. --- changelog.d/10267.bugfix | 1 + .../storage/databases/main/events_bg_updates.py | 28 ++++++++++------------ 2 files changed, 13 insertions(+), 16 deletions(-) create mode 100644 changelog.d/10267.bugfix (limited to 'synapse/storage/databases/main') diff --git a/changelog.d/10267.bugfix b/changelog.d/10267.bugfix new file mode 100644 index 0000000000..7ebda7cdc2 --- /dev/null +++ b/changelog.d/10267.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where Synapse would return errors after 231 events were handled by the server. diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 39aaee743c..da3a7df27b 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -1055,32 +1055,28 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): batch_size = max(batch_size, 1) def process(txn: Cursor) -> int: - # if this is the first pass, find the minimum stream ordering - last_stream = progress.get("last_stream") - if last_stream is None: - txn.execute( - """ - SELECT stream_ordering FROM events ORDER BY stream_ordering LIMIT 1 - """ - ) - rows = txn.fetchall() - if not rows: - return 0 - last_stream = rows[0][0] - 1 - + last_stream = progress.get("last_stream", -(1 << 31)) txn.execute( """ UPDATE events SET stream_ordering2=stream_ordering - WHERE stream_ordering > ? AND stream_ordering <= ? + WHERE stream_ordering IN ( + SELECT stream_ordering FROM events WHERE stream_ordering > ? + ORDER BY stream_ordering LIMIT ? + ) + RETURNING stream_ordering; """, - (last_stream, last_stream + batch_size), + (last_stream, batch_size), ) row_count = txn.rowcount + if row_count == 0: + return 0 + last_stream = max(row[0] for row in txn) + logger.info("populated stream_ordering2 up to %i", last_stream) self.db_pool.updates._background_update_progress_txn( txn, _BackgroundUpdates.POPULATE_STREAM_ORDERING2, - {"last_stream": last_stream + batch_size}, + {"last_stream": last_stream}, ) return row_count -- cgit 1.5.1 From 85d237eba789a667109ced140026d2494b210310 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 29 Jun 2021 19:15:47 +0100 Subject: Add a distributed lock (#10269) This adds a simple best effort locking mechanism that works cross workers. --- changelog.d/10269.misc | 1 + synapse/app/generic_worker.py | 2 + synapse/storage/databases/main/__init__.py | 2 + synapse/storage/databases/main/lock.py | 334 +++++++++++++++++++++++ synapse/storage/schema/main/delta/59/15locks.sql | 37 +++ tests/storage/databases/main/test_lock.py | 100 +++++++ 6 files changed, 476 insertions(+) create mode 100644 changelog.d/10269.misc create mode 100644 synapse/storage/databases/main/lock.py create mode 100644 synapse/storage/schema/main/delta/59/15locks.sql create mode 100644 tests/storage/databases/main/test_lock.py (limited to 'synapse/storage/databases/main') diff --git a/changelog.d/10269.misc b/changelog.d/10269.misc new file mode 100644 index 0000000000..23e590490c --- /dev/null +++ b/changelog.d/10269.misc @@ -0,0 +1 @@ +Add a distributed lock implementation. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index af8a1833f3..5b041fcaad 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -108,6 +108,7 @@ from synapse.server import HomeServer from synapse.storage.databases.main.censor_events import CensorEventsStore from synapse.storage.databases.main.client_ips import ClientIpWorkerStore from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyStore +from synapse.storage.databases.main.lock import LockStore from synapse.storage.databases.main.media_repository import MediaRepositoryStore from synapse.storage.databases.main.metrics import ServerMetricsStore from synapse.storage.databases.main.monthly_active_users import ( @@ -249,6 +250,7 @@ class GenericWorkerSlavedStore( ServerMetricsStore, SearchStore, TransactionWorkerStore, + LockStore, BaseSlavedStore, ): pass diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 9cce62ae6c..a3fddea042 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -46,6 +46,7 @@ from .events_forward_extremities import EventForwardExtremitiesStore from .filtering import FilteringStore from .group_server import GroupServerStore from .keys import KeyStore +from .lock import LockStore from .media_repository import MediaRepositoryStore from .metrics import ServerMetricsStore from .monthly_active_users import MonthlyActiveUsersStore @@ -119,6 +120,7 @@ class DataStore( CacheInvalidationWorkerStore, ServerMetricsStore, EventForwardExtremitiesStore, + LockStore, ): def __init__(self, database: DatabasePool, db_conn, hs): self.hs = hs diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py new file mode 100644 index 0000000000..e76188328c --- /dev/null +++ b/synapse/storage/databases/main/lock.py @@ -0,0 +1,334 @@ +# Copyright 2021 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from types import TracebackType +from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type + +from twisted.internet.interfaces import IReactorCore + +from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.types import Connection +from synapse.util import Clock +from synapse.util.stringutils import random_string + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +logger = logging.getLogger(__name__) + + +# How often to renew an acquired lock by updating the `last_renewed_ts` time in +# the lock table. +_RENEWAL_INTERVAL_MS = 30 * 1000 + +# How long before an acquired lock times out. +_LOCK_TIMEOUT_MS = 2 * 60 * 1000 + + +class LockStore(SQLBaseStore): + """Provides a best effort distributed lock between worker instances. + + Locks are identified by a name and key. A lock is acquired by inserting into + the `worker_locks` table if a) there is no existing row for the name/key or + b) the existing row has a `last_renewed_ts` older than `_LOCK_TIMEOUT_MS`. + + When a lock is taken out the instance inserts a random `token`, the instance + that holds that token holds the lock until it drops (or times out). + + The instance that holds the lock should regularly update the + `last_renewed_ts` column with the current time. + """ + + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + super().__init__(database, db_conn, hs) + + self._reactor = hs.get_reactor() + self._instance_name = hs.get_instance_id() + + # A map from `(lock_name, lock_key)` to the token of any locks that we + # think we currently hold. + self._live_tokens: Dict[Tuple[str, str], str] = {} + + # When we shut down we want to remove the locks. Technically this can + # lead to a race, as we may drop the lock while we are still processing. + # However, a) it should be a small window, b) the lock is best effort + # anyway and c) we want to really avoid leaking locks when we restart. + hs.get_reactor().addSystemEventTrigger( + "before", + "shutdown", + self._on_shutdown, + ) + + @wrap_as_background_process("LockStore._on_shutdown") + async def _on_shutdown(self) -> None: + """Called when the server is shutting down""" + logger.info("Dropping held locks due to shutdown") + + for (lock_name, lock_key), token in self._live_tokens.items(): + await self._drop_lock(lock_name, lock_key, token) + + logger.info("Dropped locks due to shutdown") + + async def try_acquire_lock(self, lock_name: str, lock_key: str) -> Optional["Lock"]: + """Try to acquire a lock for the given name/key. Will return an async + context manager if the lock is successfully acquired, which *must* be + used (otherwise the lock will leak). + """ + + now = self._clock.time_msec() + token = random_string(6) + + if self.db_pool.engine.can_native_upsert: + + def _try_acquire_lock_txn(txn: LoggingTransaction) -> bool: + # We take out the lock if either a) there is no row for the lock + # already or b) the existing row has timed out. + sql = """ + INSERT INTO worker_locks (lock_name, lock_key, instance_name, token, last_renewed_ts) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT (lock_name, lock_key) + DO UPDATE + SET + token = EXCLUDED.token, + instance_name = EXCLUDED.instance_name, + last_renewed_ts = EXCLUDED.last_renewed_ts + WHERE + worker_locks.last_renewed_ts < ? + """ + txn.execute( + sql, + ( + lock_name, + lock_key, + self._instance_name, + token, + now, + now - _LOCK_TIMEOUT_MS, + ), + ) + + # We only acquired the lock if we inserted or updated the table. + return bool(txn.rowcount) + + did_lock = await self.db_pool.runInteraction( + "try_acquire_lock", + _try_acquire_lock_txn, + # We can autocommit here as we're executing a single query, this + # will avoid serialization errors. + db_autocommit=True, + ) + if not did_lock: + return None + + else: + # If we're on an old SQLite we emulate the above logic by first + # clearing out any existing stale locks and then upserting. + + def _try_acquire_lock_emulated_txn(txn: LoggingTransaction) -> bool: + sql = """ + DELETE FROM worker_locks + WHERE + lock_name = ? + AND lock_key = ? + AND last_renewed_ts < ? + """ + txn.execute( + sql, + (lock_name, lock_key, now - _LOCK_TIMEOUT_MS), + ) + + inserted = self.db_pool.simple_upsert_txn_emulated( + txn, + table="worker_locks", + keyvalues={ + "lock_name": lock_name, + "lock_key": lock_key, + }, + values={}, + insertion_values={ + "token": token, + "last_renewed_ts": self._clock.time_msec(), + "instance_name": self._instance_name, + }, + ) + + return inserted + + did_lock = await self.db_pool.runInteraction( + "try_acquire_lock_emulated", _try_acquire_lock_emulated_txn + ) + + if not did_lock: + return None + + self._live_tokens[(lock_name, lock_key)] = token + + return Lock( + self._reactor, + self._clock, + self, + lock_name=lock_name, + lock_key=lock_key, + token=token, + ) + + async def _is_lock_still_valid( + self, lock_name: str, lock_key: str, token: str + ) -> bool: + """Checks whether this instance still holds the lock.""" + last_renewed_ts = await self.db_pool.simple_select_one_onecol( + table="worker_locks", + keyvalues={ + "lock_name": lock_name, + "lock_key": lock_key, + "token": token, + }, + retcol="last_renewed_ts", + allow_none=True, + desc="is_lock_still_valid", + ) + return ( + last_renewed_ts is not None + and self._clock.time_msec() - _LOCK_TIMEOUT_MS < last_renewed_ts + ) + + async def _renew_lock(self, lock_name: str, lock_key: str, token: str) -> None: + """Attempt to renew the lock if we still hold it.""" + await self.db_pool.simple_update( + table="worker_locks", + keyvalues={ + "lock_name": lock_name, + "lock_key": lock_key, + "token": token, + }, + updatevalues={"last_renewed_ts": self._clock.time_msec()}, + desc="renew_lock", + ) + + async def _drop_lock(self, lock_name: str, lock_key: str, token: str) -> None: + """Attempt to drop the lock, if we still hold it""" + await self.db_pool.simple_delete( + table="worker_locks", + keyvalues={ + "lock_name": lock_name, + "lock_key": lock_key, + "token": token, + }, + desc="drop_lock", + ) + + self._live_tokens.pop((lock_name, lock_key), None) + + +class Lock: + """An async context manager that manages an acquired lock, ensuring it is + regularly renewed and dropping it when the context manager exits. + + The lock object has an `is_still_valid` method which can be used to + double-check the lock is still valid, if e.g. processing work in a loop. + + For example: + + lock = await self.store.try_acquire_lock(...) + if not lock: + return + + async with lock: + for item in work: + await process(item) + + if not await lock.is_still_valid(): + break + """ + + def __init__( + self, + reactor: IReactorCore, + clock: Clock, + store: LockStore, + lock_name: str, + lock_key: str, + token: str, + ) -> None: + self._reactor = reactor + self._clock = clock + self._store = store + self._lock_name = lock_name + self._lock_key = lock_key + + self._token = token + + self._looping_call = clock.looping_call( + self._renew, _RENEWAL_INTERVAL_MS, store, lock_name, lock_key, token + ) + + self._dropped = False + + @staticmethod + @wrap_as_background_process("Lock._renew") + async def _renew( + store: LockStore, + lock_name: str, + lock_key: str, + token: str, + ) -> None: + """Renew the lock. + + Note: this is a static method, rather than using self.*, so that we + don't end up with a reference to `self` in the reactor, which would stop + this from being cleaned up if we dropped the context manager. + """ + await store._renew_lock(lock_name, lock_key, token) + + async def is_still_valid(self) -> bool: + """Check if the lock is still held by us""" + return await self._store._is_lock_still_valid( + self._lock_name, self._lock_key, self._token + ) + + async def __aenter__(self) -> None: + if self._dropped: + raise Exception("Cannot reuse a Lock object") + + async def __aexit__( + self, + _exctype: Optional[Type[BaseException]], + _excinst: Optional[BaseException], + _exctb: Optional[TracebackType], + ) -> bool: + if self._looping_call.running: + self._looping_call.stop() + + await self._store._drop_lock(self._lock_name, self._lock_key, self._token) + self._dropped = True + + return False + + def __del__(self) -> None: + if not self._dropped: + # We should not be dropped without the lock being released (unless + # we're shutting down), but if we are then let's at least stop + # renewing the lock. + if self._looping_call.running: + self._looping_call.stop() + + if self._reactor.running: + logger.error( + "Lock for (%s, %s) dropped without being released", + self._lock_name, + self._lock_key, + ) diff --git a/synapse/storage/schema/main/delta/59/15locks.sql b/synapse/storage/schema/main/delta/59/15locks.sql new file mode 100644 index 0000000000..8b2999ff3e --- /dev/null +++ b/synapse/storage/schema/main/delta/59/15locks.sql @@ -0,0 +1,37 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- A noddy implementation of a distributed lock across workers. While a worker +-- has taken a lock out they should regularly update the `last_renewed_ts` +-- column, a lock will be considered dropped if `last_renewed_ts` is from ages +-- ago. +CREATE TABLE worker_locks ( + lock_name TEXT NOT NULL, + lock_key TEXT NOT NULL, + -- We write the instance name to ease manual debugging, we don't ever read + -- from it. + -- Note: instance names aren't guarenteed to be unique. + instance_name TEXT NOT NULL, + -- A random string generated each time an instance takes out a lock. Used by + -- the instance to tell whether the lock is still held by it (e.g. in the + -- case where the process stalls for a long time the lock may time out and + -- be taken out by another instance, at which point the original instance + -- can tell it no longer holds the lock as the tokens no longer match). + token TEXT NOT NULL, + last_renewed_ts BIGINT NOT NULL +); + +CREATE UNIQUE INDEX worker_locks_key ON worker_locks (lock_name, lock_key); diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py new file mode 100644 index 0000000000..9ca70e7367 --- /dev/null +++ b/tests/storage/databases/main/test_lock.py @@ -0,0 +1,100 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.server import HomeServer +from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS + +from tests import unittest + + +class LockTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs: HomeServer): + self.store = hs.get_datastore() + + def test_simple_lock(self): + """Test that we can take out a lock and that while we hold it nobody + else can take it out. + """ + # First to acquire this lock, so it should complete + lock = self.get_success(self.store.try_acquire_lock("name", "key")) + self.assertIsNotNone(lock) + + # Enter the context manager + self.get_success(lock.__aenter__()) + + # Attempting to acquire the lock again fails. + lock2 = self.get_success(self.store.try_acquire_lock("name", "key")) + self.assertIsNone(lock2) + + # Calling `is_still_valid` reports true. + self.assertTrue(self.get_success(lock.is_still_valid())) + + # Drop the lock + self.get_success(lock.__aexit__(None, None, None)) + + # We can now acquire the lock again. + lock3 = self.get_success(self.store.try_acquire_lock("name", "key")) + self.assertIsNotNone(lock3) + self.get_success(lock3.__aenter__()) + self.get_success(lock3.__aexit__(None, None, None)) + + def test_maintain_lock(self): + """Test that we don't time out locks while they're still active""" + + lock = self.get_success(self.store.try_acquire_lock("name", "key")) + self.assertIsNotNone(lock) + + self.get_success(lock.__aenter__()) + + # Wait for ages with the lock, we should not be able to get the lock. + self.reactor.advance(5 * _LOCK_TIMEOUT_MS / 1000) + + lock2 = self.get_success(self.store.try_acquire_lock("name", "key")) + self.assertIsNone(lock2) + + self.get_success(lock.__aexit__(None, None, None)) + + def test_timeout_lock(self): + """Test that we time out locks if they're not updated for ages""" + + lock = self.get_success(self.store.try_acquire_lock("name", "key")) + self.assertIsNotNone(lock) + + self.get_success(lock.__aenter__()) + + # We simulate the process getting stuck by cancelling the looping call + # that keeps the lock active. + lock._looping_call.stop() + + # Wait for the lock to timeout. + self.reactor.advance(2 * _LOCK_TIMEOUT_MS / 1000) + + lock2 = self.get_success(self.store.try_acquire_lock("name", "key")) + self.assertIsNotNone(lock2) + + self.assertFalse(self.get_success(lock.is_still_valid())) + + def test_drop(self): + """Test that dropping the context manager means we stop renewing the lock""" + + lock = self.get_success(self.store.try_acquire_lock("name", "key")) + self.assertIsNotNone(lock) + + del lock + + # Wait for the lock to timeout. + self.reactor.advance(2 * _LOCK_TIMEOUT_MS / 1000) + + lock2 = self.get_success(self.store.try_acquire_lock("name", "key")) + self.assertIsNotNone(lock2) -- cgit 1.5.1 From c54db67d0ea5b5967b7ea918c66a222a75b8ced1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 29 Jun 2021 19:55:22 +0100 Subject: Handle inbound events from federation asynchronously (#10272) Fixes #9490 This will break a couple of SyTest that are expecting failures to be added to the response of a federation /send, which obviously doesn't happen now that things are asynchronous. Two drawbacks: Currently there is no logic to handle any events left in the staging area after restart, and so they'll only be handled on the next incoming event in that room. That can be fixed separately. We now only process one event per room at a time. This can be fixed up further down the line. --- changelog.d/10272.bugfix | 1 + synapse/federation/federation_server.py | 98 +++++++++++++++++- synapse/storage/databases/main/event_federation.py | 109 ++++++++++++++++++++- .../main/delta/59/16federation_inbound_staging.sql | 32 ++++++ sytest-blacklist | 6 ++ 5 files changed, 241 insertions(+), 5 deletions(-) create mode 100644 changelog.d/10272.bugfix create mode 100644 synapse/storage/schema/main/delta/59/16federation_inbound_staging.sql (limited to 'synapse/storage/databases/main') diff --git a/changelog.d/10272.bugfix b/changelog.d/10272.bugfix new file mode 100644 index 0000000000..3cefa05788 --- /dev/null +++ b/changelog.d/10272.bugfix @@ -0,0 +1 @@ +Handle inbound events from federation asynchronously. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 2b07f18529..1d050e54e2 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -44,7 +44,7 @@ from synapse.api.errors import ( SynapseError, UnsupportedRoomVersionError, ) -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.persistence import TransactionActions @@ -57,10 +57,12 @@ from synapse.logging.context import ( ) from synapse.logging.opentracing import log_kv, start_active_span_from_edu, trace from synapse.logging.utils import log_function +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.replication.http.federation import ( ReplicationFederationSendEduRestServlet, ReplicationGetQueryRestServlet, ) +from synapse.storage.databases.main.lock import Lock from synapse.types import JsonDict from synapse.util import glob_to_regex, json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute @@ -96,6 +98,11 @@ last_pdu_ts_metric = Gauge( ) +# The name of the lock to use when process events in a room received over +# federation. +_INBOUND_EVENT_HANDLING_LOCK_NAME = "federation_inbound_pdu" + + class FederationServer(FederationBase): def __init__(self, hs: "HomeServer"): super().__init__(hs) @@ -834,7 +841,94 @@ class FederationServer(FederationBase): except SynapseError as e: raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id) - await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) + # Add the event to our staging area + await self.store.insert_received_event_to_staging(origin, pdu) + + # Try and acquire the processing lock for the room, if we get it start a + # background process for handling the events in the room. + lock = await self.store.try_acquire_lock( + _INBOUND_EVENT_HANDLING_LOCK_NAME, pdu.room_id + ) + if lock: + self._process_incoming_pdus_in_room_inner( + pdu.room_id, room_version, lock, origin, pdu + ) + + @wrap_as_background_process("_process_incoming_pdus_in_room_inner") + async def _process_incoming_pdus_in_room_inner( + self, + room_id: str, + room_version: RoomVersion, + lock: Lock, + latest_origin: str, + latest_event: EventBase, + ) -> None: + """Process events in the staging area for the given room. + + The latest_origin and latest_event args are the latest origin and event + received. + """ + + # The common path is for the event we just received be the only event in + # the room, so instead of pulling the event out of the DB and parsing + # the event we just pull out the next event ID and check if that matches. + next_origin, next_event_id = await self.store.get_next_staged_event_id_for_room( + room_id + ) + if next_origin == latest_origin and next_event_id == latest_event.event_id: + origin = latest_origin + event = latest_event + else: + next = await self.store.get_next_staged_event_for_room( + room_id, room_version + ) + if not next: + return + + origin, event = next + + # We loop round until there are no more events in the room in the + # staging area, or we fail to get the lock (which means another process + # has started processing). + while True: + async with lock: + try: + await self.handler.on_receive_pdu( + origin, event, sent_to_us_directly=True + ) + except FederationError as e: + # XXX: Ideally we'd inform the remote we failed to process + # the event, but we can't return an error in the transaction + # response (as we've already responded). + logger.warning("Error handling PDU %s: %s", event.event_id, e) + except Exception: + f = failure.Failure() + logger.error( + "Failed to handle PDU %s", + event.event_id, + exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore + ) + + await self.store.remove_received_event_from_staging( + origin, event.event_id + ) + + # We need to do this check outside the lock to avoid a race between + # a new event being inserted by another instance and it attempting + # to acquire the lock. + next = await self.store.get_next_staged_event_for_room( + room_id, room_version + ) + if not next: + break + + origin, event = next + + lock = await self.store.try_acquire_lock( + _INBOUND_EVENT_HANDLING_LOCK_NAME, room_id + ) + if not lock: + return def __str__(self) -> str: return "" % self.server_name diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index c0ea445550..f23f8c6ecf 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -14,18 +14,20 @@ import itertools import logging from queue import Empty, PriorityQueue -from typing import Collection, Dict, Iterable, List, Set, Tuple +from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple from synapse.api.constants import MAX_DEPTH from synapse.api.errors import StoreError -from synapse.events import EventBase +from synapse.api.room_versions import RoomVersion +from synapse.events import EventBase, make_event_from_dict from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.types import Cursor +from synapse.util import json_encoder from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter @@ -1044,6 +1046,107 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas _delete_old_forward_extrem_cache_txn, ) + async def insert_received_event_to_staging( + self, origin: str, event: EventBase + ) -> None: + """Insert a newly received event from federation into the staging area.""" + + # We use an upsert here to handle the case where we see the same event + # from the same server multiple times. + await self.db_pool.simple_upsert( + table="federation_inbound_events_staging", + keyvalues={ + "origin": origin, + "event_id": event.event_id, + }, + values={}, + insertion_values={ + "room_id": event.room_id, + "received_ts": self._clock.time_msec(), + "event_json": json_encoder.encode(event.get_dict()), + "internal_metadata": json_encoder.encode( + event.internal_metadata.get_dict() + ), + }, + desc="insert_received_event_to_staging", + ) + + async def remove_received_event_from_staging( + self, + origin: str, + event_id: str, + ) -> None: + """Remove the given event from the staging area""" + await self.db_pool.simple_delete( + table="federation_inbound_events_staging", + keyvalues={ + "origin": origin, + "event_id": event_id, + }, + desc="remove_received_event_from_staging", + ) + + async def get_next_staged_event_id_for_room( + self, + room_id: str, + ) -> Optional[Tuple[str, str]]: + """Get the next event ID in the staging area for the given room.""" + + def _get_next_staged_event_id_for_room_txn(txn): + sql = """ + SELECT origin, event_id + FROM federation_inbound_events_staging + WHERE room_id = ? + ORDER BY received_ts ASC + LIMIT 1 + """ + + txn.execute(sql, (room_id,)) + + return txn.fetchone() + + return await self.db_pool.runInteraction( + "get_next_staged_event_id_for_room", _get_next_staged_event_id_for_room_txn + ) + + async def get_next_staged_event_for_room( + self, + room_id: str, + room_version: RoomVersion, + ) -> Optional[Tuple[str, EventBase]]: + """Get the next event in the staging area for the given room.""" + + def _get_next_staged_event_for_room_txn(txn): + sql = """ + SELECT event_json, internal_metadata, origin + FROM federation_inbound_events_staging + WHERE room_id = ? + ORDER BY received_ts ASC + LIMIT 1 + """ + txn.execute(sql, (room_id,)) + + return txn.fetchone() + + row = await self.db_pool.runInteraction( + "get_next_staged_event_for_room", _get_next_staged_event_for_room_txn + ) + + if not row: + return None + + event_d = db_to_json(row[0]) + internal_metadata_d = db_to_json(row[1]) + origin = row[2] + + event = make_event_from_dict( + event_dict=event_d, + room_version=room_version, + internal_metadata_dict=internal_metadata_d, + ) + + return origin, event + class EventFederationStore(EventFederationWorkerStore): """Responsible for storing and serving up the various graphs associated diff --git a/synapse/storage/schema/main/delta/59/16federation_inbound_staging.sql b/synapse/storage/schema/main/delta/59/16federation_inbound_staging.sql new file mode 100644 index 0000000000..43bc5c025f --- /dev/null +++ b/synapse/storage/schema/main/delta/59/16federation_inbound_staging.sql @@ -0,0 +1,32 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- A staging area for newly received events over federation. +-- +-- Note we may store the same event multiple times if it comes from different +-- servers; this is to handle the case if we get a redacted and non-redacted +-- versions of the event. +CREATE TABLE federation_inbound_events_staging ( + origin TEXT NOT NULL, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + received_ts BIGINT NOT NULL, + event_json TEXT NOT NULL, + internal_metadata TEXT NOT NULL +); + +CREATE INDEX federation_inbound_events_staging_room ON federation_inbound_events_staging(room_id, received_ts); +CREATE UNIQUE INDEX federation_inbound_events_staging_instance_event ON federation_inbound_events_staging(origin, event_id); diff --git a/sytest-blacklist b/sytest-blacklist index de9986357b..89c4e828fd 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -41,3 +41,9 @@ We can't peek into rooms with invited history_visibility We can't peek into rooms with joined history_visibility Local users can peek by room alias Peeked rooms only turn up in the sync for the device who peeked them + + +# Blacklisted due to changes made in #10272 +Outbound federation will ignore a missing event with bad JSON for room version 6 +Backfilled events whose prev_events are in a different room do not allow cross-room back-pagination +Federation rejects inbound events where the prev_events cannot be found -- cgit 1.5.1