diff options
43 files changed, 961 insertions, 544 deletions
diff --git a/.buildkite/worker-blacklist b/.buildkite/worker-blacklist index 8ed8eef1a3..cda5c84e94 100644 --- a/.buildkite/worker-blacklist +++ b/.buildkite/worker-blacklist @@ -3,10 +3,6 @@ Message history can be paginated -m.room.history_visibility == "world_readable" allows/forbids appropriately for Guest users - -m.room.history_visibility == "world_readable" allows/forbids appropriately for Real users - Can re-join room if re-invited /upgrade creates a new room diff --git a/.gitignore b/.gitignore index a84c41b0c9..f6168a8819 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ _trial_temp*/ /*.log /*.log.config /*.pid +/.python-version /*.signing.key /env/ /homeserver*.yaml diff --git a/changelog.d/5754.feature b/changelog.d/5754.feature new file mode 100644 index 0000000000..c1a09a4dce --- /dev/null +++ b/changelog.d/5754.feature @@ -0,0 +1 @@ +Synapse will no longer serve any media repo admin endpoints when `enable_media_repo` is set to False in the configuration. If a media repo worker is used, the admin APIs relating to the media repo will be served from it instead. \ No newline at end of file diff --git a/changelog.d/5788.bugfix b/changelog.d/5788.bugfix new file mode 100644 index 0000000000..5632f3cb99 --- /dev/null +++ b/changelog.d/5788.bugfix @@ -0,0 +1 @@ +Correctly handle redactions of redactions. diff --git a/changelog.d/5798.bugfix b/changelog.d/5798.bugfix new file mode 100644 index 0000000000..7db2c37af5 --- /dev/null +++ b/changelog.d/5798.bugfix @@ -0,0 +1 @@ +Return 404 instead of 403 when accessing /rooms/{roomId}/event/{eventId} for an event without the appropriate permissions. diff --git a/changelog.d/5801.misc b/changelog.d/5801.misc new file mode 100644 index 0000000000..e19854de82 --- /dev/null +++ b/changelog.d/5801.misc @@ -0,0 +1 @@ +Don't allow clients to send tombstone events that reference the room it's sent in. diff --git a/changelog.d/5805.misc b/changelog.d/5805.misc new file mode 100644 index 0000000000..352cb3db04 --- /dev/null +++ b/changelog.d/5805.misc @@ -0,0 +1 @@ +Deny sending well known state types as non-state events. diff --git a/changelog.d/5806.bugfix b/changelog.d/5806.bugfix new file mode 100644 index 0000000000..c5ca0f5629 --- /dev/null +++ b/changelog.d/5806.bugfix @@ -0,0 +1 @@ +Fix error when trying to login as a deactivated user when using a worker to handle login. diff --git a/changelog.d/5807.feature b/changelog.d/5807.feature new file mode 100644 index 0000000000..8b7d29a23c --- /dev/null +++ b/changelog.d/5807.feature @@ -0,0 +1 @@ +Allow defining HTML templates to serve the user on account renewal attempt when using the account validity feature. diff --git a/changelog.d/5808.misc b/changelog.d/5808.misc new file mode 100644 index 0000000000..cac3fd34d1 --- /dev/null +++ b/changelog.d/5808.misc @@ -0,0 +1 @@ +Handle incorrectly encoded query params correctly by returning a 400. diff --git a/changelog.d/5810.misc b/changelog.d/5810.misc new file mode 100644 index 0000000000..0a5ccbbb3f --- /dev/null +++ b/changelog.d/5810.misc @@ -0,0 +1 @@ +Return 502 not 500 when failing to reach any remote server. diff --git a/changelog.d/5825.bugfix b/changelog.d/5825.bugfix new file mode 100644 index 0000000000..fb2c6f821d --- /dev/null +++ b/changelog.d/5825.bugfix @@ -0,0 +1 @@ +Fix bug where user `/sync` stream could get wedged in rare circumstances. diff --git a/changelog.d/5826.misc b/changelog.d/5826.misc new file mode 100644 index 0000000000..9abed11bbe --- /dev/null +++ b/changelog.d/5826.misc @@ -0,0 +1 @@ +Reduce global pauses in the events stream caused by expensive state resolution during persistence. diff --git a/changelog.d/5836.misc b/changelog.d/5836.misc new file mode 100644 index 0000000000..18f2488201 --- /dev/null +++ b/changelog.d/5836.misc @@ -0,0 +1 @@ +Add a lower bound to well-known lookup cache time to avoid repeated lookups. diff --git a/changelog.d/5839.bugfix b/changelog.d/5839.bugfix new file mode 100644 index 0000000000..5775bfa653 --- /dev/null +++ b/changelog.d/5839.bugfix @@ -0,0 +1 @@ +The purge_remote_media.sh script was fixed. diff --git a/changelog.d/5843.misc b/changelog.d/5843.misc new file mode 100644 index 0000000000..e7e7d572b7 --- /dev/null +++ b/changelog.d/5843.misc @@ -0,0 +1 @@ +Whitelist history visbility sytests in worker mode tests. diff --git a/contrib/purge_api/purge_remote_media.sh b/contrib/purge_api/purge_remote_media.sh index 99c07c663d..77220d3bd5 100644 --- a/contrib/purge_api/purge_remote_media.sh +++ b/contrib/purge_api/purge_remote_media.sh @@ -51,4 +51,4 @@ TOKEN=$(sql "SELECT token FROM access_tokens WHERE user_id='$ADMIN' ORDER BY id # finally start pruning media: ############################################################################### set -x # for debugging the generated string -curl --header "Authorization: Bearer $TOKEN" -v POST "$API_URL/admin/purge_media_cache/?before_ts=$UNIX_TIMESTAMP" +curl --header "Authorization: Bearer $TOKEN" -X POST "$API_URL/admin/purge_media_cache/?before_ts=$UNIX_TIMESTAMP" diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 08316597fa..0c6be30e51 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -565,6 +565,13 @@ log_config: "CONFDIR/SERVERNAME.log.config" +## Media Store ## + +# Enable the media store service in the Synapse master. Uncomment the +# following if you are using a separate media store worker. +# +#enable_media_repo: false + # Directory where uploaded images and attachments are stored. # media_store_path: "DATADIR/media_store" @@ -802,6 +809,16 @@ uploads_path: "DATADIR/uploads" # period: 6w # renew_at: 1w # renew_email_subject: "Renew your %(app)s account" +# # Directory in which Synapse will try to find the HTML files to serve to the +# # user when trying to renew an account. Optional, defaults to +# # synapse/res/templates. +# template_dir: "res/templates" +# # HTML to be displayed to the user after they successfully renewed their +# # account. Optional. +# account_renewed_html_path: "account_renewed.html" +# # HTML to be displayed when the user tries to renew an account with an invalid +# # renewal token. Optional. +# invalid_token_html_path: "invalid_token.html" # Time that a user's session remains valid for, after they log in. # diff --git a/docs/workers.rst b/docs/workers.rst index 7b2d2db533..e11e117418 100644 --- a/docs/workers.rst +++ b/docs/workers.rst @@ -206,6 +206,13 @@ Handles the media repository. It can handle all endpoints starting with:: /_matrix/media/ +And the following regular expressions matching media-specific administration +APIs:: + + ^/_synapse/admin/v1/purge_media_cache$ + ^/_synapse/admin/v1/room/.*/media$ + ^/_synapse/admin/v1/quarantine_media/.*$ + You should also set ``enable_media_repo: False`` in the shared configuration file to stop the main synapse running background jobs related to managing the media repository. diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index ea26f29acb..3a168577c7 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -26,6 +26,7 @@ from synapse.app import _base from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging +from synapse.http.server import JsonResource from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy @@ -35,6 +36,7 @@ from synapse.replication.slave.storage.client_ips import SlavedClientIpStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.rest.admin import register_servlets_for_media_repo from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.server import HomeServer from synapse.storage.engines import create_engine @@ -71,6 +73,12 @@ class MediaRepositoryServer(HomeServer): resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) elif name == "media": media_repo = self.get_media_repository_resource() + + # We need to serve the admin servlets for media on the + # worker. + admin_resource = JsonResource(self, canonical_json=False) + register_servlets_for_media_repo(self, admin_resource) + resources.update( { MEDIA_PREFIX: media_repo, @@ -78,6 +86,7 @@ class MediaRepositoryServer(HomeServer): CONTENT_REPO_PREFIX: ContentRepoResource( self, self.config.uploads_path ), + "/_synapse/admin": admin_resource, } ) diff --git a/synapse/config/registration.py b/synapse/config/registration.py index c3de7a4e32..e2bee3c116 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -13,8 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from distutils.util import strtobool +import pkg_resources + from synapse.config._base import Config, ConfigError from synapse.types import RoomAlias from synapse.util.stringutils import random_string_with_symbols @@ -41,8 +44,36 @@ class AccountValidityConfig(Config): self.startup_job_max_delta = self.period * 10.0 / 100.0 - if self.renew_by_email_enabled and "public_baseurl" not in synapse_config: - raise ConfigError("Can't send renewal emails without 'public_baseurl'") + if self.renew_by_email_enabled: + if "public_baseurl" not in synapse_config: + raise ConfigError("Can't send renewal emails without 'public_baseurl'") + + template_dir = config.get("template_dir") + + if not template_dir: + template_dir = pkg_resources.resource_filename("synapse", "res/templates") + + if "account_renewed_html_path" in config: + file_path = os.path.join(template_dir, config["account_renewed_html_path"]) + + self.account_renewed_html_content = self.read_file( + file_path, "account_validity.account_renewed_html_path" + ) + else: + self.account_renewed_html_content = ( + "<html><body>Your account has been successfully renewed.</body><html>" + ) + + if "invalid_token_html_path" in config: + file_path = os.path.join(template_dir, config["invalid_token_html_path"]) + + self.invalid_token_html_content = self.read_file( + file_path, "account_validity.invalid_token_html_path" + ) + else: + self.invalid_token_html_content = ( + "<html><body>Invalid renewal token.</body><html>" + ) class RegistrationConfig(Config): @@ -145,6 +176,16 @@ class RegistrationConfig(Config): # period: 6w # renew_at: 1w # renew_email_subject: "Renew your %%(app)s account" + # # Directory in which Synapse will try to find the HTML files to serve to the + # # user when trying to renew an account. Optional, defaults to + # # synapse/res/templates. + # template_dir: "res/templates" + # # HTML to be displayed to the user after they successfully renewed their + # # account. Optional. + # account_renewed_html_path: "account_renewed.html" + # # HTML to be displayed when the user tries to renew an account with an invalid + # # renewal token. Optional. + # invalid_token_html_path: "invalid_token.html" # Time that a user's session remains valid for, after they log in. # diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 80a628d9b0..db39697e45 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os from collections import namedtuple @@ -87,6 +88,18 @@ def parse_thumbnail_requirements(thumbnail_sizes): class ContentRepositoryConfig(Config): def read_config(self, config, **kwargs): + + # Only enable the media repo if either the media repo is enabled or the + # current worker app is the media repo. + if ( + self.enable_media_repo is False + and config.worker_app != "synapse.app.media_repository" + ): + self.can_load_media_repo = False + return + else: + self.can_load_media_repo = True + self.max_upload_size = self.parse_size(config.get("max_upload_size", "10M")) self.max_image_pixels = self.parse_size(config.get("max_image_pixels", "32M")) self.max_spider_size = self.parse_size(config.get("max_spider_size", "10M")) @@ -202,6 +215,13 @@ class ContentRepositoryConfig(Config): return ( r""" + ## Media Store ## + + # Enable the media store service in the Synapse master. Uncomment the + # following if you are using a separate media store worker. + # + #enable_media_repo: false + # Directory where uploaded images and attachments are stored. # media_store_path: "%(media_store)s" diff --git a/synapse/events/validator.py b/synapse/events/validator.py index f7ffd1d561..272426e105 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -95,10 +95,10 @@ class EventValidator(object): elif event.type == EventTypes.Topic: self._ensure_strings(event.content, ["topic"]) - + self._ensure_state_event(event) elif event.type == EventTypes.Name: self._ensure_strings(event.content, ["name"]) - + self._ensure_state_event(event) elif event.type == EventTypes.Member: if "membership" not in event.content: raise SynapseError(400, "Content has not membership key") @@ -106,9 +106,25 @@ class EventValidator(object): if event.content["membership"] not in Membership.LIST: raise SynapseError(400, "Invalid membership key") + self._ensure_state_event(event) + elif event.type == EventTypes.Tombstone: + if "replacement_room" not in event.content: + raise SynapseError(400, "Content has no replacement_room key") + + if event.content["replacement_room"] == event.room_id: + raise SynapseError( + 400, "Tombstone cannot reference the room it was sent in" + ) + + self._ensure_state_event(event) + def _ensure_strings(self, d, keys): for s in keys: if s not in d: raise SynapseError(400, "'%s' not in content" % (s,)) if not isinstance(d[s], string_types): raise SynapseError(400, "'%s' not a string type" % (s,)) + + def _ensure_state_event(self, event): + if not event.is_state(): + raise SynapseError(400, "'%s' must be state events" % (event.type,)) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 6e03ce21af..bec3080895 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -511,9 +511,8 @@ class FederationClient(FederationBase): The [Deferred] result of callback, if it succeeds Raises: - SynapseError if the chosen remote server returns a 300/400 code. - - RuntimeError if no servers were reachable. + SynapseError if the chosen remote server returns a 300/400 code, or + no servers were reachable. """ for destination in destinations: if destination == self.server_name: @@ -538,7 +537,7 @@ class FederationClient(FederationBase): except Exception: logger.warn("Failed to %s via %s", description, destination, exc_info=1) - raise RuntimeError("Failed to %s via any server" % (description,)) + raise SynapseError(502, "Failed to %s via any server" % (description,)) def make_membership_event( self, destinations, room_id, user_id, membership, content, params diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 930204e2d0..34574f1a12 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -226,11 +226,19 @@ class AccountValidityHandler(object): Args: renewal_token (str): Token sent with the renewal request. + Returns: + bool: Whether the provided token is valid. """ - user_id = yield self.store.get_user_from_renewal_token(renewal_token) + try: + user_id = yield self.store.get_user_from_renewal_token(renewal_token) + except StoreError: + defer.returnValue(False) + logger.debug("Renewing an account for user %s", user_id) yield self.renew_account_for_user(user_id) + defer.returnValue(True) + @defer.inlineCallbacks def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False): """Renews the account attached to a given user by pushing back the diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 4007284e5b..98da2318a0 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -781,9 +781,17 @@ class SyncHandler(object): lazy_load_members=lazy_load_members, ) elif batch.limited: - state_at_timeline_start = yield self.store.get_state_ids_for_event( - batch.events[0].event_id, state_filter=state_filter - ) + if batch: + state_at_timeline_start = yield self.store.get_state_ids_for_event( + batch.events[0].event_id, state_filter=state_filter + ) + else: + # Its not clear how we get here, but empirically we do + # (#5407). Logging has been added elsewhere to try and + # figure out where this state comes from. + state_at_timeline_start = yield self.get_state_at( + room_id, stream_position=now_token, state_filter=state_filter + ) # for now, we disable LL for gappy syncs - see # https://github.com/vector-im/riot-web/issues/7211#issuecomment-419976346 @@ -803,9 +811,17 @@ class SyncHandler(object): room_id, stream_position=since_token, state_filter=state_filter ) - current_state_ids = yield self.store.get_state_ids_for_event( - batch.events[-1].event_id, state_filter=state_filter - ) + if batch: + current_state_ids = yield self.store.get_state_ids_for_event( + batch.events[-1].event_id, state_filter=state_filter + ) + else: + # Its not clear how we get here, but empirically we do + # (#5407). Logging has been added elsewhere to try and + # figure out where this state comes from. + current_state_ids = yield self.get_state_at( + room_id, stream_position=now_token, state_filter=state_filter + ) state_ids = _calculate_state( timeline_contains=timeline_state, @@ -1755,6 +1771,21 @@ class SyncHandler(object): newly_joined_room=newly_joined, ) + if not batch and batch.limited: + # This resulted in #5407, which is weird, so lets log! We do it + # here as we have the maximum amount of information. + user_id = sync_result_builder.sync_config.user.to_string() + logger.info( + "Issue #5407: Found limited batch with no events. user %s, room %s," + " sync_config %s, newly_joined %s, events %s, batch %s.", + user_id, + room_id, + sync_config, + newly_joined, + events, + batch, + ) + if newly_joined: # debug for https://github.com/matrix-org/synapse/issues/4422 issue4422_logger.debug( diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index a0d5139839..71a15f434d 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -12,10 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json + import logging -import random -import time import attr from netaddr import IPAddress @@ -24,31 +22,16 @@ from zope.interface import implementer from twisted.internet import defer from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.interfaces import IStreamClientEndpoint -from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent, readBody -from twisted.web.http import stringToDatetime +from twisted.web.client import URI, Agent, HTTPConnectionPool from twisted.web.http_headers import Headers from twisted.web.iweb import IAgent from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list +from synapse.http.federation.well_known_resolver import WellKnownResolver from synapse.logging.context import make_deferred_yieldable from synapse.util import Clock -from synapse.util.caches.ttlcache import TTLCache -from synapse.util.metrics import Measure - -# period to cache .well-known results for by default -WELL_KNOWN_DEFAULT_CACHE_PERIOD = 24 * 3600 - -# jitter to add to the .well-known default cache ttl -WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER = 10 * 60 - -# period to cache failure to fetch .well-known for -WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600 - -# cap for .well-known cache period -WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600 logger = logging.getLogger(__name__) -well_known_cache = TTLCache("well-known") @implementer(IAgent) @@ -78,7 +61,7 @@ class MatrixFederationAgent(object): reactor, tls_client_options_factory, _srv_resolver=None, - _well_known_cache=well_known_cache, + _well_known_cache=None, ): self._reactor = reactor self._clock = Clock(reactor) @@ -93,20 +76,15 @@ class MatrixFederationAgent(object): self._pool.maxPersistentPerHost = 5 self._pool.cachedConnectionTimeout = 2 * 60 - _well_known_agent = RedirectAgent( - Agent( + self._well_known_resolver = WellKnownResolver( + self._reactor, + agent=Agent( self._reactor, pool=self._pool, contextFactory=tls_client_options_factory, - ) + ), + well_known_cache=_well_known_cache, ) - self._well_known_agent = _well_known_agent - - # our cache of .well-known lookup results, mapping from server name - # to delegated name. The values can be: - # `bytes`: a valid server-name - # `None`: there is no (valid) .well-known here - self._well_known_cache = _well_known_cache @defer.inlineCallbacks def request(self, method, uri, headers=None, bodyProducer=None): @@ -217,7 +195,10 @@ class MatrixFederationAgent(object): if lookup_well_known: # try a .well-known lookup - well_known_server = yield self._get_well_known(parsed_uri.host) + well_known_result = yield self._well_known_resolver.get_well_known( + parsed_uri.host + ) + well_known_server = well_known_result.delegated_server if well_known_server: # if we found a .well-known, start again, but don't do another @@ -280,85 +261,6 @@ class MatrixFederationAgent(object): target_port=port, ) - @defer.inlineCallbacks - def _get_well_known(self, server_name): - """Attempt to fetch and parse a .well-known file for the given server - - Args: - server_name (bytes): name of the server, from the requested url - - Returns: - Deferred[bytes|None]: either the new server name, from the .well-known, or - None if there was no .well-known file. - """ - try: - result = self._well_known_cache[server_name] - except KeyError: - # TODO: should we linearise so that we don't end up doing two .well-known - # requests for the same server in parallel? - with Measure(self._clock, "get_well_known"): - result, cache_period = yield self._do_get_well_known(server_name) - - if cache_period > 0: - self._well_known_cache.set(server_name, result, cache_period) - - return result - - @defer.inlineCallbacks - def _do_get_well_known(self, server_name): - """Actually fetch and parse a .well-known, without checking the cache - - Args: - server_name (bytes): name of the server, from the requested url - - Returns: - Deferred[Tuple[bytes|None|object],int]: - result, cache period, where result is one of: - - the new server name from the .well-known (as a `bytes`) - - None if there was no .well-known file. - - INVALID_WELL_KNOWN if the .well-known was invalid - """ - uri = b"https://%s/.well-known/matrix/server" % (server_name,) - uri_str = uri.decode("ascii") - logger.info("Fetching %s", uri_str) - try: - response = yield make_deferred_yieldable( - self._well_known_agent.request(b"GET", uri) - ) - body = yield make_deferred_yieldable(readBody(response)) - if response.code != 200: - raise Exception("Non-200 response %s" % (response.code,)) - - parsed_body = json.loads(body.decode("utf-8")) - logger.info("Response from .well-known: %s", parsed_body) - if not isinstance(parsed_body, dict): - raise Exception("not a dict") - if "m.server" not in parsed_body: - raise Exception("Missing key 'm.server'") - except Exception as e: - logger.info("Error fetching %s: %s", uri_str, e) - - # add some randomness to the TTL to avoid a stampeding herd every hour - # after startup - cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD - cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER) - return (None, cache_period) - - result = parsed_body["m.server"].encode("ascii") - - cache_period = _cache_period_from_headers( - response.headers, time_now=self._reactor.seconds - ) - if cache_period is None: - cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD - # add some randomness to the TTL to avoid a stampeding herd every 24 hours - # after startup - cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER) - else: - cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD) - - return (result, cache_period) - @implementer(IStreamClientEndpoint) class LoggingHostnameEndpoint(object): @@ -374,44 +276,6 @@ class LoggingHostnameEndpoint(object): return self.ep.connect(protocol_factory) -def _cache_period_from_headers(headers, time_now=time.time): - cache_controls = _parse_cache_control(headers) - - if b"no-store" in cache_controls: - return 0 - - if b"max-age" in cache_controls: - try: - max_age = int(cache_controls[b"max-age"]) - return max_age - except ValueError: - pass - - expires = headers.getRawHeaders(b"expires") - if expires is not None: - try: - expires_date = stringToDatetime(expires[-1]) - return expires_date - time_now() - except ValueError: - # RFC7234 says 'A cache recipient MUST interpret invalid date formats, - # especially the value "0", as representing a time in the past (i.e., - # "already expired"). - return 0 - - return None - - -def _parse_cache_control(headers): - cache_controls = {} - for hdr in headers.getRawHeaders(b"cache-control", []): - for directive in hdr.split(b","): - splits = [x.strip() for x in directive.split(b"=", 1)] - k = splits[0].lower() - v = splits[1] if len(splits) > 1 else None - cache_controls[k] = v - return cache_controls - - @attr.s class _RoutingResult(object): """The result returned by `_route_matrix_uri`. diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py new file mode 100644 index 0000000000..d2866ff67d --- /dev/null +++ b/synapse/http/federation/well_known_resolver.py @@ -0,0 +1,187 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import random +import time + +import attr + +from twisted.internet import defer +from twisted.web.client import RedirectAgent, readBody +from twisted.web.http import stringToDatetime + +from synapse.logging.context import make_deferred_yieldable +from synapse.util import Clock +from synapse.util.caches.ttlcache import TTLCache +from synapse.util.metrics import Measure + +# period to cache .well-known results for by default +WELL_KNOWN_DEFAULT_CACHE_PERIOD = 24 * 3600 + +# jitter to add to the .well-known default cache ttl +WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER = 10 * 60 + +# period to cache failure to fetch .well-known for +WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600 + +# cap for .well-known cache period +WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600 + +# lower bound for .well-known cache period +WELL_KNOWN_MIN_CACHE_PERIOD = 5 * 60 + +logger = logging.getLogger(__name__) + + +_well_known_cache = TTLCache("well-known") + + +@attr.s(slots=True, frozen=True) +class WellKnownLookupResult(object): + delegated_server = attr.ib() + + +class WellKnownResolver(object): + """Handles well-known lookups for matrix servers. + """ + + def __init__(self, reactor, agent, well_known_cache=None): + self._reactor = reactor + self._clock = Clock(reactor) + + if well_known_cache is None: + well_known_cache = _well_known_cache + + self._well_known_cache = well_known_cache + self._well_known_agent = RedirectAgent(agent) + + @defer.inlineCallbacks + def get_well_known(self, server_name): + """Attempt to fetch and parse a .well-known file for the given server + + Args: + server_name (bytes): name of the server, from the requested url + + Returns: + Deferred[WellKnownLookupResult]: The result of the lookup + """ + try: + result = self._well_known_cache[server_name] + except KeyError: + # TODO: should we linearise so that we don't end up doing two .well-known + # requests for the same server in parallel? + with Measure(self._clock, "get_well_known"): + result, cache_period = yield self._do_get_well_known(server_name) + + if cache_period > 0: + self._well_known_cache.set(server_name, result, cache_period) + + return WellKnownLookupResult(delegated_server=result) + + @defer.inlineCallbacks + def _do_get_well_known(self, server_name): + """Actually fetch and parse a .well-known, without checking the cache + + Args: + server_name (bytes): name of the server, from the requested url + + Returns: + Deferred[Tuple[bytes|None|object],int]: + result, cache period, where result is one of: + - the new server name from the .well-known (as a `bytes`) + - None if there was no .well-known file. + - INVALID_WELL_KNOWN if the .well-known was invalid + """ + uri = b"https://%s/.well-known/matrix/server" % (server_name,) + uri_str = uri.decode("ascii") + logger.info("Fetching %s", uri_str) + try: + response = yield make_deferred_yieldable( + self._well_known_agent.request(b"GET", uri) + ) + body = yield make_deferred_yieldable(readBody(response)) + if response.code != 200: + raise Exception("Non-200 response %s" % (response.code,)) + + parsed_body = json.loads(body.decode("utf-8")) + logger.info("Response from .well-known: %s", parsed_body) + if not isinstance(parsed_body, dict): + raise Exception("not a dict") + if "m.server" not in parsed_body: + raise Exception("Missing key 'm.server'") + except Exception as e: + logger.info("Error fetching %s: %s", uri_str, e) + + # add some randomness to the TTL to avoid a stampeding herd every hour + # after startup + cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD + cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER) + return (None, cache_period) + + result = parsed_body["m.server"].encode("ascii") + + cache_period = _cache_period_from_headers( + response.headers, time_now=self._reactor.seconds + ) + if cache_period is None: + cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD + # add some randomness to the TTL to avoid a stampeding herd every 24 hours + # after startup + cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER) + else: + cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD) + cache_period = max(cache_period, WELL_KNOWN_MIN_CACHE_PERIOD) + + return (result, cache_period) + + +def _cache_period_from_headers(headers, time_now=time.time): + cache_controls = _parse_cache_control(headers) + + if b"no-store" in cache_controls: + return 0 + + if b"max-age" in cache_controls: + try: + max_age = int(cache_controls[b"max-age"]) + return max_age + except ValueError: + pass + + expires = headers.getRawHeaders(b"expires") + if expires is not None: + try: + expires_date = stringToDatetime(expires[-1]) + return expires_date - time_now() + except ValueError: + # RFC7234 says 'A cache recipient MUST interpret invalid date formats, + # especially the value "0", as representing a time in the past (i.e., + # "already expired"). + return 0 + + return None + + +def _parse_cache_control(headers): + cache_controls = {} + for hdr in headers.getRawHeaders(b"cache-control", []): + for directive in hdr.split(b","): + splits = [x.strip() for x in directive.split(b"=", 1)] + k = splits[0].lower() + v = splits[1] if len(splits) > 1 else None + cache_controls[k] = v + return cache_controls diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index f0ca7d9aba..fd07bf7b8e 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -166,7 +166,12 @@ def parse_string_from_args( value = args[name][0] if encoding: - value = value.decode(encoding) + try: + value = value.decode(encoding) + except ValueError: + raise SynapseError( + 400, "Query parameter %r must be %s" % (name, encoding) + ) if allowed_values is not None and value not in allowed_values: message = "Query parameter %r must be one of [%s]" % ( diff --git a/synapse/res/templates/account_renewed.html b/synapse/res/templates/account_renewed.html new file mode 100644 index 0000000000..894da030af --- /dev/null +++ b/synapse/res/templates/account_renewed.html @@ -0,0 +1 @@ +<html><body>Your account has been successfully renewed.</body><html> diff --git a/synapse/res/templates/invalid_token.html b/synapse/res/templates/invalid_token.html new file mode 100644 index 0000000000..6bd2b98364 --- /dev/null +++ b/synapse/res/templates/invalid_token.html @@ -0,0 +1 @@ +<html><body>Invalid renewal token.</body><html> diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 0a7d9b81b2..5720cab425 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -27,7 +27,7 @@ from twisted.internet import defer import synapse from synapse.api.constants import Membership, UserTypes -from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import JsonResource from synapse.http.servlet import ( RestServlet, @@ -36,7 +36,12 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) -from synapse.rest.admin._base import assert_requester_is_admin, assert_user_is_admin +from synapse.rest.admin._base import ( + assert_requester_is_admin, + assert_user_is_admin, + historical_admin_path_patterns, +) +from synapse.rest.admin.media import register_servlets_for_media_repo from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet from synapse.types import UserID, create_requester from synapse.util.versionstring import get_version_string @@ -44,28 +49,6 @@ from synapse.util.versionstring import get_version_string logger = logging.getLogger(__name__) -def historical_admin_path_patterns(path_regex): - """Returns the list of patterns for an admin endpoint, including historical ones - - This is a backwards-compatibility hack. Previously, the Admin API was exposed at - various paths under /_matrix/client. This function returns a list of patterns - matching those paths (as well as the new one), so that existing scripts which rely - on the endpoints being available there are not broken. - - Note that this should only be used for existing endpoints: new ones should just - register for the /_synapse/admin path. - """ - return list( - re.compile(prefix + path_regex) - for prefix in ( - "^/_synapse/admin/v1", - "^/_matrix/client/api/v1/admin", - "^/_matrix/client/unstable/admin", - "^/_matrix/client/r0/admin", - ) - ) - - class UsersRestServlet(RestServlet): PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)") @@ -255,25 +238,6 @@ class WhoisRestServlet(RestServlet): return (200, ret) -class PurgeMediaCacheRestServlet(RestServlet): - PATTERNS = historical_admin_path_patterns("/purge_media_cache") - - def __init__(self, hs): - self.media_repository = hs.get_media_repository() - self.auth = hs.get_auth() - - @defer.inlineCallbacks - def on_POST(self, request): - yield assert_requester_is_admin(self.auth, request) - - before_ts = parse_integer(request, "before_ts", required=True) - logger.info("before_ts: %r", before_ts) - - ret = yield self.media_repository.delete_old_remote_media(before_ts) - - return (200, ret) - - class PurgeHistoryRestServlet(RestServlet): PATTERNS = historical_admin_path_patterns( "/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?" @@ -542,50 +506,6 @@ class ShutdownRoomRestServlet(RestServlet): ) -class QuarantineMediaInRoom(RestServlet): - """Quarantines all media in a room so that no one can download it via - this server. - """ - - PATTERNS = historical_admin_path_patterns("/quarantine_media/(?P<room_id>[^/]+)") - - def __init__(self, hs): - self.store = hs.get_datastore() - self.auth = hs.get_auth() - - @defer.inlineCallbacks - def on_POST(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) - yield assert_user_is_admin(self.auth, requester.user) - - num_quarantined = yield self.store.quarantine_media_ids_in_room( - room_id, requester.user.to_string() - ) - - return (200, {"num_quarantined": num_quarantined}) - - -class ListMediaInRoom(RestServlet): - """Lists all of the media in a given room. - """ - - PATTERNS = historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media") - - def __init__(self, hs): - self.store = hs.get_datastore() - - @defer.inlineCallbacks - def on_GET(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) - if not is_admin: - raise AuthError(403, "You are not a server admin") - - local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id) - - return (200, {"local": local_mxcs, "remote": remote_mxcs}) - - class ResetPasswordRestServlet(RestServlet): """Post request to allow an administrator reset password for a user. This needs user to have administrator access in Synapse. @@ -825,7 +745,6 @@ def register_servlets(hs, http_server): def register_servlets_for_client_rest_resource(hs, http_server): """Register only the servlets which need to be exposed on /_matrix/client/xxx""" WhoisRestServlet(hs).register(http_server) - PurgeMediaCacheRestServlet(hs).register(http_server) PurgeHistoryStatusRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server) PurgeHistoryRestServlet(hs).register(http_server) @@ -834,10 +753,13 @@ def register_servlets_for_client_rest_resource(hs, http_server): GetUsersPaginatedRestServlet(hs).register(http_server) SearchUsersRestServlet(hs).register(http_server) ShutdownRoomRestServlet(hs).register(http_server) - QuarantineMediaInRoom(hs).register(http_server) - ListMediaInRoom(hs).register(http_server) UserRegisterServlet(hs).register(http_server) DeleteGroupAdminRestServlet(hs).register(http_server) AccountValidityRenewServlet(hs).register(http_server) + + # Load the media repo ones if we're using them. + if hs.config.can_load_media_repo: + register_servlets_for_media_repo(hs, http_server) + # don't add more things here: new servlets should only be exposed on # /_synapse/admin so should not go here. Instead register them in AdminRestResource. diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py index 881d67b89c..5a9b08d3ef 100644 --- a/synapse/rest/admin/_base.py +++ b/synapse/rest/admin/_base.py @@ -12,11 +12,36 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import re + from twisted.internet import defer from synapse.api.errors import AuthError +def historical_admin_path_patterns(path_regex): + """Returns the list of patterns for an admin endpoint, including historical ones + + This is a backwards-compatibility hack. Previously, the Admin API was exposed at + various paths under /_matrix/client. This function returns a list of patterns + matching those paths (as well as the new one), so that existing scripts which rely + on the endpoints being available there are not broken. + + Note that this should only be used for existing endpoints: new ones should just + register for the /_synapse/admin path. + """ + return list( + re.compile(prefix + path_regex) + for prefix in ( + "^/_synapse/admin/v1", + "^/_matrix/client/api/v1/admin", + "^/_matrix/client/unstable/admin", + "^/_matrix/client/r0/admin", + ) + ) + + @defer.inlineCallbacks def assert_requester_is_admin(auth, request): """Verify that the requester is an admin user diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py new file mode 100644 index 0000000000..824df919f2 --- /dev/null +++ b/synapse/rest/admin/media.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018-2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.api.errors import AuthError +from synapse.http.servlet import RestServlet, parse_integer +from synapse.rest.admin._base import ( + assert_requester_is_admin, + assert_user_is_admin, + historical_admin_path_patterns, +) + +logger = logging.getLogger(__name__) + + +class QuarantineMediaInRoom(RestServlet): + """Quarantines all media in a room so that no one can download it via + this server. + """ + + PATTERNS = historical_admin_path_patterns("/quarantine_media/(?P<room_id>[^/]+)") + + def __init__(self, hs): + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + @defer.inlineCallbacks + def on_POST(self, request, room_id): + requester = yield self.auth.get_user_by_req(request) + yield assert_user_is_admin(self.auth, requester.user) + + num_quarantined = yield self.store.quarantine_media_ids_in_room( + room_id, requester.user.to_string() + ) + + return (200, {"num_quarantined": num_quarantined}) + + +class ListMediaInRoom(RestServlet): + """Lists all of the media in a given room. + """ + + PATTERNS = historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media") + + def __init__(self, hs): + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def on_GET(self, request, room_id): + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + if not is_admin: + raise AuthError(403, "You are not a server admin") + + local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id) + + return (200, {"local": local_mxcs, "remote": remote_mxcs}) + + +class PurgeMediaCacheRestServlet(RestServlet): + PATTERNS = historical_admin_path_patterns("/purge_media_cache") + + def __init__(self, hs): + self.media_repository = hs.get_media_repository() + self.auth = hs.get_auth() + + @defer.inlineCallbacks + def on_POST(self, request): + yield assert_requester_is_admin(self.auth, request) + + before_ts = parse_integer(request, "before_ts", required=True) + logger.info("before_ts: %r", before_ts) + + ret = yield self.media_repository.delete_old_remote_media(before_ts) + + return (200, ret) + + +def register_servlets_for_media_repo(hs, http_server): + """ + Media repo specific APIs. + """ + PurgeMediaCacheRestServlet(hs).register(http_server) + QuarantineMediaInRoom(hs).register(http_server) + ListMediaInRoom(hs).register(http_server) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 6fe1eddcce..4b2344e696 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -568,14 +568,22 @@ class RoomEventServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id, event_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) - event = yield self.event_handler.get_event(requester.user, room_id, event_id) + try: + event = yield self.event_handler.get_event( + requester.user, room_id, event_id + ) + except AuthError: + # This endpoint is supposed to return a 404 when the requester does + # not have permission to access the event + # https://matrix.org/docs/spec/client_server/r0.5.0#get-matrix-client-r0-rooms-roomid-event-eventid + raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() if event: event = yield self._event_serializer.serialize_event(event, time_now) return (200, event) - else: - return (404, "Event not found.") + + return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) class RoomEventContextServlet(RestServlet): diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py index 133c61900a..33f6a23028 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py @@ -42,6 +42,8 @@ class AccountValidityRenewServlet(RestServlet): self.hs = hs self.account_activity_handler = hs.get_account_validity_handler() self.auth = hs.get_auth() + self.success_html = hs.config.account_validity.account_renewed_html_content + self.failure_html = hs.config.account_validity.invalid_token_html_content @defer.inlineCallbacks def on_GET(self, request): @@ -49,16 +51,23 @@ class AccountValidityRenewServlet(RestServlet): raise SynapseError(400, "Missing renewal token") renewal_token = request.args[b"token"][0] - yield self.account_activity_handler.renew_account(renewal_token.decode("utf8")) + token_valid = yield self.account_activity_handler.renew_account( + renewal_token.decode("utf8") + ) + + if token_valid: + status_code = 200 + response = self.success_html + else: + status_code = 404 + response = self.failure_html - request.setResponseCode(200) + request.setResponseCode(status_code) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader( - b"Content-Length", b"%d" % (len(AccountValidityRenewServlet.SUCCESS_HTML),) - ) - request.write(AccountValidityRenewServlet.SUCCESS_HTML) + request.setHeader(b"Content-Length", b"%d" % (len(response),)) + request.write(response.encode("utf8")) finish_request(request) - return None + defer.returnValue(None) class AccountValiditySendMailServlet(RestServlet): @@ -87,7 +96,7 @@ class AccountValiditySendMailServlet(RestServlet): user_id = requester.user.to_string() yield self.account_activity_handler.send_renewal_email_to_user(user_id) - return (200, {}) + defer.returnValue((200, {})) def register_servlets(hs, http_server): diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 92beefa176..cf5759e9a6 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -33,6 +33,7 @@ from synapse.api.errors import ( RequestSendFailed, SynapseError, ) +from synapse.config._base import ConfigError from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.async_helpers import Linearizer @@ -753,8 +754,11 @@ class MediaRepositoryResource(Resource): """ def __init__(self, hs): - Resource.__init__(self) + # If we're not configured to use it, raise if we somehow got here. + if not hs.config.can_load_media_repo: + raise ConfigError("Synapse is not configured to use a media repo.") + super().__init__() media_repo = hs.get_media_repository() self.putChild(b"upload", UploadResource(hs, media_repo)) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 88c0180116..ac876287fc 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -364,147 +364,161 @@ class EventsStore( if not events_and_contexts: return - if backfilled: - stream_ordering_manager = self._backfill_id_gen.get_next_mult( - len(events_and_contexts) - ) - else: - stream_ordering_manager = self._stream_id_gen.get_next_mult( - len(events_and_contexts) - ) - - with stream_ordering_manager as stream_orderings: - for (event, context), stream in zip(events_and_contexts, stream_orderings): - event.internal_metadata.stream_ordering = stream - - chunks = [ - events_and_contexts[x : x + 100] - for x in range(0, len(events_and_contexts), 100) - ] - - for chunk in chunks: - # We can't easily parallelize these since different chunks - # might contain the same event. :( + chunks = [ + events_and_contexts[x : x + 100] + for x in range(0, len(events_and_contexts), 100) + ] - # NB: Assumes that we are only persisting events for one room - # at a time. + for chunk in chunks: + # We can't easily parallelize these since different chunks + # might contain the same event. :( - # map room_id->list[event_ids] giving the new forward - # extremities in each room - new_forward_extremeties = {} + # NB: Assumes that we are only persisting events for one room + # at a time. - # map room_id->(type,state_key)->event_id tracking the full - # state in each room after adding these events. - # This is simply used to prefill the get_current_state_ids - # cache - current_state_for_room = {} + # map room_id->list[event_ids] giving the new forward + # extremities in each room + new_forward_extremeties = {} - # map room_id->(to_delete, to_insert) where to_delete is a list - # of type/state keys to remove from current state, and to_insert - # is a map (type,key)->event_id giving the state delta in each - # room - state_delta_for_room = {} + # map room_id->(type,state_key)->event_id tracking the full + # state in each room after adding these events. + # This is simply used to prefill the get_current_state_ids + # cache + current_state_for_room = {} - if not backfilled: - with Measure(self._clock, "_calculate_state_and_extrem"): - # Work out the new "current state" for each room. - # We do this by working out what the new extremities are and then - # calculating the state from that. - events_by_room = {} - for event, context in chunk: - events_by_room.setdefault(event.room_id, []).append( - (event, context) - ) + # map room_id->(to_delete, to_insert) where to_delete is a list + # of type/state keys to remove from current state, and to_insert + # is a map (type,key)->event_id giving the state delta in each + # room + state_delta_for_room = {} - for room_id, ev_ctx_rm in iteritems(events_by_room): - latest_event_ids = yield self.get_latest_event_ids_in_room( - room_id - ) - new_latest_event_ids = yield self._calculate_new_extremities( - room_id, ev_ctx_rm, latest_event_ids + if not backfilled: + with Measure(self._clock, "_calculate_state_and_extrem"): + # Work out the new "current state" for each room. + # We do this by working out what the new extremities are and then + # calculating the state from that. + events_by_room = {} + for event, context in chunk: + events_by_room.setdefault(event.room_id, []).append( + (event, context) + ) + + for room_id, ev_ctx_rm in iteritems(events_by_room): + latest_event_ids = yield self.get_latest_event_ids_in_room( + room_id + ) + new_latest_event_ids = yield self._calculate_new_extremities( + room_id, ev_ctx_rm, latest_event_ids + ) + + latest_event_ids = set(latest_event_ids) + if new_latest_event_ids == latest_event_ids: + # No change in extremities, so no change in state + continue + + # there should always be at least one forward extremity. + # (except during the initial persistence of the send_join + # results, in which case there will be no existing + # extremities, so we'll `continue` above and skip this bit.) + assert new_latest_event_ids, "No forward extremities left!" + + new_forward_extremeties[room_id] = new_latest_event_ids + + len_1 = ( + len(latest_event_ids) == 1 + and len(new_latest_event_ids) == 1 + ) + if len_1: + all_single_prev_not_state = all( + len(event.prev_event_ids()) == 1 + and not event.is_state() + for event, ctx in ev_ctx_rm ) - - latest_event_ids = set(latest_event_ids) - if new_latest_event_ids == latest_event_ids: - # No change in extremities, so no change in state + # Don't bother calculating state if they're just + # a long chain of single ancestor non-state events. + if all_single_prev_not_state: continue - # there should always be at least one forward extremity. - # (except during the initial persistence of the send_join - # results, in which case there will be no existing - # extremities, so we'll `continue` above and skip this bit.) - assert new_latest_event_ids, "No forward extremities left!" - - new_forward_extremeties[room_id] = new_latest_event_ids - - len_1 = ( - len(latest_event_ids) == 1 - and len(new_latest_event_ids) == 1 + state_delta_counter.inc() + if len(new_latest_event_ids) == 1: + state_delta_single_event_counter.inc() + + # This is a fairly handwavey check to see if we could + # have guessed what the delta would have been when + # processing one of these events. + # What we're interested in is if the latest extremities + # were the same when we created the event as they are + # now. When this server creates a new event (as opposed + # to receiving it over federation) it will use the + # forward extremities as the prev_events, so we can + # guess this by looking at the prev_events and checking + # if they match the current forward extremities. + for ev, _ in ev_ctx_rm: + prev_event_ids = set(ev.prev_event_ids()) + if latest_event_ids == prev_event_ids: + state_delta_reuse_delta_counter.inc() + break + + logger.info("Calculating state delta for room %s", room_id) + with Measure( + self._clock, "persist_events.get_new_state_after_events" + ): + res = yield self._get_new_state_after_events( + room_id, + ev_ctx_rm, + latest_event_ids, + new_latest_event_ids, ) - if len_1: - all_single_prev_not_state = all( - len(event.prev_event_ids()) == 1 - and not event.is_state() - for event, ctx in ev_ctx_rm - ) - # Don't bother calculating state if they're just - # a long chain of single ancestor non-state events. - if all_single_prev_not_state: - continue - - state_delta_counter.inc() - if len(new_latest_event_ids) == 1: - state_delta_single_event_counter.inc() - - # This is a fairly handwavey check to see if we could - # have guessed what the delta would have been when - # processing one of these events. - # What we're interested in is if the latest extremities - # were the same when we created the event as they are - # now. When this server creates a new event (as opposed - # to receiving it over federation) it will use the - # forward extremities as the prev_events, so we can - # guess this by looking at the prev_events and checking - # if they match the current forward extremities. - for ev, _ in ev_ctx_rm: - prev_event_ids = set(ev.prev_event_ids()) - if latest_event_ids == prev_event_ids: - state_delta_reuse_delta_counter.inc() - break - - logger.info("Calculating state delta for room %s", room_id) + current_state, delta_ids = res + + # If either are not None then there has been a change, + # and we need to work out the delta (or use that + # given) + if delta_ids is not None: + # If there is a delta we know that we've + # only added or replaced state, never + # removed keys entirely. + state_delta_for_room[room_id] = ([], delta_ids) + elif current_state is not None: with Measure( - self._clock, "persist_events.get_new_state_after_events" + self._clock, "persist_events.calculate_state_delta" ): - res = yield self._get_new_state_after_events( - room_id, - ev_ctx_rm, - latest_event_ids, - new_latest_event_ids, + delta = yield self._calculate_state_delta( + room_id, current_state ) - current_state, delta_ids = res - - # If either are not None then there has been a change, - # and we need to work out the delta (or use that - # given) - if delta_ids is not None: - # If there is a delta we know that we've - # only added or replaced state, never - # removed keys entirely. - state_delta_for_room[room_id] = ([], delta_ids) - elif current_state is not None: - with Measure( - self._clock, "persist_events.calculate_state_delta" - ): - delta = yield self._calculate_state_delta( - room_id, current_state - ) - state_delta_for_room[room_id] = delta - - # If we have the current_state then lets prefill - # the cache with it. - if current_state is not None: - current_state_for_room[room_id] = current_state + state_delta_for_room[room_id] = delta + + # If we have the current_state then lets prefill + # the cache with it. + if current_state is not None: + current_state_for_room[room_id] = current_state + + # We want to calculate the stream orderings as late as possible, as + # we only notify after all events with a lesser stream ordering have + # been persisted. I.e. if we spend 10s inside the with block then + # that will delay all subsequent events from being notified about. + # Hence why we do it down here rather than wrapping the entire + # function. + # + # Its safe to do this after calculating the state deltas etc as we + # only need to protect the *persistence* of the events. This is to + # ensure that queries of the form "fetch events since X" don't + # return events and stream positions after events that are still in + # flight, as otherwise subsequent requests "fetch event since Y" + # will not return those events. + # + # Note: Multiple instances of this function cannot be in flight at + # the same time for the same room. + if backfilled: + stream_ordering_manager = self._backfill_id_gen.get_next_mult( + len(chunk) + ) + else: + stream_ordering_manager = self._stream_id_gen.get_next_mult(len(chunk)) + + with stream_ordering_manager as stream_orderings: + for (event, context), stream in zip(chunk, stream_orderings): + event.internal_metadata.stream_ordering = stream yield self.runInteraction( "persist_events", diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 79680ee856..c6fa7f82fd 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -29,12 +29,7 @@ from synapse.api.room_versions import EventFormatVersions from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.utils import prune_event -from synapse.logging.context import ( - LoggingContext, - PreserveLoggingContext, - make_deferred_yieldable, - run_in_background, -) +from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import get_domain_from_id from synapse.util import batch_iter @@ -342,13 +337,12 @@ class EventsWorkerStore(SQLBaseStore): log_ctx = LoggingContext.current_context() log_ctx.record_event_fetch(len(missing_events_ids)) - # Note that _enqueue_events is also responsible for turning db rows + # Note that _get_events_from_db is also responsible for turning db rows # into FrozenEvents (via _get_event_from_row), which involves seeing if # the events have been redacted, and if so pulling the redaction event out # of the database to check it. # - # _enqueue_events is a bit of a rubbish name but naming is hard. - missing_events = yield self._enqueue_events( + missing_events = yield self._get_events_from_db( missing_events_ids, allow_rejected=allow_rejected ) @@ -421,28 +415,28 @@ class EventsWorkerStore(SQLBaseStore): The fetch requests. Each entry consists of a list of event ids to be fetched, and a deferred to be completed once the events have been fetched. + + The deferreds are callbacked with a dictionary mapping from event id + to event row. Note that it may well contain additional events that + were not part of this request. """ with Measure(self._clock, "_fetch_event_list"): try: - event_id_lists = list(zip(*event_list))[0] - event_ids = [item for sublist in event_id_lists for item in sublist] + events_to_fetch = set( + event_id for events, _ in event_list for event_id in events + ) row_dict = self._new_transaction( - conn, "do_fetch", [], [], self._fetch_event_rows, event_ids + conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch ) # We only want to resolve deferreds from the main thread - def fire(lst, res): - for ids, d in lst: - if not d.called: - try: - with PreserveLoggingContext(): - d.callback([res[i] for i in ids if i in res]) - except Exception: - logger.exception("Failed to callback") + def fire(): + for _, d in event_list: + d.callback(row_dict) with PreserveLoggingContext(): - self.hs.get_reactor().callFromThread(fire, event_list, row_dict) + self.hs.get_reactor().callFromThread(fire) except Exception as e: logger.exception("do_fetch") @@ -457,13 +451,98 @@ class EventsWorkerStore(SQLBaseStore): self.hs.get_reactor().callFromThread(fire, event_list, e) @defer.inlineCallbacks - def _enqueue_events(self, events, allow_rejected=False): + def _get_events_from_db(self, event_ids, allow_rejected=False): + """Fetch a bunch of events from the database. + + Returned events will be added to the cache for future lookups. + + Args: + event_ids (Iterable[str]): The event_ids of the events to fetch + allow_rejected (bool): Whether to include rejected events + + Returns: + Deferred[Dict[str, _EventCacheEntry]]: + map from event id to result. May return extra events which + weren't asked for. + """ + fetched_events = {} + events_to_fetch = event_ids + + while events_to_fetch: + row_map = yield self._enqueue_events(events_to_fetch) + + # we need to recursively fetch any redactions of those events + redaction_ids = set() + for event_id in events_to_fetch: + row = row_map.get(event_id) + fetched_events[event_id] = row + if row: + redaction_ids.update(row["redactions"]) + + events_to_fetch = redaction_ids.difference(fetched_events.keys()) + if events_to_fetch: + logger.debug("Also fetching redaction events %s", events_to_fetch) + + # build a map from event_id to EventBase + event_map = {} + for event_id, row in fetched_events.items(): + if not row: + continue + assert row["event_id"] == event_id + + rejected_reason = row["rejected_reason"] + + if not allow_rejected and rejected_reason: + continue + + d = json.loads(row["json"]) + internal_metadata = json.loads(row["internal_metadata"]) + + format_version = row["format_version"] + if format_version is None: + # This means that we stored the event before we had the concept + # of a event format version, so it must be a V1 event. + format_version = EventFormatVersions.V1 + + original_ev = event_type_from_format_version(format_version)( + event_dict=d, + internal_metadata_dict=internal_metadata, + rejected_reason=rejected_reason, + ) + + event_map[event_id] = original_ev + + # finally, we can decide whether each one nededs redacting, and build + # the cache entries. + result_map = {} + for event_id, original_ev in event_map.items(): + redactions = fetched_events[event_id]["redactions"] + redacted_event = self._maybe_redact_event_row( + original_ev, redactions, event_map + ) + + cache_entry = _EventCacheEntry( + event=original_ev, redacted_event=redacted_event + ) + + self._get_event_cache.prefill((event_id,), cache_entry) + result_map[event_id] = cache_entry + + return result_map + + @defer.inlineCallbacks + def _enqueue_events(self, events): """Fetches events from the database using the _event_fetch_list. This allows batch and bulk fetching of events - it allows us to fetch events without having to create a new transaction for each request for events. + + Args: + events (Iterable[str]): events to be fetched. + + Returns: + Deferred[Dict[str, Dict]]: map from event id to row data from the database. + May contain events that weren't requested. """ - if not events: - return {} events_d = defer.Deferred() with self._event_fetch_lock: @@ -482,32 +561,12 @@ class EventsWorkerStore(SQLBaseStore): "fetch_events", self.runWithConnection, self._do_fetch ) - logger.debug("Loading %d events", len(events)) + logger.debug("Loading %d events: %s", len(events), events) with PreserveLoggingContext(): - rows = yield events_d - logger.debug("Loaded %d events (%d rows)", len(events), len(rows)) - - if not allow_rejected: - rows[:] = [r for r in rows if r["rejected_reason"] is None] - - res = yield make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background( - self._get_event_from_row, - row["internal_metadata"], - row["json"], - row["redactions"], - rejected_reason=row["rejected_reason"], - format_version=row["format_version"], - ) - for row in rows - ], - consumeErrors=True, - ) - ) + row_map = yield events_d + logger.debug("Loaded %d events (%d rows)", len(events), len(row_map)) - return {e.event.event_id: e for e in res if e} + return row_map def _fetch_event_rows(self, txn, event_ids): """Fetch event rows from the database @@ -580,50 +639,7 @@ class EventsWorkerStore(SQLBaseStore): return event_dict - @defer.inlineCallbacks - def _get_event_from_row( - self, internal_metadata, js, redactions, format_version, rejected_reason=None - ): - """Parse an event row which has been read from the database - - Args: - internal_metadata (str): json-encoded internal_metadata column - js (str): json-encoded event body from event_json - redactions (list[str]): a list of the events which claim to have redacted - this event, from the redactions table - format_version: (str): the 'format_version' column - rejected_reason (str|None): the reason this event was rejected, if any - - Returns: - _EventCacheEntry - """ - with Measure(self._clock, "_get_event_from_row"): - d = json.loads(js) - internal_metadata = json.loads(internal_metadata) - - if format_version is None: - # This means that we stored the event before we had the concept - # of a event format version, so it must be a V1 event. - format_version = EventFormatVersions.V1 - - original_ev = event_type_from_format_version(format_version)( - event_dict=d, - internal_metadata_dict=internal_metadata, - rejected_reason=rejected_reason, - ) - - redacted_event = yield self._maybe_redact_event_row(original_ev, redactions) - - cache_entry = _EventCacheEntry( - event=original_ev, redacted_event=redacted_event - ) - - self._get_event_cache.prefill((original_ev.event_id,), cache_entry) - - return cache_entry - - @defer.inlineCallbacks - def _maybe_redact_event_row(self, original_ev, redactions): + def _maybe_redact_event_row(self, original_ev, redactions, event_map): """Given an event object and a list of possible redacting event ids, determine whether to honour any of those redactions and if so return a redacted event. @@ -631,6 +647,8 @@ class EventsWorkerStore(SQLBaseStore): Args: original_ev (EventBase): redactions (iterable[str]): list of event ids of potential redaction events + event_map (dict[str, EventBase]): other events which have been fetched, in + which we can look up the redaaction events. Map from event id to event. Returns: Deferred[EventBase|None]: if the event should be redacted, a pruned @@ -640,15 +658,9 @@ class EventsWorkerStore(SQLBaseStore): # we choose to ignore redactions of m.room.create events. return None - if original_ev.type == "m.room.redaction": - # ... and redaction events - return None - - redaction_map = yield self._get_events_from_cache_or_db(redactions) - for redaction_id in redactions: - redaction_entry = redaction_map.get(redaction_id) - if not redaction_entry: + redaction_event = event_map.get(redaction_id) + if not redaction_event or redaction_event.rejected_reason: # we don't have the redaction event, or the redaction event was not # authorized. logger.debug( @@ -658,7 +670,6 @@ class EventsWorkerStore(SQLBaseStore): ) continue - redaction_event = redaction_entry.event if redaction_event.room_id != original_ev.room_id: logger.debug( "%s was redacted by %s but redaction was in a different room!", diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 999c10a308..55e4e84d71 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -569,6 +569,27 @@ class RegistrationWorkerStore(SQLBaseStore): desc="get_id_servers_user_bound", ) + @cachedInlineCallbacks() + def get_user_deactivated_status(self, user_id): + """Retrieve the value for the `deactivated` property for the provided user. + + Args: + user_id (str): The ID of the user to retrieve the status for. + + Returns: + defer.Deferred(bool): The requested value. + """ + + res = yield self._simple_select_one_onecol( + table="users", + keyvalues={"name": user_id}, + retcol="deactivated", + desc="get_user_deactivated_status", + ) + + # Convert the integer into a boolean. + return res == 1 + class RegistrationStore( RegistrationWorkerStore, background_updates.BackgroundUpdateStore @@ -1317,24 +1338,3 @@ class RegistrationStore( user_id, deactivated, ) - - @cachedInlineCallbacks() - def get_user_deactivated_status(self, user_id): - """Retrieve the value for the `deactivated` property for the provided user. - - Args: - user_id (str): The ID of the user to retrieve the status for. - - Returns: - defer.Deferred(bool): The requested value. - """ - - res = yield self._simple_select_one_onecol( - table="users", - keyvalues={"name": user_id}, - retcol="deactivated", - desc="get_user_deactivated_status", - ) - - # Convert the integer into a boolean. - return res == 1 diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 4255add097..1435baede2 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -25,17 +25,19 @@ from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOpti from twisted.internet.protocol import Factory from twisted.protocols.tls import TLSMemoryBIOFactory from twisted.web._newclient import ResponseNeverReceived +from twisted.web.client import Agent from twisted.web.http import HTTPChannel from twisted.web.http_headers import Headers from twisted.web.iweb import IPolicyForHTTPS from synapse.config.homeserver import HomeServerConfig from synapse.crypto.context_factory import ClientTLSOptionsFactory -from synapse.http.federation.matrix_federation_agent import ( - MatrixFederationAgent, +from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent +from synapse.http.federation.srv_resolver import Server +from synapse.http.federation.well_known_resolver import ( + WellKnownResolver, _cache_period_from_headers, ) -from synapse.http.federation.srv_resolver import Server from synapse.logging.context import LoggingContext from synapse.util.caches.ttlcache import TTLCache @@ -79,9 +81,10 @@ class MatrixFederationAgentTests(TestCase): self._config = config = HomeServerConfig() config.parse_config_dict(config_dict, "", "") + self.tls_factory = ClientTLSOptionsFactory(config) self.agent = MatrixFederationAgent( reactor=self.reactor, - tls_client_options_factory=ClientTLSOptionsFactory(config), + tls_client_options_factory=self.tls_factory, _srv_resolver=self.mock_resolver, _well_known_cache=self.well_known_cache, ) @@ -928,20 +931,16 @@ class MatrixFederationAgentTests(TestCase): self.reactor.pump((0.1,)) self.successResultOf(test_d) - @defer.inlineCallbacks - def do_get_well_known(self, serv): - try: - result = yield self.agent._get_well_known(serv) - logger.info("Result from well-known fetch: %s", result) - except Exception as e: - logger.warning("Error fetching well-known: %s", e) - raise - return result - def test_well_known_cache(self): + well_known_resolver = WellKnownResolver( + self.reactor, + Agent(self.reactor, contextFactory=self.tls_factory), + well_known_cache=self.well_known_cache, + ) + self.reactor.lookups["testserv"] = "1.2.3.4" - fetch_d = self.do_get_well_known(b"testserv") + fetch_d = well_known_resolver.get_well_known(b"testserv") # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients @@ -953,26 +952,26 @@ class MatrixFederationAgentTests(TestCase): well_known_server = self._handle_well_known_connection( client_factory, expected_sni=b"testserv", - response_headers={b"Cache-Control": b"max-age=10"}, + response_headers={b"Cache-Control": b"max-age=1000"}, content=b'{ "m.server": "target-server" }', ) r = self.successResultOf(fetch_d) - self.assertEqual(r, b"target-server") + self.assertEqual(r.delegated_server, b"target-server") # close the tcp connection well_known_server.loseConnection() # repeat the request: it should hit the cache - fetch_d = self.do_get_well_known(b"testserv") + fetch_d = well_known_resolver.get_well_known(b"testserv") r = self.successResultOf(fetch_d) - self.assertEqual(r, b"target-server") + self.assertEqual(r.delegated_server, b"target-server") # expire the cache - self.reactor.pump((10.0,)) + self.reactor.pump((1000.0,)) # now it should connect again - fetch_d = self.do_get_well_known(b"testserv") + fetch_d = well_known_resolver.get_well_known(b"testserv") self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) @@ -986,7 +985,7 @@ class MatrixFederationAgentTests(TestCase): ) r = self.successResultOf(fetch_d) - self.assertEqual(r, b"other-server") + self.assertEqual(r.delegated_server, b"other-server") class TestCachePeriodFromHeaders(TestCase): diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 89a3f95c0a..bb867150f4 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -323,6 +323,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): "renew_at": 172800000, # Time in ms for 2 days "renew_by_email_enabled": True, "renew_email_subject": "Renew your account", + "account_renewed_html_path": "account_renewed.html", + "invalid_token_html_path": "invalid_token.html", } # Email config. @@ -373,6 +375,19 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) + # Check that we're getting HTML back. + content_type = None + for header in channel.result.get("headers", []): + if header[0] == b"Content-Type": + content_type = header[1] + self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result) + + # Check that the HTML we're getting is the one we expect on a successful renewal. + expected_html = self.hs.config.account_validity.account_renewed_html_content + self.assertEqual( + channel.result["body"], expected_html.encode("utf8"), channel.result + ) + # Move 3 days forward. If the renewal failed, every authed request with # our access token should be denied from now, otherwise they should # succeed. @@ -381,6 +396,28 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) + def test_renewal_invalid_token(self): + # Hit the renewal endpoint with an invalid token and check that it behaves as + # expected, i.e. that it responds with 404 Not Found and the correct HTML. + url = "/_matrix/client/unstable/account_validity/renew?token=123" + request, channel = self.make_request(b"GET", url) + self.render(request) + self.assertEquals(channel.result["code"], b"404", channel.result) + + # Check that we're getting HTML back. + content_type = None + for header in channel.result.get("headers", []): + if header[0] == b"Content-Type": + content_type = header[1] + self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result) + + # Check that the HTML we're getting is the one we expect when using an + # invalid/unknown token. + expected_html = self.hs.config.account_validity.invalid_token_html_content + self.assertEqual( + channel.result["body"], expected_html.encode("utf8"), channel.result + ) + def test_manual_email_send(self): self.email_attempts = [] diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 8488b6edc8..d961b81d48 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -17,6 +17,8 @@ from mock import Mock +from twisted.internet import defer + from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.types import RoomID, UserID @@ -216,3 +218,71 @@ class RedactionTestCase(unittest.HomeserverTestCase): }, event.unsigned["redacted_because"], ) + + def test_circular_redaction(self): + redaction_event_id1 = "$redaction1_id:test" + redaction_event_id2 = "$redaction2_id:test" + + class EventIdManglingBuilder: + def __init__(self, base_builder, event_id): + self._base_builder = base_builder + self._event_id = event_id + + @defer.inlineCallbacks + def build(self, prev_event_ids): + built_event = yield self._base_builder.build(prev_event_ids) + built_event.event_id = self._event_id + built_event._event_dict["event_id"] = self._event_id + return built_event + + @property + def room_id(self): + return self._base_builder.room_id + + event_1, context_1 = self.get_success( + self.event_creation_handler.create_new_client_event( + EventIdManglingBuilder( + self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Redaction, + "sender": self.u_alice.to_string(), + "room_id": self.room1.to_string(), + "content": {"reason": "test"}, + "redacts": redaction_event_id2, + }, + ), + redaction_event_id1, + ) + ) + ) + + self.get_success(self.store.persist_event(event_1, context_1)) + + event_2, context_2 = self.get_success( + self.event_creation_handler.create_new_client_event( + EventIdManglingBuilder( + self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Redaction, + "sender": self.u_alice.to_string(), + "room_id": self.room1.to_string(), + "content": {"reason": "test"}, + "redacts": redaction_event_id1, + }, + ), + redaction_event_id2, + ) + ) + ) + self.get_success(self.store.persist_event(event_2, context_2)) + + # fetch one of the redactions + fetched = self.get_success(self.store.get_event(redaction_event_id1)) + + # it should have been redacted + self.assertEqual(fetched.unsigned["redacted_by"], redaction_event_id2) + self.assertEqual( + fetched.unsigned["redacted_because"].event_id, redaction_event_id2 + ) |