summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2019-08-13 10:52:19 +0100
committerRichard van der Hoff <richard@matrix.org>2019-08-13 10:52:19 +0100
commitbe362cb8f8157683196e15af7bd4d207fc522e3a (patch)
tree7cad94bc789b0efd76e2eea73f334cb34080ace3
parentMerge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes (diff)
parentMerge pull request #5826 from matrix-org/erikj/reduce_event_pauses (diff)
downloadsynapse-be362cb8f8157683196e15af7bd4d207fc522e3a.tar.xz
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
-rw-r--r--.buildkite/worker-blacklist4
-rw-r--r--changelog.d/5788.bugfix1
-rw-r--r--changelog.d/5798.bugfix1
-rw-r--r--changelog.d/5807.feature1
-rw-r--r--changelog.d/5825.bugfix1
-rw-r--r--changelog.d/5826.misc1
-rw-r--r--changelog.d/5839.bugfix1
-rw-r--r--changelog.d/5843.misc1
-rw-r--r--contrib/purge_api/purge_remote_media.sh2
-rw-r--r--docs/sample_config.yaml10
-rw-r--r--synapse/config/registration.py45
-rw-r--r--synapse/handlers/account_validity.py10
-rw-r--r--synapse/handlers/sync.py43
-rw-r--r--synapse/res/templates/account_renewed.html1
-rw-r--r--synapse/res/templates/invalid_token.html1
-rw-r--r--synapse/rest/client/v1/room.py14
-rw-r--r--synapse/rest/client/v2_alpha/account_validity.py25
-rw-r--r--synapse/storage/events.py270
-rw-r--r--synapse/storage/events_worker.py213
-rw-r--r--tests/rest/client/v2_alpha/test_register.py37
-rw-r--r--tests/storage/test_redaction.py70
21 files changed, 498 insertions, 254 deletions
diff --git a/.buildkite/worker-blacklist b/.buildkite/worker-blacklist

