From 54d77107c1fde88333ecccedfe30f26fb7645847 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 15 May 2019 11:27:50 +0100 Subject: Make generating SQL bounds for pagination generic This will allow us to reuse the same structure when we paginate e.g. relations --- synapse/storage/stream.py | 179 ++++++++++++++++++++++++++++++---------------- 1 file changed, 118 insertions(+), 61 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index d105b6b17d..163363c0c2 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -64,59 +64,120 @@ _EventDictReturn = namedtuple( ) -def lower_bound(token, engine, inclusive=False): - inclusive = "=" if inclusive else "" - if token.topological is None: - return "(%d <%s %s)" % (token.stream, inclusive, "stream_ordering") - else: - if isinstance(engine, PostgresEngine): - # Postgres doesn't optimise ``(x < a) OR (x=a AND y= (topological_ordering, stream_ordering) + AND (5, 3) < (topological_ordering, stream_ordering) + + would be generated for dir=b, from_token=(6, 7) and to_token=(5, 3). + + Args: + direction (str): Whether we're paginating backwards("b") or + forwards ("f"). + column_names (tuple[str, str]): The column names to bound. Must *not* + be user defined as these get inserted directly into the SQL + statement without escapes. + from_token (tuple[int, int]|None) + to_token (tuple[int, int]|None) + engine: The database engine to generate the clauses for + + Returns: + str: The sql expression + """ + assert direction in ("b", "f") + + where_clause = [] + if from_token: + where_clause.append( + _make_generic_sql_bound( + bound=">=" if direction == "b" else "<", + column_names=column_names, + values=from_token, + engine=engine, ) - return "(%d < %s OR (%d = %s AND %d <%s %s))" % ( - token.topological, - "topological_ordering", - token.topological, - "topological_ordering", - token.stream, - inclusive, - "stream_ordering", - ) - - -def upper_bound(token, engine, inclusive=True): - inclusive = "=" if inclusive else "" - if token.topological is None: - return "(%d >%s %s)" % (token.stream, inclusive, "stream_ordering") - else: - if isinstance(engine, PostgresEngine): - # Postgres doesn't optimise ``(x > a) OR (x=a AND y>b)`` as well - # as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we - # use the later form when running against postgres. - return "((%d,%d) >%s (%s,%s))" % ( - token.topological, - token.stream, - inclusive, - "topological_ordering", - "stream_ordering", + ) + + if to_token: + where_clause.append( + _make_generic_sql_bound( + bound="<" if direction == "b" else ">=", + column_names=column_names, + values=to_token, + engine=engine, ) - return "(%d > %s OR (%d = %s AND %d >%s %s))" % ( - token.topological, - "topological_ordering", - token.topological, - "topological_ordering", - token.stream, - inclusive, - "stream_ordering", ) + return " AND ".join(where_clause) + + +def _make_generic_sql_bound(bound, column_names, values, engine): + """Create an SQL expression that bounds the given column names by the + values, e.g. create the equivalent of `(1, 2) < (col1, col2)`. + + Only works with two columns. + + Older versions of SQLite don't support that syntax so we have to expand it + out manually. + + Args: + bound (str): The comparison operator to use. One of ">", "<", ">=", + "<=", where the values are on the left and columns on the right. + names (tuple[str, str]): The column names. Must *not* be user defined + as these get inserted directly into the SQL statement without + escapes. + values (tuple[int, int]): The values to bound the columns by. + engine: The database engine to generate the SQL for + + Returns: + str + """ + + assert(bound in (">", "<", ">=", "<=")) + + name1, name2 = column_names + val1, val2 = values + + if val1 is None: + val2 = int(val2) + return "(%d %s %s)" % (val2, bound, name2) + + val1 = int(val1) + val2 = int(val2) + + if isinstance(engine, PostgresEngine): + # Postgres doesn't optimise ``(x < a) OR (x=a AND y Date: Tue, 14 May 2019 16:59:21 +0100 Subject: Add simple send_relation API and track in DB --- synapse/api/constants.py | 8 ++ synapse/rest/__init__.py | 2 + synapse/rest/client/v2_alpha/relations.py | 110 ++++++++++++++++++++++++++ synapse/storage/__init__.py | 2 + synapse/storage/events.py | 2 + synapse/storage/relations.py | 62 +++++++++++++++ synapse/storage/schema/delta/54/relations.sql | 27 +++++++ tests/rest/client/v2_alpha/test_relations.py | 98 +++++++++++++++++++++++ 8 files changed, 311 insertions(+) create mode 100644 synapse/rest/client/v2_alpha/relations.py create mode 100644 synapse/storage/relations.py create mode 100644 synapse/storage/schema/delta/54/relations.sql create mode 100644 tests/rest/client/v2_alpha/test_relations.py (limited to 'synapse/storage') diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 8547a63535..30bebd749f 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -116,3 +116,11 @@ class UserTypes(object): """ SUPPORT = "support" ALL_USER_TYPES = (SUPPORT,) + + +class RelationTypes(object): + """The types of relations known to this server. + """ + ANNOTATION = "m.annotation" + REPLACES = "m.replaces" + REFERENCES = "m.references" diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 3a24d31d1b..e6110ad9b1 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -44,6 +44,7 @@ from synapse.rest.client.v2_alpha import ( read_marker, receipts, register, + relations, report_event, room_keys, room_upgrade_rest_servlet, @@ -115,6 +116,7 @@ class ClientRestResource(JsonResource): room_upgrade_rest_servlet.register_servlets(hs, client_resource) capabilities.register_servlets(hs, client_resource) account_validity.register_servlets(hs, client_resource) + relations.register_servlets(hs, client_resource) # moving to /_synapse/admin synapse.rest.admin.register_servlets_for_client_rest_resource( diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py new file mode 100644 index 0000000000..b504b4a8be --- /dev/null +++ b/synapse/rest/client/v2_alpha/relations.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This class implements the proposed relation APIs from MSC 1849. + +Since the MSC has not been approved all APIs here are unstable and may change at +any time to reflect changes in the MSC. +""" + +import logging + +from twisted.internet import defer + +from synapse.api.constants import EventTypes +from synapse.api.errors import SynapseError +from synapse.http.servlet import ( + RestServlet, + parse_json_object_from_request, + parse_string, +) + +from ._base import client_v2_patterns + +logger = logging.getLogger(__name__) + + +class RelationSendServlet(RestServlet): + """Helper API for sending events that have relation data. + + Example API shape to send a 👍 reaction to a room: + + POST /rooms/!foo/send_relation/$bar/m.annotation/m.reaction?key=%F0%9F%91%8D + {} + + { + "event_id": "$foobar" + } + """ + + PATTERN = ( + "/rooms/(?P[^/]*)/send_relation" + "/(?P[^/]*)/(?P[^/]*)/(?P[^/]*)" + ) + + def __init__(self, hs): + super(RelationSendServlet, self).__init__() + self.auth = hs.get_auth() + self.event_creation_handler = hs.get_event_creation_handler() + + def register(self, http_server): + http_server.register_paths( + "POST", + client_v2_patterns(self.PATTERN + "$", releases=()), + self.on_PUT_or_POST, + ) + http_server.register_paths( + "PUT", + client_v2_patterns(self.PATTERN + "/(?P[^/]*)$", releases=()), + self.on_PUT_or_POST, + ) + + @defer.inlineCallbacks + def on_PUT_or_POST( + self, request, room_id, parent_id, relation_type, event_type, txn_id=None + ): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + + if event_type == EventTypes.Member: + # Add relations to a membership is meaningless, so we just deny it + # at the CS API rather than trying to handle it correctly. + raise SynapseError(400, "Cannot send member events with relations") + + content = parse_json_object_from_request(request) + + aggregation_key = parse_string(request, "key", encoding="utf-8") + + content["m.relates_to"] = { + "event_id": parent_id, + "key": aggregation_key, + "rel_type": relation_type, + } + + event_dict = { + "type": event_type, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + } + + event = yield self.event_creation_handler.create_and_send_nonmember_event( + requester, event_dict=event_dict, txn_id=txn_id + ) + + defer.returnValue((200, {"event_id": event.event_id})) + + +def register_servlets(hs, http_server): + RelationSendServlet(hs).register(http_server) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index c432041b4e..7522d3fd57 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -49,6 +49,7 @@ from .pusher import PusherStore from .receipts import ReceiptsStore from .registration import RegistrationStore from .rejections import RejectionsStore +from .relations import RelationsStore from .room import RoomStore from .roommember import RoomMemberStore from .search import SearchStore @@ -99,6 +100,7 @@ class DataStore( GroupServerStore, UserErasureStore, MonthlyActiveUsersStore, + RelationsStore, ): def __init__(self, db_conn, hs): self.hs = hs diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 7a7f841c6c..6802bf42ce 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1351,6 +1351,8 @@ class EventsStore( # Insert into the event_search table. self._store_guest_access_txn(txn, event) + self._handle_event_relations(txn, event) + # Insert into the room_memberships table. self._store_room_members_txn( txn, diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py new file mode 100644 index 0000000000..a4905162e0 --- /dev/null +++ b/synapse/storage/relations.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.constants import RelationTypes +from synapse.storage._base import SQLBaseStore + +logger = logging.getLogger(__name__) + + +class RelationsStore(SQLBaseStore): + def _handle_event_relations(self, txn, event): + """Handles inserting relation data during peristence of events + + Args: + txn + event (EventBase) + """ + relation = event.content.get("m.relates_to") + if not relation: + # No relations + return + + rel_type = relation.get("rel_type") + if rel_type not in ( + RelationTypes.ANNOTATION, + RelationTypes.REFERENCES, + RelationTypes.REPLACES, + ): + # Unknown relation type + return + + parent_id = relation.get("event_id") + if not parent_id: + # Invalid relation + return + + aggregation_key = relation.get("key") + + self._simple_insert_txn( + txn, + table="event_relations", + values={ + "event_id": event.event_id, + "relates_to_id": parent_id, + "relation_type": rel_type, + "aggregation_key": aggregation_key, + }, + ) diff --git a/synapse/storage/schema/delta/54/relations.sql b/synapse/storage/schema/delta/54/relations.sql new file mode 100644 index 0000000000..134862b870 --- /dev/null +++ b/synapse/storage/schema/delta/54/relations.sql @@ -0,0 +1,27 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Tracks related events, like reactions, replies, edits, etc. Note that things +-- in this table are not necessarily "valid", e.g. it may contain edits from +-- people who don't have power to edit other peoples events. +CREATE TABLE IF NOT EXISTS event_relations ( + event_id TEXT NOT NULL, + relates_to_id TEXT NOT NULL, + relation_type TEXT NOT NULL, + aggregation_key TEXT +); + +CREATE UNIQUE INDEX event_relations_id ON event_relations(event_id); +CREATE INDEX event_relations_relates ON event_relations(relates_to_id, relation_type, aggregation_key); diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py new file mode 100644 index 0000000000..61163d5b26 --- /dev/null +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import six + +from synapse.api.constants import EventTypes, RelationTypes +from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import relations + +from tests import unittest + + +class RelationsTestCase(unittest.HomeserverTestCase): + user_id = "@alice:test" + servlets = [ + relations.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.room = self.helper.create_room_as(self.user_id) + res = self.helper.send(self.room, body="Hi!") + self.parent_id = res["event_id"] + + def test_send_relation(self): + """Tests that sending a relation using the new /send_relation works + creates the right shape of event. + """ + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") + self.assertEquals(200, channel.code, channel.json_body) + + event_id = channel.json_body["event_id"] + + request, channel = self.make_request( + "GET", "/rooms/%s/event/%s" % (self.room, event_id) + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assert_dict( + { + "type": "m.reaction", + "sender": self.user_id, + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "key": u"👍", + "rel_type": RelationTypes.ANNOTATION, + } + }, + }, + channel.json_body, + ) + + def test_deny_membership(self): + """Test that we deny relations on membership events + """ + channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member) + self.assertEquals(400, channel.code, channel.json_body) + + def _send_relation(self, relation_type, event_type, key=None): + """Helper function to send a relation pointing at `self.parent_id` + + Args: + relation_type (str): One of `RelationTypes` + event_type (str): The type of the event to create + key (str|None): The aggregation key used for m.annotation relation + type. + + Returns: + FakeChannel + """ + query = "" + if key: + query = "?key=" + six.moves.urllib.parse.quote_plus(key) + + request, channel = self.make_request( + "POST", + "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" + % (self.room, self.parent_id, relation_type, event_type, query), + b"{}", + ) + self.render(request) + return channel -- cgit 1.5.1 From b50641e357f6139b0807df3714440a8ccc21d628 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 14 May 2019 16:59:21 +0100 Subject: Add simple pagination API --- synapse/rest/client/v2_alpha/relations.py | 50 +++++++++++++++++ synapse/storage/relations.py | 80 ++++++++++++++++++++++++++++ tests/rest/client/v2_alpha/test_relations.py | 30 +++++++++++ 3 files changed, 160 insertions(+) (limited to 'synapse/storage') diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index b504b4a8be..bac9e85c21 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -27,6 +27,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError from synapse.http.servlet import ( RestServlet, + parse_integer, parse_json_object_from_request, parse_string, ) @@ -106,5 +107,54 @@ class RelationSendServlet(RestServlet): defer.returnValue((200, {"event_id": event.event_id})) +class RelationPaginationServlet(RestServlet): + """API to paginate relations on an event by topological ordering, optionally + filtered by relation type and event type. + """ + + PATTERNS = client_v2_patterns( + "/rooms/(?P[^/]*)/relations/(?P[^/]*)" + "(/(?P[^/]*)(/(?P[^/]*))?)?$", + releases=(), + ) + + def __init__(self, hs): + super(RelationPaginationServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self._event_serializer = hs.get_event_client_serializer() + + @defer.inlineCallbacks + def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + + yield self.auth.check_in_room_or_world_readable( + room_id, requester.user.to_string() + ) + + limit = parse_integer(request, "limit", default=5) + + result = yield self.store.get_relations_for_event( + event_id=parent_id, + relation_type=relation_type, + event_type=event_type, + limit=limit, + ) + + events = yield self.store.get_events_as_list( + [c["event_id"] for c in result.chunk] + ) + + now = self.clock.time_msec() + events = yield self._event_serializer.serialize_events(events, now) + + return_value = result.to_dict() + return_value["chunk"] = events + + defer.returnValue((200, return_value)) + + def register_servlets(hs, http_server): RelationSendServlet(hs).register(http_server) + RelationPaginationServlet(hs).register(http_server) diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index a4905162e0..f304b8a9a4 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -15,13 +15,93 @@ import logging +import attr + from synapse.api.constants import RelationTypes from synapse.storage._base import SQLBaseStore logger = logging.getLogger(__name__) +@attr.s +class PaginationChunk(object): + """Returned by relation pagination APIs. + + Attributes: + chunk (list): The rows returned by pagination + """ + + chunk = attr.ib() + + def to_dict(self): + d = {"chunk": self.chunk} + + return d + + class RelationsStore(SQLBaseStore): + def get_relations_for_event( + self, event_id, relation_type=None, event_type=None, limit=5, direction="b" + ): + """Get a list of relations for an event, ordered by topological ordering. + + Args: + event_id (str): Fetch events that relate to this event ID. + relation_type (str|None): Only fetch events with this relation + type, if given. + event_type (str|None): Only fetch events with this event type, if + given. + limit (int): Only fetch the most recent `limit` events. + direction (str): Whether to fetch the most recent first (`"b"`) or + the oldest first (`"f"`). + + Returns: + Deferred[PaginationChunk]: List of event IDs that match relations + requested. The rows are of the form `{"event_id": "..."}`. + """ + + # TODO: Pagination tokens + + where_clause = ["relates_to_id = ?"] + where_args = [event_id] + + if relation_type: + where_clause.append("relation_type = ?") + where_args.append(relation_type) + + if event_type: + where_clause.append("type = ?") + where_args.append(event_type) + + order = "ASC" + if direction == "b": + order = "DESC" + + sql = """ + SELECT event_id FROM event_relations + INNER JOIN events USING (event_id) + WHERE %s + ORDER BY topological_ordering %s, stream_ordering %s + LIMIT ? + """ % ( + " AND ".join(where_clause), + order, + order, + ) + + def _get_recent_references_for_event_txn(txn): + txn.execute(sql, where_args + [limit + 1]) + + events = [{"event_id": row[0]} for row in txn] + + return PaginationChunk( + chunk=list(events[:limit]), + ) + + return self.runInteraction( + "get_recent_references_for_event", _get_recent_references_for_event_txn + ) + def _handle_event_relations(self, txn, event): """Handles inserting relation data during peristence of events diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index 61163d5b26..bcc1c1bb85 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -72,6 +72,36 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member) self.assertEquals(400, channel.code, channel.json_body) + def test_paginate(self): + """Tests that calling pagination API corectly the latest relations. + """ + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + self.assertEquals(200, channel.code, channel.json_body) + annotation_id = channel.json_body["event_id"] + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1" + % (self.room, self.parent_id), + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + # We expect to get back a single pagination result, which is the full + # relation event we sent above. + self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assert_dict( + { + "event_id": annotation_id, + "sender": self.user_id, + "type": "m.reaction", + }, + channel.json_body["chunk"][0], + ) + def _send_relation(self, relation_type, event_type, key=None): """Helper function to send a relation pointing at `self.parent_id` -- cgit 1.5.1 From 5be34fc3e3045b3ba5a97b73ba0f25887874e622 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 15 May 2019 17:30:23 +0100 Subject: Actually check for None rather falsey --- synapse/storage/relations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index f304b8a9a4..31ef6679af 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -65,11 +65,11 @@ class RelationsStore(SQLBaseStore): where_clause = ["relates_to_id = ?"] where_args = [event_id] - if relation_type: + if relation_type is not None: where_clause.append("relation_type = ?") where_args.append(relation_type) - if event_type: + if event_type is not None: where_clause.append("type = ?") where_args.append(event_type) -- cgit 1.5.1 From a0603523d2e210cf59f887bd75e1a755720cb7a8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 14 May 2019 16:59:21 +0100 Subject: Add aggregations API --- synapse/config/server.py | 5 + synapse/events/utils.py | 34 +++- synapse/rest/client/v2_alpha/relations.py | 141 ++++++++++++++- synapse/storage/relations.py | 225 +++++++++++++++++++++++- tests/rest/client/v2_alpha/test_relations.py | 251 ++++++++++++++++++++++++++- 5 files changed, 643 insertions(+), 13 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/config/server.py b/synapse/config/server.py index 7874cd9da7..f403477b54 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -100,6 +100,11 @@ class ServerConfig(Config): "block_non_admin_invites", False, ) + # Whether to enable experimental MSC1849 (aka relations) support + self.experimental_msc1849_support_enabled = config.get( + "experimental_msc1849_support_enabled", False, + ) + # Options to control access by tracking MAU self.limit_usage_by_mau = config.get("limit_usage_by_mau", False) self.max_mau_value = 0 diff --git a/synapse/events/utils.py b/synapse/events/utils.py index a5454556cc..16d0c64372 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -21,7 +21,7 @@ from frozendict import frozendict from twisted.internet import defer -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, RelationTypes from synapse.util.async_helpers import yieldable_gather_results from . import EventBase @@ -324,8 +324,12 @@ class EventClientSerializer(object): """ def __init__(self, hs): - pass + self.store = hs.get_datastore() + self.experimental_msc1849_support_enabled = ( + hs.config.experimental_msc1849_support_enabled + ) + @defer.inlineCallbacks def serialize_event(self, event, time_now, **kwargs): """Serializes a single event. @@ -337,8 +341,32 @@ class EventClientSerializer(object): Returns: Deferred[dict]: The serialized event """ + # To handle the case of presence events and the like + if not isinstance(event, EventBase): + defer.returnValue(event) + + event_id = event.event_id event = serialize_event(event, time_now, **kwargs) - return defer.succeed(event) + + # If MSC1849 is enabled then we need to look if thre are any relations + # we need to bundle in with the event + if self.experimental_msc1849_support_enabled: + annotations = yield self.store.get_aggregation_groups_for_event( + event_id, + ) + references = yield self.store.get_relations_for_event( + event_id, RelationTypes.REFERENCES, direction="f", + ) + + if annotations.chunk: + r = event["unsigned"].setdefault("m.relations", {}) + r[RelationTypes.ANNOTATION] = annotations.to_dict() + + if references.chunk: + r = event["unsigned"].setdefault("m.relations", {}) + r[RelationTypes.REFERENCES] = references.to_dict() + + defer.returnValue(event) def serialize_events(self, events, time_now, **kwargs): """Serializes multiple events. diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index c3ac73b8c7..f468409b68 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -23,7 +23,7 @@ import logging from twisted.internet import defer -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.http.servlet import ( RestServlet, @@ -141,12 +141,16 @@ class RelationPaginationServlet(RestServlet): ) limit = parse_integer(request, "limit", default=5) + from_token = parse_string(request, "from") + to_token = parse_string(request, "to") result = yield self.store.get_relations_for_event( event_id=parent_id, relation_type=relation_type, event_type=event_type, limit=limit, + from_token=from_token, + to_token=to_token, ) events = yield self.store.get_events_as_list( @@ -162,6 +166,141 @@ class RelationPaginationServlet(RestServlet): defer.returnValue((200, return_value)) +class RelationAggregationPaginationServlet(RestServlet): + """API to paginate aggregation groups of relations, e.g. paginate the + types and counts of the reactions on the events. + + Example request and response: + + GET /rooms/{room_id}/aggregations/{parent_id} + + { + chunk: [ + { + "type": "m.reaction", + "key": "👍", + "count": 3 + } + ] + } + """ + + PATTERNS = client_v2_patterns( + "/rooms/(?P[^/]*)/aggregations/(?P[^/]*)" + "(/(?P[^/]*)(/(?P[^/]*))?)?$", + releases=(), + ) + + def __init__(self, hs): + super(RelationAggregationPaginationServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + + yield self.auth.check_in_room_or_world_readable( + room_id, requester.user.to_string() + ) + + if relation_type not in (RelationTypes.ANNOTATION, None): + raise SynapseError(400, "Relation type must be 'annotation'") + + limit = parse_integer(request, "limit", default=5) + from_token = parse_string(request, "from") + to_token = parse_string(request, "to") + + res = yield self.store.get_aggregation_groups_for_event( + event_id=parent_id, + event_type=event_type, + limit=limit, + from_token=from_token, + to_token=to_token, + ) + + defer.returnValue((200, res.to_dict())) + + +class RelationAggregationGroupPaginationServlet(RestServlet): + """API to paginate within an aggregation group of relations, e.g. paginate + all the 👍 reactions on an event. + + Example request and response: + + GET /rooms/{room_id}/aggregations/{parent_id}/m.annotation/m.reaction/👍 + + { + chunk: [ + { + "type": "m.reaction", + "content": { + "m.relates_to": { + "rel_type": "m.annotation", + "key": "👍" + } + } + }, + ... + ] + } + """ + + PATTERNS = client_v2_patterns( + "/rooms/(?P[^/]*)/aggregations/(?P[^/]*)" + "/(?P[^/]*)/(?P[^/]*)/(?P[^/]*)$", + releases=(), + ) + + def __init__(self, hs): + super(RelationAggregationGroupPaginationServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self._event_serializer = hs.get_event_client_serializer() + + @defer.inlineCallbacks + def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + + yield self.auth.check_in_room_or_world_readable( + room_id, requester.user.to_string() + ) + + if relation_type != RelationTypes.ANNOTATION: + raise SynapseError(400, "Relation type must be 'annotation'") + + limit = parse_integer(request, "limit", default=5) + from_token = parse_string(request, "from") + to_token = parse_string(request, "to") + + result = yield self.store.get_relations_for_event( + event_id=parent_id, + relation_type=relation_type, + event_type=event_type, + aggregation_key=key, + limit=limit, + from_token=from_token, + to_token=to_token, + ) + + events = yield self.store.get_events_as_list( + [c["event_id"] for c in result.chunk] + ) + + now = self.clock.time_msec() + events = yield self._event_serializer.serialize_events(events, now) + + return_value = result.to_dict() + return_value["chunk"] = events + + defer.returnValue((200, return_value)) + + defer.returnValue((200, return_value)) + + def register_servlets(hs, http_server): RelationSendServlet(hs).register(http_server) RelationPaginationServlet(hs).register(http_server) + RelationAggregationPaginationServlet(hs).register(http_server) + RelationAggregationGroupPaginationServlet(hs).register(http_server) diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 31ef6679af..db4b842c97 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -18,7 +18,9 @@ import logging import attr from synapse.api.constants import RelationTypes +from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore +from synapse.storage.stream import generate_pagination_where_clause logger = logging.getLogger(__name__) @@ -29,19 +31,94 @@ class PaginationChunk(object): Attributes: chunk (list): The rows returned by pagination + next_batch (Any|None): Token to fetch next set of results with, if + None then there are no more results. + prev_batch (Any|None): Token to fetch previous set of results with, if + None then there are no previous results. """ chunk = attr.ib() + next_batch = attr.ib(default=None) + prev_batch = attr.ib(default=None) def to_dict(self): d = {"chunk": self.chunk} + if self.next_batch: + d["next_batch"] = self.next_batch.to_string() + + if self.prev_batch: + d["prev_batch"] = self.prev_batch.to_string() + return d +@attr.s +class RelationPaginationToken(object): + """Pagination token for relation pagination API. + + As the results are order by topological ordering, we can use the + `topological_ordering` and `stream_ordering` fields of the events at the + boundaries of the chunk as pagination tokens. + + Attributes: + topological (int): The topological ordering of the boundary event + stream (int): The stream ordering of the boundary event. + """ + + topological = attr.ib() + stream = attr.ib() + + @staticmethod + def from_string(string): + try: + t, s = string.split("-") + return RelationPaginationToken(int(t), int(s)) + except ValueError: + raise SynapseError(400, "Invalid token") + + def to_string(self): + return "%d-%d" % (self.topological, self.stream) + + +@attr.s +class AggregationPaginationToken(object): + """Pagination token for relation aggregation pagination API. + + As the results are order by count and then MAX(stream_ordering) of the + aggregation groups, we can just use them as our pagination token. + + Attributes: + count (int): The count of relations in the boundar group. + stream (int): The MAX stream ordering in the boundary group. + """ + + count = attr.ib() + stream = attr.ib() + + @staticmethod + def from_string(string): + try: + c, s = string.split("-") + return AggregationPaginationToken(int(c), int(s)) + except ValueError: + raise SynapseError(400, "Invalid token") + + def to_string(self): + return "%d-%d" % (self.count, self.stream) + + class RelationsStore(SQLBaseStore): def get_relations_for_event( - self, event_id, relation_type=None, event_type=None, limit=5, direction="b" + self, + event_id, + relation_type=None, + event_type=None, + aggregation_key=None, + limit=5, + direction="b", + from_token=None, + to_token=None, ): """Get a list of relations for an event, ordered by topological ordering. @@ -51,16 +128,26 @@ class RelationsStore(SQLBaseStore): type, if given. event_type (str|None): Only fetch events with this event type, if given. + aggregation_key (str|None): Only fetch events with this aggregation + key, if given. limit (int): Only fetch the most recent `limit` events. direction (str): Whether to fetch the most recent first (`"b"`) or the oldest first (`"f"`). + from_token (RelationPaginationToken|None): Fetch rows from the given + token, or from the start if None. + to_token (RelationPaginationToken|None): Fetch rows up to the given + token, or up to the end if None. Returns: Deferred[PaginationChunk]: List of event IDs that match relations requested. The rows are of the form `{"event_id": "..."}`. """ - # TODO: Pagination tokens + if from_token: + from_token = RelationPaginationToken.from_string(from_token) + + if to_token: + to_token = RelationPaginationToken.from_string(to_token) where_clause = ["relates_to_id = ?"] where_args = [event_id] @@ -73,12 +160,29 @@ class RelationsStore(SQLBaseStore): where_clause.append("type = ?") where_args.append(event_type) - order = "ASC" + if aggregation_key: + where_clause.append("aggregation_key = ?") + where_args.append(aggregation_key) + + pagination_clause = generate_pagination_where_clause( + direction=direction, + column_names=("topological_ordering", "stream_ordering"), + from_token=attr.astuple(from_token) if from_token else None, + to_token=attr.astuple(to_token) if to_token else None, + engine=self.database_engine, + ) + + if pagination_clause: + where_clause.append(pagination_clause) + if direction == "b": order = "DESC" + else: + order = "ASC" sql = """ - SELECT event_id FROM event_relations + SELECT event_id, topological_ordering, stream_ordering + FROM event_relations INNER JOIN events USING (event_id) WHERE %s ORDER BY topological_ordering %s, stream_ordering %s @@ -92,16 +196,125 @@ class RelationsStore(SQLBaseStore): def _get_recent_references_for_event_txn(txn): txn.execute(sql, where_args + [limit + 1]) - events = [{"event_id": row[0]} for row in txn] + last_topo_id = None + last_stream_id = None + events = [] + for row in txn: + events.append({"event_id": row[0]}) + last_topo_id = row[1] + last_stream_id = row[2] + + next_batch = None + if len(events) > limit and last_topo_id and last_stream_id: + next_batch = RelationPaginationToken(last_topo_id, last_stream_id) return PaginationChunk( - chunk=list(events[:limit]), + chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token ) return self.runInteraction( "get_recent_references_for_event", _get_recent_references_for_event_txn ) + def get_aggregation_groups_for_event( + self, + event_id, + event_type=None, + limit=5, + direction="b", + from_token=None, + to_token=None, + ): + """Get a list of annotations on the event, grouped by event type and + aggregation key, sorted by count. + + This is used e.g. to get the what and how many reactions have happend + on an event. + + Args: + event_id (str): Fetch events that relate to this event ID. + event_type (str|None): Only fetch events with this event type, if + given. + limit (int): Only fetch the `limit` groups. + direction (str): Whether to fetch the highest count first (`"b"`) or + the lowest count first (`"f"`). + from_token (AggregationPaginationToken|None): Fetch rows from the + given token, or from the start if None. + to_token (AggregationPaginationToken|None): Fetch rows up to the + given token, or up to the end if None. + + + Returns: + Deferred[PaginationChunk]: List of groups of annotations that + match. Each row is a dict with `type`, `key` and `count` fields. + """ + + if from_token: + from_token = AggregationPaginationToken.from_string(from_token) + + if to_token: + to_token = AggregationPaginationToken.from_string(to_token) + + where_clause = ["relates_to_id = ?", "relation_type = ?"] + where_args = [event_id, RelationTypes.ANNOTATION] + + if event_type: + where_clause.append("type = ?") + where_args.append(event_type) + + having_clause = generate_pagination_where_clause( + direction=direction, + column_names=("COUNT(*)", "MAX(stream_ordering)"), + from_token=attr.astuple(from_token) if from_token else None, + to_token=attr.astuple(to_token) if to_token else None, + engine=self.database_engine, + ) + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + if having_clause: + having_clause = "HAVING " + having_clause + else: + having_clause = "" + + sql = """ + SELECT type, aggregation_key, COUNT(*), MAX(stream_ordering) + FROM event_relations + INNER JOIN events USING (event_id) + WHERE {where_clause} + GROUP BY relation_type, type, aggregation_key + {having_clause} + ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order} + LIMIT ? + """.format( + where_clause=" AND ".join(where_clause), + order=order, + having_clause=having_clause, + ) + + def _get_aggregation_groups_for_event_txn(txn): + txn.execute(sql, where_args + [limit + 1]) + + next_batch = None + events = [] + for row in txn: + events.append({"type": row[0], "key": row[1], "count": row[2]}) + next_batch = AggregationPaginationToken(row[2], row[3]) + + if len(events) <= limit: + next_batch = None + + return PaginationChunk( + chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token + ) + + return self.runInteraction( + "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn + ) + def _handle_event_relations(self, txn, event): """Handles inserting relation data during peristence of events diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index bcc1c1bb85..661dd45755 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools + import six from synapse.api.constants import EventTypes, RelationTypes @@ -30,6 +32,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): login.register_servlets, ] + def make_homeserver(self, reactor, clock): + # We need to enable msc1849 support for aggregations + config = self.default_config() + config["experimental_msc1849_support_enabled"] = True + return self.setup_test_homeserver(config=config) + def prepare(self, reactor, clock, hs): self.room = self.helper.create_room_as(self.user_id) res = self.helper.send(self.room, body="Hi!") @@ -40,7 +48,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): creates the right shape of event. """ - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key=u"👍") self.assertEquals(200, channel.code, channel.json_body) event_id = channel.json_body["event_id"] @@ -72,7 +80,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member) self.assertEquals(400, channel.code, channel.json_body) - def test_paginate(self): + def test_basic_paginate_relations(self): """Tests that calling pagination API corectly the latest relations. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") @@ -102,6 +110,243 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel.json_body["chunk"][0], ) + # Make sure next_batch has something in it that looks like it could be a + # valid token. + self.assertIsInstance( + channel.json_body.get("next_batch"), six.string_types, channel.json_body + ) + + def test_repeated_paginate_relations(self): + """Test that if we paginate using a limit and tokens then we get the + expected events. + """ + + expected_event_ids = [] + for _ in range(10): + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + self.assertEquals(200, channel.code, channel.json_body) + expected_event_ids.append(channel.json_body["event_id"]) + + prev_token = None + found_event_ids = [] + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s" + % (self.room, self.parent_id, from_token), + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + next_batch = channel.json_body.get("next_batch") + + self.assertNotEquals(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEquals(found_event_ids, expected_event_ids) + + def test_aggregation_pagination_groups(self): + """Test that we can paginate annotation groups correctly. + """ + + sent_groups = {u"👍": 10, u"a": 7, u"b": 5, u"c": 3, u"d": 2, u"e": 1} + for key in itertools.chain.from_iterable( + itertools.repeat(key, num) for key, num in sent_groups.items() + ): + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", key=key + ) + self.assertEquals(200, channel.code, channel.json_body) + + prev_token = None + found_groups = {} + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s" + % (self.room, self.parent_id, from_token), + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + + for groups in channel.json_body["chunk"]: + # We only expect reactions + self.assertEqual(groups["type"], "m.reaction", channel.json_body) + + # We should only see each key once + self.assertNotIn(groups["key"], found_groups, channel.json_body) + + found_groups[groups["key"]] = groups["count"] + + next_batch = channel.json_body.get("next_batch") + + self.assertNotEquals(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + self.assertEquals(sent_groups, found_groups) + + def test_aggregation_pagination_within_group(self): + """Test that we can paginate within an annotation group. + """ + + expected_event_ids = [] + for _ in range(10): + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", key=u"👍" + ) + self.assertEquals(200, channel.code, channel.json_body) + expected_event_ids.append(channel.json_body["event_id"]) + + # Also send a different type of reaction so that we test we don't see it + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") + self.assertEquals(200, channel.code, channel.json_body) + + prev_token = None + found_event_ids = [] + encoded_key = six.moves.urllib.parse.quote_plus(u"👍".encode("utf-8")) + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s" + "/aggregations/%s/%s/m.reaction/%s?limit=1%s" + % ( + self.room, + self.parent_id, + RelationTypes.ANNOTATION, + encoded_key, + from_token, + ), + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + + next_batch = channel.json_body.get("next_batch") + + self.assertNotEquals(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEquals(found_event_ids, expected_event_ids) + + def test_aggregation(self): + """Test that annotations get correctly aggregated. + """ + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") + self.assertEquals(200, channel.code, channel.json_body) + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s" + % (self.room, self.parent_id), + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals( + channel.json_body, + { + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + ] + }, + ) + + def test_aggregation_must_be_annotation(self): + """Test that aggregations must be annotations. + """ + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/m.replaces/%s?limit=1" + % (self.room, self.parent_id), + ) + self.render(request) + self.assertEquals(400, channel.code, channel.json_body) + + def test_aggregation_get_event(self): + """Test that annotations and references get correctly bundled when + getting the parent event. + """ + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.REFERENCES, "m.room.test") + self.assertEquals(200, channel.code, channel.json_body) + reply_1 = channel.json_body["event_id"] + + channel = self._send_relation(RelationTypes.REFERENCES, "m.room.test") + self.assertEquals(200, channel.code, channel.json_body) + reply_2 = channel.json_body["event_id"] + + request, channel = self.make_request( + "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id) + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.maxDiff = None + + self.assertEquals( + channel.json_body["unsigned"].get("m.relations"), + { + RelationTypes.ANNOTATION: { + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + ] + }, + RelationTypes.REFERENCES: { + "chunk": [{"event_id": reply_1}, {"event_id": reply_2}] + }, + }, + ) + def _send_relation(self, relation_type, event_type, key=None): """Helper function to send a relation pointing at `self.parent_id` @@ -116,7 +361,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): """ query = "" if key: - query = "?key=" + six.moves.urllib.parse.quote_plus(key) + query = "?key=" + six.moves.urllib.parse.quote_plus(key.encode("utf-8")) request, channel = self.make_request( "POST", -- cgit 1.5.1 From 33453419b0cbbe29d3f4b376e9eafab10d5d012a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 16 May 2019 09:56:12 +0100 Subject: Add cache to relations --- synapse/storage/relations.py | 8 ++++++++ 1 file changed, 8 insertions(+) (limited to 'synapse/storage') diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index db4b842c97..1fd3d4fafc 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -21,6 +21,7 @@ from synapse.api.constants import RelationTypes from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore from synapse.storage.stream import generate_pagination_where_clause +from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -109,6 +110,7 @@ class AggregationPaginationToken(object): class RelationsStore(SQLBaseStore): + @cached(tree=True) def get_relations_for_event( self, event_id, @@ -216,6 +218,7 @@ class RelationsStore(SQLBaseStore): "get_recent_references_for_event", _get_recent_references_for_event_txn ) + @cached(tree=True) def get_aggregation_groups_for_event( self, event_id, @@ -353,3 +356,8 @@ class RelationsStore(SQLBaseStore): "aggregation_key": aggregation_key, }, ) + + txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,)) + txn.call_after( + self.get_aggregation_groups_for_event.invalidate_many, (parent_id,) + ) -- cgit 1.5.1 From b5c62c6b2643a138956af0b04521540c270a3e94 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 16 May 2019 10:18:53 +0100 Subject: Fix relations in worker mode --- synapse/replication/slave/storage/events.py | 13 ++++++++++--- synapse/replication/tcp/streams/_base.py | 1 + synapse/replication/tcp/streams/events.py | 11 ++++++----- synapse/storage/events.py | 12 ++++++++---- synapse/storage/relations.py | 4 +++- 5 files changed, 28 insertions(+), 13 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index b457c5563f..797450bc66 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import ( from synapse.storage.event_federation import EventFederationWorkerStore from synapse.storage.event_push_actions import EventPushActionsWorkerStore from synapse.storage.events_worker import EventsWorkerStore +from synapse.storage.relations import RelationsWorkerStore from synapse.storage.roommember import RoomMemberWorkerStore from synapse.storage.signatures import SignatureWorkerStore from synapse.storage.state import StateGroupWorkerStore @@ -52,6 +53,7 @@ class SlavedEventStore(EventFederationWorkerStore, EventsWorkerStore, SignatureWorkerStore, UserErasureWorkerStore, + RelationsWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): @@ -89,7 +91,7 @@ class SlavedEventStore(EventFederationWorkerStore, for row in rows: self.invalidate_caches_for_event( -token, row.event_id, row.room_id, row.type, row.state_key, - row.redacts, + row.redacts, row.relates_to, backfilled=True, ) return super(SlavedEventStore, self).process_replication_rows( @@ -102,7 +104,7 @@ class SlavedEventStore(EventFederationWorkerStore, if row.type == EventsStreamEventRow.TypeId: self.invalidate_caches_for_event( token, data.event_id, data.room_id, data.type, data.state_key, - data.redacts, + data.redacts, data.relates_to, backfilled=False, ) elif row.type == EventsStreamCurrentStateRow.TypeId: @@ -114,7 +116,8 @@ class SlavedEventStore(EventFederationWorkerStore, raise Exception("Unknown events stream row type %s" % (row.type, )) def invalidate_caches_for_event(self, stream_ordering, event_id, room_id, - etype, state_key, redacts, backfilled): + etype, state_key, redacts, relates_to, + backfilled): self._invalidate_get_event_cache(event_id) self.get_latest_event_ids_in_room.invalidate((room_id,)) @@ -136,3 +139,7 @@ class SlavedEventStore(EventFederationWorkerStore, state_key, stream_ordering ) self.get_invited_rooms_for_user.invalidate((state_key,)) + + if relates_to: + self.get_relations_for_event.invalidate_many((relates_to,)) + self.get_aggregation_groups_for_event.invalidate_many((relates_to,)) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 8971a6a22e..b6ce7a7bee 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -32,6 +32,7 @@ BackfillStreamRow = namedtuple("BackfillStreamRow", ( "type", # str "state_key", # str, optional "redacts", # str, optional + "relates_to", # str, optional )) PresenceStreamRow = namedtuple("PresenceStreamRow", ( "user_id", # str diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index e0f6e29248..f1290d022a 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -80,11 +80,12 @@ class BaseEventsStreamRow(object): class EventsStreamEventRow(BaseEventsStreamRow): TypeId = "ev" - event_id = attr.ib() # str - room_id = attr.ib() # str - type = attr.ib() # str - state_key = attr.ib() # str, optional - redacts = attr.ib() # str, optional + event_id = attr.ib() # str + room_id = attr.ib() # str + type = attr.ib() # str + state_key = attr.ib() # str, optional + redacts = attr.ib() # str, optional + relates_to = attr.ib() # str, optional @attr.s(slots=True, frozen=True) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 6802bf42ce..b025ebc926 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1657,10 +1657,11 @@ class EventsStore( def get_all_new_forward_event_rows(txn): sql = ( "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" + " state_key, redacts, relates_to_id" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" " WHERE ? < stream_ordering AND stream_ordering <= ?" " ORDER BY stream_ordering ASC" " LIMIT ?" @@ -1675,11 +1676,12 @@ class EventsStore( sql = ( "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" + " state_key, redacts, relates_to_id" " FROM events AS e" " INNER JOIN ex_outlier_stream USING (event_id)" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" " WHERE ? < event_stream_ordering" " AND event_stream_ordering <= ?" " ORDER BY event_stream_ordering DESC" @@ -1700,10 +1702,11 @@ class EventsStore( def get_all_new_backfill_event_rows(txn): sql = ( "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" + " state_key, redacts, relates_to_id" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" " WHERE ? > stream_ordering AND stream_ordering >= ?" " ORDER BY stream_ordering ASC" " LIMIT ?" @@ -1718,11 +1721,12 @@ class EventsStore( sql = ( "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" + " state_key, redacts, relates_to_id" " FROM events AS e" " INNER JOIN ex_outlier_stream USING (event_id)" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" " WHERE ? > event_stream_ordering" " AND event_stream_ordering >= ?" " ORDER BY event_stream_ordering DESC" diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 1fd3d4fafc..732418ec65 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -109,7 +109,7 @@ class AggregationPaginationToken(object): return "%d-%d" % (self.count, self.stream) -class RelationsStore(SQLBaseStore): +class RelationsWorkerStore(SQLBaseStore): @cached(tree=True) def get_relations_for_event( self, @@ -318,6 +318,8 @@ class RelationsStore(SQLBaseStore): "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn ) + +class RelationsStore(RelationsWorkerStore): def _handle_event_relations(self, txn, event): """Handles inserting relation data during peristence of events -- cgit 1.5.1 From 2c662ddde4e1277a0dc17295748aa5f0c41fa163 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 16 May 2019 14:21:39 +0100 Subject: Indirect tuple conversion --- synapse/storage/relations.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'synapse/storage') diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 732418ec65..996cb6903a 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -81,6 +81,9 @@ class RelationPaginationToken(object): def to_string(self): return "%d-%d" % (self.topological, self.stream) + def as_tuple(self): + return attr.astuple(self) + @attr.s class AggregationPaginationToken(object): @@ -108,6 +111,9 @@ class AggregationPaginationToken(object): def to_string(self): return "%d-%d" % (self.count, self.stream) + def as_tuple(self): + return attr.astuple(self) + class RelationsWorkerStore(SQLBaseStore): @cached(tree=True) -- cgit 1.5.1 From 7a7eba8302d6566c044ca82c8ac0da65aa85e36b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 16 May 2019 14:24:58 +0100 Subject: Move parsing of tokens out of storage layer --- synapse/rest/client/v2_alpha/relations.py | 19 +++++++++++++++++++ synapse/storage/relations.py | 16 ++-------------- 2 files changed, 21 insertions(+), 14 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index 1b53e638eb..41e0a44936 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -32,6 +32,7 @@ from synapse.http.servlet import ( parse_string, ) from synapse.rest.client.transactions import HttpTransactionCache +from synapse.storage.relations import AggregationPaginationToken, RelationPaginationToken from ._base import client_v2_patterns @@ -149,6 +150,12 @@ class RelationPaginationServlet(RestServlet): from_token = parse_string(request, "from") to_token = parse_string(request, "to") + if from_token: + from_token = RelationPaginationToken.from_string(from_token) + + if to_token: + to_token = RelationPaginationToken.from_string(to_token) + result = yield self.store.get_relations_for_event( event_id=parent_id, relation_type=relation_type, @@ -221,6 +228,12 @@ class RelationAggregationPaginationServlet(RestServlet): from_token = parse_string(request, "from") to_token = parse_string(request, "to") + if from_token: + from_token = AggregationPaginationToken.from_string(from_token) + + if to_token: + to_token = AggregationPaginationToken.from_string(to_token) + res = yield self.store.get_aggregation_groups_for_event( event_id=parent_id, event_type=event_type, @@ -289,6 +302,12 @@ class RelationAggregationGroupPaginationServlet(RestServlet): from_token = parse_string(request, "from") to_token = parse_string(request, "to") + if from_token: + from_token = RelationPaginationToken.from_string(from_token) + + if to_token: + to_token = RelationPaginationToken.from_string(to_token) + result = yield self.store.get_relations_for_event( event_id=parent_id, relation_type=relation_type, diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 996cb6903a..de67e305a1 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -54,7 +54,7 @@ class PaginationChunk(object): return d -@attr.s +@attr.s(frozen=True, slots=True) class RelationPaginationToken(object): """Pagination token for relation pagination API. @@ -85,7 +85,7 @@ class RelationPaginationToken(object): return attr.astuple(self) -@attr.s +@attr.s(frozen=True, slots=True) class AggregationPaginationToken(object): """Pagination token for relation aggregation pagination API. @@ -151,12 +151,6 @@ class RelationsWorkerStore(SQLBaseStore): requested. The rows are of the form `{"event_id": "..."}`. """ - if from_token: - from_token = RelationPaginationToken.from_string(from_token) - - if to_token: - to_token = RelationPaginationToken.from_string(to_token) - where_clause = ["relates_to_id = ?"] where_args = [event_id] @@ -258,12 +252,6 @@ class RelationsWorkerStore(SQLBaseStore): match. Each row is a dict with `type`, `key` and `count` fields. """ - if from_token: - from_token = AggregationPaginationToken.from_string(from_token) - - if to_token: - to_token = AggregationPaginationToken.from_string(to_token) - where_clause = ["relates_to_id = ?", "relation_type = ?"] where_args = [event_id, RelationTypes.ANNOTATION] -- cgit 1.5.1 From 895179a4dcc993af9702e3ab36e5d2157fdd0c26 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 16 May 2019 16:41:05 +0100 Subject: Update docstring --- synapse/storage/stream.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 163363c0c2..824e77c89e 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -77,6 +77,15 @@ def generate_pagination_where_clause( would be generated for dir=b, from_token=(6, 7) and to_token=(5, 3). + Note that tokens are considered to be after the row they are in, e.g. if + a row A has a token T, then we consider A to be before T. This covention + is important when figuring out inequalities for the generated SQL, and + produces the following result: + - If paginatiting forwards then we exclude any rows matching the from + token, but include those that match the to token. + - If paginatiting backwards then we include any rows matching the from + token, but include those that match the to token. + Args: direction (str): Whether we're paginating backwards("b") or forwards ("f"). @@ -131,7 +140,9 @@ def _make_generic_sql_bound(bound, column_names, values, engine): names (tuple[str, str]): The column names. Must *not* be user defined as these get inserted directly into the SQL statement without escapes. - values (tuple[int, int]): The values to bound the columns by. + values (tuple[int|None, int]): The values to bound the columns by. If + the first value is None then only creates a bound on the second + column. engine: The database engine to generate the SQL for Returns: -- cgit 1.5.1 From d46aab3fa8bc02d8db9616490e849b9443fed8e7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 14 May 2019 16:59:21 +0100 Subject: Add basic editing support --- synapse/events/utils.py | 30 +++++++-- synapse/replication/slave/storage/events.py | 1 + synapse/storage/relations.py | 60 +++++++++++++++++- tests/rest/client/v2_alpha/test_relations.py | 91 +++++++++++++++++++++++++--- 4 files changed, 167 insertions(+), 15 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 16d0c64372..2019ce9b1c 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -346,7 +346,7 @@ class EventClientSerializer(object): defer.returnValue(event) event_id = event.event_id - event = serialize_event(event, time_now, **kwargs) + serialized_event = serialize_event(event, time_now, **kwargs) # If MSC1849 is enabled then we need to look if thre are any relations # we need to bundle in with the event @@ -359,14 +359,36 @@ class EventClientSerializer(object): ) if annotations.chunk: - r = event["unsigned"].setdefault("m.relations", {}) + r = serialized_event["unsigned"].setdefault("m.relations", {}) r[RelationTypes.ANNOTATION] = annotations.to_dict() if references.chunk: - r = event["unsigned"].setdefault("m.relations", {}) + r = serialized_event["unsigned"].setdefault("m.relations", {}) r[RelationTypes.REFERENCES] = references.to_dict() - defer.returnValue(event) + edit = None + if event.type == EventTypes.Message: + edit = yield self.store.get_applicable_edit( + event.event_id, event.type, event.sender, + ) + + if edit: + # If there is an edit replace the content, preserving existing + # relations. + + relations = event.content.get("m.relates_to") + serialized_event["content"] = edit.content.get("m.new_content", {}) + if relations: + serialized_event["content"]["m.relates_to"] = relations + else: + serialized_event["content"].pop("m.relates_to", None) + + r = serialized_event["unsigned"].setdefault("m.relations", {}) + r[RelationTypes.REPLACES] = { + "event_id": edit.event_id, + } + + defer.returnValue(serialized_event) def serialize_events(self, events, time_now, **kwargs): """Serializes multiple events. diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 797450bc66..857128b311 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -143,3 +143,4 @@ class SlavedEventStore(EventFederationWorkerStore, if relates_to: self.get_relations_for_event.invalidate_many((relates_to,)) self.get_aggregation_groups_for_event.invalidate_many((relates_to,)) + self.get_applicable_edit.invalidate_many((relates_to,)) diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index de67e305a1..aa91e52733 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -17,11 +17,13 @@ import logging import attr -from synapse.api.constants import RelationTypes +from twisted.internet import defer + +from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore from synapse.storage.stream import generate_pagination_where_clause -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks logger = logging.getLogger(__name__) @@ -312,6 +314,59 @@ class RelationsWorkerStore(SQLBaseStore): "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn ) + @cachedInlineCallbacks(tree=True) + def get_applicable_edit(self, event_id, event_type, sender): + """Get the most recent edit (if any) that has happened for the given + event. + + Correctly handles checking whether edits were allowed to happen. + + Args: + event_id (str): The original event ID + event_type (str): The original event type + sender (str): The original event sender + + Returns: + Deferred[EventBase|None]: Returns the most recent edit, if any. + """ + + # We only allow edits for `m.room.message` events that have the same sender + # and event type. We can't assert these things during regular event auth so + # we have to do the post hoc. + + if event_type != EventTypes.Message: + return + + sql = """ + SELECT event_id, origin_server_ts FROM events + INNER JOIN event_relations USING (event_id) + WHERE + relates_to_id = ? + AND relation_type = ? + AND type = ? + AND sender = ? + ORDER by origin_server_ts DESC, event_id DESC + LIMIT 1 + """ + + def _get_applicable_edit_txn(txn): + txn.execute( + sql, (event_id, RelationTypes.REPLACES, event_type, sender) + ) + row = txn.fetchone() + if row: + return row[0] + + edit_id = yield self.runInteraction( + "get_applicable_edit", _get_applicable_edit_txn + ) + + if not edit_id: + return + + edit_event = yield self.get_event(edit_id, allow_none=True) + defer.returnValue(edit_event) + class RelationsStore(RelationsWorkerStore): def _handle_event_relations(self, txn, event): @@ -357,3 +412,4 @@ class RelationsStore(RelationsWorkerStore): txn.call_after( self.get_aggregation_groups_for_event.invalidate_many, (parent_id,) ) + txn.call_after(self.get_applicable_edit.invalidate_many, (parent_id,)) diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index 775622bd2b..b0e4c47ae3 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -14,6 +14,7 @@ # limitations under the License. import itertools +import json import six @@ -102,11 +103,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): # relation event we sent above. self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) self.assert_dict( - { - "event_id": annotation_id, - "sender": self.user_id, - "type": "m.reaction", - }, + {"event_id": annotation_id, "sender": self.user_id, "type": "m.reaction"}, channel.json_body["chunk"][0], ) @@ -330,8 +327,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEquals(200, channel.code, channel.json_body) - self.maxDiff = None - self.assertEquals( channel.json_body["unsigned"].get("m.relations"), { @@ -347,7 +342,84 @@ class RelationsTestCase(unittest.HomeserverTestCase): }, ) - def _send_relation(self, relation_type, event_type, key=None): + def test_edit(self): + """Test that a simple edit works. + """ + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACES, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + ) + self.assertEquals(200, channel.code, channel.json_body) + + edit_event_id = channel.json_body["event_id"] + + request, channel = self.make_request( + "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id) + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals(channel.json_body["content"], new_body) + + self.assertEquals( + channel.json_body["unsigned"].get("m.relations"), + {RelationTypes.REPLACES: {"event_id": edit_event_id}}, + ) + + def test_multi_edit(self): + """Test that multiple edits, including attempts by people who + shouldn't be allowed, are correctly handled. + """ + + channel = self._send_relation( + RelationTypes.REPLACES, + "m.room.message", + content={ + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": {"msgtype": "m.text", "body": "First edit"}, + }, + ) + self.assertEquals(200, channel.code, channel.json_body) + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACES, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + ) + self.assertEquals(200, channel.code, channel.json_body) + + edit_event_id = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.REPLACES, + "m.room.message.WRONG_TYPE", + content={ + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"}, + }, + ) + self.assertEquals(200, channel.code, channel.json_body) + + request, channel = self.make_request( + "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id) + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals(channel.json_body["content"], new_body) + + self.assertEquals( + channel.json_body["unsigned"].get("m.relations"), + {RelationTypes.REPLACES: {"event_id": edit_event_id}}, + ) + + def _send_relation(self, relation_type, event_type, key=None, content={}): """Helper function to send a relation pointing at `self.parent_id` Args: @@ -355,6 +427,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): event_type (str): The type of the event to create key (str|None): The aggregation key used for m.annotation relation type. + content(dict|None): The content of the created event. Returns: FakeChannel @@ -367,7 +440,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "POST", "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" % (self.room, self.parent_id, relation_type, event_type, query), - b"{}", + json.dumps(content).encode("utf-8"), ) self.render(request) return channel -- cgit 1.5.1 From 5dbff345094327e61fa9c1425ba0f955c9530ba7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 17 May 2019 15:48:04 +0100 Subject: Fixup bsaed on review comments --- synapse/events/utils.py | 4 +--- synapse/replication/slave/storage/events.py | 2 +- synapse/storage/relations.py | 32 +++++++++++++++-------------- 3 files changed, 19 insertions(+), 19 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 2019ce9b1c..bf3c8f8dc1 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -368,9 +368,7 @@ class EventClientSerializer(object): edit = None if event.type == EventTypes.Message: - edit = yield self.store.get_applicable_edit( - event.event_id, event.type, event.sender, - ) + edit = yield self.store.get_applicable_edit(event_id) if edit: # If there is an edit replace the content, preserving existing diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 857128b311..a3952506c1 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -143,4 +143,4 @@ class SlavedEventStore(EventFederationWorkerStore, if relates_to: self.get_relations_for_event.invalidate_many((relates_to,)) self.get_aggregation_groups_for_event.invalidate_many((relates_to,)) - self.get_applicable_edit.invalidate_many((relates_to,)) + self.get_applicable_edit.invalidate((relates_to,)) diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index aa91e52733..6e216066ab 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -19,7 +19,7 @@ import attr from twisted.internet import defer -from synapse.api.constants import EventTypes, RelationTypes +from synapse.api.constants import RelationTypes from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore from synapse.storage.stream import generate_pagination_where_clause @@ -314,8 +314,8 @@ class RelationsWorkerStore(SQLBaseStore): "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn ) - @cachedInlineCallbacks(tree=True) - def get_applicable_edit(self, event_id, event_type, sender): + @cachedInlineCallbacks() + def get_applicable_edit(self, event_id): """Get the most recent edit (if any) that has happened for the given event. @@ -323,8 +323,6 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id (str): The original event ID - event_type (str): The original event type - sender (str): The original event sender Returns: Deferred[EventBase|None]: Returns the most recent edit, if any. @@ -332,26 +330,28 @@ class RelationsWorkerStore(SQLBaseStore): # We only allow edits for `m.room.message` events that have the same sender # and event type. We can't assert these things during regular event auth so - # we have to do the post hoc. - - if event_type != EventTypes.Message: - return + # we have to do the checks post hoc. + # Fetches latest edit that has the same type and sender as the + # original, and is an `m.room.message`. sql = """ - SELECT event_id, origin_server_ts FROM events + SELECT edit.event_id FROM events AS edit INNER JOIN event_relations USING (event_id) + INNER JOIN events AS original ON + original.event_id = relates_to_id + AND edit.type = original.type + AND edit.sender = original.sender WHERE relates_to_id = ? AND relation_type = ? - AND type = ? - AND sender = ? - ORDER by origin_server_ts DESC, event_id DESC + AND edit.type = 'm.room.message' + ORDER by edit.origin_server_ts DESC, edit.event_id DESC LIMIT 1 """ def _get_applicable_edit_txn(txn): txn.execute( - sql, (event_id, RelationTypes.REPLACES, event_type, sender) + sql, (event_id, RelationTypes.REPLACES,) ) row = txn.fetchone() if row: @@ -412,4 +412,6 @@ class RelationsStore(RelationsWorkerStore): txn.call_after( self.get_aggregation_groups_for_event.invalidate_many, (parent_id,) ) - txn.call_after(self.get_applicable_edit.invalidate_many, (parent_id,)) + + if rel_type == RelationTypes.REPLACES: + txn.call_after(self.get_applicable_edit.invalidate, (parent_id,)) -- cgit 1.5.1 From 8dd9cca8ead9fc0cf0789edf9fcfce04814c5eda Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 17 May 2019 16:40:51 +0100 Subject: Spelling and clarifications --- synapse/storage/stream.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 824e77c89e..529ad4ea79 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -78,12 +78,12 @@ def generate_pagination_where_clause( would be generated for dir=b, from_token=(6, 7) and to_token=(5, 3). Note that tokens are considered to be after the row they are in, e.g. if - a row A has a token T, then we consider A to be before T. This covention + a row A has a token T, then we consider A to be before T. This convention is important when figuring out inequalities for the generated SQL, and produces the following result: - - If paginatiting forwards then we exclude any rows matching the from + - If paginating forwards then we exclude any rows matching the from token, but include those that match the to token. - - If paginatiting backwards then we include any rows matching the from + - If paginating backwards then we include any rows matching the from token, but include those that match the to token. Args: @@ -92,8 +92,12 @@ def generate_pagination_where_clause( column_names (tuple[str, str]): The column names to bound. Must *not* be user defined as these get inserted directly into the SQL statement without escapes. - from_token (tuple[int, int]|None) - to_token (tuple[int, int]|None) + from_token (tuple[int, int]|None): The start point for the pagination. + This is an exclusive minimum bound if direction is "f", and an + inclusive maximum bound if direction is "b". + to_token (tuple[int, int]|None): The endpoint point for the pagination. + This is an inclusive maximum bound if direction is "f", and an + exclusive minimum bound if direction is "b". engine: The database engine to generate the clauses for Returns: -- cgit 1.5.1 From b63cc325a9362874a13f0079589afccd7d108f1f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 17 May 2019 18:01:39 +0100 Subject: Only count aggregations from distinct senders As a user isn't allowed to send a single emoji more than once. --- synapse/storage/relations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 6e216066ab..42f587a7d8 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -280,7 +280,7 @@ class RelationsWorkerStore(SQLBaseStore): having_clause = "" sql = """ - SELECT type, aggregation_key, COUNT(*), MAX(stream_ordering) + SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering) FROM event_relations INNER JOIN events USING (event_id) WHERE {where_clause} -- cgit 1.5.1 From 935af0da380f39ba284b78054270331bdbad7712 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 20 May 2019 10:13:05 +0100 Subject: Correctly update aggregation counts after redaction --- synapse/storage/events.py | 3 +++ synapse/storage/relations.py | 17 +++++++++++++ tests/rest/client/v2_alpha/test_relations.py | 37 ++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+) (limited to 'synapse/storage') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index b025ebc926..881d6d0126 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1325,6 +1325,9 @@ class EventsStore( txn, event.room_id, event.redacts ) + # Remove from relations table. + self._handle_redaction(txn, event.redacts) + # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. self._handle_mult_prev_events( diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 6e216066ab..63e6185ee3 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -415,3 +415,20 @@ class RelationsStore(RelationsWorkerStore): if rel_type == RelationTypes.REPLACES: txn.call_after(self.get_applicable_edit.invalidate, (parent_id,)) + + def _handle_redaction(self, txn, redacted_event_id): + """Handles receiving a redaction and checking whether we need to remove + any redacted relations from the database. + + Args: + txn + redacted_event_id (str): The event that was redacted. + """ + + self._simple_delete_txn( + txn, + table="event_relations", + keyvalues={ + "event_id": redacted_event_id, + } + ) diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index cd965167f8..3737cdc396 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -320,6 +320,43 @@ class RelationsTestCase(unittest.HomeserverTestCase): }, ) + def test_aggregation_redactions(self): + """Test that annotations get correctly aggregated after a redactions. + """ + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + to_redact_event_id = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Now lets redact the 'a' reaction + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id), + access_token=self.user_token, + content={}, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s" + % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals( + channel.json_body, + {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, + ) + def test_aggregation_must_be_annotation(self): """Test that aggregations must be annotations. """ -- cgit 1.5.1 From 1dff859d6aed58561a2b5913e5c9b897bbd3599c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 20 May 2019 14:31:19 +0100 Subject: Rename relation types to match MSC --- synapse/api/constants.py | 4 ++-- synapse/events/utils.py | 6 +++--- synapse/storage/relations.py | 8 ++++---- tests/rest/client/v2_alpha/test_relations.py | 22 +++++++++++----------- 4 files changed, 20 insertions(+), 20 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 28862a1d25..6b347b1749 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -125,5 +125,5 @@ class RelationTypes(object): """The types of relations known to this server. """ ANNOTATION = "m.annotation" - REPLACES = "m.replaces" - REFERENCES = "m.references" + REPLACE = "m.replace" + REFERENCE = "m.reference" diff --git a/synapse/events/utils.py b/synapse/events/utils.py index bf3c8f8dc1..27a2a9ef98 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -355,7 +355,7 @@ class EventClientSerializer(object): event_id, ) references = yield self.store.get_relations_for_event( - event_id, RelationTypes.REFERENCES, direction="f", + event_id, RelationTypes.REFERENCE, direction="f", ) if annotations.chunk: @@ -364,7 +364,7 @@ class EventClientSerializer(object): if references.chunk: r = serialized_event["unsigned"].setdefault("m.relations", {}) - r[RelationTypes.REFERENCES] = references.to_dict() + r[RelationTypes.REFERENCE] = references.to_dict() edit = None if event.type == EventTypes.Message: @@ -382,7 +382,7 @@ class EventClientSerializer(object): serialized_event["content"].pop("m.relates_to", None) r = serialized_event["unsigned"].setdefault("m.relations", {}) - r[RelationTypes.REPLACES] = { + r[RelationTypes.REPLACE] = { "event_id": edit.event_id, } diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 63e6185ee3..493abe405e 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -351,7 +351,7 @@ class RelationsWorkerStore(SQLBaseStore): def _get_applicable_edit_txn(txn): txn.execute( - sql, (event_id, RelationTypes.REPLACES,) + sql, (event_id, RelationTypes.REPLACE,) ) row = txn.fetchone() if row: @@ -384,8 +384,8 @@ class RelationsStore(RelationsWorkerStore): rel_type = relation.get("rel_type") if rel_type not in ( RelationTypes.ANNOTATION, - RelationTypes.REFERENCES, - RelationTypes.REPLACES, + RelationTypes.REFERENCE, + RelationTypes.REPLACE, ): # Unknown relation type return @@ -413,7 +413,7 @@ class RelationsStore(RelationsWorkerStore): self.get_aggregation_groups_for_event.invalidate_many, (parent_id,) ) - if rel_type == RelationTypes.REPLACES: + if rel_type == RelationTypes.REPLACE: txn.call_after(self.get_applicable_edit.invalidate, (parent_id,)) def _handle_redaction(self, txn, redacted_event_id): diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index b3dc4bd5bc..3d040cf118 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -363,8 +363,8 @@ class RelationsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s/m.replace?limit=1" - % (self.room, self.parent_id), + "/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1" + % (self.room, self.parent_id, RelationTypes.REPLACE), access_token=self.user_token, ) self.render(request) @@ -386,11 +386,11 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") self.assertEquals(200, channel.code, channel.json_body) - channel = self._send_relation(RelationTypes.REFERENCES, "m.room.test") + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") self.assertEquals(200, channel.code, channel.json_body) reply_1 = channel.json_body["event_id"] - channel = self._send_relation(RelationTypes.REFERENCES, "m.room.test") + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") self.assertEquals(200, channel.code, channel.json_body) reply_2 = channel.json_body["event_id"] @@ -411,7 +411,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): {"type": "m.reaction", "key": "b", "count": 1}, ] }, - RelationTypes.REFERENCES: { + RelationTypes.REFERENCE: { "chunk": [{"event_id": reply_1}, {"event_id": reply_2}] }, }, @@ -423,7 +423,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): new_body = {"msgtype": "m.text", "body": "I've been edited!"} channel = self._send_relation( - RelationTypes.REPLACES, + RelationTypes.REPLACE, "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, ) @@ -443,7 +443,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals( channel.json_body["unsigned"].get("m.relations"), - {RelationTypes.REPLACES: {"event_id": edit_event_id}}, + {RelationTypes.REPLACE: {"event_id": edit_event_id}}, ) def test_multi_edit(self): @@ -452,7 +452,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): """ channel = self._send_relation( - RelationTypes.REPLACES, + RelationTypes.REPLACE, "m.room.message", content={ "msgtype": "m.text", @@ -464,7 +464,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): new_body = {"msgtype": "m.text", "body": "I've been edited!"} channel = self._send_relation( - RelationTypes.REPLACES, + RelationTypes.REPLACE, "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, ) @@ -473,7 +473,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): edit_event_id = channel.json_body["event_id"] channel = self._send_relation( - RelationTypes.REPLACES, + RelationTypes.REPLACE, "m.room.message.WRONG_TYPE", content={ "msgtype": "m.text", @@ -495,7 +495,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals( channel.json_body["unsigned"].get("m.relations"), - {RelationTypes.REPLACES: {"event_id": edit_event_id}}, + {RelationTypes.REPLACE: {"event_id": edit_event_id}}, ) def _send_relation( -- cgit 1.5.1 From c7ec06e8a6909343f5860a5eecbe3aa69f03c151 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 20 May 2019 17:39:05 +0100 Subject: Block attempts to annotate the same event twice --- synapse/handlers/message.py | 16 +++++++++- synapse/storage/relations.py | 48 ++++++++++++++++++++++++++-- tests/rest/client/v2_alpha/test_relations.py | 27 +++++++++++++++- 3 files changed, 86 insertions(+), 5 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 7b2c33a922..0c892c8dba 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -22,7 +22,7 @@ from canonicaljson import encode_canonical_json, json from twisted.internet import defer from twisted.internet.defer import succeed -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, Membership, RelationTypes from synapse.api.errors import ( AuthError, Codes, @@ -601,6 +601,20 @@ class EventCreationHandler(object): self.validator.validate_new(event) + # We now check that if this event is an annotation that the can't + # annotate the same way twice (e.g. stops users from liking an event + # multiple times). + relation = event.content.get("m.relates_to", {}) + if relation.get("rel_type") == RelationTypes.ANNOTATION: + relates_to = relation["event_id"] + aggregation_key = relation["key"] + + already_exists = yield self.store.has_user_annotated_event( + relates_to, event.type, aggregation_key, event.sender, + ) + if already_exists: + raise SynapseError(400, "Can't send same reaction twice") + logger.debug( "Created event %s", event.event_id, diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 493abe405e..7d51b38d77 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -350,9 +350,7 @@ class RelationsWorkerStore(SQLBaseStore): """ def _get_applicable_edit_txn(txn): - txn.execute( - sql, (event_id, RelationTypes.REPLACE,) - ) + txn.execute(sql, (event_id, RelationTypes.REPLACE)) row = txn.fetchone() if row: return row[0] @@ -367,6 +365,50 @@ class RelationsWorkerStore(SQLBaseStore): edit_event = yield self.get_event(edit_id, allow_none=True) defer.returnValue(edit_event) + def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender): + """Check if a user has already annotated an event with the same key + (e.g. already liked an event). + + Args: + parent_id (str): The event being annotated + event_type (str): The event type of the annotation + aggregation_key (str): The aggregation key of the annotation + sender (str): The sender of the annotation + + Returns: + Deferred[bool] + """ + + sql = """ + SELECT 1 FROM event_relations + INNER JOIN events USING (event_id) + WHERE + relates_to_id = ? + AND relation_type = ? + AND type = ? + AND sender = ? + AND aggregation_key = ? + LIMIT 1; + """ + + def _get_if_user_has_annotated_event(txn): + txn.execute( + sql, + ( + parent_id, + RelationTypes.ANNOTATION, + event_type, + sender, + aggregation_key, + ), + ) + + return bool(txn.fetchone()) + + return self.runInteraction( + "get_if_user_has_annotated_event", _get_if_user_has_annotated_event + ) + class RelationsStore(RelationsWorkerStore): def _handle_event_relations(self, txn, event): diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index 3d040cf118..43b3049daa 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -90,6 +90,15 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member) self.assertEquals(400, channel.code, channel.json_body) + def test_deny_double_react(self): + """Test that we deny relations on membership events + """ + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(400, channel.code, channel.json_body) + def test_basic_paginate_relations(self): """Tests that calling pagination API corectly the latest relations. """ @@ -234,14 +243,30 @@ class RelationsTestCase(unittest.HomeserverTestCase): """Test that we can paginate within an annotation group. """ + # We need to create ten separate users to send each reaction. + access_tokens = [self.user_token, self.user2_token] + idx = 0 + while len(access_tokens) < 10: + user_id, token = self._create_user("test" + str(idx)) + idx += 1 + + self.helper.join(self.room, user=user_id, tok=token) + access_tokens.append(token) + + idx = 0 expected_event_ids = [] for _ in range(10): channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", key=u"👍" + RelationTypes.ANNOTATION, + "m.reaction", + key=u"👍", + access_token=access_tokens[idx], ) self.assertEquals(200, channel.code, channel.json_body) expected_event_ids.append(channel.json_body["event_id"]) + idx += 1 + # Also send a different type of reaction so that we test we don't see it channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") self.assertEquals(200, channel.code, channel.json_body) -- cgit 1.5.1 From c4aef549ad1f55661675a789f89fe9e041fac874 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 21 May 2019 16:10:54 +0100 Subject: Exclude soft-failed events from fwd-extremity candidates. (#5146) When considering the candidates to be forward-extremities, we must exclude soft failures. Hopefully fixes #5090. --- changelog.d/5146.bugfix | 1 + synapse/handlers/federation.py | 7 ++++++- synapse/storage/events.py | 9 +++++++-- 3 files changed, 14 insertions(+), 3 deletions(-) create mode 100644 changelog.d/5146.bugfix (limited to 'synapse/storage') diff --git a/changelog.d/5146.bugfix b/changelog.d/5146.bugfix new file mode 100644 index 0000000000..a54abed92b --- /dev/null +++ b/changelog.d/5146.bugfix @@ -0,0 +1 @@ +Exclude soft-failed events from forward-extremity candidates: fixes "No forward extremities left!" error. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 0684778882..2202ed699a 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1916,6 +1916,11 @@ class FederationHandler(BaseHandler): event.room_id, latest_event_ids=extrem_ids, ) + logger.debug( + "Doing soft-fail check for %s: state %s", + event.event_id, current_state_ids, + ) + # Now check if event pass auth against said current state auth_types = auth_types_for_event(event) current_state_ids = [ @@ -1932,7 +1937,7 @@ class FederationHandler(BaseHandler): self.auth.check(room_version, event, auth_events=current_auth_events) except AuthError as e: logger.warn( - "Failed current state auth resolution for %r because %s", + "Soft-failing %r because %s", event, e, ) event.internal_metadata.soft_failed = True diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 881d6d0126..2ffc27ff41 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -575,10 +575,11 @@ class EventsStore( def _get_events(txn, batch): sql = """ - SELECT prev_event_id + SELECT prev_event_id, internal_metadata FROM event_edges INNER JOIN events USING (event_id) LEFT JOIN rejections USING (event_id) + LEFT JOIN event_json USING (event_id) WHERE prev_event_id IN (%s) AND NOT events.outlier @@ -588,7 +589,11 @@ class EventsStore( ) txn.execute(sql, batch) - results.extend(r[0] for r in txn) + results.extend( + r[0] + for r in txn + if not json.loads(r[1]).get("soft_failed") + ) for chunk in batch_iter(event_ids, 100): yield self.runInteraction("_get_events_which_are_prevs", _get_events, chunk) -- cgit 1.5.1 From df9d9005448d837c5e1a2b75edb5730e2062b0f2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 29 May 2019 11:56:24 +0100 Subject: Correctly filter out extremities with soft failed prevs (#5274) When we receive a soft failed event we, correctly, *do not* update the forward extremity table with the event. However, if we later receive an event that references the soft failed event we then need to remove the soft failed events prev events from the forward extremities table, otherwise we just build up forward extremities. Fixes #5269 --- changelog.d/5274.bugfix | 1 + synapse/storage/events.py | 82 +++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 80 insertions(+), 3 deletions(-) create mode 100644 changelog.d/5274.bugfix (limited to 'synapse/storage') diff --git a/changelog.d/5274.bugfix b/changelog.d/5274.bugfix new file mode 100644 index 0000000000..9e14d20289 --- /dev/null +++ b/changelog.d/5274.bugfix @@ -0,0 +1 @@ +Fix bug where we leaked extremities when we soft failed events, leading to performance degradation. diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 2ffc27ff41..6e9f3d1dc0 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -554,10 +554,18 @@ class EventsStore( e_id for event in new_events for e_id in event.prev_event_ids() ) - # Finally, remove any events which are prev_events of any existing events. + # Remove any events which are prev_events of any existing events. existing_prevs = yield self._get_events_which_are_prevs(result) result.difference_update(existing_prevs) + # Finally handle the case where the new events have soft-failed prev + # events. If they do we need to remove them and their prev events, + # otherwise we end up with dangling extremities. + existing_prevs = yield self._get_prevs_before_rejected( + e_id for event in new_events for e_id in event.prev_event_ids() + ) + result.difference_update(existing_prevs) + defer.returnValue(result) @defer.inlineCallbacks @@ -573,7 +581,7 @@ class EventsStore( """ results = [] - def _get_events(txn, batch): + def _get_events_which_are_prevs_txn(txn, batch): sql = """ SELECT prev_event_id, internal_metadata FROM event_edges @@ -596,10 +604,78 @@ class EventsStore( ) for chunk in batch_iter(event_ids, 100): - yield self.runInteraction("_get_events_which_are_prevs", _get_events, chunk) + yield self.runInteraction( + "_get_events_which_are_prevs", + _get_events_which_are_prevs_txn, + chunk, + ) defer.returnValue(results) + @defer.inlineCallbacks + def _get_prevs_before_rejected(self, event_ids): + """Get soft-failed ancestors to remove from the extremities. + + Given a set of events, find all those that have been soft-failed or + rejected. Returns those soft failed/rejected events and their prev + events (whether soft-failed/rejected or not), and recurses up the + prev-event graph until it finds no more soft-failed/rejected events. + + This is used to find extremities that are ancestors of new events, but + are separated by soft failed events. + + Args: + event_ids (Iterable[str]): Events to find prev events for. Note + that these must have already been persisted. + + Returns: + Deferred[set[str]] + """ + + # The set of event_ids to return. This includes all soft-failed events + # and their prev events. + existing_prevs = set() + + def _get_prevs_before_rejected_txn(txn, batch): + to_recursively_check = batch + + while to_recursively_check: + sql = """ + SELECT + event_id, prev_event_id, internal_metadata, + rejections.event_id IS NOT NULL + FROM event_edges + INNER JOIN events USING (event_id) + LEFT JOIN rejections USING (event_id) + LEFT JOIN event_json USING (event_id) + WHERE + event_id IN (%s) + AND NOT events.outlier + """ % ( + ",".join("?" for _ in to_recursively_check), + ) + + txn.execute(sql, to_recursively_check) + to_recursively_check = [] + + for event_id, prev_event_id, metadata, rejected in txn: + if prev_event_id in existing_prevs: + continue + + soft_failed = json.loads(metadata).get("soft_failed") + if soft_failed or rejected: + to_recursively_check.append(prev_event_id) + existing_prevs.add(prev_event_id) + + for chunk in batch_iter(event_ids, 100): + yield self.runInteraction( + "_get_prevs_before_rejected", + _get_prevs_before_rejected_txn, + chunk, + ) + + defer.returnValue(existing_prevs) + @defer.inlineCallbacks def _get_new_state_after_events( self, room_id, events_context, old_latest_event_ids, new_latest_event_ids -- cgit 1.5.1 From 6ebc08c09d4ced251750cb087aa4689f90cdd4b6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 28 May 2019 18:52:41 +0100 Subject: Add DB bg update to cleanup extremities. Due to #5269 we may have extremities in our DB that we shouldn't have, so lets add a cleanup task such to remove those. --- synapse/storage/events.py | 186 +++++++++++++++++++++ .../schema/delta/54/delete_forward_extremities.sql | 19 +++ 2 files changed, 205 insertions(+) create mode 100644 synapse/storage/schema/delta/54/delete_forward_extremities.sql (limited to 'synapse/storage') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 6e9f3d1dc0..a9be143bd5 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -221,6 +221,7 @@ class EventsStore( ): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" + EVENT_DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" def __init__(self, db_conn, hs): super(EventsStore, self).__init__(db_conn, hs) @@ -252,6 +253,11 @@ class EventsStore( psql_only=True, ) + self.register_background_update_handler( + self.EVENT_DELETE_SOFT_FAILED_EXTREMITIES, + self._cleanup_extremities_bg_update, + ) + self._event_persist_queue = _EventPeristenceQueue() self._state_resolution_handler = hs.get_state_resolution_handler() @@ -2341,6 +2347,186 @@ class EventsStore( get_all_updated_current_state_deltas_txn, ) + @defer.inlineCallbacks + def _cleanup_extremities_bg_update(self, progress, batch_size): + """Background update to clean out extremities that should have been + deleted previously. + + Mainly used to deal with the aftermath of #5269. + """ + + # This works by first copying all existing forward extremities into the + # `_extremities_to_check` table at start up, and then checking each + # event in that table whether we have any descendants that are not + # soft-failed/rejected. If that is the case then we delete that event + # from the forward extremities table. + # + # For efficiency, we do this in batches by recursively pulling out all + # descendants of a batch until we find the non soft-failed/rejected + # events, i.e. the set of descendants whose chain of prev events back + # to the batch of extremities are all soft-failed or rejected. + # Typically, we won't find any such events as extremities will rarely + # have any descendants, but if they do then we should delete those + # extremities. + + def _cleanup_extremities_bg_update_txn(txn): + # The set of extremity event IDs that we're checking this round + original_set = set() + + # A dict[str, set[str]] of event ID to their prev events. + graph = {} + + # The set of descendants of the original set that are not rejected + # nor soft-failed. Ancestors of these events should be removed + # from the forward extremities table. + non_rejected_leaves = set() + + # Set of event IDs that have been soft failed, and for which we + # should check if they have descendants which haven't been soft + # failed. + soft_failed_events_to_lookup = set() + + # First, we get `batch_size` events from the table, pulling out + # their prev events, if any, and their prev events rejection status. + txn.execute( + """SELECT prev_event_id, event_id, internal_metadata, + rejections.event_id IS NOT NULL, events.outlier + FROM ( + SELECT event_id AS prev_event_id + FROM _extremities_to_check + LIMIT ? + ) AS f + LEFT JOIN event_edges USING (prev_event_id) + LEFT JOIN events USING (event_id) + LEFT JOIN event_json USING (event_id) + LEFT JOIN rejections USING (event_id) + """, (batch_size,) + ) + + for prev_event_id, event_id, metadata, rejected, outlier in txn: + original_set.add(prev_event_id) + + if not event_id or outlier: + # Common case where the forward extremity doesn't have any + # descendants. + continue + + graph.setdefault(event_id, set()).add(prev_event_id) + + soft_failed = False + if metadata: + soft_failed = json.loads(metadata).get("soft_failed") + + if soft_failed or rejected: + soft_failed_events_to_lookup.add(event_id) + else: + non_rejected_leaves.add(event_id) + + # Now we recursively check all the soft-failed descendants we + # found above in the same way, until we have nothing left to + # check. + while soft_failed_events_to_lookup: + # We only want to do 100 at a time, so we split given list + # into two. + batch = list(soft_failed_events_to_lookup) + to_check, to_defer = batch[:100], batch[100:] + soft_failed_events_to_lookup = set(to_defer) + + sql = """SELECT prev_event_id, event_id, internal_metadata, + rejections.event_id IS NOT NULL + FROM event_edges + INNER JOIN events USING (event_id) + INNER JOIN event_json USING (event_id) + LEFT JOIN rejections USING (event_id) + WHERE + prev_event_id IN (%s) + AND NOT events.outlier + """ % ( + ",".join("?" for _ in to_check), + ) + txn.execute(sql, to_check) + + for prev_event_id, event_id, metadata, rejected in txn: + if event_id in graph: + # Already handled this event previously, but we still + # want to record the edge. + graph.setdefault(event_id, set()).add(prev_event_id) + logger.info("Already handled") + continue + + graph.setdefault(event_id, set()).add(prev_event_id) + + soft_failed = json.loads(metadata).get("soft_failed") + if soft_failed or rejected: + soft_failed_events_to_lookup.add(event_id) + else: + non_rejected_leaves.add(event_id) + + # We have a set of non-soft-failed descendants, so we recurse up + # the graph to find all ancestors and add them to the set of event + # IDs that we can delete from forward extremities table. + to_delete = set() + while non_rejected_leaves: + event_id = non_rejected_leaves.pop() + prev_event_ids = graph.get(event_id, set()) + non_rejected_leaves.update(prev_event_ids) + to_delete.update(prev_event_ids) + + to_delete.intersection_update(original_set) + + logger.info("Deleting up to %d forward extremities", len(to_delete)) + + self._simple_delete_many_txn( + txn=txn, + table="event_forward_extremities", + column="event_id", + iterable=to_delete, + keyvalues={}, + ) + + if to_delete: + # We now need to invalidate the caches of these rooms + rows = self._simple_select_many_txn( + txn, + table="events", + column="event_id", + iterable=to_delete, + keyvalues={}, + retcols=("room_id",) + ) + for row in rows: + txn.call_after( + self.get_latest_event_ids_in_room.invalidate, + (row["room_id"],) + ) + + self._simple_delete_many_txn( + txn=txn, + table="_extremities_to_check", + column="event_id", + iterable=original_set, + keyvalues={}, + ) + + return len(original_set) + + num_handled = yield self.runInteraction( + "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn, + ) + + if not num_handled: + yield self._end_background_update(self.EVENT_DELETE_SOFT_FAILED_EXTREMITIES) + + def _drop_table_txn(txn): + txn.execute("DROP TABLE _extremities_to_check") + + yield self.runInteraction( + "_cleanup_extremities_bg_update_drop_table", + _drop_table_txn, + ) + + defer.returnValue(num_handled) + AllNewEventsResult = namedtuple( "AllNewEventsResult", diff --git a/synapse/storage/schema/delta/54/delete_forward_extremities.sql b/synapse/storage/schema/delta/54/delete_forward_extremities.sql new file mode 100644 index 0000000000..7056bd1d00 --- /dev/null +++ b/synapse/storage/schema/delta/54/delete_forward_extremities.sql @@ -0,0 +1,19 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('delete_soft_failed_extremities', '{}'); + +CREATE TABLE _extremities_to_check AS SELECT event_id FROM event_forward_extremities; -- cgit 1.5.1 From 1d818fde14595299de9e13008c67239ca677014f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 29 May 2019 11:58:32 +0100 Subject: Log actual number of entries deleted --- synapse/storage/_base.py | 12 +++++++++--- synapse/storage/events.py | 6 ++++-- 2 files changed, 13 insertions(+), 5 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index fa6839ceca..3fe827cd43 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -1261,7 +1261,8 @@ class SQLBaseStore(object): " AND ".join("%s = ?" % (k,) for k in keyvalues), ) - return txn.execute(sql, list(keyvalues.values())) + txn.execute(sql, list(keyvalues.values())) + return txn.rowcount def _simple_delete_many(self, table, column, iterable, keyvalues, desc): return self.runInteraction( @@ -1280,9 +1281,12 @@ class SQLBaseStore(object): column : column name to test for inclusion against `iterable` iterable : list keyvalues : dict of column names and values to select the rows with + + Returns: + int: Number rows deleted """ if not iterable: - return + return 0 sql = "DELETE FROM %s" % table @@ -1297,7 +1301,9 @@ class SQLBaseStore(object): if clauses: sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) - return txn.execute(sql, values) + txn.execute(sql, values) + + return txn.rowcount def _get_cache_dict( self, db_conn, table, entity_column, stream_column, max_value, limit=100000 diff --git a/synapse/storage/events.py b/synapse/storage/events.py index a9be143bd5..a9664928ca 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -2476,7 +2476,7 @@ class EventsStore( logger.info("Deleting up to %d forward extremities", len(to_delete)) - self._simple_delete_many_txn( + deleted = self._simple_delete_many_txn( txn=txn, table="event_forward_extremities", column="event_id", @@ -2484,7 +2484,9 @@ class EventsStore( keyvalues={}, ) - if to_delete: + logger.info("Deleted %d forward extremities", deleted) + + if deleted: # We now need to invalidate the caches of these rooms rows = self._simple_select_many_txn( txn, -- cgit 1.5.1 From 9b8cd66524304f76209a59e12f4eca561b1a43d3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 30 May 2019 10:55:55 +0100 Subject: Fixup comments and logging --- synapse/storage/events.py | 21 ++++++++++++--------- .../schema/delta/54/delete_forward_extremities.sql | 3 +++ 2 files changed, 15 insertions(+), 9 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index a9664928ca..418d88b8dc 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -2387,7 +2387,8 @@ class EventsStore( soft_failed_events_to_lookup = set() # First, we get `batch_size` events from the table, pulling out - # their prev events, if any, and their prev events rejection status. + # their successor events, if any, and their successor events + # rejection status. txn.execute( """SELECT prev_event_id, event_id, internal_metadata, rejections.event_id IS NOT NULL, events.outlier @@ -2450,11 +2451,10 @@ class EventsStore( if event_id in graph: # Already handled this event previously, but we still # want to record the edge. - graph.setdefault(event_id, set()).add(prev_event_id) - logger.info("Already handled") + graph[event_id].add(prev_event_id) continue - graph.setdefault(event_id, set()).add(prev_event_id) + graph[event_id] = {prev_event_id} soft_failed = json.loads(metadata).get("soft_failed") if soft_failed or rejected: @@ -2474,8 +2474,6 @@ class EventsStore( to_delete.intersection_update(original_set) - logger.info("Deleting up to %d forward extremities", len(to_delete)) - deleted = self._simple_delete_many_txn( txn=txn, table="event_forward_extremities", @@ -2484,7 +2482,11 @@ class EventsStore( keyvalues={}, ) - logger.info("Deleted %d forward extremities", deleted) + logger.info( + "Deleted %d forward extremities of %d checked, to clean up #5269", + deleted, + len(original_set), + ) if deleted: # We now need to invalidate the caches of these rooms @@ -2496,10 +2498,11 @@ class EventsStore( keyvalues={}, retcols=("room_id",) ) - for row in rows: + room_ids = set(row["room_id"] for row in rows) + for room_id in room_ids: txn.call_after( self.get_latest_event_ids_in_room.invalidate, - (row["room_id"],) + (room_id,) ) self._simple_delete_many_txn( diff --git a/synapse/storage/schema/delta/54/delete_forward_extremities.sql b/synapse/storage/schema/delta/54/delete_forward_extremities.sql index 7056bd1d00..aa40f13da7 100644 --- a/synapse/storage/schema/delta/54/delete_forward_extremities.sql +++ b/synapse/storage/schema/delta/54/delete_forward_extremities.sql @@ -13,7 +13,10 @@ * limitations under the License. */ +-- Start a background job to cleanup extremities that were incorrectly added +-- by bug #5269. INSERT INTO background_updates (update_name, progress_json) VALUES ('delete_soft_failed_extremities', '{}'); +DROP TABLE IF EXISTS _extremities_to_check; -- To make this delta schema file idempotent. CREATE TABLE _extremities_to_check AS SELECT event_id FROM event_forward_extremities; -- cgit 1.5.1 From 98f438b52a93b1ce9d1f3e93fa57db0f870f9101 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 30 May 2019 11:22:59 +0100 Subject: Move event background updates to a separate file --- synapse/storage/__init__.py | 2 + synapse/storage/events.py | 371 +------------------------------- synapse/storage/events_bg_updates.py | 401 +++++++++++++++++++++++++++++++++++ 3 files changed, 405 insertions(+), 369 deletions(-) create mode 100644 synapse/storage/events_bg_updates.py (limited to 'synapse/storage') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 7522d3fd57..56c434d4e8 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -36,6 +36,7 @@ from .engines import PostgresEngine from .event_federation import EventFederationStore from .event_push_actions import EventPushActionsStore from .events import EventsStore +from .events_bg_updates import EventsBackgroundUpdatesStore from .filtering import FilteringStore from .group_server import GroupServerStore from .keys import KeyStore @@ -65,6 +66,7 @@ logger = logging.getLogger(__name__) class DataStore( + EventsBackgroundUpdatesStore, RoomMemberStore, RoomStore, RegistrationStore, diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 418d88b8dc..f9162be9b9 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -219,47 +220,11 @@ class EventsStore( EventsWorkerStore, BackgroundUpdateStore, ): - EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" - EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" - EVENT_DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" def __init__(self, db_conn, hs): super(EventsStore, self).__init__(db_conn, hs) - self.register_background_update_handler( - self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts - ) - self.register_background_update_handler( - self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, - self._background_reindex_fields_sender, - ) - - self.register_background_index_update( - "event_contains_url_index", - index_name="event_contains_url_index", - table="events", - columns=["room_id", "topological_ordering", "stream_ordering"], - where_clause="contains_url = true AND outlier = false", - ) - - # an event_id index on event_search is useful for the purge_history - # api. Plus it means we get to enforce some integrity with a UNIQUE - # clause - self.register_background_index_update( - "event_search_event_id_idx", - index_name="event_search_event_id_idx", - table="event_search", - columns=["event_id"], - unique=True, - psql_only=True, - ) - - self.register_background_update_handler( - self.EVENT_DELETE_SOFT_FAILED_EXTREMITIES, - self._cleanup_extremities_bg_update, - ) self._event_persist_queue = _EventPeristenceQueue() - self._state_resolution_handler = hs.get_state_resolution_handler() @defer.inlineCallbacks @@ -1585,153 +1550,6 @@ class EventsStore( ret = yield self.runInteraction("count_daily_active_rooms", _count) defer.returnValue(ret) - @defer.inlineCallbacks - def _background_reindex_fields_sender(self, progress, batch_size): - target_min_stream_id = progress["target_min_stream_id_inclusive"] - max_stream_id = progress["max_stream_id_exclusive"] - rows_inserted = progress.get("rows_inserted", 0) - - INSERT_CLUMP_SIZE = 1000 - - def reindex_txn(txn): - sql = ( - "SELECT stream_ordering, event_id, json FROM events" - " INNER JOIN event_json USING (event_id)" - " WHERE ? <= stream_ordering AND stream_ordering < ?" - " ORDER BY stream_ordering DESC" - " LIMIT ?" - ) - - txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) - - rows = txn.fetchall() - if not rows: - return 0 - - min_stream_id = rows[-1][0] - - update_rows = [] - for row in rows: - try: - event_id = row[1] - event_json = json.loads(row[2]) - sender = event_json["sender"] - content = event_json["content"] - - contains_url = "url" in content - if contains_url: - contains_url &= isinstance(content["url"], text_type) - except (KeyError, AttributeError): - # If the event is missing a necessary field then - # skip over it. - continue - - update_rows.append((sender, contains_url, event_id)) - - sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?" - - for index in range(0, len(update_rows), INSERT_CLUMP_SIZE): - clump = update_rows[index : index + INSERT_CLUMP_SIZE] - txn.executemany(sql, clump) - - progress = { - "target_min_stream_id_inclusive": target_min_stream_id, - "max_stream_id_exclusive": min_stream_id, - "rows_inserted": rows_inserted + len(rows), - } - - self._background_update_progress_txn( - txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress - ) - - return len(rows) - - result = yield self.runInteraction( - self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn - ) - - if not result: - yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME) - - defer.returnValue(result) - - @defer.inlineCallbacks - def _background_reindex_origin_server_ts(self, progress, batch_size): - target_min_stream_id = progress["target_min_stream_id_inclusive"] - max_stream_id = progress["max_stream_id_exclusive"] - rows_inserted = progress.get("rows_inserted", 0) - - INSERT_CLUMP_SIZE = 1000 - - def reindex_search_txn(txn): - sql = ( - "SELECT stream_ordering, event_id FROM events" - " WHERE ? <= stream_ordering AND stream_ordering < ?" - " ORDER BY stream_ordering DESC" - " LIMIT ?" - ) - - txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) - - rows = txn.fetchall() - if not rows: - return 0 - - min_stream_id = rows[-1][0] - event_ids = [row[1] for row in rows] - - rows_to_update = [] - - chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)] - for chunk in chunks: - ev_rows = self._simple_select_many_txn( - txn, - table="event_json", - column="event_id", - iterable=chunk, - retcols=["event_id", "json"], - keyvalues={}, - ) - - for row in ev_rows: - event_id = row["event_id"] - event_json = json.loads(row["json"]) - try: - origin_server_ts = event_json["origin_server_ts"] - except (KeyError, AttributeError): - # If the event is missing a necessary field then - # skip over it. - continue - - rows_to_update.append((origin_server_ts, event_id)) - - sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?" - - for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE): - clump = rows_to_update[index : index + INSERT_CLUMP_SIZE] - txn.executemany(sql, clump) - - progress = { - "target_min_stream_id_inclusive": target_min_stream_id, - "max_stream_id_exclusive": min_stream_id, - "rows_inserted": rows_inserted + len(rows_to_update), - } - - self._background_update_progress_txn( - txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress - ) - - return len(rows_to_update) - - result = yield self.runInteraction( - self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn - ) - - if not result: - yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME) - - defer.returnValue(result) - def get_current_backfill_token(self): """The current minimum token that backfilled events have reached""" return -self._backfill_id_gen.get_current_token() @@ -2347,191 +2165,6 @@ class EventsStore( get_all_updated_current_state_deltas_txn, ) - @defer.inlineCallbacks - def _cleanup_extremities_bg_update(self, progress, batch_size): - """Background update to clean out extremities that should have been - deleted previously. - - Mainly used to deal with the aftermath of #5269. - """ - - # This works by first copying all existing forward extremities into the - # `_extremities_to_check` table at start up, and then checking each - # event in that table whether we have any descendants that are not - # soft-failed/rejected. If that is the case then we delete that event - # from the forward extremities table. - # - # For efficiency, we do this in batches by recursively pulling out all - # descendants of a batch until we find the non soft-failed/rejected - # events, i.e. the set of descendants whose chain of prev events back - # to the batch of extremities are all soft-failed or rejected. - # Typically, we won't find any such events as extremities will rarely - # have any descendants, but if they do then we should delete those - # extremities. - - def _cleanup_extremities_bg_update_txn(txn): - # The set of extremity event IDs that we're checking this round - original_set = set() - - # A dict[str, set[str]] of event ID to their prev events. - graph = {} - - # The set of descendants of the original set that are not rejected - # nor soft-failed. Ancestors of these events should be removed - # from the forward extremities table. - non_rejected_leaves = set() - - # Set of event IDs that have been soft failed, and for which we - # should check if they have descendants which haven't been soft - # failed. - soft_failed_events_to_lookup = set() - - # First, we get `batch_size` events from the table, pulling out - # their successor events, if any, and their successor events - # rejection status. - txn.execute( - """SELECT prev_event_id, event_id, internal_metadata, - rejections.event_id IS NOT NULL, events.outlier - FROM ( - SELECT event_id AS prev_event_id - FROM _extremities_to_check - LIMIT ? - ) AS f - LEFT JOIN event_edges USING (prev_event_id) - LEFT JOIN events USING (event_id) - LEFT JOIN event_json USING (event_id) - LEFT JOIN rejections USING (event_id) - """, (batch_size,) - ) - - for prev_event_id, event_id, metadata, rejected, outlier in txn: - original_set.add(prev_event_id) - - if not event_id or outlier: - # Common case where the forward extremity doesn't have any - # descendants. - continue - - graph.setdefault(event_id, set()).add(prev_event_id) - - soft_failed = False - if metadata: - soft_failed = json.loads(metadata).get("soft_failed") - - if soft_failed or rejected: - soft_failed_events_to_lookup.add(event_id) - else: - non_rejected_leaves.add(event_id) - - # Now we recursively check all the soft-failed descendants we - # found above in the same way, until we have nothing left to - # check. - while soft_failed_events_to_lookup: - # We only want to do 100 at a time, so we split given list - # into two. - batch = list(soft_failed_events_to_lookup) - to_check, to_defer = batch[:100], batch[100:] - soft_failed_events_to_lookup = set(to_defer) - - sql = """SELECT prev_event_id, event_id, internal_metadata, - rejections.event_id IS NOT NULL - FROM event_edges - INNER JOIN events USING (event_id) - INNER JOIN event_json USING (event_id) - LEFT JOIN rejections USING (event_id) - WHERE - prev_event_id IN (%s) - AND NOT events.outlier - """ % ( - ",".join("?" for _ in to_check), - ) - txn.execute(sql, to_check) - - for prev_event_id, event_id, metadata, rejected in txn: - if event_id in graph: - # Already handled this event previously, but we still - # want to record the edge. - graph[event_id].add(prev_event_id) - continue - - graph[event_id] = {prev_event_id} - - soft_failed = json.loads(metadata).get("soft_failed") - if soft_failed or rejected: - soft_failed_events_to_lookup.add(event_id) - else: - non_rejected_leaves.add(event_id) - - # We have a set of non-soft-failed descendants, so we recurse up - # the graph to find all ancestors and add them to the set of event - # IDs that we can delete from forward extremities table. - to_delete = set() - while non_rejected_leaves: - event_id = non_rejected_leaves.pop() - prev_event_ids = graph.get(event_id, set()) - non_rejected_leaves.update(prev_event_ids) - to_delete.update(prev_event_ids) - - to_delete.intersection_update(original_set) - - deleted = self._simple_delete_many_txn( - txn=txn, - table="event_forward_extremities", - column="event_id", - iterable=to_delete, - keyvalues={}, - ) - - logger.info( - "Deleted %d forward extremities of %d checked, to clean up #5269", - deleted, - len(original_set), - ) - - if deleted: - # We now need to invalidate the caches of these rooms - rows = self._simple_select_many_txn( - txn, - table="events", - column="event_id", - iterable=to_delete, - keyvalues={}, - retcols=("room_id",) - ) - room_ids = set(row["room_id"] for row in rows) - for room_id in room_ids: - txn.call_after( - self.get_latest_event_ids_in_room.invalidate, - (room_id,) - ) - - self._simple_delete_many_txn( - txn=txn, - table="_extremities_to_check", - column="event_id", - iterable=original_set, - keyvalues={}, - ) - - return len(original_set) - - num_handled = yield self.runInteraction( - "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn, - ) - - if not num_handled: - yield self._end_background_update(self.EVENT_DELETE_SOFT_FAILED_EXTREMITIES) - - def _drop_table_txn(txn): - txn.execute("DROP TABLE _extremities_to_check") - - yield self.runInteraction( - "_cleanup_extremities_bg_update_drop_table", - _drop_table_txn, - ) - - defer.returnValue(num_handled) - AllNewEventsResult = namedtuple( "AllNewEventsResult", diff --git a/synapse/storage/events_bg_updates.py b/synapse/storage/events_bg_updates.py new file mode 100644 index 0000000000..2eba106abf --- /dev/null +++ b/synapse/storage/events_bg_updates.py @@ -0,0 +1,401 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from six import text_type + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.storage.background_updates import BackgroundUpdateStore + +logger = logging.getLogger(__name__) + + +class EventsBackgroundUpdatesStore(BackgroundUpdateStore): + + EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" + EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" + EVENT_DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" + + def __init__(self, db_conn, hs): + super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs) + + self.register_background_update_handler( + self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts + ) + self.register_background_update_handler( + self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, + self._background_reindex_fields_sender, + ) + + self.register_background_index_update( + "event_contains_url_index", + index_name="event_contains_url_index", + table="events", + columns=["room_id", "topological_ordering", "stream_ordering"], + where_clause="contains_url = true AND outlier = false", + ) + + # an event_id index on event_search is useful for the purge_history + # api. Plus it means we get to enforce some integrity with a UNIQUE + # clause + self.register_background_index_update( + "event_search_event_id_idx", + index_name="event_search_event_id_idx", + table="event_search", + columns=["event_id"], + unique=True, + psql_only=True, + ) + + self.register_background_update_handler( + self.EVENT_DELETE_SOFT_FAILED_EXTREMITIES, + self._cleanup_extremities_bg_update, + ) + + @defer.inlineCallbacks + def _background_reindex_fields_sender(self, progress, batch_size): + target_min_stream_id = progress["target_min_stream_id_inclusive"] + max_stream_id = progress["max_stream_id_exclusive"] + rows_inserted = progress.get("rows_inserted", 0) + + INSERT_CLUMP_SIZE = 1000 + + def reindex_txn(txn): + sql = ( + "SELECT stream_ordering, event_id, json FROM events" + " INNER JOIN event_json USING (event_id)" + " WHERE ? <= stream_ordering AND stream_ordering < ?" + " ORDER BY stream_ordering DESC" + " LIMIT ?" + ) + + txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) + + rows = txn.fetchall() + if not rows: + return 0 + + min_stream_id = rows[-1][0] + + update_rows = [] + for row in rows: + try: + event_id = row[1] + event_json = json.loads(row[2]) + sender = event_json["sender"] + content = event_json["content"] + + contains_url = "url" in content + if contains_url: + contains_url &= isinstance(content["url"], text_type) + except (KeyError, AttributeError): + # If the event is missing a necessary field then + # skip over it. + continue + + update_rows.append((sender, contains_url, event_id)) + + sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?" + + for index in range(0, len(update_rows), INSERT_CLUMP_SIZE): + clump = update_rows[index : index + INSERT_CLUMP_SIZE] + txn.executemany(sql, clump) + + progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + "rows_inserted": rows_inserted + len(rows), + } + + self._background_update_progress_txn( + txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress + ) + + return len(rows) + + result = yield self.runInteraction( + self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn + ) + + if not result: + yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME) + + defer.returnValue(result) + + @defer.inlineCallbacks + def _background_reindex_origin_server_ts(self, progress, batch_size): + target_min_stream_id = progress["target_min_stream_id_inclusive"] + max_stream_id = progress["max_stream_id_exclusive"] + rows_inserted = progress.get("rows_inserted", 0) + + INSERT_CLUMP_SIZE = 1000 + + def reindex_search_txn(txn): + sql = ( + "SELECT stream_ordering, event_id FROM events" + " WHERE ? <= stream_ordering AND stream_ordering < ?" + " ORDER BY stream_ordering DESC" + " LIMIT ?" + ) + + txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) + + rows = txn.fetchall() + if not rows: + return 0 + + min_stream_id = rows[-1][0] + event_ids = [row[1] for row in rows] + + rows_to_update = [] + + chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)] + for chunk in chunks: + ev_rows = self._simple_select_many_txn( + txn, + table="event_json", + column="event_id", + iterable=chunk, + retcols=["event_id", "json"], + keyvalues={}, + ) + + for row in ev_rows: + event_id = row["event_id"] + event_json = json.loads(row["json"]) + try: + origin_server_ts = event_json["origin_server_ts"] + except (KeyError, AttributeError): + # If the event is missing a necessary field then + # skip over it. + continue + + rows_to_update.append((origin_server_ts, event_id)) + + sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?" + + for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE): + clump = rows_to_update[index : index + INSERT_CLUMP_SIZE] + txn.executemany(sql, clump) + + progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + "rows_inserted": rows_inserted + len(rows_to_update), + } + + self._background_update_progress_txn( + txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress + ) + + return len(rows_to_update) + + result = yield self.runInteraction( + self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn + ) + + if not result: + yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME) + + defer.returnValue(result) + + @defer.inlineCallbacks + def _cleanup_extremities_bg_update(self, progress, batch_size): + """Background update to clean out extremities that should have been + deleted previously. + + Mainly used to deal with the aftermath of #5269. + """ + + # This works by first copying all existing forward extremities into the + # `_extremities_to_check` table at start up, and then checking each + # event in that table whether we have any descendants that are not + # soft-failed/rejected. If that is the case then we delete that event + # from the forward extremities table. + # + # For efficiency, we do this in batches by recursively pulling out all + # descendants of a batch until we find the non soft-failed/rejected + # events, i.e. the set of descendants whose chain of prev events back + # to the batch of extremities are all soft-failed or rejected. + # Typically, we won't find any such events as extremities will rarely + # have any descendants, but if they do then we should delete those + # extremities. + + def _cleanup_extremities_bg_update_txn(txn): + # The set of extremity event IDs that we're checking this round + original_set = set() + + # A dict[str, set[str]] of event ID to their prev events. + graph = {} + + # The set of descendants of the original set that are not rejected + # nor soft-failed. Ancestors of these events should be removed + # from the forward extremities table. + non_rejected_leaves = set() + + # Set of event IDs that have been soft failed, and for which we + # should check if they have descendants which haven't been soft + # failed. + soft_failed_events_to_lookup = set() + + # First, we get `batch_size` events from the table, pulling out + # their successor events, if any, and their successor events + # rejection status. + txn.execute( + """SELECT prev_event_id, event_id, internal_metadata, + rejections.event_id IS NOT NULL, events.outlier + FROM ( + SELECT event_id AS prev_event_id + FROM _extremities_to_check + LIMIT ? + ) AS f + LEFT JOIN event_edges USING (prev_event_id) + LEFT JOIN events USING (event_id) + LEFT JOIN event_json USING (event_id) + LEFT JOIN rejections USING (event_id) + """, (batch_size,) + ) + + for prev_event_id, event_id, metadata, rejected, outlier in txn: + original_set.add(prev_event_id) + + if not event_id or outlier: + # Common case where the forward extremity doesn't have any + # descendants. + continue + + graph.setdefault(event_id, set()).add(prev_event_id) + + soft_failed = False + if metadata: + soft_failed = json.loads(metadata).get("soft_failed") + + if soft_failed or rejected: + soft_failed_events_to_lookup.add(event_id) + else: + non_rejected_leaves.add(event_id) + + # Now we recursively check all the soft-failed descendants we + # found above in the same way, until we have nothing left to + # check. + while soft_failed_events_to_lookup: + # We only want to do 100 at a time, so we split given list + # into two. + batch = list(soft_failed_events_to_lookup) + to_check, to_defer = batch[:100], batch[100:] + soft_failed_events_to_lookup = set(to_defer) + + sql = """SELECT prev_event_id, event_id, internal_metadata, + rejections.event_id IS NOT NULL + FROM event_edges + INNER JOIN events USING (event_id) + INNER JOIN event_json USING (event_id) + LEFT JOIN rejections USING (event_id) + WHERE + prev_event_id IN (%s) + AND NOT events.outlier + """ % ( + ",".join("?" for _ in to_check), + ) + txn.execute(sql, to_check) + + for prev_event_id, event_id, metadata, rejected in txn: + if event_id in graph: + # Already handled this event previously, but we still + # want to record the edge. + graph[event_id].add(prev_event_id) + continue + + graph[event_id] = {prev_event_id} + + soft_failed = json.loads(metadata).get("soft_failed") + if soft_failed or rejected: + soft_failed_events_to_lookup.add(event_id) + else: + non_rejected_leaves.add(event_id) + + # We have a set of non-soft-failed descendants, so we recurse up + # the graph to find all ancestors and add them to the set of event + # IDs that we can delete from forward extremities table. + to_delete = set() + while non_rejected_leaves: + event_id = non_rejected_leaves.pop() + prev_event_ids = graph.get(event_id, set()) + non_rejected_leaves.update(prev_event_ids) + to_delete.update(prev_event_ids) + + to_delete.intersection_update(original_set) + + deleted = self._simple_delete_many_txn( + txn=txn, + table="event_forward_extremities", + column="event_id", + iterable=to_delete, + keyvalues={}, + ) + + logger.info( + "Deleted %d forward extremities of %d checked, to clean up #5269", + deleted, + len(original_set), + ) + + if deleted: + # We now need to invalidate the caches of these rooms + rows = self._simple_select_many_txn( + txn, + table="events", + column="event_id", + iterable=to_delete, + keyvalues={}, + retcols=("room_id",) + ) + room_ids = set(row["room_id"] for row in rows) + for room_id in room_ids: + txn.call_after( + self.get_latest_event_ids_in_room.invalidate, + (room_id,) + ) + + self._simple_delete_many_txn( + txn=txn, + table="_extremities_to_check", + column="event_id", + iterable=original_set, + keyvalues={}, + ) + + return len(original_set) + + num_handled = yield self.runInteraction( + "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn, + ) + + if not num_handled: + yield self._end_background_update(self.EVENT_DELETE_SOFT_FAILED_EXTREMITIES) + + def _drop_table_txn(txn): + txn.execute("DROP TABLE _extremities_to_check") + + yield self.runInteraction( + "_cleanup_extremities_bg_update_drop_table", + _drop_table_txn, + ) + + defer.returnValue(num_handled) -- cgit 1.5.1 From 7386c35f58f360269df1410b2d1ec6d179081b32 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 30 May 2019 11:24:42 +0100 Subject: Rename constant --- synapse/storage/events_bg_updates.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/events_bg_updates.py b/synapse/storage/events_bg_updates.py index 2eba106abf..22aac1393d 100644 --- a/synapse/storage/events_bg_updates.py +++ b/synapse/storage/events_bg_updates.py @@ -30,7 +30,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" - EVENT_DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" + DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" def __init__(self, db_conn, hs): super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs) @@ -64,7 +64,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): ) self.register_background_update_handler( - self.EVENT_DELETE_SOFT_FAILED_EXTREMITIES, + self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update, ) @@ -388,7 +388,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): ) if not num_handled: - yield self._end_background_update(self.EVENT_DELETE_SOFT_FAILED_EXTREMITIES) + yield self._end_background_update(self.DELETE_SOFT_FAILED_EXTREMITIES) def _drop_table_txn(txn): txn.execute("DROP TABLE _extremities_to_check") -- cgit 1.5.1 From 06eb408da5c4ab52a3072dc6d76fe5ac3b9b1e83 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 30 May 2019 14:06:42 +0100 Subject: Update synapse/storage/events_bg_updates.py Co-Authored-By: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> --- synapse/storage/events_bg_updates.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/events_bg_updates.py b/synapse/storage/events_bg_updates.py index 22aac1393d..75c1935bf3 100644 --- a/synapse/storage/events_bg_updates.py +++ b/synapse/storage/events_bg_updates.py @@ -255,7 +255,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): soft_failed_events_to_lookup = set() # First, we get `batch_size` events from the table, pulling out - # their successor events, if any, and their successor events + # their successor events, if any, and the successor events' # rejection status. txn.execute( """SELECT prev_event_id, event_id, internal_metadata, -- cgit 1.5.1 From e2c3660a0ffb13d4198893e91a90ae1abcad8915 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 30 May 2019 14:54:56 +0100 Subject: Add index to temp table --- synapse/storage/schema/delta/54/delete_forward_extremities.sql | 1 + 1 file changed, 1 insertion(+) (limited to 'synapse/storage') diff --git a/synapse/storage/schema/delta/54/delete_forward_extremities.sql b/synapse/storage/schema/delta/54/delete_forward_extremities.sql index aa40f13da7..b062ec840c 100644 --- a/synapse/storage/schema/delta/54/delete_forward_extremities.sql +++ b/synapse/storage/schema/delta/54/delete_forward_extremities.sql @@ -20,3 +20,4 @@ INSERT INTO background_updates (update_name, progress_json) VALUES DROP TABLE IF EXISTS _extremities_to_check; -- To make this delta schema file idempotent. CREATE TABLE _extremities_to_check AS SELECT event_id FROM event_forward_extremities; +CREATE INDEX _extremities_to_check_id ON _extremities_to_check(event_id); -- cgit 1.5.1