From 63f4990298ce6369c540fe8d8d8895b20b288317 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 11 Jan 2021 13:57:33 +0000 Subject: Ensure rejected events get added to some metadata tables (#9016) Co-authored-by: Patrick Cloke --- synapse/storage/databases/main/events.py | 49 ++++---- .../storage/databases/main/events_bg_updates.py | 124 +++++++++++++++++++++ .../schema/delta/58/28rejected_events_metadata.sql | 17 +++ 3 files changed, 166 insertions(+), 24 deletions(-) create mode 100644 synapse/storage/databases/main/schema/delta/58/28rejected_events_metadata.sql (limited to 'synapse/storage/databases') diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 90fb1a1f00..5e7753e09b 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -799,7 +799,8 @@ class PersistEventsStore: return [ec for ec in events_and_contexts if ec[0] not in to_remove] def _store_event_txn(self, txn, events_and_contexts): - """Insert new events into the event and event_json tables + """Insert new events into the event, event_json, redaction and + state_events tables. Args: txn (twisted.enterprise.adbapi.Connection): db connection @@ -871,6 +872,29 @@ class PersistEventsStore: updatevalues={"have_censored": False}, ) + state_events_and_contexts = [ + ec for ec in events_and_contexts if ec[0].is_state() + ] + + state_values = [] + for event, context in state_events_and_contexts: + vals = { + "event_id": event.event_id, + "room_id": event.room_id, + "type": event.type, + "state_key": event.state_key, + } + + # TODO: How does this work with backfilling? + if hasattr(event, "replaces_state"): + vals["prev_state"] = event.replaces_state + + state_values.append(vals) + + self.db_pool.simple_insert_many_txn( + txn, table="state_events", values=state_values + ) + def _store_rejected_events_txn(self, txn, events_and_contexts): """Add rows to the 'rejections' table for received events which were rejected @@ -987,29 +1011,6 @@ class PersistEventsStore: txn, [event for event, _ in events_and_contexts] ) - state_events_and_contexts = [ - ec for ec in events_and_contexts if ec[0].is_state() - ] - - state_values = [] - for event, context in state_events_and_contexts: - vals = { - "event_id": event.event_id, - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, - } - - # TODO: How does this work with backfilling? - if hasattr(event, "replaces_state"): - vals["prev_state"] = event.replaces_state - - state_values.append(vals) - - self.db_pool.simple_insert_many_txn( - txn, table="state_events", values=state_values - ) - # Prefill the event cache self._add_to_cache(txn, events_and_contexts) diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 97b6754846..7e4b175d08 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -14,10 +14,15 @@ # limitations under the License. import logging +from typing import List, Tuple from synapse.api.constants import EventContentFields +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.events import make_event_from_dict from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool +from synapse.storage.types import Cursor +from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -99,6 +104,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): columns=["user_id", "created_ts"], ) + self.db_pool.updates.register_background_update_handler( + "rejected_events_metadata", self._rejected_events_metadata, + ) + async def _background_reindex_fields_sender(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] @@ -582,3 +591,118 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): await self.db_pool.updates._end_background_update("event_store_labels") return num_rows + + async def _rejected_events_metadata(self, progress: dict, batch_size: int) -> int: + """Adds rejected events to the `state_events` and `event_auth` metadata + tables. + """ + + last_event_id = progress.get("last_event_id", "") + + def get_rejected_events( + txn: Cursor, + ) -> List[Tuple[str, str, JsonDict, bool, bool]]: + # Fetch rejected event json, their room version and whether we have + # inserted them into the state_events or auth_events tables. + # + # Note we can assume that events that don't have a corresponding + # room version are V1 rooms. + sql = """ + SELECT DISTINCT + event_id, + COALESCE(room_version, '1'), + json, + state_events.event_id IS NOT NULL, + event_auth.event_id IS NOT NULL + FROM rejections + INNER JOIN event_json USING (event_id) + LEFT JOIN rooms USING (room_id) + LEFT JOIN state_events USING (event_id) + LEFT JOIN event_auth USING (event_id) + WHERE event_id > ? + ORDER BY event_id + LIMIT ? + """ + + txn.execute(sql, (last_event_id, batch_size,)) + + return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn] # type: ignore + + results = await self.db_pool.runInteraction( + desc="_rejected_events_metadata_get", func=get_rejected_events + ) + + if not results: + await self.db_pool.updates._end_background_update( + "rejected_events_metadata" + ) + return 0 + + state_events = [] + auth_events = [] + for event_id, room_version, event_json, has_state, has_event_auth in results: + last_event_id = event_id + + if has_state and has_event_auth: + continue + + room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version) + if not room_version_obj: + # We no longer support this room version, so we just ignore the + # events entirely. + logger.info( + "Ignoring event with unknown room version %r: %r", + room_version, + event_id, + ) + continue + + event = make_event_from_dict(event_json, room_version_obj) + + if not event.is_state(): + continue + + if not has_state: + state_events.append( + { + "event_id": event.event_id, + "room_id": event.room_id, + "type": event.type, + "state_key": event.state_key, + } + ) + + if not has_event_auth: + for auth_id in event.auth_event_ids(): + auth_events.append( + { + "room_id": event.room_id, + "event_id": event.event_id, + "auth_id": auth_id, + } + ) + + if state_events: + await self.db_pool.simple_insert_many( + table="state_events", + values=state_events, + desc="_rejected_events_metadata_state_events", + ) + + if auth_events: + await self.db_pool.simple_insert_many( + table="event_auth", + values=auth_events, + desc="_rejected_events_metadata_event_auth", + ) + + await self.db_pool.updates._background_update_progress( + "rejected_events_metadata", {"last_event_id": last_event_id} + ) + + if len(results) < batch_size: + await self.db_pool.updates._end_background_update( + "rejected_events_metadata" + ) + + return len(results) diff --git a/synapse/storage/databases/main/schema/delta/58/28rejected_events_metadata.sql b/synapse/storage/databases/main/schema/delta/58/28rejected_events_metadata.sql new file mode 100644 index 0000000000..9c95646281 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/28rejected_events_metadata.sql @@ -0,0 +1,17 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (5828, 'rejected_events_metadata', '{}'); -- cgit 1.5.1 From 4e04435bda135d3441777a51aa54dbd4c3925f2b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 11 Jan 2021 13:58:19 +0000 Subject: Remove old tables after schema version bump (#9055) These tables are unused, and can be dropped now the schema version has been bumped. --- changelog.d/9055.misc | 1 + synapse/storage/databases/main/account_data.py | 48 +--------------------- .../main/schema/delta/59/04drop_account_data.sql | 17 ++++++++ .../main/schema/delta/59/05cache_invalidation.sql | 17 ++++++++ synapse/storage/databases/main/tags.py | 10 ----- synapse/storage/prepare_database.py | 3 -- 6 files changed, 37 insertions(+), 59 deletions(-) create mode 100644 changelog.d/9055.misc create mode 100644 synapse/storage/databases/main/schema/delta/59/04drop_account_data.sql create mode 100644 synapse/storage/databases/main/schema/delta/59/05cache_invalidation.sql (limited to 'synapse/storage/databases') diff --git a/changelog.d/9055.misc b/changelog.d/9055.misc new file mode 100644 index 0000000000..8e0512eb1e --- /dev/null +++ b/changelog.d/9055.misc @@ -0,0 +1 @@ +Drop unused database tables. diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index bff51e92b9..bad8260892 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -312,12 +312,9 @@ class AccountDataStore(AccountDataWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): self._account_data_id_gen = StreamIdGenerator( db_conn, - "account_data_max_stream_id", + "room_account_data", "stream_id", - extra_tables=[ - ("room_account_data", "stream_id"), - ("room_tags_revisions", "stream_id"), - ], + extra_tables=[("room_tags_revisions", "stream_id")], ) super().__init__(database, db_conn, hs) @@ -362,14 +359,6 @@ class AccountDataStore(AccountDataWorkerStore): lock=False, ) - # it's theoretically possible for the above to succeed and the - # below to fail - in which case we might reuse a stream id on - # restart, and the above update might not get propagated. That - # doesn't sound any worse than the whole update getting lost, - # which is what would happen if we combined the two into one - # transaction. - await self._update_max_stream_id(next_id) - self._account_data_stream_cache.entity_has_changed(user_id, next_id) self.get_account_data_for_user.invalidate((user_id,)) self.get_account_data_for_room.invalidate((user_id, room_id)) @@ -402,18 +391,6 @@ class AccountDataStore(AccountDataWorkerStore): content, ) - # it's theoretically possible for the above to succeed and the - # below to fail - in which case we might reuse a stream id on - # restart, and the above update might not get propagated. That - # doesn't sound any worse than the whole update getting lost, - # which is what would happen if we combined the two into one - # transaction. - # - # Note: This is only here for backwards compat to allow admins to - # roll back to a previous Synapse version. Next time we update the - # database version we can remove this table. - await self._update_max_stream_id(next_id) - self._account_data_stream_cache.entity_has_changed(user_id, next_id) self.get_account_data_for_user.invalidate((user_id,)) self.get_global_account_data_by_type_for_user.invalidate( @@ -486,24 +463,3 @@ class AccountDataStore(AccountDataWorkerStore): # Invalidate the cache for any ignored users which were added or removed. for ignored_user_id in previously_ignored_users ^ currently_ignored_users: self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,)) - - async def _update_max_stream_id(self, next_id: int) -> None: - """Update the max stream_id - - Args: - next_id: The the revision to advance to. - """ - - # Note: This is only here for backwards compat to allow admins to - # roll back to a previous Synapse version. Next time we update the - # database version we can remove this table. - - def _update(txn): - update_max_id_sql = ( - "UPDATE account_data_max_stream_id" - " SET stream_id = ?" - " WHERE stream_id < ?" - ) - txn.execute(update_max_id_sql, (next_id, next_id)) - - await self.db_pool.runInteraction("update_account_data_max_stream_id", _update) diff --git a/synapse/storage/databases/main/schema/delta/59/04drop_account_data.sql b/synapse/storage/databases/main/schema/delta/59/04drop_account_data.sql new file mode 100644 index 0000000000..64ab696cfe --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/04drop_account_data.sql @@ -0,0 +1,17 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This is no longer used and was only kept until we bumped the schema version. +DROP TABLE IF EXISTS account_data_max_stream_id; diff --git a/synapse/storage/databases/main/schema/delta/59/05cache_invalidation.sql b/synapse/storage/databases/main/schema/delta/59/05cache_invalidation.sql new file mode 100644 index 0000000000..fb71b360a0 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/05cache_invalidation.sql @@ -0,0 +1,17 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This is no longer used and was only kept until we bumped the schema version. +DROP TABLE IF EXISTS cache_invalidation_stream; diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index 9f120d3cb6..74da9c49f2 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -255,16 +255,6 @@ class TagsStore(TagsWorkerStore): self._account_data_stream_cache.entity_has_changed, user_id, next_id ) - # Note: This is only here for backwards compat to allow admins to - # roll back to a previous Synapse version. Next time we update the - # database version we can remove this table. - update_max_id_sql = ( - "UPDATE account_data_max_stream_id" - " SET stream_id = ?" - " WHERE stream_id < ?" - ) - txn.execute(update_max_id_sql, (next_id, next_id)) - update_sql = ( "UPDATE room_tags_revisions" " SET stream_id = ?" diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 01efb2cabb..566ea19bae 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -35,9 +35,6 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -# XXX: If you're about to bump this to 59 (or higher) please create an update -# that drops the unused `cache_invalidation_stream` table, as per #7436! -# XXX: Also add an update to drop `account_data_max_stream_id` as per #7656! SCHEMA_VERSION = 59 dir_path = os.path.abspath(os.path.dirname(__file__)) -- cgit 1.5.1 From 1315a2e8be702a513d49c1142e9e52b642286635 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 11 Jan 2021 16:09:22 +0000 Subject: Use a chain cover index to efficiently calculate auth chain difference (#8868) --- changelog.d/8868.misc | 1 + docs/auth_chain_diff.dot | 32 ++ docs/auth_chain_diff.dot.png | Bin 0 -> 42427 bytes docs/auth_chain_difference_algorithm.md | 108 +++++ synapse/storage/database.py | 22 +- synapse/storage/databases/main/event_federation.py | 185 +++++++ synapse/storage/databases/main/events.py | 535 ++++++++++++++++++++- synapse/storage/databases/main/room.py | 51 +- .../main/schema/delta/59/04_event_auth_chains.sql | 52 ++ .../delta/59/04_event_auth_chains.sql.postgres | 16 + synapse/util/iterutils.py | 53 +- tests/storage/test_event_chain.py | 472 ++++++++++++++++++ tests/storage/test_event_federation.py | 249 +++++++++- tests/util/test_itertools.py | 41 +- 14 files changed, 1769 insertions(+), 48 deletions(-) create mode 100644 changelog.d/8868.misc create mode 100644 docs/auth_chain_diff.dot create mode 100644 docs/auth_chain_diff.dot.png create mode 100644 docs/auth_chain_difference_algorithm.md create mode 100644 synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql create mode 100644 synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres create mode 100644 tests/storage/test_event_chain.py (limited to 'synapse/storage/databases') diff --git a/changelog.d/8868.misc b/changelog.d/8868.misc new file mode 100644 index 0000000000..1a11e30944 --- /dev/null +++ b/changelog.d/8868.misc @@ -0,0 +1 @@ +Improve efficiency of large state resolutions for new rooms. diff --git a/docs/auth_chain_diff.dot b/docs/auth_chain_diff.dot new file mode 100644 index 0000000000..978d579ada --- /dev/null +++ b/docs/auth_chain_diff.dot @@ -0,0 +1,32 @@ +digraph auth { + nodesep=0.5; + rankdir="RL"; + + C [label="Create (1,1)"]; + + BJ [label="Bob's Join (2,1)", color=red]; + BJ2 [label="Bob's Join (2,2)", color=red]; + BJ2 -> BJ [color=red, dir=none]; + + subgraph cluster_foo { + A1 [label="Alice's invite (4,1)", color=blue]; + A2 [label="Alice's Join (4,2)", color=blue]; + A3 [label="Alice's Join (4,3)", color=blue]; + A3 -> A2 -> A1 [color=blue, dir=none]; + color=none; + } + + PL1 [label="Power Level (3,1)", color=darkgreen]; + PL2 [label="Power Level (3,2)", color=darkgreen]; + PL2 -> PL1 [color=darkgreen, dir=none]; + + {rank = same; C; BJ; PL1; A1;} + + A1 -> C [color=grey]; + A1 -> BJ [color=grey]; + PL1 -> C [color=grey]; + BJ2 -> PL1 [penwidth=2]; + + A3 -> PL2 [penwidth=2]; + A1 -> PL1 -> BJ -> C [penwidth=2]; +} diff --git a/docs/auth_chain_diff.dot.png b/docs/auth_chain_diff.dot.png new file mode 100644 index 0000000000..771c07308f Binary files /dev/null and b/docs/auth_chain_diff.dot.png differ diff --git a/docs/auth_chain_difference_algorithm.md b/docs/auth_chain_difference_algorithm.md new file mode 100644 index 0000000000..30f72a70da --- /dev/null +++ b/docs/auth_chain_difference_algorithm.md @@ -0,0 +1,108 @@ +# Auth Chain Difference Algorithm + +The auth chain difference algorithm is used by V2 state resolution, where a +naive implementation can be a significant source of CPU and DB usage. + +### Definitions + +A *state set* is a set of state events; e.g. the input of a state resolution +algorithm is a collection of state sets. + +The *auth chain* of a set of events are all the events' auth events and *their* +auth events, recursively (i.e. the events reachable by walking the graph induced +by an event's auth events links). + +The *auth chain difference* of a collection of state sets is the union minus the +intersection of the sets of auth chains corresponding to the state sets, i.e an +event is in the auth chain difference if it is reachable by walking the auth +event graph from at least one of the state sets but not from *all* of the state +sets. + +## Breadth First Walk Algorithm + +A way of calculating the auth chain difference without calculating the full auth +chains for each state set is to do a parallel breadth first walk (ordered by +depth) of each state set's auth chain. By tracking which events are reachable +from each state set we can finish early if every pending event is reachable from +every state set. + +This can work well for state sets that have a small auth chain difference, but +can be very inefficient for larger differences. However, this algorithm is still +used if we don't have a chain cover index for the room (e.g. because we're in +the process of indexing it). + +## Chain Cover Index + +Synapse computes auth chain differences by pre-computing a "chain cover" index +for the auth chain in a room, allowing efficient reachability queries like "is +event A in the auth chain of event B". This is done by assigning every event a +*chain ID* and *sequence number* (e.g. `(5,3)`), and having a map of *links* +between chains (e.g. `(5,3) -> (2,4)`) such that A is reachable by B (i.e. `A` +is in the auth chain of `B`) if and only if either: + +1. A and B have the same chain ID and `A`'s sequence number is less than `B`'s + sequence number; or +2. there is a link `L` between `B`'s chain ID and `A`'s chain ID such that + `L.start_seq_no` <= `B.seq_no` and `A.seq_no` <= `L.end_seq_no`. + +There are actually two potential implementations, one where we store links from +each chain to every other reachable chain (the transitive closure of the links +graph), and one where we remove redundant links (the transitive reduction of the +links graph) e.g. if we have chains `C3 -> C2 -> C1` then the link `C3 -> C1` +would not be stored. Synapse uses the former implementations so that it doesn't +need to recurse to test reachability between chains. + +### Example + +An example auth graph would look like the following, where chains have been +formed based on type/state_key and are denoted by colour and are labelled with +`(chain ID, sequence number)`. Links are denoted by the arrows (links in grey +are those that would be remove in the second implementation described above). + +![Example](auth_chain_diff.dot.png) + +Note that we don't include all links between events and their auth events, as +most of those links would be redundant. For example, all events point to the +create event, but each chain only needs the one link from it's base to the +create event. + +## Using the Index + +This index can be used to calculate the auth chain difference of the state sets +by looking at the chain ID and sequence numbers reachable from each state set: + +1. For every state set lookup the chain ID/sequence numbers of each state event +2. Use the index to find all chains and the maximum sequence number reachable + from each state set. +3. The auth chain difference is then all events in each chain that have sequence + numbers between the maximum sequence number reachable from *any* state set and + the minimum reachable by *all* state sets (if any). + +Note that steps 2 is effectively calculating the auth chain for each state set +(in terms of chain IDs and sequence numbers), and step 3 is calculating the +difference between the union and intersection of the auth chains. + +### Worked Example + +For example, given the above graph, we can calculate the difference between +state sets consisting of: + +1. `S1`: Alice's invite `(4,1)` and Bob's second join `(2,2)`; and +2. `S2`: Alice's second join `(4,3)` and Bob's first join `(2,1)`. + +Using the index we see that the following auth chains are reachable from each +state set: + +1. `S1`: `(1,1)`, `(2,2)`, `(3,1)` & `(4,1)` +2. `S2`: `(1,1)`, `(2,1)`, `(3,2)` & `(4,3)` + +And so, for each the ranges that are in the auth chain difference: +1. Chain 1: None, (since everything can reach the create event). +2. Chain 2: The range `(1, 2]` (i.e. just `2`), as `1` is reachable by all state + sets and the maximum reachable is `2` (corresponding to Bob's second join). +3. Chain 3: Similarly the range `(1, 2]` (corresponding to the second power + level). +4. Chain 4: The range `(1, 3]` (corresponding to both of Alice's joins). + +So the final result is: Bob's second join `(2,2)`, the second power level +`(3,2)` and both of Alice's joins `(4,2)` & `(4,3)`. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index b70ca3087b..6cfadc2b4e 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -179,6 +179,9 @@ class LoggingDatabaseConnection: _CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]] +R = TypeVar("R") + + class LoggingTransaction: """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() @@ -266,6 +269,20 @@ class LoggingTransaction: for val in args: self.execute(sql, val) + def execute_values(self, sql: str, *args: Any) -> List[Tuple]: + """Corresponds to psycopg2.extras.execute_values. Only available when + using postgres. + + Always sets fetch=True when caling `execute_values`, so will return the + results. + """ + assert isinstance(self.database_engine, PostgresEngine) + from psycopg2.extras import execute_values # type: ignore + + return self._do_execute( + lambda *x: execute_values(self.txn, *x, fetch=True), sql, *args + ) + def execute(self, sql: str, *args: Any) -> None: self._do_execute(self.txn.execute, sql, *args) @@ -276,7 +293,7 @@ class LoggingTransaction: "Strip newlines out of SQL so that the loggers in the DB are on one line" return " ".join(line.strip() for line in sql.splitlines() if line.strip()) - def _do_execute(self, func, sql: str, *args: Any) -> None: + def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R: sql = self._make_sql_one_line(sql) # TODO(paul): Maybe use 'info' and 'debug' for values? @@ -347,9 +364,6 @@ class PerformanceCounters: return top_n_counters -R = TypeVar("R") - - class DatabasePool: """Wraps a single physical database and connection pool. diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index ebffd89251..8326640d20 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -24,6 +24,8 @@ from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore +from synapse.storage.engines import PostgresEngine +from synapse.storage.types import Cursor from synapse.types import Collection from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache @@ -32,6 +34,11 @@ from synapse.util.iterutils import batch_iter logger = logging.getLogger(__name__) +class _NoChainCoverIndex(Exception): + def __init__(self, room_id: str): + super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,)) + + class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) @@ -151,15 +158,193 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas The set of the difference in auth chains. """ + # Check if we have indexed the room so we can use the chain cover + # algorithm. + room = await self.get_room(room_id) + if room["has_auth_chain_index"]: + try: + return await self.db_pool.runInteraction( + "get_auth_chain_difference_chains", + self._get_auth_chain_difference_using_cover_index_txn, + room_id, + state_sets, + ) + except _NoChainCoverIndex: + # For whatever reason we don't actually have a chain cover index + # for the events in question, so we fall back to the old method. + pass + return await self.db_pool.runInteraction( "get_auth_chain_difference", self._get_auth_chain_difference_txn, state_sets, ) + def _get_auth_chain_difference_using_cover_index_txn( + self, txn: Cursor, room_id: str, state_sets: List[Set[str]] + ) -> Set[str]: + """Calculates the auth chain difference using the chain index. + + See docs/auth_chain_difference_algorithm.md for details + """ + + # First we look up the chain ID/sequence numbers for all the events, and + # work out the chain/sequence numbers reachable from each state set. + + initial_events = set(state_sets[0]).union(*state_sets[1:]) + + # Map from event_id -> (chain ID, seq no) + chain_info = {} # type: Dict[str, Tuple[int, int]] + + # Map from chain ID -> seq no -> event Id + chain_to_event = {} # type: Dict[int, Dict[int, str]] + + # All the chains that we've found that are reachable from the state + # sets. + seen_chains = set() # type: Set[int] + + sql = """ + SELECT event_id, chain_id, sequence_number + FROM event_auth_chains + WHERE %s + """ + for batch in batch_iter(initial_events, 1000): + clause, args = make_in_list_sql_clause( + txn.database_engine, "event_id", batch + ) + txn.execute(sql % (clause,), args) + + for event_id, chain_id, sequence_number in txn: + chain_info[event_id] = (chain_id, sequence_number) + seen_chains.add(chain_id) + chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id + + # Check that we actually have a chain ID for all the events. + events_missing_chain_info = initial_events.difference(chain_info) + if events_missing_chain_info: + # This can happen due to e.g. downgrade/upgrade of the server. We + # raise an exception and fall back to the previous algorithm. + logger.info( + "Unexpectedly found that events don't have chain IDs in room %s: %s", + room_id, + events_missing_chain_info, + ) + raise _NoChainCoverIndex(room_id) + + # Corresponds to `state_sets`, except as a map from chain ID to max + # sequence number reachable from the state set. + set_to_chain = [] # type: List[Dict[int, int]] + for state_set in state_sets: + chains = {} # type: Dict[int, int] + set_to_chain.append(chains) + + for event_id in state_set: + chain_id, seq_no = chain_info[event_id] + + chains[chain_id] = max(seq_no, chains.get(chain_id, 0)) + + # Now we look up all links for the chains we have, adding chains to + # set_to_chain that are reachable from each set. + sql = """ + SELECT + origin_chain_id, origin_sequence_number, + target_chain_id, target_sequence_number + FROM event_auth_chain_links + WHERE %s + """ + + # (We need to take a copy of `seen_chains` as we want to mutate it in + # the loop) + for batch in batch_iter(set(seen_chains), 1000): + clause, args = make_in_list_sql_clause( + txn.database_engine, "origin_chain_id", batch + ) + txn.execute(sql % (clause,), args) + + for ( + origin_chain_id, + origin_sequence_number, + target_chain_id, + target_sequence_number, + ) in txn: + for chains in set_to_chain: + # chains are only reachable if the origin sequence number of + # the link is less than the max sequence number in the + # origin chain. + if origin_sequence_number <= chains.get(origin_chain_id, 0): + chains[target_chain_id] = max( + target_sequence_number, chains.get(target_chain_id, 0), + ) + + seen_chains.add(target_chain_id) + + # Now for each chain we figure out the maximum sequence number reachable + # from *any* state set and the minimum sequence number reachable from + # *all* state sets. Events in that range are in the auth chain + # difference. + result = set() + + # Mapping from chain ID to the range of sequence numbers that should be + # pulled from the database. + chain_to_gap = {} # type: Dict[int, Tuple[int, int]] + + for chain_id in seen_chains: + min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain) + max_seq_no = max(chains.get(chain_id, 0) for chains in set_to_chain) + + if min_seq_no < max_seq_no: + # We have a non empty gap, try and fill it from the events that + # we have, otherwise add them to the list of gaps to pull out + # from the DB. + for seq_no in range(min_seq_no + 1, max_seq_no + 1): + event_id = chain_to_event.get(chain_id, {}).get(seq_no) + if event_id: + result.add(event_id) + else: + chain_to_gap[chain_id] = (min_seq_no, max_seq_no) + break + + if not chain_to_gap: + # If there are no gaps to fetch, we're done! + return result + + if isinstance(self.database_engine, PostgresEngine): + # We can use `execute_values` to efficiently fetch the gaps when + # using postgres. + sql = """ + SELECT event_id + FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq) + WHERE + c.chain_id = l.chain_id + AND min_seq < sequence_number AND sequence_number <= max_seq + """ + + args = [ + (chain_id, min_no, max_no) + for chain_id, (min_no, max_no) in chain_to_gap.items() + ] + + rows = txn.execute_values(sql, args) + result.update(r for r, in rows) + else: + # For SQLite we just fall back to doing a noddy for loop. + sql = """ + SELECT event_id FROM event_auth_chains + WHERE chain_id = ? AND ? < sequence_number AND sequence_number <= ? + """ + for chain_id, (min_no, max_no) in chain_to_gap.items(): + txn.execute(sql, (chain_id, min_no, max_no)) + result.update(r for r, in txn) + + return result + def _get_auth_chain_difference_txn( self, txn, state_sets: List[Set[str]] ) -> Set[str]: + """Calculates the auth chain difference using a breadth first search. + + This is used when we don't have a cover index for the room. + """ # Algorithm Description # ~~~~~~~~~~~~~~~~~~~~~ diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 5e7753e09b..186f064036 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -17,7 +17,17 @@ import itertools import logging from collections import OrderedDict, namedtuple -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generator, + Iterable, + List, + Optional, + Set, + Tuple, +) import attr from prometheus_client import Counter @@ -33,9 +43,10 @@ from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchEntry from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.storage.util.sequence import build_sequence_generator from synapse.types import StateMap, get_domain_from_id from synapse.util import json_encoder -from synapse.util.iterutils import batch_iter +from synapse.util.iterutils import batch_iter, sorted_topologically if TYPE_CHECKING: from synapse.server import HomeServer @@ -89,6 +100,14 @@ class PersistEventsStore: self._clock = hs.get_clock() self._instance_name = hs.get_instance_name() + def get_chain_id_txn(txn): + txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains") + return txn.fetchone()[0] + + self._event_chain_id_gen = build_sequence_generator( + db.engine, get_chain_id_txn, "event_auth_chain_id" + ) + self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id @@ -366,6 +385,36 @@ class PersistEventsStore: # Insert into event_to_state_groups. self._store_event_state_mappings_txn(txn, events_and_contexts) + self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts]) + + # _store_rejected_events_txn filters out any events which were + # rejected, and returns the filtered list. + events_and_contexts = self._store_rejected_events_txn( + txn, events_and_contexts=events_and_contexts + ) + + # From this point onwards the events are only ones that weren't + # rejected. + + self._update_metadata_tables_txn( + txn, + events_and_contexts=events_and_contexts, + all_events_and_contexts=all_events_and_contexts, + backfilled=backfilled, + ) + + # We call this last as it assumes we've inserted the events into + # room_memberships, where applicable. + self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) + + def _persist_event_auth_chain_txn( + self, txn: LoggingTransaction, events: List[EventBase], + ) -> None: + + # We only care about state events, so this if there are no state events. + if not any(e.is_state() for e in events): + return + # We want to store event_auth mappings for rejected events, as they're # used in state res v2. # This is only necessary if the rejected event appears in an accepted @@ -381,31 +430,357 @@ class PersistEventsStore: "room_id": event.room_id, "auth_id": auth_id, } - for event, _ in events_and_contexts + for event in events for auth_id in event.auth_event_ids() if event.is_state() ], ) - # _store_rejected_events_txn filters out any events which were - # rejected, and returns the filtered list. - events_and_contexts = self._store_rejected_events_txn( - txn, events_and_contexts=events_and_contexts + # We now calculate chain ID/sequence numbers for any state events we're + # persisting. We ignore out of band memberships as we're not in the room + # and won't have their auth chain (we'll fix it up later if we join the + # room). + # + # See: docs/auth_chain_difference_algorithm.md + + # We ignore legacy rooms that we aren't filling the chain cover index + # for. + rows = self.db_pool.simple_select_many_txn( + txn, + table="rooms", + column="room_id", + iterable={event.room_id for event in events if event.is_state()}, + keyvalues={}, + retcols=("room_id", "has_auth_chain_index"), ) + rooms_using_chain_index = { + row["room_id"] for row in rows if row["has_auth_chain_index"] + } - # From this point onwards the events are only ones that weren't - # rejected. + state_events = { + event.event_id: event + for event in events + if event.is_state() and event.room_id in rooms_using_chain_index + } - self._update_metadata_tables_txn( + if not state_events: + return + + # Map from event ID to chain ID/sequence number. + chain_map = {} # type: Dict[str, Tuple[int, int]] + + # We need to know the type/state_key and auth events of the events we're + # calculating chain IDs for. We don't rely on having the full Event + # instances as we'll potentially be pulling more events from the DB and + # we don't need the overhead of fetching/parsing the full event JSON. + event_to_types = { + e.event_id: (e.type, e.state_key) for e in state_events.values() + } + event_to_auth_chain = { + e.event_id: e.auth_event_ids() for e in state_events.values() + } + + # Set of event IDs to calculate chain ID/seq numbers for. + events_to_calc_chain_id_for = set(state_events) + + # We check if there are any events that need to be handled in the rooms + # we're looking at. These should just be out of band memberships, where + # we didn't have the auth chain when we first persisted. + rows = self.db_pool.simple_select_many_txn( txn, - events_and_contexts=events_and_contexts, - all_events_and_contexts=all_events_and_contexts, - backfilled=backfilled, + table="event_auth_chain_to_calculate", + keyvalues={}, + column="room_id", + iterable={e.room_id for e in state_events.values()}, + retcols=("event_id", "type", "state_key"), ) + for row in rows: + event_id = row["event_id"] + event_type = row["type"] + state_key = row["state_key"] + + # (We could pull out the auth events for all rows at once using + # simple_select_many, but this case happens rarely and almost always + # with a single row.) + auth_events = self.db_pool.simple_select_onecol_txn( + txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id", + ) - # We call this last as it assumes we've inserted the events into - # room_memberships, where applicable. - self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) + events_to_calc_chain_id_for.add(event_id) + event_to_types[event_id] = (event_type, state_key) + event_to_auth_chain[event_id] = auth_events + + # First we get the chain ID and sequence numbers for the events' + # auth events (that aren't also currently being persisted). + # + # Note that there there is an edge case here where we might not have + # calculated chains and sequence numbers for events that were "out + # of band". We handle this case by fetching the necessary info and + # adding it to the set of events to calculate chain IDs for. + + missing_auth_chains = { + a_id + for auth_events in event_to_auth_chain.values() + for a_id in auth_events + if a_id not in events_to_calc_chain_id_for + } + + # We loop here in case we find an out of band membership and need to + # fetch their auth event info. + while missing_auth_chains: + sql = """ + SELECT event_id, events.type, state_key, chain_id, sequence_number + FROM events + INNER JOIN state_events USING (event_id) + LEFT JOIN event_auth_chains USING (event_id) + WHERE + """ + clause, args = make_in_list_sql_clause( + txn.database_engine, "event_id", missing_auth_chains, + ) + txn.execute(sql + clause, args) + + missing_auth_chains.clear() + + for auth_id, event_type, state_key, chain_id, sequence_number in txn: + event_to_types[auth_id] = (event_type, state_key) + + if chain_id is None: + # No chain ID, so the event was persisted out of band. + # We add to list of events to calculate auth chains for. + + events_to_calc_chain_id_for.add(auth_id) + + event_to_auth_chain[ + auth_id + ] = self.db_pool.simple_select_onecol_txn( + txn, + "event_auth", + keyvalues={"event_id": auth_id}, + retcol="auth_id", + ) + + missing_auth_chains.update( + e + for e in event_to_auth_chain[auth_id] + if e not in event_to_types + ) + else: + chain_map[auth_id] = (chain_id, sequence_number) + + # Now we check if we have any events where we don't have auth chain, + # this should only be out of band memberships. + for event_id in sorted_topologically(event_to_auth_chain, event_to_auth_chain): + for auth_id in event_to_auth_chain[event_id]: + if ( + auth_id not in chain_map + and auth_id not in events_to_calc_chain_id_for + ): + events_to_calc_chain_id_for.discard(event_id) + + # If this is an event we're trying to persist we add it to + # the list of events to calculate chain IDs for next time + # around. (Otherwise we will have already added it to the + # table). + event = state_events.get(event_id) + if event: + self.db_pool.simple_insert_txn( + txn, + table="event_auth_chain_to_calculate", + values={ + "event_id": event.event_id, + "room_id": event.room_id, + "type": event.type, + "state_key": event.state_key, + }, + ) + + # We stop checking the event's auth events since we've + # discarded it. + break + + if not events_to_calc_chain_id_for: + return + + # We now calculate the chain IDs/sequence numbers for the events. We + # do this by looking at the chain ID and sequence number of any auth + # event with the same type/state_key and incrementing the sequence + # number by one. If there was no match or the chain ID/sequence + # number is already taken we generate a new chain. + # + # We need to do this in a topologically sorted order as we want to + # generate chain IDs/sequence numbers of an event's auth events + # before the event itself. + chains_tuples_allocated = set() # type: Set[Tuple[int, int]] + new_chain_tuples = {} # type: Dict[str, Tuple[int, int]] + for event_id in sorted_topologically( + events_to_calc_chain_id_for, event_to_auth_chain + ): + existing_chain_id = None + for auth_id in event_to_auth_chain[event_id]: + if event_to_types.get(event_id) == event_to_types.get(auth_id): + existing_chain_id = chain_map[auth_id] + break + + new_chain_tuple = None + if existing_chain_id: + # We found a chain ID/sequence number candidate, check its + # not already taken. + proposed_new_id = existing_chain_id[0] + proposed_new_seq = existing_chain_id[1] + 1 + if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated: + already_allocated = self.db_pool.simple_select_one_onecol_txn( + txn, + table="event_auth_chains", + keyvalues={ + "chain_id": proposed_new_id, + "sequence_number": proposed_new_seq, + }, + retcol="event_id", + allow_none=True, + ) + if already_allocated: + # Mark it as already allocated so we don't need to hit + # the DB again. + chains_tuples_allocated.add((proposed_new_id, proposed_new_seq)) + else: + new_chain_tuple = ( + proposed_new_id, + proposed_new_seq, + ) + + if not new_chain_tuple: + new_chain_tuple = (self._event_chain_id_gen.get_next_id_txn(txn), 1) + + chains_tuples_allocated.add(new_chain_tuple) + + chain_map[event_id] = new_chain_tuple + new_chain_tuples[event_id] = new_chain_tuple + + self.db_pool.simple_insert_many_txn( + txn, + table="event_auth_chains", + values=[ + {"event_id": event_id, "chain_id": c_id, "sequence_number": seq} + for event_id, (c_id, seq) in new_chain_tuples.items() + ], + ) + + self.db_pool.simple_delete_many_txn( + txn, + table="event_auth_chain_to_calculate", + keyvalues={}, + column="event_id", + iterable=new_chain_tuples, + ) + + # Now we need to calculate any new links between chains caused by + # the new events. + # + # Links are pairs of chain ID/sequence numbers such that for any + # event A (CA, SA) and any event B (CB, SB), B is in A's auth chain + # if and only if there is at least one link (CA, S1) -> (CB, S2) + # where SA >= S1 and S2 >= SB. + # + # We try and avoid adding redundant links to the table, e.g. if we + # have two links between two chains which both start/end at the + # sequence number event (or cross) then one can be safely dropped. + # + # To calculate new links we look at every new event and: + # 1. Fetch the chain ID/sequence numbers of its auth events, + # discarding any that are reachable by other auth events, or + # that have the same chain ID as the event. + # 2. For each retained auth event we: + # a. Add a link from the event's to the auth event's chain + # ID/sequence number; and + # b. Add a link from the event to every chain reachable by the + # auth event. + + # Step 1, fetch all existing links from all the chains we've seen + # referenced. + chain_links = _LinkMap() + rows = self.db_pool.simple_select_many_txn( + txn, + table="event_auth_chain_links", + column="origin_chain_id", + iterable={chain_id for chain_id, _ in chain_map.values()}, + keyvalues={}, + retcols=( + "origin_chain_id", + "origin_sequence_number", + "target_chain_id", + "target_sequence_number", + ), + ) + for row in rows: + chain_links.add_link( + (row["origin_chain_id"], row["origin_sequence_number"]), + (row["target_chain_id"], row["target_sequence_number"]), + new=False, + ) + + # We do this in toplogical order to avoid adding redundant links. + for event_id in sorted_topologically( + events_to_calc_chain_id_for, event_to_auth_chain + ): + chain_id, sequence_number = chain_map[event_id] + + # Filter out auth events that are reachable by other auth + # events. We do this by looking at every permutation of pairs of + # auth events (A, B) to check if B is reachable from A. + reduction = { + a_id + for a_id in event_to_auth_chain[event_id] + if chain_map[a_id][0] != chain_id + } + for start_auth_id, end_auth_id in itertools.permutations( + event_to_auth_chain[event_id], r=2, + ): + if chain_links.exists_path_from( + chain_map[start_auth_id], chain_map[end_auth_id] + ): + reduction.discard(end_auth_id) + + # Step 2, figure out what the new links are from the reduced + # list of auth events. + for auth_id in reduction: + auth_chain_id, auth_sequence_number = chain_map[auth_id] + + # Step 2a, add link between the event and auth event + chain_links.add_link( + (chain_id, sequence_number), (auth_chain_id, auth_sequence_number) + ) + + # Step 2b, add a link to chains reachable from the auth + # event. + for target_id, target_seq in chain_links.get_links_from( + (auth_chain_id, auth_sequence_number) + ): + if target_id == chain_id: + continue + + chain_links.add_link( + (chain_id, sequence_number), (target_id, target_seq) + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="event_auth_chain_links", + values=[ + { + "origin_chain_id": source_id, + "origin_sequence_number": source_seq, + "target_chain_id": target_id, + "target_sequence_number": target_seq, + } + for ( + source_id, + source_seq, + target_id, + target_seq, + ) in chain_links.get_additions() + ], + ) def _persist_transaction_ids_txn( self, @@ -1521,3 +1896,131 @@ class PersistEventsStore: if not ev.internal_metadata.is_outlier() ], ) + + +@attr.s(slots=True) +class _LinkMap: + """A helper type for tracking links between chains. + """ + + # Stores the set of links as nested maps: source chain ID -> target chain ID + # -> source sequence number -> target sequence number. + maps = attr.ib(type=Dict[int, Dict[int, Dict[int, int]]], factory=dict) + + # Stores the links that have been added (with new set to true), as tuples of + # `(source chain ID, source sequence no, target chain ID, target sequence no.)` + additions = attr.ib(type=Set[Tuple[int, int, int, int]], factory=set) + + def add_link( + self, + src_tuple: Tuple[int, int], + target_tuple: Tuple[int, int], + new: bool = True, + ) -> bool: + """Add a new link between two chains, ensuring no redundant links are added. + + New links should be added in topological order. + + Args: + src_tuple: The chain ID/sequence number of the source of the link. + target_tuple: The chain ID/sequence number of the target of the link. + new: Whether this is a "new" link, i.e. should it be returned + by `get_additions`. + + Returns: + True if a link was added, false if the given link was dropped as redundant + """ + src_chain, src_seq = src_tuple + target_chain, target_seq = target_tuple + + current_links = self.maps.setdefault(src_chain, {}).setdefault(target_chain, {}) + + assert src_chain != target_chain + + if new: + # Check if the new link is redundant + for current_seq_src, current_seq_target in current_links.items(): + # If a link "crosses" another link then its redundant. For example + # in the following link 1 (L1) is redundant, as any event reachable + # via L1 is *also* reachable via L2. + # + # Chain A Chain B + # | | + # L1 |------ | + # | | | + # L2 |---- | -->| + # | | | + # | |--->| + # | | + # | | + # + # So we only need to keep links which *do not* cross, i.e. links + # that both start and end above or below an existing link. + # + # Note, since we add links in topological ordering we should never + # see `src_seq` less than `current_seq_src`. + + if current_seq_src <= src_seq and target_seq <= current_seq_target: + # This new link is redundant, nothing to do. + return False + + self.additions.add((src_chain, src_seq, target_chain, target_seq)) + + current_links[src_seq] = target_seq + return True + + def get_links_from( + self, src_tuple: Tuple[int, int] + ) -> Generator[Tuple[int, int], None, None]: + """Gets the chains reachable from the given chain/sequence number. + + Yields: + The chain ID and sequence number the link points to. + """ + src_chain, src_seq = src_tuple + for target_id, sequence_numbers in self.maps.get(src_chain, {}).items(): + for link_src_seq, target_seq in sequence_numbers.items(): + if link_src_seq <= src_seq: + yield target_id, target_seq + + def get_links_between( + self, source_chain: int, target_chain: int + ) -> Generator[Tuple[int, int], None, None]: + """Gets the links between two chains. + + Yields: + The source and target sequence numbers. + """ + + yield from self.maps.get(source_chain, {}).get(target_chain, {}).items() + + def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]: + """Gets any newly added links. + + Yields: + The source chain ID/sequence number and target chain ID/sequence number + """ + + for src_chain, src_seq, target_chain, _ in self.additions: + target_seq = self.maps.get(src_chain, {}).get(target_chain, {}).get(src_seq) + if target_seq is not None: + yield (src_chain, src_seq, target_chain, target_seq) + + def exists_path_from( + self, src_tuple: Tuple[int, int], target_tuple: Tuple[int, int], + ) -> bool: + """Checks if there is a path between the source chain ID/sequence and + target chain ID/sequence. + """ + src_chain, src_seq = src_tuple + target_chain, target_seq = target_tuple + + if src_chain == target_chain: + return target_seq <= src_seq + + links = self.get_links_between(src_chain, target_chain) + for link_start_seq, link_end_seq in links: + if link_start_seq <= src_seq and target_seq <= link_end_seq: + return True + + return False diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 4650d0689b..284f2ce77c 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -84,7 +84,7 @@ class RoomWorkerStore(SQLBaseStore): return await self.db_pool.simple_select_one( table="rooms", keyvalues={"room_id": room_id}, - retcols=("room_id", "is_public", "creator"), + retcols=("room_id", "is_public", "creator", "has_auth_chain_index"), desc="get_room", allow_none=True, ) @@ -1166,6 +1166,37 @@ class RoomBackgroundUpdateStore(SQLBaseStore): # It's overridden by RoomStore for the synapse master. raise NotImplementedError() + async def has_auth_chain_index(self, room_id: str) -> bool: + """Check if the room has (or can have) a chain cover index. + + Defaults to True if we don't have an entry in `rooms` table nor any + events for the room. + """ + + has_auth_chain_index = await self.db_pool.simple_select_one_onecol( + table="rooms", + keyvalues={"room_id": room_id}, + retcol="has_auth_chain_index", + desc="has_auth_chain_index", + allow_none=True, + ) + + if has_auth_chain_index: + return True + + # It's possible that we already have events for the room in our DB + # without a corresponding room entry. If we do then we don't want to + # mark the room as having an auth chain cover index. + max_ordering = await self.db_pool.simple_select_one_onecol( + table="events", + keyvalues={"room_id": room_id}, + retcol="MAX(stream_ordering)", + allow_none=True, + desc="upsert_room_on_join", + ) + + return max_ordering is None + class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): def __init__(self, database: DatabasePool, db_conn, hs): @@ -1179,12 +1210,21 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): Called when we join a room over federation, and overwrites any room version currently in the table. """ + # It's possible that we already have events for the room in our DB + # without a corresponding room entry. If we do then we don't want to + # mark the room as having an auth chain cover index. + has_auth_chain_index = await self.has_auth_chain_index(room_id) + await self.db_pool.simple_upsert( desc="upsert_room_on_join", table="rooms", keyvalues={"room_id": room_id}, values={"room_version": room_version.identifier}, - insertion_values={"is_public": False, "creator": ""}, + insertion_values={ + "is_public": False, + "creator": "", + "has_auth_chain_index": has_auth_chain_index, + }, # rooms has a unique constraint on room_id, so no need to lock when doing an # emulated upsert. lock=False, @@ -1219,6 +1259,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): "creator": room_creator_user_id, "is_public": is_public, "room_version": room_version.identifier, + "has_auth_chain_index": True, }, ) if is_public: @@ -1247,6 +1288,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): When we receive an invite or any other event over federation that may relate to a room we are not in, store the version of the room if we don't already know the room version. """ + # It's possible that we already have events for the room in our DB + # without a corresponding room entry. If we do then we don't want to + # mark the room as having an auth chain cover index. + has_auth_chain_index = await self.has_auth_chain_index(room_id) + await self.db_pool.simple_upsert( desc="maybe_store_room_on_outlier_membership", table="rooms", @@ -1256,6 +1302,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): "room_version": room_version.identifier, "is_public": False, "creator": "", + "has_auth_chain_index": has_auth_chain_index, }, # rooms has a unique constraint on room_id, so no need to lock when doing an # emulated upsert. diff --git a/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql new file mode 100644 index 0000000000..729196cfd5 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql @@ -0,0 +1,52 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- See docs/auth_chain_difference_algorithm.md + +CREATE TABLE event_auth_chains ( + event_id TEXT PRIMARY KEY, + chain_id BIGINT NOT NULL, + sequence_number BIGINT NOT NULL +); + +CREATE UNIQUE INDEX event_auth_chains_c_seq_index ON event_auth_chains (chain_id, sequence_number); + + +CREATE TABLE event_auth_chain_links ( + origin_chain_id BIGINT NOT NULL, + origin_sequence_number BIGINT NOT NULL, + + target_chain_id BIGINT NOT NULL, + target_sequence_number BIGINT NOT NULL +); + + +CREATE INDEX event_auth_chain_links_idx ON event_auth_chain_links (origin_chain_id, target_chain_id); + + +-- Events that we have persisted but not calculated auth chains for, +-- e.g. out of band memberships (where we don't have the auth chain) +CREATE TABLE event_auth_chain_to_calculate ( + event_id TEXT PRIMARY KEY, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL +); + +CREATE INDEX event_auth_chain_to_calculate_rm_id ON event_auth_chain_to_calculate(room_id); + + +-- Whether we've calculated the above index for a room. +ALTER TABLE rooms ADD COLUMN has_auth_chain_index BOOLEAN; diff --git a/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres new file mode 100644 index 0000000000..e8a035bbeb --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres @@ -0,0 +1,16 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE IF NOT EXISTS event_auth_chain_id; diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py index 06faeebe7f..f7b4857a84 100644 --- a/synapse/util/iterutils.py +++ b/synapse/util/iterutils.py @@ -13,8 +13,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import heapq from itertools import islice -from typing import Iterable, Iterator, Sequence, Tuple, TypeVar +from typing import ( + Dict, + Generator, + Iterable, + Iterator, + Mapping, + Sequence, + Set, + Tuple, + TypeVar, +) + +from synapse.types import Collection T = TypeVar("T") @@ -46,3 +59,41 @@ def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]: If the input is empty, no chunks are returned. """ return (iseq[i : i + maxlen] for i in range(0, len(iseq), maxlen)) + + +def sorted_topologically( + nodes: Iterable[T], graph: Mapping[T, Collection[T]], +) -> Generator[T, None, None]: + """Given a set of nodes and a graph, yield the nodes in toplogical order. + + For example `sorted_topologically([1, 2], {1: [2]})` will yield `2, 1`. + """ + + # This is implemented by Kahn's algorithm. + + degree_map = {node: 0 for node in nodes} + reverse_graph = {} # type: Dict[T, Set[T]] + + for node, edges in graph.items(): + if node not in degree_map: + continue + + for edge in edges: + if edge in degree_map: + degree_map[node] += 1 + + reverse_graph.setdefault(edge, set()).add(node) + reverse_graph.setdefault(node, set()) + + zero_degree = [node for node, degree in degree_map.items() if degree == 0] + heapq.heapify(zero_degree) + + while zero_degree: + node = heapq.heappop(zero_degree) + yield node + + for edge in reverse_graph[node]: + if edge in degree_map: + degree_map[edge] -= 1 + if degree_map[edge] == 0: + heapq.heappush(zero_degree, edge) diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py new file mode 100644 index 0000000000..83c377824b --- /dev/null +++ b/tests/storage/test_event_chain.py @@ -0,0 +1,472 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Tuple + +from twisted.trial import unittest + +from synapse.api.constants import EventTypes +from synapse.api.room_versions import RoomVersions +from synapse.events import EventBase +from synapse.storage.databases.main.events import _LinkMap + +from tests.unittest import HomeserverTestCase + + +class EventChainStoreTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self._next_stream_ordering = 1 + + def test_simple(self): + """Test that the example in `docs/auth_chain_difference_algorithm.md` + works. + """ + + event_factory = self.hs.get_event_builder_factory() + bob = "@creator:test" + alice = "@alice:test" + room_id = "!room:test" + + # Ensure that we have a rooms entry so that we generate the chain index. + self.get_success( + self.store.store_room( + room_id=room_id, + room_creator_user_id="", + is_public=True, + room_version=RoomVersions.V6, + ) + ) + + create = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Create, + "state_key": "", + "sender": bob, + "room_id": room_id, + "content": {"tag": "create"}, + }, + ).build(prev_event_ids=[], auth_event_ids=[]) + ) + + bob_join = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": bob, + "sender": bob, + "room_id": room_id, + "content": {"tag": "bob_join"}, + }, + ).build(prev_event_ids=[], auth_event_ids=[create.event_id]) + ) + + power = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.PowerLevels, + "state_key": "", + "sender": bob, + "room_id": room_id, + "content": {"tag": "power"}, + }, + ).build( + prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id], + ) + ) + + alice_invite = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": alice, + "sender": bob, + "room_id": room_id, + "content": {"tag": "alice_invite"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, bob_join.event_id, power.event_id], + ) + ) + + alice_join = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": alice, + "sender": alice, + "room_id": room_id, + "content": {"tag": "alice_join"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id], + ) + ) + + power_2 = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.PowerLevels, + "state_key": "", + "sender": bob, + "room_id": room_id, + "content": {"tag": "power_2"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, bob_join.event_id, power.event_id], + ) + ) + + bob_join_2 = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": bob, + "sender": bob, + "room_id": room_id, + "content": {"tag": "bob_join_2"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, bob_join.event_id, power.event_id], + ) + ) + + alice_join2 = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": alice, + "sender": alice, + "room_id": room_id, + "content": {"tag": "alice_join2"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[ + create.event_id, + alice_join.event_id, + power_2.event_id, + ], + ) + ) + + events = [ + create, + bob_join, + power, + alice_invite, + alice_join, + bob_join_2, + power_2, + alice_join2, + ] + + expected_links = [ + (bob_join, create), + (power, create), + (power, bob_join), + (alice_invite, create), + (alice_invite, power), + (alice_invite, bob_join), + (bob_join_2, power), + (alice_join2, power_2), + ] + + self.persist(events) + chain_map, link_map = self.fetch_chains(events) + + # Check that the expected links and only the expected links have been + # added. + self.assertEqual(len(expected_links), len(list(link_map.get_additions()))) + + for start, end in expected_links: + start_id, start_seq = chain_map[start.event_id] + end_id, end_seq = chain_map[end.event_id] + + self.assertIn( + (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id)) + ) + + # Test that everything can reach the create event, but the create event + # can't reach anything. + for event in events[1:]: + self.assertTrue( + link_map.exists_path_from( + chain_map[event.event_id], chain_map[create.event_id] + ), + ) + + self.assertFalse( + link_map.exists_path_from( + chain_map[create.event_id], chain_map[event.event_id], + ), + ) + + def test_out_of_order_events(self): + """Test that we handle persisting events that we don't have the full + auth chain for yet (which should only happen for out of band memberships). + """ + event_factory = self.hs.get_event_builder_factory() + bob = "@creator:test" + alice = "@alice:test" + room_id = "!room:test" + + # Ensure that we have a rooms entry so that we generate the chain index. + self.get_success( + self.store.store_room( + room_id=room_id, + room_creator_user_id="", + is_public=True, + room_version=RoomVersions.V6, + ) + ) + + # First persist the base room. + create = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Create, + "state_key": "", + "sender": bob, + "room_id": room_id, + "content": {"tag": "create"}, + }, + ).build(prev_event_ids=[], auth_event_ids=[]) + ) + + bob_join = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": bob, + "sender": bob, + "room_id": room_id, + "content": {"tag": "bob_join"}, + }, + ).build(prev_event_ids=[], auth_event_ids=[create.event_id]) + ) + + power = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.PowerLevels, + "state_key": "", + "sender": bob, + "room_id": room_id, + "content": {"tag": "power"}, + }, + ).build( + prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id], + ) + ) + + self.persist([create, bob_join, power]) + + # Now persist an invite and a couple of memberships out of order. + alice_invite = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": alice, + "sender": bob, + "room_id": room_id, + "content": {"tag": "alice_invite"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, bob_join.event_id, power.event_id], + ) + ) + + alice_join = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": alice, + "sender": alice, + "room_id": room_id, + "content": {"tag": "alice_join"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id], + ) + ) + + alice_join2 = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": alice, + "sender": alice, + "room_id": room_id, + "content": {"tag": "alice_join2"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, alice_join.event_id, power.event_id], + ) + ) + + self.persist([alice_join]) + self.persist([alice_join2]) + self.persist([alice_invite]) + + # The end result should be sane. + events = [create, bob_join, power, alice_invite, alice_join] + + chain_map, link_map = self.fetch_chains(events) + + expected_links = [ + (bob_join, create), + (power, create), + (power, bob_join), + (alice_invite, create), + (alice_invite, power), + (alice_invite, bob_join), + ] + + # Check that the expected links and only the expected links have been + # added. + self.assertEqual(len(expected_links), len(list(link_map.get_additions()))) + + for start, end in expected_links: + start_id, start_seq = chain_map[start.event_id] + end_id, end_seq = chain_map[end.event_id] + + self.assertIn( + (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id)) + ) + + def persist( + self, events: List[EventBase], + ): + """Persist the given events and check that the links generated match + those given. + """ + + persist_events_store = self.hs.get_datastores().persist_events + + for e in events: + e.internal_metadata.stream_ordering = self._next_stream_ordering + self._next_stream_ordering += 1 + + def _persist(txn): + # We need to persist the events to the events and state_events + # tables. + persist_events_store._store_event_txn(txn, [(e, {}) for e in events]) + + # Actually call the function that calculates the auth chain stuff. + persist_events_store._persist_event_auth_chain_txn(txn, events) + + self.get_success( + persist_events_store.db_pool.runInteraction("_persist", _persist,) + ) + + def fetch_chains( + self, events: List[EventBase] + ) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]: + + # Fetch the map from event ID -> (chain ID, sequence number) + rows = self.get_success( + self.store.db_pool.simple_select_many_batch( + table="event_auth_chains", + column="event_id", + iterable=[e.event_id for e in events], + retcols=("event_id", "chain_id", "sequence_number"), + keyvalues={}, + ) + ) + + chain_map = { + row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows + } + + # Fetch all the links and pass them to the _LinkMap. + rows = self.get_success( + self.store.db_pool.simple_select_many_batch( + table="event_auth_chain_links", + column="origin_chain_id", + iterable=[chain_id for chain_id, _ in chain_map.values()], + retcols=( + "origin_chain_id", + "origin_sequence_number", + "target_chain_id", + "target_sequence_number", + ), + keyvalues={}, + ) + ) + + link_map = _LinkMap() + for row in rows: + added = link_map.add_link( + (row["origin_chain_id"], row["origin_sequence_number"]), + (row["target_chain_id"], row["target_sequence_number"]), + ) + + # We shouldn't have persisted any redundant links + self.assertTrue(added) + + return chain_map, link_map + + +class LinkMapTestCase(unittest.TestCase): + def test_simple(self): + """Basic tests for the LinkMap. + """ + link_map = _LinkMap() + + link_map.add_link((1, 1), (2, 1), new=False) + self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)]) + self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)]) + self.assertCountEqual(link_map.get_additions(), []) + self.assertTrue(link_map.exists_path_from((1, 5), (2, 1))) + self.assertFalse(link_map.exists_path_from((1, 5), (2, 2))) + self.assertTrue(link_map.exists_path_from((1, 5), (1, 1))) + self.assertFalse(link_map.exists_path_from((1, 1), (1, 5))) + + # Attempting to add a redundant link is ignored. + self.assertFalse(link_map.add_link((1, 4), (2, 1))) + self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)]) + + # Adding new non-redundant links works + self.assertTrue(link_map.add_link((1, 3), (2, 3))) + self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)]) + + self.assertTrue(link_map.add_link((2, 5), (1, 3))) + self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)]) + self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)]) + + self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)]) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 482506d731..9d04a066d8 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -13,6 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import attr +from parameterized import parameterized + +from synapse.events import _EventInternalMetadata + import tests.unittest import tests.utils @@ -113,7 +118,8 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1])) self.assertTrue(r == [room2] or r == [room3]) - def test_auth_difference(self): + @parameterized.expand([(True,), (False,)]) + def test_auth_difference(self, use_chain_cover_index: bool): room_id = "@ROOM:local" # The silly auth graph we use to test the auth difference algorithm, @@ -159,46 +165,223 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): "j": 1, } + # Mark the room as not having a cover index + + def store_room(txn): + self.store.db_pool.simple_insert_txn( + txn, + "rooms", + { + "room_id": room_id, + "creator": "room_creator_user_id", + "is_public": True, + "room_version": "6", + "has_auth_chain_index": use_chain_cover_index, + }, + ) + + self.get_success(self.store.db_pool.runInteraction("store_room", store_room)) + # We rudely fiddle with the appropriate tables directly, as that's much # easier than constructing events properly. - def insert_event(txn, event_id, stream_ordering): + def insert_event(txn): + stream_ordering = 0 + + for event_id in auth_graph: + stream_ordering += 1 + depth = depth_map[event_id] + + self.store.db_pool.simple_insert_txn( + txn, + table="events", + values={ + "event_id": event_id, + "room_id": room_id, + "depth": depth, + "topological_ordering": depth, + "type": "m.test", + "processed": True, + "outlier": False, + "stream_ordering": stream_ordering, + }, + ) + + self.hs.datastores.persist_events._persist_event_auth_chain_txn( + txn, + [ + FakeEvent(event_id, room_id, auth_graph[event_id]) + for event_id in auth_graph + ], + ) + + self.get_success(self.store.db_pool.runInteraction("insert", insert_event,)) + + # Now actually test that various combinations give the right result: + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c", "e", "f"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}]) + ) + self.assertSetEqual(difference, {"a", "b", "d", "e"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}]) + ) + self.assertSetEqual(difference, set()) + + def test_auth_difference_partial_cover(self): + """Test that we correctly handle rooms where not all events have a chain + cover calculated. This can happen in some obscure edge cases, including + during the background update that calculates the chain cover for old + rooms. + """ + + room_id = "@ROOM:local" + + # The silly auth graph we use to test the auth difference algorithm, + # where the top are the most recent events. + # + # A B + # \ / + # D E + # \ | + # ` F C + # | /| + # G ยด | + # | \ | + # H I + # | | + # K J + + auth_graph = { + "a": ["e"], + "b": ["e"], + "c": ["g", "i"], + "d": ["f"], + "e": ["f"], + "f": ["g"], + "g": ["h", "i"], + "h": ["k"], + "i": ["j"], + "k": [], + "j": [], + } + + depth_map = { + "a": 7, + "b": 7, + "c": 4, + "d": 6, + "e": 6, + "f": 5, + "g": 3, + "h": 2, + "i": 2, + "k": 1, + "j": 1, + } - depth = depth_map[event_id] + # We rudely fiddle with the appropriate tables directly, as that's much + # easier than constructing events properly. + def insert_event(txn): + # First insert the room and mark it as having a chain cover. self.store.db_pool.simple_insert_txn( txn, - table="events", - values={ - "event_id": event_id, + "rooms", + { "room_id": room_id, - "depth": depth, - "topological_ordering": depth, - "type": "m.test", - "processed": True, - "outlier": False, - "stream_ordering": stream_ordering, + "creator": "room_creator_user_id", + "is_public": True, + "room_version": "6", + "has_auth_chain_index": True, }, ) - self.store.db_pool.simple_insert_many_txn( + stream_ordering = 0 + + for event_id in auth_graph: + stream_ordering += 1 + depth = depth_map[event_id] + + self.store.db_pool.simple_insert_txn( + txn, + table="events", + values={ + "event_id": event_id, + "room_id": room_id, + "depth": depth, + "topological_ordering": depth, + "type": "m.test", + "processed": True, + "outlier": False, + "stream_ordering": stream_ordering, + }, + ) + + # Insert all events apart from 'B' + self.hs.datastores.persist_events._persist_event_auth_chain_txn( txn, - table="event_auth", - values=[ - {"event_id": event_id, "room_id": room_id, "auth_id": a} - for a in auth_graph[event_id] + [ + FakeEvent(event_id, room_id, auth_graph[event_id]) + for event_id in auth_graph + if event_id != "b" ], ) - next_stream_ordering = 0 - for event_id in auth_graph: - next_stream_ordering += 1 - self.get_success( - self.store.db_pool.runInteraction( - "insert", insert_event, event_id, next_stream_ordering - ) + # Now we insert the event 'B' without a chain cover, by temporarily + # pretending the room doesn't have a chain cover. + + self.store.db_pool.simple_update_txn( + txn, + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"has_auth_chain_index": False}, + ) + + self.hs.datastores.persist_events._persist_event_auth_chain_txn( + txn, [FakeEvent("b", room_id, auth_graph["b"])], + ) + + self.store.db_pool.simple_update_txn( + txn, + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"has_auth_chain_index": True}, ) + self.get_success(self.store.db_pool.runInteraction("insert", insert_event,)) + # Now actually test that various combinations give the right result: difference = self.get_success( @@ -240,3 +423,21 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): self.store.get_auth_chain_difference(room_id, [{"a"}]) ) self.assertSetEqual(difference, set()) + + +@attr.s +class FakeEvent: + event_id = attr.ib() + room_id = attr.ib() + auth_events = attr.ib() + + type = "foo" + state_key = "foo" + + internal_metadata = _EventInternalMetadata({}) + + def auth_event_ids(self): + return self.auth_events + + def is_state(self): + return True diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py index 0ab0a91483..1184cea5a3 100644 --- a/tests/util/test_itertools.py +++ b/tests/util/test_itertools.py @@ -12,7 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.iterutils import chunk_seq +from typing import Dict, List + +from synapse.util.iterutils import chunk_seq, sorted_topologically from tests.unittest import TestCase @@ -45,3 +47,40 @@ class ChunkSeqTests(TestCase): self.assertEqual( list(parts), [], ) + + +class SortTopologically(TestCase): + def test_empty(self): + "Test that an empty graph works correctly" + + graph = {} # type: Dict[int, List[int]] + self.assertEqual(list(sorted_topologically([], graph)), []) + + def test_disconnected(self): + "Test that a graph with no edges work" + + graph = {1: [], 2: []} # type: Dict[int, List[int]] + + # For disconnected nodes the output is simply sorted. + self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2]) + + def test_linear(self): + "Test that a simple `4 -> 3 -> 2 -> 1` graph works" + + graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]] + + self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) + + def test_subset(self): + "Test that only sorting a subset of the graph works" + graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]] + + self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4]) + + def test_fork(self): + "Test that a forked graph works" + graph = {1: [], 2: [1], 3: [1], 4: [2, 3]} # type: Dict[int, List[int]] + + # Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should + # always get the same one. + self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) -- cgit 1.5.1 From c9195744a4c8196f5900a467d63327ad3a9c9bbc Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 11 Jan 2021 18:01:27 +0000 Subject: Move more encryption endpoints off master (#9068) --- changelog.d/9068.feature | 1 + synapse/app/generic_worker.py | 12 +++- synapse/storage/databases/main/end_to_end_keys.py | 88 +++++++++++------------ 3 files changed, 55 insertions(+), 46 deletions(-) create mode 100644 changelog.d/9068.feature (limited to 'synapse/storage/databases') diff --git a/changelog.d/9068.feature b/changelog.d/9068.feature new file mode 100644 index 0000000000..cdf1844fa7 --- /dev/null +++ b/changelog.d/9068.feature @@ -0,0 +1 @@ +Add experimental support for handling `/keys/claim` and `/room_keys` APIs on worker processes. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index a57535989a..f24c648ac7 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -100,14 +100,18 @@ from synapse.rest.client.v1.profile import ( ) from synapse.rest.client.v1.push_rule import PushRuleRestServlet from synapse.rest.client.v1.voip import VoipRestServlet -from synapse.rest.client.v2_alpha import groups, sync, user_directory +from synapse.rest.client.v2_alpha import groups, room_keys, sync, user_directory from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.client.v2_alpha.account import ThreepidRestServlet from synapse.rest.client.v2_alpha.account_data import ( AccountDataServlet, RoomAccountDataServlet, ) -from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet +from synapse.rest.client.v2_alpha.keys import ( + KeyChangesServlet, + KeyQueryServlet, + OneTimeKeyServlet, +) from synapse.rest.client.v2_alpha.register import RegisterRestServlet from synapse.rest.client.v2_alpha.sendtodevice import SendToDeviceRestServlet from synapse.rest.client.versions import VersionsRestServlet @@ -116,6 +120,7 @@ from synapse.rest.key.v2 import KeyApiV2Resource from synapse.server import HomeServer, cache_in_self from synapse.storage.databases.main.censor_events import CensorEventsStore from synapse.storage.databases.main.client_ips import ClientIpWorkerStore +from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyStore from synapse.storage.databases.main.media_repository import MediaRepositoryStore from synapse.storage.databases.main.metrics import ServerMetricsStore from synapse.storage.databases.main.monthly_active_users import ( @@ -447,6 +452,7 @@ class GenericWorkerSlavedStore( UserDirectoryStore, StatsStore, UIAuthWorkerStore, + EndToEndRoomKeyStore, SlavedDeviceInboxStore, SlavedDeviceStore, SlavedReceiptsStore, @@ -504,6 +510,7 @@ class GenericWorkerServer(HomeServer): LoginRestServlet(self).register(resource) ThreepidRestServlet(self).register(resource) KeyQueryServlet(self).register(resource) + OneTimeKeyServlet(self).register(resource) KeyChangesServlet(self).register(resource) VoipRestServlet(self).register(resource) PushRuleRestServlet(self).register(resource) @@ -521,6 +528,7 @@ class GenericWorkerServer(HomeServer): room.register_servlets(self, resource, True) room.register_deprecated_servlets(self, resource) InitialSyncRestServlet(self).register(resource) + room_keys.register_servlets(self, resource) SendToDeviceRestServlet(self).register(resource) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 4d1b92d1aa..1b6ccd51c8 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -707,50 +707,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): """Get the current stream id from the _device_list_id_gen""" ... - -class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): - async def set_e2e_device_keys( - self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict - ) -> bool: - """Stores device keys for a device. Returns whether there was a change - or the keys were already in the database. - """ - - def _set_e2e_device_keys_txn(txn): - set_tag("user_id", user_id) - set_tag("device_id", device_id) - set_tag("time_now", time_now) - set_tag("device_keys", device_keys) - - old_key_json = self.db_pool.simple_select_one_onecol_txn( - txn, - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - retcol="key_json", - allow_none=True, - ) - - # In py3 we need old_key_json to match new_key_json type. The DB - # returns unicode while encode_canonical_json returns bytes. - new_key_json = encode_canonical_json(device_keys).decode("utf-8") - - if old_key_json == new_key_json: - log_kv({"Message": "Device key already stored."}) - return False - - self.db_pool.simple_upsert_txn( - txn, - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - values={"ts_added_ms": time_now, "key_json": new_key_json}, - ) - log_kv({"message": "Device keys stored."}) - return True - - return await self.db_pool.runInteraction( - "set_e2e_device_keys", _set_e2e_device_keys_txn - ) - async def claim_e2e_one_time_keys( self, query_list: Iterable[Tuple[str, str, str]] ) -> Dict[str, Dict[str, Dict[str, bytes]]]: @@ -840,6 +796,50 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): "claim_e2e_one_time_keys", _claim_e2e_one_time_keys ) + +class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): + async def set_e2e_device_keys( + self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict + ) -> bool: + """Stores device keys for a device. Returns whether there was a change + or the keys were already in the database. + """ + + def _set_e2e_device_keys_txn(txn): + set_tag("user_id", user_id) + set_tag("device_id", device_id) + set_tag("time_now", time_now) + set_tag("device_keys", device_keys) + + old_key_json = self.db_pool.simple_select_one_onecol_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcol="key_json", + allow_none=True, + ) + + # In py3 we need old_key_json to match new_key_json type. The DB + # returns unicode while encode_canonical_json returns bytes. + new_key_json = encode_canonical_json(device_keys).decode("utf-8") + + if old_key_json == new_key_json: + log_kv({"Message": "Device key already stored."}) + return False + + self.db_pool.simple_upsert_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + values={"ts_added_ms": time_now, "key_json": new_key_json}, + ) + log_kv({"message": "Device keys stored."}) + return True + + return await self.db_pool.runInteraction( + "set_e2e_device_keys", _set_e2e_device_keys_txn + ) + async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None: def delete_e2e_keys_by_device_txn(txn): log_kv( -- cgit 1.5.1 From 7a2e9b549defe3f55531711a863183a33e7af83c Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 12 Jan 2021 22:30:15 +0100 Subject: Remove user's avatar URL and displayname when deactivated. (#8932) This only applies if the user's data is to be erased. --- changelog.d/8932.feature | 1 + docs/admin_api/user_admin_api.rst | 21 +++ synapse/handlers/deactivate_account.py | 18 ++- synapse/handlers/profile.py | 8 +- synapse/rest/admin/users.py | 22 ++- synapse/rest/client/v2_alpha/account.py | 7 +- synapse/server.py | 2 +- synapse/storage/databases/main/profile.py | 2 +- tests/handlers/test_profile.py | 30 ++++ tests/rest/admin/test_user.py | 220 ++++++++++++++++++++++++++++++ tests/rest/client/v1/test_login.py | 5 +- tests/rest/client/v1/test_rooms.py | 6 +- tests/storage/test_profile.py | 26 ++++ 13 files changed, 351 insertions(+), 17 deletions(-) create mode 100644 changelog.d/8932.feature (limited to 'synapse/storage/databases') diff --git a/changelog.d/8932.feature b/changelog.d/8932.feature new file mode 100644 index 0000000000..a1d17394d7 --- /dev/null +++ b/changelog.d/8932.feature @@ -0,0 +1 @@ +Remove a user's avatar URL and display name when deactivated with the Admin API. diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index 3115951e1f..b3d413cf57 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -98,6 +98,8 @@ Body parameters: - ``deactivated``, optional. If unspecified, deactivation state will be left unchanged on existing accounts and set to ``false`` for new accounts. + A user cannot be erased by deactivating with this API. For details on deactivating users see + `Deactivate Account <#deactivate-account>`_. If the user already exists then optional parameters default to the current value. @@ -248,6 +250,25 @@ server admin: see `README.rst `_. The erase parameter is optional and defaults to ``false``. An empty body may be passed for backwards compatibility. +The following actions are performed when deactivating an user: + +- Try to unpind 3PIDs from the identity server +- Remove all 3PIDs from the homeserver +- Delete all devices and E2EE keys +- Delete all access tokens +- Delete the password hash +- Removal from all rooms the user is a member of +- Remove the user from the user directory +- Reject all pending invites +- Remove all account validity information related to the user + +The following additional actions are performed during deactivation if``erase`` +is set to ``true``: + +- Remove the user's display name +- Remove the user's avatar URL +- Mark the user as erased + Reset password ============== diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index e808142365..c4a3b26a84 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional from synapse.api.errors import SynapseError from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import UserID, create_requester +from synapse.types import Requester, UserID, create_requester from ._base import BaseHandler @@ -38,6 +38,7 @@ class DeactivateAccountHandler(BaseHandler): self._device_handler = hs.get_device_handler() self._room_member_handler = hs.get_room_member_handler() self._identity_handler = hs.get_identity_handler() + self._profile_handler = hs.get_profile_handler() self.user_directory_handler = hs.get_user_directory_handler() self._server_name = hs.hostname @@ -52,16 +53,23 @@ class DeactivateAccountHandler(BaseHandler): self._account_validity_enabled = hs.config.account_validity.enabled async def deactivate_account( - self, user_id: str, erase_data: bool, id_server: Optional[str] = None + self, + user_id: str, + erase_data: bool, + requester: Requester, + id_server: Optional[str] = None, + by_admin: bool = False, ) -> bool: """Deactivate a user's account Args: user_id: ID of user to be deactivated erase_data: whether to GDPR-erase the user's data + requester: The user attempting to make this change. id_server: Use the given identity server when unbinding any threepids. If None then will attempt to unbind using the identity server specified when binding (if known). + by_admin: Whether this change was made by an administrator. Returns: True if identity server supports removing threepids, otherwise False. @@ -121,6 +129,12 @@ class DeactivateAccountHandler(BaseHandler): # Mark the user as erased, if they asked for that if erase_data: + user = UserID.from_string(user_id) + # Remove avatar URL from this user + await self._profile_handler.set_avatar_url(user, requester, "", by_admin) + # Remove displayname from this user + await self._profile_handler.set_displayname(user, requester, "", by_admin) + logger.info("Marking %s as erased", user_id) await self.store.mark_user_erased(user_id) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 36f9ee4b71..c02b951031 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -286,13 +286,19 @@ class ProfileHandler(BaseHandler): 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) ) + avatar_url_to_set = new_avatar_url # type: Optional[str] + if new_avatar_url == "": + avatar_url_to_set = None + # Same like set_displayname if by_admin: requester = create_requester( target_user, authenticated_entity=requester.authenticated_entity ) - await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url) + await self.store.set_profile_avatar_url( + target_user.localpart, avatar_url_to_set + ) if self.hs.config.user_directory_search_all_users: profile = await self.store.get_profileinfo(target_user.localpart) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index f8a73e7d9d..f39e3d6d5c 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -244,7 +244,7 @@ class UserRestServletV2(RestServlet): if deactivate and not user["deactivated"]: await self.deactivate_account_handler.deactivate_account( - target_user.to_string(), False + target_user.to_string(), False, requester, by_admin=True ) elif not deactivate and user["deactivated"]: if "password" not in body: @@ -486,12 +486,22 @@ class WhoisRestServlet(RestServlet): class DeactivateAccountRestServlet(RestServlet): PATTERNS = admin_patterns("/deactivate/(?P[^/]*)") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self._deactivate_account_handler = hs.get_deactivate_account_handler() self.auth = hs.get_auth() + self.is_mine = hs.is_mine + self.store = hs.get_datastore() + + async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + if not self.is_mine(UserID.from_string(target_user_id)): + raise SynapseError(400, "Can only deactivate local users") + + if not await self.store.get_user_by_id(target_user_id): + raise NotFoundError("User not found") - async def on_POST(self, request, target_user_id): - await assert_requester_is_admin(self.auth, request) body = parse_json_object_from_request(request, allow_empty_body=True) erase = body.get("erase", False) if not isinstance(erase, bool): @@ -501,10 +511,8 @@ class DeactivateAccountRestServlet(RestServlet): Codes.BAD_JSON, ) - UserID.from_string(target_user_id) - result = await self._deactivate_account_handler.deactivate_account( - target_user_id, erase + target_user_id, erase, requester, by_admin=True ) if result: id_server_unbind_result = "success" diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 3b50dc885f..65e68d641b 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -305,7 +305,7 @@ class DeactivateAccountRestServlet(RestServlet): # allow ASes to deactivate their own users if requester.app_service: await self._deactivate_account_handler.deactivate_account( - requester.user.to_string(), erase + requester.user.to_string(), erase, requester ) return 200, {} @@ -313,7 +313,10 @@ class DeactivateAccountRestServlet(RestServlet): requester, request, body, "deactivate your account", ) result = await self._deactivate_account_handler.deactivate_account( - requester.user.to_string(), erase, id_server=body.get("id_server") + requester.user.to_string(), + erase, + requester, + id_server=body.get("id_server"), ) if result: id_server_unbind_result = "success" diff --git a/synapse/server.py b/synapse/server.py index 12da92b63c..d4c235cda5 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -501,7 +501,7 @@ class HomeServer(metaclass=abc.ABCMeta): return InitialSyncHandler(self) @cache_in_self - def get_profile_handler(self): + def get_profile_handler(self) -> ProfileHandler: return ProfileHandler(self) @cache_in_self diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 0e25ca3d7a..54ef0f1f54 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -82,7 +82,7 @@ class ProfileWorkerStore(SQLBaseStore): ) async def set_profile_avatar_url( - self, user_localpart: str, new_avatar_url: str + self, user_localpart: str, new_avatar_url: Optional[str] ) -> None: await self.db_pool.simple_update_one( table="profiles", diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 919547556b..022943a10a 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -105,6 +105,21 @@ class ProfileTestCase(unittest.TestCase): "Frank", ) + # Set displayname to an empty string + yield defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "" + ) + ) + + self.assertIsNone( + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.frank.localpart) + ) + ) + ) + @defer.inlineCallbacks def test_set_my_name_if_disabled(self): self.hs.config.enable_set_displayname = False @@ -223,6 +238,21 @@ class ProfileTestCase(unittest.TestCase): "http://my.server/me.png", ) + # Set avatar to an empty string + yield defer.ensureDeferred( + self.handler.set_avatar_url( + self.frank, synapse.types.create_requester(self.frank), "", + ) + ) + + self.assertIsNone( + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.frank.localpart) + ) + ), + ) + @defer.inlineCallbacks def test_set_my_avatar_if_disabled(self): self.hs.config.enable_set_avatar_url = False diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index ad4588c1da..04599c2fcf 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -588,6 +588,200 @@ class UsersListTestCase(unittest.HomeserverTestCase): _search_test(None, "bar", "user_id") +class DeactivateAccountTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass", displayname="User1") + self.other_user_token = self.login("user", "pass") + self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote( + self.other_user + ) + self.url = "/_synapse/admin/v1/deactivate/%s" % urllib.parse.quote( + self.other_user + ) + + # set attributes for user + self.get_success( + self.store.set_profile_avatar_url("user", "mxc://servername/mediaid") + ) + self.get_success( + self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0) + ) + + def test_no_auth(self): + """ + Try to deactivate users without authentication. + """ + channel = self.make_request("POST", self.url, b"{}") + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_not_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + url = "/_synapse/admin/v1/deactivate/@bob:test" + + channel = self.make_request("POST", url, access_token=self.other_user_token) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("You are not a server admin", channel.json_body["error"]) + + channel = self.make_request( + "POST", url, access_token=self.other_user_token, content=b"{}", + ) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("You are not a server admin", channel.json_body["error"]) + + def test_user_does_not_exist(self): + """ + Tests that deactivation for a user that does not exist returns a 404 + """ + + channel = self.make_request( + "POST", + "/_synapse/admin/v1/deactivate/@unknown_person:test", + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_erase_is_not_bool(self): + """ + If parameter `erase` is not boolean, return an error + """ + body = json.dumps({"erase": "False"}) + + channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that deactivation for a user that is not a local returns a 400 + """ + url = "/_synapse/admin/v1/deactivate/@unknown_person:unknown_domain" + + channel = self.make_request("POST", url, access_token=self.admin_user_tok) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only deactivate local users", channel.json_body["error"]) + + def test_deactivate_user_erase_true(self): + """ + Test deactivating an user and set `erase` to `true` + """ + + # Get user + channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(False, channel.json_body["deactivated"]) + self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User1", channel.json_body["displayname"]) + + # Deactivate user + body = json.dumps({"erase": True}) + + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["deactivated"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self.assertIsNone(channel.json_body["avatar_url"]) + self.assertIsNone(channel.json_body["displayname"]) + + self._is_erased("@user:test", True) + + def test_deactivate_user_erase_false(self): + """ + Test deactivating an user and set `erase` to `false` + """ + + # Get user + channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(False, channel.json_body["deactivated"]) + self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User1", channel.json_body["displayname"]) + + # Deactivate user + body = json.dumps({"erase": False}) + + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["deactivated"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User1", channel.json_body["displayname"]) + + self._is_erased("@user:test", False) + + def _is_erased(self, user_id: str, expect: bool) -> None: + """Assert that the user is erased or not + """ + d = self.store.is_user_erased(user_id) + if expect: + self.assertTrue(self.get_success(d)) + else: + self.assertFalse(self.get_success(d)) + + class UserRestTestCase(unittest.HomeserverTestCase): servlets = [ @@ -987,6 +1181,26 @@ class UserRestTestCase(unittest.HomeserverTestCase): Test deactivating another user. """ + # set attributes for user + self.get_success( + self.store.set_profile_avatar_url("user", "mxc://servername/mediaid") + ) + self.get_success( + self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0) + ) + + # Get user + channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(False, channel.json_body["deactivated"]) + self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User", channel.json_body["displayname"]) + # Deactivate user body = json.dumps({"deactivated": True}) @@ -1000,6 +1214,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User", channel.json_body["displayname"]) # the user is deactivated, the threepid will be deleted # Get user @@ -1010,6 +1227,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User", channel.json_body["displayname"]) @override_config({"user_directory": {"enabled": True, "search_all_users": True}}) def test_change_name_deactivate_user_user_directory(self): diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 1d1dc9f8a2..f9b8011961 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -30,6 +30,7 @@ from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import devices, register from synapse.rest.client.v2_alpha.account import WhoamiRestServlet from synapse.rest.synapse.client.pick_idp import PickIdpResource +from synapse.types import create_requester from tests import unittest from tests.handlers.test_oidc import HAS_OIDC @@ -667,7 +668,9 @@ class CASTestCase(unittest.HomeserverTestCase): # Deactivate the account. self.get_success( - self.deactivate_account_handler.deactivate_account(self.user_id, False) + self.deactivate_account_handler.deactivate_account( + self.user_id, False, create_requester(self.user_id) + ) ) # Request the CAS ticket. diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 6105eac47c..d4e3165436 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -29,7 +29,7 @@ from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin from synapse.rest.client.v1 import directory, login, profile, room from synapse.rest.client.v2_alpha import account -from synapse.types import JsonDict, RoomAlias, UserID +from synapse.types import JsonDict, RoomAlias, UserID, create_requester from synapse.util.stringutils import random_string from tests import unittest @@ -1687,7 +1687,9 @@ class ContextTestCase(unittest.HomeserverTestCase): deactivate_account_handler = self.hs.get_deactivate_account_handler() self.get_success( - deactivate_account_handler.deactivate_account(self.user_id, erase_data=True) + deactivate_account_handler.deactivate_account( + self.user_id, True, create_requester(self.user_id) + ) ) # Invite another user in the room. This is needed because messages will be diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 3fd0a38cf5..ea63bd56b4 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -48,6 +48,19 @@ class ProfileStoreTestCase(unittest.TestCase): ), ) + # test set to None + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.u_frank.localpart, None) + ) + + self.assertIsNone( + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.u_frank.localpart) + ) + ) + ) + @defer.inlineCallbacks def test_avatar_url(self): yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) @@ -66,3 +79,16 @@ class ProfileStoreTestCase(unittest.TestCase): ) ), ) + + # test set to None + yield defer.ensureDeferred( + self.store.set_profile_avatar_url(self.u_frank.localpart, None) + ) + + self.assertIsNone( + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.u_frank.localpart) + ) + ) + ) -- cgit 1.5.1