index 8ed8eef1a3..cda5c84e94 100644 --- a/.buildkite/worker-blacklist +++ b/.buildkite/worker-blacklist
@@ -3,10 +3,6 @@ Message history can be paginated -m.room.history_visibility == "world_readable" allows/forbids appropriately for Guest users - -m.room.history_visibility == "world_readable" allows/forbids appropriately for Real users - Can re-join room if re-invited /upgrade creates a new room diff --git a/changelog.d/5788.bugfix b/changelog.d/5788.bugfix new file mode 100644
index 0000000000..5632f3cb99 --- /dev/null +++ b/changelog.d/5788.bugfix
@@ -0,0 +1 @@ +Correctly handle redactions of redactions. diff --git a/changelog.d/5798.bugfix b/changelog.d/5798.bugfix new file mode 100644
index 0000000000..7db2c37af5 --- /dev/null +++ b/changelog.d/5798.bugfix
@@ -0,0 +1 @@ +Return 404 instead of 403 when accessing /rooms/{roomId}/event/{eventId} for an event without the appropriate permissions. diff --git a/changelog.d/5807.feature b/changelog.d/5807.feature new file mode 100644
index 0000000000..8b7d29a23c --- /dev/null +++ b/changelog.d/5807.feature
@@ -0,0 +1 @@ +Allow defining HTML templates to serve the user on account renewal attempt when using the account validity feature. diff --git a/changelog.d/5825.bugfix b/changelog.d/5825.bugfix new file mode 100644
index 0000000000..fb2c6f821d --- /dev/null +++ b/changelog.d/5825.bugfix
@@ -0,0 +1 @@ +Fix bug where user `/sync` stream could get wedged in rare circumstances. diff --git a/changelog.d/5826.misc b/changelog.d/5826.misc new file mode 100644
index 0000000000..9abed11bbe --- /dev/null +++ b/changelog.d/5826.misc
@@ -0,0 +1 @@ +Reduce global pauses in the events stream caused by expensive state resolution during persistence. diff --git a/changelog.d/5839.bugfix b/changelog.d/5839.bugfix new file mode 100644
index 0000000000..5775bfa653 --- /dev/null +++ b/changelog.d/5839.bugfix
@@ -0,0 +1 @@ +The purge_remote_media.sh script was fixed. diff --git a/changelog.d/5843.misc b/changelog.d/5843.misc new file mode 100644
index 0000000000..e7e7d572b7 --- /dev/null +++ b/changelog.d/5843.misc
@@ -0,0 +1 @@ +Whitelist history visbility sytests in worker mode tests. diff --git a/contrib/purge_api/purge_remote_media.sh b/contrib/purge_api/purge_remote_media.sh
index 99c07c663d..77220d3bd5 100644 --- a/contrib/purge_api/purge_remote_media.sh +++ b/contrib/purge_api/purge_remote_media.sh
@@ -51,4 +51,4 @@ TOKEN=$(sql "SELECT token FROM access_tokens WHERE user_id='$ADMIN' ORDER BY id # finally start pruning media: ############################################################################### set -x # for debugging the generated string -curl --header "Authorization: Bearer $TOKEN" -v POST "$API_URL/admin/purge_media_cache/?before_ts=$UNIX_TIMESTAMP" +curl --header "Authorization: Bearer $TOKEN" -X POST "$API_URL/admin/purge_media_cache/?before_ts=$UNIX_TIMESTAMP" diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 08316597fa..1b206fe6bf 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml
@@ -802,6 +802,16 @@ uploads_path: "DATADIR/uploads" # period: 6w # renew_at: 1w # renew_email_subject: "Renew your %(app)s account" +# # Directory in which Synapse will try to find the HTML files to serve to the +# # user when trying to renew an account. Optional, defaults to +# # synapse/res/templates. +# template_dir: "res/templates" +# # HTML to be displayed to the user after they successfully renewed their +# # account. Optional. +# account_renewed_html_path: "account_renewed.html" +# # HTML to be displayed when the user tries to renew an account with an invalid +# # renewal token. Optional. +# invalid_token_html_path: "invalid_token.html" # Time that a user's session remains valid for, after they log in. # diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index c3de7a4e32..e2bee3c116 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py
@@ -13,8 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from distutils.util import strtobool +import pkg_resources + from synapse.config._base import Config, ConfigError from synapse.types import RoomAlias from synapse.util.stringutils import random_string_with_symbols @@ -41,8 +44,36 @@ class AccountValidityConfig(Config): self.startup_job_max_delta = self.period * 10.0 / 100.0 - if self.renew_by_email_enabled and "public_baseurl" not in synapse_config: - raise ConfigError("Can't send renewal emails without 'public_baseurl'") + if self.renew_by_email_enabled: + if "public_baseurl" not in synapse_config: + raise ConfigError("Can't send renewal emails without 'public_baseurl'") + + template_dir = config.get("template_dir") + + if not template_dir: + template_dir = pkg_resources.resource_filename("synapse", "res/templates") + + if "account_renewed_html_path" in config: + file_path = os.path.join(template_dir, config["account_renewed_html_path"]) + + self.account_renewed_html_content = self.read_file( + file_path, "account_validity.account_renewed_html_path" + ) + else: + self.account_renewed_html_content = ( + "<html><body>Your account has been successfully renewed.</body><html>" + ) + + if "invalid_token_html_path" in config: + file_path = os.path.join(template_dir, config["invalid_token_html_path"]) + + self.invalid_token_html_content = self.read_file( + file_path, "account_validity.invalid_token_html_path" + ) + else: + self.invalid_token_html_content = ( + "<html><body>Invalid renewal token.</body><html>" + ) class RegistrationConfig(Config): @@ -145,6 +176,16 @@ class RegistrationConfig(Config): # period: 6w # renew_at: 1w # renew_email_subject: "Renew your %%(app)s account" + # # Directory in which Synapse will try to find the HTML files to serve to the + # # user when trying to renew an account. Optional, defaults to + # # synapse/res/templates. + # template_dir: "res/templates" + # # HTML to be displayed to the user after they successfully renewed their + # # account. Optional. + # account_renewed_html_path: "account_renewed.html" + # # HTML to be displayed when the user tries to renew an account with an invalid + # # renewal token. Optional. + # invalid_token_html_path: "invalid_token.html" # Time that a user's session remains valid for, after they log in. # diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 930204e2d0..34574f1a12 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py
@@ -226,11 +226,19 @@ class AccountValidityHandler(object): Args: renewal_token (str): Token sent with the renewal request. + Returns: + bool: Whether the provided token is valid. """ - user_id = yield self.store.get_user_from_renewal_token(renewal_token) + try: + user_id = yield self.store.get_user_from_renewal_token(renewal_token) + except StoreError: + defer.returnValue(False) + logger.debug("Renewing an account for user %s", user_id) yield self.renew_account_for_user(user_id) + defer.returnValue(True) + @defer.inlineCallbacks def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False): """Renews the account attached to a given user by pushing back the diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 5c9150c9f8..d2b924e76c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py
@@ -784,9 +784,17 @@ class SyncHandler(object): lazy_load_members=lazy_load_members, ) elif batch.limited: - state_at_timeline_start = yield self.store.get_state_ids_for_event( - batch.events[0].event_id, state_filter=state_filter - ) + if batch: + state_at_timeline_start = yield self.store.get_state_ids_for_event( + batch.events[0].event_id, state_filter=state_filter + ) + else: + # Its not clear how we get here, but empirically we do + # (#5407). Logging has been added elsewhere to try and + # figure out where this state comes from. + state_at_timeline_start = yield self.get_state_at( + room_id, stream_position=now_token, state_filter=state_filter + ) # for now, we disable LL for gappy syncs - see # https://github.com/vector-im/riot-web/issues/7211#issuecomment-419976346 @@ -806,9 +814,17 @@ class SyncHandler(object): room_id, stream_position=since_token, state_filter=state_filter ) - current_state_ids = yield self.store.get_state_ids_for_event( - batch.events[-1].event_id, state_filter=state_filter - ) + if batch: + current_state_ids = yield self.store.get_state_ids_for_event( + batch.events[-1].event_id, state_filter=state_filter + ) + else: + # Its not clear how we get here, but empirically we do + # (#5407). Logging has been added elsewhere to try and + # figure out where this state comes from. + current_state_ids = yield self.get_state_at( + room_id, stream_position=now_token, state_filter=state_filter + ) state_ids = _calculate_state( timeline_contains=timeline_state, @@ -1758,6 +1774,21 @@ class SyncHandler(object): newly_joined_room=newly_joined, ) + if not batch and batch.limited: + # This resulted in #5407, which is weird, so lets log! We do it + # here as we have the maximum amount of information. + user_id = sync_result_builder.sync_config.user.to_string() + logger.info( + "Issue #5407: Found limited batch with no events. user %s, room %s," + " sync_config %s, newly_joined %s, events %s, batch %s.", + user_id, + room_id, + sync_config, + newly_joined, + events, + batch, + ) + if newly_joined: # debug for https://github.com/matrix-org/synapse/issues/4422 issue4422_logger.debug( diff --git a/synapse/res/templates/account_renewed.html b/synapse/res/templates/account_renewed.html new file mode 100644
index 0000000000..894da030af --- /dev/null +++ b/synapse/res/templates/account_renewed.html
@@ -0,0 +1 @@ +<html><body>Your account has been successfully renewed.</body><html> diff --git a/synapse/res/templates/invalid_token.html b/synapse/res/templates/invalid_token.html new file mode 100644
index 0000000000..6bd2b98364 --- /dev/null +++ b/synapse/res/templates/invalid_token.html
@@ -0,0 +1 @@ +<html><body>Invalid renewal token.</body><html> diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 6fe1eddcce..4b2344e696 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py
@@ -568,14 +568,22 @@ class RoomEventServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id, event_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) - event = yield self.event_handler.get_event(requester.user, room_id, event_id) + try: + event = yield self.event_handler.get_event( + requester.user, room_id, event_id + ) + except AuthError: + # This endpoint is supposed to return a 404 when the requester does + # not have permission to access the event + # https://matrix.org/docs/spec/client_server/r0.5.0#get-matrix-client-r0-rooms-roomid-event-eventid + raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() if event: event = yield self._event_serializer.serialize_event(event, time_now) return (200, event) - else: - return (404, "Event not found.") + + return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) class RoomEventContextServlet(RestServlet): diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py
index 133c61900a..33f6a23028 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py
@@ -42,6 +42,8 @@ class AccountValidityRenewServlet(RestServlet): self.hs = hs self.account_activity_handler = hs.get_account_validity_handler() self.auth = hs.get_auth() + self.success_html = hs.config.account_validity.account_renewed_html_content + self.failure_html = hs.config.account_validity.invalid_token_html_content @defer.inlineCallbacks def on_GET(self, request): @@ -49,16 +51,23 @@ class AccountValidityRenewServlet(RestServlet): raise SynapseError(400, "Missing renewal token") renewal_token = request.args[b"token"][0] - yield self.account_activity_handler.renew_account(renewal_token.decode("utf8")) + token_valid = yield self.account_activity_handler.renew_account( + renewal_token.decode("utf8") + ) + + if token_valid: + status_code = 200 + response = self.success_html + else: + status_code = 404 + response = self.failure_html - request.setResponseCode(200) + request.setResponseCode(status_code) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader( - b"Content-Length", b"%d" % (len(AccountValidityRenewServlet.SUCCESS_HTML),) - ) - request.write(AccountValidityRenewServlet.SUCCESS_HTML) + request.setHeader(b"Content-Length", b"%d" % (len(response),)) + request.write(response.encode("utf8")) finish_request(request) - return None + defer.returnValue(None) class AccountValiditySendMailServlet(RestServlet): @@ -87,7 +96,7 @@ class AccountValiditySendMailServlet(RestServlet): user_id = requester.user.to_string() yield self.account_activity_handler.send_renewal_email_to_user(user_id) - return (200, {}) + defer.returnValue((200, {})) def register_servlets(hs, http_server): diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 88c0180116..ac876287fc 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py
@@ -364,147 +364,161 @@ class EventsStore( if not events_and_contexts: return - if backfilled: - stream_ordering_manager = self._backfill_id_gen.get_next_mult( - len(events_and_contexts) - ) - else: - stream_ordering_manager = self._stream_id_gen.get_next_mult( - len(events_and_contexts) - ) - - with stream_ordering_manager as stream_orderings: - for (event, context), stream in zip(events_and_contexts, stream_orderings): - event.internal_metadata.stream_ordering = stream - - chunks = [ - events_and_contexts[x : x + 100] - for x in range(0, len(events_and_contexts), 100) - ] - - for chunk in chunks: - # We can't easily parallelize these since different chunks - # might contain the same event. :( + chunks = [ + events_and_contexts[x : x + 100] + for x in range(0, len(events_and_contexts), 100) + ] - # NB: Assumes that we are only persisting events for one room - # at a time. + for chunk in chunks: + # We can't easily parallelize these since different chunks + # might contain the same event. :( - # map room_id->list[event_ids] giving the new forward - # extremities in each room - new_forward_extremeties = {} + # NB: Assumes that we are only persisting events for one room + # at a time. - # map room_id->(type,state_key)->event_id tracking the full - # state in each room after adding these events. - # This is simply used to prefill the get_current_state_ids - # cache - current_state_for_room = {} + # map room_id->list[event_ids] giving the new forward + # extremities in each room + new_forward_extremeties = {} - # map room_id->(to_delete, to_insert) where to_delete is a list - # of type/state keys to remove from current state, and to_insert - # is a map (type,key)->event_id giving the state delta in each - # room - state_delta_for_room = {} + # map room_id->(type,state_key)->event_id tracking the full + # state in each room after adding these events. + # This is simply used to prefill the get_current_state_ids + # cache + current_state_for_room = {} - if not backfilled: - with Measure(self._clock, "_calculate_state_and_extrem"): - # Work out the new "current state" for each room. - # We do this by working out what the new extremities are and then - # calculating the state from that. - events_by_room = {} - for event, context in chunk: - events_by_room.setdefault(event.room_id, []).append( - (event, context) - ) + # map room_id->(to_delete, to_insert) where to_delete is a list + # of type/state keys to remove from current state, and to_insert + # is a map (type,key)->event_id giving the state delta in each + # room + state_delta_for_room = {} - for room_id, ev_ctx_rm in iteritems(events_by_room): - latest_event_ids = yield self.get_latest_event_ids_in_room( - room_id - ) - new_latest_event_ids = yield self._calculate_new_extremities( - room_id, ev_ctx_rm, latest_event_ids + if not backfilled: + with Measure(self._clock, "_calculate_state_and_extrem"): + # Work out the new "current state" for each room. + # We do this by working out what the new extremities are and then + # calculating the state from that. + events_by_room = {} + for event, context in chunk: + events_by_room.setdefault(event.room_id, []).append( + (event, context) + ) + + for room_id, ev_ctx_rm in iteritems(events_by_room): + latest_event_ids = yield self.get_latest_event_ids_in_room( + room_id + ) + new_latest_event_ids = yield self._calculate_new_extremities( + room_id, ev_ctx_rm, latest_event_ids + ) + + latest_event_ids = set(latest_event_ids) + if new_latest_event_ids == latest_event_ids: + # No change in extremities, so no change in state + continue + + # there should always be at least one forward extremity. + # (except during the initial persistence of the send_join + # results, in which case there will be no existing + # extremities, so we'll `continue` above and skip this bit.) + assert new_latest_event_ids, "No forward extremities left!" + + new_forward_extremeties[room_id] = new_latest_event_ids + + len_1 = ( + len(latest_event_ids) == 1 + and len(new_latest_event_ids) == 1 + ) + if len_1: + all_single_prev_not_state = all( + len(event.prev_event_ids()) == 1 + and not event.is_state() + for event, ctx in ev_ctx_rm ) - - latest_event_ids = set(latest_event_ids) - if new_latest_event_ids == latest_event_ids: - # No change in extremities, so no change in state + # Don't bother calculating state if they're just + # a long chain of single ancestor non-state events. + if all_single_prev_not_state: continue - # there should always be at least one forward extremity. - # (except during the initial persistence of the send_join - # results, in which case there will be no existing - # extremities, so we'll `continue` above and skip this bit.) - assert new_latest_event_ids, "No forward extremities left!" - - new_forward_extremeties[room_id] = new_latest_event_ids - - len_1 = ( - len(latest_event_ids) == 1 - and len(new_latest_event_ids) == 1 + state_delta_counter.inc() + if len(new_latest_event_ids) == 1: + state_delta_single_event_counter.inc() + + # This is a fairly handwavey check to see if we could + # have guessed what the delta would have been when + # processing one of these events. + # What we're interested in is if the latest extremities + # were the same when we created the event as they are + # now. When this server creates a new event (as opposed + # to receiving it over federation) it will use the + # forward extremities as the prev_events, so we can + # guess this by looking at the prev_events and checking + # if they match the current forward extremities. + for ev, _ in ev_ctx_rm: + prev_event_ids = set(ev.prev_event_ids()) + if latest_event_ids == prev_event_ids: + state_delta_reuse_delta_counter.inc() + break + + logger.info("Calculating state delta for room %s", room_id) + with Measure( + self._clock, "persist_events.get_new_state_after_events" + ): + res = yield self._get_new_state_after_events( + room_id, + ev_ctx_rm, + latest_event_ids, + new_latest_event_ids, ) - if len_1: - all_single_prev_not_state = all( - len(event.prev_event_ids()) == 1 - and not event.is_state() - for event, ctx in ev_ctx_rm - ) - # Don't bother calculating state if they're just - # a long chain of single ancestor non-state events. - if all_single_prev_not_state: - continue - - state_delta_counter.inc() - if len(new_latest_event_ids) == 1: - state_delta_single_event_counter.inc() - - # This is a fairly handwavey check to see if we could - # have guessed what the delta would have been when - # processing one of these events. - # What we're interested in is if the latest extremities - # were the same when we created the event as they are - # now. When this server creates a new event (as opposed - # to receiving it over federation) it will use the - # forward extremities as the prev_events, so we can - # guess this by looking at the prev_events and checking - # if they match the current forward extremities. - for ev, _ in ev_ctx_rm: - prev_event_ids = set(ev.prev_event_ids()) - if latest_event_ids == prev_event_ids: - state_delta_reuse_delta_counter.inc() - break - - logger.info("Calculating state delta for room %s", room_id) + current_state, delta_ids = res + + # If either are not None then there has been a change, + # and we need to work out the delta (or use that + # given) + if delta_ids is not None: + # If there is a delta we know that we've + # only added or replaced state, never + # removed keys entirely. + state_delta_for_room[room_id] = ([], delta_ids) + elif current_state is not None: with Measure( - self._clock, "persist_events.get_new_state_after_events" + self._clock, "persist_events.calculate_state_delta" ): - res = yield self._get_new_state_after_events( - room_id, - ev_ctx_rm, - latest_event_ids, - new_latest_event_ids, + delta = yield self._calculate_state_delta( + room_id, current_state ) - current_state, delta_ids = res - - # If either are not None then there has been a change, - # and we need to work out the delta (or use that - # given) - if delta_ids is not None: - # If there is a delta we know that we've - # only added or replaced state, never - # removed keys entirely. - state_delta_for_room[room_id] = ([], delta_ids) - elif current_state is not None: - with Measure( - self._clock, "persist_events.calculate_state_delta" - ): - delta = yield self._calculate_state_delta( - room_id, current_state - ) - state_delta_for_room[room_id] = delta - - # If we have the current_state then lets prefill - # the cache with it. - if current_state is not None: - current_state_for_room[room_id] = current_state + state_delta_for_room[room_id] = delta + + # If we have the current_state then lets prefill + # the cache with it. + if current_state is not None: + current_state_for_room[room_id] = current_state + + # We want to calculate the stream orderings as late as possible, as + # we only notify after all events with a lesser stream ordering have + # been persisted. I.e. if we spend 10s inside the with block then + # that will delay all subsequent events from being notified about. + # Hence why we do it down here rather than wrapping the entire + # function. + # + # Its safe to do this after calculating the state deltas etc as we + # only need to protect the *persistence* of the events. This is to + # ensure that queries of the form "fetch events since X" don't + # return events and stream positions after events that are still in + # flight, as otherwise subsequent requests "fetch event since Y" + # will not return those events. + # + # Note: Multiple instances of this function cannot be in flight at + # the same time for the same room. + if backfilled: + stream_ordering_manager = self._backfill_id_gen.get_next_mult( + len(chunk) + ) + else: + stream_ordering_manager = self._stream_id_gen.get_next_mult(len(chunk)) + + with stream_ordering_manager as stream_orderings: + for (event, context), stream in zip(chunk, stream_orderings): + event.internal_metadata.stream_ordering = stream yield self.runInteraction( "persist_events", diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 79680ee856..c6fa7f82fd 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py
@@ -29,12 +29,7 @@ from synapse.api.room_versions import EventFormatVersions from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.utils import prune_event -from synapse.logging.context import ( - LoggingContext, - PreserveLoggingContext, - make_deferred_yieldable, - run_in_background, -) +from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import get_domain_from_id from synapse.util import batch_iter @@ -342,13 +337,12 @@ class EventsWorkerStore(SQLBaseStore): log_ctx = LoggingContext.current_context() log_ctx.record_event_fetch(len(missing_events_ids)) - # Note that _enqueue_events is also responsible for turning db rows + # Note that _get_events_from_db is also responsible for turning db rows # into FrozenEvents (via _get_event_from_row), which involves seeing if # the events have been redacted, and if so pulling the redaction event out # of the database to check it. # - # _enqueue_events is a bit of a rubbish name but naming is hard. - missing_events = yield self._enqueue_events( + missing_events = yield self._get_events_from_db( missing_events_ids, allow_rejected=allow_rejected ) @@ -421,28 +415,28 @@ class EventsWorkerStore(SQLBaseStore): The fetch requests. Each entry consists of a list of event ids to be fetched, and a deferred to be completed once the events have been fetched. + + The deferreds are callbacked with a dictionary mapping from event id + to event row. Note that it may well contain additional events that + were not part of this request. """ with Measure(self._clock, "_fetch_event_list"): try: - event_id_lists = list(zip(*event_list))[0] - event_ids = [item for sublist in event_id_lists for item in sublist] + events_to_fetch = set( + event_id for events, _ in event_list for event_id in events + ) row_dict = self._new_transaction( - conn, "do_fetch", [], [], self._fetch_event_rows, event_ids + conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch ) # We only want to resolve deferreds from the main thread - def fire(lst, res): - for ids, d in lst: - if not d.called: - try: - with PreserveLoggingContext(): - d.callback([res[i] for i in ids if i in res]) - except Exception: - logger.exception("Failed to callback") + def fire(): + for _, d in event_list: + d.callback(row_dict) with PreserveLoggingContext(): - self.hs.get_reactor().callFromThread(fire, event_list, row_dict) + self.hs.get_reactor().callFromThread(fire) except Exception as e: logger.exception("do_fetch") @@ -457,13 +451,98 @@ class EventsWorkerStore(SQLBaseStore): self.hs.get_reactor().callFromThread(fire, event_list, e) @defer.inlineCallbacks - def _enqueue_events(self, events, allow_rejected=False): + def _get_events_from_db(self, event_ids, allow_rejected=False): + """Fetch a bunch of events from the database. + + Returned events will be added to the cache for future lookups. + + Args: + event_ids (Iterable[str]): The event_ids of the events to fetch + allow_rejected (bool): Whether to include rejected events + + Returns: + Deferred[Dict[str, _EventCacheEntry]]: + map from event id to result. May return extra events which + weren't asked for. + """ + fetched_events = {} + events_to_fetch = event_ids + + while events_to_fetch: + row_map = yield self._enqueue_events(events_to_fetch) + + # we need to recursively fetch any redactions of those events + redaction_ids = set() + for event_id in events_to_fetch: + row = row_map.get(event_id) + fetched_events[event_id] = row + if row: + redaction_ids.update(row["redactions"]) + + events_to_fetch = redaction_ids.difference(fetched_events.keys()) + if events_to_fetch: + logger.debug("Also fetching redaction events %s", events_to_fetch) + + # build a map from event_id to EventBase + event_map = {} + for event_id, row in fetched_events.items(): + if not row: + continue + assert row["event_id"] == event_id + + rejected_reason = row["rejected_reason"] + + if not allow_rejected and rejected_reason: + continue + + d = json.loads(row["json"]) + internal_metadata = json.loads(row["internal_metadata"]) + + format_version = row["format_version"] + if format_version is None: + # This means that we stored the event before we had the concept + # of a event format version, so it must be a V1 event. + format_version = EventFormatVersions.V1 + + original_ev = event_type_from_format_version(format_version)( + event_dict=d, + internal_metadata_dict=internal_metadata, + rejected_reason=rejected_reason, + ) + + event_map[event_id] = original_ev + + # finally, we can decide whether each one nededs redacting, and build + # the cache entries. + result_map = {} + for event_id, original_ev in event_map.items(): + redactions = fetched_events[event_id]["redactions"] + redacted_event = self._maybe_redact_event_row( + original_ev, redactions, event_map + ) + + cache_entry = _EventCacheEntry( + event=original_ev, redacted_event=redacted_event + ) + + self._get_event_cache.prefill((event_id,), cache_entry) + result_map[event_id] = cache_entry + + return result_map + + @defer.inlineCallbacks + def _enqueue_events(self, events): """Fetches events from the database using the _event_fetch_list. This allows batch and bulk fetching of events - it allows us to fetch events without having to create a new transaction for each request for events. + + Args: + events (Iterable[str]): events to be fetched. + + Returns: + Deferred[Dict[str, Dict]]: map from event id to row data from the database. + May contain events that weren't requested. """ - if not events: - return {} events_d = defer.Deferred() with self._event_fetch_lock: @@ -482,32 +561,12 @@ class EventsWorkerStore(SQLBaseStore): "fetch_events", self.runWithConnection, self._do_fetch ) - logger.debug("Loading %d events", len(events)) + logger.debug("Loading %d events: %s", len(events), events) with PreserveLoggingContext(): - rows = yield events_d - logger.debug("Loaded %d events (%d rows)", len(events), len(rows)) - - if not allow_rejected: - rows[:] = [r for r in rows if r["rejected_reason"] is None] - - res = yield make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background( - self._get_event_from_row, - row["internal_metadata"], - row["json"], - row["redactions"], - rejected_reason=row["rejected_reason"], - format_version=row["format_version"], - ) - for row in rows - ], - consumeErrors=True, - ) - ) + row_map = yield events_d + logger.debug("Loaded %d events (%d rows)", len(events), len(row_map)) - return {e.event.event_id: e for e in res if e} + return row_map def _fetch_event_rows(self, txn, event_ids): """Fetch event rows from the database @@ -580,50 +639,7 @@ class EventsWorkerStore(SQLBaseStore): return event_dict - @defer.inlineCallbacks - def _get_event_from_row( - self, internal_metadata, js, redactions, format_version, rejected_reason=None - ): - """Parse an event row which has been read from the database - - Args: - internal_metadata (str): json-encoded internal_metadata column - js (str): json-encoded event body from event_json - redactions (list[str]): a list of the events which claim to have redacted - this event, from the redactions table - format_version: (str): the 'format_version' column - rejected_reason (str|None): the reason this event was rejected, if any - - Returns: - _EventCacheEntry - """ - with Measure(self._clock, "_get_event_from_row"): - d = json.loads(js) - internal_metadata = json.loads(internal_metadata) - - if format_version is None: - # This means that we stored the event before we had the concept - # of a event format version, so it must be a V1 event. - format_version = EventFormatVersions.V1 - - original_ev = event_type_from_format_version(format_version)( - event_dict=d, - internal_metadata_dict=internal_metadata, - rejected_reason=rejected_reason, - ) - - redacted_event = yield self._maybe_redact_event_row(original_ev, redactions) - - cache_entry = _EventCacheEntry( - event=original_ev, redacted_event=redacted_event - ) - - self._get_event_cache.prefill((original_ev.event_id,), cache_entry) - - return cache_entry - - @defer.inlineCallbacks - def _maybe_redact_event_row(self, original_ev, redactions): + def _maybe_redact_event_row(self, original_ev, redactions, event_map): """Given an event object and a list of possible redacting event ids, determine whether to honour any of those redactions and if so return a redacted event. @@ -631,6 +647,8 @@ class EventsWorkerStore(SQLBaseStore): Args: original_ev (EventBase): redactions (iterable[str]): list of event ids of potential redaction events + event_map (dict[str, EventBase]): other events which have been fetched, in + which we can look up the redaaction events. Map from event id to event. Returns: Deferred[EventBase|None]: if the event should be redacted, a pruned @@ -640,15 +658,9 @@ class EventsWorkerStore(SQLBaseStore): # we choose to ignore redactions of m.room.create events. return None - if original_ev.type == "m.room.redaction": - # ... and redaction events - return None - - redaction_map = yield self._get_events_from_cache_or_db(redactions) - for redaction_id in redactions: - redaction_entry = redaction_map.get(redaction_id) - if not redaction_entry: + redaction_event = event_map.get(redaction_id) + if not redaction_event or redaction_event.rejected_reason: # we don't have the redaction event, or the redaction event was not # authorized. logger.debug( @@ -658,7 +670,6 @@ class EventsWorkerStore(SQLBaseStore): ) continue - redaction_event = redaction_entry.event if redaction_event.room_id != original_ev.room_id: logger.debug( "%s was redacted by %s but redaction was in a different room!", diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 89a3f95c0a..bb867150f4 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py
@@ -323,6 +323,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): "renew_at": 172800000, # Time in ms for 2 days "renew_by_email_enabled": True, "renew_email_subject": "Renew your account", + "account_renewed_html_path": "account_renewed.html", + "invalid_token_html_path": "invalid_token.html", } # Email config. @@ -373,6 +375,19 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) + # Check that we're getting HTML back. + content_type = None + for header in channel.result.get("headers", []): + if header[0] == b"Content-Type": + content_type = header[1] + self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result) + + # Check that the HTML we're getting is the one we expect on a successful renewal. + expected_html = self.hs.config.account_validity.account_renewed_html_content + self.assertEqual( + channel.result["body"], expected_html.encode("utf8"), channel.result + ) + # Move 3 days forward. If the renewal failed, every authed request with # our access token should be denied from now, otherwise they should # succeed. @@ -381,6 +396,28 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) + def test_renewal_invalid_token(self): + # Hit the renewal endpoint with an invalid token and check that it behaves as + # expected, i.e. that it responds with 404 Not Found and the correct HTML. + url = "/_matrix/client/unstable/account_validity/renew?token=123" + request, channel = self.make_request(b"GET", url) + self.render(request) + self.assertEquals(channel.result["code"], b"404", channel.result) + + # Check that we're getting HTML back. + content_type = None + for header in channel.result.get("headers", []): + if header[0] == b"Content-Type": + content_type = header[1] + self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result) + + # Check that the HTML we're getting is the one we expect when using an + # invalid/unknown token. + expected_html = self.hs.config.account_validity.invalid_token_html_content + self.assertEqual( + channel.result["body"], expected_html.encode("utf8"), channel.result + ) + def test_manual_email_send(self): self.email_attempts = [] diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 8488b6edc8..d961b81d48 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py
@@ -17,6 +17,8 @@ from mock import Mock +from twisted.internet import defer + from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.types import RoomID, UserID @@ -216,3 +218,71 @@ class RedactionTestCase(unittest.HomeserverTestCase): }, event.unsigned["redacted_because"], ) + + def test_circular_redaction(self): + redaction_event_id1 = "$redaction1_id:test" + redaction_event_id2 = "$redaction2_id:test" + + class EventIdManglingBuilder: + def __init__(self, base_builder, event_id): + self._base_builder = base_builder + self._event_id = event_id + + @defer.inlineCallbacks + def build(self, prev_event_ids): + built_event = yield self._base_builder.build(prev_event_ids) + built_event.event_id = self._event_id + built_event._event_dict["event_id"] = self._event_id + return built_event + + @property + def room_id(self): + return self._base_builder.room_id + + event_1, context_1 = self.get_success( + self.event_creation_handler.create_new_client_event( + EventIdManglingBuilder( + self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Redaction, + "sender": self.u_alice.to_string(), + "room_id": self.room1.to_string(), + "content": {"reason": "test"}, + "redacts": redaction_event_id2, + }, + ), + redaction_event_id1, + ) + ) + ) + + self.get_success(self.store.persist_event(event_1, context_1)) + + event_2, context_2 = self.get_success( + self.event_creation_handler.create_new_client_event( + EventIdManglingBuilder( + self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Redaction, + "sender": self.u_alice.to_string(), + "room_id": self.room1.to_string(), + "content": {"reason": "test"}, + "redacts": redaction_event_id1, + }, + ), + redaction_event_id2, + ) + ) + ) + self.get_success(self.store.persist_event(event_2, context_2)) + + # fetch one of the redactions + fetched = self.get_success(self.store.get_event(redaction_event_id1)) + + # it should have been redacted + self.assertEqual(fetched.unsigned["redacted_by"], redaction_event_id2) + self.assertEqual( + fetched.unsigned["redacted_because"].event_id, redaction_event_id2 + )