From bd0d45ca69587f4f258b738dfa3a55704838081e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 26 Apr 2019 11:13:16 +0100 Subject: Fix infinite loop in presence handler Fixes #5102 --- synapse/storage/state_deltas.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) (limited to 'synapse/storage') diff --git a/synapse/storage/state_deltas.py b/synapse/storage/state_deltas.py index 56e42f583d..31a0279b18 100644 --- a/synapse/storage/state_deltas.py +++ b/synapse/storage/state_deltas.py @@ -22,6 +22,24 @@ logger = logging.getLogger(__name__) class StateDeltasStore(SQLBaseStore): def get_current_state_deltas(self, prev_stream_id): + """Fetch a list of room state changes since the given stream id + + Each entry in the result contains the following fields: + - stream_id (int) + - room_id (str) + - type (str): event type + - state_key (str): + - event_id (str|None): new event_id for this state key. None if the + state has been deleted. + - prev_event_id (str|None): previous event_id for this state key. None + if it's new state. + + Args: + prev_stream_id (int): point to get changes since (exclusive) + + Returns: + Deferred[list[dict]]: results + """ prev_stream_id = int(prev_stream_id) if not self._curr_state_delta_stream_cache.has_any_entity_changed( prev_stream_id -- cgit 1.5.1 From 11ea16777f52264f0a1d6f4db5b7db5c6c147523 Mon Sep 17 00:00:00 2001 From: Quentin Dufour Date: Thu, 9 May 2019 12:01:41 +0200 Subject: Limit the number of EDUs in transactions to 100 as expected by receiver (#5138) Fixes #3951. --- changelog.d/5138.bugfix | 1 + synapse/federation/sender/per_destination_queue.py | 56 ++++++++++++---------- synapse/storage/deviceinbox.py | 2 +- 3 files changed, 32 insertions(+), 27 deletions(-) create mode 100644 changelog.d/5138.bugfix (limited to 'synapse/storage') diff --git a/changelog.d/5138.bugfix b/changelog.d/5138.bugfix new file mode 100644 index 0000000000..c1945a8eb2 --- /dev/null +++ b/changelog.d/5138.bugfix @@ -0,0 +1 @@ +Limit the number of EDUs in transactions to 100 as expected by synapse. Thanks to @superboum for this work! diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index be99211003..4d96f026c6 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -33,6 +33,9 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage import UserPresenceState from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter +# This is defined in the Matrix spec and enforced by the receiver. +MAX_EDUS_PER_TRANSACTION = 100 + logger = logging.getLogger(__name__) @@ -197,7 +200,8 @@ class PerDestinationQueue(object): pending_pdus = [] while True: device_message_edus, device_stream_id, dev_list_id = ( - yield self._get_new_device_messages() + # We have to keep 2 free slots for presence and rr_edus + yield self._get_new_device_messages(MAX_EDUS_PER_TRANSACTION - 2) ) # BEGIN CRITICAL SECTION @@ -216,19 +220,9 @@ class PerDestinationQueue(object): pending_edus = [] - pending_edus.extend(self._get_rr_edus(force_flush=False)) - # We can only include at most 100 EDUs per transactions - pending_edus.extend(self._pop_pending_edus(100 - len(pending_edus))) - - pending_edus.extend( - self._pending_edus_keyed.values() - ) - - self._pending_edus_keyed = {} - - pending_edus.extend(device_message_edus) - + # rr_edus and pending_presence take at most one slot each + pending_edus.extend(self._get_rr_edus(force_flush=False)) pending_presence = self._pending_presence self._pending_presence = {} if pending_presence: @@ -248,6 +242,12 @@ class PerDestinationQueue(object): ) ) + pending_edus.extend(device_message_edus) + pending_edus.extend(self._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus))) + while len(pending_edus) < MAX_EDUS_PER_TRANSACTION and self._pending_edus_keyed: + _, val = self._pending_edus_keyed.popitem() + pending_edus.append(val) + if pending_pdus: logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", self._destination, len(pending_pdus)) @@ -259,7 +259,7 @@ class PerDestinationQueue(object): # if we've decided to send a transaction anyway, and we have room, we # may as well send any pending RRs - if len(pending_edus) < 100: + if len(pending_edus) < MAX_EDUS_PER_TRANSACTION: pending_edus.extend(self._get_rr_edus(force_flush=True)) # END CRITICAL SECTION @@ -346,33 +346,37 @@ class PerDestinationQueue(object): return pending_edus @defer.inlineCallbacks - def _get_new_device_messages(self): - last_device_stream_id = self._last_device_stream_id - to_device_stream_id = self._store.get_to_device_stream_token() - contents, stream_id = yield self._store.get_new_device_msgs_for_remote( - self._destination, last_device_stream_id, to_device_stream_id + def _get_new_device_messages(self, limit): + last_device_list = self._last_device_list_stream_id + # Will return at most 20 entries + now_stream_id, results = yield self._store.get_devices_by_remote( + self._destination, last_device_list ) edus = [ Edu( origin=self._server_name, destination=self._destination, - edu_type="m.direct_to_device", + edu_type="m.device_list_update", content=content, ) - for content in contents + for content in results ] - last_device_list = self._last_device_list_stream_id - now_stream_id, results = yield self._store.get_devices_by_remote( - self._destination, last_device_list + assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs" + + last_device_stream_id = self._last_device_stream_id + to_device_stream_id = self._store.get_to_device_stream_token() + contents, stream_id = yield self._store.get_new_device_msgs_for_remote( + self._destination, last_device_stream_id, to_device_stream_id, limit - len(edus) ) edus.extend( Edu( origin=self._server_name, destination=self._destination, - edu_type="m.device_list_update", + edu_type="m.direct_to_device", content=content, ) - for content in results + for content in contents ) + defer.returnValue((edus, stream_id, now_stream_id)) diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index fed4ea3610..9b0a99cb49 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -118,7 +118,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): defer.returnValue(count) def get_new_device_msgs_for_remote( - self, destination, last_stream_id, current_stream_id, limit=100 + self, destination, last_stream_id, current_stream_id, limit ): """ Args: -- cgit 1.5.1 From 4fb44fb5b9f0043ab7bfbc0aa7251c92680cc639 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 14 May 2019 13:37:44 +0100 Subject: Expose DataStore._get_events as get_events_as_list This is in preparation for reaction work which requires it. --- synapse/storage/appservice.py | 4 +-- synapse/storage/event_federation.py | 6 ++--- synapse/storage/events_worker.py | 50 ++++++++++++++++++++++++++----------- synapse/storage/search.py | 4 +-- synapse/storage/stream.py | 18 +++++++------ tests/storage/test_appservice.py | 2 +- 6 files changed, 54 insertions(+), 30 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index 6092f600ba..eb329ebd8b 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -302,7 +302,7 @@ class ApplicationServiceTransactionWorkerStore( event_ids = json.loads(entry["event_ids"]) - events = yield self._get_events(event_ids) + events = yield self.get_events_as_list(event_ids) defer.returnValue( AppServiceTransaction(service=service, id=entry["txn_id"], events=events) @@ -358,7 +358,7 @@ class ApplicationServiceTransactionWorkerStore( "get_new_events_for_appservice", get_new_events_for_appservice_txn ) - events = yield self._get_events(event_ids) + events = yield self.get_events_as_list(event_ids) defer.returnValue((upper_bound, events)) diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 956f876572..09e39c2c28 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -45,7 +45,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas """ return self.get_auth_chain_ids( event_ids, include_given=include_given - ).addCallback(self._get_events) + ).addCallback(self.get_events_as_list) def get_auth_chain_ids(self, event_ids, include_given=False): """Get auth events for given event_ids. The events *must* be state events. @@ -316,7 +316,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas event_list, limit, ) - .addCallback(self._get_events) + .addCallback(self.get_events_as_list) .addCallback(lambda l: sorted(l, key=lambda e: -e.depth)) ) @@ -382,7 +382,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas latest_events, limit, ) - events = yield self._get_events(ids) + events = yield self.get_events_as_list(ids) defer.returnValue(events) def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 663991a9b6..5c40056d11 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -103,7 +103,7 @@ class EventsWorkerStore(SQLBaseStore): Returns: Deferred : A FrozenEvent. """ - events = yield self._get_events( + events = yield self.get_events_as_list( [event_id], check_redacted=check_redacted, get_prev_content=get_prev_content, @@ -142,7 +142,7 @@ class EventsWorkerStore(SQLBaseStore): Returns: Deferred : Dict from event_id to event. """ - events = yield self._get_events( + events = yield self.get_events_as_list( event_ids, check_redacted=check_redacted, get_prev_content=get_prev_content, @@ -152,13 +152,32 @@ class EventsWorkerStore(SQLBaseStore): defer.returnValue({e.event_id: e for e in events}) @defer.inlineCallbacks - def _get_events( + def get_events_as_list( self, event_ids, check_redacted=True, get_prev_content=False, allow_rejected=False, ): + """Get events from the database and return in a list in the same order + as given by `event_ids` arg. + + Args: + event_ids (list): The event_ids of the events to fetch + check_redacted (bool): If True, check if event has been redacted + and redact it. + get_prev_content (bool): If True and event is a state event, + include the previous states content in the unsigned field. + allow_rejected (bool): If True return rejected events. + + Returns: + Deferred[list]: List of events fetched from the database. The + events are in the same order as `event_ids` arg. + + Note that the returned list may be smaller than the list of event + IDs if not all events could be fetched. + """ + if not event_ids: defer.returnValue([]) @@ -202,21 +221,22 @@ class EventsWorkerStore(SQLBaseStore): # # The problem is that we end up at this point when an event # which has been redacted is pulled out of the database by - # _enqueue_events, because _enqueue_events needs to check the - # redaction before it can cache the redacted event. So obviously, - # calling get_event to get the redacted event out of the database - # gives us an infinite loop. + # _enqueue_events, because _enqueue_events needs to check + # the redaction before it can cache the redacted event. So + # obviously, calling get_event to get the redacted event out + # of the database gives us an infinite loop. # - # For now (quick hack to fix during 0.99 release cycle), we just - # go and fetch the relevant row from the db, but it would be nice - # to think about how we can cache this rather than hit the db - # every time we access a redaction event. + # For now (quick hack to fix during 0.99 release cycle), we + # just go and fetch the relevant row from the db, but it + # would be nice to think about how we can cache this rather + # than hit the db every time we access a redaction event. # # One thought on how to do this: - # 1. split _get_events up so that it is divided into (a) get the - # rawish event from the db/cache, (b) do the redaction/rejection - # filtering - # 2. have _get_event_from_row just call the first half of that + # 1. split get_events_as_list up so that it is divided into + # (a) get the rawish event from the db/cache, (b) do the + # redaction/rejection filtering + # 2. have _get_event_from_row just call the first half of + # that orig_sender = yield self._simple_select_one_onecol( table="events", diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 226f8f1b7e..ff49eaae02 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -460,7 +460,7 @@ class SearchStore(BackgroundUpdateStore): results = list(filter(lambda row: row["room_id"] in room_ids, results)) - events = yield self._get_events([r["event_id"] for r in results]) + events = yield self.get_events_as_list([r["event_id"] for r in results]) event_map = {ev.event_id: ev for ev in events} @@ -605,7 +605,7 @@ class SearchStore(BackgroundUpdateStore): results = list(filter(lambda row: row["room_id"] in room_ids, results)) - events = yield self._get_events([r["event_id"] for r in results]) + events = yield self.get_events_as_list([r["event_id"] for r in results]) event_map = {ev.event_id: ev for ev in events} diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 9cd1e0f9fe..d105b6b17d 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -319,7 +319,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): rows = yield self.runInteraction("get_room_events_stream_for_room", f) - ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True) + ret = yield self.get_events_as_list([ + r.event_id for r in rows], get_prev_content=True, + ) self._set_before_and_after(ret, rows, topo_order=from_id is None) @@ -367,7 +369,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): rows = yield self.runInteraction("get_membership_changes_for_user", f) - ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True) + ret = yield self.get_events_as_list( + [r.event_id for r in rows], get_prev_content=True, + ) self._set_before_and_after(ret, rows, topo_order=False) @@ -394,7 +398,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) logger.debug("stream before") - events = yield self._get_events( + events = yield self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) logger.debug("stream after") @@ -580,11 +584,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): event_filter, ) - events_before = yield self._get_events( + events_before = yield self.get_events_as_list( [e for e in results["before"]["event_ids"]], get_prev_content=True ) - events_after = yield self._get_events( + events_after = yield self.get_events_as_list( [e for e in results["after"]["event_ids"]], get_prev_content=True ) @@ -697,7 +701,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): "get_all_new_events_stream", get_all_new_events_stream_txn ) - events = yield self._get_events(event_ids) + events = yield self.get_events_as_list(event_ids) defer.returnValue((upper_bound, events)) @@ -849,7 +853,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): event_filter, ) - events = yield self._get_events( + events = yield self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 3f0083831b..25a6c89ef5 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -340,7 +340,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): other_events = [Mock(event_id="e5"), Mock(event_id="e6")] # we aren't testing store._base stuff here, so mock this out - self.store._get_events = Mock(return_value=events) + self.store.get_events_as_list = Mock(return_value=events) yield self._insert_txn(self.as_list[1]["id"], 9, other_events) yield self._insert_txn(service.id, 10, events) -- cgit 1.5.1 From 52ddc6c0ed40dfc15e8e4abd383a2a5fc12fcfec Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 15 May 2019 09:52:15 +0100 Subject: Update docstring with correct type Co-Authored-By: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> --- synapse/storage/events_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 5c40056d11..adc6cf26b5 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -171,7 +171,7 @@ class EventsWorkerStore(SQLBaseStore): allow_rejected (bool): If True return rejected events. Returns: - Deferred[list]: List of events fetched from the database. The + Deferred[list[EventBase]]: List of events fetched from the database. The events are in the same order as `event_ids` arg. Note that the returned list may be smaller than the list of event -- cgit 1.5.1 From 54d77107c1fde88333ecccedfe30f26fb7645847 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 15 May 2019 11:27:50 +0100 Subject: Make generating SQL bounds for pagination generic This will allow us to reuse the same structure when we paginate e.g. relations --- synapse/storage/stream.py | 179 ++++++++++++++++++++++++++++++---------------- 1 file changed, 118 insertions(+), 61 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index d105b6b17d..163363c0c2 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -64,59 +64,120 @@ _EventDictReturn = namedtuple( ) -def lower_bound(token, engine, inclusive=False): - inclusive = "=" if inclusive else "" - if token.topological is None: - return "(%d <%s %s)" % (token.stream, inclusive, "stream_ordering") - else: - if isinstance(engine, PostgresEngine): - # Postgres doesn't optimise ``(x < a) OR (x=a AND y= (topological_ordering, stream_ordering) + AND (5, 3) < (topological_ordering, stream_ordering) + + would be generated for dir=b, from_token=(6, 7) and to_token=(5, 3). + + Args: + direction (str): Whether we're paginating backwards("b") or + forwards ("f"). + column_names (tuple[str, str]): The column names to bound. Must *not* + be user defined as these get inserted directly into the SQL + statement without escapes. + from_token (tuple[int, int]|None) + to_token (tuple[int, int]|None) + engine: The database engine to generate the clauses for + + Returns: + str: The sql expression + """ + assert direction in ("b", "f") + + where_clause = [] + if from_token: + where_clause.append( + _make_generic_sql_bound( + bound=">=" if direction == "b" else "<", + column_names=column_names, + values=from_token, + engine=engine, ) - return "(%d < %s OR (%d = %s AND %d <%s %s))" % ( - token.topological, - "topological_ordering", - token.topological, - "topological_ordering", - token.stream, - inclusive, - "stream_ordering", - ) - - -def upper_bound(token, engine, inclusive=True): - inclusive = "=" if inclusive else "" - if token.topological is None: - return "(%d >%s %s)" % (token.stream, inclusive, "stream_ordering") - else: - if isinstance(engine, PostgresEngine): - # Postgres doesn't optimise ``(x > a) OR (x=a AND y>b)`` as well - # as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we - # use the later form when running against postgres. - return "((%d,%d) >%s (%s,%s))" % ( - token.topological, - token.stream, - inclusive, - "topological_ordering", - "stream_ordering", + ) + + if to_token: + where_clause.append( + _make_generic_sql_bound( + bound="<" if direction == "b" else ">=", + column_names=column_names, + values=to_token, + engine=engine, ) - return "(%d > %s OR (%d = %s AND %d >%s %s))" % ( - token.topological, - "topological_ordering", - token.topological, - "topological_ordering", - token.stream, - inclusive, - "stream_ordering", ) + return " AND ".join(where_clause) + + +def _make_generic_sql_bound(bound, column_names, values, engine): + """Create an SQL expression that bounds the given column names by the + values, e.g. create the equivalent of `(1, 2) < (col1, col2)`. + + Only works with two columns. + + Older versions of SQLite don't support that syntax so we have to expand it + out manually. + + Args: + bound (str): The comparison operator to use. One of ">", "<", ">=", + "<=", where the values are on the left and columns on the right. + names (tuple[str, str]): The column names. Must *not* be user defined + as these get inserted directly into the SQL statement without + escapes. + values (tuple[int, int]): The values to bound the columns by. + engine: The database engine to generate the SQL for + + Returns: + str + """ + + assert(bound in (">", "<", ">=", "<=")) + + name1, name2 = column_names + val1, val2 = values + + if val1 is None: + val2 = int(val2) + return "(%d %s %s)" % (val2, bound, name2) + + val1 = int(val1) + val2 = int(val2) + + if isinstance(engine, PostgresEngine): + # Postgres doesn't optimise ``(x < a) OR (x=a AND y Date: Tue, 14 May 2019 16:59:21 +0100 Subject: Add simple send_relation API and track in DB --- synapse/api/constants.py | 8 ++ synapse/rest/__init__.py | 2 + synapse/rest/client/v2_alpha/relations.py | 110 ++++++++++++++++++++++++++ synapse/storage/__init__.py | 2 + synapse/storage/events.py | 2 + synapse/storage/relations.py | 62 +++++++++++++++ synapse/storage/schema/delta/54/relations.sql | 27 +++++++ tests/rest/client/v2_alpha/test_relations.py | 98 +++++++++++++++++++++++ 8 files changed, 311 insertions(+) create mode 100644 synapse/rest/client/v2_alpha/relations.py create mode 100644 synapse/storage/relations.py create mode 100644 synapse/storage/schema/delta/54/relations.sql create mode 100644 tests/rest/client/v2_alpha/test_relations.py (limited to 'synapse/storage') diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 8547a63535..30bebd749f 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -116,3 +116,11 @@ class UserTypes(object): """ SUPPORT = "support" ALL_USER_TYPES = (SUPPORT,) + + +class RelationTypes(object): + """The types of relations known to this server. + """ + ANNOTATION = "m.annotation" + REPLACES = "m.replaces" + REFERENCES = "m.references" diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 3a24d31d1b..e6110ad9b1 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -44,6 +44,7 @@ from synapse.rest.client.v2_alpha import ( read_marker, receipts, register, + relations, report_event, room_keys, room_upgrade_rest_servlet, @@ -115,6 +116,7 @@ class ClientRestResource(JsonResource): room_upgrade_rest_servlet.register_servlets(hs, client_resource) capabilities.register_servlets(hs, client_resource) account_validity.register_servlets(hs, client_resource) + relations.register_servlets(hs, client_resource) # moving to /_synapse/admin synapse.rest.admin.register_servlets_for_client_rest_resource( diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py new file mode 100644 index 0000000000..b504b4a8be --- /dev/null +++ b/synapse/rest/client/v2_alpha/relations.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This class implements the proposed relation APIs from MSC 1849. + +Since the MSC has not been approved all APIs here are unstable and may change at +any time to reflect changes in the MSC. +""" + +import logging + +from twisted.internet import defer + +from synapse.api.constants import EventTypes +from synapse.api.errors import SynapseError +from synapse.http.servlet import ( + RestServlet, + parse_json_object_from_request, + parse_string, +) + +from ._base import client_v2_patterns + +logger = logging.getLogger(__name__) + + +class RelationSendServlet(RestServlet): + """Helper API for sending events that have relation data. + + Example API shape to send a 👍 reaction to a room: + + POST /rooms/!foo/send_relation/$bar/m.annotation/m.reaction?key=%F0%9F%91%8D + {} + + { + "event_id": "$foobar" + } + """ + + PATTERN = ( + "/rooms/(?P[^/]*)/send_relation" + "/(?P[^/]*)/(?P[^/]*)/(?P[^/]*)" + ) + + def __init__(self, hs): + super(RelationSendServlet, self).__init__() + self.auth = hs.get_auth() + self.event_creation_handler = hs.get_event_creation_handler() + + def register(self, http_server): + http_server.register_paths( + "POST", + client_v2_patterns(self.PATTERN + "$", releases=()), + self.on_PUT_or_POST, + ) + http_server.register_paths( + "PUT", + client_v2_patterns(self.PATTERN + "/(?P[^/]*)$", releases=()), + self.on_PUT_or_POST, + ) + + @defer.inlineCallbacks + def on_PUT_or_POST( + self, request, room_id, parent_id, relation_type, event_type, txn_id=None + ): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + + if event_type == EventTypes.Member: + # Add relations to a membership is meaningless, so we just deny it + # at the CS API rather than trying to handle it correctly. + raise SynapseError(400, "Cannot send member events with relations") + + content = parse_json_object_from_request(request) + + aggregation_key = parse_string(request, "key", encoding="utf-8") + + content["m.relates_to"] = { + "event_id": parent_id, + "key": aggregation_key, + "rel_type": relation_type, + } + + event_dict = { + "type": event_type, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + } + + event = yield self.event_creation_handler.create_and_send_nonmember_event( + requester, event_dict=event_dict, txn_id=txn_id + ) + + defer.returnValue((200, {"event_id": event.event_id})) + + +def register_servlets(hs, http_server): + RelationSendServlet(hs).register(http_server) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index c432041b4e..7522d3fd57 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -49,6 +49,7 @@ from .pusher import PusherStore from .receipts import ReceiptsStore from .registration import RegistrationStore from .rejections import RejectionsStore +from .relations import RelationsStore from .room import RoomStore from .roommember import RoomMemberStore from .search import SearchStore @@ -99,6 +100,7 @@ class DataStore( GroupServerStore, UserErasureStore, MonthlyActiveUsersStore, + RelationsStore, ): def __init__(self, db_conn, hs): self.hs = hs diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 7a7f841c6c..6802bf42ce 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1351,6 +1351,8 @@ class EventsStore( # Insert into the event_search table. self._store_guest_access_txn(txn, event) + self._handle_event_relations(txn, event) + # Insert into the room_memberships table. self._store_room_members_txn( txn, diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py new file mode 100644 index 0000000000..a4905162e0 --- /dev/null +++ b/synapse/storage/relations.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.constants import RelationTypes +from synapse.storage._base import SQLBaseStore + +logger = logging.getLogger(__name__) + + +class RelationsStore(SQLBaseStore): + def _handle_event_relations(self, txn, event): + """Handles inserting relation data during peristence of events + + Args: + txn + event (EventBase) + """ + relation = event.content.get("m.relates_to") + if not relation: + # No relations + return + + rel_type = relation.get("rel_type") + if rel_type not in ( + RelationTypes.ANNOTATION, + RelationTypes.REFERENCES, + RelationTypes.REPLACES, + ): + # Unknown relation type + return + + parent_id = relation.get("event_id") + if not parent_id: + # Invalid relation + return + + aggregation_key = relation.get("key") + + self._simple_insert_txn( + txn, + table="event_relations", + values={ + "event_id": event.event_id, + "relates_to_id": parent_id, + "relation_type": rel_type, + "aggregation_key": aggregation_key, + }, + ) diff --git a/synapse/storage/schema/delta/54/relations.sql b/synapse/storage/schema/delta/54/relations.sql new file mode 100644 index 0000000000..134862b870 --- /dev/null +++ b/synapse/storage/schema/delta/54/relations.sql @@ -0,0 +1,27 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Tracks related events, like reactions, replies, edits, etc. Note that things +-- in this table are not necessarily "valid", e.g. it may contain edits from +-- people who don't have power to edit other peoples events. +CREATE TABLE IF NOT EXISTS event_relations ( + event_id TEXT NOT NULL, + relates_to_id TEXT NOT NULL, + relation_type TEXT NOT NULL, + aggregation_key TEXT +); + +CREATE UNIQUE INDEX event_relations_id ON event_relations(event_id); +CREATE INDEX event_relations_relates ON event_relations(relates_to_id, relation_type, aggregation_key); diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py new file mode 100644 index 0000000000..61163d5b26 --- /dev/null +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import six + +from synapse.api.constants import EventTypes, RelationTypes +from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import relations + +from tests import unittest + + +class RelationsTestCase(unittest.HomeserverTestCase): + user_id = "@alice:test" + servlets = [ + relations.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.room = self.helper.create_room_as(self.user_id) + res = self.helper.send(self.room, body="Hi!") + self.parent_id = res["event_id"] + + def test_send_relation(self): + """Tests that sending a relation using the new /send_relation works + creates the right shape of event. + """ + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") + self.assertEquals(200, channel.code, channel.json_body) + + event_id = channel.json_body["event_id"] + + request, channel = self.make_request( + "GET", "/rooms/%s/event/%s" % (self.room, event_id) + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assert_dict( + { + "type": "m.reaction", + "sender": self.user_id, + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "key": u"👍", + "rel_type": RelationTypes.ANNOTATION, + } + }, + }, + channel.json_body, + ) + + def test_deny_membership(self): + """Test that we deny relations on membership events + """ + channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member) + self.assertEquals(400, channel.code, channel.json_body) + + def _send_relation(self, relation_type, event_type, key=None): + """Helper function to send a relation pointing at `self.parent_id` + + Args: + relation_type (str): One of `RelationTypes` + event_type (str): The type of the event to create + key (str|None): The aggregation key used for m.annotation relation + type. + + Returns: + FakeChannel + """ + query = "" + if key: + query = "?key=" + six.moves.urllib.parse.quote_plus(key) + + request, channel = self.make_request( + "POST", + "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" + % (self.room, self.parent_id, relation_type, event_type, query), + b"{}", + ) + self.render(request) + return channel -- cgit 1.5.1 From b50641e357f6139b0807df3714440a8ccc21d628 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 14 May 2019 16:59:21 +0100 Subject: Add simple pagination API --- synapse/rest/client/v2_alpha/relations.py | 50 +++++++++++++++++ synapse/storage/relations.py | 80 ++++++++++++++++++++++++++++ tests/rest/client/v2_alpha/test_relations.py | 30 +++++++++++ 3 files changed, 160 insertions(+) (limited to 'synapse/storage') diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index b504b4a8be..bac9e85c21 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -27,6 +27,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError from synapse.http.servlet import ( RestServlet, + parse_integer, parse_json_object_from_request, parse_string, ) @@ -106,5 +107,54 @@ class RelationSendServlet(RestServlet): defer.returnValue((200, {"event_id": event.event_id})) +class RelationPaginationServlet(RestServlet): + """API to paginate relations on an event by topological ordering, optionally + filtered by relation type and event type. + """ + + PATTERNS = client_v2_patterns( + "/rooms/(?P[^/]*)/relations/(?P[^/]*)" + "(/(?P[^/]*)(/(?P[^/]*))?)?$", + releases=(), + ) + + def __init__(self, hs): + super(RelationPaginationServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self._event_serializer = hs.get_event_client_serializer() + + @defer.inlineCallbacks + def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + + yield self.auth.check_in_room_or_world_readable( + room_id, requester.user.to_string() + ) + + limit = parse_integer(request, "limit", default=5) + + result = yield self.store.get_relations_for_event( + event_id=parent_id, + relation_type=relation_type, + event_type=event_type, + limit=limit, + ) + + events = yield self.store.get_events_as_list( + [c["event_id"] for c in result.chunk] + ) + + now = self.clock.time_msec() + events = yield self._event_serializer.serialize_events(events, now) + + return_value = result.to_dict() + return_value["chunk"] = events + + defer.returnValue((200, return_value)) + + def register_servlets(hs, http_server): RelationSendServlet(hs).register(http_server) + RelationPaginationServlet(hs).register(http_server) diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index a4905162e0..f304b8a9a4 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -15,13 +15,93 @@ import logging +import attr + from synapse.api.constants import RelationTypes from synapse.storage._base import SQLBaseStore logger = logging.getLogger(__name__) +@attr.s +class PaginationChunk(object): + """Returned by relation pagination APIs. + + Attributes: + chunk (list): The rows returned by pagination + """ + + chunk = attr.ib() + + def to_dict(self): + d = {"chunk": self.chunk} + + return d + + class RelationsStore(SQLBaseStore): + def get_relations_for_event( + self, event_id, relation_type=None, event_type=None, limit=5, direction="b" + ): + """Get a list of relations for an event, ordered by topological ordering. + + Args: + event_id (str): Fetch events that relate to this event ID. + relation_type (str|None): Only fetch events with this relation + type, if given. + event_type (str|None): Only fetch events with this event type, if + given. + limit (int): Only fetch the most recent `limit` events. + direction (str): Whether to fetch the most recent first (`"b"`) or + the oldest first (`"f"`). + + Returns: + Deferred[PaginationChunk]: List of event IDs that match relations + requested. The rows are of the form `{"event_id": "..."}`. + """ + + # TODO: Pagination tokens + + where_clause = ["relates_to_id = ?"] + where_args = [event_id] + + if relation_type: + where_clause.append("relation_type = ?") + where_args.append(relation_type) + + if event_type: + where_clause.append("type = ?") + where_args.append(event_type) + + order = "ASC" + if direction == "b": + order = "DESC" + + sql = """ + SELECT event_id FROM event_relations + INNER JOIN events USING (event_id) + WHERE %s + ORDER BY topological_ordering %s, stream_ordering %s + LIMIT ? + """ % ( + " AND ".join(where_clause), + order, + order, + ) + + def _get_recent_references_for_event_txn(txn): + txn.execute(sql, where_args + [limit + 1]) + + events = [{"event_id": row[0]} for row in txn] + + return PaginationChunk( + chunk=list(events[:limit]), + ) + + return self.runInteraction( + "get_recent_references_for_event", _get_recent_references_for_event_txn + ) + def _handle_event_relations(self, txn, event): """Handles inserting relation data during peristence of events diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index 61163d5b26..bcc1c1bb85 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -72,6 +72,36 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member) self.assertEquals(400, channel.code, channel.json_body) + def test_paginate(self): + """Tests that calling pagination API corectly the latest relations. + """ + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + self.assertEquals(200, channel.code, channel.json_body) + annotation_id = channel.json_body["event_id"] + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1" + % (self.room, self.parent_id), + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + # We expect to get back a single pagination result, which is the full + # relation event we sent above. + self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assert_dict( + { + "event_id": annotation_id, + "sender": self.user_id, + "type": "m.reaction", + }, + channel.json_body["chunk"][0], + ) + def _send_relation(self, relation_type, event_type, key=None): """Helper function to send a relation pointing at `self.parent_id` -- cgit 1.5.1 From 5be34fc3e3045b3ba5a97b73ba0f25887874e622 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 15 May 2019 17:30:23 +0100 Subject: Actually check for None rather falsey --- synapse/storage/relations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index f304b8a9a4..31ef6679af 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -65,11 +65,11 @@ class RelationsStore(SQLBaseStore): where_clause = ["relates_to_id = ?"] where_args = [event_id] - if relation_type: + if relation_type is not None: where_clause.append("relation_type = ?") where_args.append(relation_type) - if event_type: + if event_type is not None: where_clause.append("type = ?") where_args.append(event_type) -- cgit 1.5.1 From a0603523d2e210cf59f887bd75e1a755720cb7a8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 14 May 2019 16:59:21 +0100 Subject: Add aggregations API --- synapse/config/server.py | 5 + synapse/events/utils.py | 34 +++- synapse/rest/client/v2_alpha/relations.py | 141 ++++++++++++++- synapse/storage/relations.py | 225 +++++++++++++++++++++++- tests/rest/client/v2_alpha/test_relations.py | 251 ++++++++++++++++++++++++++- 5 files changed, 643 insertions(+), 13 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/config/server.py b/synapse/config/server.py index 7874cd9da7..f403477b54 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -100,6 +100,11 @@ class ServerConfig(Config): "block_non_admin_invites", False, ) + # Whether to enable experimental MSC1849 (aka relations) support + self.experimental_msc1849_support_enabled = config.get( + "experimental_msc1849_support_enabled", False, + ) + # Options to control access by tracking MAU self.limit_usage_by_mau = config.get("limit_usage_by_mau", False) self.max_mau_value = 0 diff --git a/synapse/events/utils.py b/synapse/events/utils.py index a5454556cc..16d0c64372 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -21,7 +21,7 @@ from frozendict import frozendict from twisted.internet import defer -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, RelationTypes from synapse.util.async_helpers import yieldable_gather_results from . import EventBase @@ -324,8 +324,12 @@ class EventClientSerializer(object): """ def __init__(self, hs): - pass + self.store = hs.get_datastore() + self.experimental_msc1849_support_enabled = ( + hs.config.experimental_msc1849_support_enabled + ) + @defer.inlineCallbacks def serialize_event(self, event, time_now, **kwargs): """Serializes a single event. @@ -337,8 +341,32 @@ class EventClientSerializer(object): Returns: Deferred[dict]: The serialized event """ + # To handle the case of presence events and the like + if not isinstance(event, EventBase): + defer.returnValue(event) + + event_id = event.event_id event = serialize_event(event, time_now, **kwargs) - return defer.succeed(event) + + # If MSC1849 is enabled then we need to look if thre are any relations + # we need to bundle in with the event + if self.experimental_msc1849_support_enabled: + annotations = yield self.store.get_aggregation_groups_for_event( + event_id, + ) + references = yield self.store.get_relations_for_event( + event_id, RelationTypes.REFERENCES, direction="f", + ) + + if annotations.chunk: + r = event["unsigned"].setdefault("m.relations", {}) + r[RelationTypes.ANNOTATION] = annotations.to_dict() + + if references.chunk: + r = event["unsigned"].setdefault("m.relations", {}) + r[RelationTypes.REFERENCES] = references.to_dict() + + defer.returnValue(event) def serialize_events(self, events, time_now, **kwargs): """Serializes multiple events. diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index c3ac73b8c7..f468409b68 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -23,7 +23,7 @@ import logging from twisted.internet import defer -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.http.servlet import ( RestServlet, @@ -141,12 +141,16 @@ class RelationPaginationServlet(RestServlet): ) limit = parse_integer(request, "limit", default=5) + from_token = parse_string(request, "from") + to_token = parse_string(request, "to") result = yield self.store.get_relations_for_event( event_id=parent_id, relation_type=relation_type, event_type=event_type, limit=limit, + from_token=from_token, + to_token=to_token, ) events = yield self.store.get_events_as_list( @@ -162,6 +166,141 @@ class RelationPaginationServlet(RestServlet): defer.returnValue((200, return_value)) +class RelationAggregationPaginationServlet(RestServlet): + """API to paginate aggregation groups of relations, e.g. paginate the + types and counts of the reactions on the events. + + Example request and response: + + GET /rooms/{room_id}/aggregations/{parent_id} + + { + chunk: [ + { + "type": "m.reaction", + "key": "👍", + "count": 3 + } + ] + } + """ + + PATTERNS = client_v2_patterns( + "/rooms/(?P[^/]*)/aggregations/(?P[^/]*)" + "(/(?P[^/]*)(/(?P[^/]*))?)?$", + releases=(), + ) + + def __init__(self, hs): + super(RelationAggregationPaginationServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + + yield self.auth.check_in_room_or_world_readable( + room_id, requester.user.to_string() + ) + + if relation_type not in (RelationTypes.ANNOTATION, None): + raise SynapseError(400, "Relation type must be 'annotation'") + + limit = parse_integer(request, "limit", default=5) + from_token = parse_string(request, "from") + to_token = parse_string(request, "to") + + res = yield self.store.get_aggregation_groups_for_event( + event_id=parent_id, + event_type=event_type, + limit=limit, + from_token=from_token, + to_token=to_token, + ) + + defer.returnValue((200, res.to_dict())) + + +class RelationAggregationGroupPaginationServlet(RestServlet): + """API to paginate within an aggregation group of relations, e.g. paginate + all the 👍 reactions on an event. + + Example request and response: + + GET /rooms/{room_id}/aggregations/{parent_id}/m.annotation/m.reaction/👍 + + { + chunk: [ + { + "type": "m.reaction", + "content": { + "m.relates_to": { + "rel_type": "m.annotation", + "key": "👍" + } + } + }, + ... + ] + } + """ + + PATTERNS = client_v2_patterns( + "/rooms/(?P[^/]*)/aggregations/(?P[^/]*)" + "/(?P[^/]*)/(?P[^/]*)/(?P[^/]*)$", + releases=(), + ) + + def __init__(self, hs): + super(RelationAggregationGroupPaginationServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self._event_serializer = hs.get_event_client_serializer() + + @defer.inlineCallbacks + def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + + yield self.auth.check_in_room_or_world_readable( + room_id, requester.user.to_string() + ) + + if relation_type != RelationTypes.ANNOTATION: + raise SynapseError(400, "Relation type must be 'annotation'") + + limit = parse_integer(request, "limit", default=5) + from_token = parse_string(request, "from") + to_token = parse_string(request, "to") + + result = yield self.store.get_relations_for_event( + event_id=parent_id, + relation_type=relation_type, + event_type=event_type, + aggregation_key=key, + limit=limit, + from_token=from_token, + to_token=to_token, + ) + + events = yield self.store.get_events_as_list( + [c["event_id"] for c in result.chunk] + ) + + now = self.clock.time_msec() + events = yield self._event_serializer.serialize_events(events, now) + + return_value = result.to_dict() + return_value["chunk"] = events + + defer.returnValue((200, return_value)) + + defer.returnValue((200, return_value)) + + def register_servlets(hs, http_server): RelationSendServlet(hs).register(http_server) RelationPaginationServlet(hs).register(http_server) + RelationAggregationPaginationServlet(hs).register(http_server) + RelationAggregationGroupPaginationServlet(hs).register(http_server) diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 31ef6679af..db4b842c97 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -18,7 +18,9 @@ import logging import attr from synapse.api.constants import RelationTypes +from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore +from synapse.storage.stream import generate_pagination_where_clause logger = logging.getLogger(__name__) @@ -29,19 +31,94 @@ class PaginationChunk(object): Attributes: chunk (list): The rows returned by pagination + next_batch (Any|None): Token to fetch next set of results with, if + None then there are no more results. + prev_batch (Any|None): Token to fetch previous set of results with, if + None then there are no previous results. """ chunk = attr.ib() + next_batch = attr.ib(default=None) + prev_batch = attr.ib(default=None) def to_dict(self): d = {"chunk": self.chunk} + if self.next_batch: + d["next_batch"] = self.next_batch.to_string() + + if self.prev_batch: + d["prev_batch"] = self.prev_batch.to_string() + return d +@attr.s +class RelationPaginationToken(object): + """Pagination token for relation pagination API. + + As the results are order by topological ordering, we can use the + `topological_ordering` and `stream_ordering` fields of the events at the + boundaries of the chunk as pagination tokens. + + Attributes: + topological (int): The topological ordering of the boundary event + stream (int): The stream ordering of the boundary event. + """ + + topological = attr.ib() + stream = attr.ib() + + @staticmethod + def from_string(string): + try: + t, s = string.split("-") + return RelationPaginationToken(int(t), int(s)) + except ValueError: + raise SynapseError(400, "Invalid token") + + def to_string(self): + return "%d-%d" % (self.topological, self.stream) + + +@attr.s +class AggregationPaginationToken(object): + """Pagination token for relation aggregation pagination API. + + As the results are order by count and then MAX(stream_ordering) of the + aggregation groups, we can just use them as our pagination token. + + Attributes: + count (int): The count of relations in the boundar group. + stream (int): The MAX stream ordering in the boundary group. + """ + + count = attr.ib() + stream = attr.ib() + + @staticmethod + def from_string(string): + try: + c, s = string.split("-") + return AggregationPaginationToken(int(c), int(s)) + except ValueError: + raise SynapseError(400, "Invalid token") + + def to_string(self): + return "%d-%d" % (self.count, self.stream) + + class RelationsStore(SQLBaseStore): def get_relations_for_event( - self, event_id, relation_type=None, event_type=None, limit=5, direction="b" + self, + event_id, + relation_type=None, + event_type=None, + aggregation_key=None, + limit=5, + direction="b", + from_token=None, + to_token=None, ): """Get a list of relations for an event, ordered by topological ordering. @@ -51,16 +128,26 @@ class RelationsStore(SQLBaseStore): type, if given. event_type (str|None): Only fetch events with this event type, if given. + aggregation_key (str|None): Only fetch events with this aggregation + key, if given. limit (int): Only fetch the most recent `limit` events. direction (str): Whether to fetch the most recent first (`"b"`) or the oldest first (`"f"`). + from_token (RelationPaginationToken|None): Fetch rows from the given + token, or from the start if None. + to_token (RelationPaginationToken|None): Fetch rows up to the given + token, or up to the end if None. Returns: Deferred[PaginationChunk]: List of event IDs that match relations requested. The rows are of the form `{"event_id": "..."}`. """ - # TODO: Pagination tokens + if from_token: + from_token = RelationPaginationToken.from_string(from_token) + + if to_token: + to_token = RelationPaginationToken.from_string(to_token) where_clause = ["relates_to_id = ?"] where_args = [event_id] @@ -73,12 +160,29 @@ class RelationsStore(SQLBaseStore): where_clause.append("type = ?") where_args.append(event_type) - order = "ASC" + if aggregation_key: + where_clause.append("aggregation_key = ?") + where_args.append(aggregation_key) + + pagination_clause = generate_pagination_where_clause( + direction=direction, + column_names=("topological_ordering", "stream_ordering"), + from_token=attr.astuple(from_token) if from_token else None, + to_token=attr.astuple(to_token) if to_token else None, + engine=self.database_engine, + ) + + if pagination_clause: + where_clause.append(pagination_clause) + if direction == "b": order = "DESC" + else: + order = "ASC" sql = """ - SELECT event_id FROM event_relations + SELECT event_id, topological_ordering, stream_ordering + FROM event_relations INNER JOIN events USING (event_id) WHERE %s ORDER BY topological_ordering %s, stream_ordering %s @@ -92,16 +196,125 @@ class RelationsStore(SQLBaseStore): def _get_recent_references_for_event_txn(txn): txn.execute(sql, where_args + [limit + 1]) - events = [{"event_id": row[0]} for row in txn] + last_topo_id = None + last_stream_id = None + events = [] + for row in txn: + events.append({"event_id": row[0]}) + last_topo_id = row[1] + last_stream_id = row[2] + + next_batch = None + if len(events) > limit and last_topo_id and last_stream_id: + next_batch = RelationPaginationToken(last_topo_id, last_stream_id) return PaginationChunk( - chunk=list(events[:limit]), + chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token ) return self.runInteraction( "get_recent_references_for_event", _get_recent_references_for_event_txn ) + def get_aggregation_groups_for_event( + self, + event_id, + event_type=None, + limit=5, + direction="b", + from_token=None, + to_token=None, + ): + """Get a list of annotations on the event, grouped by event type and + aggregation key, sorted by count. + + This is used e.g. to get the what and how many reactions have happend + on an event. + + Args: + event_id (str): Fetch events that relate to this event ID. + event_type (str|None): Only fetch events with this event type, if + given. + limit (int): Only fetch the `limit` groups. + direction (str): Whether to fetch the highest count first (`"b"`) or + the lowest count first (`"f"`). + from_token (AggregationPaginationToken|None): Fetch rows from the + given token, or from the start if None. + to_token (AggregationPaginationToken|None): Fetch rows up to the + given token, or up to the end if None. + + + Returns: + Deferred[PaginationChunk]: List of groups of annotations that + match. Each row is a dict with `type`, `key` and `count` fields. + """ + + if from_token: + from_token = AggregationPaginationToken.from_string(from_token) + + if to_token: + to_token = AggregationPaginationToken.from_string(to_token) + + where_clause = ["relates_to_id = ?", "relation_type = ?"] + where_args = [event_id, RelationTypes.ANNOTATION] + + if event_type: + where_clause.append("type = ?") + where_args.append(event_type) + + having_clause = generate_pagination_where_clause( + direction=direction, + column_names=("COUNT(*)", "MAX(stream_ordering)"), + from_token=attr.astuple(from_token) if from_token else None, + to_token=attr.astuple(to_token) if to_token else None, + engine=self.database_engine, + ) + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + if having_clause: + having_clause = "HAVING " + having_clause + else: + having_clause = "" + + sql = """ + SELECT type, aggregation_key, COUNT(*), MAX(stream_ordering) + FROM event_relations + INNER JOIN events USING (event_id) + WHERE {where_clause} + GROUP BY relation_type, type, aggregation_key + {having_clause} + ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order} + LIMIT ? + """.format( + where_clause=" AND ".join(where_clause), + order=order, + having_clause=having_clause, + ) + + def _get_aggregation_groups_for_event_txn(txn): + txn.execute(sql, where_args + [limit + 1]) + + next_batch = None + events = [] + for row in txn: + events.append({"type": row[0], "key": row[1], "count": row[2]}) + next_batch = AggregationPaginationToken(row[2], row[3]) + + if len(events) <= limit: + next_batch = None + + return PaginationChunk( + chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token + ) + + return self.runInteraction( + "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn + ) + def _handle_event_relations(self, txn, event): """Handles inserting relation data during peristence of events diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index bcc1c1bb85..661dd45755 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools + import six from synapse.api.constants import EventTypes, RelationTypes @@ -30,6 +32,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): login.register_servlets, ] + def make_homeserver(self, reactor, clock): + # We need to enable msc1849 support for aggregations + config = self.default_config() + config["experimental_msc1849_support_enabled"] = True + return self.setup_test_homeserver(config=config) + def prepare(self, reactor, clock, hs): self.room = self.helper.create_room_as(self.user_id) res = self.helper.send(self.room, body="Hi!") @@ -40,7 +48,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): creates the right shape of event. """ - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key=u"👍") self.assertEquals(200, channel.code, channel.json_body) event_id = channel.json_body["event_id"] @@ -72,7 +80,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member) self.assertEquals(400, channel.code, channel.json_body) - def test_paginate(self): + def test_basic_paginate_relations(self): """Tests that calling pagination API corectly the latest relations. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") @@ -102,6 +110,243 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel.json_body["chunk"][0], ) + # Make sure next_batch has something in it that looks like it could be a + # valid token. + self.assertIsInstance( + channel.json_body.get("next_batch"), six.string_types, channel.json_body + ) + + def test_repeated_paginate_relations(self): + """Test that if we paginate using a limit and tokens then we get the + expected events. + """ + + expected_event_ids = [] + for _ in range(10): + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + self.assertEquals(200, channel.code, channel.json_body) + expected_event_ids.append(channel.json_body["event_id"]) + + prev_token = None + found_event_ids = [] + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s" + % (self.room, self.parent_id, from_token), + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + next_batch = channel.json_body.get("next_batch") + + self.assertNotEquals(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEquals(found_event_ids, expected_event_ids) + + def test_aggregation_pagination_groups(self): + """Test that we can paginate annotation groups correctly. + """ + + sent_groups = {u"👍": 10, u"a": 7, u"b": 5, u"c": 3, u"d": 2, u"e": 1} + for key in itertools.chain.from_iterable( + itertools.repeat(key, num) for key, num in sent_groups.items() + ): + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", key=key + ) + self.assertEquals(200, channel.code, channel.json_body) + + prev_token = None + found_groups = {} + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s" + % (self.room, self.parent_id, from_token), + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + + for groups in channel.json_body["chunk"]: + # We only expect reactions + self.assertEqual(groups["type"], "m.reaction", channel.json_body) + + # We should only see each key once + self.assertNotIn(groups["key"], found_groups, channel.json_body) + + found_groups[groups["key"]] = groups["count"] + + next_batch = channel.json_body.get("next_batch") + + self.assertNotEquals(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + self.assertEquals(sent_groups, found_groups) + + def test_aggregation_pagination_within_group(self): + """Test that we can paginate within an annotation group. + """ + + expected_event_ids = [] + for _ in range(10): + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", key=u"👍" + ) + self.assertEquals(200, channel.code, channel.json_body) + expected_event_ids.append(channel.json_body["event_id"]) + + # Also send a different type of reaction so that we test we don't see it + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") + self.assertEquals(200, channel.code, channel.json_body) + + prev_token = None + found_event_ids = [] + encoded_key = six.moves.urllib.parse.quote_plus(u"👍".encode("utf-8")) + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s" + "/aggregations/%s/%s/m.reaction/%s?limit=1%s" + % ( + self.room, + self.parent_id, + RelationTypes.ANNOTATION, + encoded_key, + from_token, + ), + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + + next_batch = channel.json_body.get("next_batch") + + self.assertNotEquals(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEquals(found_event_ids, expected_event_ids) + + def test_aggregation(self): + """Test that annotations get correctly aggregated. + """ + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") + self.assertEquals(200, channel.code, channel.json_body) + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s" + % (self.room, self.parent_id), + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals( + channel.json_body, + { + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + ] + }, + ) + + def test_aggregation_must_be_annotation(self): + """Test that aggregations must be annotations. + """ + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/m.replaces/%s?limit=1" + % (self.room, self.parent_id), + ) + self.render(request) + self.assertEquals(400, channel.code, channel.json_body) + + def test_aggregation_get_event(self): + """Test that annotations and references get correctly bundled when + getting the parent event. + """ + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.REFERENCES, "m.room.test") + self.assertEquals(200, channel.code, channel.json_body) + reply_1 = channel.json_body["event_id"] + + channel = self._send_relation(RelationTypes.REFERENCES, "m.room.test") + self.assertEquals(200, channel.code, channel.json_body) + reply_2 = channel.json_body["event_id"] + + request, channel = self.make_request( + "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id) + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.maxDiff = None + + self.assertEquals( + channel.json_body["unsigned"].get("m.relations"), + { + RelationTypes.ANNOTATION: { + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + ] + }, + RelationTypes.REFERENCES: { + "chunk": [{"event_id": reply_1}, {"event_id": reply_2}] + }, + }, + ) + def _send_relation(self, relation_type, event_type, key=None): """Helper function to send a relation pointing at `self.parent_id` @@ -116,7 +361,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): """ query = "" if key: - query = "?key=" + six.moves.urllib.parse.quote_plus(key) + query = "?key=" + six.moves.urllib.parse.quote_plus(key.encode("utf-8")) request, channel = self.make_request( "POST", -- cgit 1.5.1 From 33453419b0cbbe29d3f4b376e9eafab10d5d012a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 16 May 2019 09:56:12 +0100 Subject: Add cache to relations --- synapse/storage/relations.py | 8 ++++++++ 1 file changed, 8 insertions(+) (limited to 'synapse/storage') diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index db4b842c97..1fd3d4fafc 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -21,6 +21,7 @@ from synapse.api.constants import RelationTypes from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore from synapse.storage.stream import generate_pagination_where_clause +from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -109,6 +110,7 @@ class AggregationPaginationToken(object): class RelationsStore(SQLBaseStore): + @cached(tree=True) def get_relations_for_event( self, event_id, @@ -216,6 +218,7 @@ class RelationsStore(SQLBaseStore): "get_recent_references_for_event", _get_recent_references_for_event_txn ) + @cached(tree=True) def get_aggregation_groups_for_event( self, event_id, @@ -353,3 +356,8 @@ class RelationsStore(SQLBaseStore): "aggregation_key": aggregation_key, }, ) + + txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,)) + txn.call_after( + self.get_aggregation_groups_for_event.invalidate_many, (parent_id,) + ) -- cgit 1.5.1 From b5c62c6b2643a138956af0b04521540c270a3e94 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 16 May 2019 10:18:53 +0100 Subject: Fix relations in worker mode --- synapse/replication/slave/storage/events.py | 13 ++++++++++--- synapse/replication/tcp/streams/_base.py | 1 + synapse/replication/tcp/streams/events.py | 11 ++++++----- synapse/storage/events.py | 12 ++++++++---- synapse/storage/relations.py | 4 +++- 5 files changed, 28 insertions(+), 13 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index b457c5563f..797450bc66 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import ( from synapse.storage.event_federation import EventFederationWorkerStore from synapse.storage.event_push_actions import EventPushActionsWorkerStore from synapse.storage.events_worker import EventsWorkerStore +from synapse.storage.relations import RelationsWorkerStore from synapse.storage.roommember import RoomMemberWorkerStore from synapse.storage.signatures import SignatureWorkerStore from synapse.storage.state import StateGroupWorkerStore @@ -52,6 +53,7 @@ class SlavedEventStore(EventFederationWorkerStore, EventsWorkerStore, SignatureWorkerStore, UserErasureWorkerStore, + RelationsWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): @@ -89,7 +91,7 @@ class SlavedEventStore(EventFederationWorkerStore, for row in rows: self.invalidate_caches_for_event( -token, row.event_id, row.room_id, row.type, row.state_key, - row.redacts, + row.redacts, row.relates_to, backfilled=True, ) return super(SlavedEventStore, self).process_replication_rows( @@ -102,7 +104,7 @@ class SlavedEventStore(EventFederationWorkerStore, if row.type == EventsStreamEventRow.TypeId: self.invalidate_caches_for_event( token, data.event_id, data.room_id, data.type, data.state_key, - data.redacts, + data.redacts, data.relates_to, backfilled=False, ) elif row.type == EventsStreamCurrentStateRow.TypeId: @@ -114,7 +116,8 @@ class SlavedEventStore(EventFederationWorkerStore, raise Exception("Unknown events stream row type %s" % (row.type, )) def invalidate_caches_for_event(self, stream_ordering, event_id, room_id, - etype, state_key, redacts, backfilled): + etype, state_key, redacts, relates_to, + backfilled): self._invalidate_get_event_cache(event_id) self.get_latest_event_ids_in_room.invalidate((room_id,)) @@ -136,3 +139,7 @@ class SlavedEventStore(EventFederationWorkerStore, state_key, stream_ordering ) self.get_invited_rooms_for_user.invalidate((state_key,)) + + if relates_to: + self.get_relations_for_event.invalidate_many((relates_to,)) + self.get_aggregation_groups_for_event.invalidate_many((relates_to,)) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 8971a6a22e..b6ce7a7bee 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -32,6 +32,7 @@ BackfillStreamRow = namedtuple("BackfillStreamRow", ( "type", # str "state_key", # str, optional "redacts", # str, optional + "relates_to", # str, optional )) PresenceStreamRow = namedtuple("PresenceStreamRow", ( "user_id", # str diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index e0f6e29248..f1290d022a 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -80,11 +80,12 @@ class BaseEventsStreamRow(object): class EventsStreamEventRow(BaseEventsStreamRow): TypeId = "ev" - event_id = attr.ib() # str - room_id = attr.ib() # str - type = attr.ib() # str - state_key = attr.ib() # str, optional - redacts = attr.ib() # str, optional + event_id = attr.ib() # str + room_id = attr.ib() # str + type = attr.ib() # str + state_key = attr.ib() # str, optional + redacts = attr.ib() # str, optional + relates_to = attr.ib() # str, optional @attr.s(slots=True, frozen=True) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 6802bf42ce..b025ebc926 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1657,10 +1657,11 @@ class EventsStore( def get_all_new_forward_event_rows(txn): sql = ( "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" + " state_key, redacts, relates_to_id" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" " WHERE ? < stream_ordering AND stream_ordering <= ?" " ORDER BY stream_ordering ASC" " LIMIT ?" @@ -1675,11 +1676,12 @@ class EventsStore( sql = ( "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" + " state_key, redacts, relates_to_id" " FROM events AS e" " INNER JOIN ex_outlier_stream USING (event_id)" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" " WHERE ? < event_stream_ordering" " AND event_stream_ordering <= ?" " ORDER BY event_stream_ordering DESC" @@ -1700,10 +1702,11 @@ class EventsStore( def get_all_new_backfill_event_rows(txn): sql = ( "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" + " state_key, redacts, relates_to_id" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" " WHERE ? > stream_ordering AND stream_ordering >= ?" " ORDER BY stream_ordering ASC" " LIMIT ?" @@ -1718,11 +1721,12 @@ class EventsStore( sql = ( "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" + " state_key, redacts, relates_to_id" " FROM events AS e" " INNER JOIN ex_outlier_stream USING (event_id)" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" " WHERE ? > event_stream_ordering" " AND event_stream_ordering >= ?" " ORDER BY event_stream_ordering DESC" diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 1fd3d4fafc..732418ec65 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -109,7 +109,7 @@ class AggregationPaginationToken(object): return "%d-%d" % (self.count, self.stream) -class RelationsStore(SQLBaseStore): +class RelationsWorkerStore(SQLBaseStore): @cached(tree=True) def get_relations_for_event( self, @@ -318,6 +318,8 @@ class RelationsStore(SQLBaseStore): "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn ) + +class RelationsStore(RelationsWorkerStore): def _handle_event_relations(self, txn, event): """Handles inserting relation data during peristence of events -- cgit 1.5.1 From 2c662ddde4e1277a0dc17295748aa5f0c41fa163 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 16 May 2019 14:21:39 +0100 Subject: Indirect tuple conversion --- synapse/storage/relations.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'synapse/storage') diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 732418ec65..996cb6903a 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -81,6 +81,9 @@ class RelationPaginationToken(object): def to_string(self): return "%d-%d" % (self.topological, self.stream) + def as_tuple(self): + return attr.astuple(self) + @attr.s class AggregationPaginationToken(object): @@ -108,6 +111,9 @@ class AggregationPaginationToken(object): def to_string(self): return "%d-%d" % (self.count, self.stream) + def as_tuple(self): + return attr.astuple(self) + class RelationsWorkerStore(SQLBaseStore): @cached(tree=True) -- cgit 1.5.1 From 7a7eba8302d6566c044ca82c8ac0da65aa85e36b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 16 May 2019 14:24:58 +0100 Subject: Move parsing of tokens out of storage layer --- synapse/rest/client/v2_alpha/relations.py | 19 +++++++++++++++++++ synapse/storage/relations.py | 16 ++-------------- 2 files changed, 21 insertions(+), 14 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index 1b53e638eb..41e0a44936 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -32,6 +32,7 @@ from synapse.http.servlet import ( parse_string, ) from synapse.rest.client.transactions import HttpTransactionCache +from synapse.storage.relations import AggregationPaginationToken, RelationPaginationToken from ._base import client_v2_patterns @@ -149,6 +150,12 @@ class RelationPaginationServlet(RestServlet): from_token = parse_string(request, "from") to_token = parse_string(request, "to") + if from_token: + from_token = RelationPaginationToken.from_string(from_token) + + if to_token: + to_token = RelationPaginationToken.from_string(to_token) + result = yield self.store.get_relations_for_event( event_id=parent_id, relation_type=relation_type, @@ -221,6 +228,12 @@ class RelationAggregationPaginationServlet(RestServlet): from_token = parse_string(request, "from") to_token = parse_string(request, "to") + if from_token: + from_token = AggregationPaginationToken.from_string(from_token) + + if to_token: + to_token = AggregationPaginationToken.from_string(to_token) + res = yield self.store.get_aggregation_groups_for_event( event_id=parent_id, event_type=event_type, @@ -289,6 +302,12 @@ class RelationAggregationGroupPaginationServlet(RestServlet): from_token = parse_string(request, "from") to_token = parse_string(request, "to") + if from_token: + from_token = RelationPaginationToken.from_string(from_token) + + if to_token: + to_token = RelationPaginationToken.from_string(to_token) + result = yield self.store.get_relations_for_event( event_id=parent_id, relation_type=relation_type, diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 996cb6903a..de67e305a1 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -54,7 +54,7 @@ class PaginationChunk(object): return d -@attr.s +@attr.s(frozen=True, slots=True) class RelationPaginationToken(object): """Pagination token for relation pagination API. @@ -85,7 +85,7 @@ class RelationPaginationToken(object): return attr.astuple(self) -@attr.s +@attr.s(frozen=True, slots=True) class AggregationPaginationToken(object): """Pagination token for relation aggregation pagination API. @@ -151,12 +151,6 @@ class RelationsWorkerStore(SQLBaseStore): requested. The rows are of the form `{"event_id": "..."}`. """ - if from_token: - from_token = RelationPaginationToken.from_string(from_token) - - if to_token: - to_token = RelationPaginationToken.from_string(to_token) - where_clause = ["relates_to_id = ?"] where_args = [event_id] @@ -258,12 +252,6 @@ class RelationsWorkerStore(SQLBaseStore): match. Each row is a dict with `type`, `key` and `count` fields. """ - if from_token: - from_token = AggregationPaginationToken.from_string(from_token) - - if to_token: - to_token = AggregationPaginationToken.from_string(to_token) - where_clause = ["relates_to_id = ?", "relation_type = ?"] where_args = [event_id, RelationTypes.ANNOTATION] -- cgit 1.5.1 From 895179a4dcc993af9702e3ab36e5d2157fdd0c26 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 16 May 2019 16:41:05 +0100 Subject: Update docstring --- synapse/storage/stream.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 163363c0c2..824e77c89e 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -77,6 +77,15 @@ def generate_pagination_where_clause( would be generated for dir=b, from_token=(6, 7) and to_token=(5, 3). + Note that tokens are considered to be after the row they are in, e.g. if + a row A has a token T, then we consider A to be before T. This covention + is important when figuring out inequalities for the generated SQL, and + produces the following result: + - If paginatiting forwards then we exclude any rows matching the from + token, but include those that match the to token. + - If paginatiting backwards then we include any rows matching the from + token, but include those that match the to token. + Args: direction (str): Whether we're paginating backwards("b") or forwards ("f"). @@ -131,7 +140,9 @@ def _make_generic_sql_bound(bound, column_names, values, engine): names (tuple[str, str]): The column names. Must *not* be user defined as these get inserted directly into the SQL statement without escapes. - values (tuple[int, int]): The values to bound the columns by. + values (tuple[int|None, int]): The values to bound the columns by. If + the first value is None then only creates a bound on the second + column. engine: The database engine to generate the SQL for Returns: -- cgit 1.5.1 From d46aab3fa8bc02d8db9616490e849b9443fed8e7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 14 May 2019 16:59:21 +0100 Subject: Add basic editing support --- synapse/events/utils.py | 30 +++++++-- synapse/replication/slave/storage/events.py | 1 + synapse/storage/relations.py | 60 +++++++++++++++++- tests/rest/client/v2_alpha/test_relations.py | 91 +++++++++++++++++++++++++--- 4 files changed, 167 insertions(+), 15 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 16d0c64372..2019ce9b1c 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -346,7 +346,7 @@ class EventClientSerializer(object): defer.returnValue(event) event_id = event.event_id - event = serialize_event(event, time_now, **kwargs) + serialized_event = serialize_event(event, time_now, **kwargs) # If MSC1849 is enabled then we need to look if thre are any relations # we need to bundle in with the event @@ -359,14 +359,36 @@ class EventClientSerializer(object): ) if annotations.chunk: - r = event["unsigned"].setdefault("m.relations", {}) + r = serialized_event["unsigned"].setdefault("m.relations", {}) r[RelationTypes.ANNOTATION] = annotations.to_dict() if references.chunk: - r = event["unsigned"].setdefault("m.relations", {}) + r = serialized_event["unsigned"].setdefault("m.relations", {}) r[RelationTypes.REFERENCES] = references.to_dict() - defer.returnValue(event) + edit = None + if event.type == EventTypes.Message: + edit = yield self.store.get_applicable_edit( + event.event_id, event.type, event.sender, + ) + + if edit: + # If there is an edit replace the content, preserving existing + # relations. + + relations = event.content.get("m.relates_to") + serialized_event["content"] = edit.content.get("m.new_content", {}) + if relations: + serialized_event["content"]["m.relates_to"] = relations + else: + serialized_event["content"].pop("m.relates_to", None) + + r = serialized_event["unsigned"].setdefault("m.relations", {}) + r[RelationTypes.REPLACES] = { + "event_id": edit.event_id, + } + + defer.returnValue(serialized_event) def serialize_events(self, events, time_now, **kwargs): """Serializes multiple events. diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 797450bc66..857128b311 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -143,3 +143,4 @@ class SlavedEventStore(EventFederationWorkerStore, if relates_to: self.get_relations_for_event.invalidate_many((relates_to,)) self.get_aggregation_groups_for_event.invalidate_many((relates_to,)) + self.get_applicable_edit.invalidate_many((relates_to,)) diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index de67e305a1..aa91e52733 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -17,11 +17,13 @@ import logging import attr -from synapse.api.constants import RelationTypes +from twisted.internet import defer + +from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore from synapse.storage.stream import generate_pagination_where_clause -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks logger = logging.getLogger(__name__) @@ -312,6 +314,59 @@ class RelationsWorkerStore(SQLBaseStore): "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn ) + @cachedInlineCallbacks(tree=True) + def get_applicable_edit(self, event_id, event_type, sender): + """Get the most recent edit (if any) that has happened for the given + event. + + Correctly handles checking whether edits were allowed to happen. + + Args: + event_id (str): The original event ID + event_type (str): The original event type + sender (str): The original event sender + + Returns: + Deferred[EventBase|None]: Returns the most recent edit, if any. + """ + + # We only allow edits for `m.room.message` events that have the same sender + # and event type. We can't assert these things during regular event auth so + # we have to do the post hoc. + + if event_type != EventTypes.Message: + return + + sql = """ + SELECT event_id, origin_server_ts FROM events + INNER JOIN event_relations USING (event_id) + WHERE + relates_to_id = ? + AND relation_type = ? + AND type = ? + AND sender = ? + ORDER by origin_server_ts DESC, event_id DESC + LIMIT 1 + """ + + def _get_applicable_edit_txn(txn): + txn.execute( + sql, (event_id, RelationTypes.REPLACES, event_type, sender) + ) + row = txn.fetchone() + if row: + return row[0] + + edit_id = yield self.runInteraction( + "get_applicable_edit", _get_applicable_edit_txn + ) + + if not edit_id: + return + + edit_event = yield self.get_event(edit_id, allow_none=True) + defer.returnValue(edit_event) + class RelationsStore(RelationsWorkerStore): def _handle_event_relations(self, txn, event): @@ -357,3 +412,4 @@ class RelationsStore(RelationsWorkerStore): txn.call_after( self.get_aggregation_groups_for_event.invalidate_many, (parent_id,) ) + txn.call_after(self.get_applicable_edit.invalidate_many, (parent_id,)) diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index 775622bd2b..b0e4c47ae3 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -14,6 +14,7 @@ # limitations under the License. import itertools +import json import six @@ -102,11 +103,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): # relation event we sent above. self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) self.assert_dict( - { - "event_id": annotation_id, - "sender": self.user_id, - "type": "m.reaction", - }, + {"event_id": annotation_id, "sender": self.user_id, "type": "m.reaction"}, channel.json_body["chunk"][0], ) @@ -330,8 +327,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEquals(200, channel.code, channel.json_body) - self.maxDiff = None - self.assertEquals( channel.json_body["unsigned"].get("m.relations"), { @@ -347,7 +342,84 @@ class RelationsTestCase(unittest.HomeserverTestCase): }, ) - def _send_relation(self, relation_type, event_type, key=None): + def test_edit(self): + """Test that a simple edit works. + """ + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACES, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + ) + self.assertEquals(200, channel.code, channel.json_body) + + edit_event_id = channel.json_body["event_id"] + + request, channel = self.make_request( + "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id) + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals(channel.json_body["content"], new_body) + + self.assertEquals( + channel.json_body["unsigned"].get("m.relations"), + {RelationTypes.REPLACES: {"event_id": edit_event_id}}, + ) + + def test_multi_edit(self): + """Test that multiple edits, including attempts by people who + shouldn't be allowed, are correctly handled. + """ + + channel = self._send_relation( + RelationTypes.REPLACES, + "m.room.message", + content={ + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": {"msgtype": "m.text", "body": "First edit"}, + }, + ) + self.assertEquals(200, channel.code, channel.json_body) + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACES, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + ) + self.assertEquals(200, channel.code, channel.json_body) + + edit_event_id = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.REPLACES, + "m.room.message.WRONG_TYPE", + content={ + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"}, + }, + ) + self.assertEquals(200, channel.code, channel.json_body) + + request, channel = self.make_request( + "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id) + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals(channel.json_body["content"], new_body) + + self.assertEquals( + channel.json_body["unsigned"].get("m.relations"), + {RelationTypes.REPLACES: {"event_id": edit_event_id}}, + ) + + def _send_relation(self, relation_type, event_type, key=None, content={}): """Helper function to send a relation pointing at `self.parent_id` Args: @@ -355,6 +427,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): event_type (str): The type of the event to create key (str|None): The aggregation key used for m.annotation relation type. + content(dict|None): The content of the created event. Returns: FakeChannel @@ -367,7 +440,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "POST", "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" % (self.room, self.parent_id, relation_type, event_type, query), - b"{}", + json.dumps(content).encode("utf-8"), ) self.render(request) return channel -- cgit 1.5.1 From 5dbff345094327e61fa9c1425ba0f955c9530ba7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 17 May 2019 15:48:04 +0100 Subject: Fixup bsaed on review comments --- synapse/events/utils.py | 4 +--- synapse/replication/slave/storage/events.py | 2 +- synapse/storage/relations.py | 32 +++++++++++++++-------------- 3 files changed, 19 insertions(+), 19 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 2019ce9b1c..bf3c8f8dc1 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -368,9 +368,7 @@ class EventClientSerializer(object): edit = None if event.type == EventTypes.Message: - edit = yield self.store.get_applicable_edit( - event.event_id, event.type, event.sender, - ) + edit = yield self.store.get_applicable_edit(event_id) if edit: # If there is an edit replace the content, preserving existing diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 857128b311..a3952506c1 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -143,4 +143,4 @@ class SlavedEventStore(EventFederationWorkerStore, if relates_to: self.get_relations_for_event.invalidate_many((relates_to,)) self.get_aggregation_groups_for_event.invalidate_many((relates_to,)) - self.get_applicable_edit.invalidate_many((relates_to,)) + self.get_applicable_edit.invalidate((relates_to,)) diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index aa91e52733..6e216066ab 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -19,7 +19,7 @@ import attr from twisted.internet import defer -from synapse.api.constants import EventTypes, RelationTypes +from synapse.api.constants import RelationTypes from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore from synapse.storage.stream import generate_pagination_where_clause @@ -314,8 +314,8 @@ class RelationsWorkerStore(SQLBaseStore): "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn ) - @cachedInlineCallbacks(tree=True) - def get_applicable_edit(self, event_id, event_type, sender): + @cachedInlineCallbacks() + def get_applicable_edit(self, event_id): """Get the most recent edit (if any) that has happened for the given event. @@ -323,8 +323,6 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id (str): The original event ID - event_type (str): The original event type - sender (str): The original event sender Returns: Deferred[EventBase|None]: Returns the most recent edit, if any. @@ -332,26 +330,28 @@ class RelationsWorkerStore(SQLBaseStore): # We only allow edits for `m.room.message` events that have the same sender # and event type. We can't assert these things during regular event auth so - # we have to do the post hoc. - - if event_type != EventTypes.Message: - return + # we have to do the checks post hoc. + # Fetches latest edit that has the same type and sender as the + # original, and is an `m.room.message`. sql = """ - SELECT event_id, origin_server_ts FROM events + SELECT edit.event_id FROM events AS edit INNER JOIN event_relations USING (event_id) + INNER JOIN events AS original ON + original.event_id = relates_to_id + AND edit.type = original.type + AND edit.sender = original.sender WHERE relates_to_id = ? AND relation_type = ? - AND type = ? - AND sender = ? - ORDER by origin_server_ts DESC, event_id DESC + AND edit.type = 'm.room.message' + ORDER by edit.origin_server_ts DESC, edit.event_id DESC LIMIT 1 """ def _get_applicable_edit_txn(txn): txn.execute( - sql, (event_id, RelationTypes.REPLACES, event_type, sender) + sql, (event_id, RelationTypes.REPLACES,) ) row = txn.fetchone() if row: @@ -412,4 +412,6 @@ class RelationsStore(RelationsWorkerStore): txn.call_after( self.get_aggregation_groups_for_event.invalidate_many, (parent_id,) ) - txn.call_after(self.get_applicable_edit.invalidate_many, (parent_id,)) + + if rel_type == RelationTypes.REPLACES: + txn.call_after(self.get_applicable_edit.invalidate, (parent_id,)) -- cgit 1.5.1 From 8dd9cca8ead9fc0cf0789edf9fcfce04814c5eda Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 17 May 2019 16:40:51 +0100 Subject: Spelling and clarifications --- synapse/storage/stream.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 824e77c89e..529ad4ea79 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -78,12 +78,12 @@ def generate_pagination_where_clause( would be generated for dir=b, from_token=(6, 7) and to_token=(5, 3). Note that tokens are considered to be after the row they are in, e.g. if - a row A has a token T, then we consider A to be before T. This covention + a row A has a token T, then we consider A to be before T. This convention is important when figuring out inequalities for the generated SQL, and produces the following result: - - If paginatiting forwards then we exclude any rows matching the from + - If paginating forwards then we exclude any rows matching the from token, but include those that match the to token. - - If paginatiting backwards then we include any rows matching the from + - If paginating backwards then we include any rows matching the from token, but include those that match the to token. Args: @@ -92,8 +92,12 @@ def generate_pagination_where_clause( column_names (tuple[str, str]): The column names to bound. Must *not* be user defined as these get inserted directly into the SQL statement without escapes. - from_token (tuple[int, int]|None) - to_token (tuple[int, int]|None) + from_token (tuple[int, int]|None): The start point for the pagination. + This is an exclusive minimum bound if direction is "f", and an + inclusive maximum bound if direction is "b". + to_token (tuple[int, int]|None): The endpoint point for the pagination. + This is an inclusive maximum bound if direction is "f", and an + exclusive minimum bound if direction is "b". engine: The database engine to generate the clauses for Returns: -- cgit 1.5.1 From b63cc325a9362874a13f0079589afccd7d108f1f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 17 May 2019 18:01:39 +0100 Subject: Only count aggregations from distinct senders As a user isn't allowed to send a single emoji more than once. --- synapse/storage/relations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 6e216066ab..42f587a7d8 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -280,7 +280,7 @@ class RelationsWorkerStore(SQLBaseStore): having_clause = "" sql = """ - SELECT type, aggregation_key, COUNT(*), MAX(stream_ordering) + SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering) FROM event_relations INNER JOIN events USING (event_id) WHERE {where_clause} -- cgit 1.5.1 From ad5b4074e1a84b1e50a97144fbf1902ddc89db73 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 17 May 2019 19:37:31 +0100 Subject: Add startup background job for account validity If account validity is enabled in the server's configuration, this job will run at startup as a background job and will stick an expiration date to any registered account missing one. --- synapse/storage/_base.py | 62 +++++++++++++++++++++++++++++ synapse/storage/registration.py | 16 ++------ tests/rest/client/v2_alpha/test_register.py | 55 +++++++++++++++++++++++++ 3 files changed, 121 insertions(+), 12 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 983ce026e1..06d516c9e2 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -227,6 +229,8 @@ class SQLBaseStore(object): # A set of tables that are not safe to use native upserts in. self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) + self._account_validity = self.hs.config.account_validity + # We add the user_directory_search table to the blacklist on SQLite # because the existing search table does not have an index, making it # unsafe to use native upserts. @@ -243,6 +247,14 @@ class SQLBaseStore(object): self._check_safe_to_upsert, ) + if self._account_validity.enabled: + self._clock.call_later( + 0.0, + run_as_background_process, + "account_validity_set_expiration_dates", + self._set_expiration_date_when_missing, + ) + @defer.inlineCallbacks def _check_safe_to_upsert(self): """ @@ -275,6 +287,56 @@ class SQLBaseStore(object): self._check_safe_to_upsert, ) + @defer.inlineCallbacks + def _set_expiration_date_when_missing(self): + """ + Retrieves the list of registered users that don't have an expiration date, and + adds an expiration date for each of them. + """ + + def select_users_with_no_expiration_date_txn(txn): + """Retrieves the list of registered users with no expiration date from the + database. + """ + sql = ( + "SELECT users.name FROM users" + " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" + " WHERE account_validity.user_id is NULL;" + ) + txn.execute(sql, []) + return self.cursor_to_dict(txn) + + res = yield self.runInteraction( + "get_users_with_no_expiration_date", + select_users_with_no_expiration_date_txn, + ) + + if res: + for user in res: + self.runInteraction( + "set_expiration_date_for_user_background", + self.set_expiration_date_for_user_txn, + user["name"], + ) + + def set_expiration_date_for_user_txn(self, txn, user_id): + """Sets an expiration date to the account with the given user ID. + + Args: + user_id (str): User ID to set an expiration date for. + """ + now_ms = self._clock.time_msec() + expiration_ts = now_ms + self._account_validity.period + self._simple_insert_txn( + txn, + "account_validity", + values={ + "user_id": user_id, + "expiration_ts_ms": expiration_ts, + "email_sent": False, + }, + ) + def start_profiling(self): self._previous_loop_ts = self._clock.time_msec() diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 03a06a83d6..4cf159ba81 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- -# Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -725,17 +727,7 @@ class RegistrationStore( raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) if self._account_validity.enabled: - now_ms = self.clock.time_msec() - expiration_ts = now_ms + self._account_validity.period - self._simple_insert_txn( - txn, - "account_validity", - values={ - "user_id": user_id, - "expiration_ts_ms": expiration_ts, - "email_sent": False, - } - ) + self.set_expiration_date_for_user_txn(txn, user_id) if token: # it's possible for this to get a conflict, but only for a single user diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 65685883db..d4a1d4d50c 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -1,3 +1,20 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import datetime import json import os @@ -409,3 +426,41 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEqual(len(self.email_attempts), 1) + + +class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + ] + + def make_homeserver(self, reactor, clock): + self.validity_period = 10 + + config = self.default_config() + + config["enable_registration"] = True + config["account_validity"] = { + "enabled": False, + } + + self.hs = self.setup_test_homeserver(config=config) + self.hs.config.account_validity.period = self.validity_period + + self.store = self.hs.get_datastore() + + return self.hs + + def test_background_job(self): + """ + Tests whether the account validity startup background job does the right thing, + which is sticking an expiration date to every account that doesn't already have + one. + """ + user_id = self.register_user("kermit", "user") + + now_ms = self.hs.clock.time_msec() + self.get_success(self.store._set_expiration_date_when_missing()) + + res = self.get_success(self.store.get_expiration_ts_for_user(user_id)) + self.assertEqual(res, now_ms + self.validity_period) -- cgit 1.5.1 From 935af0da380f39ba284b78054270331bdbad7712 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 20 May 2019 10:13:05 +0100 Subject: Correctly update aggregation counts after redaction --- synapse/storage/events.py | 3 +++ synapse/storage/relations.py | 17 +++++++++++++ tests/rest/client/v2_alpha/test_relations.py | 37 ++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+) (limited to 'synapse/storage') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index b025ebc926..881d6d0126 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1325,6 +1325,9 @@ class EventsStore( txn, event.room_id, event.redacts ) + # Remove from relations table. + self._handle_redaction(txn, event.redacts) + # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. self._handle_mult_prev_events( diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 6e216066ab..63e6185ee3 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -415,3 +415,20 @@ class RelationsStore(RelationsWorkerStore): if rel_type == RelationTypes.REPLACES: txn.call_after(self.get_applicable_edit.invalidate, (parent_id,)) + + def _handle_redaction(self, txn, redacted_event_id): + """Handles receiving a redaction and checking whether we need to remove + any redacted relations from the database. + + Args: + txn + redacted_event_id (str): The event that was redacted. + """ + + self._simple_delete_txn( + txn, + table="event_relations", + keyvalues={ + "event_id": redacted_event_id, + } + ) diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index cd965167f8..3737cdc396 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -320,6 +320,43 @@ class RelationsTestCase(unittest.HomeserverTestCase): }, ) + def test_aggregation_redactions(self): + """Test that annotations get correctly aggregated after a redactions. + """ + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + to_redact_event_id = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Now lets redact the 'a' reaction + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id), + access_token=self.user_token, + content={}, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s" + % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals( + channel.json_body, + {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, + ) + def test_aggregation_must_be_annotation(self): """Test that aggregations must be annotations. """ -- cgit 1.5.1 From 1dff859d6aed58561a2b5913e5c9b897bbd3599c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 20 May 2019 14:31:19 +0100 Subject: Rename relation types to match MSC --- synapse/api/constants.py | 4 ++-- synapse/events/utils.py | 6 +++--- synapse/storage/relations.py | 8 ++++---- tests/rest/client/v2_alpha/test_relations.py | 22 +++++++++++----------- 4 files changed, 20 insertions(+), 20 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 28862a1d25..6b347b1749 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -125,5 +125,5 @@ class RelationTypes(object): """The types of relations known to this server. """ ANNOTATION = "m.annotation" - REPLACES = "m.replaces" - REFERENCES = "m.references" + REPLACE = "m.replace" + REFERENCE = "m.reference" diff --git a/synapse/events/utils.py b/synapse/events/utils.py index bf3c8f8dc1..27a2a9ef98 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -355,7 +355,7 @@ class EventClientSerializer(object): event_id, ) references = yield self.store.get_relations_for_event( - event_id, RelationTypes.REFERENCES, direction="f", + event_id, RelationTypes.REFERENCE, direction="f", ) if annotations.chunk: @@ -364,7 +364,7 @@ class EventClientSerializer(object): if references.chunk: r = serialized_event["unsigned"].setdefault("m.relations", {}) - r[RelationTypes.REFERENCES] = references.to_dict() + r[RelationTypes.REFERENCE] = references.to_dict() edit = None if event.type == EventTypes.Message: @@ -382,7 +382,7 @@ class EventClientSerializer(object): serialized_event["content"].pop("m.relates_to", None) r = serialized_event["unsigned"].setdefault("m.relations", {}) - r[RelationTypes.REPLACES] = { + r[RelationTypes.REPLACE] = { "event_id": edit.event_id, } diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 63e6185ee3..493abe405e 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -351,7 +351,7 @@ class RelationsWorkerStore(SQLBaseStore): def _get_applicable_edit_txn(txn): txn.execute( - sql, (event_id, RelationTypes.REPLACES,) + sql, (event_id, RelationTypes.REPLACE,) ) row = txn.fetchone() if row: @@ -384,8 +384,8 @@ class RelationsStore(RelationsWorkerStore): rel_type = relation.get("rel_type") if rel_type not in ( RelationTypes.ANNOTATION, - RelationTypes.REFERENCES, - RelationTypes.REPLACES, + RelationTypes.REFERENCE, + RelationTypes.REPLACE, ): # Unknown relation type return @@ -413,7 +413,7 @@ class RelationsStore(RelationsWorkerStore): self.get_aggregation_groups_for_event.invalidate_many, (parent_id,) ) - if rel_type == RelationTypes.REPLACES: + if rel_type == RelationTypes.REPLACE: txn.call_after(self.get_applicable_edit.invalidate, (parent_id,)) def _handle_redaction(self, txn, redacted_event_id): diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index b3dc4bd5bc..3d040cf118 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -363,8 +363,8 @@ class RelationsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s/m.replace?limit=1" - % (self.room, self.parent_id), + "/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1" + % (self.room, self.parent_id, RelationTypes.REPLACE), access_token=self.user_token, ) self.render(request) @@ -386,11 +386,11 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") self.assertEquals(200, channel.code, channel.json_body) - channel = self._send_relation(RelationTypes.REFERENCES, "m.room.test") + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") self.assertEquals(200, channel.code, channel.json_body) reply_1 = channel.json_body["event_id"] - channel = self._send_relation(RelationTypes.REFERENCES, "m.room.test") + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") self.assertEquals(200, channel.code, channel.json_body) reply_2 = channel.json_body["event_id"] @@ -411,7 +411,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): {"type": "m.reaction", "key": "b", "count": 1}, ] }, - RelationTypes.REFERENCES: { + RelationTypes.REFERENCE: { "chunk": [{"event_id": reply_1}, {"event_id": reply_2}] }, }, @@ -423,7 +423,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): new_body = {"msgtype": "m.text", "body": "I've been edited!"} channel = self._send_relation( - RelationTypes.REPLACES, + RelationTypes.REPLACE, "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, ) @@ -443,7 +443,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals( channel.json_body["unsigned"].get("m.relations"), - {RelationTypes.REPLACES: {"event_id": edit_event_id}}, + {RelationTypes.REPLACE: {"event_id": edit_event_id}}, ) def test_multi_edit(self): @@ -452,7 +452,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): """ channel = self._send_relation( - RelationTypes.REPLACES, + RelationTypes.REPLACE, "m.room.message", content={ "msgtype": "m.text", @@ -464,7 +464,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): new_body = {"msgtype": "m.text", "body": "I've been edited!"} channel = self._send_relation( - RelationTypes.REPLACES, + RelationTypes.REPLACE, "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, ) @@ -473,7 +473,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): edit_event_id = channel.json_body["event_id"] channel = self._send_relation( - RelationTypes.REPLACES, + RelationTypes.REPLACE, "m.room.message.WRONG_TYPE", content={ "msgtype": "m.text", @@ -495,7 +495,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals( channel.json_body["unsigned"].get("m.relations"), - {RelationTypes.REPLACES: {"event_id": edit_event_id}}, + {RelationTypes.REPLACE: {"event_id": edit_event_id}}, ) def _send_relation( -- cgit 1.5.1 From c7ec06e8a6909343f5860a5eecbe3aa69f03c151 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 20 May 2019 17:39:05 +0100 Subject: Block attempts to annotate the same event twice --- synapse/handlers/message.py | 16 +++++++++- synapse/storage/relations.py | 48 ++++++++++++++++++++++++++-- tests/rest/client/v2_alpha/test_relations.py | 27 +++++++++++++++- 3 files changed, 86 insertions(+), 5 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 7b2c33a922..0c892c8dba 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -22,7 +22,7 @@ from canonicaljson import encode_canonical_json, json from twisted.internet import defer from twisted.internet.defer import succeed -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, Membership, RelationTypes from synapse.api.errors import ( AuthError, Codes, @@ -601,6 +601,20 @@ class EventCreationHandler(object): self.validator.validate_new(event) + # We now check that if this event is an annotation that the can't + # annotate the same way twice (e.g. stops users from liking an event + # multiple times). + relation = event.content.get("m.relates_to", {}) + if relation.get("rel_type") == RelationTypes.ANNOTATION: + relates_to = relation["event_id"] + aggregation_key = relation["key"] + + already_exists = yield self.store.has_user_annotated_event( + relates_to, event.type, aggregation_key, event.sender, + ) + if already_exists: + raise SynapseError(400, "Can't send same reaction twice") + logger.debug( "Created event %s", event.event_id, diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 493abe405e..7d51b38d77 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -350,9 +350,7 @@ class RelationsWorkerStore(SQLBaseStore): """ def _get_applicable_edit_txn(txn): - txn.execute( - sql, (event_id, RelationTypes.REPLACE,) - ) + txn.execute(sql, (event_id, RelationTypes.REPLACE)) row = txn.fetchone() if row: return row[0] @@ -367,6 +365,50 @@ class RelationsWorkerStore(SQLBaseStore): edit_event = yield self.get_event(edit_id, allow_none=True) defer.returnValue(edit_event) + def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender): + """Check if a user has already annotated an event with the same key + (e.g. already liked an event). + + Args: + parent_id (str): The event being annotated + event_type (str): The event type of the annotation + aggregation_key (str): The aggregation key of the annotation + sender (str): The sender of the annotation + + Returns: + Deferred[bool] + """ + + sql = """ + SELECT 1 FROM event_relations + INNER JOIN events USING (event_id) + WHERE + relates_to_id = ? + AND relation_type = ? + AND type = ? + AND sender = ? + AND aggregation_key = ? + LIMIT 1; + """ + + def _get_if_user_has_annotated_event(txn): + txn.execute( + sql, + ( + parent_id, + RelationTypes.ANNOTATION, + event_type, + sender, + aggregation_key, + ), + ) + + return bool(txn.fetchone()) + + return self.runInteraction( + "get_if_user_has_annotated_event", _get_if_user_has_annotated_event + ) + class RelationsStore(RelationsWorkerStore): def _handle_event_relations(self, txn, event): diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index 3d040cf118..43b3049daa 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -90,6 +90,15 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member) self.assertEquals(400, channel.code, channel.json_body) + def test_deny_double_react(self): + """Test that we deny relations on membership events + """ + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(400, channel.code, channel.json_body) + def test_basic_paginate_relations(self): """Tests that calling pagination API corectly the latest relations. """ @@ -234,14 +243,30 @@ class RelationsTestCase(unittest.HomeserverTestCase): """Test that we can paginate within an annotation group. """ + # We need to create ten separate users to send each reaction. + access_tokens = [self.user_token, self.user2_token] + idx = 0 + while len(access_tokens) < 10: + user_id, token = self._create_user("test" + str(idx)) + idx += 1 + + self.helper.join(self.room, user=user_id, tok=token) + access_tokens.append(token) + + idx = 0 expected_event_ids = [] for _ in range(10): channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", key=u"👍" + RelationTypes.ANNOTATION, + "m.reaction", + key=u"👍", + access_token=access_tokens[idx], ) self.assertEquals(200, channel.code, channel.json_body) expected_event_ids.append(channel.json_body["event_id"]) + idx += 1 + # Also send a different type of reaction so that we test we don't see it channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") self.assertEquals(200, channel.code, channel.json_body) -- cgit 1.5.1 From 5ceee46c6bd55376ff2188b6e674035d7da95f24 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 21 May 2019 13:38:51 +0100 Subject: Do the select and insert in a single transaction --- synapse/storage/_base.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 06d516c9e2..fa6839ceca 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -304,21 +304,17 @@ class SQLBaseStore(object): " WHERE account_validity.user_id is NULL;" ) txn.execute(sql, []) - return self.cursor_to_dict(txn) - res = yield self.runInteraction( + res = self.cursor_to_dict(txn) + if res: + for user in res: + self.set_expiration_date_for_user_txn(txn, user["name"]) + + yield self.runInteraction( "get_users_with_no_expiration_date", select_users_with_no_expiration_date_txn, ) - if res: - for user in res: - self.runInteraction( - "set_expiration_date_for_user_background", - self.set_expiration_date_for_user_txn, - user["name"], - ) - def set_expiration_date_for_user_txn(self, txn, user_id): """Sets an expiration date to the account with the given user ID. -- cgit 1.5.1 From c4aef549ad1f55661675a789f89fe9e041fac874 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 21 May 2019 16:10:54 +0100 Subject: Exclude soft-failed events from fwd-extremity candidates. (#5146) When considering the candidates to be forward-extremities, we must exclude soft failures. Hopefully fixes #5090. --- changelog.d/5146.bugfix | 1 + synapse/handlers/federation.py | 7 ++++++- synapse/storage/events.py | 9 +++++++-- 3 files changed, 14 insertions(+), 3 deletions(-) create mode 100644 changelog.d/5146.bugfix (limited to 'synapse/storage') diff --git a/changelog.d/5146.bugfix b/changelog.d/5146.bugfix new file mode 100644 index 0000000000..a54abed92b --- /dev/null +++ b/changelog.d/5146.bugfix @@ -0,0 +1 @@ +Exclude soft-failed events from forward-extremity candidates: fixes "No forward extremities left!" error. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 0684778882..2202ed699a 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1916,6 +1916,11 @@ class FederationHandler(BaseHandler): event.room_id, latest_event_ids=extrem_ids, ) + logger.debug( + "Doing soft-fail check for %s: state %s", + event.event_id, current_state_ids, + ) + # Now check if event pass auth against said current state auth_types = auth_types_for_event(event) current_state_ids = [ @@ -1932,7 +1937,7 @@ class FederationHandler(BaseHandler): self.auth.check(room_version, event, auth_events=current_auth_events) except AuthError as e: logger.warn( - "Failed current state auth resolution for %r because %s", + "Soft-failing %r because %s", event, e, ) event.internal_metadata.soft_failed = True diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 881d6d0126..2ffc27ff41 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -575,10 +575,11 @@ class EventsStore( def _get_events(txn, batch): sql = """ - SELECT prev_event_id + SELECT prev_event_id, internal_metadata FROM event_edges INNER JOIN events USING (event_id) LEFT JOIN rejections USING (event_id) + LEFT JOIN event_json USING (event_id) WHERE prev_event_id IN (%s) AND NOT events.outlier @@ -588,7 +589,11 @@ class EventsStore( ) txn.execute(sql, batch) - results.extend(r[0] for r in txn) + results.extend( + r[0] + for r in txn + if not json.loads(r[1]).get("soft_failed") + ) for chunk in batch_iter(event_ids, 100): yield self.runInteraction("_get_events_which_are_prevs", _get_events, chunk) -- cgit 1.5.1 From 4a30e4acb4ef14431914bd42ad09a51bd81d6c3e Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Tue, 21 May 2019 11:36:50 -0500 Subject: Room Statistics (#4338) --- changelog.d/4338.feature | 1 + docs/sample_config.yaml | 16 ++ synapse/api/constants.py | 1 + synapse/config/homeserver.py | 42 ++- synapse/config/stats.py | 60 ++++ synapse/handlers/stats.py | 325 +++++++++++++++++++++ synapse/server.py | 6 + synapse/storage/__init__.py | 2 + synapse/storage/events_worker.py | 24 ++ synapse/storage/roommember.py | 32 +++ synapse/storage/schema/delta/54/stats.sql | 80 ++++++ synapse/storage/state_deltas.py | 12 +- synapse/storage/stats.py | 450 ++++++++++++++++++++++++++++++ tests/handlers/test_stats.py | 251 +++++++++++++++++ tests/rest/client/v1/utils.py | 17 ++ 15 files changed, 1306 insertions(+), 13 deletions(-) create mode 100644 changelog.d/4338.feature create mode 100644 synapse/config/stats.py create mode 100644 synapse/handlers/stats.py create mode 100644 synapse/storage/schema/delta/54/stats.sql create mode 100644 synapse/storage/stats.py create mode 100644 tests/handlers/test_stats.py (limited to 'synapse/storage') diff --git a/changelog.d/4338.feature b/changelog.d/4338.feature new file mode 100644 index 0000000000..01285e965c --- /dev/null +++ b/changelog.d/4338.feature @@ -0,0 +1 @@ +Synapse now more efficiently collates room statistics. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index f658ec8ecd..559fbcdd01 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1153,6 +1153,22 @@ password_config: # + +# Local statistics collection. Used in populating the room directory. +# +# 'bucket_size' controls how large each statistics timeslice is. It can +# be defined in a human readable short form -- e.g. "1d", "1y". +# +# 'retention' controls how long historical statistics will be kept for. +# It can be defined in a human readable short form -- e.g. "1d", "1y". +# +# +#stats: +# enabled: true +# bucket_size: 1d +# retention: 1y + + # Server Notices room configuration # # Uncomment this section to enable a room which can be used to send notices diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 6b347b1749..ee129c8689 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -79,6 +79,7 @@ class EventTypes(object): RoomHistoryVisibility = "m.room.history_visibility" CanonicalAlias = "m.room.canonical_alias" + Encryption = "m.room.encryption" RoomAvatar = "m.room.avatar" RoomEncryption = "m.room.encryption" GuestAccess = "m.room.guest_access" diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 727fdc54d8..5c4fc8ff21 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from .api import ApiConfig from .appservice import AppServiceConfig from .captcha import CaptchaConfig @@ -36,20 +37,41 @@ from .saml2_config import SAML2Config from .server import ServerConfig from .server_notices_config import ServerNoticesConfig from .spam_checker import SpamCheckerConfig +from .stats import StatsConfig from .tls import TlsConfig from .user_directory import UserDirectoryConfig from .voip import VoipConfig from .workers import WorkerConfig -class HomeServerConfig(ServerConfig, TlsConfig, DatabaseConfig, LoggingConfig, - RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, - VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig, - AppServiceConfig, KeyConfig, SAML2Config, CasConfig, - JWTConfig, PasswordConfig, EmailConfig, - WorkerConfig, PasswordAuthProviderConfig, PushConfig, - SpamCheckerConfig, GroupsConfig, UserDirectoryConfig, - ConsentConfig, - ServerNoticesConfig, RoomDirectoryConfig, - ): +class HomeServerConfig( + ServerConfig, + TlsConfig, + DatabaseConfig, + LoggingConfig, + RatelimitConfig, + ContentRepositoryConfig, + CaptchaConfig, + VoipConfig, + RegistrationConfig, + MetricsConfig, + ApiConfig, + AppServiceConfig, + KeyConfig, + SAML2Config, + CasConfig, + JWTConfig, + PasswordConfig, + EmailConfig, + WorkerConfig, + PasswordAuthProviderConfig, + PushConfig, + SpamCheckerConfig, + GroupsConfig, + UserDirectoryConfig, + ConsentConfig, + StatsConfig, + ServerNoticesConfig, + RoomDirectoryConfig, +): pass diff --git a/synapse/config/stats.py b/synapse/config/stats.py new file mode 100644 index 0000000000..80fc1b9dd0 --- /dev/null +++ b/synapse/config/stats.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division + +import sys + +from ._base import Config + + +class StatsConfig(Config): + """Stats Configuration + Configuration for the behaviour of synapse's stats engine + """ + + def read_config(self, config): + self.stats_enabled = True + self.stats_bucket_size = 86400 + self.stats_retention = sys.maxsize + stats_config = config.get("stats", None) + if stats_config: + self.stats_enabled = stats_config.get("enabled", self.stats_enabled) + self.stats_bucket_size = ( + self.parse_duration(stats_config.get("bucket_size", "1d")) / 1000 + ) + self.stats_retention = ( + self.parse_duration( + stats_config.get("retention", "%ds" % (sys.maxsize,)) + ) + / 1000 + ) + + def default_config(self, config_dir_path, server_name, **kwargs): + return """ + # Local statistics collection. Used in populating the room directory. + # + # 'bucket_size' controls how large each statistics timeslice is. It can + # be defined in a human readable short form -- e.g. "1d", "1y". + # + # 'retention' controls how long historical statistics will be kept for. + # It can be defined in a human readable short form -- e.g. "1d", "1y". + # + # + #stats: + # enabled: true + # bucket_size: 1d + # retention: 1y + """ diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py new file mode 100644 index 0000000000..0e92b405ba --- /dev/null +++ b/synapse/handlers/stats.py @@ -0,0 +1,325 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.handlers.state_deltas import StateDeltasHandler +from synapse.metrics import event_processing_positions +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import UserID +from synapse.util.metrics import Measure + +logger = logging.getLogger(__name__) + + +class StatsHandler(StateDeltasHandler): + """Handles keeping the *_stats tables updated with a simple time-series of + information about the users, rooms and media on the server, such that admins + have some idea of who is consuming their resources. + + Heavily derived from UserDirectoryHandler + """ + + def __init__(self, hs): + super(StatsHandler, self).__init__(hs) + self.hs = hs + self.store = hs.get_datastore() + self.state = hs.get_state_handler() + self.server_name = hs.hostname + self.clock = hs.get_clock() + self.notifier = hs.get_notifier() + self.is_mine_id = hs.is_mine_id + self.stats_bucket_size = hs.config.stats_bucket_size + + # The current position in the current_state_delta stream + self.pos = None + + # Guard to ensure we only process deltas one at a time + self._is_processing = False + + if hs.config.stats_enabled: + self.notifier.add_replication_callback(self.notify_new_event) + + # We kick this off so that we don't have to wait for a change before + # we start populating stats + self.clock.call_later(0, self.notify_new_event) + + def notify_new_event(self): + """Called when there may be more deltas to process + """ + if not self.hs.config.stats_enabled: + return + + if self._is_processing: + return + + @defer.inlineCallbacks + def process(): + try: + yield self._unsafe_process() + finally: + self._is_processing = False + + self._is_processing = True + run_as_background_process("stats.notify_new_event", process) + + @defer.inlineCallbacks + def _unsafe_process(self): + # If self.pos is None then means we haven't fetched it from DB + if self.pos is None: + self.pos = yield self.store.get_stats_stream_pos() + + # If still None then the initial background update hasn't happened yet + if self.pos is None: + defer.returnValue(None) + + # Loop round handling deltas until we're up to date + while True: + with Measure(self.clock, "stats_delta"): + deltas = yield self.store.get_current_state_deltas(self.pos) + if not deltas: + return + + logger.info("Handling %d state deltas", len(deltas)) + yield self._handle_deltas(deltas) + + self.pos = deltas[-1]["stream_id"] + yield self.store.update_stats_stream_pos(self.pos) + + event_processing_positions.labels("stats").set(self.pos) + + @defer.inlineCallbacks + def _handle_deltas(self, deltas): + """ + Called with the state deltas to process + """ + for delta in deltas: + typ = delta["type"] + state_key = delta["state_key"] + room_id = delta["room_id"] + event_id = delta["event_id"] + stream_id = delta["stream_id"] + prev_event_id = delta["prev_event_id"] + + logger.debug("Handling: %r %r, %s", typ, state_key, event_id) + + token = yield self.store.get_earliest_token_for_room_stats(room_id) + + # If the earliest token to begin from is larger than our current + # stream ID, skip processing this delta. + if token is not None and token >= stream_id: + logger.debug( + "Ignoring: %s as earlier than this room's initial ingestion event", + event_id, + ) + continue + + if event_id is None and prev_event_id is None: + # Errr... + continue + + event_content = {} + + if event_id is not None: + event_content = (yield self.store.get_event(event_id)).content or {} + + # quantise time to the nearest bucket + now = yield self.store.get_received_ts(event_id) + now = (now // 1000 // self.stats_bucket_size) * self.stats_bucket_size + + if typ == EventTypes.Member: + # we could use _get_key_change here but it's a bit inefficient + # given we're not testing for a specific result; might as well + # just grab the prev_membership and membership strings and + # compare them. + prev_event_content = {} + if prev_event_id is not None: + prev_event_content = ( + yield self.store.get_event(prev_event_id) + ).content + + membership = event_content.get("membership", Membership.LEAVE) + prev_membership = prev_event_content.get("membership", Membership.LEAVE) + + if prev_membership == membership: + continue + + if prev_membership == Membership.JOIN: + yield self.store.update_stats_delta( + now, "room", room_id, "joined_members", -1 + ) + elif prev_membership == Membership.INVITE: + yield self.store.update_stats_delta( + now, "room", room_id, "invited_members", -1 + ) + elif prev_membership == Membership.LEAVE: + yield self.store.update_stats_delta( + now, "room", room_id, "left_members", -1 + ) + elif prev_membership == Membership.BAN: + yield self.store.update_stats_delta( + now, "room", room_id, "banned_members", -1 + ) + else: + err = "%s is not a valid prev_membership" % (repr(prev_membership),) + logger.error(err) + raise ValueError(err) + + if membership == Membership.JOIN: + yield self.store.update_stats_delta( + now, "room", room_id, "joined_members", +1 + ) + elif membership == Membership.INVITE: + yield self.store.update_stats_delta( + now, "room", room_id, "invited_members", +1 + ) + elif membership == Membership.LEAVE: + yield self.store.update_stats_delta( + now, "room", room_id, "left_members", +1 + ) + elif membership == Membership.BAN: + yield self.store.update_stats_delta( + now, "room", room_id, "banned_members", +1 + ) + else: + err = "%s is not a valid membership" % (repr(membership),) + logger.error(err) + raise ValueError(err) + + user_id = state_key + if self.is_mine_id(user_id): + # update user_stats as it's one of our users + public = yield self._is_public_room(room_id) + + if membership == Membership.LEAVE: + yield self.store.update_stats_delta( + now, + "user", + user_id, + "public_rooms" if public else "private_rooms", + -1, + ) + elif membership == Membership.JOIN: + yield self.store.update_stats_delta( + now, + "user", + user_id, + "public_rooms" if public else "private_rooms", + +1, + ) + + elif typ == EventTypes.Create: + # Newly created room. Add it with all blank portions. + yield self.store.update_room_state( + room_id, + { + "join_rules": None, + "history_visibility": None, + "encryption": None, + "name": None, + "topic": None, + "avatar": None, + "canonical_alias": None, + }, + ) + + elif typ == EventTypes.JoinRules: + yield self.store.update_room_state( + room_id, {"join_rules": event_content.get("join_rule")} + ) + + is_public = yield self._get_key_change( + prev_event_id, event_id, "join_rule", JoinRules.PUBLIC + ) + if is_public is not None: + yield self.update_public_room_stats(now, room_id, is_public) + + elif typ == EventTypes.RoomHistoryVisibility: + yield self.store.update_room_state( + room_id, + {"history_visibility": event_content.get("history_visibility")}, + ) + + is_public = yield self._get_key_change( + prev_event_id, event_id, "history_visibility", "world_readable" + ) + if is_public is not None: + yield self.update_public_room_stats(now, room_id, is_public) + + elif typ == EventTypes.Encryption: + yield self.store.update_room_state( + room_id, {"encryption": event_content.get("algorithm")} + ) + elif typ == EventTypes.Name: + yield self.store.update_room_state( + room_id, {"name": event_content.get("name")} + ) + elif typ == EventTypes.Topic: + yield self.store.update_room_state( + room_id, {"topic": event_content.get("topic")} + ) + elif typ == EventTypes.RoomAvatar: + yield self.store.update_room_state( + room_id, {"avatar": event_content.get("url")} + ) + elif typ == EventTypes.CanonicalAlias: + yield self.store.update_room_state( + room_id, {"canonical_alias": event_content.get("alias")} + ) + + @defer.inlineCallbacks + def update_public_room_stats(self, ts, room_id, is_public): + """ + Increment/decrement a user's number of public rooms when a room they are + in changes to/from public visibility. + + Args: + ts (int): Timestamp in seconds + room_id (str) + is_public (bool) + """ + # For now, blindly iterate over all local users in the room so that + # we can handle the whole problem of copying buckets over as needed + user_ids = yield self.store.get_users_in_room(room_id) + + for user_id in user_ids: + if self.hs.is_mine(UserID.from_string(user_id)): + yield self.store.update_stats_delta( + ts, "user", user_id, "public_rooms", +1 if is_public else -1 + ) + yield self.store.update_stats_delta( + ts, "user", user_id, "private_rooms", -1 if is_public else +1 + ) + + @defer.inlineCallbacks + def _is_public_room(self, room_id): + join_rules = yield self.state.get_current_state(room_id, EventTypes.JoinRules) + history_visibility = yield self.state.get_current_state( + room_id, EventTypes.RoomHistoryVisibility + ) + + if (join_rules and join_rules.content.get("join_rule") == JoinRules.PUBLIC) or ( + ( + history_visibility + and history_visibility.content.get("history_visibility") + == "world_readable" + ) + ): + defer.returnValue(True) + else: + defer.returnValue(False) diff --git a/synapse/server.py b/synapse/server.py index 80d40b9272..9229a68a8d 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -72,6 +72,7 @@ from synapse.handlers.room_list import RoomListHandler from synapse.handlers.room_member import RoomMemberMasterHandler from synapse.handlers.room_member_worker import RoomMemberWorkerHandler from synapse.handlers.set_password import SetPasswordHandler +from synapse.handlers.stats import StatsHandler from synapse.handlers.sync import SyncHandler from synapse.handlers.typing import TypingHandler from synapse.handlers.user_directory import UserDirectoryHandler @@ -139,6 +140,7 @@ class HomeServer(object): 'acme_handler', 'auth_handler', 'device_handler', + 'stats_handler', 'e2e_keys_handler', 'e2e_room_keys_handler', 'event_handler', @@ -191,6 +193,7 @@ class HomeServer(object): REQUIRED_ON_MASTER_STARTUP = [ "user_directory_handler", + "stats_handler" ] # This is overridden in derived application classes @@ -474,6 +477,9 @@ class HomeServer(object): def build_secrets(self): return Secrets() + def build_stats_handler(self): + return StatsHandler(self) + def build_spam_checker(self): return SpamChecker(self) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 7522d3fd57..66675d08ae 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -55,6 +55,7 @@ from .roommember import RoomMemberStore from .search import SearchStore from .signatures import SignatureStore from .state import StateStore +from .stats import StatsStore from .stream import StreamStore from .tags import TagsStore from .transactions import TransactionStore @@ -100,6 +101,7 @@ class DataStore( GroupServerStore, UserErasureStore, MonthlyActiveUsersStore, + StatsStore, RelationsStore, ): def __init__(self, db_conn, hs): diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index adc6cf26b5..83ffae2132 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -611,3 +611,27 @@ class EventsWorkerStore(SQLBaseStore): return res return self.runInteraction("get_rejection_reasons", f) + + def _get_total_state_event_counts_txn(self, txn, room_id): + """ + See get_state_event_counts. + """ + sql = "SELECT COUNT(*) FROM state_events WHERE room_id=?" + txn.execute(sql, (room_id,)) + row = txn.fetchone() + return row[0] if row else 0 + + def get_total_state_event_counts(self, room_id): + """ + Gets the total number of state events in a room. + + Args: + room_id (str) + + Returns: + Deferred[int] + """ + return self.runInteraction( + "get_total_state_event_counts", + self._get_total_state_event_counts_txn, room_id + ) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 57df17bcc2..4bd1669458 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -142,6 +142,38 @@ class RoomMemberWorkerStore(EventsWorkerStore): return self.runInteraction("get_room_summary", _get_room_summary_txn) + def _get_user_count_in_room_txn(self, txn, room_id, membership): + """ + See get_user_count_in_room. + """ + sql = ( + "SELECT count(*) FROM room_memberships as m" + " INNER JOIN current_state_events as c" + " ON m.event_id = c.event_id " + " AND m.room_id = c.room_id " + " AND m.user_id = c.state_key" + " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?" + ) + + txn.execute(sql, (room_id, membership)) + row = txn.fetchone() + return row[0] + + def get_user_count_in_room(self, room_id, membership): + """ + Get the user count in a room with a particular membership. + + Args: + room_id (str) + membership (Membership) + + Returns: + Deferred[int] + """ + return self.runInteraction( + "get_users_in_room", self._get_user_count_in_room_txn, room_id, membership + ) + @cached() def get_invited_rooms_for_user(self, user_id): """ Get all the rooms the user is invited to diff --git a/synapse/storage/schema/delta/54/stats.sql b/synapse/storage/schema/delta/54/stats.sql new file mode 100644 index 0000000000..652e58308e --- /dev/null +++ b/synapse/storage/schema/delta/54/stats.sql @@ -0,0 +1,80 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE stats_stream_pos ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_id BIGINT, + CHECK (Lock='X') +); + +INSERT INTO stats_stream_pos (stream_id) VALUES (null); + +CREATE TABLE user_stats ( + user_id TEXT NOT NULL, + ts BIGINT NOT NULL, + bucket_size INT NOT NULL, + public_rooms INT NOT NULL, + private_rooms INT NOT NULL +); + +CREATE UNIQUE INDEX user_stats_user_ts ON user_stats(user_id, ts); + +CREATE TABLE room_stats ( + room_id TEXT NOT NULL, + ts BIGINT NOT NULL, + bucket_size INT NOT NULL, + current_state_events INT NOT NULL, + joined_members INT NOT NULL, + invited_members INT NOT NULL, + left_members INT NOT NULL, + banned_members INT NOT NULL, + state_events INT NOT NULL +); + +CREATE UNIQUE INDEX room_stats_room_ts ON room_stats(room_id, ts); + +-- cache of current room state; useful for the publicRooms list +CREATE TABLE room_state ( + room_id TEXT NOT NULL, + join_rules TEXT, + history_visibility TEXT, + encryption TEXT, + name TEXT, + topic TEXT, + avatar TEXT, + canonical_alias TEXT + -- get aliases straight from the right table +); + +CREATE UNIQUE INDEX room_state_room ON room_state(room_id); + +CREATE TABLE room_stats_earliest_token ( + room_id TEXT NOT NULL, + token BIGINT NOT NULL +); + +CREATE UNIQUE INDEX room_stats_earliest_token_idx ON room_stats_earliest_token(room_id); + +-- Set up staging tables +INSERT INTO background_updates (update_name, progress_json) VALUES + ('populate_stats_createtables', '{}'); + +-- Run through each room and update stats +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_stats_process_rooms', '{}', 'populate_stats_createtables'); + +-- Clean up staging tables +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_stats_cleanup', '{}', 'populate_stats_process_rooms'); diff --git a/synapse/storage/state_deltas.py b/synapse/storage/state_deltas.py index 31a0279b18..5fdb442104 100644 --- a/synapse/storage/state_deltas.py +++ b/synapse/storage/state_deltas.py @@ -84,10 +84,16 @@ class StateDeltasStore(SQLBaseStore): "get_current_state_deltas", get_current_state_deltas_txn ) - def get_max_stream_id_in_current_state_deltas(self): - return self._simple_select_one_onecol( + def _get_max_stream_id_in_current_state_deltas_txn(self, txn): + return self._simple_select_one_onecol_txn( + txn, table="current_state_delta_stream", keyvalues={}, retcol="COALESCE(MAX(stream_id), -1)", - desc="get_max_stream_id_in_current_state_deltas", + ) + + def get_max_stream_id_in_current_state_deltas(self): + return self.runInteraction( + "get_max_stream_id_in_current_state_deltas", + self._get_max_stream_id_in_current_state_deltas_txn, ) diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py new file mode 100644 index 0000000000..71b80a891d --- /dev/null +++ b/synapse/storage/stats.py @@ -0,0 +1,450 @@ +# -*- coding: utf-8 -*- +# Copyright 2018, 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.storage.state_deltas import StateDeltasStore +from synapse.util.caches.descriptors import cached + +logger = logging.getLogger(__name__) + +# these fields track absolutes (e.g. total number of rooms on the server) +ABSOLUTE_STATS_FIELDS = { + "room": ( + "current_state_events", + "joined_members", + "invited_members", + "left_members", + "banned_members", + "state_events", + ), + "user": ("public_rooms", "private_rooms"), +} + +TYPE_TO_ROOM = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user_id")} + +TEMP_TABLE = "_temp_populate_stats" + + +class StatsStore(StateDeltasStore): + def __init__(self, db_conn, hs): + super(StatsStore, self).__init__(db_conn, hs) + + self.server_name = hs.hostname + self.clock = self.hs.get_clock() + self.stats_enabled = hs.config.stats_enabled + self.stats_bucket_size = hs.config.stats_bucket_size + + self.register_background_update_handler( + "populate_stats_createtables", self._populate_stats_createtables + ) + self.register_background_update_handler( + "populate_stats_process_rooms", self._populate_stats_process_rooms + ) + self.register_background_update_handler( + "populate_stats_cleanup", self._populate_stats_cleanup + ) + + @defer.inlineCallbacks + def _populate_stats_createtables(self, progress, batch_size): + + if not self.stats_enabled: + yield self._end_background_update("populate_stats_createtables") + defer.returnValue(1) + + # Get all the rooms that we want to process. + def _make_staging_area(txn): + sql = ( + "CREATE TABLE IF NOT EXISTS " + + TEMP_TABLE + + "_rooms(room_id TEXT NOT NULL, events BIGINT NOT NULL)" + ) + txn.execute(sql) + + sql = ( + "CREATE TABLE IF NOT EXISTS " + + TEMP_TABLE + + "_position(position TEXT NOT NULL)" + ) + txn.execute(sql) + + # Get rooms we want to process from the database + sql = """ + SELECT room_id, count(*) FROM current_state_events + GROUP BY room_id + """ + txn.execute(sql) + rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()] + self._simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) + del rooms + + new_pos = yield self.get_max_stream_id_in_current_state_deltas() + yield self.runInteraction("populate_stats_temp_build", _make_staging_area) + yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) + self.get_earliest_token_for_room_stats.invalidate_all() + + yield self._end_background_update("populate_stats_createtables") + defer.returnValue(1) + + @defer.inlineCallbacks + def _populate_stats_cleanup(self, progress, batch_size): + """ + Update the user directory stream position, then clean up the old tables. + """ + if not self.stats_enabled: + yield self._end_background_update("populate_stats_cleanup") + defer.returnValue(1) + + position = yield self._simple_select_one_onecol( + TEMP_TABLE + "_position", None, "position" + ) + yield self.update_stats_stream_pos(position) + + def _delete_staging_area(txn): + txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms") + txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position") + + yield self.runInteraction("populate_stats_cleanup", _delete_staging_area) + + yield self._end_background_update("populate_stats_cleanup") + defer.returnValue(1) + + @defer.inlineCallbacks + def _populate_stats_process_rooms(self, progress, batch_size): + + if not self.stats_enabled: + yield self._end_background_update("populate_stats_process_rooms") + defer.returnValue(1) + + # If we don't have progress filed, delete everything. + if not progress: + yield self.delete_all_stats() + + def _get_next_batch(txn): + # Only fetch 250 rooms, so we don't fetch too many at once, even + # if those 250 rooms have less than batch_size state events. + sql = """ + SELECT room_id, events FROM %s_rooms + ORDER BY events DESC + LIMIT 250 + """ % ( + TEMP_TABLE, + ) + txn.execute(sql) + rooms_to_work_on = txn.fetchall() + + if not rooms_to_work_on: + return None + + # Get how many are left to process, so we can give status on how + # far we are in processing + txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms") + progress["remaining"] = txn.fetchone()[0] + + return rooms_to_work_on + + rooms_to_work_on = yield self.runInteraction( + "populate_stats_temp_read", _get_next_batch + ) + + # No more rooms -- complete the transaction. + if not rooms_to_work_on: + yield self._end_background_update("populate_stats_process_rooms") + defer.returnValue(1) + + logger.info( + "Processing the next %d rooms of %d remaining", + (len(rooms_to_work_on), progress["remaining"]), + ) + + # Number of state events we've processed by going through each room + processed_event_count = 0 + + for room_id, event_count in rooms_to_work_on: + + current_state_ids = yield self.get_current_state_ids(room_id) + + join_rules = yield self.get_event( + current_state_ids.get((EventTypes.JoinRules, "")), allow_none=True + ) + history_visibility = yield self.get_event( + current_state_ids.get((EventTypes.RoomHistoryVisibility, "")), + allow_none=True, + ) + encryption = yield self.get_event( + current_state_ids.get((EventTypes.RoomEncryption, "")), allow_none=True + ) + name = yield self.get_event( + current_state_ids.get((EventTypes.Name, "")), allow_none=True + ) + topic = yield self.get_event( + current_state_ids.get((EventTypes.Topic, "")), allow_none=True + ) + avatar = yield self.get_event( + current_state_ids.get((EventTypes.RoomAvatar, "")), allow_none=True + ) + canonical_alias = yield self.get_event( + current_state_ids.get((EventTypes.CanonicalAlias, "")), allow_none=True + ) + + def _or_none(x, arg): + if x: + return x.content.get(arg) + return None + + yield self.update_room_state( + room_id, + { + "join_rules": _or_none(join_rules, "join_rule"), + "history_visibility": _or_none( + history_visibility, "history_visibility" + ), + "encryption": _or_none(encryption, "algorithm"), + "name": _or_none(name, "name"), + "topic": _or_none(topic, "topic"), + "avatar": _or_none(avatar, "url"), + "canonical_alias": _or_none(canonical_alias, "alias"), + }, + ) + + now = self.hs.get_reactor().seconds() + + # quantise time to the nearest bucket + now = (now // self.stats_bucket_size) * self.stats_bucket_size + + def _fetch_data(txn): + + # Get the current token of the room + current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn) + + current_state_events = len(current_state_ids) + joined_members = self._get_user_count_in_room_txn( + txn, room_id, Membership.JOIN + ) + invited_members = self._get_user_count_in_room_txn( + txn, room_id, Membership.INVITE + ) + left_members = self._get_user_count_in_room_txn( + txn, room_id, Membership.LEAVE + ) + banned_members = self._get_user_count_in_room_txn( + txn, room_id, Membership.BAN + ) + total_state_events = self._get_total_state_event_counts_txn( + txn, room_id + ) + + self._update_stats_txn( + txn, + "room", + room_id, + now, + { + "bucket_size": self.stats_bucket_size, + "current_state_events": current_state_events, + "joined_members": joined_members, + "invited_members": invited_members, + "left_members": left_members, + "banned_members": banned_members, + "state_events": total_state_events, + }, + ) + self._simple_insert_txn( + txn, + "room_stats_earliest_token", + {"room_id": room_id, "token": current_token}, + ) + + yield self.runInteraction("update_room_stats", _fetch_data) + + # We've finished a room. Delete it from the table. + yield self._simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id}) + # Update the remaining counter. + progress["remaining"] -= 1 + yield self.runInteraction( + "populate_stats", + self._background_update_progress_txn, + "populate_stats_process_rooms", + progress, + ) + + processed_event_count += event_count + + if processed_event_count > batch_size: + # Don't process any more rooms, we've hit our batch size. + defer.returnValue(processed_event_count) + + defer.returnValue(processed_event_count) + + def delete_all_stats(self): + """ + Delete all statistics records. + """ + + def _delete_all_stats_txn(txn): + txn.execute("DELETE FROM room_state") + txn.execute("DELETE FROM room_stats") + txn.execute("DELETE FROM room_stats_earliest_token") + txn.execute("DELETE FROM user_stats") + + return self.runInteraction("delete_all_stats", _delete_all_stats_txn) + + def get_stats_stream_pos(self): + return self._simple_select_one_onecol( + table="stats_stream_pos", + keyvalues={}, + retcol="stream_id", + desc="stats_stream_pos", + ) + + def update_stats_stream_pos(self, stream_id): + return self._simple_update_one( + table="stats_stream_pos", + keyvalues={}, + updatevalues={"stream_id": stream_id}, + desc="update_stats_stream_pos", + ) + + def update_room_state(self, room_id, fields): + """ + Args: + room_id (str) + fields (dict[str:Any]) + """ + return self._simple_upsert( + table="room_state", + keyvalues={"room_id": room_id}, + values=fields, + desc="update_room_state", + ) + + def get_deltas_for_room(self, room_id, start, size=100): + """ + Get statistics deltas for a given room. + + Args: + room_id (str) + start (int): Pagination start. Number of entries, not timestamp. + size (int): How many entries to return. + + Returns: + Deferred[list[dict]], where the dict has the keys of + ABSOLUTE_STATS_FIELDS["room"] and "ts". + """ + return self._simple_select_list_paginate( + "room_stats", + {"room_id": room_id}, + "ts", + start, + size, + retcols=(list(ABSOLUTE_STATS_FIELDS["room"]) + ["ts"]), + order_direction="DESC", + ) + + def get_all_room_state(self): + return self._simple_select_list( + "room_state", None, retcols=("name", "topic", "canonical_alias") + ) + + @cached() + def get_earliest_token_for_room_stats(self, room_id): + """ + Fetch the "earliest token". This is used by the room stats delta + processor to ignore deltas that have been processed between the + start of the background task and any particular room's stats + being calculated. + + Returns: + Deferred[int] + """ + return self._simple_select_one_onecol( + "room_stats_earliest_token", + {"room_id": room_id}, + retcol="token", + allow_none=True, + ) + + def update_stats(self, stats_type, stats_id, ts, fields): + table, id_col = TYPE_TO_ROOM[stats_type] + return self._simple_upsert( + table=table, + keyvalues={id_col: stats_id, "ts": ts}, + values=fields, + desc="update_stats", + ) + + def _update_stats_txn(self, txn, stats_type, stats_id, ts, fields): + table, id_col = TYPE_TO_ROOM[stats_type] + return self._simple_upsert_txn( + txn, table=table, keyvalues={id_col: stats_id, "ts": ts}, values=fields + ) + + def update_stats_delta(self, ts, stats_type, stats_id, field, value): + def _update_stats_delta(txn): + table, id_col = TYPE_TO_ROOM[stats_type] + + sql = ( + "SELECT * FROM %s" + " WHERE %s=? and ts=(" + " SELECT MAX(ts) FROM %s" + " WHERE %s=?" + ")" + ) % (table, id_col, table, id_col) + txn.execute(sql, (stats_id, stats_id)) + rows = self.cursor_to_dict(txn) + if len(rows) == 0: + # silently skip as we don't have anything to apply a delta to yet. + # this tries to minimise any race between the initial sync and + # subsequent deltas arriving. + return + + current_ts = ts + latest_ts = rows[0]["ts"] + if current_ts < latest_ts: + # This one is in the past, but we're just encountering it now. + # Mark it as part of the current bucket. + current_ts = latest_ts + elif ts != latest_ts: + # we have to copy our absolute counters over to the new entry. + values = { + key: rows[0][key] for key in ABSOLUTE_STATS_FIELDS[stats_type] + } + values[id_col] = stats_id + values["ts"] = ts + values["bucket_size"] = self.stats_bucket_size + + self._simple_insert_txn(txn, table=table, values=values) + + # actually update the new value + if stats_type in ABSOLUTE_STATS_FIELDS[stats_type]: + self._simple_update_txn( + txn, + table=table, + keyvalues={id_col: stats_id, "ts": current_ts}, + updatevalues={field: value}, + ) + else: + sql = ("UPDATE %s SET %s=%s+? WHERE %s=? AND ts=?") % ( + table, + field, + field, + id_col, + ) + txn.execute(sql, (value, stats_id, current_ts)) + + return self.runInteraction("update_stats_delta", _update_stats_delta) diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py new file mode 100644 index 0000000000..249aba3d59 --- /dev/null +++ b/tests/handlers/test_stats.py @@ -0,0 +1,251 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mock import Mock + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests import unittest + + +class StatsRoomTests(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + + self.store = hs.get_datastore() + self.handler = self.hs.get_stats_handler() + + def _add_background_updates(self): + """ + Add the background updates we need to run. + """ + # Ugh, have to reset this flag + self.store._all_done = False + + self.get_success( + self.store._simple_insert( + "background_updates", + {"update_name": "populate_stats_createtables", "progress_json": "{}"}, + ) + ) + self.get_success( + self.store._simple_insert( + "background_updates", + { + "update_name": "populate_stats_process_rooms", + "progress_json": "{}", + "depends_on": "populate_stats_createtables", + }, + ) + ) + self.get_success( + self.store._simple_insert( + "background_updates", + { + "update_name": "populate_stats_cleanup", + "progress_json": "{}", + "depends_on": "populate_stats_process_rooms", + }, + ) + ) + + def test_initial_room(self): + """ + The background updates will build the table from scratch. + """ + r = self.get_success(self.store.get_all_room_state()) + self.assertEqual(len(r), 0) + + # Disable stats + self.hs.config.stats_enabled = False + self.handler.stats_enabled = False + + u1 = self.register_user("u1", "pass") + u1_token = self.login("u1", "pass") + + room_1 = self.helper.create_room_as(u1, tok=u1_token) + self.helper.send_state( + room_1, event_type="m.room.topic", body={"topic": "foo"}, tok=u1_token + ) + + # Stats disabled, shouldn't have done anything + r = self.get_success(self.store.get_all_room_state()) + self.assertEqual(len(r), 0) + + # Enable stats + self.hs.config.stats_enabled = True + self.handler.stats_enabled = True + + # Do the initial population of the user directory via the background update + self._add_background_updates() + + while not self.get_success(self.store.has_completed_background_updates()): + self.get_success(self.store.do_next_background_update(100), by=0.1) + + r = self.get_success(self.store.get_all_room_state()) + + self.assertEqual(len(r), 1) + self.assertEqual(r[0]["topic"], "foo") + + def test_initial_earliest_token(self): + """ + Ingestion via notify_new_event will ignore tokens that the background + update have already processed. + """ + self.reactor.advance(86401) + + self.hs.config.stats_enabled = False + self.handler.stats_enabled = False + + u1 = self.register_user("u1", "pass") + u1_token = self.login("u1", "pass") + + u2 = self.register_user("u2", "pass") + u2_token = self.login("u2", "pass") + + u3 = self.register_user("u3", "pass") + u3_token = self.login("u3", "pass") + + room_1 = self.helper.create_room_as(u1, tok=u1_token) + self.helper.send_state( + room_1, event_type="m.room.topic", body={"topic": "foo"}, tok=u1_token + ) + + # Begin the ingestion by creating the temp tables. This will also store + # the position that the deltas should begin at, once they take over. + self.hs.config.stats_enabled = True + self.handler.stats_enabled = True + self.store._all_done = False + self.get_success(self.store.update_stats_stream_pos(None)) + + self.get_success( + self.store._simple_insert( + "background_updates", + {"update_name": "populate_stats_createtables", "progress_json": "{}"}, + ) + ) + + while not self.get_success(self.store.has_completed_background_updates()): + self.get_success(self.store.do_next_background_update(100), by=0.1) + + # Now, before the table is actually ingested, add some more events. + self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token) + self.helper.join(room=room_1, user=u2, tok=u2_token) + + # Now do the initial ingestion. + self.get_success( + self.store._simple_insert( + "background_updates", + {"update_name": "populate_stats_process_rooms", "progress_json": "{}"}, + ) + ) + self.get_success( + self.store._simple_insert( + "background_updates", + { + "update_name": "populate_stats_cleanup", + "progress_json": "{}", + "depends_on": "populate_stats_process_rooms", + }, + ) + ) + + self.store._all_done = False + while not self.get_success(self.store.has_completed_background_updates()): + self.get_success(self.store.do_next_background_update(100), by=0.1) + + self.reactor.advance(86401) + + # Now add some more events, triggering ingestion. Because of the stream + # position being set to before the events sent in the middle, a simpler + # implementation would reprocess those events, and say there were four + # users, not three. + self.helper.invite(room=room_1, src=u1, targ=u3, tok=u1_token) + self.helper.join(room=room_1, user=u3, tok=u3_token) + + # Get the deltas! There should be two -- day 1, and day 2. + r = self.get_success(self.store.get_deltas_for_room(room_1, 0)) + + # The oldest has 2 joined members + self.assertEqual(r[-1]["joined_members"], 2) + + # The newest has 3 + self.assertEqual(r[0]["joined_members"], 3) + + def test_incorrect_state_transition(self): + """ + If the state transition is not one of (JOIN, INVITE, LEAVE, BAN) to + (JOIN, INVITE, LEAVE, BAN), an error is raised. + """ + events = { + "a1": {"membership": Membership.LEAVE}, + "a2": {"membership": "not a real thing"}, + } + + def get_event(event_id): + m = Mock() + m.content = events[event_id] + d = defer.Deferred() + self.reactor.callLater(0.0, d.callback, m) + return d + + def get_received_ts(event_id): + return defer.succeed(1) + + self.store.get_received_ts = get_received_ts + self.store.get_event = get_event + + deltas = [ + { + "type": EventTypes.Member, + "state_key": "some_user", + "room_id": "room", + "event_id": "a1", + "prev_event_id": "a2", + "stream_id": "bleb", + } + ] + + f = self.get_failure(self.handler._handle_deltas(deltas), ValueError) + self.assertEqual( + f.value.args[0], "'not a real thing' is not a valid prev_membership" + ) + + # And the other way... + deltas = [ + { + "type": EventTypes.Member, + "state_key": "some_user", + "room_id": "room", + "event_id": "a2", + "prev_event_id": "a1", + "stream_id": "bleb", + } + ] + + f = self.get_failure(self.handler._handle_deltas(deltas), ValueError) + self.assertEqual( + f.value.args[0], "'not a real thing' is not a valid membership" + ) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 05b0143c42..f7133fc12e 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -127,3 +127,20 @@ class RestHelper(object): ) return channel.json_body + + def send_state(self, room_id, event_type, body, tok, expect_code=200): + path = "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, event_type) + if tok: + path = path + "?access_token=%s" % tok + + request, channel = make_request( + self.hs.get_reactor(), "PUT", path, json.dumps(body).encode('utf8') + ) + render(request, self.resource, self.hs.get_reactor()) + + assert int(channel.result["code"]) == expect_code, ( + "Expected: %d, got: %d, resp: %r" + % (expect_code, int(channel.result["code"]), channel.result["body"]) + ) + + return channel.json_body -- cgit 1.5.1 From 85d1e03b9d50c1c64d2742ef3ec4f2dcd2bf7f9f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 23 May 2019 11:17:42 +0100 Subject: Simplifications and comments in do_auth (#5227) I was staring at this function trying to figure out wtf it was actually doing. This is (hopefully) a non-functional refactor which makes it a bit clearer. --- changelog.d/5227.misc | 1 + synapse/handlers/federation.py | 301 +++++++++++++++++++++++---------------- synapse/storage/events_worker.py | 2 +- 3 files changed, 183 insertions(+), 121 deletions(-) create mode 100644 changelog.d/5227.misc (limited to 'synapse/storage') diff --git a/changelog.d/5227.misc b/changelog.d/5227.misc new file mode 100644 index 0000000000..32bd7b6009 --- /dev/null +++ b/changelog.d/5227.misc @@ -0,0 +1 @@ +Simplifications and comments in do_auth. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 2202ed699a..cf4fad7de0 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2013,15 +2013,44 @@ class FederationHandler(BaseHandler): Args: origin (str): - event (synapse.events.FrozenEvent): + event (synapse.events.EventBase): context (synapse.events.snapshot.EventContext): - auth_events (dict[(str, str)->str]): + auth_events (dict[(str, str)->synapse.events.EventBase]): + Map from (event_type, state_key) to event + + What we expect the event's auth_events to be, based on the event's + position in the dag. I think? maybe?? + + Also NB that this function adds entries to it. + Returns: + defer.Deferred[None] + """ + room_version = yield self.store.get_room_version(event.room_id) + + yield self._update_auth_events_and_context_for_auth( + origin, event, context, auth_events + ) + try: + self.auth.check(room_version, event, auth_events=auth_events) + except AuthError as e: + logger.warn("Failed auth resolution for %r because %s", event, e) + raise e + + @defer.inlineCallbacks + def _update_auth_events_and_context_for_auth( + self, origin, event, context, auth_events + ): + """Helper for do_auth. See there for docs. + + Args: + origin (str): + event (synapse.events.EventBase): + context (synapse.events.snapshot.EventContext): + auth_events (dict[(str, str)->synapse.events.EventBase]): Returns: defer.Deferred[None] """ - # Check if we have all the auth events. - current_state = set(e.event_id for e in auth_events.values()) event_auth_events = set(event.auth_event_ids()) if event.is_state(): @@ -2029,11 +2058,21 @@ class FederationHandler(BaseHandler): else: event_key = None - if event_auth_events - current_state: + # if the event's auth_events refers to events which are not in our + # calculated auth_events, we need to fetch those events from somewhere. + # + # we start by fetching them from the store, and then try calling /event_auth/. + missing_auth = event_auth_events.difference( + e.event_id for e in auth_events.values() + ) + + if missing_auth: # TODO: can we use store.have_seen_events here instead? have_events = yield self.store.get_seen_events_with_rejections( - event_auth_events - current_state + missing_auth ) + logger.debug("Got events %s from store", have_events) + missing_auth.difference_update(have_events.keys()) else: have_events = {} @@ -2042,13 +2081,12 @@ class FederationHandler(BaseHandler): for e in auth_events.values() }) - seen_events = set(have_events.keys()) - - missing_auth = event_auth_events - seen_events - current_state - if missing_auth: - logger.info("Missing auth: %s", missing_auth) # If we don't have all the auth events, we need to get them. + logger.info( + "auth_events contains unknown events: %s", + missing_auth, + ) try: remote_auth_chain = yield self.federation_client.get_event_auth( origin, event.room_id, event.event_id @@ -2089,145 +2127,168 @@ class FederationHandler(BaseHandler): have_events = yield self.store.get_seen_events_with_rejections( event.auth_event_ids() ) - seen_events = set(have_events.keys()) except Exception: # FIXME: logger.exception("Failed to get auth chain") + if event.internal_metadata.is_outlier(): + logger.info("Skipping auth_event fetch for outlier") + return + # FIXME: Assumes we have and stored all the state for all the # prev_events - current_state = set(e.event_id for e in auth_events.values()) - different_auth = event_auth_events - current_state + different_auth = event_auth_events.difference( + e.event_id for e in auth_events.values() + ) - room_version = yield self.store.get_room_version(event.room_id) + if not different_auth: + return - if different_auth and not event.internal_metadata.is_outlier(): - # Do auth conflict res. - logger.info("Different auth: %s", different_auth) - - different_events = yield logcontext.make_deferred_yieldable( - defer.gatherResults([ - logcontext.run_in_background( - self.store.get_event, - d, - allow_none=True, - allow_rejected=False, - ) - for d in different_auth - if d in have_events and not have_events[d] - ], consumeErrors=True) - ).addErrback(unwrapFirstError) - - if different_events: - local_view = dict(auth_events) - remote_view = dict(auth_events) - remote_view.update({ - (d.type, d.state_key): d for d in different_events if d - }) + logger.info( + "auth_events refers to events which are not in our calculated auth " + "chain: %s", + different_auth, + ) + + room_version = yield self.store.get_room_version(event.room_id) - new_state = yield self.state_handler.resolve_events( - room_version, - [list(local_view.values()), list(remote_view.values())], - event + different_events = yield logcontext.make_deferred_yieldable( + defer.gatherResults([ + logcontext.run_in_background( + self.store.get_event, + d, + allow_none=True, + allow_rejected=False, ) + for d in different_auth + if d in have_events and not have_events[d] + ], consumeErrors=True) + ).addErrback(unwrapFirstError) + + if different_events: + local_view = dict(auth_events) + remote_view = dict(auth_events) + remote_view.update({ + (d.type, d.state_key): d for d in different_events if d + }) - auth_events.update(new_state) + new_state = yield self.state_handler.resolve_events( + room_version, + [list(local_view.values()), list(remote_view.values())], + event + ) - current_state = set(e.event_id for e in auth_events.values()) - different_auth = event_auth_events - current_state + logger.info( + "After state res: updating auth_events with new state %s", + { + (d.type, d.state_key): d.event_id for d in new_state.values() + if auth_events.get((d.type, d.state_key)) != d + }, + ) - yield self._update_context_for_auth_events( - event, context, auth_events, event_key, - ) + auth_events.update(new_state) + + different_auth = event_auth_events.difference( + e.event_id for e in auth_events.values() + ) - if different_auth and not event.internal_metadata.is_outlier(): - logger.info("Different auth after resolution: %s", different_auth) + yield self._update_context_for_auth_events( + event, context, auth_events, event_key, + ) - # Only do auth resolution if we have something new to say. - # We can't rove an auth failure. - do_resolution = False + if not different_auth: + # we're done + return - provable = [ - RejectedReason.NOT_ANCESTOR, RejectedReason.NOT_ANCESTOR, - ] + logger.info( + "auth_events still refers to events which are not in the calculated auth " + "chain after state resolution: %s", + different_auth, + ) - for e_id in different_auth: - if e_id in have_events: - if have_events[e_id] in provable: - do_resolution = True - break + # Only do auth resolution if we have something new to say. + # We can't prove an auth failure. + do_resolution = False - if do_resolution: - prev_state_ids = yield context.get_prev_state_ids(self.store) - # 1. Get what we think is the auth chain. - auth_ids = yield self.auth.compute_auth_events( - event, prev_state_ids - ) - local_auth_chain = yield self.store.get_auth_chain( - auth_ids, include_given=True - ) + for e_id in different_auth: + if e_id in have_events: + if have_events[e_id] == RejectedReason.NOT_ANCESTOR: + do_resolution = True + break - try: - # 2. Get remote difference. - result = yield self.federation_client.query_auth( - origin, - event.room_id, - event.event_id, - local_auth_chain, - ) + if not do_resolution: + logger.info( + "Skipping auth resolution due to lack of provable rejection reasons" + ) + return - seen_remotes = yield self.store.have_seen_events( - [e.event_id for e in result["auth_chain"]] - ) + logger.info("Doing auth resolution") - # 3. Process any remote auth chain events we haven't seen. - for ev in result["auth_chain"]: - if ev.event_id in seen_remotes: - continue + prev_state_ids = yield context.get_prev_state_ids(self.store) - if ev.event_id == event.event_id: - continue + # 1. Get what we think is the auth chain. + auth_ids = yield self.auth.compute_auth_events( + event, prev_state_ids + ) + local_auth_chain = yield self.store.get_auth_chain( + auth_ids, include_given=True + ) - try: - auth_ids = ev.auth_event_ids() - auth = { - (e.type, e.state_key): e - for e in result["auth_chain"] - if e.event_id in auth_ids - or event.type == EventTypes.Create - } - ev.internal_metadata.outlier = True + try: + # 2. Get remote difference. + result = yield self.federation_client.query_auth( + origin, + event.room_id, + event.event_id, + local_auth_chain, + ) - logger.debug( - "do_auth %s different_auth: %s", - event.event_id, e.event_id - ) + seen_remotes = yield self.store.have_seen_events( + [e.event_id for e in result["auth_chain"]] + ) - yield self._handle_new_event( - origin, ev, auth_events=auth - ) + # 3. Process any remote auth chain events we haven't seen. + for ev in result["auth_chain"]: + if ev.event_id in seen_remotes: + continue - if ev.event_id in event_auth_events: - auth_events[(ev.type, ev.state_key)] = ev - except AuthError: - pass + if ev.event_id == event.event_id: + continue - except Exception: - # FIXME: - logger.exception("Failed to query auth chain") + try: + auth_ids = ev.auth_event_ids() + auth = { + (e.type, e.state_key): e + for e in result["auth_chain"] + if e.event_id in auth_ids + or event.type == EventTypes.Create + } + ev.internal_metadata.outlier = True + + logger.debug( + "do_auth %s different_auth: %s", + event.event_id, e.event_id + ) - # 4. Look at rejects and their proofs. - # TODO. + yield self._handle_new_event( + origin, ev, auth_events=auth + ) - yield self._update_context_for_auth_events( - event, context, auth_events, event_key, - ) + if ev.event_id in event_auth_events: + auth_events[(ev.type, ev.state_key)] = ev + except AuthError: + pass - try: - self.auth.check(room_version, event, auth_events=auth_events) - except AuthError as e: - logger.warn("Failed auth resolution for %r because %s", event, e) - raise e + except Exception: + # FIXME: + logger.exception("Failed to query auth chain") + + # 4. Look at rejects and their proofs. + # TODO. + + yield self._update_context_for_auth_events( + event, context, auth_events, event_key, + ) @defer.inlineCallbacks def _update_context_for_auth_events(self, event, context, auth_events, diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 83ffae2132..21b353cad3 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -610,7 +610,7 @@ class EventsWorkerStore(SQLBaseStore): return res - return self.runInteraction("get_rejection_reasons", f) + return self.runInteraction("get_seen_events_with_rejections", f) def _get_total_state_event_counts_txn(self, txn, room_id): """ -- cgit 1.5.1 From 2e052110ee0bca17b8e27b6b48ee8b7c64bc94ae Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 23 May 2019 11:45:39 +0100 Subject: Rewrite store_server_verify_key to store several keys at once (#5234) Storing server keys hammered the database a bit. This replaces the implementation which stored a single key, with one which can do many updates at once. --- changelog.d/5234.misc | 1 + synapse/crypto/keyring.py | 59 ++++++++++------------------------------ synapse/storage/keys.py | 65 ++++++++++++++++++++++++++------------------ tests/crypto/test_keyring.py | 14 ++++++++-- tests/storage/test_keys.py | 44 ++++++++++++++++++++---------- 5 files changed, 96 insertions(+), 87 deletions(-) create mode 100644 changelog.d/5234.misc (limited to 'synapse/storage') diff --git a/changelog.d/5234.misc b/changelog.d/5234.misc new file mode 100644 index 0000000000..43fbd6f67c --- /dev/null +++ b/changelog.d/5234.misc @@ -0,0 +1 @@ +Rewrite store_server_verify_key to store several keys at once. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 5cc98542ce..badb5254ea 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -453,10 +453,11 @@ class Keyring(object): raise_from(KeyLookupError("Remote server returned an error"), e) keys = {} + added_keys = [] - responses = query_response["server_keys"] + time_now_ms = self.clock.time_msec() - for response in responses: + for response in query_response["server_keys"]: if ( u"signatures" not in response or perspective_name not in response[u"signatures"] @@ -492,21 +493,13 @@ class Keyring(object): ) server_name = response["server_name"] + added_keys.extend( + (server_name, key_id, key) for key_id, key in processed_response.items() + ) keys.setdefault(server_name, {}).update(processed_response) - yield logcontext.make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background( - self.store_keys, - server_name=server_name, - from_server=perspective_name, - verify_keys=response_keys, - ) - for server_name, response_keys in keys.items() - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) + yield self.store.store_server_verify_keys( + perspective_name, time_now_ms, added_keys ) defer.returnValue(keys) @@ -519,6 +512,7 @@ class Keyring(object): if requested_key_id in keys: continue + time_now_ms = self.clock.time_msec() try: response = yield self.client.get_json( destination=server_name, @@ -548,12 +542,13 @@ class Keyring(object): requested_ids=[requested_key_id], response_json=response, ) - + yield self.store.store_server_verify_keys( + server_name, + time_now_ms, + ((server_name, key_id, key) for key_id, key in response_keys.items()), + ) keys.update(response_keys) - yield self.store_keys( - server_name=server_name, from_server=server_name, verify_keys=keys - ) defer.returnValue({server_name: keys}) @defer.inlineCallbacks @@ -650,32 +645,6 @@ class Keyring(object): defer.returnValue(response_keys) - def store_keys(self, server_name, from_server, verify_keys): - """Store a collection of verify keys for a given server - Args: - server_name(str): The name of the server the keys are for. - from_server(str): The server the keys were downloaded from. - verify_keys(dict): A mapping of key_id to VerifyKey. - Returns: - A deferred that completes when the keys are stored. - """ - # TODO(markjh): Store whether the keys have expired. - return logcontext.make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background( - self.store.store_server_verify_key, - server_name, - server_name, - key.time_added, - key, - ) - for key_id, key in verify_keys.items() - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) - ) - @defer.inlineCallbacks def _handle_key_deferred(verify_request): diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 7036541792..3c5f52009b 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -84,38 +84,51 @@ class KeyStore(SQLBaseStore): return self.runInteraction("get_server_verify_keys", _txn) - def store_server_verify_key( - self, server_name, from_server, time_now_ms, verify_key - ): - """Stores a NACL verification key for the given server. + def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): + """Stores NACL verification keys for remote servers. Args: - server_name (str): The name of the server. - from_server (str): Where the verification key was looked up - time_now_ms (int): The time now in milliseconds - verify_key (nacl.signing.VerifyKey): The NACL verify key. + from_server (str): Where the verification keys were looked up + ts_added_ms (int): The time to record that the key was added + verify_keys (iterable[tuple[str, str, nacl.signing.VerifyKey]]): + keys to be stored. Each entry is a triplet of + (server_name, key_id, key). """ - key_id = "%s:%s" % (verify_key.alg, verify_key.version) - - # XXX fix this to not need a lock (#3819) - def _txn(txn): - self._simple_upsert_txn( - txn, - table="server_signature_keys", - keyvalues={"server_name": server_name, "key_id": key_id}, - values={ - "from_server": from_server, - "ts_added_ms": time_now_ms, - "verify_key": db_binary_type(verify_key.encode()), - }, + key_values = [] + value_values = [] + invalidations = [] + for server_name, key_id, verify_key in verify_keys: + key_values.append((server_name, key_id)) + value_values.append( + ( + from_server, + ts_added_ms, + db_binary_type(verify_key.encode()), + ) ) # invalidate takes a tuple corresponding to the params of # _get_server_verify_key. _get_server_verify_key only takes one # param, which is itself the 2-tuple (server_name, key_id). - txn.call_after( - self._get_server_verify_key.invalidate, ((server_name, key_id),) - ) - - return self.runInteraction("store_server_verify_key", _txn) + invalidations.append((server_name, key_id)) + + def _invalidate(res): + f = self._get_server_verify_key.invalidate + for i in invalidations: + f((i, )) + return res + + return self.runInteraction( + "store_server_verify_keys", + self._simple_upsert_many_txn, + table="server_signature_keys", + key_names=("server_name", "key_id"), + key_values=key_values, + value_names=( + "from_server", + "ts_added_ms", + "verify_key", + ), + value_values=value_values, + ).addCallback(_invalidate) def store_server_keys_json( self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 3c79d4afe7..bcffe53a91 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -192,8 +192,18 @@ class KeyringTestCase(unittest.HomeserverTestCase): kr = keyring.Keyring(self.hs) key1 = signedjson.key.generate_signing_key(1) - r = self.hs.datastore.store_server_verify_key( - "server9", "", time.time() * 1000, signedjson.key.get_verify_key(key1) + key1_id = "%s:%s" % (key1.alg, key1.version) + + r = self.hs.datastore.store_server_verify_keys( + "server9", + time.time() * 1000, + [ + ( + "server9", + key1_id, + signedjson.key.get_verify_key(key1), + ), + ], ) self.get_success(r) json1 = {} diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index 6bfaa00fe9..71ad7aee32 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -31,23 +31,32 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): def test_get_server_verify_keys(self): store = self.hs.get_datastore() - d = store.store_server_verify_key("server1", "from_server", 0, KEY_1) - self.get_success(d) - d = store.store_server_verify_key("server1", "from_server", 0, KEY_2) + key_id_1 = "ed25519:key1" + key_id_2 = "ed25519:KEY_ID_2" + d = store.store_server_verify_keys( + "from_server", + 10, + [ + ("server1", key_id_1, KEY_1), + ("server1", key_id_2, KEY_2), + ], + ) self.get_success(d) d = store.get_server_verify_keys( - [ - ("server1", "ed25519:key1"), - ("server1", "ed25519:key2"), - ("server1", "ed25519:key3"), - ] + [("server1", key_id_1), ("server1", key_id_2), ("server1", "ed25519:key3")] ) res = self.get_success(d) self.assertEqual(len(res.keys()), 3) - self.assertEqual(res[("server1", "ed25519:key1")].version, "key1") - self.assertEqual(res[("server1", "ed25519:key2")].version, "key2") + res1 = res[("server1", key_id_1)] + self.assertEqual(res1, KEY_1) + self.assertEqual(res1.version, "key1") + + res2 = res[("server1", key_id_2)] + self.assertEqual(res2, KEY_2) + # version comes from the ID it was stored with + self.assertEqual(res2.version, "KEY_ID_2") # non-existent result gives None self.assertIsNone(res[("server1", "ed25519:key3")]) @@ -60,9 +69,14 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): key_id_1 = "ed25519:key1" key_id_2 = "ed25519:key2" - d = store.store_server_verify_key("srv1", "from_server", 0, KEY_1) - self.get_success(d) - d = store.store_server_verify_key("srv1", "from_server", 0, KEY_2) + d = store.store_server_verify_keys( + "from_server", + 0, + [ + ("srv1", key_id_1, KEY_1), + ("srv1", key_id_2, KEY_2), + ], + ) self.get_success(d) d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) @@ -81,7 +95,9 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): new_key_2 = signedjson.key.get_verify_key( signedjson.key.generate_signing_key("key2") ) - d = store.store_server_verify_key("srv1", "from_server", 10, new_key_2) + d = store.store_server_verify_keys( + "from_server", 10, [("srv1", key_id_2, new_key_2)] + ) self.get_success(d) d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) -- cgit 1.5.1 From b75537beaf841089f9f07c9dbed04a7a420a8b1f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 3 Apr 2019 18:10:24 +0100 Subject: Store key validity time in the storage layer This is a first step to checking that the key is valid at the required moment. The idea here is that, rather than passing VerifyKey objects in and out of the storage layer, we instead pass FetchKeyResult objects, which simply wrap the VerifyKey and add a valid_until_ts field. --- changelog.d/5237.misc | 1 + synapse/crypto/keyring.py | 47 +++++++++++++++------- synapse/storage/keys.py | 31 +++++++++----- .../delta/54/add_validity_to_server_keys.sql | 23 +++++++++++ tests/crypto/test_keyring.py | 22 ++++++---- tests/storage/test_keys.py | 44 +++++++++++++------- 6 files changed, 122 insertions(+), 46 deletions(-) create mode 100644 changelog.d/5237.misc create mode 100644 synapse/storage/schema/delta/54/add_validity_to_server_keys.sql (limited to 'synapse/storage') diff --git a/changelog.d/5237.misc b/changelog.d/5237.misc new file mode 100644 index 0000000000..f4fe3b821b --- /dev/null +++ b/changelog.d/5237.misc @@ -0,0 +1 @@ +Store key validity time in the storage layer. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 9d629b2238..14a27288fd 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -20,7 +20,6 @@ from collections import namedtuple from six import raise_from from six.moves import urllib -import nacl.signing from signedjson.key import ( decode_verify_key_bytes, encode_verify_key_base64, @@ -43,6 +42,7 @@ from synapse.api.errors import ( RequestSendFailed, SynapseError, ) +from synapse.storage.keys import FetchKeyResult from synapse.util import logcontext, unwrapFirstError from synapse.util.logcontext import ( LoggingContext, @@ -307,11 +307,15 @@ class Keyring(object): # complete this VerifyKeyRequest. result_keys = results.get(server_name, {}) for key_id in verify_request.key_ids: - key = result_keys.get(key_id) - if key: + fetch_key_result = result_keys.get(key_id) + if fetch_key_result: with PreserveLoggingContext(): verify_request.deferred.callback( - (server_name, key_id, key) + ( + server_name, + key_id, + fetch_key_result.verify_key, + ) ) break else: @@ -348,12 +352,12 @@ class Keyring(object): def get_keys_from_store(self, server_name_and_key_ids): """ Args: - server_name_and_key_ids (iterable(Tuple[str, iterable[str]]): + server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]): list of (server_name, iterable[key_id]) tuples to fetch keys for Returns: - Deferred: resolves to dict[str, dict[str, VerifyKey|None]]: map from - server_name -> key_id -> VerifyKey + Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]: + map from server_name -> key_id -> FetchKeyResult """ keys_to_fetch = ( (server_name, key_id) @@ -430,6 +434,18 @@ class Keyring(object): def get_server_verify_key_v2_indirect( self, server_names_and_key_ids, perspective_name, perspective_keys ): + """ + Args: + server_names_and_key_ids (iterable[Tuple[str, iterable[str]]]): + list of (server_name, iterable[key_id]) tuples to fetch keys for + perspective_name (str): name of the notary server to query for the keys + perspective_keys (dict[str, VerifyKey]): map of key_id->key for the + notary server + + Returns: + Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map + from server_name -> key_id -> FetchKeyResult + """ # TODO(mark): Set the minimum_valid_until_ts to that needed by # the events being validated or the current time if validating # an incoming request. @@ -506,7 +522,7 @@ class Keyring(object): @defer.inlineCallbacks def get_server_verify_key_v2_direct(self, server_name, key_ids): - keys = {} # type: dict[str, nacl.signing.VerifyKey] + keys = {} # type: dict[str, FetchKeyResult] for requested_key_id in key_ids: if requested_key_id in keys: @@ -583,9 +599,9 @@ class Keyring(object): actually in the response Returns: - Deferred[dict[str, nacl.signing.VerifyKey]]: - map from key_id to key object + Deferred[dict[str, FetchKeyResult]]: map from key_id to result object """ + ts_valid_until_ms = response_json[u"valid_until_ts"] # start by extracting the keys from the response, since they may be required # to validate the signature on the response. @@ -595,7 +611,9 @@ class Keyring(object): key_base64 = key_data["key"] key_bytes = decode_base64(key_base64) verify_key = decode_verify_key_bytes(key_id, key_bytes) - verify_keys[key_id] = verify_key + verify_keys[key_id] = FetchKeyResult( + verify_key=verify_key, valid_until_ts=ts_valid_until_ms + ) # TODO: improve this signature checking server_name = response_json["server_name"] @@ -606,7 +624,7 @@ class Keyring(object): ) verify_signed_json( - response_json, server_name, verify_keys[key_id] + response_json, server_name, verify_keys[key_id].verify_key ) for key_id, key_data in response_json["old_verify_keys"].items(): @@ -614,7 +632,9 @@ class Keyring(object): key_base64 = key_data["key"] key_bytes = decode_base64(key_base64) verify_key = decode_verify_key_bytes(key_id, key_bytes) - verify_keys[key_id] = verify_key + verify_keys[key_id] = FetchKeyResult( + verify_key=verify_key, valid_until_ts=key_data["expired_ts"] + ) # re-sign the json with our own key, so that it is ready if we are asked to # give it out as a notary server @@ -623,7 +643,6 @@ class Keyring(object): ) signed_key_json_bytes = encode_canonical_json(signed_key_json) - ts_valid_until_ms = signed_key_json[u"valid_until_ts"] # for reasons I don't quite understand, we store this json for the key ids we # requested, as well as those we got. diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 3c5f52009b..5300720dbb 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -19,6 +19,7 @@ import logging import six +import attr from signedjson.key import decode_verify_key_bytes from synapse.util import batch_iter @@ -36,6 +37,12 @@ else: db_binary_type = memoryview +@attr.s(slots=True, frozen=True) +class FetchKeyResult(object): + verify_key = attr.ib() # VerifyKey: the key itself + valid_until_ts = attr.ib() # int: how long we can use this key for + + class KeyStore(SQLBaseStore): """Persistence for signature verification keys """ @@ -54,8 +61,8 @@ class KeyStore(SQLBaseStore): iterable of (server_name, key-id) tuples to fetch keys for Returns: - Deferred: resolves to dict[Tuple[str, str], VerifyKey|None]: - map from (server_name, key_id) -> VerifyKey, or None if the key is + Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]: + map from (server_name, key_id) -> FetchKeyResult, or None if the key is unknown """ keys = {} @@ -65,17 +72,19 @@ class KeyStore(SQLBaseStore): # batch_iter always returns tuples so it's safe to do len(batch) sql = ( - "SELECT server_name, key_id, verify_key FROM server_signature_keys " - "WHERE 1=0" + "SELECT server_name, key_id, verify_key, ts_valid_until_ms " + "FROM server_signature_keys WHERE 1=0" ) + " OR (server_name=? AND key_id=?)" * len(batch) txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) for row in txn: - server_name, key_id, key_bytes = row - keys[(server_name, key_id)] = decode_verify_key_bytes( - key_id, bytes(key_bytes) + server_name, key_id, key_bytes, ts_valid_until_ms = row + res = FetchKeyResult( + verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)), + valid_until_ts=ts_valid_until_ms, ) + keys[(server_name, key_id)] = res def _txn(txn): for batch in batch_iter(server_name_and_key_ids, 50): @@ -89,20 +98,21 @@ class KeyStore(SQLBaseStore): Args: from_server (str): Where the verification keys were looked up ts_added_ms (int): The time to record that the key was added - verify_keys (iterable[tuple[str, str, nacl.signing.VerifyKey]]): + verify_keys (iterable[tuple[str, str, FetchKeyResult]]): keys to be stored. Each entry is a triplet of (server_name, key_id, key). """ key_values = [] value_values = [] invalidations = [] - for server_name, key_id, verify_key in verify_keys: + for server_name, key_id, fetch_result in verify_keys: key_values.append((server_name, key_id)) value_values.append( ( from_server, ts_added_ms, - db_binary_type(verify_key.encode()), + fetch_result.valid_until_ts, + db_binary_type(fetch_result.verify_key.encode()), ) ) # invalidate takes a tuple corresponding to the params of @@ -125,6 +135,7 @@ class KeyStore(SQLBaseStore): value_names=( "from_server", "ts_added_ms", + "ts_valid_until_ms", "verify_key", ), value_values=value_values, diff --git a/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql b/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql new file mode 100644 index 0000000000..c01aa9d2d9 --- /dev/null +++ b/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql @@ -0,0 +1,23 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* When we can use this key until, before we have to refresh it. */ +ALTER TABLE server_signature_keys ADD COLUMN ts_valid_until_ms BIGINT; + +UPDATE server_signature_keys SET ts_valid_until_ms = ( + SELECT MAX(ts_valid_until_ms) FROM server_keys_json skj WHERE + skj.server_name = server_signature_keys.server_name AND + skj.key_id = server_signature_keys.key_id +); diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index bcffe53a91..83de32b05d 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -25,6 +25,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.crypto import keyring from synapse.crypto.keyring import KeyLookupError +from synapse.storage.keys import FetchKeyResult from synapse.util import logcontext from synapse.util.logcontext import LoggingContext @@ -201,7 +202,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): ( "server9", key1_id, - signedjson.key.get_verify_key(key1), + FetchKeyResult(signedjson.key.get_verify_key(key1), 1000), ), ], ) @@ -251,9 +252,10 @@ class KeyringTestCase(unittest.HomeserverTestCase): server_name_and_key_ids = [(SERVER_NAME, ("key1",))] keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids)) k = keys[SERVER_NAME][testverifykey_id] - self.assertEqual(k, testverifykey) - self.assertEqual(k.alg, "ed25519") - self.assertEqual(k.version, "ver1") + self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) + self.assertEqual(k.verify_key, testverifykey) + self.assertEqual(k.verify_key.alg, "ed25519") + self.assertEqual(k.verify_key.version, "ver1") # check that the perspectives store is correctly updated lookup_triplet = (SERVER_NAME, testverifykey_id, None) @@ -321,9 +323,10 @@ class KeyringTestCase(unittest.HomeserverTestCase): keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids)) self.assertIn(SERVER_NAME, keys) k = keys[SERVER_NAME][testverifykey_id] - self.assertEqual(k, testverifykey) - self.assertEqual(k.alg, "ed25519") - self.assertEqual(k.version, "ver1") + self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) + self.assertEqual(k.verify_key, testverifykey) + self.assertEqual(k.verify_key.alg, "ed25519") + self.assertEqual(k.verify_key.version, "ver1") # check that the perspectives store is correctly updated lookup_triplet = (SERVER_NAME, testverifykey_id, None) @@ -346,7 +349,10 @@ class KeyringTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def run_in_context(f, *args, **kwargs): - with LoggingContext("testctx"): + with LoggingContext("testctx") as ctx: + # we set the "request" prop to make it easier to follow what's going on in the + # logs. + ctx.request = "testctx" rv = yield f(*args, **kwargs) defer.returnValue(rv) diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index 71ad7aee32..e07ff01201 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -17,6 +17,8 @@ import signedjson.key from twisted.internet.defer import Deferred +from synapse.storage.keys import FetchKeyResult + import tests.unittest KEY_1 = signedjson.key.decode_verify_key_base64( @@ -37,8 +39,8 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): "from_server", 10, [ - ("server1", key_id_1, KEY_1), - ("server1", key_id_2, KEY_2), + ("server1", key_id_1, FetchKeyResult(KEY_1, 100)), + ("server1", key_id_2, FetchKeyResult(KEY_2, 200)), ], ) self.get_success(d) @@ -50,13 +52,15 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): self.assertEqual(len(res.keys()), 3) res1 = res[("server1", key_id_1)] - self.assertEqual(res1, KEY_1) - self.assertEqual(res1.version, "key1") + self.assertEqual(res1.verify_key, KEY_1) + self.assertEqual(res1.verify_key.version, "key1") + self.assertEqual(res1.valid_until_ts, 100) res2 = res[("server1", key_id_2)] - self.assertEqual(res2, KEY_2) + self.assertEqual(res2.verify_key, KEY_2) # version comes from the ID it was stored with - self.assertEqual(res2.version, "KEY_ID_2") + self.assertEqual(res2.verify_key.version, "KEY_ID_2") + self.assertEqual(res2.valid_until_ts, 200) # non-existent result gives None self.assertIsNone(res[("server1", "ed25519:key3")]) @@ -73,8 +77,8 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): "from_server", 0, [ - ("srv1", key_id_1, KEY_1), - ("srv1", key_id_2, KEY_2), + ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)), + ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)), ], ) self.get_success(d) @@ -82,26 +86,38 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) res = self.get_success(d) self.assertEqual(len(res.keys()), 2) - self.assertEqual(res[("srv1", key_id_1)], KEY_1) - self.assertEqual(res[("srv1", key_id_2)], KEY_2) + + res1 = res[("srv1", key_id_1)] + self.assertEqual(res1.verify_key, KEY_1) + self.assertEqual(res1.valid_until_ts, 100) + + res2 = res[("srv1", key_id_2)] + self.assertEqual(res2.verify_key, KEY_2) + self.assertEqual(res2.valid_until_ts, 200) # we should be able to look up the same thing again without a db hit res = store.get_server_verify_keys([("srv1", key_id_1)]) if isinstance(res, Deferred): res = self.successResultOf(res) self.assertEqual(len(res.keys()), 1) - self.assertEqual(res[("srv1", key_id_1)], KEY_1) + self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1) new_key_2 = signedjson.key.get_verify_key( signedjson.key.generate_signing_key("key2") ) d = store.store_server_verify_keys( - "from_server", 10, [("srv1", key_id_2, new_key_2)] + "from_server", 10, [("srv1", key_id_2, FetchKeyResult(new_key_2, 300))] ) self.get_success(d) d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) res = self.get_success(d) self.assertEqual(len(res.keys()), 2) - self.assertEqual(res[("srv1", key_id_1)], KEY_1) - self.assertEqual(res[("srv1", key_id_2)], new_key_2) + + res1 = res[("srv1", key_id_1)] + self.assertEqual(res1.verify_key, KEY_1) + self.assertEqual(res1.valid_until_ts, 100) + + res2 = res[("srv1", key_id_2)] + self.assertEqual(res2.verify_key, new_key_2) + self.assertEqual(res2.valid_until_ts, 300) -- cgit 1.5.1 From bc4b2ecf70bc3965cbbf1daee52bf7577e219d7b Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Sat, 25 May 2019 12:02:48 -0600 Subject: Fix logging for room stats background update --- synapse/storage/stats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py index 71b80a891d..eb0ced5b5e 100644 --- a/synapse/storage/stats.py +++ b/synapse/storage/stats.py @@ -169,7 +169,7 @@ class StatsStore(StateDeltasStore): logger.info( "Processing the next %d rooms of %d remaining", - (len(rooms_to_work_on), progress["remaining"]), + len(rooms_to_work_on), progress["remaining"], ) # Number of state events we've processed by going through each room -- cgit 1.5.1 From ba17de7fbc29700163b23363ae0e03f8a01ef274 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 28 May 2019 10:11:38 +0100 Subject: Fix schema update for account validity --- .../storage/schema/delta/54/account_validity.sql | 27 ------------------- .../delta/54/account_validity_with_renewal.sql | 30 ++++++++++++++++++++++ 2 files changed, 30 insertions(+), 27 deletions(-) delete mode 100644 synapse/storage/schema/delta/54/account_validity.sql create mode 100644 synapse/storage/schema/delta/54/account_validity_with_renewal.sql (limited to 'synapse/storage') diff --git a/synapse/storage/schema/delta/54/account_validity.sql b/synapse/storage/schema/delta/54/account_validity.sql deleted file mode 100644 index 2357626000..0000000000 --- a/synapse/storage/schema/delta/54/account_validity.sql +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright 2019 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -DROP TABLE IF EXISTS account_validity; - --- Track what users are in public rooms. -CREATE TABLE IF NOT EXISTS account_validity ( - user_id TEXT PRIMARY KEY, - expiration_ts_ms BIGINT NOT NULL, - email_sent BOOLEAN NOT NULL, - renewal_token TEXT -); - -CREATE INDEX account_validity_email_sent_idx ON account_validity(email_sent, expiration_ts_ms) -CREATE UNIQUE INDEX account_validity_renewal_string_idx ON account_validity(renewal_token) diff --git a/synapse/storage/schema/delta/54/account_validity_with_renewal.sql b/synapse/storage/schema/delta/54/account_validity_with_renewal.sql new file mode 100644 index 0000000000..0adb2ad55e --- /dev/null +++ b/synapse/storage/schema/delta/54/account_validity_with_renewal.sql @@ -0,0 +1,30 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We previously changed the schema for this table without renaming the file, which means +-- that some databases might still be using the old schema. This ensures Synapse uses the +-- right schema for the table. +DROP TABLE IF EXISTS account_validity; + +-- Track what users are in public rooms. +CREATE TABLE IF NOT EXISTS account_validity ( + user_id TEXT PRIMARY KEY, + expiration_ts_ms BIGINT NOT NULL, + email_sent BOOLEAN NOT NULL, + renewal_token TEXT +); + +CREATE INDEX account_validity_email_sent_idx ON account_validity(email_sent, expiration_ts_ms) +CREATE UNIQUE INDEX account_validity_renewal_string_idx ON account_validity(renewal_token) -- cgit 1.5.1 From 52839886d664576831462e033b88e5aba4c019e3 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 28 May 2019 16:47:42 +0100 Subject: Allow configuring a range for the account validity startup job When enabling the account validity feature, Synapse will look at startup for registered account without an expiration date, and will set one equals to 'now + validity_period' for them. On large servers, it can mean that a large number of users will have the same expiration date, which means that they will all be sent a renewal email at the same time, which isn't ideal. In order to mitigate this, this PR allows server admins to define a 'max_delta' so that the expiration date is a random value in the [now + validity_period ; now + validity_period + max_delta] range. This allows renewal emails to be progressively sent over a configured period instead of being sent all in one big batch. --- synapse/config/registration.py | 11 +++++++++++ synapse/storage/_base.py | 23 +++++++++++++++++++++-- tests/rest/client/v2_alpha/test_register.py | 21 +++++++++++++++++++++ 3 files changed, 53 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 693288f938..b4fd4af368 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -39,6 +39,10 @@ class AccountValidityConfig(Config): else: self.renew_email_subject = "Renew your %(app)s account" + self.startup_job_max_delta = self.parse_duration( + config.get("startup_job_max_delta", 0), + ) + if self.renew_by_email_enabled and "public_baseurl" not in synapse_config: raise ConfigError("Can't send renewal emails without 'public_baseurl'") @@ -131,11 +135,18 @@ class RegistrationConfig(Config): # after that the validity period changes and Synapse is restarted, the users' # expiration dates won't be updated unless their account is manually renewed. # + # If set, the ``startup_job_max_delta`` optional setting will make the startup job + # described above set a random expiration date between t + period and + # t + period + startup_job_max_delta, t being the date and time at which the job + # sets the expiration date for a given user. This is useful for server admins that + # want to avoid Synapse sending a lot of renewal emails at once. + # #account_validity: # enabled: True # period: 6w # renew_at: 1w # renew_email_subject: "Renew your %%(app)s account" + # startup_job_max_delta: 2d # The user must provide all of the below types of 3PID when registering. # diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index fa6839ceca..40802fd3dc 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -16,6 +16,7 @@ # limitations under the License. import itertools import logging +import random import sys import threading import time @@ -247,6 +248,8 @@ class SQLBaseStore(object): self._check_safe_to_upsert, ) + self.rand = random.SystemRandom() + if self._account_validity.enabled: self._clock.call_later( 0.0, @@ -308,21 +311,37 @@ class SQLBaseStore(object): res = self.cursor_to_dict(txn) if res: for user in res: - self.set_expiration_date_for_user_txn(txn, user["name"]) + self.set_expiration_date_for_user_txn( + txn, + user["name"], + use_delta=True, + ) yield self.runInteraction( "get_users_with_no_expiration_date", select_users_with_no_expiration_date_txn, ) - def set_expiration_date_for_user_txn(self, txn, user_id): + def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): """Sets an expiration date to the account with the given user ID. Args: user_id (str): User ID to set an expiration date for. + use_delta (bool): If set to False, the expiration date for the user will be + now + validity period. If set to True, this expiration date will be a + random value in the [now + period; now + period + max_delta] range, + max_delta being the configured value for the size of the range, unless + delta is 0, in which case it sets it to now + period. """ now_ms = self._clock.time_msec() expiration_ts = now_ms + self._account_validity.period + + if use_delta and self._account_validity.startup_job_max_delta: + expiration_ts = self.rand.randrange( + expiration_ts, + expiration_ts + self._account_validity.startup_job_max_delta, + ) + self._simple_insert_txn( txn, "account_validity", diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index d4a1d4d50c..7603440fd8 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -436,6 +436,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.validity_period = 10 + self.max_delta = 10 config = self.default_config() @@ -459,8 +460,28 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): """ user_id = self.register_user("kermit", "user") + self.hs.config.account_validity.startup_job_max_delta = 0 + now_ms = self.hs.clock.time_msec() self.get_success(self.store._set_expiration_date_when_missing()) res = self.get_success(self.store.get_expiration_ts_for_user(user_id)) self.assertEqual(res, now_ms + self.validity_period) + + def test_background_job_with_max_delta(self): + """ + Tests the same thing as test_background_job, except that it sets the + startup_job_max_delta parameter and checks that the expiration date is within the + allowed range. + """ + user_id = self.register_user("kermit_delta", "user") + + self.hs.config.account_validity.startup_job_max_delta = self.max_delta + + now_ms = self.hs.clock.time_msec() + self.get_success(self.store._set_expiration_date_when_missing()) + + res = self.get_success(self.store.get_expiration_ts_for_user(user_id)) + + self.assertLessEqual(res, now_ms + self.validity_period + self.delta) + self.assertGreaterEqual(res, now_ms + self.validity_period) -- cgit 1.5.1 From 58c8ed5b0dbfe0556f11985a61c0e13bbe61d93c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 29 May 2019 11:56:24 +0100 Subject: Correctly filter out extremities with soft failed prevs (#5274) When we receive a soft failed event we, correctly, *do not* update the forward extremity table with the event. However, if we later receive an event that references the soft failed event we then need to remove the soft failed events prev events from the forward extremities table, otherwise we just build up forward extremities. Fixes #5269 --- changelog.d/5274.bugfix | 1 + synapse/storage/events.py | 82 +++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 80 insertions(+), 3 deletions(-) create mode 100644 changelog.d/5274.bugfix (limited to 'synapse/storage') diff --git a/changelog.d/5274.bugfix b/changelog.d/5274.bugfix new file mode 100644 index 0000000000..9e14d20289 --- /dev/null +++ b/changelog.d/5274.bugfix @@ -0,0 +1 @@ +Fix bug where we leaked extremities when we soft failed events, leading to performance degradation. diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 2ffc27ff41..6e9f3d1dc0 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -554,10 +554,18 @@ class EventsStore( e_id for event in new_events for e_id in event.prev_event_ids() ) - # Finally, remove any events which are prev_events of any existing events. + # Remove any events which are prev_events of any existing events. existing_prevs = yield self._get_events_which_are_prevs(result) result.difference_update(existing_prevs) + # Finally handle the case where the new events have soft-failed prev + # events. If they do we need to remove them and their prev events, + # otherwise we end up with dangling extremities. + existing_prevs = yield self._get_prevs_before_rejected( + e_id for event in new_events for e_id in event.prev_event_ids() + ) + result.difference_update(existing_prevs) + defer.returnValue(result) @defer.inlineCallbacks @@ -573,7 +581,7 @@ class EventsStore( """ results = [] - def _get_events(txn, batch): + def _get_events_which_are_prevs_txn(txn, batch): sql = """ SELECT prev_event_id, internal_metadata FROM event_edges @@ -596,10 +604,78 @@ class EventsStore( ) for chunk in batch_iter(event_ids, 100): - yield self.runInteraction("_get_events_which_are_prevs", _get_events, chunk) + yield self.runInteraction( + "_get_events_which_are_prevs", + _get_events_which_are_prevs_txn, + chunk, + ) defer.returnValue(results) + @defer.inlineCallbacks + def _get_prevs_before_rejected(self, event_ids): + """Get soft-failed ancestors to remove from the extremities. + + Given a set of events, find all those that have been soft-failed or + rejected. Returns those soft failed/rejected events and their prev + events (whether soft-failed/rejected or not), and recurses up the + prev-event graph until it finds no more soft-failed/rejected events. + + This is used to find extremities that are ancestors of new events, but + are separated by soft failed events. + + Args: + event_ids (Iterable[str]): Events to find prev events for. Note + that these must have already been persisted. + + Returns: + Deferred[set[str]] + """ + + # The set of event_ids to return. This includes all soft-failed events + # and their prev events. + existing_prevs = set() + + def _get_prevs_before_rejected_txn(txn, batch): + to_recursively_check = batch + + while to_recursively_check: + sql = """ + SELECT + event_id, prev_event_id, internal_metadata, + rejections.event_id IS NOT NULL + FROM event_edges + INNER JOIN events USING (event_id) + LEFT JOIN rejections USING (event_id) + LEFT JOIN event_json USING (event_id) + WHERE + event_id IN (%s) + AND NOT events.outlier + """ % ( + ",".join("?" for _ in to_recursively_check), + ) + + txn.execute(sql, to_recursively_check) + to_recursively_check = [] + + for event_id, prev_event_id, metadata, rejected in txn: + if prev_event_id in existing_prevs: + continue + + soft_failed = json.loads(metadata).get("soft_failed") + if soft_failed or rejected: + to_recursively_check.append(prev_event_id) + existing_prevs.add(prev_event_id) + + for chunk in batch_iter(event_ids, 100): + yield self.runInteraction( + "_get_prevs_before_rejected", + _get_prevs_before_rejected_txn, + chunk, + ) + + defer.returnValue(existing_prevs) + @defer.inlineCallbacks def _get_new_state_after_events( self, room_id, events_context, old_latest_event_ids, new_latest_event_ids -- cgit 1.5.1 From d79c9994f416ee5dab27a277fa729ffa5ee74ccc Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 28 May 2019 18:52:41 +0100 Subject: Add DB bg update to cleanup extremities. Due to #5269 we may have extremities in our DB that we shouldn't have, so lets add a cleanup task such to remove those. --- synapse/storage/events.py | 186 +++++++++++++++++++++ .../schema/delta/54/delete_forward_extremities.sql | 19 +++ 2 files changed, 205 insertions(+) create mode 100644 synapse/storage/schema/delta/54/delete_forward_extremities.sql (limited to 'synapse/storage') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 6e9f3d1dc0..a9be143bd5 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -221,6 +221,7 @@ class EventsStore( ): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" + EVENT_DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" def __init__(self, db_conn, hs): super(EventsStore, self).__init__(db_conn, hs) @@ -252,6 +253,11 @@ class EventsStore( psql_only=True, ) + self.register_background_update_handler( + self.EVENT_DELETE_SOFT_FAILED_EXTREMITIES, + self._cleanup_extremities_bg_update, + ) + self._event_persist_queue = _EventPeristenceQueue() self._state_resolution_handler = hs.get_state_resolution_handler() @@ -2341,6 +2347,186 @@ class EventsStore( get_all_updated_current_state_deltas_txn, ) + @defer.inlineCallbacks + def _cleanup_extremities_bg_update(self, progress, batch_size): + """Background update to clean out extremities that should have been + deleted previously. + + Mainly used to deal with the aftermath of #5269. + """ + + # This works by first copying all existing forward extremities into the + # `_extremities_to_check` table at start up, and then checking each + # event in that table whether we have any descendants that are not + # soft-failed/rejected. If that is the case then we delete that event + # from the forward extremities table. + # + # For efficiency, we do this in batches by recursively pulling out all + # descendants of a batch until we find the non soft-failed/rejected + # events, i.e. the set of descendants whose chain of prev events back + # to the batch of extremities are all soft-failed or rejected. + # Typically, we won't find any such events as extremities will rarely + # have any descendants, but if they do then we should delete those + # extremities. + + def _cleanup_extremities_bg_update_txn(txn): + # The set of extremity event IDs that we're checking this round + original_set = set() + + # A dict[str, set[str]] of event ID to their prev events. + graph = {} + + # The set of descendants of the original set that are not rejected + # nor soft-failed. Ancestors of these events should be removed + # from the forward extremities table. + non_rejected_leaves = set() + + # Set of event IDs that have been soft failed, and for which we + # should check if they have descendants which haven't been soft + # failed. + soft_failed_events_to_lookup = set() + + # First, we get `batch_size` events from the table, pulling out + # their prev events, if any, and their prev events rejection status. + txn.execute( + """SELECT prev_event_id, event_id, internal_metadata, + rejections.event_id IS NOT NULL, events.outlier + FROM ( + SELECT event_id AS prev_event_id + FROM _extremities_to_check + LIMIT ? + ) AS f + LEFT JOIN event_edges USING (prev_event_id) + LEFT JOIN events USING (event_id) + LEFT JOIN event_json USING (event_id) + LEFT JOIN rejections USING (event_id) + """, (batch_size,) + ) + + for prev_event_id, event_id, metadata, rejected, outlier in txn: + original_set.add(prev_event_id) + + if not event_id or outlier: + # Common case where the forward extremity doesn't have any + # descendants. + continue + + graph.setdefault(event_id, set()).add(prev_event_id) + + soft_failed = False + if metadata: + soft_failed = json.loads(metadata).get("soft_failed") + + if soft_failed or rejected: + soft_failed_events_to_lookup.add(event_id) + else: + non_rejected_leaves.add(event_id) + + # Now we recursively check all the soft-failed descendants we + # found above in the same way, until we have nothing left to + # check. + while soft_failed_events_to_lookup: + # We only want to do 100 at a time, so we split given list + # into two. + batch = list(soft_failed_events_to_lookup) + to_check, to_defer = batch[:100], batch[100:] + soft_failed_events_to_lookup = set(to_defer) + + sql = """SELECT prev_event_id, event_id, internal_metadata, + rejections.event_id IS NOT NULL + FROM event_edges + INNER JOIN events USING (event_id) + INNER JOIN event_json USING (event_id) + LEFT JOIN rejections USING (event_id) + WHERE + prev_event_id IN (%s) + AND NOT events.outlier + """ % ( + ",".join("?" for _ in to_check), + ) + txn.execute(sql, to_check) + + for prev_event_id, event_id, metadata, rejected in txn: + if event_id in graph: + # Already handled this event previously, but we still + # want to record the edge. + graph.setdefault(event_id, set()).add(prev_event_id) + logger.info("Already handled") + continue + + graph.setdefault(event_id, set()).add(prev_event_id) + + soft_failed = json.loads(metadata).get("soft_failed") + if soft_failed or rejected: + soft_failed_events_to_lookup.add(event_id) + else: + non_rejected_leaves.add(event_id) + + # We have a set of non-soft-failed descendants, so we recurse up + # the graph to find all ancestors and add them to the set of event + # IDs that we can delete from forward extremities table. + to_delete = set() + while non_rejected_leaves: + event_id = non_rejected_leaves.pop() + prev_event_ids = graph.get(event_id, set()) + non_rejected_leaves.update(prev_event_ids) + to_delete.update(prev_event_ids) + + to_delete.intersection_update(original_set) + + logger.info("Deleting up to %d forward extremities", len(to_delete)) + + self._simple_delete_many_txn( + txn=txn, + table="event_forward_extremities", + column="event_id", + iterable=to_delete, + keyvalues={}, + ) + + if to_delete: + # We now need to invalidate the caches of these rooms + rows = self._simple_select_many_txn( + txn, + table="events", + column="event_id", + iterable=to_delete, + keyvalues={}, + retcols=("room_id",) + ) + for row in rows: + txn.call_after( + self.get_latest_event_ids_in_room.invalidate, + (row["room_id"],) + ) + + self._simple_delete_many_txn( + txn=txn, + table="_extremities_to_check", + column="event_id", + iterable=original_set, + keyvalues={}, + ) + + return len(original_set) + + num_handled = yield self.runInteraction( + "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn, + ) + + if not num_handled: + yield self._end_background_update(self.EVENT_DELETE_SOFT_FAILED_EXTREMITIES) + + def _drop_table_txn(txn): + txn.execute("DROP TABLE _extremities_to_check") + + yield self.runInteraction( + "_cleanup_extremities_bg_update_drop_table", + _drop_table_txn, + ) + + defer.returnValue(num_handled) + AllNewEventsResult = namedtuple( "AllNewEventsResult", diff --git a/synapse/storage/schema/delta/54/delete_forward_extremities.sql b/synapse/storage/schema/delta/54/delete_forward_extremities.sql new file mode 100644 index 0000000000..7056bd1d00 --- /dev/null +++ b/synapse/storage/schema/delta/54/delete_forward_extremities.sql @@ -0,0 +1,19 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('delete_soft_failed_extremities', '{}'); + +CREATE TABLE _extremities_to_check AS SELECT event_id FROM event_forward_extremities; -- cgit 1.5.1 From 7e8e683754cdc606a1440832d9b1eb47f930ddee Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 29 May 2019 11:58:32 +0100 Subject: Log actual number of entries deleted --- synapse/storage/_base.py | 12 +++++++++--- synapse/storage/events.py | 6 ++++-- 2 files changed, 13 insertions(+), 5 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index fa6839ceca..3fe827cd43 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -1261,7 +1261,8 @@ class SQLBaseStore(object): " AND ".join("%s = ?" % (k,) for k in keyvalues), ) - return txn.execute(sql, list(keyvalues.values())) + txn.execute(sql, list(keyvalues.values())) + return txn.rowcount def _simple_delete_many(self, table, column, iterable, keyvalues, desc): return self.runInteraction( @@ -1280,9 +1281,12 @@ class SQLBaseStore(object): column : column name to test for inclusion against `iterable` iterable : list keyvalues : dict of column names and values to select the rows with + + Returns: + int: Number rows deleted """ if not iterable: - return + return 0 sql = "DELETE FROM %s" % table @@ -1297,7 +1301,9 @@ class SQLBaseStore(object): if clauses: sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) - return txn.execute(sql, values) + txn.execute(sql, values) + + return txn.rowcount def _get_cache_dict( self, db_conn, table, entity_column, stream_column, max_value, limit=100000 diff --git a/synapse/storage/events.py b/synapse/storage/events.py index a9be143bd5..a9664928ca 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -2476,7 +2476,7 @@ class EventsStore( logger.info("Deleting up to %d forward extremities", len(to_delete)) - self._simple_delete_many_txn( + deleted = self._simple_delete_many_txn( txn=txn, table="event_forward_extremities", column="event_id", @@ -2484,7 +2484,9 @@ class EventsStore( keyvalues={}, ) - if to_delete: + logger.info("Deleted %d forward extremities", deleted) + + if deleted: # We now need to invalidate the caches of these rooms rows = self._simple_select_many_txn( txn, -- cgit 1.5.1 From 46c8f7a5170d04dfa6ad02c69667d4aa48635231 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 30 May 2019 01:47:16 +1000 Subject: Implement the SHHS complexity API (#5216) --- changelog.d/5216.misc | 1 + synapse/api/urls.py | 1 + synapse/federation/transport/server.py | 31 +++++++++++- synapse/rest/admin/__init__.py | 12 +++-- synapse/storage/events_worker.py | 50 ++++++++++++++++++- tests/federation/test_complexity.py | 90 ++++++++++++++++++++++++++++++++++ 6 files changed, 180 insertions(+), 5 deletions(-) create mode 100644 changelog.d/5216.misc create mode 100644 tests/federation/test_complexity.py (limited to 'synapse/storage') diff --git a/changelog.d/5216.misc b/changelog.d/5216.misc new file mode 100644 index 0000000000..dbfa29475f --- /dev/null +++ b/changelog.d/5216.misc @@ -0,0 +1 @@ +Synapse will now serve the experimental "room complexity" API endpoint. diff --git a/synapse/api/urls.py b/synapse/api/urls.py index 3c6bddff7a..e16c386a14 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -26,6 +26,7 @@ CLIENT_API_PREFIX = "/_matrix/client" FEDERATION_PREFIX = "/_matrix/federation" FEDERATION_V1_PREFIX = FEDERATION_PREFIX + "/v1" FEDERATION_V2_PREFIX = FEDERATION_PREFIX + "/v2" +FEDERATION_UNSTABLE_PREFIX = FEDERATION_PREFIX + "/unstable" STATIC_PREFIX = "/_matrix/static" WEB_CLIENT_PREFIX = "/_matrix/client" CONTENT_REPO_PREFIX = "/_matrix/content" diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 385eda2dca..d0efc4e0d3 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -23,7 +23,11 @@ from twisted.internet import defer import synapse from synapse.api.errors import Codes, FederationDeniedError, SynapseError from synapse.api.room_versions import RoomVersions -from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX +from synapse.api.urls import ( + FEDERATION_UNSTABLE_PREFIX, + FEDERATION_V1_PREFIX, + FEDERATION_V2_PREFIX, +) from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.server import JsonResource from synapse.http.servlet import ( @@ -1304,6 +1308,30 @@ class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet): defer.returnValue((200, new_content)) +class RoomComplexityServlet(BaseFederationServlet): + """ + Indicates to other servers how complex (and therefore likely + resource-intensive) a public room this server knows about is. + """ + PATH = "/rooms/(?P[^/]*)/complexity" + PREFIX = FEDERATION_UNSTABLE_PREFIX + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, room_id): + + store = self.handler.hs.get_datastore() + + is_public = yield store.is_room_world_readable_or_publicly_joinable( + room_id + ) + + if not is_public: + raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM) + + complexity = yield store.get_room_complexity(room_id) + defer.returnValue((200, complexity)) + + FEDERATION_SERVLET_CLASSES = ( FederationSendServlet, FederationEventServlet, @@ -1327,6 +1355,7 @@ FEDERATION_SERVLET_CLASSES = ( FederationThirdPartyInviteExchangeServlet, On3pidBindServlet, FederationVersionServlet, + RoomComplexityServlet, ) OPENID_SERVLET_CLASSES = ( diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 744d85594f..d6c4dcdb18 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -822,10 +822,16 @@ class AdminRestResource(JsonResource): def __init__(self, hs): JsonResource.__init__(self, hs, canonical_json=False) + register_servlets(hs, self) - register_servlets_for_client_rest_resource(hs, self) - SendServerNoticeServlet(hs).register(self) - VersionServlet(hs).register(self) + +def register_servlets(hs, http_server): + """ + Register all the admin servlets. + """ + register_servlets_for_client_rest_resource(hs, http_server) + SendServerNoticeServlet(hs).register(http_server) + VersionServlet(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 21b353cad3..b56c83e460 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import division + import itertools import logging from collections import namedtuple @@ -614,7 +616,7 @@ class EventsWorkerStore(SQLBaseStore): def _get_total_state_event_counts_txn(self, txn, room_id): """ - See get_state_event_counts. + See get_total_state_event_counts. """ sql = "SELECT COUNT(*) FROM state_events WHERE room_id=?" txn.execute(sql, (room_id,)) @@ -635,3 +637,49 @@ class EventsWorkerStore(SQLBaseStore): "get_total_state_event_counts", self._get_total_state_event_counts_txn, room_id ) + + def _get_current_state_event_counts_txn(self, txn, room_id): + """ + See get_current_state_event_counts. + """ + sql = "SELECT COUNT(*) FROM current_state_events WHERE room_id=?" + txn.execute(sql, (room_id,)) + row = txn.fetchone() + return row[0] if row else 0 + + def get_current_state_event_counts(self, room_id): + """ + Gets the current number of state events in a room. + + Args: + room_id (str) + + Returns: + Deferred[int] + """ + return self.runInteraction( + "get_current_state_event_counts", + self._get_current_state_event_counts_txn, room_id + ) + + @defer.inlineCallbacks + def get_room_complexity(self, room_id): + """ + Get a rough approximation of the complexity of the room. This is used by + remote servers to decide whether they wish to join the room or not. + Higher complexity value indicates that being in the room will consume + more resources. + + Args: + room_id (str) + + Returns: + Deferred[dict[str:int]] of complexity version to complexity. + """ + state_events = yield self.get_current_state_event_counts(room_id) + + # Call this one "v1", so we can introduce new ones as we want to develop + # it. + complexity_v1 = round(state_events / 500, 2) + + defer.returnValue({"v1": complexity_v1}) diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py new file mode 100644 index 0000000000..1e3e5aec66 --- /dev/null +++ b/tests/federation/test_complexity.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 Matrix.org Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.config.ratelimiting import FederationRateLimitConfig +from synapse.federation.transport import server +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.util.ratelimitutils import FederationRateLimiter + +from tests import unittest + + +class RoomComplexityTests(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def default_config(self, name='test'): + config = super(RoomComplexityTests, self).default_config(name=name) + config["limit_large_remote_room_joins"] = True + config["limit_large_remote_room_complexity"] = 0.05 + return config + + def prepare(self, reactor, clock, homeserver): + class Authenticator(object): + def authenticate_request(self, request, content): + return defer.succeed("otherserver.nottld") + + ratelimiter = FederationRateLimiter( + clock, + FederationRateLimitConfig( + window_size=1, + sleep_limit=1, + sleep_msec=1, + reject_limit=1000, + concurrent_requests=1000, + ), + ) + server.register_servlets( + homeserver, self.resource, Authenticator(), ratelimiter + ) + + def test_complexity_simple(self): + + u1 = self.register_user("u1", "pass") + u1_token = self.login("u1", "pass") + + room_1 = self.helper.create_room_as(u1, tok=u1_token) + self.helper.send_state( + room_1, event_type="m.room.topic", body={"topic": "foo"}, tok=u1_token + ) + + # Get the room complexity + request, channel = self.make_request( + "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) + ) + self.render(request) + self.assertEquals(200, channel.code) + complexity = channel.json_body["v1"] + self.assertTrue(complexity > 0, complexity) + + # Artificially raise the complexity + store = self.hs.get_datastore() + store.get_current_state_event_counts = lambda x: defer.succeed(500 * 1.23) + + # Get the room complexity again -- make sure it's our artificial value + request, channel = self.make_request( + "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) + ) + self.render(request) + self.assertEquals(200, channel.code) + complexity = channel.json_body["v1"] + self.assertEqual(complexity, 1.23) -- cgit 1.5.1 From 640fcbb07f8dc7d89465734f009d8e0a458c2b17 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 30 May 2019 10:55:55 +0100 Subject: Fixup comments and logging --- synapse/storage/events.py | 21 ++++++++++++--------- .../schema/delta/54/delete_forward_extremities.sql | 3 +++ 2 files changed, 15 insertions(+), 9 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index a9664928ca..418d88b8dc 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -2387,7 +2387,8 @@ class EventsStore( soft_failed_events_to_lookup = set() # First, we get `batch_size` events from the table, pulling out - # their prev events, if any, and their prev events rejection status. + # their successor events, if any, and their successor events + # rejection status. txn.execute( """SELECT prev_event_id, event_id, internal_metadata, rejections.event_id IS NOT NULL, events.outlier @@ -2450,11 +2451,10 @@ class EventsStore( if event_id in graph: # Already handled this event previously, but we still # want to record the edge. - graph.setdefault(event_id, set()).add(prev_event_id) - logger.info("Already handled") + graph[event_id].add(prev_event_id) continue - graph.setdefault(event_id, set()).add(prev_event_id) + graph[event_id] = {prev_event_id} soft_failed = json.loads(metadata).get("soft_failed") if soft_failed or rejected: @@ -2474,8 +2474,6 @@ class EventsStore( to_delete.intersection_update(original_set) - logger.info("Deleting up to %d forward extremities", len(to_delete)) - deleted = self._simple_delete_many_txn( txn=txn, table="event_forward_extremities", @@ -2484,7 +2482,11 @@ class EventsStore( keyvalues={}, ) - logger.info("Deleted %d forward extremities", deleted) + logger.info( + "Deleted %d forward extremities of %d checked, to clean up #5269", + deleted, + len(original_set), + ) if deleted: # We now need to invalidate the caches of these rooms @@ -2496,10 +2498,11 @@ class EventsStore( keyvalues={}, retcols=("room_id",) ) - for row in rows: + room_ids = set(row["room_id"] for row in rows) + for room_id in room_ids: txn.call_after( self.get_latest_event_ids_in_room.invalidate, - (row["room_id"],) + (room_id,) ) self._simple_delete_many_txn( diff --git a/synapse/storage/schema/delta/54/delete_forward_extremities.sql b/synapse/storage/schema/delta/54/delete_forward_extremities.sql index 7056bd1d00..aa40f13da7 100644 --- a/synapse/storage/schema/delta/54/delete_forward_extremities.sql +++ b/synapse/storage/schema/delta/54/delete_forward_extremities.sql @@ -13,7 +13,10 @@ * limitations under the License. */ +-- Start a background job to cleanup extremities that were incorrectly added +-- by bug #5269. INSERT INTO background_updates (update_name, progress_json) VALUES ('delete_soft_failed_extremities', '{}'); +DROP TABLE IF EXISTS _extremities_to_check; -- To make this delta schema file idempotent. CREATE TABLE _extremities_to_check AS SELECT event_id FROM event_forward_extremities; -- cgit 1.5.1 From 5c1ece0ffcd803eb4bf8e5748d3e2633426e00a0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 30 May 2019 11:22:59 +0100 Subject: Move event background updates to a separate file --- synapse/storage/__init__.py | 2 + synapse/storage/events.py | 371 +------------------------------- synapse/storage/events_bg_updates.py | 401 +++++++++++++++++++++++++++++++++++ 3 files changed, 405 insertions(+), 369 deletions(-) create mode 100644 synapse/storage/events_bg_updates.py (limited to 'synapse/storage') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 66675d08ae..71316f7d09 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -36,6 +36,7 @@ from .engines import PostgresEngine from .event_federation import EventFederationStore from .event_push_actions import EventPushActionsStore from .events import EventsStore +from .events_bg_updates import EventsBackgroundUpdatesStore from .filtering import FilteringStore from .group_server import GroupServerStore from .keys import KeyStore @@ -66,6 +67,7 @@ logger = logging.getLogger(__name__) class DataStore( + EventsBackgroundUpdatesStore, RoomMemberStore, RoomStore, RegistrationStore, diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 418d88b8dc..f9162be9b9 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -219,47 +220,11 @@ class EventsStore( EventsWorkerStore, BackgroundUpdateStore, ): - EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" - EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" - EVENT_DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" def __init__(self, db_conn, hs): super(EventsStore, self).__init__(db_conn, hs) - self.register_background_update_handler( - self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts - ) - self.register_background_update_handler( - self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, - self._background_reindex_fields_sender, - ) - - self.register_background_index_update( - "event_contains_url_index", - index_name="event_contains_url_index", - table="events", - columns=["room_id", "topological_ordering", "stream_ordering"], - where_clause="contains_url = true AND outlier = false", - ) - - # an event_id index on event_search is useful for the purge_history - # api. Plus it means we get to enforce some integrity with a UNIQUE - # clause - self.register_background_index_update( - "event_search_event_id_idx", - index_name="event_search_event_id_idx", - table="event_search", - columns=["event_id"], - unique=True, - psql_only=True, - ) - - self.register_background_update_handler( - self.EVENT_DELETE_SOFT_FAILED_EXTREMITIES, - self._cleanup_extremities_bg_update, - ) self._event_persist_queue = _EventPeristenceQueue() - self._state_resolution_handler = hs.get_state_resolution_handler() @defer.inlineCallbacks @@ -1585,153 +1550,6 @@ class EventsStore( ret = yield self.runInteraction("count_daily_active_rooms", _count) defer.returnValue(ret) - @defer.inlineCallbacks - def _background_reindex_fields_sender(self, progress, batch_size): - target_min_stream_id = progress["target_min_stream_id_inclusive"] - max_stream_id = progress["max_stream_id_exclusive"] - rows_inserted = progress.get("rows_inserted", 0) - - INSERT_CLUMP_SIZE = 1000 - - def reindex_txn(txn): - sql = ( - "SELECT stream_ordering, event_id, json FROM events" - " INNER JOIN event_json USING (event_id)" - " WHERE ? <= stream_ordering AND stream_ordering < ?" - " ORDER BY stream_ordering DESC" - " LIMIT ?" - ) - - txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) - - rows = txn.fetchall() - if not rows: - return 0 - - min_stream_id = rows[-1][0] - - update_rows = [] - for row in rows: - try: - event_id = row[1] - event_json = json.loads(row[2]) - sender = event_json["sender"] - content = event_json["content"] - - contains_url = "url" in content - if contains_url: - contains_url &= isinstance(content["url"], text_type) - except (KeyError, AttributeError): - # If the event is missing a necessary field then - # skip over it. - continue - - update_rows.append((sender, contains_url, event_id)) - - sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?" - - for index in range(0, len(update_rows), INSERT_CLUMP_SIZE): - clump = update_rows[index : index + INSERT_CLUMP_SIZE] - txn.executemany(sql, clump) - - progress = { - "target_min_stream_id_inclusive": target_min_stream_id, - "max_stream_id_exclusive": min_stream_id, - "rows_inserted": rows_inserted + len(rows), - } - - self._background_update_progress_txn( - txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress - ) - - return len(rows) - - result = yield self.runInteraction( - self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn - ) - - if not result: - yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME) - - defer.returnValue(result) - - @defer.inlineCallbacks - def _background_reindex_origin_server_ts(self, progress, batch_size): - target_min_stream_id = progress["target_min_stream_id_inclusive"] - max_stream_id = progress["max_stream_id_exclusive"] - rows_inserted = progress.get("rows_inserted", 0) - - INSERT_CLUMP_SIZE = 1000 - - def reindex_search_txn(txn): - sql = ( - "SELECT stream_ordering, event_id FROM events" - " WHERE ? <= stream_ordering AND stream_ordering < ?" - " ORDER BY stream_ordering DESC" - " LIMIT ?" - ) - - txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) - - rows = txn.fetchall() - if not rows: - return 0 - - min_stream_id = rows[-1][0] - event_ids = [row[1] for row in rows] - - rows_to_update = [] - - chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)] - for chunk in chunks: - ev_rows = self._simple_select_many_txn( - txn, - table="event_json", - column="event_id", - iterable=chunk, - retcols=["event_id", "json"], - keyvalues={}, - ) - - for row in ev_rows: - event_id = row["event_id"] - event_json = json.loads(row["json"]) - try: - origin_server_ts = event_json["origin_server_ts"] - except (KeyError, AttributeError): - # If the event is missing a necessary field then - # skip over it. - continue - - rows_to_update.append((origin_server_ts, event_id)) - - sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?" - - for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE): - clump = rows_to_update[index : index + INSERT_CLUMP_SIZE] - txn.executemany(sql, clump) - - progress = { - "target_min_stream_id_inclusive": target_min_stream_id, - "max_stream_id_exclusive": min_stream_id, - "rows_inserted": rows_inserted + len(rows_to_update), - } - - self._background_update_progress_txn( - txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress - ) - - return len(rows_to_update) - - result = yield self.runInteraction( - self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn - ) - - if not result: - yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME) - - defer.returnValue(result) - def get_current_backfill_token(self): """The current minimum token that backfilled events have reached""" return -self._backfill_id_gen.get_current_token() @@ -2347,191 +2165,6 @@ class EventsStore( get_all_updated_current_state_deltas_txn, ) - @defer.inlineCallbacks - def _cleanup_extremities_bg_update(self, progress, batch_size): - """Background update to clean out extremities that should have been - deleted previously. - - Mainly used to deal with the aftermath of #5269. - """ - - # This works by first copying all existing forward extremities into the - # `_extremities_to_check` table at start up, and then checking each - # event in that table whether we have any descendants that are not - # soft-failed/rejected. If that is the case then we delete that event - # from the forward extremities table. - # - # For efficiency, we do this in batches by recursively pulling out all - # descendants of a batch until we find the non soft-failed/rejected - # events, i.e. the set of descendants whose chain of prev events back - # to the batch of extremities are all soft-failed or rejected. - # Typically, we won't find any such events as extremities will rarely - # have any descendants, but if they do then we should delete those - # extremities. - - def _cleanup_extremities_bg_update_txn(txn): - # The set of extremity event IDs that we're checking this round - original_set = set() - - # A dict[str, set[str]] of event ID to their prev events. - graph = {} - - # The set of descendants of the original set that are not rejected - # nor soft-failed. Ancestors of these events should be removed - # from the forward extremities table. - non_rejected_leaves = set() - - # Set of event IDs that have been soft failed, and for which we - # should check if they have descendants which haven't been soft - # failed. - soft_failed_events_to_lookup = set() - - # First, we get `batch_size` events from the table, pulling out - # their successor events, if any, and their successor events - # rejection status. - txn.execute( - """SELECT prev_event_id, event_id, internal_metadata, - rejections.event_id IS NOT NULL, events.outlier - FROM ( - SELECT event_id AS prev_event_id - FROM _extremities_to_check - LIMIT ? - ) AS f - LEFT JOIN event_edges USING (prev_event_id) - LEFT JOIN events USING (event_id) - LEFT JOIN event_json USING (event_id) - LEFT JOIN rejections USING (event_id) - """, (batch_size,) - ) - - for prev_event_id, event_id, metadata, rejected, outlier in txn: - original_set.add(prev_event_id) - - if not event_id or outlier: - # Common case where the forward extremity doesn't have any - # descendants. - continue - - graph.setdefault(event_id, set()).add(prev_event_id) - - soft_failed = False - if metadata: - soft_failed = json.loads(metadata).get("soft_failed") - - if soft_failed or rejected: - soft_failed_events_to_lookup.add(event_id) - else: - non_rejected_leaves.add(event_id) - - # Now we recursively check all the soft-failed descendants we - # found above in the same way, until we have nothing left to - # check. - while soft_failed_events_to_lookup: - # We only want to do 100 at a time, so we split given list - # into two. - batch = list(soft_failed_events_to_lookup) - to_check, to_defer = batch[:100], batch[100:] - soft_failed_events_to_lookup = set(to_defer) - - sql = """SELECT prev_event_id, event_id, internal_metadata, - rejections.event_id IS NOT NULL - FROM event_edges - INNER JOIN events USING (event_id) - INNER JOIN event_json USING (event_id) - LEFT JOIN rejections USING (event_id) - WHERE - prev_event_id IN (%s) - AND NOT events.outlier - """ % ( - ",".join("?" for _ in to_check), - ) - txn.execute(sql, to_check) - - for prev_event_id, event_id, metadata, rejected in txn: - if event_id in graph: - # Already handled this event previously, but we still - # want to record the edge. - graph[event_id].add(prev_event_id) - continue - - graph[event_id] = {prev_event_id} - - soft_failed = json.loads(metadata).get("soft_failed") - if soft_failed or rejected: - soft_failed_events_to_lookup.add(event_id) - else: - non_rejected_leaves.add(event_id) - - # We have a set of non-soft-failed descendants, so we recurse up - # the graph to find all ancestors and add them to the set of event - # IDs that we can delete from forward extremities table. - to_delete = set() - while non_rejected_leaves: - event_id = non_rejected_leaves.pop() - prev_event_ids = graph.get(event_id, set()) - non_rejected_leaves.update(prev_event_ids) - to_delete.update(prev_event_ids) - - to_delete.intersection_update(original_set) - - deleted = self._simple_delete_many_txn( - txn=txn, - table="event_forward_extremities", - column="event_id", - iterable=to_delete, - keyvalues={}, - ) - - logger.info( - "Deleted %d forward extremities of %d checked, to clean up #5269", - deleted, - len(original_set), - ) - - if deleted: - # We now need to invalidate the caches of these rooms - rows = self._simple_select_many_txn( - txn, - table="events", - column="event_id", - iterable=to_delete, - keyvalues={}, - retcols=("room_id",) - ) - room_ids = set(row["room_id"] for row in rows) - for room_id in room_ids: - txn.call_after( - self.get_latest_event_ids_in_room.invalidate, - (room_id,) - ) - - self._simple_delete_many_txn( - txn=txn, - table="_extremities_to_check", - column="event_id", - iterable=original_set, - keyvalues={}, - ) - - return len(original_set) - - num_handled = yield self.runInteraction( - "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn, - ) - - if not num_handled: - yield self._end_background_update(self.EVENT_DELETE_SOFT_FAILED_EXTREMITIES) - - def _drop_table_txn(txn): - txn.execute("DROP TABLE _extremities_to_check") - - yield self.runInteraction( - "_cleanup_extremities_bg_update_drop_table", - _drop_table_txn, - ) - - defer.returnValue(num_handled) - AllNewEventsResult = namedtuple( "AllNewEventsResult", diff --git a/synapse/storage/events_bg_updates.py b/synapse/storage/events_bg_updates.py new file mode 100644 index 0000000000..2eba106abf --- /dev/null +++ b/synapse/storage/events_bg_updates.py @@ -0,0 +1,401 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from six import text_type + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.storage.background_updates import BackgroundUpdateStore + +logger = logging.getLogger(__name__) + + +class EventsBackgroundUpdatesStore(BackgroundUpdateStore): + + EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" + EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" + EVENT_DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" + + def __init__(self, db_conn, hs): + super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs) + + self.register_background_update_handler( + self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts + ) + self.register_background_update_handler( + self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, + self._background_reindex_fields_sender, + ) + + self.register_background_index_update( + "event_contains_url_index", + index_name="event_contains_url_index", + table="events", + columns=["room_id", "topological_ordering", "stream_ordering"], + where_clause="contains_url = true AND outlier = false", + ) + + # an event_id index on event_search is useful for the purge_history + # api. Plus it means we get to enforce some integrity with a UNIQUE + # clause + self.register_background_index_update( + "event_search_event_id_idx", + index_name="event_search_event_id_idx", + table="event_search", + columns=["event_id"], + unique=True, + psql_only=True, + ) + + self.register_background_update_handler( + self.EVENT_DELETE_SOFT_FAILED_EXTREMITIES, + self._cleanup_extremities_bg_update, + ) + + @defer.inlineCallbacks + def _background_reindex_fields_sender(self, progress, batch_size): + target_min_stream_id = progress["target_min_stream_id_inclusive"] + max_stream_id = progress["max_stream_id_exclusive"] + rows_inserted = progress.get("rows_inserted", 0) + + INSERT_CLUMP_SIZE = 1000 + + def reindex_txn(txn): + sql = ( + "SELECT stream_ordering, event_id, json FROM events" + " INNER JOIN event_json USING (event_id)" + " WHERE ? <= stream_ordering AND stream_ordering < ?" + " ORDER BY stream_ordering DESC" + " LIMIT ?" + ) + + txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) + + rows = txn.fetchall() + if not rows: + return 0 + + min_stream_id = rows[-1][0] + + update_rows = [] + for row in rows: + try: + event_id = row[1] + event_json = json.loads(row[2]) + sender = event_json["sender"] + content = event_json["content"] + + contains_url = "url" in content + if contains_url: + contains_url &= isinstance(content["url"], text_type) + except (KeyError, AttributeError): + # If the event is missing a necessary field then + # skip over it. + continue + + update_rows.append((sender, contains_url, event_id)) + + sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?" + + for index in range(0, len(update_rows), INSERT_CLUMP_SIZE): + clump = update_rows[index : index + INSERT_CLUMP_SIZE] + txn.executemany(sql, clump) + + progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + "rows_inserted": rows_inserted + len(rows), + } + + self._background_update_progress_txn( + txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress + ) + + return len(rows) + + result = yield self.runInteraction( + self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn + ) + + if not result: + yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME) + + defer.returnValue(result) + + @defer.inlineCallbacks + def _background_reindex_origin_server_ts(self, progress, batch_size): + target_min_stream_id = progress["target_min_stream_id_inclusive"] + max_stream_id = progress["max_stream_id_exclusive"] + rows_inserted = progress.get("rows_inserted", 0) + + INSERT_CLUMP_SIZE = 1000 + + def reindex_search_txn(txn): + sql = ( + "SELECT stream_ordering, event_id FROM events" + " WHERE ? <= stream_ordering AND stream_ordering < ?" + " ORDER BY stream_ordering DESC" + " LIMIT ?" + ) + + txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) + + rows = txn.fetchall() + if not rows: + return 0 + + min_stream_id = rows[-1][0] + event_ids = [row[1] for row in rows] + + rows_to_update = [] + + chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)] + for chunk in chunks: + ev_rows = self._simple_select_many_txn( + txn, + table="event_json", + column="event_id", + iterable=chunk, + retcols=["event_id", "json"], + keyvalues={}, + ) + + for row in ev_rows: + event_id = row["event_id"] + event_json = json.loads(row["json"]) + try: + origin_server_ts = event_json["origin_server_ts"] + except (KeyError, AttributeError): + # If the event is missing a necessary field then + # skip over it. + continue + + rows_to_update.append((origin_server_ts, event_id)) + + sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?" + + for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE): + clump = rows_to_update[index : index + INSERT_CLUMP_SIZE] + txn.executemany(sql, clump) + + progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + "rows_inserted": rows_inserted + len(rows_to_update), + } + + self._background_update_progress_txn( + txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress + ) + + return len(rows_to_update) + + result = yield self.runInteraction( + self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn + ) + + if not result: + yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME) + + defer.returnValue(result) + + @defer.inlineCallbacks + def _cleanup_extremities_bg_update(self, progress, batch_size): + """Background update to clean out extremities that should have been + deleted previously. + + Mainly used to deal with the aftermath of #5269. + """ + + # This works by first copying all existing forward extremities into the + # `_extremities_to_check` table at start up, and then checking each + # event in that table whether we have any descendants that are not + # soft-failed/rejected. If that is the case then we delete that event + # from the forward extremities table. + # + # For efficiency, we do this in batches by recursively pulling out all + # descendants of a batch until we find the non soft-failed/rejected + # events, i.e. the set of descendants whose chain of prev events back + # to the batch of extremities are all soft-failed or rejected. + # Typically, we won't find any such events as extremities will rarely + # have any descendants, but if they do then we should delete those + # extremities. + + def _cleanup_extremities_bg_update_txn(txn): + # The set of extremity event IDs that we're checking this round + original_set = set() + + # A dict[str, set[str]] of event ID to their prev events. + graph = {} + + # The set of descendants of the original set that are not rejected + # nor soft-failed. Ancestors of these events should be removed + # from the forward extremities table. + non_rejected_leaves = set() + + # Set of event IDs that have been soft failed, and for which we + # should check if they have descendants which haven't been soft + # failed. + soft_failed_events_to_lookup = set() + + # First, we get `batch_size` events from the table, pulling out + # their successor events, if any, and their successor events + # rejection status. + txn.execute( + """SELECT prev_event_id, event_id, internal_metadata, + rejections.event_id IS NOT NULL, events.outlier + FROM ( + SELECT event_id AS prev_event_id + FROM _extremities_to_check + LIMIT ? + ) AS f + LEFT JOIN event_edges USING (prev_event_id) + LEFT JOIN events USING (event_id) + LEFT JOIN event_json USING (event_id) + LEFT JOIN rejections USING (event_id) + """, (batch_size,) + ) + + for prev_event_id, event_id, metadata, rejected, outlier in txn: + original_set.add(prev_event_id) + + if not event_id or outlier: + # Common case where the forward extremity doesn't have any + # descendants. + continue + + graph.setdefault(event_id, set()).add(prev_event_id) + + soft_failed = False + if metadata: + soft_failed = json.loads(metadata).get("soft_failed") + + if soft_failed or rejected: + soft_failed_events_to_lookup.add(event_id) + else: + non_rejected_leaves.add(event_id) + + # Now we recursively check all the soft-failed descendants we + # found above in the same way, until we have nothing left to + # check. + while soft_failed_events_to_lookup: + # We only want to do 100 at a time, so we split given list + # into two. + batch = list(soft_failed_events_to_lookup) + to_check, to_defer = batch[:100], batch[100:] + soft_failed_events_to_lookup = set(to_defer) + + sql = """SELECT prev_event_id, event_id, internal_metadata, + rejections.event_id IS NOT NULL + FROM event_edges + INNER JOIN events USING (event_id) + INNER JOIN event_json USING (event_id) + LEFT JOIN rejections USING (event_id) + WHERE + prev_event_id IN (%s) + AND NOT events.outlier + """ % ( + ",".join("?" for _ in to_check), + ) + txn.execute(sql, to_check) + + for prev_event_id, event_id, metadata, rejected in txn: + if event_id in graph: + # Already handled this event previously, but we still + # want to record the edge. + graph[event_id].add(prev_event_id) + continue + + graph[event_id] = {prev_event_id} + + soft_failed = json.loads(metadata).get("soft_failed") + if soft_failed or rejected: + soft_failed_events_to_lookup.add(event_id) + else: + non_rejected_leaves.add(event_id) + + # We have a set of non-soft-failed descendants, so we recurse up + # the graph to find all ancestors and add them to the set of event + # IDs that we can delete from forward extremities table. + to_delete = set() + while non_rejected_leaves: + event_id = non_rejected_leaves.pop() + prev_event_ids = graph.get(event_id, set()) + non_rejected_leaves.update(prev_event_ids) + to_delete.update(prev_event_ids) + + to_delete.intersection_update(original_set) + + deleted = self._simple_delete_many_txn( + txn=txn, + table="event_forward_extremities", + column="event_id", + iterable=to_delete, + keyvalues={}, + ) + + logger.info( + "Deleted %d forward extremities of %d checked, to clean up #5269", + deleted, + len(original_set), + ) + + if deleted: + # We now need to invalidate the caches of these rooms + rows = self._simple_select_many_txn( + txn, + table="events", + column="event_id", + iterable=to_delete, + keyvalues={}, + retcols=("room_id",) + ) + room_ids = set(row["room_id"] for row in rows) + for room_id in room_ids: + txn.call_after( + self.get_latest_event_ids_in_room.invalidate, + (room_id,) + ) + + self._simple_delete_many_txn( + txn=txn, + table="_extremities_to_check", + column="event_id", + iterable=original_set, + keyvalues={}, + ) + + return len(original_set) + + num_handled = yield self.runInteraction( + "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn, + ) + + if not num_handled: + yield self._end_background_update(self.EVENT_DELETE_SOFT_FAILED_EXTREMITIES) + + def _drop_table_txn(txn): + txn.execute("DROP TABLE _extremities_to_check") + + yield self.runInteraction( + "_cleanup_extremities_bg_update_drop_table", + _drop_table_txn, + ) + + defer.returnValue(num_handled) -- cgit 1.5.1 From 468bd090ff354f27597e08b54a969f22afbfad43 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 30 May 2019 11:24:42 +0100 Subject: Rename constant --- synapse/storage/events_bg_updates.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/events_bg_updates.py b/synapse/storage/events_bg_updates.py index 2eba106abf..22aac1393d 100644 --- a/synapse/storage/events_bg_updates.py +++ b/synapse/storage/events_bg_updates.py @@ -30,7 +30,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" - EVENT_DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" + DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" def __init__(self, db_conn, hs): super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs) @@ -64,7 +64,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): ) self.register_background_update_handler( - self.EVENT_DELETE_SOFT_FAILED_EXTREMITIES, + self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update, ) @@ -388,7 +388,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): ) if not num_handled: - yield self._end_background_update(self.EVENT_DELETE_SOFT_FAILED_EXTREMITIES) + yield self._end_background_update(self.DELETE_SOFT_FAILED_EXTREMITIES) def _drop_table_txn(txn): txn.execute("DROP TABLE _extremities_to_check") -- cgit 1.5.1 From cb967e2346096d7e647c757e3b57093549746f14 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 30 May 2019 14:06:42 +0100 Subject: Update synapse/storage/events_bg_updates.py Co-Authored-By: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> --- synapse/storage/events_bg_updates.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/events_bg_updates.py b/synapse/storage/events_bg_updates.py index 22aac1393d..75c1935bf3 100644 --- a/synapse/storage/events_bg_updates.py +++ b/synapse/storage/events_bg_updates.py @@ -255,7 +255,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): soft_failed_events_to_lookup = set() # First, we get `batch_size` events from the table, pulling out - # their successor events, if any, and their successor events + # their successor events, if any, and the successor events' # rejection status. txn.execute( """SELECT prev_event_id, event_id, internal_metadata, -- cgit 1.5.1 From 6cdfb0207e2de72286a7a8d3b3c417c2808e90dc Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 30 May 2019 14:54:56 +0100 Subject: Add index to temp table --- synapse/storage/schema/delta/54/delete_forward_extremities.sql | 1 + 1 file changed, 1 insertion(+) (limited to 'synapse/storage') diff --git a/synapse/storage/schema/delta/54/delete_forward_extremities.sql b/synapse/storage/schema/delta/54/delete_forward_extremities.sql index aa40f13da7..b062ec840c 100644 --- a/synapse/storage/schema/delta/54/delete_forward_extremities.sql +++ b/synapse/storage/schema/delta/54/delete_forward_extremities.sql @@ -20,3 +20,4 @@ INSERT INTO background_updates (update_name, progress_json) VALUES DROP TABLE IF EXISTS _extremities_to_check; -- To make this delta schema file idempotent. CREATE TABLE _extremities_to_check AS SELECT event_id FROM event_forward_extremities; +CREATE INDEX _extremities_to_check_id ON _extremities_to_check(event_id); -- cgit 1.5.1 From 54d50fbfdf8c39d92c36291a572419cee6b9b916 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 30 May 2019 15:15:13 +0100 Subject: Get events all at once --- synapse/storage/stats.py | 59 +++++++++++++++++++++--------------------------- 1 file changed, 26 insertions(+), 33 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py index eb0ced5b5e..99b4af5555 100644 --- a/synapse/storage/stats.py +++ b/synapse/storage/stats.py @@ -179,46 +179,39 @@ class StatsStore(StateDeltasStore): current_state_ids = yield self.get_current_state_ids(room_id) - join_rules = yield self.get_event( - current_state_ids.get((EventTypes.JoinRules, "")), allow_none=True + join_rules_id = current_state_ids.get((EventTypes.JoinRules, "")) + history_visibility_id = current_state_ids.get( + (EventTypes.RoomHistoryVisibility, "") ) - history_visibility = yield self.get_event( - current_state_ids.get((EventTypes.RoomHistoryVisibility, "")), - allow_none=True, - ) - encryption = yield self.get_event( - current_state_ids.get((EventTypes.RoomEncryption, "")), allow_none=True - ) - name = yield self.get_event( - current_state_ids.get((EventTypes.Name, "")), allow_none=True - ) - topic = yield self.get_event( - current_state_ids.get((EventTypes.Topic, "")), allow_none=True - ) - avatar = yield self.get_event( - current_state_ids.get((EventTypes.RoomAvatar, "")), allow_none=True - ) - canonical_alias = yield self.get_event( - current_state_ids.get((EventTypes.CanonicalAlias, "")), allow_none=True - ) - - def _or_none(x, arg): - if x: - return x.content.get(arg) + encryption_id = current_state_ids.get((EventTypes.RoomEncryption, "")) + name_id = current_state_ids.get((EventTypes.Name, "")) + topic_id = current_state_ids.get((EventTypes.Topic, "")) + avatar_id = current_state_ids.get((EventTypes.RoomAvatar, "")) + canonical_alias_id = current_state_ids.get((EventTypes.CanonicalAlias, "")) + + state_events = yield self.get_events([ + join_rules_id, history_visibility_id, encryption_id, name_id, + topic_id, avatar_id, canonical_alias_id, + ]) + + def _get_or_none(event_id, arg): + event = state_events.get(event_id) + if event: + return event.content.get(arg) return None yield self.update_room_state( room_id, { - "join_rules": _or_none(join_rules, "join_rule"), - "history_visibility": _or_none( - history_visibility, "history_visibility" + "join_rules": _get_or_none(join_rules_id, "join_rule"), + "history_visibility": _get_or_none( + history_visibility_id, "history_visibility" ), - "encryption": _or_none(encryption, "algorithm"), - "name": _or_none(name, "name"), - "topic": _or_none(topic, "topic"), - "avatar": _or_none(avatar, "url"), - "canonical_alias": _or_none(canonical_alias, "alias"), + "encryption": _get_or_none(encryption_id, "algorithm"), + "name": _get_or_none(name_id, "name"), + "topic": _get_or_none(topic_id, "topic"), + "avatar": _get_or_none(avatar_id, "url"), + "canonical_alias": _get_or_none(canonical_alias_id, "alias"), }, ) -- cgit 1.5.1 From 04710cc2d71127b1f416e87f7a4aea3ce6d93410 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 30 May 2019 15:22:32 +0100 Subject: Fetch membership counts all at once --- synapse/storage/roommember.py | 33 +++++++++++---------------------- synapse/storage/stats.py | 23 +++++++---------------- 2 files changed, 18 insertions(+), 38 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 4bd1669458..7617913326 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -142,26 +142,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): return self.runInteraction("get_room_summary", _get_room_summary_txn) - def _get_user_count_in_room_txn(self, txn, room_id, membership): + def _get_user_counts_in_room_txn(self, txn, room_id): """ - See get_user_count_in_room. - """ - sql = ( - "SELECT count(*) FROM room_memberships as m" - " INNER JOIN current_state_events as c" - " ON m.event_id = c.event_id " - " AND m.room_id = c.room_id " - " AND m.user_id = c.state_key" - " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?" - ) - - txn.execute(sql, (room_id, membership)) - row = txn.fetchone() - return row[0] - - def get_user_count_in_room(self, room_id, membership): - """ - Get the user count in a room with a particular membership. + Get the user count in a room by membership. Args: room_id (str) @@ -170,9 +153,15 @@ class RoomMemberWorkerStore(EventsWorkerStore): Returns: Deferred[int] """ - return self.runInteraction( - "get_users_in_room", self._get_user_count_in_room_txn, room_id, membership - ) + sql = """ + SELECT m.membership, count(*) FROM room_memberships as m + INNER JOIN current_state_events as c USING(event_id) + WHERE c.type = 'm.room.member' AND c.room_id = ? + GROUP BY m.membership + """ + + txn.execute(sql, (room_id,)) + return {row[0]: row[1] for row in txn} @cached() def get_invited_rooms_for_user(self, user_id): diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py index 99b4af5555..727f60b3bd 100644 --- a/synapse/storage/stats.py +++ b/synapse/storage/stats.py @@ -226,18 +226,9 @@ class StatsStore(StateDeltasStore): current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn) current_state_events = len(current_state_ids) - joined_members = self._get_user_count_in_room_txn( - txn, room_id, Membership.JOIN - ) - invited_members = self._get_user_count_in_room_txn( - txn, room_id, Membership.INVITE - ) - left_members = self._get_user_count_in_room_txn( - txn, room_id, Membership.LEAVE - ) - banned_members = self._get_user_count_in_room_txn( - txn, room_id, Membership.BAN - ) + + membership_counts = self._get_user_counts_in_room_txn(txn, room_id) + total_state_events = self._get_total_state_event_counts_txn( txn, room_id ) @@ -250,10 +241,10 @@ class StatsStore(StateDeltasStore): { "bucket_size": self.stats_bucket_size, "current_state_events": current_state_events, - "joined_members": joined_members, - "invited_members": invited_members, - "left_members": left_members, - "banned_members": banned_members, + "joined_members": membership_counts.get(Membership.JOIN, 0), + "invited_members": membership_counts.get(Membership.INVITE, 0), + "left_members": membership_counts.get(Membership.LEAVE, 0), + "banned_members": membership_counts.get(Membership.BAN, 0), "state_events": total_state_events, }, ) -- cgit 1.5.1 From e2c46ed851599dc08cc8a822e07c0d4f9a050ee2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 30 May 2019 15:26:38 +0100 Subject: Move deletion from table inside txn --- synapse/storage/stats.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py index 727f60b3bd..a99637d4b4 100644 --- a/synapse/storage/stats.py +++ b/synapse/storage/stats.py @@ -254,10 +254,13 @@ class StatsStore(StateDeltasStore): {"room_id": room_id, "token": current_token}, ) + # We've finished a room. Delete it from the table. + self._simple_delete_one_txn( + txn, TEMP_TABLE + "_rooms", {"room_id": room_id}, + ) + yield self.runInteraction("update_room_stats", _fetch_data) - # We've finished a room. Delete it from the table. - yield self._simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id}) # Update the remaining counter. progress["remaining"] -= 1 yield self.runInteraction( -- cgit 1.5.1 From 5ac75fc9a2d80ddf2974d281c381f82515606403 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 30 May 2019 15:26:55 +0100 Subject: Join against events to use its room_id index --- synapse/storage/events_worker.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index b56c83e460..1782428048 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -618,7 +618,12 @@ class EventsWorkerStore(SQLBaseStore): """ See get_total_state_event_counts. """ - sql = "SELECT COUNT(*) FROM state_events WHERE room_id=?" + # We join against the events table as that has an index on room_id + sql = """ + SELECT COUNT(*) FROM state_events + INNER JOIN events USING (room_id, event_id) + WHERE room_id=? + """ txn.execute(sql, (room_id,)) row = txn.fetchone() return row[0] if row else 0 -- cgit 1.5.1 From 847b9dcd1c9d7d7a43333e85f69dc78471095475 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 31 May 2019 09:54:46 +0100 Subject: Make max_delta equal to period * 10% --- synapse/config/registration.py | 15 ++++----------- synapse/storage/_base.py | 7 +++---- tests/rest/client/v2_alpha/test_register.py | 18 +----------------- 3 files changed, 8 insertions(+), 32 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/config/registration.py b/synapse/config/registration.py index b4fd4af368..1835b4b1f3 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -39,9 +39,7 @@ class AccountValidityConfig(Config): else: self.renew_email_subject = "Renew your %(app)s account" - self.startup_job_max_delta = self.parse_duration( - config.get("startup_job_max_delta", 0), - ) + self.startup_job_max_delta = self.period * 10. / 100. if self.renew_by_email_enabled and "public_baseurl" not in synapse_config: raise ConfigError("Can't send renewal emails without 'public_baseurl'") @@ -133,20 +131,15 @@ class RegistrationConfig(Config): # This means that, if a validity period is set, and Synapse is restarted (it will # then derive an expiration date from the current validity period), and some time # after that the validity period changes and Synapse is restarted, the users' - # expiration dates won't be updated unless their account is manually renewed. - # - # If set, the ``startup_job_max_delta`` optional setting will make the startup job - # described above set a random expiration date between t + period and - # t + period + startup_job_max_delta, t being the date and time at which the job - # sets the expiration date for a given user. This is useful for server admins that - # want to avoid Synapse sending a lot of renewal emails at once. + # expiration dates won't be updated unless their account is manually renewed. This + # date will be randomly selected within a range [now + period ; now + period + d], + # where d is equal to 10% of the validity period. # #account_validity: # enabled: True # period: 6w # renew_at: 1w # renew_email_subject: "Renew your %%(app)s account" - # startup_job_max_delta: 2d # The user must provide all of the below types of 3PID when registering. # diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 40802fd3dc..7f944ec717 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -329,14 +329,13 @@ class SQLBaseStore(object): user_id (str): User ID to set an expiration date for. use_delta (bool): If set to False, the expiration date for the user will be now + validity period. If set to True, this expiration date will be a - random value in the [now + period; now + period + max_delta] range, - max_delta being the configured value for the size of the range, unless - delta is 0, in which case it sets it to now + period. + random value in the [now + period; now + period + d] range, d being a + delta equal to 10% of the validity period. """ now_ms = self._clock.time_msec() expiration_ts = now_ms + self._account_validity.period - if use_delta and self._account_validity.startup_job_max_delta: + if use_delta: expiration_ts = self.rand.randrange( expiration_ts, expiration_ts + self._account_validity.startup_job_max_delta, diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 68654e25ab..711628ded1 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -436,7 +436,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.validity_period = 10 - self.max_delta = 10 + self.max_delta = self.validity_period * 10. / 100. config = self.default_config() @@ -453,22 +453,6 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): return self.hs def test_background_job(self): - """ - Tests whether the account validity startup background job does the right thing, - which is sticking an expiration date to every account that doesn't already have - one. - """ - user_id = self.register_user("kermit", "user") - - self.hs.config.account_validity.startup_job_max_delta = 0 - - now_ms = self.hs.clock.time_msec() - self.get_success(self.store._set_expiration_date_when_missing()) - - res = self.get_success(self.store.get_expiration_ts_for_user(user_id)) - self.assertEqual(res, now_ms + self.validity_period) - - def test_background_job_with_max_delta(self): """ Tests the same thing as test_background_job, except that it sets the startup_job_max_delta parameter and checks that the expiration date is within the -- cgit 1.5.1 From 5037326d6624d1d1780a0536d19ff79e275f8735 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 30 May 2019 15:37:57 +0100 Subject: Add indices. Remove room_ids accidentally added We have to do this by re-inserting a background update and recreating tables, as the tables only get created during a background update and will later be deleted. We also make sure that we remove any entries that should have been removed but weren't due to a race that has been fixed in a previous commit. --- synapse/storage/schema/delta/54/stats2.sql | 28 ++++++++++++++++++++ synapse/storage/stats.py | 41 ++++++++++++++++++++---------- 2 files changed, 56 insertions(+), 13 deletions(-) create mode 100644 synapse/storage/schema/delta/54/stats2.sql (limited to 'synapse/storage') diff --git a/synapse/storage/schema/delta/54/stats2.sql b/synapse/storage/schema/delta/54/stats2.sql new file mode 100644 index 0000000000..3b2d48447f --- /dev/null +++ b/synapse/storage/schema/delta/54/stats2.sql @@ -0,0 +1,28 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This delta file gets run after `54/stats.sql` delta. + +-- We want to add some indices to the temporary stats table, so we re-insert +-- 'populate_stats_createtables' if we are still processing the rooms update. +INSERT INTO background_updates (update_name, progress_json) + SELECT 'populate_stats_createtables', '{}' + WHERE + 'populate_stats_process_rooms' IN ( + SELECT update_name FROM background_updates + ) + AND 'populate_stats_createtables' NOT IN ( -- don't insert if already exists + SELECT update_name FROM background_updates + ); diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py index a99637d4b4..1c0b183a56 100644 --- a/synapse/storage/stats.py +++ b/synapse/storage/stats.py @@ -18,6 +18,7 @@ import logging from twisted.internet import defer from synapse.api.constants import EventTypes, Membership +from synapse.storage.prepare_database import get_statements from synapse.storage.state_deltas import StateDeltasStore from synapse.util.caches.descriptors import cached @@ -69,12 +70,25 @@ class StatsStore(StateDeltasStore): # Get all the rooms that we want to process. def _make_staging_area(txn): - sql = ( - "CREATE TABLE IF NOT EXISTS " - + TEMP_TABLE - + "_rooms(room_id TEXT NOT NULL, events BIGINT NOT NULL)" - ) - txn.execute(sql) + # Create the temporary tables + stmts = get_statements(""" + -- We just recreate the table, we'll be reinserting the + -- correct entries again later anyway. + DROP TABLE IF EXISTS {temp}_rooms; + + CREATE TABLE IF NOT EXISTS {temp}_rooms( + room_id TEXT NOT NULL, + events BIGINT NOT NULL + ); + + CREATE INDEX {temp}_rooms_events + ON {temp}_rooms(events); + CREATE INDEX {temp}_rooms_id + ON {temp}_rooms(room_id); + """.format(temp=TEMP_TABLE).splitlines()) + + for statement in stmts: + txn.execute(statement) sql = ( "CREATE TABLE IF NOT EXISTS " @@ -83,15 +97,16 @@ class StatsStore(StateDeltasStore): ) txn.execute(sql) - # Get rooms we want to process from the database + # Get rooms we want to process from the database, only adding + # those that we haven't (i.e. those not in room_stats_earliest_token) sql = """ - SELECT room_id, count(*) FROM current_state_events - GROUP BY room_id - """ + INSERT INTO %s_rooms (room_id, events) + SELECT c.room_id, count(*) FROM current_state_events AS c + LEFT JOIN room_stats_earliest_token AS t USING (room_id) + WHERE t.room_id IS NULL + GROUP BY c.room_id + """ % (TEMP_TABLE,) txn.execute(sql) - rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()] - self._simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) - del rooms new_pos = yield self.get_max_stream_id_in_current_state_deltas() yield self.runInteraction("populate_stats_temp_build", _make_staging_area) -- cgit 1.5.1 From 4d794dae210ce30e87d8a6b9ee2f9b481cadf539 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 31 May 2019 11:09:34 +0100 Subject: Move delta from +10% to -10% --- synapse/config/registration.py | 2 +- synapse/storage/_base.py | 4 ++-- tests/rest/client/v2_alpha/test_register.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 4af825a2ab..aad3400819 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -132,7 +132,7 @@ class RegistrationConfig(Config): # then derive an expiration date from the current validity period), and some time # after that the validity period changes and Synapse is restarted, the users' # expiration dates won't be updated unless their account is manually renewed. This - # date will be randomly selected within a range [now + period ; now + period + d], + # date will be randomly selected within a range [now + period - d ; now + period], # where d is equal to 10%% of the validity period. # #account_validity: diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 7f944ec717..086318a530 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -329,7 +329,7 @@ class SQLBaseStore(object): user_id (str): User ID to set an expiration date for. use_delta (bool): If set to False, the expiration date for the user will be now + validity period. If set to True, this expiration date will be a - random value in the [now + period; now + period + d] range, d being a + random value in the [now + period - d ; now + period] range, d being a delta equal to 10% of the validity period. """ now_ms = self._clock.time_msec() @@ -337,8 +337,8 @@ class SQLBaseStore(object): if use_delta: expiration_ts = self.rand.randrange( + expiration_ts - self._account_validity.startup_job_max_delta, expiration_ts, - expiration_ts + self._account_validity.startup_job_max_delta, ) self._simple_insert_txn( diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 711628ded1..0cb6a363d6 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -467,5 +467,5 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): res = self.get_success(self.store.get_expiration_ts_for_user(user_id)) - self.assertLessEqual(res, now_ms + self.validity_period + self.max_delta) - self.assertGreaterEqual(res, now_ms + self.validity_period) + self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta) + self.assertLessEqual(res, now_ms + self.validity_period) -- cgit 1.5.1 From fa4b54aca57bebc94e2b763abdae79343a08f969 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 3 Jun 2019 17:06:54 +0100 Subject: Ignore room state with null bytes in for room stats --- synapse/storage/stats.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) (limited to 'synapse/storage') diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py index 1c0b183a56..1f39ef211a 100644 --- a/synapse/storage/stats.py +++ b/synapse/storage/stats.py @@ -328,6 +328,21 @@ class StatsStore(StateDeltasStore): room_id (str) fields (dict[str:Any]) """ + + # For whatever reason some of the fields may contain null bytes, which + # postgres isn't a fan of, so we replace those fields with null. + for col in ( + "join_rules", + "history_visibility", + "encryption", + "name", + "topic", + "avatar", + "canonical_alias" + ): + if "\0" in fields.get(col, ""): + fields[col] = None + return self._simple_upsert( table="room_state", keyvalues={"room_id": room_id}, -- cgit 1.5.1 From 0a56966f7d4879f9d517c96a3c714accdce4e17f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 3 Jun 2019 17:42:52 +0100 Subject: Fix --- synapse/storage/stats.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py index 1f39ef211a..ff266b09b0 100644 --- a/synapse/storage/stats.py +++ b/synapse/storage/stats.py @@ -340,7 +340,8 @@ class StatsStore(StateDeltasStore): "avatar", "canonical_alias" ): - if "\0" in fields.get(col, ""): + field = fields.get(col) + if field and "\0" in field: fields[col] = None return self._simple_upsert( -- cgit 1.5.1