From a7bdf98d01d2225a479753a85ba81adf02b16a32 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 5 Aug 2020 21:38:57 +0100 Subject: Rename database classes to make some sense (#8033) --- synapse/storage/databases/state/__init__.py | 16 + synapse/storage/databases/state/bg_updates.py | 374 ++++++++++++ .../state/schema/delta/23/drop_state_index.sql | 16 + .../state/schema/delta/30/state_stream.sql | 33 ++ .../state/schema/delta/32/remove_state_indices.sql | 19 + .../state/schema/delta/35/add_state_index.sql | 17 + .../databases/state/schema/delta/35/state.sql | 22 + .../state/schema/delta/35/state_dedupe.sql | 17 + .../state/schema/delta/47/state_group_seq.py | 34 ++ .../state/schema/delta/56/state_group_room_idx.sql | 17 + .../state/schema/full_schemas/54/full.sql | 37 ++ .../schema/full_schemas/54/sequence.sql.postgres | 21 + synapse/storage/databases/state/store.py | 644 +++++++++++++++++++++ 13 files changed, 1267 insertions(+) create mode 100644 synapse/storage/databases/state/__init__.py create mode 100644 synapse/storage/databases/state/bg_updates.py create mode 100644 synapse/storage/databases/state/schema/delta/23/drop_state_index.sql create mode 100644 synapse/storage/databases/state/schema/delta/30/state_stream.sql create mode 100644 synapse/storage/databases/state/schema/delta/32/remove_state_indices.sql create mode 100644 synapse/storage/databases/state/schema/delta/35/add_state_index.sql create mode 100644 synapse/storage/databases/state/schema/delta/35/state.sql create mode 100644 synapse/storage/databases/state/schema/delta/35/state_dedupe.sql create mode 100644 synapse/storage/databases/state/schema/delta/47/state_group_seq.py create mode 100644 synapse/storage/databases/state/schema/delta/56/state_group_room_idx.sql create mode 100644 synapse/storage/databases/state/schema/full_schemas/54/full.sql create mode 100644 synapse/storage/databases/state/schema/full_schemas/54/sequence.sql.postgres create mode 100644 synapse/storage/databases/state/store.py (limited to 'synapse/storage/databases/state') diff --git a/synapse/storage/databases/state/__init__.py b/synapse/storage/databases/state/__init__.py new file mode 100644 index 0000000000..c90d022899 --- /dev/null +++ b/synapse/storage/databases/state/__init__.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.databases.state.store import StateGroupDataStore # noqa: F401 diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py new file mode 100644 index 0000000000..1e2d584098 --- /dev/null +++ b/synapse/storage/databases/state/bg_updates.py @@ -0,0 +1,374 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import DatabasePool +from synapse.storage.engines import PostgresEngine +from synapse.storage.state import StateFilter + +logger = logging.getLogger(__name__) + + +MAX_STATE_DELTA_HOPS = 100 + + +class StateGroupBackgroundUpdateStore(SQLBaseStore): + """Defines functions related to state groups needed to run the state backgroud + updates. + """ + + def _count_state_group_hops_txn(self, txn, state_group): + """Given a state group, count how many hops there are in the tree. + + This is used to ensure the delta chains don't get too long. + """ + if isinstance(self.database_engine, PostgresEngine): + sql = """ + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT count(*) FROM state; + """ + + txn.execute(sql, (state_group,)) + row = txn.fetchone() + if row and row[0]: + return row[0] + else: + return 0 + else: + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + next_group = state_group + count = 0 + + while next_group: + next_group = self.db_pool.simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + if next_group: + count += 1 + + return count + + def _get_state_groups_from_groups_txn( + self, txn, groups, state_filter=StateFilter.all() + ): + results = {group: {} for group in groups} + + where_clause, where_args = state_filter.make_sql_filter_clause() + + # Unless the filter clause is empty, we're going to append it after an + # existing where clause + if where_clause: + where_clause = " AND (%s)" % (where_clause,) + + if isinstance(self.database_engine, PostgresEngine): + # Temporarily disable sequential scans in this transaction. This is + # a temporary hack until we can add the right indices in + txn.execute("SET LOCAL enable_seqscan=off") + + # The below query walks the state_group tree so that the "state" + # table includes all state_groups in the tree. It then joins + # against `state_groups_state` to fetch the latest state. + # It assumes that previous state groups are always numerically + # lesser. + # The PARTITION is used to get the event_id in the greatest state + # group for the given type, state_key. + # This may return multiple rows per (type, state_key), but last_value + # should be the same. + sql = """ + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT DISTINCT ON (type, state_key) + type, state_key, event_id + FROM state_groups_state + WHERE state_group IN ( + SELECT state_group FROM state + ) %s + ORDER BY type, state_key, state_group DESC + """ + + for group in groups: + args = [group] + args.extend(where_args) + + txn.execute(sql % (where_clause,), args) + for row in txn: + typ, state_key, event_id = row + key = (typ, state_key) + results[group][key] = event_id + else: + max_entries_returned = state_filter.max_entries_returned() + + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + for group in groups: + next_group = group + + while next_group: + # We did this before by getting the list of group ids, and + # then passing that list to sqlite to get latest event for + # each (type, state_key). However, that was terribly slow + # without the right indices (which we can't add until + # after we finish deduping state, which requires this func) + args = [next_group] + args.extend(where_args) + + txn.execute( + "SELECT type, state_key, event_id FROM state_groups_state" + " WHERE state_group = ? " + where_clause, + args, + ) + results[group].update( + ((typ, state_key), event_id) + for typ, state_key, event_id in txn + if (typ, state_key) not in results[group] + ) + + # If the number of entries in the (type,state_key)->event_id dict + # matches the number of (type,state_keys) types we were searching + # for, then we must have found them all, so no need to go walk + # further down the tree... UNLESS our types filter contained + # wildcards (i.e. Nones) in which case we have to do an exhaustive + # search + if ( + max_entries_returned is not None + and len(results[group]) == max_entries_returned + ): + break + + next_group = self.db_pool.simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + + return results + + +class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): + + STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" + STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" + STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx" + + def __init__(self, database: DatabasePool, db_conn, hs): + super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs) + self.db_pool.updates.register_background_update_handler( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, + self._background_deduplicate_state, + ) + self.db_pool.updates.register_background_update_handler( + self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state + ) + self.db_pool.updates.register_background_index_update( + self.STATE_GROUPS_ROOM_INDEX_UPDATE_NAME, + index_name="state_groups_room_id_idx", + table="state_groups", + columns=["room_id"], + ) + + @defer.inlineCallbacks + def _background_deduplicate_state(self, progress, batch_size): + """This background update will slowly deduplicate state by reencoding + them as deltas. + """ + last_state_group = progress.get("last_state_group", 0) + rows_inserted = progress.get("rows_inserted", 0) + max_group = progress.get("max_group", None) + + BATCH_SIZE_SCALE_FACTOR = 100 + + batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) + + if max_group is None: + rows = yield self.db_pool.execute( + "_background_deduplicate_state", + None, + "SELECT coalesce(max(id), 0) FROM state_groups", + ) + max_group = rows[0][0] + + def reindex_txn(txn): + new_last_state_group = last_state_group + for count in range(batch_size): + txn.execute( + "SELECT id, room_id FROM state_groups" + " WHERE ? < id AND id <= ?" + " ORDER BY id ASC" + " LIMIT 1", + (new_last_state_group, max_group), + ) + row = txn.fetchone() + if row: + state_group, room_id = row + + if not row or not state_group: + return True, count + + txn.execute( + "SELECT state_group FROM state_group_edges" + " WHERE state_group = ?", + (state_group,), + ) + + # If we reach a point where we've already started inserting + # edges we should stop. + if txn.fetchall(): + return True, count + + txn.execute( + "SELECT coalesce(max(id), 0) FROM state_groups" + " WHERE id < ? AND room_id = ?", + (state_group, room_id), + ) + (prev_group,) = txn.fetchone() + new_last_state_group = state_group + + if prev_group: + potential_hops = self._count_state_group_hops_txn(txn, prev_group) + if potential_hops >= MAX_STATE_DELTA_HOPS: + # We want to ensure chains are at most this long,# + # otherwise read performance degrades. + continue + + prev_state = self._get_state_groups_from_groups_txn( + txn, [prev_group] + ) + prev_state = prev_state[prev_group] + + curr_state = self._get_state_groups_from_groups_txn( + txn, [state_group] + ) + curr_state = curr_state[state_group] + + if not set(prev_state.keys()) - set(curr_state.keys()): + # We can only do a delta if the current has a strict super set + # of keys + + delta_state = { + key: value + for key, value in curr_state.items() + if prev_state.get(key, None) != value + } + + self.db_pool.simple_delete_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": state_group}, + ) + + self.db_pool.simple_insert_txn( + txn, + table="state_group_edges", + values={ + "state_group": state_group, + "prev_state_group": prev_group, + }, + ) + + self.db_pool.simple_delete_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": state_group}, + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in delta_state.items() + ], + ) + + progress = { + "last_state_group": state_group, + "rows_inserted": rows_inserted + batch_size, + "max_group": max_group, + } + + self.db_pool.updates._background_update_progress_txn( + txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress + ) + + return False, batch_size + + finished, result = yield self.db_pool.runInteraction( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn + ) + + if finished: + yield self.db_pool.updates._end_background_update( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME + ) + + return result * BATCH_SIZE_SCALE_FACTOR + + @defer.inlineCallbacks + def _background_index_state(self, progress, batch_size): + def reindex_txn(conn): + conn.rollback() + if isinstance(self.database_engine, PostgresEngine): + # postgres insists on autocommit for the index + conn.set_session(autocommit=True) + try: + txn = conn.cursor() + txn.execute( + "CREATE INDEX CONCURRENTLY state_groups_state_type_idx" + " ON state_groups_state(state_group, type, state_key)" + ) + txn.execute("DROP INDEX IF EXISTS state_groups_state_id") + finally: + conn.set_session(autocommit=False) + else: + txn = conn.cursor() + txn.execute( + "CREATE INDEX state_groups_state_type_idx" + " ON state_groups_state(state_group, type, state_key)" + ) + txn.execute("DROP INDEX IF EXISTS state_groups_state_id") + + yield self.db_pool.runWithConnection(reindex_txn) + + yield self.db_pool.updates._end_background_update( + self.STATE_GROUP_INDEX_UPDATE_NAME + ) + + return 1 diff --git a/synapse/storage/databases/state/schema/delta/23/drop_state_index.sql b/synapse/storage/databases/state/schema/delta/23/drop_state_index.sql new file mode 100644 index 0000000000..ae09fa0065 --- /dev/null +++ b/synapse/storage/databases/state/schema/delta/23/drop_state_index.sql @@ -0,0 +1,16 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +DROP INDEX IF EXISTS state_groups_state_tuple; diff --git a/synapse/storage/databases/state/schema/delta/30/state_stream.sql b/synapse/storage/databases/state/schema/delta/30/state_stream.sql new file mode 100644 index 0000000000..e85699e82e --- /dev/null +++ b/synapse/storage/databases/state/schema/delta/30/state_stream.sql @@ -0,0 +1,33 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +/* We used to create a table called current_state_resets, but this is no + * longer used and is removed in delta 54. + */ + +/* The outlier events that have aquired a state group typically through + * backfill. This is tracked separately to the events table, as assigning a + * state group change the position of the existing event in the stream + * ordering. + * However since a stream_ordering is assigned in persist_event for the + * (event, state) pair, we can use that stream_ordering to identify when + * the new state was assigned for the event. + */ +CREATE TABLE IF NOT EXISTS ex_outlier_stream( + event_stream_ordering BIGINT PRIMARY KEY NOT NULL, + event_id TEXT NOT NULL, + state_group BIGINT NOT NULL +); diff --git a/synapse/storage/databases/state/schema/delta/32/remove_state_indices.sql b/synapse/storage/databases/state/schema/delta/32/remove_state_indices.sql new file mode 100644 index 0000000000..1450313bfa --- /dev/null +++ b/synapse/storage/databases/state/schema/delta/32/remove_state_indices.sql @@ -0,0 +1,19 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- The following indices are redundant, other indices are equivalent or +-- supersets +DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY diff --git a/synapse/storage/databases/state/schema/delta/35/add_state_index.sql b/synapse/storage/databases/state/schema/delta/35/add_state_index.sql new file mode 100644 index 0000000000..33980d02f0 --- /dev/null +++ b/synapse/storage/databases/state/schema/delta/35/add_state_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json, depends_on) + VALUES ('state_group_state_type_index', '{}', 'state_group_state_deduplication'); diff --git a/synapse/storage/databases/state/schema/delta/35/state.sql b/synapse/storage/databases/state/schema/delta/35/state.sql new file mode 100644 index 0000000000..0f1fa68a89 --- /dev/null +++ b/synapse/storage/databases/state/schema/delta/35/state.sql @@ -0,0 +1,22 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE state_group_edges( + state_group BIGINT NOT NULL, + prev_state_group BIGINT NOT NULL +); + +CREATE INDEX state_group_edges_idx ON state_group_edges(state_group); +CREATE INDEX state_group_edges_prev_idx ON state_group_edges(prev_state_group); diff --git a/synapse/storage/databases/state/schema/delta/35/state_dedupe.sql b/synapse/storage/databases/state/schema/delta/35/state_dedupe.sql new file mode 100644 index 0000000000..97e5067ef4 --- /dev/null +++ b/synapse/storage/databases/state/schema/delta/35/state_dedupe.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('state_group_state_deduplication', '{}'); diff --git a/synapse/storage/databases/state/schema/delta/47/state_group_seq.py b/synapse/storage/databases/state/schema/delta/47/state_group_seq.py new file mode 100644 index 0000000000..9fd1ccf6f7 --- /dev/null +++ b/synapse/storage/databases/state/schema/delta/47/state_group_seq.py @@ -0,0 +1,34 @@ +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.engines import PostgresEngine + + +def run_create(cur, database_engine, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + # if we already have some state groups, we want to start making new + # ones with a higher id. + cur.execute("SELECT max(id) FROM state_groups") + row = cur.fetchone() + + if row[0] is None: + start_val = 1 + else: + start_val = row[0] + 1 + + cur.execute("CREATE SEQUENCE state_group_id_seq START WITH %s", (start_val,)) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/databases/state/schema/delta/56/state_group_room_idx.sql b/synapse/storage/databases/state/schema/delta/56/state_group_room_idx.sql new file mode 100644 index 0000000000..7916ef18b2 --- /dev/null +++ b/synapse/storage/databases/state/schema/delta/56/state_group_room_idx.sql @@ -0,0 +1,17 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('state_groups_room_id_idx', '{}'); diff --git a/synapse/storage/databases/state/schema/full_schemas/54/full.sql b/synapse/storage/databases/state/schema/full_schemas/54/full.sql new file mode 100644 index 0000000000..35f97d6b3d --- /dev/null +++ b/synapse/storage/databases/state/schema/full_schemas/54/full.sql @@ -0,0 +1,37 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE state_groups ( + id BIGINT PRIMARY KEY, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE state_groups_state ( + state_group BIGINT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE state_group_edges ( + state_group BIGINT NOT NULL, + prev_state_group BIGINT NOT NULL +); + +CREATE INDEX state_group_edges_idx ON state_group_edges (state_group); +CREATE INDEX state_group_edges_prev_idx ON state_group_edges (prev_state_group); +CREATE INDEX state_groups_state_type_idx ON state_groups_state (state_group, type, state_key); diff --git a/synapse/storage/databases/state/schema/full_schemas/54/sequence.sql.postgres b/synapse/storage/databases/state/schema/full_schemas/54/sequence.sql.postgres new file mode 100644 index 0000000000..fcd926c9fb --- /dev/null +++ b/synapse/storage/databases/state/schema/full_schemas/54/sequence.sql.postgres @@ -0,0 +1,21 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE state_group_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1; diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py new file mode 100644 index 0000000000..7f104ad936 --- /dev/null +++ b/synapse/storage/databases/state/store.py @@ -0,0 +1,644 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections import namedtuple +from typing import Dict, Iterable, List, Set, Tuple + +from twisted.internet import defer + +from synapse.api.constants import EventTypes +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import DatabasePool +from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore +from synapse.storage.state import StateFilter +from synapse.storage.types import Cursor +from synapse.storage.util.sequence import build_sequence_generator +from synapse.types import StateMap +from synapse.util.caches.descriptors import cached +from synapse.util.caches.dictionary_cache import DictionaryCache + +logger = logging.getLogger(__name__) + + +MAX_STATE_DELTA_HOPS = 100 + + +class _GetStateGroupDelta( + namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) +): + """Return type of get_state_group_delta that implements __len__, which lets + us use the itrable flag when caching + """ + + __slots__ = [] + + def __len__(self): + return len(self.delta_ids) if self.delta_ids else 0 + + +class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): + """A data store for fetching/storing state groups. + """ + + def __init__(self, database: DatabasePool, db_conn, hs): + super(StateGroupDataStore, self).__init__(database, db_conn, hs) + + # Originally the state store used a single DictionaryCache to cache the + # event IDs for the state types in a given state group to avoid hammering + # on the state_group* tables. + # + # The point of using a DictionaryCache is that it can cache a subset + # of the state events for a given state group (i.e. a subset of the keys for a + # given dict which is an entry in the cache for a given state group ID). + # + # However, this poses problems when performing complicated queries + # on the store - for instance: "give me all the state for this group, but + # limit members to this subset of users", as DictionaryCache's API isn't + # rich enough to say "please cache any of these fields, apart from this subset". + # This is problematic when lazy loading members, which requires this behaviour, + # as without it the cache has no choice but to speculatively load all + # state events for the group, which negates the efficiency being sought. + # + # Rather than overcomplicating DictionaryCache's API, we instead split the + # state_group_cache into two halves - one for tracking non-member events, + # and the other for tracking member_events. This means that lazy loading + # queries can be made in a cache-friendly manner by querying both caches + # separately and then merging the result. So for the example above, you + # would query the members cache for a specific subset of state keys + # (which DictionaryCache will handle efficiently and fine) and the non-members + # cache for all state (which DictionaryCache will similarly handle fine) + # and then just merge the results together. + # + # We size the non-members cache to be smaller than the members cache as the + # vast majority of state in Matrix (today) is member events. + + self._state_group_cache = DictionaryCache( + "*stateGroupCache*", + # TODO: this hasn't been tuned yet + 50000, + ) + self._state_group_members_cache = DictionaryCache( + "*stateGroupMembersCache*", 500000, + ) + + def get_max_state_group_txn(txn: Cursor): + txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") + return txn.fetchone()[0] + + self._state_group_seq_gen = build_sequence_generator( + self.database_engine, get_max_state_group_txn, "state_group_id_seq" + ) + + @cached(max_entries=10000, iterable=True) + def get_state_group_delta(self, state_group): + """Given a state group try to return a previous group and a delta between + the old and the new. + + Returns: + (prev_group, delta_ids), where both may be None. + """ + + def _get_state_group_delta_txn(txn): + prev_group = self.db_pool.simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": state_group}, + retcol="prev_state_group", + allow_none=True, + ) + + if not prev_group: + return _GetStateGroupDelta(None, None) + + delta_ids = self.db_pool.simple_select_list_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": state_group}, + retcols=("type", "state_key", "event_id"), + ) + + return _GetStateGroupDelta( + prev_group, + {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, + ) + + return self.db_pool.runInteraction( + "get_state_group_delta", _get_state_group_delta_txn + ) + + async def _get_state_groups_from_groups( + self, groups: List[int], state_filter: StateFilter + ) -> Dict[int, StateMap[str]]: + """Returns the state groups for a given set of groups from the + database, filtering on types of state events. + + Args: + groups: list of state group IDs to query + state_filter: The state filter used to fetch state + from the database. + Returns: + Dict of state group to state map. + """ + results = {} + + chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] + for chunk in chunks: + res = await self.db_pool.runInteraction( + "_get_state_groups_from_groups", + self._get_state_groups_from_groups_txn, + chunk, + state_filter, + ) + results.update(res) + + return results + + def _get_state_for_group_using_cache(self, cache, group, state_filter): + """Checks if group is in cache. See `_get_state_for_groups` + + Args: + cache(DictionaryCache): the state group cache to use + group(int): The state group to lookup + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns 2-tuple (`state_dict`, `got_all`). + `got_all` is a bool indicating if we successfully retrieved all + requests state from the cache, if False we need to query the DB for the + missing state. + """ + is_all, known_absent, state_dict_ids = cache.get(group) + + if is_all or state_filter.is_full(): + # Either we have everything or want everything, either way + # `is_all` tells us whether we've gotten everything. + return state_filter.filter_state(state_dict_ids), is_all + + # tracks whether any of our requested types are missing from the cache + missing_types = False + + if state_filter.has_wildcards(): + # We don't know if we fetched all the state keys for the types in + # the filter that are wildcards, so we have to assume that we may + # have missed some. + missing_types = True + else: + # There aren't any wild cards, so `concrete_types()` returns the + # complete list of event types we're wanting. + for key in state_filter.concrete_types(): + if key not in state_dict_ids and key not in known_absent: + missing_types = True + break + + return state_filter.filter_state(state_dict_ids), not missing_types + + async def _get_state_for_groups( + self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() + ) -> Dict[int, StateMap[str]]: + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key + + Args: + groups: list of state groups for which we want + to get the state. + state_filter: The state filter used to fetch state + from the database. + Returns: + Dict of state group to state map. + """ + + member_filter, non_member_filter = state_filter.get_member_split() + + # Now we look them up in the member and non-member caches + ( + non_member_state, + incomplete_groups_nm, + ) = self._get_state_for_groups_using_cache( + groups, self._state_group_cache, state_filter=non_member_filter + ) + + (member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache( + groups, self._state_group_members_cache, state_filter=member_filter + ) + + state = dict(non_member_state) + for group in groups: + state[group].update(member_state[group]) + + # Now fetch any missing groups from the database + + incomplete_groups = incomplete_groups_m | incomplete_groups_nm + + if not incomplete_groups: + return state + + cache_sequence_nm = self._state_group_cache.sequence + cache_sequence_m = self._state_group_members_cache.sequence + + # Help the cache hit ratio by expanding the filter a bit + db_state_filter = state_filter.return_expanded() + + group_to_state_dict = await self._get_state_groups_from_groups( + list(incomplete_groups), state_filter=db_state_filter + ) + + # Now lets update the caches + self._insert_into_cache( + group_to_state_dict, + db_state_filter, + cache_seq_num_members=cache_sequence_m, + cache_seq_num_non_members=cache_sequence_nm, + ) + + # And finally update the result dict, by filtering out any extra + # stuff we pulled out of the database. + for group, group_state_dict in group_to_state_dict.items(): + # We just replace any existing entries, as we will have loaded + # everything we need from the database anyway. + state[group] = state_filter.filter_state(group_state_dict) + + return state + + def _get_state_for_groups_using_cache( + self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter + ) -> Tuple[Dict[int, StateMap[str]], Set[int]]: + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key, querying from a specific cache. + + Args: + groups: list of state groups for which we want to get the state. + cache: the cache of group ids to state dicts which + we will pass through - either the normal state cache or the + specific members state cache. + state_filter: The state filter used to fetch state from the + database. + + Returns: + Tuple of dict of state_group_id to state map of entries in the + cache, and the state group ids either missing from the cache or + incomplete. + """ + results = {} + incomplete_groups = set() + for group in set(groups): + state_dict_ids, got_all = self._get_state_for_group_using_cache( + cache, group, state_filter + ) + results[group] = state_dict_ids + + if not got_all: + incomplete_groups.add(group) + + return results, incomplete_groups + + def _insert_into_cache( + self, + group_to_state_dict, + state_filter, + cache_seq_num_members, + cache_seq_num_non_members, + ): + """Inserts results from querying the database into the relevant cache. + + Args: + group_to_state_dict (dict): The new entries pulled from database. + Map from state group to state dict + state_filter (StateFilter): The state filter used to fetch state + from the database. + cache_seq_num_members (int): Sequence number of member cache since + last lookup in cache + cache_seq_num_non_members (int): Sequence number of member cache since + last lookup in cache + """ + + # We need to work out which types we've fetched from the DB for the + # member vs non-member caches. This should be as accurate as possible, + # but can be an underestimate (e.g. when we have wild cards) + + member_filter, non_member_filter = state_filter.get_member_split() + if member_filter.is_full(): + # We fetched all member events + member_types = None + else: + # `concrete_types()` will only return a subset when there are wild + # cards in the filter, but that's fine. + member_types = member_filter.concrete_types() + + if non_member_filter.is_full(): + # We fetched all non member events + non_member_types = None + else: + non_member_types = non_member_filter.concrete_types() + + for group, group_state_dict in group_to_state_dict.items(): + state_dict_members = {} + state_dict_non_members = {} + + for k, v in group_state_dict.items(): + if k[0] == EventTypes.Member: + state_dict_members[k] = v + else: + state_dict_non_members[k] = v + + self._state_group_members_cache.update( + cache_seq_num_members, + key=group, + value=state_dict_members, + fetched_keys=member_types, + ) + + self._state_group_cache.update( + cache_seq_num_non_members, + key=group, + value=state_dict_non_members, + fetched_keys=non_member_types, + ) + + def store_state_group( + self, event_id, room_id, prev_group, delta_ids, current_state_ids + ): + """Store a new set of state, returning a newly assigned state group. + + Args: + event_id (str): The event ID for which the state was calculated + room_id (str) + prev_group (int|None): A previous state group for the room, optional. + delta_ids (dict|None): The delta between state at `prev_group` and + `current_state_ids`, if `prev_group` was given. Same format as + `current_state_ids`. + current_state_ids (dict): The state to store. Map of (type, state_key) + to event_id. + + Returns: + Deferred[int]: The state group ID + """ + + def _store_state_group_txn(txn): + if current_state_ids is None: + # AFAIK, this can never happen + raise Exception("current_state_ids cannot be None") + + state_group = self._state_group_seq_gen.get_next_id_txn(txn) + + self.db_pool.simple_insert_txn( + txn, + table="state_groups", + values={"id": state_group, "room_id": room_id, "event_id": event_id}, + ) + + # We persist as a delta if we can, while also ensuring the chain + # of deltas isn't tooo long, as otherwise read performance degrades. + if prev_group: + is_in_db = self.db_pool.simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (prev_group,) + ) + + potential_hops = self._count_state_group_hops_txn(txn, prev_group) + if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: + self.db_pool.simple_insert_txn( + txn, + table="state_group_edges", + values={"state_group": state_group, "prev_state_group": prev_group}, + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in delta_ids.items() + ], + ) + else: + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in current_state_ids.items() + ], + ) + + # Prefill the state group caches with this group. + # It's fine to use the sequence like this as the state group map + # is immutable. (If the map wasn't immutable then this prefill could + # race with another update) + + current_member_state_ids = { + s: ev + for (s, ev) in current_state_ids.items() + if s[0] == EventTypes.Member + } + txn.call_after( + self._state_group_members_cache.update, + self._state_group_members_cache.sequence, + key=state_group, + value=dict(current_member_state_ids), + ) + + current_non_member_state_ids = { + s: ev + for (s, ev) in current_state_ids.items() + if s[0] != EventTypes.Member + } + txn.call_after( + self._state_group_cache.update, + self._state_group_cache.sequence, + key=state_group, + value=dict(current_non_member_state_ids), + ) + + return state_group + + return self.db_pool.runInteraction("store_state_group", _store_state_group_txn) + + def purge_unreferenced_state_groups( + self, room_id: str, state_groups_to_delete + ) -> defer.Deferred: + """Deletes no longer referenced state groups and de-deltas any state + groups that reference them. + + Args: + room_id: The room the state groups belong to (must all be in the + same room). + state_groups_to_delete (Collection[int]): Set of all state groups + to delete. + """ + + return self.db_pool.runInteraction( + "purge_unreferenced_state_groups", + self._purge_unreferenced_state_groups, + room_id, + state_groups_to_delete, + ) + + def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete): + logger.info( + "[purge] found %i state groups to delete", len(state_groups_to_delete) + ) + + rows = self.db_pool.simple_select_many_txn( + txn, + table="state_group_edges", + column="prev_state_group", + iterable=state_groups_to_delete, + keyvalues={}, + retcols=("state_group",), + ) + + remaining_state_groups = { + row["state_group"] + for row in rows + if row["state_group"] not in state_groups_to_delete + } + + logger.info( + "[purge] de-delta-ing %i remaining state groups", + len(remaining_state_groups), + ) + + # Now we turn the state groups that reference to-be-deleted state + # groups to non delta versions. + for sg in remaining_state_groups: + logger.info("[purge] de-delta-ing remaining state group %s", sg) + curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) + curr_state = curr_state[sg] + + self.db_pool.simple_delete_txn( + txn, table="state_groups_state", keyvalues={"state_group": sg} + ) + + self.db_pool.simple_delete_txn( + txn, table="state_group_edges", keyvalues={"state_group": sg} + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": sg, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in curr_state.items() + ], + ) + + logger.info("[purge] removing redundant state groups") + txn.executemany( + "DELETE FROM state_groups_state WHERE state_group = ?", + ((sg,) for sg in state_groups_to_delete), + ) + txn.executemany( + "DELETE FROM state_groups WHERE id = ?", + ((sg,) for sg in state_groups_to_delete), + ) + + async def get_previous_state_groups( + self, state_groups: Iterable[int] + ) -> Dict[int, int]: + """Fetch the previous groups of the given state groups. + + Args: + state_groups + + Returns: + A mapping from state group to previous state group. + """ + + rows = await self.db_pool.simple_select_many_batch( + table="state_group_edges", + column="prev_state_group", + iterable=state_groups, + keyvalues={}, + retcols=("prev_state_group", "state_group"), + desc="get_previous_state_groups", + ) + + return {row["state_group"]: row["prev_state_group"] for row in rows} + + def purge_room_state(self, room_id, state_groups_to_delete): + """Deletes all record of a room from state tables + + Args: + room_id (str): + state_groups_to_delete (list[int]): State groups to delete + """ + + return self.db_pool.runInteraction( + "purge_room_state", + self._purge_room_state_txn, + room_id, + state_groups_to_delete, + ) + + def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete): + # first we have to delete the state groups states + logger.info("[purge] removing %s from state_groups_state", room_id) + + self.db_pool.simple_delete_many_txn( + txn, + table="state_groups_state", + column="state_group", + iterable=state_groups_to_delete, + keyvalues={}, + ) + + # ... and the state group edges + logger.info("[purge] removing %s from state_group_edges", room_id) + + self.db_pool.simple_delete_many_txn( + txn, + table="state_group_edges", + column="state_group", + iterable=state_groups_to_delete, + keyvalues={}, + ) + + # ... and the state groups + logger.info("[purge] removing %s from state_groups", room_id) + + self.db_pool.simple_delete_many_txn( + txn, + table="state_groups", + column="id", + iterable=state_groups_to_delete, + keyvalues={}, + ) -- cgit 1.5.1 From a0acdfa9e93ae63a3adee264d5420fdd1d38d76e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 11 Aug 2020 17:21:13 -0400 Subject: Converts event_federation and registration databases to async/await (#8061) --- changelog.d/8061.misc | 1 + synapse/storage/databases/main/event_federation.py | 38 ++-- synapse/storage/databases/main/registration.py | 233 ++++++++++----------- synapse/storage/databases/state/bg_updates.py | 18 +- tests/handlers/test_register.py | 11 +- tests/storage/test_monthly_active_users.py | 8 +- tests/storage/test_registration.py | 18 +- 7 files changed, 150 insertions(+), 177 deletions(-) create mode 100644 changelog.d/8061.misc (limited to 'synapse/storage/databases/state') diff --git a/changelog.d/8061.misc b/changelog.d/8061.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8061.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index eddb32b4d3..484875f989 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -15,9 +15,7 @@ import itertools import logging from queue import Empty, PriorityQueue -from typing import Dict, List, Optional, Set, Tuple - -from twisted.internet import defer +from typing import Dict, Iterable, List, Optional, Set, Tuple from synapse.api.errors import StoreError from synapse.metrics.background_process_metrics import run_as_background_process @@ -286,17 +284,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return dict(txn) - @defer.inlineCallbacks - def get_max_depth_of(self, event_ids): + async def get_max_depth_of(self, event_ids: List[str]) -> int: """Returns the max depth of a set of event IDs Args: - event_ids (list[str]) - - Returns - Deferred[int] + event_ids: The event IDs to calculate the max depth of. """ - rows = yield self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="events", column="event_id", iterable=event_ids, @@ -550,9 +544,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return event_results - @defer.inlineCallbacks - def get_missing_events(self, room_id, earliest_events, latest_events, limit): - ids = yield self.db_pool.runInteraction( + async def get_missing_events(self, room_id, earliest_events, latest_events, limit): + ids = await self.db_pool.runInteraction( "get_missing_events", self._get_missing_events, room_id, @@ -560,7 +553,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas latest_events, limit, ) - events = yield self.get_events_as_list(ids) + events = await self.get_events_as_list(ids) return events def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): @@ -595,17 +588,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas event_results.reverse() return event_results - @defer.inlineCallbacks - def get_successor_events(self, event_ids): + async def get_successor_events(self, event_ids: Iterable[str]) -> List[str]: """Fetch all events that have the given events as a prev event Args: - event_ids (iterable[str]) - - Returns: - Deferred[list[str]] + event_ids: The events to use as the previous events. """ - rows = yield self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="event_edges", column="prev_event_id", iterable=event_ids, @@ -674,8 +663,7 @@ class EventFederationStore(EventFederationWorkerStore): txn.execute(query, (room_id,)) txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) - @defer.inlineCallbacks - def _background_delete_non_state_event_auth(self, progress, batch_size): + async def _background_delete_non_state_event_auth(self, progress, batch_size): def delete_event_auth(txn): target_min_stream_id = progress.get("target_min_stream_id_inclusive") max_stream_id = progress.get("max_stream_id_exclusive") @@ -714,12 +702,12 @@ class EventFederationStore(EventFederationWorkerStore): return min_stream_id >= target_min_stream_id - result = yield self.db_pool.runInteraction( + result = await self.db_pool.runInteraction( self.EVENT_AUTH_STATE_ONLY, delete_event_auth ) if not result: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.EVENT_AUTH_STATE_ONLY ) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index f618629e09..402ae25571 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -17,9 +17,8 @@ import logging import re -from typing import Optional +from typing import Dict, List, Optional -from twisted.internet import defer from twisted.internet.defer import Deferred from synapse.api.constants import UserTypes @@ -30,7 +29,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.types import Cursor from synapse.storage.util.sequence import build_sequence_generator from synapse.types import UserID -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import cached THIRTY_MINUTES_IN_MS = 30 * 60 * 1000 @@ -69,19 +68,15 @@ class RegistrationWorkerStore(SQLBaseStore): desc="get_user_by_id", ) - @defer.inlineCallbacks - def is_trial_user(self, user_id): + async def is_trial_user(self, user_id: str) -> bool: """Checks if user is in the "trial" period, i.e. within the first N days of registration defined by `mau_trial_days` config Args: - user_id (str) - - Returns: - Deferred[bool] + user_id: The user to check for trial status. """ - info = yield self.get_user_by_id(user_id) + info = await self.get_user_by_id(user_id) if not info: return False @@ -105,41 +100,42 @@ class RegistrationWorkerStore(SQLBaseStore): "get_user_by_access_token", self._query_for_auth, token ) - @cachedInlineCallbacks() - def get_expiration_ts_for_user(self, user_id): + @cached() + async def get_expiration_ts_for_user(self, user_id: str) -> Optional[None]: """Get the expiration timestamp for the account bearing a given user ID. Args: - user_id (str): The ID of the user. + user_id: The ID of the user. Returns: - defer.Deferred: None, if the account has no expiration timestamp, - otherwise int representation of the timestamp (as a number of - milliseconds since epoch). + None, if the account has no expiration timestamp, otherwise int + representation of the timestamp (as a number of milliseconds since epoch). """ - res = yield self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="account_validity", keyvalues={"user_id": user_id}, retcol="expiration_ts_ms", allow_none=True, desc="get_expiration_ts_for_user", ) - return res - @defer.inlineCallbacks - def set_account_validity_for_user( - self, user_id, expiration_ts, email_sent, renewal_token=None - ): + async def set_account_validity_for_user( + self, + user_id: str, + expiration_ts: int, + email_sent: bool, + renewal_token: Optional[str] = None, + ) -> None: """Updates the account validity properties of the given account, with the given values. Args: - user_id (str): ID of the account to update properties for. - expiration_ts (int): New expiration date, as a timestamp in milliseconds + user_id: ID of the account to update properties for. + expiration_ts: New expiration date, as a timestamp in milliseconds since epoch. - email_sent (bool): True means a renewal email has been sent for this - account and there's no need to send another one for the current validity + email_sent: True means a renewal email has been sent for this account + and there's no need to send another one for the current validity period. - renewal_token (str): Renewal token the user can use to extend the validity + renewal_token: Renewal token the user can use to extend the validity of their account. Defaults to no token. """ @@ -158,75 +154,69 @@ class RegistrationWorkerStore(SQLBaseStore): txn, self.get_expiration_ts_for_user, (user_id,) ) - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "set_account_validity_for_user", set_account_validity_for_user_txn ) - @defer.inlineCallbacks - def set_renewal_token_for_user(self, user_id, renewal_token): + async def set_renewal_token_for_user( + self, user_id: str, renewal_token: str + ) -> None: """Defines a renewal token for a given user. Args: - user_id (str): ID of the user to set the renewal token for. - renewal_token (str): Random unique string that will be used to renew the + user_id: ID of the user to set the renewal token for. + renewal_token: Random unique string that will be used to renew the user's account. Raises: StoreError: The provided token is already set for another user. """ - yield self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, updatevalues={"renewal_token": renewal_token}, desc="set_renewal_token_for_user", ) - @defer.inlineCallbacks - def get_user_from_renewal_token(self, renewal_token): + async def get_user_from_renewal_token(self, renewal_token: str) -> str: """Get a user ID from a renewal token. Args: - renewal_token (str): The renewal token to perform the lookup with. + renewal_token: The renewal token to perform the lookup with. Returns: - defer.Deferred[str]: The ID of the user to which the token belongs. + The ID of the user to which the token belongs. """ - res = yield self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="account_validity", keyvalues={"renewal_token": renewal_token}, retcol="user_id", desc="get_user_from_renewal_token", ) - return res - - @defer.inlineCallbacks - def get_renewal_token_for_user(self, user_id): + async def get_renewal_token_for_user(self, user_id: str) -> str: """Get the renewal token associated with a given user ID. Args: - user_id (str): The user ID to lookup a token for. + user_id: The user ID to lookup a token for. Returns: - defer.Deferred[str]: The renewal token associated with this user ID. + The renewal token associated with this user ID. """ - res = yield self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="account_validity", keyvalues={"user_id": user_id}, retcol="renewal_token", desc="get_renewal_token_for_user", ) - return res - - @defer.inlineCallbacks - def get_users_expiring_soon(self): + async def get_users_expiring_soon(self) -> List[Dict[str, int]]: """Selects users whose account will expire in the [now, now + renew_at] time window (see configuration for account_validity for information on what renew_at refers to). Returns: - Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]] + A list of dictionaries mapping user ID to expiration time (in milliseconds). """ def select_users_txn(txn, now_ms, renew_at): @@ -238,53 +228,49 @@ class RegistrationWorkerStore(SQLBaseStore): txn.execute(sql, values) return self.db_pool.cursor_to_dict(txn) - res = yield self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_users_expiring_soon", select_users_txn, self.clock.time_msec(), self.config.account_validity.renew_at, ) - return res - - @defer.inlineCallbacks - def set_renewal_mail_status(self, user_id, email_sent): + async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None: """Sets or unsets the flag that indicates whether a renewal email has been sent to the user (and the user hasn't renewed their account yet). Args: - user_id (str): ID of the user to set/unset the flag for. - email_sent (bool): Flag which indicates whether a renewal email has been sent + user_id: ID of the user to set/unset the flag for. + email_sent: Flag which indicates whether a renewal email has been sent to this user. """ - yield self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, updatevalues={"email_sent": email_sent}, desc="set_renewal_mail_status", ) - @defer.inlineCallbacks - def delete_account_validity_for_user(self, user_id): + async def delete_account_validity_for_user(self, user_id: str) -> None: """Deletes the entry for the given user in the account validity table, removing their expiration date and renewal token. Args: - user_id (str): ID of the user to remove from the account validity table. + user_id: ID of the user to remove from the account validity table. """ - yield self.db_pool.simple_delete_one( + await self.db_pool.simple_delete_one( table="account_validity", keyvalues={"user_id": user_id}, desc="delete_account_validity_for_user", ) - async def is_server_admin(self, user): + async def is_server_admin(self, user: UserID) -> bool: """Determines if a user is an admin of this homeserver. Args: - user (UserID): user ID of the user to test + user: user ID of the user to test - Returns (bool): + Returns: true iff the user is a server admin, false otherwise. """ res = await self.db_pool.simple_select_one_onecol( @@ -332,32 +318,31 @@ class RegistrationWorkerStore(SQLBaseStore): return None - @cachedInlineCallbacks() - def is_real_user(self, user_id): + @cached() + async def is_real_user(self, user_id: str) -> bool: """Determines if the user is a real user, ie does not have a 'user_type'. Args: - user_id (str): user id to test + user_id: user id to test Returns: - Deferred[bool]: True if user 'user_type' is null or empty string + True if user 'user_type' is null or empty string """ - res = yield self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "is_real_user", self.is_real_user_txn, user_id ) - return res @cached() - def is_support_user(self, user_id): + async def is_support_user(self, user_id: str) -> bool: """Determines if the user is of type UserTypes.SUPPORT Args: - user_id (str): user id to test + user_id: user id to test Returns: - Deferred[bool]: True if user is of type UserTypes.SUPPORT + True if user is of type UserTypes.SUPPORT """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "is_support_user", self.is_support_user_txn, user_id ) @@ -413,8 +398,7 @@ class RegistrationWorkerStore(SQLBaseStore): desc="get_user_by_external_id", ) - @defer.inlineCallbacks - def count_all_users(self): + async def count_all_users(self): """Counts all users registered on the homeserver.""" def _count_users(txn): @@ -424,8 +408,7 @@ class RegistrationWorkerStore(SQLBaseStore): return rows[0]["users"] return 0 - ret = yield self.db_pool.runInteraction("count_users", _count_users) - return ret + return await self.db_pool.runInteraction("count_users", _count_users) def count_daily_user_type(self): """ @@ -460,8 +443,7 @@ class RegistrationWorkerStore(SQLBaseStore): "count_daily_user_type", _count_daily_user_type ) - @defer.inlineCallbacks - def count_nonbridged_users(self): + async def count_nonbridged_users(self): def _count_users(txn): txn.execute( """ @@ -472,11 +454,9 @@ class RegistrationWorkerStore(SQLBaseStore): (count,) = txn.fetchone() return count - ret = yield self.db_pool.runInteraction("count_users", _count_users) - return ret + return await self.db_pool.runInteraction("count_users", _count_users) - @defer.inlineCallbacks - def count_real_users(self): + async def count_real_users(self): """Counts all users without a special user_type registered on the homeserver.""" def _count_users(txn): @@ -486,8 +466,7 @@ class RegistrationWorkerStore(SQLBaseStore): return rows[0]["users"] return 0 - ret = yield self.db_pool.runInteraction("count_real_users", _count_users) - return ret + return await self.db_pool.runInteraction("count_real_users", _count_users) async def generate_user_id(self) -> str: """Generate a suitable localpart for a guest user @@ -537,23 +516,20 @@ class RegistrationWorkerStore(SQLBaseStore): return ret["user_id"] return None - @defer.inlineCallbacks - def user_add_threepid(self, user_id, medium, address, validated_at, added_at): - yield self.db_pool.simple_upsert( + async def user_add_threepid(self, user_id, medium, address, validated_at, added_at): + await self.db_pool.simple_upsert( "user_threepids", {"medium": medium, "address": address}, {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, ) - @defer.inlineCallbacks - def user_get_threepids(self, user_id): - ret = yield self.db_pool.simple_select_list( + async def user_get_threepids(self, user_id): + return await self.db_pool.simple_select_list( "user_threepids", {"user_id": user_id}, ["medium", "address", "validated_at", "added_at"], "user_get_threepids", ) - return ret def user_delete_threepid(self, user_id, medium, address): return self.db_pool.simple_delete( @@ -668,18 +644,18 @@ class RegistrationWorkerStore(SQLBaseStore): desc="get_id_servers_user_bound", ) - @cachedInlineCallbacks() - def get_user_deactivated_status(self, user_id): + @cached() + async def get_user_deactivated_status(self, user_id: str) -> bool: """Retrieve the value for the `deactivated` property for the provided user. Args: - user_id (str): The ID of the user to retrieve the status for. + user_id: The ID of the user to retrieve the status for. Returns: - defer.Deferred(bool): The requested value. + True if the user was deactivated, false if the user is still active. """ - res = yield self.db_pool.simple_select_one_onecol( + res = await self.db_pool.simple_select_one_onecol( table="users", keyvalues={"name": user_id}, retcol="deactivated", @@ -818,8 +794,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): "users_set_deactivated_flag", self._background_update_set_deactivated_flag ) - @defer.inlineCallbacks - def _background_update_set_deactivated_flag(self, progress, batch_size): + async def _background_update_set_deactivated_flag(self, progress, batch_size): """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1 for each of them. """ @@ -870,19 +845,18 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): else: return False, len(rows) - end, nb_processed = yield self.db_pool.runInteraction( + end, nb_processed = await self.db_pool.runInteraction( "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn ) if end: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( "users_set_deactivated_flag" ) return nb_processed - @defer.inlineCallbacks - def _bg_user_threepids_grandfather(self, progress, batch_size): + async def _bg_user_threepids_grandfather(self, progress, batch_size): """We now track which identity servers a user binds their 3PID to, so we need to handle the case of existing bindings where we didn't track this. @@ -903,11 +877,11 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): txn.executemany(sql, [(id_server,) for id_server in id_servers]) if id_servers: - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn ) - yield self.db_pool.updates._end_background_update("user_threepids_grandfather") + await self.db_pool.updates._end_background_update("user_threepids_grandfather") return 1 @@ -937,23 +911,26 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS) - @defer.inlineCallbacks - def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms): + async def add_access_token_to_user( + self, + user_id: str, + token: str, + device_id: Optional[str], + valid_until_ms: Optional[int], + ) -> None: """Adds an access token for the given user. Args: - user_id (str): The user ID. - token (str): The new access token to add. - device_id (str): ID of the device to associate with the access - token - valid_until_ms (int|None): when the token is valid until. None for - no expiry. + user_id: The user ID. + token: The new access token to add. + device_id: ID of the device to associate with the access token + valid_until_ms: when the token is valid until. None for no expiry. Raises: StoreError if there was a problem adding this. """ next_id = self._access_tokens_id_gen.get_next() - yield self.db_pool.simple_insert( + await self.db_pool.simple_insert( "access_tokens", { "id": next_id, @@ -1097,7 +1074,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - txn.call_after(self.is_guest.invalidate, (user_id,)) def record_user_external_id( self, auth_provider: str, external_id: str, user_id: str @@ -1241,9 +1217,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): return self.db_pool.runInteraction("delete_access_token", f) - @cachedInlineCallbacks() - def is_guest(self, user_id): - res = yield self.db_pool.simple_select_one_onecol( + @cached() + async def is_guest(self, user_id: str) -> bool: + res = await self.db_pool.simple_select_one_onecol( table="users", keyvalues={"name": user_id}, retcol="is_guest", @@ -1481,16 +1457,17 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): self.clock.time_msec(), ) - @defer.inlineCallbacks - def set_user_deactivated_status(self, user_id, deactivated): + async def set_user_deactivated_status( + self, user_id: str, deactivated: bool + ) -> None: """Set the `deactivated` property for the provided user to the provided value. Args: - user_id (str): The ID of the user to set the status for. - deactivated (bool): The value to set for `deactivated`. + user_id: The ID of the user to set the status for. + deactivated: The value to set for `deactivated`. """ - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "set_user_deactivated_status", self.set_user_deactivated_status_txn, user_id, @@ -1507,9 +1484,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): self._invalidate_cache_and_stream( txn, self.get_user_deactivated_status, (user_id,) ) + txn.call_after(self.is_guest.invalidate, (user_id,)) - @defer.inlineCallbacks - def _set_expiration_date_when_missing(self): + async def _set_expiration_date_when_missing(self): """ Retrieves the list of registered users that don't have an expiration date, and adds an expiration date for each of them. @@ -1533,7 +1510,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): txn, user["name"], use_delta=True ) - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "get_users_with_no_expiration_date", select_users_with_no_expiration_date_txn, ) diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 1e2d584098..139085b672 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine @@ -198,8 +196,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): columns=["room_id"], ) - @defer.inlineCallbacks - def _background_deduplicate_state(self, progress, batch_size): + async def _background_deduplicate_state(self, progress, batch_size): """This background update will slowly deduplicate state by reencoding them as deltas. """ @@ -212,7 +209,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) if max_group is None: - rows = yield self.db_pool.execute( + rows = await self.db_pool.execute( "_background_deduplicate_state", None, "SELECT coalesce(max(id), 0) FROM state_groups", @@ -330,19 +327,18 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): return False, batch_size - finished, result = yield self.db_pool.runInteraction( + finished, result = await self.db_pool.runInteraction( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn ) if finished: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME ) return result * BATCH_SIZE_SCALE_FACTOR - @defer.inlineCallbacks - def _background_index_state(self, progress, batch_size): + async def _background_index_state(self, progress, batch_size): def reindex_txn(conn): conn.rollback() if isinstance(self.database_engine, PostgresEngine): @@ -365,9 +361,9 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): ) txn.execute("DROP INDEX IF EXISTS state_groups_state_id") - yield self.db_pool.runWithConnection(reindex_txn) + await self.db_pool.runWithConnection(reindex_txn) - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.STATE_GROUP_INDEX_UPDATE_NAME ) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 6d45c4b233..e364b1bd62 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -22,6 +22,7 @@ from synapse.api.errors import Codes, ResourceLimitError, SynapseError from synapse.handlers.register import RegistrationHandler from synapse.types import RoomAlias, UserID, create_requester +from tests.test_utils import make_awaitable from tests.unittest import override_config from .. import unittest @@ -187,7 +188,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.is_real_user = Mock(return_value=defer.succeed(False)) + self.store.is_real_user = Mock(return_value=make_awaitable(False)) user_id = self.get_success(self.handler.register_user(localpart="support")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) @@ -199,8 +200,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self): room_alias_str = "#room:test" - self.store.count_real_users = Mock(return_value=defer.succeed(1)) - self.store.is_real_user = Mock(return_value=defer.succeed(True)) + self.store.count_real_users = Mock(return_value=make_awaitable(1)) + self.store.is_real_user = Mock(return_value=make_awaitable(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) directory_handler = self.hs.get_handlers().directory_handler @@ -214,8 +215,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.count_real_users = Mock(return_value=defer.succeed(2)) - self.store.is_real_user = Mock(return_value=defer.succeed(True)) + self.store.count_real_users = Mock(return_value=make_awaitable(2)) + self.store.is_real_user = Mock(return_value=make_awaitable(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index e793781a26..9870c74883 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -300,8 +300,12 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.get_success(self.store.register_user(user_id=user2, password_hash=None)) now = int(self.hs.get_clock().time_msec()) - self.store.user_add_threepid(user1, "email", user1_email, now, now) - self.store.user_add_threepid(user2, "email", user2_email, now, now) + self.get_success( + self.store.user_add_threepid(user1, "email", user1_email, now, now) + ) + self.get_success( + self.store.user_add_threepid(user2, "email", user2_email, now, now) + ) users = self.get_success(self.store.get_registered_reserved_users()) self.assertEqual(len(users), len(threepids)) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 71a40a0a49..840db66072 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -58,8 +58,10 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_add_tokens(self): yield self.store.register_user(self.user_id, self.pwhash) - yield self.store.add_access_token_to_user( - self.user_id, self.tokens[1], self.device_id, valid_until_ms=None + yield defer.ensureDeferred( + self.store.add_access_token_to_user( + self.user_id, self.tokens[1], self.device_id, valid_until_ms=None + ) ) result = yield self.store.get_user_by_access_token(self.tokens[1]) @@ -74,11 +76,15 @@ class RegistrationStoreTestCase(unittest.TestCase): def test_user_delete_access_tokens(self): # add some tokens yield self.store.register_user(self.user_id, self.pwhash) - yield self.store.add_access_token_to_user( - self.user_id, self.tokens[0], device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.store.add_access_token_to_user( + self.user_id, self.tokens[0], device_id=None, valid_until_ms=None + ) ) - yield self.store.add_access_token_to_user( - self.user_id, self.tokens[1], self.device_id, valid_until_ms=None + yield defer.ensureDeferred( + self.store.add_access_token_to_user( + self.user_id, self.tokens[1], self.device_id, valid_until_ms=None + ) ) # now delete some -- cgit 1.5.1