From a7bdf98d01d2225a479753a85ba81adf02b16a32 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 5 Aug 2020 21:38:57 +0100 Subject: Rename database classes to make some sense (#8033) --- synapse/storage/databases/main/stream.py | 1064 ++++++++++++++++++++++++++++++ 1 file changed, 1064 insertions(+) create mode 100644 synapse/storage/databases/main/stream.py (limited to 'synapse/storage/databases/main/stream.py') diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py new file mode 100644 index 0000000000..aaf225894e --- /dev/null +++ b/synapse/storage/databases/main/stream.py @@ -0,0 +1,1064 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" This module is responsible for getting events from the DB for pagination +and event streaming. + +The order it returns events in depend on whether we are streaming forwards or +are paginating backwards. We do this because we want to handle out of order +messages nicely, while still returning them in the correct order when we +paginate bacwards. + +This is implemented by keeping two ordering columns: stream_ordering and +topological_ordering. Stream ordering is basically insertion/received order +(except for events from backfill requests). The topological_ordering is a +weak ordering of events based on the pdu graph. + +This means that we have to have two different types of tokens, depending on +what sort order was used: + - stream tokens are of the form: "s%d", which maps directly to the column + - topological tokems: "t%d-%d", where the integers map to the topological + and stream ordering columns respectively. +""" + +import abc +import logging +from collections import namedtuple +from typing import Optional + +from twisted.internet import defer + +from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import DatabasePool, make_in_list_sql_clause +from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.engines import PostgresEngine +from synapse.types import RoomStreamToken +from synapse.util.caches.stream_change_cache import StreamChangeCache + +logger = logging.getLogger(__name__) + + +MAX_STREAM_SIZE = 1000 + + +_STREAM_TOKEN = "stream" +_TOPOLOGICAL_TOKEN = "topological" + + +# Used as return values for pagination APIs +_EventDictReturn = namedtuple( + "_EventDictReturn", ("event_id", "topological_ordering", "stream_ordering") +) + + +def generate_pagination_where_clause( + direction, column_names, from_token, to_token, engine +): + """Creates an SQL expression to bound the columns by the pagination + tokens. + + For example creates an SQL expression like: + + (6, 7) >= (topological_ordering, stream_ordering) + AND (5, 3) < (topological_ordering, stream_ordering) + + would be generated for dir=b, from_token=(6, 7) and to_token=(5, 3). + + Note that tokens are considered to be after the row they are in, e.g. if + a row A has a token T, then we consider A to be before T. This convention + is important when figuring out inequalities for the generated SQL, and + produces the following result: + - If paginating forwards then we exclude any rows matching the from + token, but include those that match the to token. + - If paginating backwards then we include any rows matching the from + token, but include those that match the to token. + + Args: + direction (str): Whether we're paginating backwards("b") or + forwards ("f"). + column_names (tuple[str, str]): The column names to bound. Must *not* + be user defined as these get inserted directly into the SQL + statement without escapes. + from_token (tuple[int, int]|None): The start point for the pagination. + This is an exclusive minimum bound if direction is "f", and an + inclusive maximum bound if direction is "b". + to_token (tuple[int, int]|None): The endpoint point for the pagination. + This is an inclusive maximum bound if direction is "f", and an + exclusive minimum bound if direction is "b". + engine: The database engine to generate the clauses for + + Returns: + str: The sql expression + """ + assert direction in ("b", "f") + + where_clause = [] + if from_token: + where_clause.append( + _make_generic_sql_bound( + bound=">=" if direction == "b" else "<", + column_names=column_names, + values=from_token, + engine=engine, + ) + ) + + if to_token: + where_clause.append( + _make_generic_sql_bound( + bound="<" if direction == "b" else ">=", + column_names=column_names, + values=to_token, + engine=engine, + ) + ) + + return " AND ".join(where_clause) + + +def _make_generic_sql_bound(bound, column_names, values, engine): + """Create an SQL expression that bounds the given column names by the + values, e.g. create the equivalent of `(1, 2) < (col1, col2)`. + + Only works with two columns. + + Older versions of SQLite don't support that syntax so we have to expand it + out manually. + + Args: + bound (str): The comparison operator to use. One of ">", "<", ">=", + "<=", where the values are on the left and columns on the right. + names (tuple[str, str]): The column names. Must *not* be user defined + as these get inserted directly into the SQL statement without + escapes. + values (tuple[int|None, int]): The values to bound the columns by. If + the first value is None then only creates a bound on the second + column. + engine: The database engine to generate the SQL for + + Returns: + str + """ + + assert bound in (">", "<", ">=", "<=") + + name1, name2 = column_names + val1, val2 = values + + if val1 is None: + val2 = int(val2) + return "(%d %s %s)" % (val2, bound, name2) + + val1 = int(val1) + val2 = int(val2) + + if isinstance(engine, PostgresEngine): + # Postgres doesn't optimise ``(x < a) OR (x=a AND y ? AND stream_ordering <= ?" + " ORDER BY stream_ordering %s LIMIT ?" + ) % (order,) + txn.execute(sql, (room_id, from_id, to_id, limit)) + + rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] + return rows + + rows = yield self.db_pool.runInteraction("get_room_events_stream_for_room", f) + + ret = yield self.get_events_as_list( + [r.event_id for r in rows], get_prev_content=True + ) + + self._set_before_and_after(ret, rows, topo_order=from_id is None) + + if order.lower() == "desc": + ret.reverse() + + if rows: + key = "s%d" % min(r.stream_ordering for r in rows) + else: + # Assume we didn't get anything because there was nothing to + # get. + key = from_key + + return ret, key + + @defer.inlineCallbacks + def get_membership_changes_for_user(self, user_id, from_key, to_key): + from_id = RoomStreamToken.parse_stream_token(from_key).stream + to_id = RoomStreamToken.parse_stream_token(to_key).stream + + if from_key == to_key: + return [] + + if from_id: + has_changed = self._membership_stream_cache.has_entity_changed( + user_id, int(from_id) + ) + if not has_changed: + return [] + + def f(txn): + sql = ( + "SELECT m.event_id, stream_ordering FROM events AS e," + " room_memberships AS m" + " WHERE e.event_id = m.event_id" + " AND m.user_id = ?" + " AND e.stream_ordering > ? AND e.stream_ordering <= ?" + " ORDER BY e.stream_ordering ASC" + ) + txn.execute(sql, (user_id, from_id, to_id)) + + rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] + + return rows + + rows = yield self.db_pool.runInteraction("get_membership_changes_for_user", f) + + ret = yield self.get_events_as_list( + [r.event_id for r in rows], get_prev_content=True + ) + + self._set_before_and_after(ret, rows, topo_order=False) + + return ret + + @defer.inlineCallbacks + def get_recent_events_for_room(self, room_id, limit, end_token): + """Get the most recent events in the room in topological ordering. + + Args: + room_id (str) + limit (int) + end_token (str): The stream token representing now. + + Returns: + Deferred[tuple[list[FrozenEvent], str]]: Returns a list of + events and a token pointing to the start of the returned + events. + The events returned are in ascending order. + """ + + rows, token = yield self.get_recent_event_ids_for_room( + room_id, limit, end_token + ) + + events = yield self.get_events_as_list( + [r.event_id for r in rows], get_prev_content=True + ) + + self._set_before_and_after(events, rows) + + return (events, token) + + @defer.inlineCallbacks + def get_recent_event_ids_for_room(self, room_id, limit, end_token): + """Get the most recent events in the room in topological ordering. + + Args: + room_id (str) + limit (int) + end_token (str): The stream token representing now. + + Returns: + Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of + _EventDictReturn and a token pointing to the start of the returned + events. + The events returned are in ascending order. + """ + # Allow a zero limit here, and no-op. + if limit == 0: + return [], end_token + + end_token = RoomStreamToken.parse(end_token) + + rows, token = yield self.db_pool.runInteraction( + "get_recent_event_ids_for_room", + self._paginate_room_events_txn, + room_id, + from_token=end_token, + limit=limit, + ) + + # We want to return the results in ascending order. + rows.reverse() + + return rows, token + + def get_room_event_before_stream_ordering(self, room_id, stream_ordering): + """Gets details of the first event in a room at or before a stream ordering + + Args: + room_id (str): + stream_ordering (int): + + Returns: + Deferred[(int, int, str)]: + (stream ordering, topological ordering, event_id) + """ + + def _f(txn): + sql = ( + "SELECT stream_ordering, topological_ordering, event_id" + " FROM events" + " WHERE room_id = ? AND stream_ordering <= ?" + " AND NOT outlier" + " ORDER BY stream_ordering DESC" + " LIMIT 1" + ) + txn.execute(sql, (room_id, stream_ordering)) + return txn.fetchone() + + return self.db_pool.runInteraction("get_room_event_before_stream_ordering", _f) + + async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str: + """Returns the current token for rooms stream. + + By default, it returns the current global stream token. Specifying a + `room_id` causes it to return the current room specific topological + token. + """ + token = self.get_room_max_stream_ordering() + if room_id is None: + return "s%d" % (token,) + else: + topo = await self.db_pool.runInteraction( + "_get_max_topological_txn", self._get_max_topological_txn, room_id + ) + return "t%d-%d" % (topo, token) + + def get_stream_token_for_event(self, event_id): + """The stream token for an event + Args: + event_id(str): The id of the event to look up a stream token for. + Raises: + StoreError if the event wasn't in the database. + Returns: + A deferred "s%d" stream token. + """ + return self.db_pool.simple_select_one_onecol( + table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" + ).addCallback(lambda row: "s%d" % (row,)) + + def get_topological_token_for_event(self, event_id): + """The stream token for an event + Args: + event_id(str): The id of the event to look up a stream token for. + Raises: + StoreError if the event wasn't in the database. + Returns: + A deferred "t%d-%d" topological token. + """ + return self.db_pool.simple_select_one( + table="events", + keyvalues={"event_id": event_id}, + retcols=("stream_ordering", "topological_ordering"), + desc="get_topological_token_for_event", + ).addCallback( + lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"]) + ) + + def get_max_topological_token(self, room_id, stream_key): + """Get the max topological token in a room before the given stream + ordering. + + Args: + room_id (str) + stream_key (int) + + Returns: + Deferred[int] + """ + sql = ( + "SELECT coalesce(max(topological_ordering), 0) FROM events" + " WHERE room_id = ? AND stream_ordering < ?" + ) + return self.db_pool.execute( + "get_max_topological_token", None, sql, room_id, stream_key + ).addCallback(lambda r: r[0][0] if r else 0) + + def _get_max_topological_txn(self, txn, room_id): + txn.execute( + "SELECT MAX(topological_ordering) FROM events WHERE room_id = ?", + (room_id,), + ) + + rows = txn.fetchall() + return rows[0][0] if rows else 0 + + @staticmethod + def _set_before_and_after(events, rows, topo_order=True): + """Inserts ordering information to events' internal metadata from + the DB rows. + + Args: + events (list[FrozenEvent]) + rows (list[_EventDictReturn]) + topo_order (bool): Whether the events were ordered topologically + or by stream ordering. If true then all rows should have a non + null topological_ordering. + """ + for event, row in zip(events, rows): + stream = row.stream_ordering + if topo_order and row.topological_ordering: + topo = row.topological_ordering + else: + topo = None + internal = event.internal_metadata + internal.before = str(RoomStreamToken(topo, stream - 1)) + internal.after = str(RoomStreamToken(topo, stream)) + internal.order = (int(topo) if topo else 0, int(stream)) + + @defer.inlineCallbacks + def get_events_around( + self, room_id, event_id, before_limit, after_limit, event_filter=None + ): + """Retrieve events and pagination tokens around a given event in a + room. + + Args: + room_id (str) + event_id (str) + before_limit (int) + after_limit (int) + event_filter (Filter|None) + + Returns: + dict + """ + + results = yield self.db_pool.runInteraction( + "get_events_around", + self._get_events_around_txn, + room_id, + event_id, + before_limit, + after_limit, + event_filter, + ) + + events_before = yield self.get_events_as_list( + list(results["before"]["event_ids"]), get_prev_content=True + ) + + events_after = yield self.get_events_as_list( + list(results["after"]["event_ids"]), get_prev_content=True + ) + + return { + "events_before": events_before, + "events_after": events_after, + "start": results["before"]["token"], + "end": results["after"]["token"], + } + + def _get_events_around_txn( + self, txn, room_id, event_id, before_limit, after_limit, event_filter + ): + """Retrieves event_ids and pagination tokens around a given event in a + room. + + Args: + room_id (str) + event_id (str) + before_limit (int) + after_limit (int) + event_filter (Filter|None) + + Returns: + dict + """ + + results = self.db_pool.simple_select_one_txn( + txn, + "events", + keyvalues={"event_id": event_id, "room_id": room_id}, + retcols=["stream_ordering", "topological_ordering"], + ) + + # Paginating backwards includes the event at the token, but paginating + # forward doesn't. + before_token = RoomStreamToken( + results["topological_ordering"] - 1, results["stream_ordering"] + ) + + after_token = RoomStreamToken( + results["topological_ordering"], results["stream_ordering"] + ) + + rows, start_token = self._paginate_room_events_txn( + txn, + room_id, + before_token, + direction="b", + limit=before_limit, + event_filter=event_filter, + ) + events_before = [r.event_id for r in rows] + + rows, end_token = self._paginate_room_events_txn( + txn, + room_id, + after_token, + direction="f", + limit=after_limit, + event_filter=event_filter, + ) + events_after = [r.event_id for r in rows] + + return { + "before": {"event_ids": events_before, "token": start_token}, + "after": {"event_ids": events_after, "token": end_token}, + } + + @defer.inlineCallbacks + def get_all_new_events_stream(self, from_id, current_id, limit): + """Get all new events + + Returns all events with from_id < stream_ordering <= current_id. + + Args: + from_id (int): the stream_ordering of the last event we processed + current_id (int): the stream_ordering of the most recently processed event + limit (int): the maximum number of events to return + + Returns: + Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where + `next_id` is the next value to pass as `from_id` (it will either be the + stream_ordering of the last returned event, or, if fewer than `limit` events + were found, `current_id`. + """ + + def get_all_new_events_stream_txn(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id" + " FROM events AS e" + " WHERE" + " ? < e.stream_ordering AND e.stream_ordering <= ?" + " ORDER BY e.stream_ordering ASC" + " LIMIT ?" + ) + + txn.execute(sql, (from_id, current_id, limit)) + rows = txn.fetchall() + + upper_bound = current_id + if len(rows) == limit: + upper_bound = rows[-1][0] + + return upper_bound, [row[1] for row in rows] + + upper_bound, event_ids = yield self.db_pool.runInteraction( + "get_all_new_events_stream", get_all_new_events_stream_txn + ) + + events = yield self.get_events_as_list(event_ids) + + return upper_bound, events + + async def get_federation_out_pos(self, typ: str) -> int: + if self._need_to_reset_federation_stream_positions: + await self.db_pool.runInteraction( + "_reset_federation_positions_txn", self._reset_federation_positions_txn + ) + self._need_to_reset_federation_stream_positions = False + + return await self.db_pool.simple_select_one_onecol( + table="federation_stream_position", + retcol="stream_id", + keyvalues={"type": typ, "instance_name": self._instance_name}, + desc="get_federation_out_pos", + ) + + async def update_federation_out_pos(self, typ, stream_id): + if self._need_to_reset_federation_stream_positions: + await self.db_pool.runInteraction( + "_reset_federation_positions_txn", self._reset_federation_positions_txn + ) + self._need_to_reset_federation_stream_positions = False + + return await self.db_pool.simple_update_one( + table="federation_stream_position", + keyvalues={"type": typ, "instance_name": self._instance_name}, + updatevalues={"stream_id": stream_id}, + desc="update_federation_out_pos", + ) + + def _reset_federation_positions_txn(self, txn): + """Fiddles with the `federation_stream_position` table to make it match + the configured federation sender instances during start up. + """ + + # The federation sender instances may have changed, so we need to + # massage the `federation_stream_position` table to have a row per type + # per instance sending federation. If there is a mismatch we update the + # table with the correct rows using the *minimum* stream ID seen. This + # may result in resending of events/EDUs to remote servers, but that is + # preferable to dropping them. + + if not self._send_federation: + return + + # Pull out the configured instances. If we don't have a shard config then + # we assume that we're the only instance sending. + configured_instances = self._federation_shard_config.instances + if not configured_instances: + configured_instances = [self._instance_name] + elif self._instance_name not in configured_instances: + return + + instances_in_table = self.db_pool.simple_select_onecol_txn( + txn, + table="federation_stream_position", + keyvalues={}, + retcol="instance_name", + ) + + if set(instances_in_table) == set(configured_instances): + # Nothing to do + return + + sql = """ + SELECT type, MIN(stream_id) FROM federation_stream_position + GROUP BY type + """ + txn.execute(sql) + min_positions = dict(txn) # Map from type -> min position + + # Ensure we do actually have some values here + assert set(min_positions) == {"federation", "events"} + + sql = """ + DELETE FROM federation_stream_position + WHERE NOT (%s) + """ + clause, args = make_in_list_sql_clause( + txn.database_engine, "instance_name", configured_instances + ) + txn.execute(sql % (clause,), args) + + for typ, stream_id in min_positions.items(): + self.db_pool.simple_upsert_txn( + txn, + table="federation_stream_position", + keyvalues={"type": typ, "instance_name": self._instance_name}, + values={"stream_id": stream_id}, + ) + + def has_room_changed_since(self, room_id, stream_id): + return self._events_stream_cache.has_entity_changed(room_id, stream_id) + + def _paginate_room_events_txn( + self, + txn, + room_id, + from_token, + to_token=None, + direction="b", + limit=-1, + event_filter=None, + ): + """Returns list of events before or after a given token. + + Args: + txn + room_id (str) + from_token (RoomStreamToken): The token used to stream from + to_token (RoomStreamToken|None): A token which if given limits the + results to only those before + direction(char): Either 'b' or 'f' to indicate whether we are + paginating forwards or backwards from `from_key`. + limit (int): The maximum number of events to return. + event_filter (Filter|None): If provided filters the events to + those that match the filter. + + Returns: + Deferred[tuple[list[_EventDictReturn], str]]: Returns the results + as a list of _EventDictReturn and a token that points to the end + of the result set. If no events are returned then the end of the + stream has been reached (i.e. there are no events between + `from_token` and `to_token`), or `limit` is zero. + """ + + assert int(limit) >= 0 + + # Tokens really represent positions between elements, but we use + # the convention of pointing to the event before the gap. Hence + # we have a bit of asymmetry when it comes to equalities. + args = [False, room_id] + if direction == "b": + order = "DESC" + else: + order = "ASC" + + bounds = generate_pagination_where_clause( + direction=direction, + column_names=("topological_ordering", "stream_ordering"), + from_token=from_token, + to_token=to_token, + engine=self.database_engine, + ) + + filter_clause, filter_args = filter_to_clause(event_filter) + + if filter_clause: + bounds += " AND " + filter_clause + args.extend(filter_args) + + args.append(int(limit)) + + select_keywords = "SELECT" + join_clause = "" + if event_filter and event_filter.labels: + # If we're not filtering on a label, then joining on event_labels will + # return as many row for a single event as the number of labels it has. To + # avoid this, only join if we're filtering on at least one label. + join_clause = """ + LEFT JOIN event_labels + USING (event_id, room_id, topological_ordering) + """ + if len(event_filter.labels) > 1: + # Using DISTINCT in this SELECT query is quite expensive, because it + # requires the engine to sort on the entire (not limited) result set, + # i.e. the entire events table. We only need to use it when we're + # filtering on more than two labels, because that's the only scenario + # in which we can possibly to get multiple times the same event ID in + # the results. + select_keywords += "DISTINCT" + + sql = """ + %(select_keywords)s event_id, topological_ordering, stream_ordering + FROM events + %(join_clause)s + WHERE outlier = ? AND room_id = ? AND %(bounds)s + ORDER BY topological_ordering %(order)s, + stream_ordering %(order)s LIMIT ? + """ % { + "select_keywords": select_keywords, + "join_clause": join_clause, + "bounds": bounds, + "order": order, + } + + txn.execute(sql, args) + + rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn] + + if rows: + topo = rows[-1].topological_ordering + toke = rows[-1].stream_ordering + if direction == "b": + # Tokens are positions between events. + # This token points *after* the last event in the chunk. + # We need it to point to the event before it in the chunk + # when we are going backwards so we subtract one from the + # stream part. + toke -= 1 + next_token = RoomStreamToken(topo, toke) + else: + # TODO (erikj): We should work out what to do here instead. + next_token = to_token if to_token else from_token + + return rows, str(next_token) + + @defer.inlineCallbacks + def paginate_room_events( + self, room_id, from_key, to_key=None, direction="b", limit=-1, event_filter=None + ): + """Returns list of events before or after a given token. + + Args: + room_id (str) + from_key (str): The token used to stream from + to_key (str|None): A token which if given limits the results to + only those before + direction(char): Either 'b' or 'f' to indicate whether we are + paginating forwards or backwards from `from_key`. + limit (int): The maximum number of events to return. + event_filter (Filter|None): If provided filters the events to + those that match the filter. + + Returns: + tuple[list[FrozenEvent], str]: Returns the results as a list of + events and a token that points to the end of the result set. If no + events are returned then the end of the stream has been reached + (i.e. there are no events between `from_key` and `to_key`). + """ + + from_key = RoomStreamToken.parse(from_key) + if to_key: + to_key = RoomStreamToken.parse(to_key) + + rows, token = yield self.db_pool.runInteraction( + "paginate_room_events", + self._paginate_room_events_txn, + room_id, + from_key, + to_key, + direction, + limit, + event_filter, + ) + + events = yield self.get_events_as_list( + [r.event_id for r in rows], get_prev_content=True + ) + + self._set_before_and_after(events, rows) + + return (events, token) + + +class StreamStore(StreamWorkerStore): + def get_room_max_stream_ordering(self): + return self._stream_id_gen.get_current_token() + + def get_room_min_stream_ordering(self): + return self._backfill_id_gen.get_current_token() -- cgit 1.5.1 From ad6190c9252aafd37cd8c229b70853bfc4ef0e64 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 17 Aug 2020 07:24:46 -0400 Subject: Convert stream database to async/await. (#8074) --- changelog.d/8074.misc | 1 + synapse/api/filtering.py | 2 +- synapse/api/presence.py | 69 ++++ synapse/federation/send_queue.py | 2 +- synapse/federation/sender/__init__.py | 2 +- synapse/federation/sender/per_destination_queue.py | 2 +- synapse/handlers/presence.py | 2 +- synapse/storage/databases/main/presence.py | 2 +- synapse/storage/databases/main/stream.py | 387 +++++++++++---------- synapse/storage/presence.py | 69 ---- tests/handlers/test_presence.py | 2 +- tests/storage/test_purge.py | 49 +-- 12 files changed, 293 insertions(+), 296 deletions(-) create mode 100644 changelog.d/8074.misc create mode 100644 synapse/api/presence.py delete mode 100644 synapse/storage/presence.py (limited to 'synapse/storage/databases/main/stream.py') diff --git a/changelog.d/8074.misc b/changelog.d/8074.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8074.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 7393d6cb74..a8937d2595 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -23,7 +23,7 @@ from jsonschema import FormatChecker from synapse.api.constants import EventContentFields from synapse.api.errors import SynapseError -from synapse.storage.presence import UserPresenceState +from synapse.api.presence import UserPresenceState from synapse.types import RoomID, UserID FILTER_SCHEMA = { diff --git a/synapse/api/presence.py b/synapse/api/presence.py new file mode 100644 index 0000000000..18a462f0ee --- /dev/null +++ b/synapse/api/presence.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple + +from synapse.api.constants import PresenceState + + +class UserPresenceState( + namedtuple( + "UserPresenceState", + ( + "user_id", + "state", + "last_active_ts", + "last_federation_update_ts", + "last_user_sync_ts", + "status_msg", + "currently_active", + ), + ) +): + """Represents the current presence state of the user. + + user_id (str) + last_active (int): Time in msec that the user last interacted with server. + last_federation_update (int): Time in msec since either a) we sent a presence + update to other servers or b) we received a presence update, depending + on if is a local user or not. + last_user_sync (int): Time in msec that the user last *completed* a sync + (or event stream). + status_msg (str): User set status message. + """ + + def as_dict(self): + return dict(self._asdict()) + + @staticmethod + def from_dict(d): + return UserPresenceState(**d) + + def copy_and_replace(self, **kwargs): + return self._replace(**kwargs) + + @classmethod + def default(cls, user_id): + """Returns a default presence state. + """ + return cls( + user_id=user_id, + state=PresenceState.OFFLINE, + last_active_ts=0, + last_federation_update_ts=0, + last_user_sync_ts=0, + status_msg=None, + currently_active=False, + ) diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 2b0ab2dcbf..4d65d4aeea 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -37,8 +37,8 @@ from sortedcontainers import SortedDict from twisted.internet import defer +from synapse.api.presence import UserPresenceState from synapse.metrics import LaterGauge -from synapse.storage.presence import UserPresenceState from synapse.util.metrics import Measure from .units import Edu diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 94cc63001e..e53b6ac456 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -22,6 +22,7 @@ from twisted.internet import defer import synapse import synapse.metrics +from synapse.api.presence import UserPresenceState from synapse.events import EventBase from synapse.federation.sender.per_destination_queue import PerDestinationQueue from synapse.federation.sender.transaction_manager import TransactionManager @@ -39,7 +40,6 @@ from synapse.metrics import ( events_processed_counter, ) from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage.presence import UserPresenceState from synapse.types import ReadReceipt from synapse.util.metrics import Measure, measure_func diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 8cbc23d901..c09ffcaf4c 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -24,12 +24,12 @@ from synapse.api.errors import ( HttpResponseException, RequestSendFailed, ) +from synapse.api.presence import UserPresenceState from synapse.events import EventBase from synapse.federation.units import Edu from synapse.handlers.presence import format_user_presence_state from synapse.metrics import sent_transactions_counter from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage.presence import UserPresenceState from synapse.types import ReadReceipt from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 5387b3724f..24e1940ee5 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -33,13 +33,13 @@ from typing_extensions import ContextManager import synapse.metrics from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.errors import SynapseError +from synapse.api.presence import UserPresenceState from synapse.logging.context import run_in_background from synapse.logging.utils import log_function from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.state import StateHandler from synapse.storage.databases.main import DataStore -from synapse.storage.presence import UserPresenceState from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches.descriptors import cached diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 9f691e5792..4e3ec02d14 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -15,8 +15,8 @@ from typing import List, Tuple +from synapse.api.presence import UserPresenceState from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.presence import UserPresenceState from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index aaf225894e..8ccfb8fc46 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -39,15 +39,17 @@ what sort order was used: import abc import logging from collections import namedtuple -from typing import Optional +from typing import Dict, Iterable, List, Optional, Tuple from twisted.internet import defer +from synapse.api.filtering import Filter +from synapse.events import EventBase from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_in_list_sql_clause from synapse.storage.databases.main.events_worker import EventsWorkerStore -from synapse.storage.engines import PostgresEngine +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine from synapse.types import RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -68,8 +70,12 @@ _EventDictReturn = namedtuple( def generate_pagination_where_clause( - direction, column_names, from_token, to_token, engine -): + direction: str, + column_names: Tuple[str, str], + from_token: Optional[Tuple[int, int]], + to_token: Optional[Tuple[int, int]], + engine: BaseDatabaseEngine, +) -> str: """Creates an SQL expression to bound the columns by the pagination tokens. @@ -90,21 +96,19 @@ def generate_pagination_where_clause( token, but include those that match the to token. Args: - direction (str): Whether we're paginating backwards("b") or - forwards ("f"). - column_names (tuple[str, str]): The column names to bound. Must *not* - be user defined as these get inserted directly into the SQL - statement without escapes. - from_token (tuple[int, int]|None): The start point for the pagination. - This is an exclusive minimum bound if direction is "f", and an - inclusive maximum bound if direction is "b". - to_token (tuple[int, int]|None): The endpoint point for the pagination. - This is an inclusive maximum bound if direction is "f", and an - exclusive minimum bound if direction is "b". + direction: Whether we're paginating backwards("b") or forwards ("f"). + column_names: The column names to bound. Must *not* be user defined as + these get inserted directly into the SQL statement without escapes. + from_token: The start point for the pagination. This is an exclusive + minimum bound if direction is "f", and an inclusive maximum bound if + direction is "b". + to_token: The endpoint point for the pagination. This is an inclusive + maximum bound if direction is "f", and an exclusive minimum bound if + direction is "b". engine: The database engine to generate the clauses for Returns: - str: The sql expression + The sql expression """ assert direction in ("b", "f") @@ -132,7 +136,12 @@ def generate_pagination_where_clause( return " AND ".join(where_clause) -def _make_generic_sql_bound(bound, column_names, values, engine): +def _make_generic_sql_bound( + bound: str, + column_names: Tuple[str, str], + values: Tuple[Optional[int], int], + engine: BaseDatabaseEngine, +) -> str: """Create an SQL expression that bounds the given column names by the values, e.g. create the equivalent of `(1, 2) < (col1, col2)`. @@ -142,18 +151,18 @@ def _make_generic_sql_bound(bound, column_names, values, engine): out manually. Args: - bound (str): The comparison operator to use. One of ">", "<", ">=", + bound: The comparison operator to use. One of ">", "<", ">=", "<=", where the values are on the left and columns on the right. - names (tuple[str, str]): The column names. Must *not* be user defined + names: The column names. Must *not* be user defined as these get inserted directly into the SQL statement without escapes. - values (tuple[int|None, int]): The values to bound the columns by. If + values: The values to bound the columns by. If the first value is None then only creates a bound on the second column. engine: The database engine to generate the SQL for Returns: - str + The SQL statement """ assert bound in (">", "<", ">=", "<=") @@ -193,7 +202,7 @@ def _make_generic_sql_bound(bound, column_names, values, engine): ) -def filter_to_clause(event_filter): +def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]: # NB: This may create SQL clauses that don't optimise well (and we don't # have indices on all possible clauses). E.g. it may create # "room_id == X AND room_id != X", which postgres doesn't optimise. @@ -291,34 +300,35 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): def get_room_min_stream_ordering(self): raise NotImplementedError() - @defer.inlineCallbacks - def get_room_events_stream_for_rooms( - self, room_ids, from_key, to_key, limit=0, order="DESC" - ): + async def get_room_events_stream_for_rooms( + self, + room_ids: Iterable[str], + from_key: str, + to_key: str, + limit: int = 0, + order: str = "DESC", + ) -> Dict[str, Tuple[List[EventBase], str]]: """Get new room events in stream ordering since `from_key`. Args: - room_id (str) - from_key (str): Token from which no events are returned before - to_key (str): Token from which no events are returned after. (This + room_ids + from_key: Token from which no events are returned before + to_key: Token from which no events are returned after. (This is typically the current stream token) - limit (int): Maximum number of events to return - order (str): Either "DESC" or "ASC". Determines which events are + limit: Maximum number of events to return + order: Either "DESC" or "ASC". Determines which events are returned when the result is limited. If "DESC" then the most recent `limit` events are returned, otherwise returns the oldest `limit` events. Returns: - Deferred[dict[str,tuple[list[FrozenEvent], str]]] - A map from room id to a tuple containing: - - list of recent events in the room - - stream ordering key for the start of the chunk of events returned. + A map from room id to a tuple containing: + - list of recent events in the room + - stream ordering key for the start of the chunk of events returned. """ from_id = RoomStreamToken.parse_stream_token(from_key).stream - room_ids = yield self._events_stream_cache.get_entities_changed( - room_ids, from_id - ) + room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id) if not room_ids: return {} @@ -326,7 +336,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): results = {} room_ids = list(room_ids) for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)): - res = yield make_deferred_yieldable( + res = await make_deferred_yieldable( defer.gatherResults( [ run_in_background( @@ -361,28 +371,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if self._events_stream_cache.has_entity_changed(room_id, from_key) } - @defer.inlineCallbacks - def get_room_events_stream_for_room( - self, room_id, from_key, to_key, limit=0, order="DESC" - ): + async def get_room_events_stream_for_room( + self, + room_id: str, + from_key: str, + to_key: str, + limit: int = 0, + order: str = "DESC", + ) -> Tuple[List[EventBase], str]: """Get new room events in stream ordering since `from_key`. Args: - room_id (str) - from_key (str): Token from which no events are returned before - to_key (str): Token from which no events are returned after. (This + room_id + from_key: Token from which no events are returned before + to_key: Token from which no events are returned after. (This is typically the current stream token) - limit (int): Maximum number of events to return - order (str): Either "DESC" or "ASC". Determines which events are + limit: Maximum number of events to return + order: Either "DESC" or "ASC". Determines which events are returned when the result is limited. If "DESC" then the most recent `limit` events are returned, otherwise returns the oldest `limit` events. Returns: - Deferred[tuple[list[FrozenEvent], str]]: Returns the list of - events (in ascending order) and the token from the start of - the chunk of events returned. + The list of events (in ascending order) and the token from the start + of the chunk of events returned. """ if from_key == to_key: return [], from_key @@ -390,9 +403,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): from_id = RoomStreamToken.parse_stream_token(from_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream - has_changed = yield self._events_stream_cache.has_entity_changed( - room_id, from_id - ) + has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id) if not has_changed: return [], from_key @@ -410,9 +421,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] return rows - rows = yield self.db_pool.runInteraction("get_room_events_stream_for_room", f) + rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f) - ret = yield self.get_events_as_list( + ret = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) @@ -430,8 +441,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return ret, key - @defer.inlineCallbacks - def get_membership_changes_for_user(self, user_id, from_key, to_key): + async def get_membership_changes_for_user(self, user_id, from_key, to_key): from_id = RoomStreamToken.parse_stream_token(from_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream @@ -460,9 +470,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows - rows = yield self.db_pool.runInteraction("get_membership_changes_for_user", f) + rows = await self.db_pool.runInteraction("get_membership_changes_for_user", f) - ret = yield self.get_events_as_list( + ret = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) @@ -470,27 +480,26 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return ret - @defer.inlineCallbacks - def get_recent_events_for_room(self, room_id, limit, end_token): + async def get_recent_events_for_room( + self, room_id: str, limit: int, end_token: str + ) -> Tuple[List[EventBase], str]: """Get the most recent events in the room in topological ordering. Args: - room_id (str) - limit (int) - end_token (str): The stream token representing now. + room_id + limit + end_token: The stream token representing now. Returns: - Deferred[tuple[list[FrozenEvent], str]]: Returns a list of - events and a token pointing to the start of the returned - events. - The events returned are in ascending order. + A list of events and a token pointing to the start of the returned + events. The events returned are in ascending order. """ - rows, token = yield self.get_recent_event_ids_for_room( + rows, token = await self.get_recent_event_ids_for_room( room_id, limit, end_token ) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) @@ -498,20 +507,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return (events, token) - @defer.inlineCallbacks - def get_recent_event_ids_for_room(self, room_id, limit, end_token): + async def get_recent_event_ids_for_room( + self, room_id: str, limit: int, end_token: str + ) -> Tuple[List[_EventDictReturn], str]: """Get the most recent events in the room in topological ordering. Args: - room_id (str) - limit (int) - end_token (str): The stream token representing now. + room_id + limit + end_token: The stream token representing now. Returns: - Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of - _EventDictReturn and a token pointing to the start of the returned - events. - The events returned are in ascending order. + A list of _EventDictReturn and a token pointing to the start of the + returned events. The events returned are in ascending order. """ # Allow a zero limit here, and no-op. if limit == 0: @@ -519,7 +527,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): end_token = RoomStreamToken.parse(end_token) - rows, token = yield self.db_pool.runInteraction( + rows, token = await self.db_pool.runInteraction( "get_recent_event_ids_for_room", self._paginate_room_events_txn, room_id, @@ -532,12 +540,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows, token - def get_room_event_before_stream_ordering(self, room_id, stream_ordering): + def get_room_event_before_stream_ordering(self, room_id: str, stream_ordering: int): """Gets details of the first event in a room at or before a stream ordering Args: - room_id (str): - stream_ordering (int): + room_id: + stream_ordering: Returns: Deferred[(int, int, str)]: @@ -574,55 +582,56 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) return "t%d-%d" % (topo, token) - def get_stream_token_for_event(self, event_id): + async def get_stream_token_for_event(self, event_id: str) -> str: """The stream token for an event Args: - event_id(str): The id of the event to look up a stream token for. + event_id: The id of the event to look up a stream token for. Raises: StoreError if the event wasn't in the database. Returns: - A deferred "s%d" stream token. + A "s%d" stream token. """ - return self.db_pool.simple_select_one_onecol( + row = await self.db_pool.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" - ).addCallback(lambda row: "s%d" % (row,)) + ) + return "s%d" % (row,) - def get_topological_token_for_event(self, event_id): + async def get_topological_token_for_event(self, event_id: str) -> str: """The stream token for an event Args: - event_id(str): The id of the event to look up a stream token for. + event_id: The id of the event to look up a stream token for. Raises: StoreError if the event wasn't in the database. Returns: - A deferred "t%d-%d" topological token. + A "t%d-%d" topological token. """ - return self.db_pool.simple_select_one( + row = await self.db_pool.simple_select_one( table="events", keyvalues={"event_id": event_id}, retcols=("stream_ordering", "topological_ordering"), desc="get_topological_token_for_event", - ).addCallback( - lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"]) ) + return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"]) - def get_max_topological_token(self, room_id, stream_key): + async def get_max_topological_token(self, room_id: str, stream_key: int) -> int: """Get the max topological token in a room before the given stream ordering. Args: - room_id (str) - stream_key (int) + room_id + stream_key Returns: - Deferred[int] + The maximum topological token. """ sql = ( "SELECT coalesce(max(topological_ordering), 0) FROM events" " WHERE room_id = ? AND stream_ordering < ?" ) - return self.db_pool.execute( + row = await self.db_pool.execute( "get_max_topological_token", None, sql, room_id, stream_key - ).addCallback(lambda r: r[0][0] if r else 0) + ) + return row[0][0] if row else 0 def _get_max_topological_txn(self, txn, room_id): txn.execute( @@ -634,16 +643,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows[0][0] if rows else 0 @staticmethod - def _set_before_and_after(events, rows, topo_order=True): + def _set_before_and_after( + events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True + ): """Inserts ordering information to events' internal metadata from the DB rows. Args: - events (list[FrozenEvent]) - rows (list[_EventDictReturn]) - topo_order (bool): Whether the events were ordered topologically - or by stream ordering. If true then all rows should have a non - null topological_ordering. + events + rows + topo_order: Whether the events were ordered topologically or by stream + ordering. If true then all rows should have a non null + topological_ordering. """ for event, row in zip(events, rows): stream = row.stream_ordering @@ -656,25 +667,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): internal.after = str(RoomStreamToken(topo, stream)) internal.order = (int(topo) if topo else 0, int(stream)) - @defer.inlineCallbacks - def get_events_around( - self, room_id, event_id, before_limit, after_limit, event_filter=None - ): + async def get_events_around( + self, + room_id: str, + event_id: str, + before_limit: int, + after_limit: int, + event_filter: Optional[Filter] = None, + ) -> dict: """Retrieve events and pagination tokens around a given event in a room. - - Args: - room_id (str) - event_id (str) - before_limit (int) - after_limit (int) - event_filter (Filter|None) - - Returns: - dict """ - results = yield self.db_pool.runInteraction( + results = await self.db_pool.runInteraction( "get_events_around", self._get_events_around_txn, room_id, @@ -684,11 +689,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): event_filter, ) - events_before = yield self.get_events_as_list( + events_before = await self.get_events_as_list( list(results["before"]["event_ids"]), get_prev_content=True ) - events_after = yield self.get_events_as_list( + events_after = await self.get_events_as_list( list(results["after"]["event_ids"]), get_prev_content=True ) @@ -700,17 +705,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): } def _get_events_around_txn( - self, txn, room_id, event_id, before_limit, after_limit, event_filter - ): + self, + txn, + room_id: str, + event_id: str, + before_limit: int, + after_limit: int, + event_filter: Optional[Filter], + ) -> dict: """Retrieves event_ids and pagination tokens around a given event in a room. Args: - room_id (str) - event_id (str) - before_limit (int) - after_limit (int) - event_filter (Filter|None) + room_id + event_id + before_limit + after_limit + event_filter Returns: dict @@ -758,22 +769,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): "after": {"event_ids": events_after, "token": end_token}, } - @defer.inlineCallbacks - def get_all_new_events_stream(self, from_id, current_id, limit): + async def get_all_new_events_stream( + self, from_id: int, current_id: int, limit: int + ) -> Tuple[int, List[EventBase]]: """Get all new events Returns all events with from_id < stream_ordering <= current_id. Args: - from_id (int): the stream_ordering of the last event we processed - current_id (int): the stream_ordering of the most recently processed event - limit (int): the maximum number of events to return + from_id: the stream_ordering of the last event we processed + current_id: the stream_ordering of the most recently processed event + limit: the maximum number of events to return Returns: - Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where - `next_id` is the next value to pass as `from_id` (it will either be the - stream_ordering of the last returned event, or, if fewer than `limit` events - were found, `current_id`. + A tuple of (next_id, events), where `next_id` is the next value to + pass as `from_id` (it will either be the stream_ordering of the + last returned event, or, if fewer than `limit` events were found, + the `current_id`). """ def get_all_new_events_stream_txn(txn): @@ -795,11 +807,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return upper_bound, [row[1] for row in rows] - upper_bound, event_ids = yield self.db_pool.runInteraction( + upper_bound, event_ids = await self.db_pool.runInteraction( "get_all_new_events_stream", get_all_new_events_stream_txn ) - events = yield self.get_events_as_list(event_ids) + events = await self.get_events_as_list(event_ids) return upper_bound, events @@ -817,21 +829,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): desc="get_federation_out_pos", ) - async def update_federation_out_pos(self, typ, stream_id): + async def update_federation_out_pos(self, typ: str, stream_id: int) -> None: if self._need_to_reset_federation_stream_positions: await self.db_pool.runInteraction( "_reset_federation_positions_txn", self._reset_federation_positions_txn ) self._need_to_reset_federation_stream_positions = False - return await self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="federation_stream_position", keyvalues={"type": typ, "instance_name": self._instance_name}, updatevalues={"stream_id": stream_id}, desc="update_federation_out_pos", ) - def _reset_federation_positions_txn(self, txn): + def _reset_federation_positions_txn(self, txn) -> None: """Fiddles with the `federation_stream_position` table to make it match the configured federation sender instances during start up. """ @@ -892,39 +904,37 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): values={"stream_id": stream_id}, ) - def has_room_changed_since(self, room_id, stream_id): + def has_room_changed_since(self, room_id: str, stream_id: int) -> bool: return self._events_stream_cache.has_entity_changed(room_id, stream_id) def _paginate_room_events_txn( self, txn, - room_id, - from_token, - to_token=None, - direction="b", - limit=-1, - event_filter=None, - ): + room_id: str, + from_token: RoomStreamToken, + to_token: Optional[RoomStreamToken] = None, + direction: str = "b", + limit: int = -1, + event_filter: Optional[Filter] = None, + ) -> Tuple[List[_EventDictReturn], str]: """Returns list of events before or after a given token. Args: txn - room_id (str) - from_token (RoomStreamToken): The token used to stream from - to_token (RoomStreamToken|None): A token which if given limits the - results to only those before - direction(char): Either 'b' or 'f' to indicate whether we are - paginating forwards or backwards from `from_key`. - limit (int): The maximum number of events to return. - event_filter (Filter|None): If provided filters the events to + room_id + from_token: The token used to stream from + to_token: A token which if given limits the results to only those before + direction: Either 'b' or 'f' to indicate whether we are paginating + forwards or backwards from `from_key`. + limit: The maximum number of events to return. + event_filter: If provided filters the events to those that match the filter. Returns: - Deferred[tuple[list[_EventDictReturn], str]]: Returns the results - as a list of _EventDictReturn and a token that points to the end - of the result set. If no events are returned then the end of the - stream has been reached (i.e. there are no events between - `from_token` and `to_token`), or `limit` is zero. + A list of _EventDictReturn and a token that points to the end of the + result set. If no events are returned then the end of the stream has + been reached (i.e. there are no events between `from_token` and + `to_token`), or `limit` is zero. """ assert int(limit) >= 0 @@ -1008,35 +1018,38 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows, str(next_token) - @defer.inlineCallbacks - def paginate_room_events( - self, room_id, from_key, to_key=None, direction="b", limit=-1, event_filter=None - ): + async def paginate_room_events( + self, + room_id: str, + from_key: str, + to_key: Optional[str] = None, + direction: str = "b", + limit: int = -1, + event_filter: Optional[Filter] = None, + ) -> Tuple[List[EventBase], str]: """Returns list of events before or after a given token. Args: - room_id (str) - from_key (str): The token used to stream from - to_key (str|None): A token which if given limits the results to - only those before - direction(char): Either 'b' or 'f' to indicate whether we are - paginating forwards or backwards from `from_key`. - limit (int): The maximum number of events to return. - event_filter (Filter|None): If provided filters the events to - those that match the filter. + room_id + from_key: The token used to stream from + to_key: A token which if given limits the results to only those before + direction: Either 'b' or 'f' to indicate whether we are paginating + forwards or backwards from `from_key`. + limit: The maximum number of events to return. + event_filter: If provided filters the events to those that match the filter. Returns: - tuple[list[FrozenEvent], str]: Returns the results as a list of - events and a token that points to the end of the result set. If no - events are returned then the end of the stream has been reached - (i.e. there are no events between `from_key` and `to_key`). + The results as a list of events and a token that points to the end + of the result set. If no events are returned then the end of the + stream has been reached (i.e. there are no events between `from_key` + and `to_key`). """ from_key = RoomStreamToken.parse(from_key) if to_key: to_key = RoomStreamToken.parse(to_key) - rows, token = yield self.db_pool.runInteraction( + rows, token = await self.db_pool.runInteraction( "paginate_room_events", self._paginate_room_events_txn, room_id, @@ -1047,7 +1060,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): event_filter, ) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) @@ -1057,8 +1070,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): class StreamStore(StreamWorkerStore): - def get_room_max_stream_ordering(self): + def get_room_max_stream_ordering(self) -> int: return self._stream_id_gen.get_current_token() - def get_room_min_stream_ordering(self): + def get_room_min_stream_ordering(self) -> int: return self._backfill_id_gen.get_current_token() diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py deleted file mode 100644 index 18a462f0ee..0000000000 --- a/synapse/storage/presence.py +++ /dev/null @@ -1,69 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import namedtuple - -from synapse.api.constants import PresenceState - - -class UserPresenceState( - namedtuple( - "UserPresenceState", - ( - "user_id", - "state", - "last_active_ts", - "last_federation_update_ts", - "last_user_sync_ts", - "status_msg", - "currently_active", - ), - ) -): - """Represents the current presence state of the user. - - user_id (str) - last_active (int): Time in msec that the user last interacted with server. - last_federation_update (int): Time in msec since either a) we sent a presence - update to other servers or b) we received a presence update, depending - on if is a local user or not. - last_user_sync (int): Time in msec that the user last *completed* a sync - (or event stream). - status_msg (str): User set status message. - """ - - def as_dict(self): - return dict(self._asdict()) - - @staticmethod - def from_dict(d): - return UserPresenceState(**d) - - def copy_and_replace(self, **kwargs): - return self._replace(**kwargs) - - @classmethod - def default(cls, user_id): - """Returns a default presence state. - """ - return cls( - user_id=user_id, - state=PresenceState.OFFLINE, - last_active_ts=0, - last_federation_update_ts=0, - last_user_sync_ts=0, - status_msg=None, - currently_active=False, - ) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 05ea40a7de..306dcfe944 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -19,6 +19,7 @@ from mock import Mock, call from signedjson.key import generate_signing_key from synapse.api.constants import EventTypes, Membership, PresenceState +from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events.builder import EventBuilder from synapse.handlers.presence import ( @@ -32,7 +33,6 @@ from synapse.handlers.presence import ( handle_update, ) from synapse.rest.client.v1 import room -from synapse.storage.presence import UserPresenceState from synapse.types import UserID, get_domain_from_id from tests import unittest diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index a6012c973d..918387733b 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -15,6 +15,7 @@ from twisted.internet import defer +from synapse.api.errors import NotFoundError from synapse.rest.client.v1 import room from tests.unittest import HomeserverTestCase @@ -46,30 +47,19 @@ class PurgeTests(HomeserverTestCase): storage = self.hs.get_storage() # Get the topological token - event = store.get_topological_token_for_event(last["event_id"]) - self.pump() - event = self.successResultOf(event) - - # Purge everything before this topological token - purge = defer.ensureDeferred( - storage.purge_events.purge_history(self.room_id, event, True) + event = self.get_success( + store.get_topological_token_for_event(last["event_id"]) ) - self.pump() - self.assertEqual(self.successResultOf(purge), None) - # Try and get the events - get_first = store.get_event(first["event_id"]) - get_second = store.get_event(second["event_id"]) - get_third = store.get_event(third["event_id"]) - get_last = store.get_event(last["event_id"]) - self.pump() + # Purge everything before this topological token + self.get_success(storage.purge_events.purge_history(self.room_id, event, True)) # 1-3 should fail and last will succeed, meaning that 1-3 are deleted # and last is not. - self.failureResultOf(get_first) - self.failureResultOf(get_second) - self.failureResultOf(get_third) - self.successResultOf(get_last) + self.get_failure(store.get_event(first["event_id"]), NotFoundError) + self.get_failure(store.get_event(second["event_id"]), NotFoundError) + self.get_failure(store.get_event(third["event_id"]), NotFoundError) + self.get_success(store.get_event(last["event_id"])) def test_purge_wont_delete_extrems(self): """ @@ -84,9 +74,9 @@ class PurgeTests(HomeserverTestCase): storage = self.hs.get_datastore() # Set the topological token higher than it should be - event = storage.get_topological_token_for_event(last["event_id"]) - self.pump() - event = self.successResultOf(event) + event = self.get_success( + storage.get_topological_token_for_event(last["event_id"]) + ) event = "t{}-{}".format( *list(map(lambda x: x + 1, map(int, event[1:].split("-")))) ) @@ -98,14 +88,7 @@ class PurgeTests(HomeserverTestCase): self.assertIn("greater than forward", f.value.args[0]) # Try and get the events - get_first = storage.get_event(first["event_id"]) - get_second = storage.get_event(second["event_id"]) - get_third = storage.get_event(third["event_id"]) - get_last = storage.get_event(last["event_id"]) - self.pump() - - # Nothing is deleted. - self.successResultOf(get_first) - self.successResultOf(get_second) - self.successResultOf(get_third) - self.successResultOf(get_last) + self.get_success(storage.get_event(first["event_id"])) + self.get_success(storage.get_event(second["event_id"])) + self.get_success(storage.get_event(third["event_id"])) + self.get_success(storage.get_event(last["event_id"])) -- cgit 1.5.1 From 3c01724b330ac99d6defb12634ea9046ae52fe63 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 18 Aug 2020 09:53:13 -0400 Subject: Fix the return type of send_nonmember_events. (#8112) --- changelog.d/8112.misc | 1 + synapse/handlers/message.py | 2 +- synapse/storage/databases/main/stream.py | 19 +++++++++++++++---- 3 files changed, 17 insertions(+), 5 deletions(-) create mode 100644 changelog.d/8112.misc (limited to 'synapse/storage/databases/main/stream.py') diff --git a/changelog.d/8112.misc b/changelog.d/8112.misc new file mode 100644 index 0000000000..80045dde1a --- /dev/null +++ b/changelog.d/8112.misc @@ -0,0 +1 @@ +Return the previous stream token if a non-member event is a duplicate. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index f242d3c6ac..532fc30681 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -674,7 +674,7 @@ class EventCreationHandler(object): event.event_id, prev_event.event_id, ) - return await self.store.get_stream_token_for_event(prev_event.event_id) + return await self.store.get_stream_id_for_event(prev_event.event_id) return await self.handle_new_client_event( requester=requester, event=event, context=context, ratelimit=ratelimit diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 8ccfb8fc46..4377bddb8c 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -582,6 +582,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) return "t%d-%d" % (topo, token) + async def get_stream_id_for_event(self, event_id: str) -> int: + """The stream ID for an event + Args: + event_id: The id of the event to look up a stream token for. + Raises: + StoreError if the event wasn't in the database. + Returns: + A stream ID. + """ + return await self.db_pool.simple_select_one_onecol( + table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" + ) + async def get_stream_token_for_event(self, event_id: str) -> str: """The stream token for an event Args: @@ -591,10 +604,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: A "s%d" stream token. """ - row = await self.db_pool.simple_select_one_onecol( - table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" - ) - return "s%d" % (row,) + stream_id = await self.get_stream_id_for_event(event_id) + return "s%d" % (stream_id,) async def get_topological_token_for_event(self, event_id: str) -> str: """The stream token for an event -- cgit 1.5.1 From f40645e60b9cab69c953094848be61c0989a91cb Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 18 Aug 2020 16:20:49 -0400 Subject: Convert events worker database to async/await. (#8071) --- changelog.d/8071.misc | 1 + synapse/event_auth.py | 2 +- synapse/handlers/federation.py | 16 +-- synapse/handlers/message.py | 6 +- synapse/handlers/room_member.py | 2 +- synapse/spam_checker_api/__init__.py | 2 +- synapse/state/__init__.py | 2 +- synapse/storage/databases/main/event_federation.py | 30 +++-- synapse/storage/databases/main/events_worker.py | 132 ++++++++++++--------- synapse/storage/databases/main/stream.py | 1 - .../test_resource_limits_server_notices.py | 6 +- tests/storage/test_appservice.py | 3 +- 12 files changed, 106 insertions(+), 97 deletions(-) create mode 100644 changelog.d/8071.misc (limited to 'synapse/storage/databases/main/stream.py') diff --git a/changelog.d/8071.misc b/changelog.d/8071.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8071.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index c0981eee62..8c907ad596 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -47,7 +47,7 @@ def check( Args: room_version_obj: the version of the room event: the event being checked. - auth_events (dict: event-key -> event): the existing room state. + auth_events: the existing room state. Raises: AuthError if the checks fail diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 593932adb7..5b270228e7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1777,9 +1777,7 @@ class FederationHandler(BaseHandler): """Returns the state at the event. i.e. not including said event. """ - event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event = await self.store.get_event(event_id, check_room_id=room_id) state_groups = await self.state_store.get_state_groups(room_id, [event_id]) @@ -1805,9 +1803,7 @@ class FederationHandler(BaseHandler): async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: """Returns the state at the event. i.e. not including said event. """ - event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event = await self.store.get_event(event_id, check_room_id=room_id) state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id]) @@ -2155,9 +2151,9 @@ class FederationHandler(BaseHandler): auth_types = auth_types_for_event(event) current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types] - current_auth_events = await self.store.get_events(current_state_ids) + auth_events_map = await self.store.get_events(current_state_ids) current_auth_events = { - (e.type, e.state_key): e for e in current_auth_events.values() + (e.type, e.state_key): e for e in auth_events_map.values() } try: @@ -2173,9 +2169,7 @@ class FederationHandler(BaseHandler): if not in_room: raise AuthError(403, "Host not in room.") - event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event = await self.store.get_event(event_id, check_room_id=room_id) # Just go through and process each event in `remote_auth_chain`. We # don't want to fall into the trap of `missing` being wrong. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 532fc30681..b999d91d1a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -960,7 +960,7 @@ class EventCreationHandler(object): allow_none=True, ) - is_admin_redaction = ( + is_admin_redaction = bool( original_event and event.sender != original_event.sender ) @@ -1080,8 +1080,8 @@ class EventCreationHandler(object): auth_events_ids = self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) - auth_events = await self.store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events.values()} + auth_events_map = await self.store.get_events(auth_events_ids) + auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()} room_version = await self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 31705cdbdb..aa1ccde211 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -716,7 +716,7 @@ class RoomMemberHandler(object): guest_access = await self.store.get_event(guest_access_id) - return ( + return bool( guest_access and guest_access.content and "guest_access" in guest_access.content diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py index 9b78924d96..4d9b13ac04 100644 --- a/synapse/spam_checker_api/__init__.py +++ b/synapse/spam_checker_api/__init__.py @@ -51,5 +51,5 @@ class SpamCheckerApi(object): state_ids = yield self._store.get_filtered_current_state_ids( room_id=room_id, state_filter=StateFilter.from_types(types) ) - state = yield self._store.get_events(state_ids.values()) + state = yield defer.ensureDeferred(self._store.get_events(state_ids.values())) return state.values() diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index a1d3884667..dba8d91eef 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -641,7 +641,7 @@ class StateResolutionStore(object): allow_rejected (bool): If True return rejected events. Returns: - Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. + Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event. """ return self.store.get_events( diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 431bd76693..4826be630c 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): - def get_auth_chain(self, event_ids, include_given=False): + async def get_auth_chain(self, event_ids, include_given=False): """Get auth events for given event_ids. The events *must* be state events. Args: @@ -40,9 +40,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas Returns: list of events """ - return self.get_auth_chain_ids( + event_ids = await self.get_auth_chain_ids( event_ids, include_given=include_given - ).addCallback(self.get_events_as_list) + ) + return await self.get_events_as_list(event_ids) def get_auth_chain_ids( self, @@ -459,7 +460,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn ) - def get_backfill_events(self, room_id, event_list, limit): + async def get_backfill_events(self, room_id, event_list, limit): """Get a list of Events for a given topic that occurred before (and including) the events in event_list. Return a list of max size `limit` @@ -469,17 +470,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas event_list (list) limit (int) """ - return ( - self.db_pool.runInteraction( - "get_backfill_events", - self._get_backfill_events, - room_id, - event_list, - limit, - ) - .addCallback(self.get_events_as_list) - .addCallback(lambda l: sorted(l, key=lambda e: -e.depth)) + event_ids = await self.db_pool.runInteraction( + "get_backfill_events", + self._get_backfill_events, + room_id, + event_list, + limit, ) + events = await self.get_events_as_list(event_ids) + return sorted(events, key=lambda e: -e.depth) def _get_backfill_events(self, txn, room_id, event_list, limit): logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit) @@ -540,8 +539,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas latest_events, limit, ) - events = await self.get_events_as_list(ids) - return events + return await self.get_events_as_list(ids) def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 8c63a0dc4d..e3a154a527 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -19,9 +19,10 @@ import itertools import logging import threading from collections import namedtuple -from typing import List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple, overload from constantly import NamedConstant, Names +from typing_extensions import Literal from twisted.internet import defer @@ -32,7 +33,7 @@ from synapse.api.room_versions import ( EventFormatVersions, RoomVersions, ) -from synapse.events import make_event_from_dict +from synapse.events import EventBase, make_event_from_dict from synapse.events.utils import prune_event from synapse.logging.context import PreserveLoggingContext, current_context from synapse.metrics.background_process_metrics import run_as_background_process @@ -42,8 +43,8 @@ from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator -from synapse.types import get_domain_from_id -from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks +from synapse.types import Collection, get_domain_from_id +from synapse.util.caches.descriptors import Cache, cached from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -137,8 +138,33 @@ class EventsWorkerStore(SQLBaseStore): desc="get_received_ts", ) - @defer.inlineCallbacks - def get_event( + # Inform mypy that if allow_none is False (the default) then get_event + # always returns an EventBase. + @overload + async def get_event( + self, + event_id: str, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[False] = False, + check_room_id: Optional[str] = None, + ) -> EventBase: + ... + + @overload + async def get_event( + self, + event_id: str, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[True] = False, + check_room_id: Optional[str] = None, + ) -> Optional[EventBase]: + ... + + async def get_event( self, event_id: str, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, @@ -146,7 +172,7 @@ class EventsWorkerStore(SQLBaseStore): allow_rejected: bool = False, allow_none: bool = False, check_room_id: Optional[str] = None, - ): + ) -> Optional[EventBase]: """Get an event from the database by event_id. Args: @@ -171,12 +197,12 @@ class EventsWorkerStore(SQLBaseStore): If there is a mismatch, behave as per allow_none. Returns: - Deferred[EventBase|None] + The event, or None if the event was not found. """ if not isinstance(event_id, str): raise TypeError("Invalid event event_id %r" % (event_id,)) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [event_id], redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, @@ -194,14 +220,13 @@ class EventsWorkerStore(SQLBaseStore): return event - @defer.inlineCallbacks - def get_events( + async def get_events( self, - event_ids: List[str], + event_ids: Iterable[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, - ): + ) -> Dict[str, EventBase]: """Get events from the database Args: @@ -220,9 +245,9 @@ class EventsWorkerStore(SQLBaseStore): omits rejeted events from the response. Returns: - Deferred : Dict from event_id to event. + A mapping from event_id to event. """ - events = yield self.get_events_as_list( + events = await self.get_events_as_list( event_ids, redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, @@ -231,14 +256,13 @@ class EventsWorkerStore(SQLBaseStore): return {e.event_id: e for e in events} - @defer.inlineCallbacks - def get_events_as_list( + async def get_events_as_list( self, - event_ids: List[str], + event_ids: Collection[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, - ): + ) -> List[EventBase]: """Get events from the database and return in a list in the same order as given by `event_ids` arg. @@ -259,8 +283,8 @@ class EventsWorkerStore(SQLBaseStore): omits rejected events from the response. Returns: - Deferred[list[EventBase]]: List of events fetched from the database. The - events are in the same order as `event_ids` arg. + List of events fetched from the database. The events are in the same + order as `event_ids` arg. Note that the returned list may be smaller than the list of event IDs if not all events could be fetched. @@ -270,7 +294,7 @@ class EventsWorkerStore(SQLBaseStore): return [] # there may be duplicates so we cast the list to a set - event_entry_map = yield self._get_events_from_cache_or_db( + event_entry_map = await self._get_events_from_cache_or_db( set(event_ids), allow_rejected=allow_rejected ) @@ -305,7 +329,7 @@ class EventsWorkerStore(SQLBaseStore): continue redacted_event_id = entry.event.redacts - event_map = yield self._get_events_from_cache_or_db([redacted_event_id]) + event_map = await self._get_events_from_cache_or_db([redacted_event_id]) original_event_entry = event_map.get(redacted_event_id) if not original_event_entry: # we don't have the redacted event (or it was rejected). @@ -371,7 +395,7 @@ class EventsWorkerStore(SQLBaseStore): if get_prev_content: if "replaces_state" in event.unsigned: - prev = yield self.get_event( + prev = await self.get_event( event.unsigned["replaces_state"], get_prev_content=False, allow_none=True, @@ -383,8 +407,7 @@ class EventsWorkerStore(SQLBaseStore): return events - @defer.inlineCallbacks - def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): + async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): """Fetch a bunch of events from the cache or the database. If events are pulled from the database, they will be cached for future lookups. @@ -399,7 +422,7 @@ class EventsWorkerStore(SQLBaseStore): rejected events are omitted from the response. Returns: - Deferred[Dict[str, _EventCacheEntry]]: + Dict[str, _EventCacheEntry]: map from event id to result """ event_entry_map = self._get_events_from_cache( @@ -417,7 +440,7 @@ class EventsWorkerStore(SQLBaseStore): # the events have been redacted, and if so pulling the redaction event out # of the database to check it. # - missing_events = yield self._get_events_from_db( + missing_events = await self._get_events_from_db( missing_events_ids, allow_rejected=allow_rejected ) @@ -525,8 +548,7 @@ class EventsWorkerStore(SQLBaseStore): with PreserveLoggingContext(): self.hs.get_reactor().callFromThread(fire, event_list, e) - @defer.inlineCallbacks - def _get_events_from_db(self, event_ids, allow_rejected=False): + async def _get_events_from_db(self, event_ids, allow_rejected=False): """Fetch a bunch of events from the database. Returned events will be added to the cache for future lookups. @@ -540,7 +562,7 @@ class EventsWorkerStore(SQLBaseStore): rejected events are omitted from the response. Returns: - Deferred[Dict[str, _EventCacheEntry]]: + Dict[str, _EventCacheEntry]: map from event id to result. May return extra events which weren't asked for. """ @@ -548,7 +570,7 @@ class EventsWorkerStore(SQLBaseStore): events_to_fetch = event_ids while events_to_fetch: - row_map = yield self._enqueue_events(events_to_fetch) + row_map = await self._enqueue_events(events_to_fetch) # we need to recursively fetch any redactions of those events redaction_ids = set() @@ -650,8 +672,7 @@ class EventsWorkerStore(SQLBaseStore): return result_map - @defer.inlineCallbacks - def _enqueue_events(self, events): + async def _enqueue_events(self, events): """Fetches events from the database using the _event_fetch_list. This allows batch and bulk fetching of events - it allows us to fetch events without having to create a new transaction for each request for events. @@ -660,7 +681,7 @@ class EventsWorkerStore(SQLBaseStore): events (Iterable[str]): events to be fetched. Returns: - Deferred[Dict[str, Dict]]: map from event id to row data from the database. + Dict[str, Dict]: map from event id to row data from the database. May contain events that weren't requested. """ @@ -683,7 +704,7 @@ class EventsWorkerStore(SQLBaseStore): logger.debug("Loading %d events: %s", len(events), events) with PreserveLoggingContext(): - row_map = yield events_d + row_map = await events_d logger.debug("Loaded %d events (%d rows)", len(events), len(row_map)) return row_map @@ -842,33 +863,29 @@ class EventsWorkerStore(SQLBaseStore): # no valid redaction found for this event return None - @defer.inlineCallbacks - def have_events_in_timeline(self, event_ids): + async def have_events_in_timeline(self, event_ids): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ - rows = yield defer.ensureDeferred( - self.db_pool.simple_select_many_batch( - table="events", - retcols=("event_id",), - column="event_id", - iterable=list(event_ids), - keyvalues={"outlier": False}, - desc="have_events_in_timeline", - ) + rows = await self.db_pool.simple_select_many_batch( + table="events", + retcols=("event_id",), + column="event_id", + iterable=list(event_ids), + keyvalues={"outlier": False}, + desc="have_events_in_timeline", ) return {r["event_id"] for r in rows} - @defer.inlineCallbacks - def have_seen_events(self, event_ids): + async def have_seen_events(self, event_ids): """Given a list of event ids, check if we have already processed them. Args: event_ids (iterable[str]): Returns: - Deferred[set[str]]: The events we have already seen. + set[str]: The events we have already seen. """ results = set() @@ -884,7 +901,7 @@ class EventsWorkerStore(SQLBaseStore): # break the input up into chunks of 100 input_iterator = iter(event_ids) for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "have_seen_events", have_seen_events_txn, chunk ) return results @@ -914,8 +931,7 @@ class EventsWorkerStore(SQLBaseStore): room_id, ) - @defer.inlineCallbacks - def get_room_complexity(self, room_id): + async def get_room_complexity(self, room_id): """ Get a rough approximation of the complexity of the room. This is used by remote servers to decide whether they wish to join the room or not. @@ -926,9 +942,9 @@ class EventsWorkerStore(SQLBaseStore): room_id (str) Returns: - Deferred[dict[str:int]] of complexity version to complexity. + dict[str:int] of complexity version to complexity. """ - state_events = yield self.get_current_state_event_counts(room_id) + state_events = await self.get_current_state_event_counts(room_id) # Call this one "v1", so we can introduce new ones as we want to develop # it. @@ -1165,9 +1181,9 @@ class EventsWorkerStore(SQLBaseStore): to_2, so_2 = await self.get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) - @cachedInlineCallbacks(max_entries=5000) - def get_event_ordering(self, event_id): - res = yield self.db_pool.simple_select_one( + @cached(max_entries=5000) + async def get_event_ordering(self, event_id): + res = await self.db_pool.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], keyvalues={"event_id": event_id}, diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 4377bddb8c..497f607703 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -379,7 +379,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): limit: int = 0, order: str = "DESC", ) -> Tuple[List[EventBase], str]: - """Get new room events in stream ordering since `from_key`. Args: diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 2858d13558..23db821fb7 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -104,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check the content, but once == remove blocking event @@ -122,7 +122,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -217,7 +217,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index a425e66f37..17fbde284a 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import ( ) from tests import unittest +from tests.test_utils import make_awaitable from tests.utils import setup_test_homeserver @@ -357,7 +358,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): other_events = [Mock(event_id="e5"), Mock(event_id="e6")] # we aren't testing store._base stuff here, so mock this out - self.store.get_events_as_list = Mock(return_value=defer.succeed(events)) + self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) yield self._insert_txn(self.as_list[1]["id"], 9, other_events) yield self._insert_txn(service.id, 10, events) -- cgit 1.5.1 From aec7085179464cc0a05177108ad2962d09ca4f4a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 28 Aug 2020 09:37:55 -0400 Subject: Convert state and stream stores and related code to async (#8194) --- changelog.d/8194.misc | 1 + synapse/handlers/room.py | 2 +- synapse/storage/databases/main/state.py | 19 ++++++++++--------- synapse/storage/databases/main/state_deltas.py | 21 +++++++++++---------- synapse/storage/databases/main/stream.py | 11 +++++++---- synapse/storage/databases/state/store.py | 26 +++++++++++++------------- synapse/storage/state.py | 16 ++++++++-------- 7 files changed, 51 insertions(+), 45 deletions(-) create mode 100644 changelog.d/8194.misc (limited to 'synapse/storage/databases/main/stream.py') diff --git a/changelog.d/8194.misc b/changelog.d/8194.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8194.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 1419d72e94..9d5b1828df 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -451,7 +451,7 @@ class RoomCreationHandler(BaseHandler): old_room_member_state_events = await self.store.get_events( old_room_member_state_ids.values() ) - for k, old_event in old_room_member_state_events.items(): + for old_event in old_room_member_state_events.values(): # Only transfer ban events if ( "membership" in old_event.content diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 458f169617..5c6168e301 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -27,6 +27,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.state import StateFilter +from synapse.types import StateMap from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList @@ -163,15 +164,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return create_event @cached(max_entries=100000, iterable=True) - def get_current_state_ids(self, room_id): + async def get_current_state_ids(self, room_id: str) -> StateMap[str]: """Get the current state event ids for a room based on the current_state_events table. Args: - room_id (str) + room_id: The room to get the state IDs of. Returns: - deferred: dict of (type, state_key) -> event_id + The current state of the room. """ def _get_current_state_ids_txn(txn): @@ -184,14 +185,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn} - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_current_state_ids", _get_current_state_ids_txn ) # FIXME: how should this be cached? - def get_filtered_current_state_ids( + async def get_filtered_current_state_ids( self, room_id: str, state_filter: StateFilter = StateFilter.all() - ): + ) -> StateMap[str]: """Get the current state event of a given type for a room based on the current_state_events table. This may not be as up-to-date as the result of doing a fresh state resolution as per state_handler.get_current_state @@ -202,14 +203,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): from the database. Returns: - defer.Deferred[StateMap[str]]: Map from type/state_key to event ID. + Map from type/state_key to event ID. """ where_clause, where_args = state_filter.make_sql_filter_clause() if not where_clause: # We delegate to the cached version - return self.get_current_state_ids(room_id) + return await self.get_current_state_ids(room_id) def _get_filtered_current_state_ids_txn(txn): results = {} @@ -231,7 +232,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return results - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn ) diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py index 0d963c98ff..356623fc6e 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py @@ -14,8 +14,7 @@ # limitations under the License. import logging - -from twisted.internet import defer +from typing import Any, Dict, List, Tuple from synapse.storage._base import SQLBaseStore @@ -23,7 +22,9 @@ logger = logging.getLogger(__name__) class StateDeltasStore(SQLBaseStore): - def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int): + async def get_current_state_deltas( + self, prev_stream_id: int, max_stream_id: int + ) -> Tuple[int, List[Dict[str, Any]]]: """Fetch a list of room state changes since the given stream id Each entry in the result contains the following fields: @@ -37,12 +38,12 @@ class StateDeltasStore(SQLBaseStore): if it's new state. Args: - prev_stream_id (int): point to get changes since (exclusive) - max_stream_id (int): the point that we know has been correctly persisted + prev_stream_id: point to get changes since (exclusive) + max_stream_id: the point that we know has been correctly persisted - ie, an upper limit to return changes from. Returns: - Deferred[tuple[int, list[dict]]: A tuple consisting of: + A tuple consisting of: - the stream id which these results go up to - list of current_state_delta_stream rows. If it is empty, we are up to date. @@ -58,7 +59,7 @@ class StateDeltasStore(SQLBaseStore): # if the CSDs haven't changed between prev_stream_id and now, we # know for certain that they haven't changed between prev_stream_id and # max_stream_id. - return defer.succeed((max_stream_id, [])) + return (max_stream_id, []) def get_current_state_deltas_txn(txn): # First we calculate the max stream id that will give us less than @@ -102,7 +103,7 @@ class StateDeltasStore(SQLBaseStore): txn.execute(sql, (prev_stream_id, clipped_stream_id)) return clipped_stream_id, self.db_pool.cursor_to_dict(txn) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_current_state_deltas", get_current_state_deltas_txn ) @@ -114,8 +115,8 @@ class StateDeltasStore(SQLBaseStore): retcol="COALESCE(MAX(stream_id), -1)", ) - def get_max_stream_id_in_current_state_deltas(self): - return self.db_pool.runInteraction( + async def get_max_stream_id_in_current_state_deltas(self): + return await self.db_pool.runInteraction( "get_max_stream_id_in_current_state_deltas", self._get_max_stream_id_in_current_state_deltas_txn, ) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 497f607703..24f44a7e36 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -539,7 +539,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows, token - def get_room_event_before_stream_ordering(self, room_id: str, stream_ordering: int): + async def get_room_event_before_stream_ordering( + self, room_id: str, stream_ordering: int + ) -> Tuple[int, int, str]: """Gets details of the first event in a room at or before a stream ordering Args: @@ -547,8 +549,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): stream_ordering: Returns: - Deferred[(int, int, str)]: - (stream ordering, topological ordering, event_id) + A tuple of (stream ordering, topological ordering, event_id) """ def _f(txn): @@ -563,7 +564,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): txn.execute(sql, (room_id, stream_ordering)) return txn.fetchone() - return self.db_pool.runInteraction("get_room_event_before_stream_ordering", _f) + return await self.db_pool.runInteraction( + "get_room_event_before_stream_ordering", _f + ) async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str: """Returns the current token for rooms stream. diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 7f104ad936..e924f1ca3b 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -17,8 +17,6 @@ import logging from collections import namedtuple from typing import Dict, Iterable, List, Set, Tuple -from twisted.internet import defer - from synapse.api.constants import EventTypes from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool @@ -103,7 +101,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ) @cached(max_entries=10000, iterable=True) - def get_state_group_delta(self, state_group): + async def get_state_group_delta(self, state_group): """Given a state group try to return a previous group and a delta between the old and the new. @@ -135,7 +133,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, ) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_state_group_delta", _get_state_group_delta_txn ) @@ -367,9 +365,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): fetched_keys=non_member_types, ) - def store_state_group( + async def store_state_group( self, event_id, room_id, prev_group, delta_ids, current_state_ids - ): + ) -> int: """Store a new set of state, returning a newly assigned state group. Args: @@ -383,7 +381,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): to event_id. Returns: - Deferred[int]: The state group ID + The state group ID """ def _store_state_group_txn(txn): @@ -484,11 +482,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return state_group - return self.db_pool.runInteraction("store_state_group", _store_state_group_txn) + return await self.db_pool.runInteraction( + "store_state_group", _store_state_group_txn + ) - def purge_unreferenced_state_groups( + async def purge_unreferenced_state_groups( self, room_id: str, state_groups_to_delete - ) -> defer.Deferred: + ) -> None: """Deletes no longer referenced state groups and de-deltas any state groups that reference them. @@ -499,7 +499,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): to delete. """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "purge_unreferenced_state_groups", self._purge_unreferenced_state_groups, room_id, @@ -594,7 +594,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return {row["state_group"]: row["prev_state_group"] for row in rows} - def purge_room_state(self, room_id, state_groups_to_delete): + async def purge_room_state(self, room_id, state_groups_to_delete): """Deletes all record of a room from state tables Args: @@ -602,7 +602,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): state_groups_to_delete (list[int]): State groups to delete """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "purge_room_state", self._purge_room_state_txn, room_id, diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 534883361f..96a1b59d64 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -333,7 +333,7 @@ class StateGroupStorage(object): def __init__(self, hs, stores): self.stores = stores - def get_state_group_delta(self, state_group: int): + async def get_state_group_delta(self, state_group: int): """Given a state group try to return a previous group and a delta between the old and the new. @@ -341,11 +341,11 @@ class StateGroupStorage(object): state_group: The state group used to retrieve state deltas. Returns: - Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]: + Tuple[Optional[int], Optional[StateMap[str]]]: (prev_group, delta_ids) """ - return self.stores.state.get_state_group_delta(state_group) + return await self.stores.state.get_state_group_delta(state_group) async def get_state_groups_ids( self, _room_id: str, event_ids: Iterable[str] @@ -525,7 +525,7 @@ class StateGroupStorage(object): state_filter: The state filter used to fetch state from the database. Returns: - A deferred dict from (type, state_key) -> state_event + A dict from (type, state_key) -> state_event """ state_map = await self.get_state_ids_for_events([event_id], state_filter) return state_map[event_id] @@ -546,14 +546,14 @@ class StateGroupStorage(object): """ return self.stores.state._get_state_for_groups(groups, state_filter) - def store_state_group( + async def store_state_group( self, event_id: str, room_id: str, prev_group: Optional[int], delta_ids: Optional[dict], current_state_ids: dict, - ): + ) -> int: """Store a new set of state, returning a newly assigned state group. Args: @@ -567,8 +567,8 @@ class StateGroupStorage(object): to event_id. Returns: - Deferred[int]: The state group ID + The state group ID """ - return self.stores.state.store_state_group( + return await self.stores.state.store_state_group( event_id, room_id, prev_group, delta_ids, current_state_ids ) -- cgit 1.5.1 From 5a1dd297c3ce105a7f516d9d9fe87b94b9d356c8 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 2 Sep 2020 17:19:37 +0100 Subject: Re-implement unread counts (again) (#8059) --- changelog.d/8059.feature | 1 + synapse/handlers/sync.py | 33 +-- synapse/push/bulk_push_rule_evaluator.py | 92 ++++++--- synapse/push/push_tools.py | 2 +- synapse/rest/client/v2_alpha/sync.py | 1 + .../storage/databases/main/event_push_actions.py | 224 +++++++++++++++------ synapse/storage/databases/main/events.py | 4 +- .../main/schema/delta/58/15unread_count.sql | 26 +++ synapse/storage/databases/main/stream.py | 21 +- tests/replication/slave/storage/test_events.py | 10 +- tests/rest/client/v2_alpha/test_sync.py | 157 ++++++++++++++- tests/storage/test_event_push_actions.py | 8 +- 12 files changed, 457 insertions(+), 122 deletions(-) create mode 100644 changelog.d/8059.feature create mode 100644 synapse/storage/databases/main/schema/delta/58/15unread_count.sql (limited to 'synapse/storage/databases/main/stream.py') diff --git a/changelog.d/8059.feature b/changelog.d/8059.feature new file mode 100644 index 0000000000..feb02be234 --- /dev/null +++ b/changelog.d/8059.feature @@ -0,0 +1 @@ +Add unread messages count to sync responses, as specified in [MSC2654](https://github.com/matrix-org/matrix-doc/pull/2654). diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 9a86eb01c9..43fc40fc2f 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -95,7 +95,12 @@ class TimelineBatch: __bool__ = __nonzero__ # python3 -@attr.s(slots=True, frozen=True) +# We can't freeze this class, because we need to update it after it's instantiated to +# update its unread count. This is because we calculate the unread count for a room only +# if there are updates for it, which we check after the instance has been created. +# This should not be a big deal because we update the notification counts afterwards as +# well anyway. +@attr.s(slots=True) class JoinedSyncResult: room_id = attr.ib(type=str) timeline = attr.ib(type=TimelineBatch) @@ -104,6 +109,7 @@ class JoinedSyncResult: account_data = attr.ib(type=List[JsonDict]) unread_notifications = attr.ib(type=JsonDict) summary = attr.ib(type=Optional[JsonDict]) + unread_count = attr.ib(type=int) def __nonzero__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -931,7 +937,7 @@ class SyncHandler(object): async def unread_notifs_for_room_id( self, room_id: str, sync_config: SyncConfig - ) -> Optional[Dict[str, str]]: + ) -> Dict[str, int]: with Measure(self.clock, "unread_notifs_for_room_id"): last_unread_event_id = await self.store.get_last_receipt_event_id_for_user( user_id=sync_config.user.to_string(), @@ -939,15 +945,10 @@ class SyncHandler(object): receipt_type="m.read", ) - if last_unread_event_id: - notifs = await self.store.get_unread_event_push_actions_by_room_for_user( - room_id, sync_config.user.to_string(), last_unread_event_id - ) - return notifs - - # There is no new information in this period, so your notification - # count is whatever it was last time. - return None + notifs = await self.store.get_unread_event_push_actions_by_room_for_user( + room_id, sync_config.user.to_string(), last_unread_event_id + ) + return notifs async def generate_sync_result( self, @@ -1886,7 +1887,7 @@ class SyncHandler(object): ) if room_builder.rtype == "joined": - unread_notifications = {} # type: Dict[str, str] + unread_notifications = {} # type: Dict[str, int] room_sync = JoinedSyncResult( room_id=room_id, timeline=batch, @@ -1895,14 +1896,16 @@ class SyncHandler(object): account_data=account_data_events, unread_notifications=unread_notifications, summary=summary, + unread_count=0, ) if room_sync or always_include: notifs = await self.unread_notifs_for_room_id(room_id, sync_config) - if notifs is not None: - unread_notifications["notification_count"] = notifs["notify_count"] - unread_notifications["highlight_count"] = notifs["highlight_count"] + unread_notifications["notification_count"] = notifs["notify_count"] + unread_notifications["highlight_count"] = notifs["highlight_count"] + + room_sync.unread_count = notifs["unread_count"] sync_result_builder.joined.append(room_sync) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index e7fcee0e87..e7fa02b78b 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -19,8 +19,10 @@ from collections import namedtuple from prometheus_client import Counter -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, Membership, RelationTypes from synapse.event_auth import get_user_power_level +from synapse.events import EventBase +from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.util.async_helpers import Linearizer from synapse.util.caches import register_cache @@ -51,6 +53,48 @@ push_rules_delta_state_cache_metric = register_cache( ) +STATE_EVENT_TYPES_TO_MARK_UNREAD = { + EventTypes.Topic, + EventTypes.Name, + EventTypes.RoomAvatar, + EventTypes.Tombstone, +} + + +def _should_count_as_unread(event: EventBase, context: EventContext) -> bool: + # Exclude rejected and soft-failed events. + if context.rejected or event.internal_metadata.is_soft_failed(): + return False + + # Exclude notices. + if ( + not event.is_state() + and event.type == EventTypes.Message + and event.content.get("msgtype") == "m.notice" + ): + return False + + # Exclude edits. + relates_to = event.content.get("m.relates_to", {}) + if relates_to.get("rel_type") == RelationTypes.REPLACE: + return False + + # Mark events that have a non-empty string body as unread. + body = event.content.get("body") + if isinstance(body, str) and body: + return True + + # Mark some state events as unread. + if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD: + return True + + # Mark encrypted events as unread. + if not event.is_state() and event.type == EventTypes.Encrypted: + return True + + return False + + class BulkPushRuleEvaluator(object): """Calculates the outcome of push rules for an event for all users in the room at once. @@ -133,9 +177,12 @@ class BulkPushRuleEvaluator(object): return pl_event.content if pl_event else {}, sender_level async def action_for_event_by_user(self, event, context) -> None: - """Given an event and context, evaluate the push rules and insert the - results into the event_push_actions_staging table. + """Given an event and context, evaluate the push rules, check if the message + should increment the unread count, and insert the results into the + event_push_actions_staging table. """ + count_as_unread = _should_count_as_unread(event, context) + rules_by_user = await self._get_rules_for_event(event, context) actions_by_user = {} @@ -172,6 +219,8 @@ class BulkPushRuleEvaluator(object): if event.type == EventTypes.Member and event.state_key == uid: display_name = event.content.get("displayname", None) + actions_by_user[uid] = [] + for rule in rules: if "enabled" in rule and not rule["enabled"]: continue @@ -189,7 +238,9 @@ class BulkPushRuleEvaluator(object): # Mark in the DB staging area the push actions for users who should be # notified for this event. (This will then get handled when we persist # the event) - await self.store.add_push_actions_to_staging(event.event_id, actions_by_user) + await self.store.add_push_actions_to_staging( + event.event_id, actions_by_user, count_as_unread, + ) def _condition_checker(evaluator, conditions, uid, display_name, cache): @@ -369,8 +420,8 @@ class RulesForRoom(object): Args: ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets updated with any new rules. - member_event_ids (list): List of event ids for membership events that - have happened since the last time we filled rules_by_user + member_event_ids (dict): Dict of user id to event id for membership events + that have happened since the last time we filled rules_by_user state_group: The state group we are currently computing push rules for. Used when updating the cache. """ @@ -390,34 +441,19 @@ class RulesForRoom(object): if logger.isEnabledFor(logging.DEBUG): logger.debug("Found members %r: %r", self.room_id, members.values()) - interested_in_user_ids = { + user_ids = { user_id for user_id, membership in members.values() if membership == Membership.JOIN } - logger.debug("Joined: %r", interested_in_user_ids) - - if_users_with_pushers = await self.store.get_if_users_have_pushers( - interested_in_user_ids, on_invalidate=self.invalidate_all_cb - ) - - user_ids = { - uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher - } - - logger.debug("With pushers: %r", user_ids) - - users_with_receipts = await self.store.get_users_with_read_receipts_in_room( - self.room_id, on_invalidate=self.invalidate_all_cb - ) - - logger.debug("With receipts: %r", users_with_receipts) + logger.debug("Joined: %r", user_ids) - # any users with pushers must be ours: they have pushers - for uid in users_with_receipts: - if uid in interested_in_user_ids: - user_ids.add(uid) + # Previously we only considered users with pushers or read receipts in that + # room. We can't do this anymore because we use push actions to calculate unread + # counts, which don't rely on the user having pushers or sent a read receipt into + # the room. Therefore we just need to filter for local users here. + user_ids = list(filter(self.is_mine_id, user_ids)) rules_by_user = await self.store.bulk_get_push_rules( user_ids, on_invalidate=self.invalidate_all_cb diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index d0145666bf..f7a25571f3 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -36,7 +36,7 @@ async def get_badge_count(store, user_id): ) # return one badge count per conversation, as count per # message is so noisy as to be almost useless - badge += 1 if notifs["notify_count"] else 0 + badge += 1 if notifs["unread_count"] else 0 return badge diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 96488b131a..a0b00135e1 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -425,6 +425,7 @@ class SyncRestServlet(RestServlet): result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications result["summary"] = room.summary + result["org.matrix.msc2654.unread_count"] = room.unread_count return result diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 384c39f4d7..001d06378d 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -15,7 +15,9 @@ # limitations under the License. import logging -from typing import Dict, List, Union +from typing import Dict, List, Optional, Tuple, Union + +import attr from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json @@ -88,8 +90,26 @@ class EventPushActionsWorkerStore(SQLBaseStore): @cached(num_args=3, tree=True, max_entries=5000) async def get_unread_event_push_actions_by_room_for_user( - self, room_id, user_id, last_read_event_id - ): + self, room_id: str, user_id: str, last_read_event_id: Optional[str], + ) -> Dict[str, int]: + """Get the notification count, the highlight count and the unread message count + for a given user in a given room after the given read receipt. + + Note that this function assumes the user to be a current member of the room, + since it's either called by the sync handler to handle joined room entries, or by + the HTTP pusher to calculate the badge of unread joined rooms. + + Args: + room_id: The room to retrieve the counts in. + user_id: The user to retrieve the counts for. + last_read_event_id: The event associated with the latest read receipt for + this user in this room. None if no receipt for this user in this room. + + Returns + A dict containing the counts mentioned earlier in this docstring, + respectively under the keys "notify_count", "highlight_count" and + "unread_count". + """ return await self.db_pool.runInteraction( "get_unread_event_push_actions_by_room", self._get_unread_counts_by_receipt_txn, @@ -99,69 +119,71 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) def _get_unread_counts_by_receipt_txn( - self, txn, room_id, user_id, last_read_event_id + self, txn, room_id, user_id, last_read_event_id, ): - sql = ( - "SELECT stream_ordering" - " FROM events" - " WHERE room_id = ? AND event_id = ?" - ) - txn.execute(sql, (room_id, last_read_event_id)) - results = txn.fetchall() - if len(results) == 0: - return {"notify_count": 0, "highlight_count": 0} + stream_ordering = None + + if last_read_event_id is not None: + stream_ordering = self.get_stream_id_for_event_txn( + txn, last_read_event_id, allow_none=True, + ) + + if stream_ordering is None: + # Either last_read_event_id is None, or it's an event we don't have (e.g. + # because it's been purged), in which case retrieve the stream ordering for + # the latest membership event from this user in this room (which we assume is + # a join). + event_id = self.db_pool.simple_select_one_onecol_txn( + txn=txn, + table="local_current_membership", + keyvalues={"room_id": room_id, "user_id": user_id}, + retcol="event_id", + ) - stream_ordering = results[0][0] + stream_ordering = self.get_stream_id_for_event_txn(txn, event_id) return self._get_unread_counts_by_pos_txn( txn, room_id, user_id, stream_ordering ) def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering): - - # First get number of notifications. - # We don't need to put a notif=1 clause as all rows always have - # notif=1 sql = ( - "SELECT count(*)" + "SELECT" + " COUNT(CASE WHEN notif = 1 THEN 1 END)," + " COUNT(CASE WHEN highlight = 1 THEN 1 END)," + " COUNT(CASE WHEN unread = 1 THEN 1 END)" " FROM event_push_actions ea" - " WHERE" - " user_id = ?" - " AND room_id = ?" - " AND stream_ordering > ?" + " WHERE user_id = ?" + " AND room_id = ?" + " AND stream_ordering > ?" ) txn.execute(sql, (user_id, room_id, stream_ordering)) row = txn.fetchone() - notify_count = row[0] if row else 0 + + (notif_count, highlight_count, unread_count) = (0, 0, 0) + + if row: + (notif_count, highlight_count, unread_count) = row txn.execute( """ - SELECT notif_count FROM event_push_summary - WHERE room_id = ? AND user_id = ? AND stream_ordering > ? - """, + SELECT notif_count, unread_count FROM event_push_summary + WHERE room_id = ? AND user_id = ? AND stream_ordering > ? + """, (room_id, user_id, stream_ordering), ) - rows = txn.fetchall() - if rows: - notify_count += rows[0][0] - - # Now get the number of highlights - sql = ( - "SELECT count(*)" - " FROM event_push_actions ea" - " WHERE" - " highlight = 1" - " AND user_id = ?" - " AND room_id = ?" - " AND stream_ordering > ?" - ) - - txn.execute(sql, (user_id, room_id, stream_ordering)) row = txn.fetchone() - highlight_count = row[0] if row else 0 - return {"notify_count": notify_count, "highlight_count": highlight_count} + if row: + notif_count += row[0] + unread_count += row[1] + + return { + "notify_count": notif_count, + "unread_count": unread_count, + "highlight_count": highlight_count, + } async def get_push_action_users_in_range( self, min_stream_ordering, max_stream_ordering @@ -222,6 +244,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" + " AND ep.notif = 1" " ORDER BY ep.stream_ordering ASC LIMIT ?" ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] @@ -250,6 +273,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" + " AND ep.notif = 1" " ORDER BY ep.stream_ordering ASC LIMIT ?" ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] @@ -324,6 +348,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" + " AND ep.notif = 1" " ORDER BY ep.stream_ordering DESC LIMIT ?" ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] @@ -352,6 +377,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" + " AND ep.notif = 1" " ORDER BY ep.stream_ordering DESC LIMIT ?" ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] @@ -402,7 +428,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): def _get_if_maybe_push_in_range_for_user_txn(txn): sql = """ SELECT 1 FROM event_push_actions - WHERE user_id = ? AND stream_ordering > ? + WHERE user_id = ? AND stream_ordering > ? AND notif = 1 LIMIT 1 """ @@ -415,7 +441,10 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) async def add_push_actions_to_staging( - self, event_id: str, user_id_actions: Dict[str, List[Union[dict, str]]] + self, + event_id: str, + user_id_actions: Dict[str, List[Union[dict, str]]], + count_as_unread: bool, ) -> None: """Add the push actions for the event to the push action staging area. @@ -423,21 +452,23 @@ class EventPushActionsWorkerStore(SQLBaseStore): event_id user_id_actions: A mapping of user_id to list of push actions, where an action can either be a string or dict. + count_as_unread: Whether this event should increment unread counts. """ - if not user_id_actions: return # This is a helper function for generating the necessary tuple that - # can be used to inert into the `event_push_actions_staging` table. + # can be used to insert into the `event_push_actions_staging` table. def _gen_entry(user_id, actions): is_highlight = 1 if _action_has_highlight(actions) else 0 + notif = 1 if "notify" in actions else 0 return ( event_id, # event_id column user_id, # user_id column _serialize_action(actions, is_highlight), # actions column - 1, # notif column + notif, # notif column is_highlight, # highlight column + int(count_as_unread), # unread column ) def _add_push_actions_to_staging_txn(txn): @@ -446,8 +477,8 @@ class EventPushActionsWorkerStore(SQLBaseStore): sql = """ INSERT INTO event_push_actions_staging - (event_id, user_id, actions, notif, highlight) - VALUES (?, ?, ?, ?, ?) + (event_id, user_id, actions, notif, highlight, unread) + VALUES (?, ?, ?, ?, ?, ?) """ txn.executemany( @@ -811,24 +842,63 @@ class EventPushActionsStore(EventPushActionsWorkerStore): # Calculate the new counts that should be upserted into event_push_summary sql = """ SELECT user_id, room_id, - coalesce(old.notif_count, 0) + upd.notif_count, + coalesce(old.%s, 0) + upd.cnt, upd.stream_ordering, old.user_id FROM ( - SELECT user_id, room_id, count(*) as notif_count, + SELECT user_id, room_id, count(*) as cnt, max(stream_ordering) as stream_ordering FROM event_push_actions WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0 + AND %s = 1 GROUP BY user_id, room_id ) AS upd LEFT JOIN event_push_summary AS old USING (user_id, room_id) """ - txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering)) - rows = txn.fetchall() + # First get the count of unread messages. + txn.execute( + sql % ("unread_count", "unread"), + (old_rotate_stream_ordering, rotate_to_stream_ordering), + ) + + # We need to merge results from the two requests (the one that retrieves the + # unread count and the one that retrieves the notifications count) into a single + # object because we might not have the same amount of rows in each of them. To do + # this, we use a dict indexed on the user ID and room ID to make it easier to + # populate. + summaries = {} # type: Dict[Tuple[str, str], _EventPushSummary] + for row in txn: + summaries[(row[0], row[1])] = _EventPushSummary( + unread_count=row[2], + stream_ordering=row[3], + old_user_id=row[4], + notif_count=0, + ) + + # Then get the count of notifications. + txn.execute( + sql % ("notif_count", "notif"), + (old_rotate_stream_ordering, rotate_to_stream_ordering), + ) + + for row in txn: + if (row[0], row[1]) in summaries: + summaries[(row[0], row[1])].notif_count = row[2] + else: + # Because the rules on notifying are different than the rules on marking + # a message unread, we might end up with messages that notify but aren't + # marked unread, so we might not have a summary for this (user, room) + # tuple to complete. + summaries[(row[0], row[1])] = _EventPushSummary( + unread_count=0, + stream_ordering=row[3], + old_user_id=row[4], + notif_count=row[2], + ) - logger.info("Rotating notifications, handling %d rows", len(rows)) + logger.info("Rotating notifications, handling %d rows", len(summaries)) # If the `old.user_id` above is NULL then we know there isn't already an # entry in the table, so we simply insert it. Otherwise we update the @@ -838,22 +908,34 @@ class EventPushActionsStore(EventPushActionsWorkerStore): table="event_push_summary", values=[ { - "user_id": row[0], - "room_id": row[1], - "notif_count": row[2], - "stream_ordering": row[3], + "user_id": user_id, + "room_id": room_id, + "notif_count": summary.notif_count, + "unread_count": summary.unread_count, + "stream_ordering": summary.stream_ordering, } - for row in rows - if row[4] is None + for ((user_id, room_id), summary) in summaries.items() + if summary.old_user_id is None ], ) txn.executemany( """ - UPDATE event_push_summary SET notif_count = ?, stream_ordering = ? + UPDATE event_push_summary + SET notif_count = ?, unread_count = ?, stream_ordering = ? WHERE user_id = ? AND room_id = ? """, - ((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None), + ( + ( + summary.notif_count, + summary.unread_count, + summary.stream_ordering, + user_id, + room_id, + ) + for ((user_id, room_id), summary) in summaries.items() + if summary.old_user_id is not None + ), ) txn.execute( @@ -879,3 +961,15 @@ def _action_has_highlight(actions): pass return False + + +@attr.s +class _EventPushSummary: + """Summary of pending event push actions for a given user in a given room. + Used in _rotate_notifs_before_txn to manipulate results from event_push_actions. + """ + + unread_count = attr.ib(type=int) + stream_ordering = attr.ib(type=int) + old_user_id = attr.ib(type=str) + notif_count = attr.ib(type=int) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 46b11e705b..b94fe7ac17 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1298,9 +1298,9 @@ class PersistEventsStore: sql = """ INSERT INTO event_push_actions ( room_id, event_id, user_id, actions, stream_ordering, - topological_ordering, notif, highlight + topological_ordering, notif, highlight, unread ) - SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight + SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread FROM event_push_actions_staging WHERE event_id = ? """ diff --git a/synapse/storage/databases/main/schema/delta/58/15unread_count.sql b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql new file mode 100644 index 0000000000..b451e8663a --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql @@ -0,0 +1,26 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We're hijacking the push actions to store unread messages and unread counts (specified +-- in MSC2654) because doing otherwise would result in either performance issues or +-- reimplementing a consequent bit of the push actions. + +-- Add columns to event_push_actions and event_push_actions_staging to track unread +-- messages and calculate unread counts. +ALTER TABLE event_push_actions_staging ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0; +ALTER TABLE event_push_actions ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0; + +-- Add column to event_push_summary +ALTER TABLE event_push_summary ADD COLUMN unread_count BIGINT NOT NULL DEFAULT 0; \ No newline at end of file diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 24f44a7e36..83c1ddf95a 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -47,7 +47,11 @@ from synapse.api.filtering import Filter from synapse.events import EventBase from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, make_in_list_sql_clause +from synapse.storage.database import ( + DatabasePool, + LoggingTransaction, + make_in_list_sql_clause, +) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine from synapse.types import RoomStreamToken @@ -593,8 +597,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: A stream ID. """ - return await self.db_pool.simple_select_one_onecol( - table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" + return await self.db_pool.runInteraction( + "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id, + ) + + def get_stream_id_for_event_txn( + self, txn: LoggingTransaction, event_id: str, allow_none=False, + ) -> int: + return self.db_pool.simple_select_one_onecol_txn( + txn=txn, + table="events", + keyvalues={"event_id": event_id}, + retcol="stream_ordering", + allow_none=allow_none, ) async def get_stream_token_for_event(self, event_id: str) -> str: diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 0b5204654c..561258a356 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -160,7 +160,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], - {"highlight_count": 0, "notify_count": 0}, + {"highlight_count": 0, "unread_count": 0, "notify_count": 0}, ) self.persist( @@ -173,7 +173,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], - {"highlight_count": 0, "notify_count": 1}, + {"highlight_count": 0, "unread_count": 0, "notify_count": 1}, ) self.persist( @@ -188,7 +188,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], - {"highlight_count": 1, "notify_count": 2}, + {"highlight_count": 1, "unread_count": 0, "notify_count": 2}, ) def test_get_rooms_for_user_with_stream_ordering(self): @@ -368,7 +368,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.get_success( self.master_store.add_push_actions_to_staging( - event.event_id, {user_id: actions for user_id, actions in push_actions} + event.event_id, + {user_id: actions for user_id, actions in push_actions}, + False, ) ) return event, context diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index fa3a3ec1bd..a31e44c97e 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -16,9 +16,9 @@ import json import synapse.rest.admin -from synapse.api.constants import EventContentFields, EventTypes +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client.v2_alpha import read_marker, sync from tests import unittest from tests.server import TimedOutException @@ -324,3 +324,156 @@ class SyncTypingTests(unittest.HomeserverTestCase): "GET", sync_url % (access_token, next_batch) ) self.assertRaises(TimedOutException, self.render, request) + + +class UnreadMessagesTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + read_marker.register_servlets, + room.register_servlets, + sync.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.url = "/sync?since=%s" + self.next_batch = "s0" + + # Register the first user (used to check the unread counts). + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + # Create the room we'll check unread counts for. + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + # Register the second user (used to send events to the room). + self.user2 = self.register_user("kermit2", "monkey") + self.tok2 = self.login("kermit2", "monkey") + + # Change the power levels of the room so that the second user can send state + # events. + self.helper.send_state( + self.room_id, + EventTypes.PowerLevels, + { + "users": {self.user_id: 100, self.user2: 100}, + "users_default": 0, + "events": { + "m.room.name": 50, + "m.room.power_levels": 100, + "m.room.history_visibility": 100, + "m.room.canonical_alias": 50, + "m.room.avatar": 50, + "m.room.tombstone": 100, + "m.room.server_acl": 100, + "m.room.encryption": 100, + }, + "events_default": 0, + "state_default": 50, + "ban": 50, + "kick": 50, + "redact": 50, + "invite": 0, + }, + tok=self.tok, + ) + + def test_unread_counts(self): + """Tests that /sync returns the right value for the unread count (MSC2654).""" + + # Check that our own messages don't increase the unread count. + self.helper.send(self.room_id, "hello", tok=self.tok) + self._check_unread_count(0) + + # Join the new user and check that this doesn't increase the unread count. + self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) + self._check_unread_count(0) + + # Check that the new user sending a message increases our unread count. + res = self.helper.send(self.room_id, "hello", tok=self.tok2) + self._check_unread_count(1) + + # Send a read receipt to tell the server we've read the latest event. + body = json.dumps({"m.read": res["event_id"]}).encode("utf8") + request, channel = self.make_request( + "POST", + "/rooms/%s/read_markers" % self.room_id, + body, + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the unread counter is back to 0. + self._check_unread_count(0) + + # Check that room name changes increase the unread counter. + self.helper.send_state( + self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2, + ) + self._check_unread_count(1) + + # Check that room topic changes increase the unread counter. + self.helper.send_state( + self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2, + ) + self._check_unread_count(2) + + # Check that encrypted messages increase the unread counter. + self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2) + self._check_unread_count(3) + + # Check that custom events with a body increase the unread counter. + self.helper.send_event( + self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that edits don't increase the unread counter. + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "body": "hello", + "msgtype": "m.text", + "m.relates_to": {"rel_type": RelationTypes.REPLACE}, + }, + tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that notices don't increase the unread counter. + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"body": "hello", "msgtype": "m.notice"}, + tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that tombstone events changes increase the unread counter. + self.helper.send_state( + self.room_id, + EventTypes.Tombstone, + {"replacement_room": "!someroom:test"}, + tok=self.tok2, + ) + self._check_unread_count(5) + + def _check_unread_count(self, expected_count: True): + """Syncs and compares the unread count with the expected value.""" + + request, channel = self.make_request( + "GET", self.url % self.next_batch, access_token=self.tok, + ) + self.render(request) + + self.assertEqual(channel.code, 200, channel.json_body) + + room_entry = channel.json_body["rooms"]["join"][self.room_id] + self.assertEqual( + room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry, + ) + + # Store the next batch for the next request. + self.next_batch = channel.json_body["next_batch"] diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index cdfd2634aa..c0595963dd 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -67,7 +67,11 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): ) self.assertEquals( counts, - {"notify_count": noitf_count, "highlight_count": highlight_count}, + { + "notify_count": noitf_count, + "unread_count": 0, # Unread counts are tested in the sync tests. + "highlight_count": highlight_count, + }, ) @defer.inlineCallbacks @@ -80,7 +84,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield defer.ensureDeferred( self.store.add_push_actions_to_staging( - event.event_id, {user_id: action} + event.event_id, {user_id: action}, False, ) ) yield defer.ensureDeferred( -- cgit 1.5.1 From 112266eafd457204a34a76fa51d7074d0809a1db Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 2 Sep 2020 17:52:38 +0100 Subject: Add StreamStore to mypy (#8232) --- changelog.d/8232.misc | 1 + mypy.ini | 1 + synapse/events/__init__.py | 4 +-- synapse/storage/database.py | 34 +++++++++++++++++++++++ synapse/storage/databases/main/stream.py | 46 +++++++++++++++++++------------- 5 files changed, 66 insertions(+), 20 deletions(-) create mode 100644 changelog.d/8232.misc (limited to 'synapse/storage/databases/main/stream.py') diff --git a/changelog.d/8232.misc b/changelog.d/8232.misc new file mode 100644 index 0000000000..3a7a352c4f --- /dev/null +++ b/changelog.d/8232.misc @@ -0,0 +1 @@ +Add type hints to `StreamStore`. diff --git a/mypy.ini b/mypy.ini index 21c6f523a0..ae3290d5bb 100644 --- a/mypy.ini +++ b/mypy.ini @@ -43,6 +43,7 @@ files = synapse/server_notices, synapse/spam_checker_api, synapse/state, + synapse/storage/databases/main/stream.py, synapse/storage/databases/main/ui_auth.py, synapse/storage/database.py, synapse/storage/engines, diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 67db763dbf..62ea44fa49 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -18,7 +18,7 @@ import abc import os from distutils.util import strtobool -from typing import Dict, Optional, Type +from typing import Dict, Optional, Tuple, Type from unpaddedbase64 import encode_base64 @@ -120,7 +120,7 @@ class _EventInternalMetadata(object): # be here before = DictProperty("before") # type: str after = DictProperty("after") # type: str - order = DictProperty("order") # type: int + order = DictProperty("order") # type: Tuple[int, int] def get_dict(self) -> JsonDict: return dict(self._dict) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 7ab370efef..af8796ad92 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -604,6 +604,18 @@ class DatabasePool(object): results = [dict(zip(col_headers, row)) for row in cursor] return results + @overload + async def execute( + self, desc: str, decoder: Literal[None], query: str, *args: Any + ) -> List[Tuple[Any, ...]]: + ... + + @overload + async def execute( + self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any + ) -> R: + ... + async def execute( self, desc: str, @@ -1088,6 +1100,28 @@ class DatabasePool(object): desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none ) + @overload + async def simple_select_one_onecol( + self, + table: str, + keyvalues: Dict[str, Any], + retcol: Iterable[str], + allow_none: Literal[False] = False, + desc: str = "simple_select_one_onecol", + ) -> Any: + ... + + @overload + async def simple_select_one_onecol( + self, + table: str, + keyvalues: Dict[str, Any], + retcol: Iterable[str], + allow_none: Literal[True] = True, + desc: str = "simple_select_one_onecol", + ) -> Optional[Any]: + ... + async def simple_select_one_onecol( self, table: str, diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 83c1ddf95a..be6df8a6d1 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -39,7 +39,7 @@ what sort order was used: import abc import logging from collections import namedtuple -from typing import Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple from twisted.internet import defer @@ -54,9 +54,12 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine -from synapse.types import RoomStreamToken +from synapse.types import Collection, RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -206,7 +209,7 @@ def _make_generic_sql_bound( ) -def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]: +def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: # NB: This may create SQL clauses that don't optimise well (and we don't # have indices on all possible clauses). E.g. it may create # "room_id == X AND room_id != X", which postgres doesn't optimise. @@ -264,7 +267,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): __metaclass__ = abc.ABCMeta - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super(StreamWorkerStore, self).__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() @@ -297,16 +300,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): self._stream_order_on_start = self.get_room_max_stream_ordering() @abc.abstractmethod - def get_room_max_stream_ordering(self): + def get_room_max_stream_ordering(self) -> int: raise NotImplementedError() @abc.abstractmethod - def get_room_min_stream_ordering(self): + def get_room_min_stream_ordering(self) -> int: raise NotImplementedError() async def get_room_events_stream_for_rooms( self, - room_ids: Iterable[str], + room_ids: Collection[str], from_key: str, to_key: str, limit: int = 0, @@ -360,19 +363,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return results - def get_rooms_that_changed(self, room_ids, from_key): + def get_rooms_that_changed( + self, room_ids: Collection[str], from_key: str + ) -> Set[str]: """Given a list of rooms and a token, return rooms where there may have been changes. Args: - room_ids (list) - from_key (str): The room_key portion of a StreamToken + room_ids + from_key: The room_key portion of a StreamToken """ - from_key = RoomStreamToken.parse_stream_token(from_key).stream + from_id = RoomStreamToken.parse_stream_token(from_key).stream return { room_id for room_id in room_ids - if self._events_stream_cache.has_entity_changed(room_id, from_key) + if self._events_stream_cache.has_entity_changed(room_id, from_id) } async def get_room_events_stream_for_room( @@ -444,7 +449,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return ret, key - async def get_membership_changes_for_user(self, user_id, from_key, to_key): + async def get_membership_changes_for_user( + self, user_id: str, from_key: str, to_key: str + ) -> List[EventBase]: from_id = RoomStreamToken.parse_stream_token(from_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream @@ -661,7 +668,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) return row[0][0] if row else 0 - def _get_max_topological_txn(self, txn, room_id): + def _get_max_topological_txn(self, txn: LoggingTransaction, room_id: str) -> int: txn.execute( "SELECT MAX(topological_ordering) FROM events WHERE room_id = ?", (room_id,), @@ -734,7 +741,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): def _get_events_around_txn( self, - txn, + txn: LoggingTransaction, room_id: str, event_id: str, before_limit: int, @@ -762,6 +769,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): retcols=["stream_ordering", "topological_ordering"], ) + # This cannot happen as `allow_none=False`. + assert results is not None + # Paginating backwards includes the event at the token, but paginating # forward doesn't. before_token = RoomStreamToken( @@ -871,7 +881,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): desc="update_federation_out_pos", ) - def _reset_federation_positions_txn(self, txn) -> None: + def _reset_federation_positions_txn(self, txn: LoggingTransaction) -> None: """Fiddles with the `federation_stream_position` table to make it match the configured federation sender instances during start up. """ @@ -910,7 +920,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): GROUP BY type """ txn.execute(sql) - min_positions = dict(txn) # Map from type -> min position + min_positions = {typ: pos for typ, pos in txn} # Map from type -> min position # Ensure we do actually have some values here assert set(min_positions) == {"federation", "events"} @@ -937,7 +947,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): def _paginate_room_events_txn( self, - txn, + txn: LoggingTransaction, room_id: str, from_token: RoomStreamToken, to_token: Optional[RoomStreamToken] = None, -- cgit 1.5.1