From 95e41f368b19996872a1661d7066670fe65f1eba Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 22 Jun 2020 08:04:14 -0400 Subject: Allow local media to be marked as safe from being quarantined. (#7718) --- tests/rest/admin/test_admin.py | 137 +++++++++++++++++++---------------------- 1 file changed, 65 insertions(+), 72 deletions(-) (limited to 'tests') diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 977615ebef..b1a4decced 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -220,6 +220,24 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): return hs + def _ensure_quarantined(self, admin_user_tok, server_and_media_id): + """Ensure a piece of media is quarantined when trying to access it.""" + request, channel = self.make_request( + "GET", server_and_media_id, shorthand=False, access_token=admin_user_tok, + ) + request.render(self.download_resource) + self.pump(1.0) + + # Should be quarantined + self.assertEqual( + 404, + int(channel.code), + msg=( + "Expected to receive a 404 on accessing quarantined media: %s" + % server_and_media_id + ), + ) + def test_quarantine_media_requires_admin(self): self.register_user("nonadmin", "pass", admin=False) non_admin_user_tok = self.login("nonadmin", "pass") @@ -292,24 +310,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.code), msg=channel.result["body"]) # Attempt to access the media - request, channel = self.make_request( - "GET", - server_name_and_media_id, - shorthand=False, - access_token=admin_user_tok, - ) - request.render(self.download_resource) - self.pump(1.0) - - # Should be quarantined - self.assertEqual( - 404, - int(channel.code), - msg=( - "Expected to receive a 404 on accessing quarantined media: %s" - % server_name_and_media_id - ), - ) + self._ensure_quarantined(admin_user_tok, server_name_and_media_id) def test_quarantine_all_media_in_room(self, override_url_template=None): self.register_user("room_admin", "pass", admin=True) @@ -371,45 +372,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): server_and_media_id_2 = mxc_2[6:] # Test that we cannot download any of the media anymore - request, channel = self.make_request( - "GET", - server_and_media_id_1, - shorthand=False, - access_token=non_admin_user_tok, - ) - request.render(self.download_resource) - self.pump(1.0) - - # Should be quarantined - self.assertEqual( - 404, - int(channel.code), - msg=( - "Expected to receive a 404 on accessing quarantined media: %s" - % server_and_media_id_1 - ), - ) - - request, channel = self.make_request( - "GET", - server_and_media_id_2, - shorthand=False, - access_token=non_admin_user_tok, - ) - request.render(self.download_resource) - self.pump(1.0) - - # Should be quarantined - self.assertEqual( - 404, - int(channel.code), - msg=( - "Expected to receive a 404 on accessing quarantined media: %s" - % server_and_media_id_2 - ), - ) + self._ensure_quarantined(admin_user_tok, server_and_media_id_1) + self._ensure_quarantined(admin_user_tok, server_and_media_id_2) - def test_quaraantine_all_media_in_room_deprecated_api_path(self): + def test_quarantine_all_media_in_room_deprecated_api_path(self): # Perform the above test with the deprecated API path self.test_quarantine_all_media_in_room("/_synapse/admin/v1/quarantine_media/%s") @@ -449,25 +415,52 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): ) # Attempt to access each piece of media + self._ensure_quarantined(admin_user_tok, server_and_media_id_1) + self._ensure_quarantined(admin_user_tok, server_and_media_id_2) + + def test_cannot_quarantine_safe_media(self): + self.register_user("user_admin", "pass", admin=True) + admin_user_tok = self.login("user_admin", "pass") + + non_admin_user = self.register_user("user_nonadmin", "pass", admin=False) + non_admin_user_tok = self.login("user_nonadmin", "pass") + + # Upload some media + response_1 = self.helper.upload_media( + self.upload_resource, self.image_data, tok=non_admin_user_tok + ) + response_2 = self.helper.upload_media( + self.upload_resource, self.image_data, tok=non_admin_user_tok + ) + + # Extract media IDs + server_and_media_id_1 = response_1["content_uri"][6:] + server_and_media_id_2 = response_2["content_uri"][6:] + + # Mark the second item as safe from quarantine. + _, media_id_2 = server_and_media_id_2.split("/") + self.get_success(self.store.mark_local_media_as_safe(media_id_2)) + + # Quarantine all media by this user + url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote( + non_admin_user + ) request, channel = self.make_request( - "GET", - server_and_media_id_1, - shorthand=False, - access_token=non_admin_user_tok, + "POST", url.encode("ascii"), access_token=admin_user_tok, ) - request.render(self.download_resource) + self.render(request) self.pump(1.0) - - # Should be quarantined + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( - 404, - int(channel.code), - msg=( - "Expected to receive a 404 on accessing quarantined media: %s" - % server_and_media_id_1, - ), + json.loads(channel.result["body"].decode("utf-8")), + {"num_quarantined": 1}, + "Expected 1 quarantined item", ) + # Attempt to access each piece of media, the first should fail, the + # second should succeed. + self._ensure_quarantined(admin_user_tok, server_and_media_id_1) + # Attempt to access each piece of media request, channel = self.make_request( "GET", @@ -478,12 +471,12 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): request.render(self.download_resource) self.pump(1.0) - # Should be quarantined + # Shouldn't be quarantined self.assertEqual( - 404, + 200, int(channel.code), msg=( - "Expected to receive a 404 on accessing quarantined media: %s" + "Expected to receive a 200 on accessing not-quarantined media: %s" % server_and_media_id_2 ), ) -- cgit 1.5.1 From 6920e58136671f086536332bdd6844dff0d4b429 Mon Sep 17 00:00:00 2001 From: Sorunome Date: Wed, 24 Jun 2020 11:23:55 +0200 Subject: add org.matrix.login.jwt so that m.login.jwt can be deprecated (#7675) --- changelog.d/7675.removal | 1 + synapse/rest/client/v1/login.py | 5 ++++- tests/rest/client/v1/test_login.py | 10 +++++++--- 3 files changed, 12 insertions(+), 4 deletions(-) create mode 100644 changelog.d/7675.removal (limited to 'tests') diff --git a/changelog.d/7675.removal b/changelog.d/7675.removal new file mode 100644 index 0000000000..2500e2c578 --- /dev/null +++ b/changelog.d/7675.removal @@ -0,0 +1 @@ +Deprecate `m.login.jwt` login method in favour of `org.matrix.login.jwt`, as `m.login.jwt` is not part of the Matrix spec. diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index c2c9a9c3aa..bf0f9bd077 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -81,7 +81,8 @@ class LoginRestServlet(RestServlet): CAS_TYPE = "m.login.cas" SSO_TYPE = "m.login.sso" TOKEN_TYPE = "m.login.token" - JWT_TYPE = "m.login.jwt" + JWT_TYPE = "org.matrix.login.jwt" + JWT_TYPE_DEPRECATED = "m.login.jwt" def __init__(self, hs): super(LoginRestServlet, self).__init__() @@ -116,6 +117,7 @@ class LoginRestServlet(RestServlet): flows = [] if self.jwt_enabled: flows.append({"type": LoginRestServlet.JWT_TYPE}) + flows.append({"type": LoginRestServlet.JWT_TYPE_DEPRECATED}) if self.cas_enabled: # we advertise CAS for backwards compat, though MSC1721 renamed it @@ -149,6 +151,7 @@ class LoginRestServlet(RestServlet): try: if self.jwt_enabled and ( login_submission["type"] == LoginRestServlet.JWT_TYPE + or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED ): result = await self.do_jwt_login(login_submission) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 9033f09fd2..fd97999956 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -526,7 +526,9 @@ class JWTTestCase(unittest.HomeserverTestCase): return jwt.encode(token, secret, "HS256").decode("ascii") def jwt_login(self, *args): - params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)}) + params = json.dumps( + {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} + ) request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) return channel @@ -568,7 +570,7 @@ class JWTTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["error"], "Invalid JWT") def test_login_no_token(self): - params = json.dumps({"type": "m.login.jwt"}) + params = json.dumps({"type": "org.matrix.login.jwt"}) request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) self.assertEqual(channel.result["code"], b"401", channel.result) @@ -640,7 +642,9 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): return jwt.encode(token, secret, "RS256").decode("ascii") def jwt_login(self, *args): - params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)}) + params = json.dumps( + {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} + ) request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) return channel -- cgit 1.5.1 From 0e0a2817a29391fd777f7ee683dc03d63cf40302 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 24 Jun 2020 18:48:18 +0100 Subject: Yield during large v2 state res. (#7735) State res v2 across large data sets can be very CPU intensive, and if all the relevant events are in the cache the algorithm will run from start to finish within a single reactor tick. This can result in blocking the reactor tick for several seconds, which can have major repercussions on other requests. To fix this we simply add the occaisonal `sleep(0)` during iterations to yield execution until the next reactor tick. The aim is to only do this for large data sets so that we don't impact otherwise quick resolutions.= --- changelog.d/7735.bugfix | 1 + synapse/handlers/federation.py | 1 + synapse/state/__init__.py | 6 ++++- synapse/state/v2.py | 56 ++++++++++++++++++++++++++++++++++-------- tests/state/test_v2.py | 9 +++++++ 5 files changed, 62 insertions(+), 11 deletions(-) create mode 100644 changelog.d/7735.bugfix (limited to 'tests') diff --git a/changelog.d/7735.bugfix b/changelog.d/7735.bugfix new file mode 100644 index 0000000000..86959a5ca4 --- /dev/null +++ b/changelog.d/7735.bugfix @@ -0,0 +1 @@ +Fix large state resolutions from stalling Synapse for seconds at a time. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 873f6bc39f..3828ff0ef0 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -376,6 +376,7 @@ class FederationHandler(BaseHandler): room_version = await self.store.get_room_version_id(room_id) state_map = await resolve_events_with_store( + self.clock, room_id, room_version, state_maps, diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 50fd843f66..495d9f04c8 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -32,6 +32,7 @@ from synapse.logging.utils import log_function from synapse.state import v1, v2 from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.types import StateMap +from synapse.util import Clock from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure, measure_func @@ -414,6 +415,7 @@ class StateHandler(object): with Measure(self.clock, "state._resolve_events"): new_state = yield resolve_events_with_store( + self.clock, event.room_id, room_version, state_set_ids, @@ -516,6 +518,7 @@ class StateResolutionHandler(object): logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): new_state = yield resolve_events_with_store( + self.clock, room_id, room_version, list(state_groups_ids.values()), @@ -589,6 +592,7 @@ def _make_state_cache_entry(new_state, state_groups_ids): def resolve_events_with_store( + clock: Clock, room_id: str, room_version: str, state_sets: List[StateMap[str]], @@ -625,7 +629,7 @@ def resolve_events_with_store( ) else: return v2.resolve_events_with_store( - room_id, room_version, state_sets, event_map, state_res_store + clock, room_id, room_version, state_sets, event_map, state_res_store ) diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 57eadce4e6..7181ecda9a 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -27,12 +27,20 @@ from synapse.api.errors import AuthError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.types import StateMap +from synapse.util import Clock logger = logging.getLogger(__name__) +# We want to yield to the reactor occasionally during state res when dealing +# with large data sets, so that we don't exhaust the reactor. This is done by +# yielding to reactor during loops every N iterations. +_YIELD_AFTER_ITERATIONS = 100 + + @defer.inlineCallbacks def resolve_events_with_store( + clock: Clock, room_id: str, room_version: str, state_sets: List[StateMap[str]], @@ -42,13 +50,11 @@ def resolve_events_with_store( """Resolves the state using the v2 state resolution algorithm Args: + clock room_id: the room we are working in - room_version: The room version - state_sets: List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. - event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be @@ -113,7 +119,7 @@ def resolve_events_with_store( ) sorted_power_events = yield _reverse_topological_power_sort( - room_id, power_events, event_map, state_res_store, full_conflicted_set + clock, room_id, power_events, event_map, state_res_store, full_conflicted_set ) logger.debug("sorted %d power events", len(sorted_power_events)) @@ -142,7 +148,7 @@ def resolve_events_with_store( pl = resolved_state.get((EventTypes.PowerLevels, ""), None) leftover_events = yield _mainline_sort( - room_id, leftover_events, pl, event_map, state_res_store + clock, room_id, leftover_events, pl, event_map, state_res_store ) logger.debug("resolving remaining events") @@ -317,12 +323,13 @@ def _add_event_and_auth_chain_to_graph( @defer.inlineCallbacks def _reverse_topological_power_sort( - room_id, event_ids, event_map, state_res_store, auth_diff + clock, room_id, event_ids, event_map, state_res_store, auth_diff ): """Returns a list of the event_ids sorted by reverse topological ordering, and then by power level and origin_server_ts Args: + clock (Clock) room_id (str): the room we are working in event_ids (list[str]): The events to sort event_map (dict[str,FrozenEvent]) @@ -334,18 +341,28 @@ def _reverse_topological_power_sort( """ graph = {} - for event_id in event_ids: + for idx, event_id in enumerate(event_ids, start=1): yield _add_event_and_auth_chain_to_graph( graph, room_id, event_id, event_map, state_res_store, auth_diff ) + # We yield occasionally when we're working with large data sets to + # ensure that we don't block the reactor loop for too long. + if idx % _YIELD_AFTER_ITERATIONS == 0: + yield clock.sleep(0) + event_to_pl = {} - for event_id in graph: + for idx, event_id in enumerate(graph, start=1): pl = yield _get_power_level_for_sender( room_id, event_id, event_map, state_res_store ) event_to_pl[event_id] = pl + # We yield occasionally when we're working with large data sets to + # ensure that we don't block the reactor loop for too long. + if idx % _YIELD_AFTER_ITERATIONS == 0: + yield clock.sleep(0) + def _get_power_order(event_id): ev = event_map[event_id] pl = event_to_pl[event_id] @@ -423,12 +440,13 @@ def _iterative_auth_checks( @defer.inlineCallbacks def _mainline_sort( - room_id, event_ids, resolved_power_event_id, event_map, state_res_store + clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store ): """Returns a sorted list of event_ids sorted by mainline ordering based on the given event resolved_power_event_id Args: + clock (Clock) room_id (str): room we're working in event_ids (list[str]): Events to sort resolved_power_event_id (str): The final resolved power level event ID @@ -438,8 +456,14 @@ def _mainline_sort( Returns: Deferred[list[str]]: The sorted list """ + if not event_ids: + # It's possible for there to be no event IDs here to sort, so we can + # skip calculating the mainline in that case. + return [] + mainline = [] pl = resolved_power_event_id + idx = 0 while pl: mainline.append(pl) pl_ev = yield _get_event(room_id, pl, event_map, state_res_store) @@ -453,17 +477,29 @@ def _mainline_sort( pl = aid break + # We yield occasionally when we're working with large data sets to + # ensure that we don't block the reactor loop for too long. + if idx != 0 and idx % _YIELD_AFTER_ITERATIONS == 0: + yield clock.sleep(0) + + idx += 1 + mainline_map = {ev_id: i + 1 for i, ev_id in enumerate(reversed(mainline))} event_ids = list(event_ids) order_map = {} - for ev_id in event_ids: + for idx, ev_id in enumerate(event_ids, start=1): depth = yield _get_mainline_depth_for_event( event_map[ev_id], mainline_map, event_map, state_res_store ) order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id) + # We yield occasionally when we're working with large data sets to + # ensure that we don't block the reactor loop for too long. + if idx % _YIELD_AFTER_ITERATIONS == 0: + yield clock.sleep(0) + event_ids.sort(key=lambda ev_id: order_map[ev_id]) return event_ids diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index cdc347bc53..38f9b423ef 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -17,6 +17,8 @@ import itertools import attr +from twisted.internet import defer + from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.room_versions import RoomVersions from synapse.event_auth import auth_types_for_event @@ -41,6 +43,11 @@ MEMBERSHIP_CONTENT_BAN = {"membership": Membership.BAN} ORIGIN_SERVER_TS = 0 +class FakeClock: + def sleep(self, msec): + return defer.succeed(None) + + class FakeEvent(object): """A fake event we use as a convenience. @@ -417,6 +424,7 @@ class StateTestCase(unittest.TestCase): state_before = dict(state_at_event[prev_events[0]]) else: state_d = resolve_events_with_store( + FakeClock(), ROOM_ID, RoomVersions.V2.identifier, [state_at_event[n] for n in prev_events], @@ -565,6 +573,7 @@ class SimpleParamStateTestCase(unittest.TestCase): # Test that we correctly handle passing `None` as the event_map state_d = resolve_events_with_store( + FakeClock(), ROOM_ID, RoomVersions.V2.identifier, [self.state_at_bob, self.state_at_charlie], -- cgit 1.5.1