From d6fb96e056f79de220d8d59429d89a61498e9af3 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 7 Dec 2021 16:51:53 +0000 Subject: Fix case in `wait_for_background_updates` where `self.store` does not exist (#11331) Pull the DataStore from the HomeServer instance, which always exists. --- tests/unittest.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) (limited to 'tests') diff --git a/tests/unittest.py b/tests/unittest.py index eea0903f05..1431848367 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -331,16 +331,13 @@ class HomeserverTestCase(TestCase): time.sleep(0.01) def wait_for_background_updates(self) -> None: - """Block until all background database updates have completed. - - Note that callers must ensure there's a store property created on the - testcase. - """ + """Block until all background database updates have completed.""" + store = self.hs.get_datastore() while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() + store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db_pool.updates.do_next_background_update(False), by=0.1 + store.db_pool.updates.do_next_background_update(False), by=0.1 ) def make_homeserver(self, reactor, clock): -- cgit 1.5.1 From 8541809cb952ebf0da2a95dd93eccd5644dab49d Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 8 Dec 2021 05:01:38 -0500 Subject: Send and handle cross-signing messages using the stable prefix. (#10520) --- changelog.d/10520.misc | 1 + synapse/handlers/e2e_keys.py | 8 ++++++-- synapse/storage/databases/main/devices.py | 4 +++- tests/federation/test_federation_sender.py | 5 +++-- 4 files changed, 13 insertions(+), 5 deletions(-) create mode 100644 changelog.d/10520.misc (limited to 'tests') diff --git a/changelog.d/10520.misc b/changelog.d/10520.misc new file mode 100644 index 0000000000..a911e165da --- /dev/null +++ b/changelog.d/10520.misc @@ -0,0 +1 @@ +Send and handle cross-signing messages using the stable prefix. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 60c11e3d21..b2554bda04 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -65,8 +65,12 @@ class E2eKeysHandler: else: # Only register this edu handler on master as it requires writing # device updates to the db - # - # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec + federation_registry.register_edu_handler( + "m.signing_key_update", + self._edu_updater.incoming_signing_key_update, + ) + # also handle the unstable version + # FIXME: remove this when enough servers have upgraded federation_registry.register_edu_handler( "org.matrix.signing_key_update", self._edu_updater.incoming_signing_key_update, diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index d5a4a661cd..838a2a6a3d 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -274,7 +274,9 @@ class DeviceWorkerStore(SQLBaseStore): # add the updated cross-signing keys to the results list for user_id, result in cross_signing_keys_by_user.items(): result["user_id"] = user_id - # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec + results.append(("m.signing_key_update", result)) + # also send the unstable version + # FIXME: remove this when enough servers have upgraded results.append(("org.matrix.signing_key_update", result)) return now_stream_id, results diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index b457dad6d2..b2376e2db9 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -266,7 +266,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): ) # expect signing key update edu - self.assertEqual(len(self.edus), 1) + self.assertEqual(len(self.edus), 2) + self.assertEqual(self.edus.pop(0)["edu_type"], "m.signing_key_update") self.assertEqual(self.edus.pop(0)["edu_type"], "org.matrix.signing_key_update") # sign the devices @@ -491,7 +492,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): ) -> None: """Check that the txn has an EDU with a signing key update.""" edus = txn["edus"] - self.assertEqual(len(edus), 1) + self.assertEqual(len(edus), 2) def generate_and_upload_device_signing_key( self, user_id: str, device_id: str -- cgit 1.5.1 From 365e9482fe18b293f55f78e5f5d2d1107a1d95e1 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 8 Dec 2021 14:54:47 +0000 Subject: Use HTTPStatus constants in place of literals in `tests.rest.client.test_auth`. (#11520) --- changelog.d/11520.misc | 1 + tests/rest/client/test_auth.py | 134 ++++++++++++++++++++++++++--------------- 2 files changed, 88 insertions(+), 47 deletions(-) create mode 100644 changelog.d/11520.misc (limited to 'tests') diff --git a/changelog.d/11520.misc b/changelog.d/11520.misc new file mode 100644 index 0000000000..2d84120e19 --- /dev/null +++ b/changelog.d/11520.misc @@ -0,0 +1 @@ +Use HTTPStatus constants in place of literals in `tests.rest.client.test_auth`. \ No newline at end of file diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 72bbc87b4a..27cb856b0a 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -85,7 +85,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", "auth/m.login.recaptcha/fallback/web?session=" + session ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) channel = self.make_request( "POST", @@ -104,7 +104,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): """Ensure that fallback auth via a captcha works.""" # Returns a 401 as per the spec channel = self.register( - 401, + HTTPStatus.UNAUTHORIZED, {"username": "user", "type": "m.login.password", "password": "bar"}, ) @@ -116,15 +116,17 @@ class FallbackAuthTests(unittest.HomeserverTestCase): ) # Complete the recaptcha step. - self.recaptcha(session, 200) + self.recaptcha(session, HTTPStatus.OK) # also complete the dummy auth - self.register(200, {"auth": {"session": session, "type": "m.login.dummy"}}) + self.register( + HTTPStatus.OK, {"auth": {"session": session, "type": "m.login.dummy"}} + ) # Now we should have fulfilled a complete auth flow, including # the recaptcha fallback step, we can then send a # request to the register API with the session in the authdict. - channel = self.register(200, {"auth": {"session": session}}) + channel = self.register(HTTPStatus.OK, {"auth": {"session": session}}) # We're given a registered user. self.assertEqual(channel.json_body["user_id"], "@user:test") @@ -137,7 +139,8 @@ class FallbackAuthTests(unittest.HomeserverTestCase): # will be used.) # Returns a 401 as per the spec channel = self.register( - 401, {"username": "user", "type": "m.login.password", "password": "bar"} + HTTPStatus.UNAUTHORIZED, + {"username": "user", "type": "m.login.password", "password": "bar"}, ) # Grab the session @@ -231,7 +234,9 @@ class UIAuthTests(unittest.HomeserverTestCase): """ # Attempt to delete this device. # Returns a 401 as per the spec - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) # Grab the session session = channel.json_body["session"] @@ -242,7 +247,7 @@ class UIAuthTests(unittest.HomeserverTestCase): self.delete_device( self.user_tok, self.device_id, - 200, + HTTPStatus.OK, { "auth": { "type": "m.login.password", @@ -260,14 +265,16 @@ class UIAuthTests(unittest.HomeserverTestCase): UIA - check that still works. """ - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) session = channel.json_body["session"] # Make another request providing the UI auth flow. self.delete_device( self.user_tok, self.device_id, - 200, + HTTPStatus.OK, { "auth": { "type": "m.login.password", @@ -293,7 +300,9 @@ class UIAuthTests(unittest.HomeserverTestCase): # Attempt to delete the first device. # Returns a 401 as per the spec - channel = self.delete_devices(401, {"devices": [self.device_id]}) + channel = self.delete_devices( + HTTPStatus.UNAUTHORIZED, {"devices": [self.device_id]} + ) # Grab the session session = channel.json_body["session"] @@ -303,7 +312,7 @@ class UIAuthTests(unittest.HomeserverTestCase): # Make another request providing the UI auth flow, but try to delete the # second device. self.delete_devices( - 200, + HTTPStatus.OK, { "devices": ["dev2"], "auth": { @@ -324,7 +333,9 @@ class UIAuthTests(unittest.HomeserverTestCase): # Attempt to delete the first device. # Returns a 401 as per the spec - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) # Grab the session session = channel.json_body["session"] @@ -338,7 +349,7 @@ class UIAuthTests(unittest.HomeserverTestCase): self.delete_device( self.user_tok, "dev2", - 403, + HTTPStatus.FORBIDDEN, { "auth": { "type": "m.login.password", @@ -361,13 +372,13 @@ class UIAuthTests(unittest.HomeserverTestCase): self.login("test", self.user_pass, "dev3") # Attempt to delete a device. This works since the user just logged in. - self.delete_device(self.user_tok, "dev2", 200) + self.delete_device(self.user_tok, "dev2", HTTPStatus.OK) # Move the clock forward past the validation timeout. self.reactor.advance(6) # Deleting another devices throws the user into UI auth. - channel = self.delete_device(self.user_tok, "dev3", 401) + channel = self.delete_device(self.user_tok, "dev3", HTTPStatus.UNAUTHORIZED) # Grab the session session = channel.json_body["session"] @@ -378,7 +389,7 @@ class UIAuthTests(unittest.HomeserverTestCase): self.delete_device( self.user_tok, "dev3", - 200, + HTTPStatus.OK, { "auth": { "type": "m.login.password", @@ -393,7 +404,7 @@ class UIAuthTests(unittest.HomeserverTestCase): # due to re-using the previous session. # # Note that *no auth* information is provided, not even a session iD! - self.delete_device(self.user_tok, self.device_id, 200) + self.delete_device(self.user_tok, self.device_id, HTTPStatus.OK) @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) @@ -413,7 +424,9 @@ class UIAuthTests(unittest.HomeserverTestCase): self.assertEqual(login_resp["user_id"], self.user) # initiate a UI Auth process by attempting to delete the device - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) # check that SSO is offered flows = channel.json_body["flows"] @@ -426,13 +439,13 @@ class UIAuthTests(unittest.HomeserverTestCase): ) # that should serve a confirmation page - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) # and now the delete request should succeed. self.delete_device( self.user_tok, self.device_id, - 200, + HTTPStatus.OK, body={"auth": {"session": session_id}}, ) @@ -445,13 +458,15 @@ class UIAuthTests(unittest.HomeserverTestCase): # now call the device deletion API: we should get the option to auth with SSO # and not password. - channel = self.delete_device(user_tok, device_id, 401) + channel = self.delete_device(user_tok, device_id, HTTPStatus.UNAUTHORIZED) flows = channel.json_body["flows"] self.assertEqual(flows, [{"stages": ["m.login.sso"]}]) def test_does_not_offer_sso_for_password_user(self): - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) flows = channel.json_body["flows"] self.assertEqual(flows, [{"stages": ["m.login.password"]}]) @@ -463,7 +478,9 @@ class UIAuthTests(unittest.HomeserverTestCase): login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) self.assertEqual(login_resp["user_id"], self.user) - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) flows = channel.json_body["flows"] # we have no particular expectations of ordering here @@ -480,7 +497,9 @@ class UIAuthTests(unittest.HomeserverTestCase): self.assertEqual(login_resp["user_id"], self.user) # start a UI Auth flow by attempting to delete a device - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) flows = channel.json_body["flows"] self.assertIn({"stages": ["m.login.sso"]}, flows) @@ -496,7 +515,10 @@ class UIAuthTests(unittest.HomeserverTestCase): # ... and the delete op should now fail with a 403 self.delete_device( - self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}} + self.user_tok, + self.device_id, + HTTPStatus.FORBIDDEN, + body={"auth": {"session": session_id}}, ) @@ -551,7 +573,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): login_without_refresh = self.make_request( "POST", "/_matrix/client/r0/login", body ) - self.assertEqual(login_without_refresh.code, 200, login_without_refresh.result) + self.assertEqual( + login_without_refresh.code, HTTPStatus.OK, login_without_refresh.result + ) self.assertNotIn("refresh_token", login_without_refresh.json_body) login_with_refresh = self.make_request( @@ -559,7 +583,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", {"refresh_token": True, **body}, ) - self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result) + self.assertEqual( + login_with_refresh.code, HTTPStatus.OK, login_with_refresh.result + ) self.assertIn("refresh_token", login_with_refresh.json_body) self.assertIn("expires_in_ms", login_with_refresh.json_body) @@ -577,7 +603,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): }, ) self.assertEqual( - register_without_refresh.code, 200, register_without_refresh.result + register_without_refresh.code, + HTTPStatus.OK, + register_without_refresh.result, ) self.assertNotIn("refresh_token", register_without_refresh.json_body) @@ -591,7 +619,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "refresh_token": True, }, ) - self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result) + self.assertEqual( + register_with_refresh.code, HTTPStatus.OK, register_with_refresh.result + ) self.assertIn("refresh_token", register_with_refresh.json_body) self.assertIn("expires_in_ms", register_with_refresh.json_body) @@ -610,14 +640,14 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response.code, 200, login_response.result) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) refresh_response = self.make_request( "POST", "/_matrix/client/v1/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) - self.assertEqual(refresh_response.code, 200, refresh_response.result) + self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) self.assertIn("access_token", refresh_response.json_body) self.assertIn("refresh_token", refresh_response.json_body) self.assertIn("expires_in_ms", refresh_response.json_body) @@ -648,7 +678,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response.code, 200, login_response.result) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) self.assertApproximates( login_response.json_body["expires_in_ms"], 60 * 1000, 100 ) @@ -658,7 +688,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/v1/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) - self.assertEqual(refresh_response.code, 200, refresh_response.result) + self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) self.assertApproximates( refresh_response.json_body["expires_in_ms"], 60 * 1000, 100 ) @@ -705,7 +735,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", {"refresh_token": True, **body}, ) - self.assertEqual(login_response1.code, 200, login_response1.result) + self.assertEqual(login_response1.code, HTTPStatus.OK, login_response1.result) self.assertApproximates( login_response1.json_body["expires_in_ms"], 60 * 1000, 100 ) @@ -716,7 +746,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response2.code, 200, login_response2.result) + self.assertEqual(login_response2.code, HTTPStatus.OK, login_response2.result) nonrefreshable_access_token = login_response2.json_body["access_token"] # Advance 59 seconds in the future (just shy of 1 minute, the time of expiry) @@ -818,7 +848,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response.code, 200, login_response.result) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) refresh_token = login_response.json_body["refresh_token"] # Advance shy of 2 minutes into the future @@ -826,7 +856,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # Refresh our session. The refresh token should still be valid right now. refresh_response = self.use_refresh_token(refresh_token) - self.assertEqual(refresh_response.code, 200, refresh_response.result) + self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) self.assertIn( "refresh_token", refresh_response.json_body, @@ -846,7 +876,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # This should fail because the refresh token's lifetime has also been # diminished as our session expired. refresh_response = self.use_refresh_token(refresh_token) - self.assertEqual(refresh_response.code, 403, refresh_response.result) + self.assertEqual( + refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result + ) def test_refresh_token_invalidation(self): """Refresh tokens are invalidated after first use of the next token. @@ -875,7 +907,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response.code, 200, login_response.result) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) # This first refresh should work properly first_refresh_response = self.make_request( @@ -884,7 +916,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( - first_refresh_response.code, 200, first_refresh_response.result + first_refresh_response.code, HTTPStatus.OK, first_refresh_response.result ) # This one as well, since the token in the first one was never used @@ -894,7 +926,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( - second_refresh_response.code, 200, second_refresh_response.result + second_refresh_response.code, HTTPStatus.OK, second_refresh_response.result ) # This one should not, since the token from the first refresh is not valid anymore @@ -904,7 +936,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": first_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( - third_refresh_response.code, 401, third_refresh_response.result + third_refresh_response.code, + HTTPStatus.UNAUTHORIZED, + third_refresh_response.result, ) # The associated access token should also be invalid @@ -913,7 +947,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/account/whoami", access_token=first_refresh_response.json_body["access_token"], ) - self.assertEqual(whoami_response.code, 401, whoami_response.result) + self.assertEqual( + whoami_response.code, HTTPStatus.UNAUTHORIZED, whoami_response.result + ) # But all other tokens should work (they will expire after some time) for access_token in [ @@ -923,7 +959,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): whoami_response = self.make_request( "GET", "/_matrix/client/r0/account/whoami", access_token=access_token ) - self.assertEqual(whoami_response.code, 200, whoami_response.result) + self.assertEqual( + whoami_response.code, HTTPStatus.OK, whoami_response.result + ) # Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail fourth_refresh_response = self.make_request( @@ -932,7 +970,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( - fourth_refresh_response.code, 403, fourth_refresh_response.result + fourth_refresh_response.code, + HTTPStatus.FORBIDDEN, + fourth_refresh_response.result, ) # But refreshing from the last valid refresh token still works @@ -942,5 +982,5 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": second_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( - fifth_refresh_response.code, 200, fifth_refresh_response.result + fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result ) -- cgit 1.5.1 From 7ecaa3b976b04dc5b2c6786aa18845016b80dd01 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 8 Dec 2021 17:59:40 +0100 Subject: Clean up `synapse.rest.admin` (#11535) --- changelog.d/11535.misc | 1 + synapse/rest/admin/__init__.py | 4 +- synapse/rest/admin/background_updates.py | 16 +++---- synapse/rest/admin/devices.py | 20 ++++----- synapse/rest/admin/event_reports.py | 2 - synapse/rest/admin/federation.py | 2 +- synapse/rest/admin/groups.py | 2 +- synapse/rest/admin/media.py | 60 ++++++++----------------- synapse/rest/admin/registration_tokens.py | 3 -- synapse/rest/admin/rooms.py | 70 +++++++++-------------------- synapse/rest/admin/server_notice_servlet.py | 4 +- synapse/rest/admin/statistics.py | 22 ++++----- synapse/rest/admin/username_available.py | 2 +- synapse/rest/admin/users.py | 51 ++++++++++----------- tests/rest/admin/test_statistics.py | 2 +- 15 files changed, 96 insertions(+), 165 deletions(-) create mode 100644 changelog.d/11535.misc (limited to 'tests') diff --git a/changelog.d/11535.misc b/changelog.d/11535.misc new file mode 100644 index 0000000000..580ac354ab --- /dev/null +++ b/changelog.d/11535.misc @@ -0,0 +1 @@ +Clean up `synapse.rest.admin`. \ No newline at end of file diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index c499afd4be..701c609c12 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -108,7 +108,7 @@ class VersionServlet(RestServlet): class PurgeHistoryRestServlet(RestServlet): PATTERNS = admin_patterns( - "/purge_history/(?P[^/]*)(/(?P[^/]+))?" + "/purge_history/(?P[^/]*)(/(?P[^/]*))?$" ) def __init__(self, hs: "HomeServer"): @@ -195,7 +195,7 @@ class PurgeHistoryRestServlet(RestServlet): class PurgeHistoryStatusRestServlet(RestServlet): - PATTERNS = admin_patterns("/purge_history_status/(?P[^/]+)") + PATTERNS = admin_patterns("/purge_history_status/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self.pagination_handler = hs.get_pagination_handler() diff --git a/synapse/rest/admin/background_updates.py b/synapse/rest/admin/background_updates.py index 479672d4d5..6ec00ce0b9 100644 --- a/synapse/rest/admin/background_updates.py +++ b/synapse/rest/admin/background_updates.py @@ -22,7 +22,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, ) from synapse.http.site import SynapseRequest -from synapse.rest.admin._base import admin_patterns, assert_user_is_admin +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.types import JsonDict if TYPE_CHECKING: @@ -41,8 +41,7 @@ class BackgroundUpdateEnabledRestServlet(RestServlet): self._data_stores = hs.get_datastores() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request) - await assert_user_is_admin(self._auth, requester.user) + await assert_requester_is_admin(self._auth, request) # We need to check that all configured databases have updates enabled. # (They *should* all be in sync.) @@ -51,8 +50,7 @@ class BackgroundUpdateEnabledRestServlet(RestServlet): return HTTPStatus.OK, {"enabled": enabled} async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request) - await assert_user_is_admin(self._auth, requester.user) + await assert_requester_is_admin(self._auth, request) body = parse_json_object_from_request(request) @@ -84,8 +82,7 @@ class BackgroundUpdateRestServlet(RestServlet): self._data_stores = hs.get_datastores() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request) - await assert_user_is_admin(self._auth, requester.user) + await assert_requester_is_admin(self._auth, request) # We need to check that all configured databases have updates enabled. # (They *should* all be in sync.) @@ -111,15 +108,14 @@ class BackgroundUpdateRestServlet(RestServlet): class BackgroundUpdateStartJobRestServlet(RestServlet): """Allows to start specific background updates""" - PATTERNS = admin_patterns("/background_updates/start_job") + PATTERNS = admin_patterns("/background_updates/start_job$") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() self._store = hs.get_datastore() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request) - await assert_user_is_admin(self._auth, requester.user) + await assert_requester_is_admin(self._auth, request) body = parse_json_object_from_request(request) assert_params_in_dict(body, ["job_name"]) diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index 2e5a6600d3..062a33d28d 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -42,10 +42,10 @@ class DeviceRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.store = hs.get_datastore() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str, device_id: str @@ -53,7 +53,7 @@ class DeviceRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) @@ -71,7 +71,7 @@ class DeviceRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) @@ -87,7 +87,7 @@ class DeviceRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) @@ -109,14 +109,10 @@ class DevicesRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/devices$", "v2") def __init__(self, hs: "HomeServer"): - """ - Args: - hs: server - """ - self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.store = hs.get_datastore() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str @@ -124,7 +120,7 @@ class DevicesRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) @@ -144,10 +140,10 @@ class DeleteDevicesRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/delete_devices$", "v2") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.store = hs.get_datastore() + self.is_mine = hs.is_mine async def on_POST( self, request: SynapseRequest, user_id: str @@ -155,7 +151,7 @@ class DeleteDevicesRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py index 5ee8b11110..38477f8ead 100644 --- a/synapse/rest/admin/event_reports.py +++ b/synapse/rest/admin/event_reports.py @@ -52,7 +52,6 @@ class EventReportsRestServlet(RestServlet): PATTERNS = admin_patterns("/event_reports$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -115,7 +114,6 @@ class EventReportDetailRestServlet(RestServlet): PATTERNS = admin_patterns("/event_reports/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py index 744687be35..50d88c9109 100644 --- a/synapse/rest/admin/federation.py +++ b/synapse/rest/admin/federation.py @@ -100,7 +100,7 @@ class DestinationsRestServlet(RestServlet): 200 OK with details of a destination if success otherwise an error. """ - PATTERNS = admin_patterns("/federation/destinations/(?P[^/]+)$") + PATTERNS = admin_patterns("/federation/destinations/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py index a27110388f..cd697e180e 100644 --- a/synapse/rest/admin/groups.py +++ b/synapse/rest/admin/groups.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) class DeleteGroupAdminRestServlet(RestServlet): """Allows deleting of local groups""" - PATTERNS = admin_patterns("/delete_group/(?P[^/]*)") + PATTERNS = admin_patterns("/delete_group/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self.group_server = hs.get_groups_server_handler() diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 9e23e2d8fc..7236e4027f 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -17,7 +17,7 @@ import logging from http import HTTPStatus from typing import TYPE_CHECKING, Tuple -from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest @@ -41,9 +41,9 @@ class QuarantineMediaInRoom(RestServlet): """ PATTERNS = [ - *admin_patterns("/room/(?P[^/]+)/media/quarantine$"), + *admin_patterns("/room/(?P[^/]*)/media/quarantine$"), # This path kept around for legacy reasons - *admin_patterns("/quarantine_media/(?P[^/]+)"), + *admin_patterns("/quarantine_media/(?P[^/]*)$"), ] def __init__(self, hs: "HomeServer"): @@ -71,7 +71,7 @@ class QuarantineMediaByUser(RestServlet): this server. """ - PATTERNS = admin_patterns("/user/(?P[^/]+)/media/quarantine$") + PATTERNS = admin_patterns("/user/(?P[^/]*)/media/quarantine$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -99,7 +99,7 @@ class QuarantineMediaByID(RestServlet): """ PATTERNS = admin_patterns( - "/media/quarantine/(?P[^/]+)/(?P[^/]+)" + "/media/quarantine/(?P[^/]*)/(?P[^/]*)$" ) def __init__(self, hs: "HomeServer"): @@ -128,7 +128,7 @@ class UnquarantineMediaByID(RestServlet): """ PATTERNS = admin_patterns( - "/media/unquarantine/(?P[^/]+)/(?P[^/]+)" + "/media/unquarantine/(?P[^/]*)/(?P[^/]*)$" ) def __init__(self, hs: "HomeServer"): @@ -138,8 +138,7 @@ class UnquarantineMediaByID(RestServlet): async def on_POST( self, request: SynapseRequest, server_name: str, media_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) logging.info( "Remove from quarantine local media by ID: %s/%s", server_name, media_id @@ -154,7 +153,7 @@ class UnquarantineMediaByID(RestServlet): class ProtectMediaByID(RestServlet): """Protect local media from being quarantined.""" - PATTERNS = admin_patterns("/media/protect/(?P[^/]+)") + PATTERNS = admin_patterns("/media/protect/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -163,8 +162,7 @@ class ProtectMediaByID(RestServlet): async def on_POST( self, request: SynapseRequest, media_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) logging.info("Protecting local media by ID: %s", media_id) @@ -177,7 +175,7 @@ class ProtectMediaByID(RestServlet): class UnprotectMediaByID(RestServlet): """Unprotect local media from being quarantined.""" - PATTERNS = admin_patterns("/media/unprotect/(?P[^/]+)") + PATTERNS = admin_patterns("/media/unprotect/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -186,8 +184,7 @@ class UnprotectMediaByID(RestServlet): async def on_POST( self, request: SynapseRequest, media_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) logging.info("Unprotecting local media by ID: %s", media_id) @@ -200,7 +197,7 @@ class UnprotectMediaByID(RestServlet): class ListMediaInRoom(RestServlet): """Lists all of the media in a given room.""" - PATTERNS = admin_patterns("/room/(?P[^/]+)/media$") + PATTERNS = admin_patterns("/room/(?P[^/]*)/media$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -209,10 +206,7 @@ class ListMediaInRoom(RestServlet): async def on_GET( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - is_admin = await self.auth.is_server_admin(requester.user) - if not is_admin: - raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin") + await assert_requester_is_admin(self.auth, request) local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id) @@ -254,7 +248,7 @@ class PurgeMediaCacheRestServlet(RestServlet): class DeleteMediaByID(RestServlet): """Delete local media by a given ID. Removes it from this server.""" - PATTERNS = admin_patterns("/media/(?P[^/]+)/(?P[^/]+)") + PATTERNS = admin_patterns("/media/(?P[^/]*)/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -286,7 +280,7 @@ class DeleteMediaByDateSize(RestServlet): timestamp and size. """ - PATTERNS = admin_patterns("/media/(?P[^/]+)/delete$") + PATTERNS = admin_patterns("/media/(?P[^/]*)/delete$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -353,7 +347,7 @@ class UserMediaRestServlet(RestServlet): media that exist given for this user """ - PATTERNS = admin_patterns("/users/(?P[^/]+)/media$") + PATTERNS = admin_patterns("/users/(?P[^/]*)/media$") def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine @@ -403,16 +397,7 @@ class UserMediaRestServlet(RestServlet): request, "order_by", default=MediaSortOrder.CREATED_TS.value, - allowed_values=( - MediaSortOrder.MEDIA_ID.value, - MediaSortOrder.UPLOAD_NAME.value, - MediaSortOrder.CREATED_TS.value, - MediaSortOrder.LAST_ACCESS_TS.value, - MediaSortOrder.MEDIA_LENGTH.value, - MediaSortOrder.MEDIA_TYPE.value, - MediaSortOrder.QUARANTINED_BY.value, - MediaSortOrder.SAFE_FROM_QUARANTINE.value, - ), + allowed_values=[sort_order.value for sort_order in MediaSortOrder], ) direction = parse_string( request, "dir", default="f", allowed_values=("f", "b") @@ -470,16 +455,7 @@ class UserMediaRestServlet(RestServlet): request, "order_by", default=MediaSortOrder.CREATED_TS.value, - allowed_values=( - MediaSortOrder.MEDIA_ID.value, - MediaSortOrder.UPLOAD_NAME.value, - MediaSortOrder.CREATED_TS.value, - MediaSortOrder.LAST_ACCESS_TS.value, - MediaSortOrder.MEDIA_LENGTH.value, - MediaSortOrder.MEDIA_TYPE.value, - MediaSortOrder.QUARANTINED_BY.value, - MediaSortOrder.SAFE_FROM_QUARANTINE.value, - ), + allowed_values=[sort_order.value for sort_order in MediaSortOrder], ) direction = parse_string( request, "dir", default="f", allowed_values=("f", "b") diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py index 891b98c088..04948b6408 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py @@ -70,7 +70,6 @@ class ListRegistrationTokensRestServlet(RestServlet): PATTERNS = admin_patterns("/registration_tokens$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -109,7 +108,6 @@ class NewRegistrationTokenRestServlet(RestServlet): PATTERNS = admin_patterns("/registration_tokens/new$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -260,7 +258,6 @@ class RegistrationTokenRestServlet(RestServlet): PATTERNS = admin_patterns("/registration_tokens/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.clock = hs.get_clock() self.auth = hs.get_auth() self.store = hs.get_datastore() diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 829e86675a..17c6df1cc8 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -61,7 +61,7 @@ class RoomRestV2Servlet(RestServlet): If 'purge' is true, it will remove all traces of a room from the database. """ - PATTERNS = admin_patterns("/rooms/(?P[^/]+)$", "v2") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)$", "v2") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() @@ -123,7 +123,7 @@ class RoomRestV2Servlet(RestServlet): class DeleteRoomStatusByRoomIdRestServlet(RestServlet): """Get the status of the delete room background task.""" - PATTERNS = admin_patterns("/rooms/(?P[^/]+)/delete_status$", "v2") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/delete_status$", "v2") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() @@ -160,7 +160,7 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet): class DeleteRoomStatusByDeleteIdRestServlet(RestServlet): """Get the status of the delete room background task.""" - PATTERNS = admin_patterns("/rooms/delete_status/(?P[^/]+)$", "v2") + PATTERNS = admin_patterns("/rooms/delete_status/(?P[^/]*)$", "v2") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() @@ -193,35 +193,17 @@ class ListRoomRestServlet(RestServlet): self.admin_handler = hs.get_admin_handler() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) # Extract query parameters start = parse_integer(request, "from", default=0) limit = parse_integer(request, "limit", default=100) - order_by = parse_string(request, "order_by", default=RoomSortOrder.NAME.value) - if order_by not in ( - RoomSortOrder.ALPHABETICAL.value, - RoomSortOrder.SIZE.value, - RoomSortOrder.NAME.value, - RoomSortOrder.CANONICAL_ALIAS.value, - RoomSortOrder.JOINED_MEMBERS.value, - RoomSortOrder.JOINED_LOCAL_MEMBERS.value, - RoomSortOrder.VERSION.value, - RoomSortOrder.CREATOR.value, - RoomSortOrder.ENCRYPTION.value, - RoomSortOrder.FEDERATABLE.value, - RoomSortOrder.PUBLIC.value, - RoomSortOrder.JOIN_RULES.value, - RoomSortOrder.GUEST_ACCESS.value, - RoomSortOrder.HISTORY_VISIBILITY.value, - RoomSortOrder.STATE_EVENTS.value, - ): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Unknown value for order_by: %s" % (order_by,), - errcode=Codes.INVALID_PARAM, - ) + order_by = parse_string( + request, + "order_by", + default=RoomSortOrder.NAME.value, + allowed_values=[sort_order.value for sort_order in RoomSortOrder], + ) search_term = parse_string(request, "search_term", encoding="utf-8") if search_term == "": @@ -292,10 +274,9 @@ class RoomRestServlet(RestServlet): TODO: Add on_POST to allow room creation without joining the room """ - PATTERNS = admin_patterns("/rooms/(?P[^/]+)$") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.room_shutdown_handler = hs.get_room_shutdown_handler() @@ -397,10 +378,9 @@ class RoomMembersRestServlet(RestServlet): Get members list of a room. """ - PATTERNS = admin_patterns("/rooms/(?P[^/]+)/members") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/members$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -424,10 +404,9 @@ class RoomStateRestServlet(RestServlet): Get full state within a room. """ - PATTERNS = admin_patterns("/rooms/(?P[^/]+)/state") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/state$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -436,8 +415,7 @@ class RoomStateRestServlet(RestServlet): async def on_GET( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) ret = await self.store.get_room(room_id) if not ret: @@ -454,14 +432,14 @@ class RoomStateRestServlet(RestServlet): class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): - PATTERNS = admin_patterns("/join/(?P[^/]*)") + PATTERNS = admin_patterns("/join/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.hs = hs self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() self.state_handler = hs.get_state_handler() + self.is_mine = hs.is_mine async def on_POST( self, request: SynapseRequest, room_identifier: str @@ -477,7 +455,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): assert_params_in_dict(content, ["user_id"]) target_user = UserID.from_string(content["user_id"]) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError( HTTPStatus.BAD_REQUEST, "This endpoint can only be used with local users", @@ -542,11 +520,10 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): } """ - PATTERNS = admin_patterns("/rooms/(?P[^/]*)/make_room_admin") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/make_room_admin$") def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.event_creation_handler = hs.get_event_creation_handler() @@ -688,19 +665,17 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): GET /_synapse/admin/v1/rooms//forward_extremities """ - PATTERNS = admin_patterns("/rooms/(?P[^/]*)/forward_extremities") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/forward_extremities$") def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() async def on_DELETE( self, request: SynapseRequest, room_identifier: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) room_id, _ = await self.resolve_room_id(room_identifier) @@ -710,8 +685,7 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): async def on_GET( self, request: SynapseRequest, room_identifier: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) room_id, _ = await self.resolve_room_id(room_identifier) @@ -793,7 +767,7 @@ class BlockRoomRestServlet(RestServlet): On GET: Get blocking status of room and user who has blocked this room. """ - PATTERNS = admin_patterns("/rooms/(?P[^/]+)/block$") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/block$") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py index b295fb078b..15da9cd881 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py @@ -52,11 +52,11 @@ class SendServerNoticeServlet(RestServlet): """ def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.server_notices_manager = hs.get_server_notices_manager() self.admin_handler = hs.get_admin_handler() self.txns = HttpTransactionCache(hs) + self.is_mine = hs.is_mine def register(self, json_resource: HttpServer) -> None: PATTERN = "/send_server_notice" @@ -88,7 +88,7 @@ class SendServerNoticeServlet(RestServlet): ) target_user = UserID.from_string(body["user_id"]) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError( HTTPStatus.BAD_REQUEST, "Server notices can only be sent to local users" ) diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py index ca41fd45f2..7a6546372e 100644 --- a/synapse/rest/admin/statistics.py +++ b/synapse/rest/admin/statistics.py @@ -37,7 +37,6 @@ class UserMediaStatisticsRestServlet(RestServlet): PATTERNS = admin_patterns("/statistics/users/media$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -45,19 +44,16 @@ class UserMediaStatisticsRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) order_by = parse_string( - request, "order_by", default=UserSortOrder.USER_ID.value + request, + "order_by", + default=UserSortOrder.USER_ID.value, + allowed_values=( + UserSortOrder.MEDIA_LENGTH.value, + UserSortOrder.MEDIA_COUNT.value, + UserSortOrder.USER_ID.value, + UserSortOrder.DISPLAYNAME.value, + ), ) - if order_by not in ( - UserSortOrder.MEDIA_LENGTH.value, - UserSortOrder.MEDIA_COUNT.value, - UserSortOrder.USER_ID.value, - UserSortOrder.DISPLAYNAME.value, - ): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Unknown value for order_by: %s" % (order_by,), - errcode=Codes.INVALID_PARAM, - ) start = parse_integer(request, "from", default=0) if start < 0: diff --git a/synapse/rest/admin/username_available.py b/synapse/rest/admin/username_available.py index 2bf1472967..5353dc3682 100644 --- a/synapse/rest/admin/username_available.py +++ b/synapse/rest/admin/username_available.py @@ -37,7 +37,7 @@ class UsernameAvailableRestServlet(RestServlet): } """ - PATTERNS = admin_patterns("/username_available") + PATTERNS = admin_patterns("/username_available$") def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 2a60b602b1..db678da4cf 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -66,7 +66,6 @@ class UsersRestServletV2(RestServlet): """ def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() @@ -126,7 +125,7 @@ class UsersRestServletV2(RestServlet): class UserRestServletV2(RestServlet): - PATTERNS = admin_patterns("/users/(?P[^/]+)$", "v2") + PATTERNS = admin_patterns("/users/(?P[^/]*)$", "v2") """Get request to list user details. This needs user to have administrator access in Synapse. @@ -414,7 +413,7 @@ class UserRegisterServlet(RestServlet): nonce to the time it was generated, in int seconds. """ - PATTERNS = admin_patterns("/register") + PATTERNS = admin_patterns("/register$") NONCE_TIMEOUT = 60 def __init__(self, hs: "HomeServer"): @@ -561,9 +560,9 @@ class WhoisRestServlet(RestServlet): ] def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str @@ -575,7 +574,7 @@ class WhoisRestServlet(RestServlet): if target_user != auth_user: await assert_user_is_admin(self.auth, auth_user) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user") ret = await self.admin_handler.get_whois(target_user) @@ -584,7 +583,7 @@ class WhoisRestServlet(RestServlet): class DeactivateAccountRestServlet(RestServlet): - PATTERNS = admin_patterns("/deactivate/(?P[^/]*)") + PATTERNS = admin_patterns("/deactivate/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self._deactivate_account_handler = hs.get_deactivate_account_handler() @@ -630,7 +629,6 @@ class AccountValidityRenewServlet(RestServlet): PATTERNS = admin_patterns("/account_validity/validity$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.account_activity_handler = hs.get_account_validity_handler() self.auth = hs.get_auth() @@ -674,11 +672,10 @@ class ResetPasswordRestServlet(RestServlet): 200 OK with empty object if success otherwise an error. """ - PATTERNS = admin_patterns("/reset_password/(?P[^/]*)") + PATTERNS = admin_patterns("/reset_password/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() - self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self._set_password_handler = hs.get_set_password_handler() @@ -718,12 +715,12 @@ class SearchUsersRestServlet(RestServlet): 200 OK with json object {list[dict[str, Any]], count} or empty object. """ - PATTERNS = admin_patterns("/search_users/(?P[^/]*)") + PATTERNS = admin_patterns("/search_users/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, target_user_id: str @@ -740,7 +737,7 @@ class SearchUsersRestServlet(RestServlet): # if not is_admin and target_user != auth_user: # raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin") - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only users a local user") term = parse_string(request, "term", required=True) @@ -779,9 +776,9 @@ class UserAdminServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/admin$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str @@ -790,7 +787,7 @@ class UserAdminServlet(RestServlet): target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be admins of this homeserver", @@ -813,7 +810,7 @@ class UserAdminServlet(RestServlet): assert_params_in_dict(body, ["admin"]) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be admins of this homeserver", @@ -834,7 +831,7 @@ class UserMembershipRestServlet(RestServlet): Get room list of an user. """ - PATTERNS = admin_patterns("/users/(?P[^/]+)/joined_rooms$") + PATTERNS = admin_patterns("/users/(?P[^/]*)/joined_rooms$") def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine @@ -909,10 +906,10 @@ class UserTokenRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/login$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() + self.is_mine_id = hs.is_mine_id async def on_POST( self, request: SynapseRequest, user_id: str @@ -921,7 +918,7 @@ class UserTokenRestServlet(RestServlet): await assert_user_is_admin(self.auth, requester.user) auth_user = requester.user - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be logged in as" ) @@ -975,19 +972,19 @@ class ShadowBanRestServlet(RestServlet): {} """ - PATTERNS = admin_patterns("/users/(?P[^/]*)/shadow_ban") + PATTERNS = admin_patterns("/users/(?P[^/]*)/shadow_ban$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine_id = hs.is_mine_id async def on_POST( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned" ) @@ -1001,7 +998,7 @@ class ShadowBanRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned" ) @@ -1027,19 +1024,19 @@ class RateLimitRestServlet(RestServlet): } """ - PATTERNS = admin_patterns("/users/(?P[^/]*)/override_ratelimit") + PATTERNS = admin_patterns("/users/(?P[^/]*)/override_ratelimit$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine_id = hs.is_mine_id async def on_GET( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") if not await self.store.get_user_by_id(user_id): @@ -1068,7 +1065,7 @@ class RateLimitRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited" ) @@ -1113,7 +1110,7 @@ class RateLimitRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited" ) diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index 7cb8ec57ba..f6e85fdaad 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -92,7 +92,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): channel.code, msg=channel.json_body, ) - self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # negative from channel = self.make_request( -- cgit 1.5.1 From b3bcacf3c1c72bfadeb46fe4d0198ca155a8c615 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 9 Dec 2021 12:23:34 +0100 Subject: Add missing `errcode` to `parse_string` and `parse_boolean` (#11542) --- changelog.d/11542.misc | 1 + synapse/http/servlet.py | 4 ++-- tests/rest/admin/test_federation.py | 4 ++-- tests/rest/admin/test_media.py | 2 +- tests/rest/admin/test_statistics.py | 2 +- tests/rest/admin/test_user.py | 12 ++++++------ 6 files changed, 13 insertions(+), 12 deletions(-) create mode 100644 changelog.d/11542.misc (limited to 'tests') diff --git a/changelog.d/11542.misc b/changelog.d/11542.misc new file mode 100644 index 0000000000..f614165037 --- /dev/null +++ b/changelog.d/11542.misc @@ -0,0 +1 @@ +Add missing `errcode` to `parse_string` and `parse_boolean`. \ No newline at end of file diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 6dd9b9ad03..1627225f28 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -246,7 +246,7 @@ def parse_boolean_from_args( message = ( "Boolean query parameter %r must be one of ['true', 'false']" ) % (name,) - raise SynapseError(400, message) + raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) else: if required: message = "Missing boolean query parameter %r" % (name,) @@ -414,7 +414,7 @@ def _parse_string_value( name, ", ".join(repr(v) for v in allowed_values), ) - raise SynapseError(400, message) + raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) else: return value_str diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py index 5188499ef2..d1cd5b0751 100644 --- a/tests/rest/admin/test_federation.py +++ b/tests/rest/admin/test_federation.py @@ -95,7 +95,7 @@ class FederationTestCase(unittest.HomeserverTestCase): ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid search order channel = self.make_request( @@ -105,7 +105,7 @@ class FederationTestCase(unittest.HomeserverTestCase): ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid destination channel = self.make_request( diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 81e578fd26..3f727788ce 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -360,7 +360,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): channel.code, msg=channel.json_body, ) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Boolean query parameter 'keep_profiles' must be one of ['true', 'false']", channel.json_body["error"], diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index f6e85fdaad..7cb8ec57ba 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -92,7 +92,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): channel.code, msg=channel.json_body, ) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from channel = self.make_request( diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 4fedd5fd08..294d429ce1 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -608,7 +608,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid deactivated channel = self.make_request( @@ -618,7 +618,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # unkown order_by channel = self.make_request( @@ -628,7 +628,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid search order channel = self.make_request( @@ -638,7 +638,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_limit(self): """ @@ -2896,7 +2896,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid search order channel = self.make_request( @@ -2906,7 +2906,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative limit channel = self.make_request( -- cgit 1.5.1 From b47d10dc46e4644c432f017d5b2129ff7a349166 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 9 Dec 2021 06:41:27 -0500 Subject: Support unprefixed versions of fallback key property names. (#11541) --- changelog.d/11541.misc | 1 + synapse/handlers/e2e_keys.py | 4 +++- synapse/rest/client/sync.py | 3 +++ tests/handlers/test_e2e_keys.py | 30 +++++++++++++++++++++++++----- 4 files changed, 32 insertions(+), 6 deletions(-) create mode 100644 changelog.d/11541.misc (limited to 'tests') diff --git a/changelog.d/11541.misc b/changelog.d/11541.misc new file mode 100644 index 0000000000..31c72c2a20 --- /dev/null +++ b/changelog.d/11541.misc @@ -0,0 +1 @@ +Support unprefixed versions of fallback key property names. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index b2554bda04..14360b4e40 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -580,7 +580,9 @@ class E2eKeysHandler: log_kv( {"message": "Did not update one_time_keys", "reason": "no keys given"} ) - fallback_keys = keys.get("org.matrix.msc2732.fallback_keys", None) + fallback_keys = keys.get("fallback_keys") or keys.get( + "org.matrix.msc2732.fallback_keys" + ) if fallback_keys and isinstance(fallback_keys, dict): log_kv( { diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 88e4f5e063..dd90ffa123 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -293,6 +293,9 @@ class SyncRestServlet(RestServlet): response[ "org.matrix.msc2732.device_unused_fallback_key_types" ] = sync_result.device_unused_fallback_key_types + response[ + "device_unused_fallback_key_types" + ] = sync_result.device_unused_fallback_key_types if joined: response["rooms"][Membership.JOIN] = joined diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index f0723892e4..ddcf3ee348 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -161,8 +161,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): def test_fallback_key(self): local_user = "@boris:" + self.hs.hostname device_id = "xyz" - fallback_key = {"alg1:k1": "key1"} - fallback_key2 = {"alg1:k2": "key2"} + fallback_key = {"alg1:k1": "fallback_key1"} + fallback_key2 = {"alg1:k2": "fallback_key2"} + fallback_key3 = {"alg1:k2": "fallback_key3"} otk = {"alg1:k2": "key2"} # we shouldn't have any unused fallback keys yet @@ -175,7 +176,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): self.handler.upload_keys_for_user( local_user, device_id, - {"org.matrix.msc2732.fallback_keys": fallback_key}, + {"fallback_keys": fallback_key}, ) ) @@ -220,7 +221,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): self.handler.upload_keys_for_user( local_user, device_id, - {"org.matrix.msc2732.fallback_keys": fallback_key}, + {"fallback_keys": fallback_key}, ) ) @@ -234,7 +235,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): self.handler.upload_keys_for_user( local_user, device_id, - {"org.matrix.msc2732.fallback_keys": fallback_key2}, + {"fallback_keys": fallback_key2}, ) ) @@ -271,6 +272,25 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}}, ) + # using the unstable prefix should also set the fallback key + self.get_success( + self.handler.upload_keys_for_user( + local_user, + device_id, + {"org.matrix.msc2732.fallback_keys": fallback_key3}, + ) + ) + + res = self.get_success( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + ) + ) + self.assertEqual( + res, + {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}}, + ) + def test_replace_master_key(self): """uploading a new signing key should make the old signing key unavailable""" local_user = "@boris:" + self.hs.hostname -- cgit 1.5.1 From 3b8872299aac25a7e3ee5a9e00564105aa6de237 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 9 Dec 2021 13:16:01 -0500 Subject: Do not allow cross-room relations, per MSC2674. (#11516) --- changelog.d/11516.bugfix | 1 + synapse/events/utils.py | 11 ++- synapse/rest/client/relations.py | 7 +- synapse/storage/databases/main/events.py | 8 +- synapse/storage/databases/main/relations.py | 36 ++++++--- tests/rest/client/test_relations.py | 115 ++++++++++++++++++++++++++++ 6 files changed, 161 insertions(+), 17 deletions(-) create mode 100644 changelog.d/11516.bugfix (limited to 'tests') diff --git a/changelog.d/11516.bugfix b/changelog.d/11516.bugfix new file mode 100644 index 0000000000..22bba93671 --- /dev/null +++ b/changelog.d/11516.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where relations from other rooms could be included in the bundled aggregations of an event. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 84ef69df67..3da432c1df 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -454,23 +454,26 @@ class EventClientSerializer: return event_id = event.event_id + room_id = event.room_id # The bundled aggregations to include. aggregations = {} - annotations = await self.store.get_aggregation_groups_for_event(event_id) + annotations = await self.store.get_aggregation_groups_for_event( + event_id, room_id + ) if annotations.chunk: aggregations[RelationTypes.ANNOTATION] = annotations.to_dict() references = await self.store.get_relations_for_event( - event_id, RelationTypes.REFERENCE, direction="f" + event_id, room_id, RelationTypes.REFERENCE, direction="f" ) if references.chunk: aggregations[RelationTypes.REFERENCE] = references.to_dict() edit = None if event.type == EventTypes.Message: - edit = await self.store.get_applicable_edit(event_id) + edit = await self.store.get_applicable_edit(event_id, room_id) if edit: # If there is an edit replace the content, preserving existing @@ -503,7 +506,7 @@ class EventClientSerializer: ( thread_count, latest_thread_event, - ) = await self.store.get_thread_summary(event_id) + ) = await self.store.get_thread_summary(event_id, room_id) if latest_thread_event: aggregations[RelationTypes.THREAD] = { # Don't bundle aggregations as this could recurse forever. diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index fc4e6921c5..ffa37ef06c 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -212,6 +212,7 @@ class RelationPaginationServlet(RestServlet): pagination_chunk = await self.store.get_relations_for_event( event_id=parent_id, + room_id=room_id, relation_type=relation_type, event_type=event_type, limit=limit, @@ -317,6 +318,7 @@ class RelationAggregationPaginationServlet(RestServlet): pagination_chunk = await self.store.get_aggregation_groups_for_event( event_id=parent_id, + room_id=room_id, event_type=event_type, limit=limit, from_token=from_token, @@ -383,7 +385,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet): # This checks that a) the event exists and b) the user is allowed to # view it. - await self.event_handler.get_event(requester.user, room_id, parent_id) + event = await self.event_handler.get_event(requester.user, room_id, parent_id) + if event is None: + raise SynapseError(404, "Unknown parent event.") if relation_type != RelationTypes.ANNOTATION: raise SynapseError(400, "Relation type must be 'annotation'") @@ -402,6 +406,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): result = await self.store.get_relations_for_event( event_id=parent_id, + room_id=room_id, relation_type=relation_type, event_type=event_type, aggregation_key=key, diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 4e528612ea..f1f4ce5e07 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1780,10 +1780,14 @@ class PersistEventsStore: ) if rel_type == RelationTypes.REPLACE: - txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) + txn.call_after( + self.store.get_applicable_edit.invalidate, (parent_id, event.room_id) + ) if rel_type == RelationTypes.THREAD: - txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,)) + txn.call_after( + self.store.get_thread_summary.invalidate, (parent_id, event.room_id) + ) def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase): """Handles keeping track of insertion events and edges/connections. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 0a43acda07..3368a8b084 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -37,6 +37,7 @@ class RelationsWorkerStore(SQLBaseStore): async def get_relations_for_event( self, event_id: str, + room_id: str, relation_type: Optional[str] = None, event_type: Optional[str] = None, aggregation_key: Optional[str] = None, @@ -49,6 +50,7 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id: Fetch events that relate to this event ID. + room_id: The room the event belongs to. relation_type: Only fetch events with this relation type, if given. event_type: Only fetch events with this event type, if given. aggregation_key: Only fetch events with this aggregation key, if given. @@ -63,8 +65,8 @@ class RelationsWorkerStore(SQLBaseStore): the form `{"event_id": "..."}`. """ - where_clause = ["relates_to_id = ?"] - where_args: List[Union[str, int]] = [event_id] + where_clause = ["relates_to_id = ?", "room_id = ?"] + where_args: List[Union[str, int]] = [event_id, room_id] if relation_type is not None: where_clause.append("relation_type = ?") @@ -199,6 +201,7 @@ class RelationsWorkerStore(SQLBaseStore): async def get_aggregation_groups_for_event( self, event_id: str, + room_id: str, event_type: Optional[str] = None, limit: int = 5, direction: str = "b", @@ -213,6 +216,7 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id: Fetch events that relate to this event ID. + room_id: The room the event belongs to. event_type: Only fetch events with this event type, if given. limit: Only fetch the `limit` groups. direction: Whether to fetch the highest count first (`"b"`) or @@ -225,8 +229,12 @@ class RelationsWorkerStore(SQLBaseStore): `type`, `key` and `count` fields. """ - where_clause = ["relates_to_id = ?", "relation_type = ?"] - where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION] + where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"] + where_args: List[Union[str, int]] = [ + event_id, + room_id, + RelationTypes.ANNOTATION, + ] if event_type: where_clause.append("type = ?") @@ -288,7 +296,9 @@ class RelationsWorkerStore(SQLBaseStore): ) @cached() - async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: + async def get_applicable_edit( + self, event_id: str, room_id: str + ) -> Optional[EventBase]: """Get the most recent edit (if any) that has happened for the given event. @@ -296,6 +306,7 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id: The original event ID + room_id: The original event's room ID Returns: The most recent edit, if any. @@ -317,13 +328,14 @@ class RelationsWorkerStore(SQLBaseStore): WHERE relates_to_id = ? AND relation_type = ? + AND edit.room_id = ? AND edit.type = 'm.room.message' ORDER by edit.origin_server_ts DESC, edit.event_id DESC LIMIT 1 """ def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]: - txn.execute(sql, (event_id, RelationTypes.REPLACE)) + txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id)) row = txn.fetchone() if row: return row[0] @@ -340,13 +352,14 @@ class RelationsWorkerStore(SQLBaseStore): @cached() async def get_thread_summary( - self, event_id: str + self, event_id: str, room_id: str ) -> Tuple[int, Optional[EventBase]]: """Get the number of threaded replies, the senders of those replies, and the latest reply (if any) for the given event. Args: - event_id: The original event ID + event_id: Summarize the thread related to this event ID. + room_id: The room the event belongs to. Returns: The number of items in the thread and the most recent response, if any. @@ -363,12 +376,13 @@ class RelationsWorkerStore(SQLBaseStore): INNER JOIN events USING (event_id) WHERE relates_to_id = ? + AND room_id = ? AND relation_type = ? ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT 1 """ - txn.execute(sql, (event_id, RelationTypes.THREAD)) + txn.execute(sql, (event_id, room_id, RelationTypes.THREAD)) row = txn.fetchone() if row is None: return 0, None @@ -378,11 +392,13 @@ class RelationsWorkerStore(SQLBaseStore): sql = """ SELECT COALESCE(COUNT(event_id), 0) FROM event_relations + INNER JOIN events USING (event_id) WHERE relates_to_id = ? + AND room_id = ? AND relation_type = ? """ - txn.execute(sql, (event_id, RelationTypes.THREAD)) + txn.execute(sql, (event_id, room_id, RelationTypes.THREAD)) count = txn.fetchone()[0] # type: ignore[index] return count, latest_event_id diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 397c12c2a6..55f4f0b1d0 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -16,6 +16,7 @@ import itertools import urllib.parse from typing import Dict, List, Optional, Tuple +from unittest.mock import patch from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin @@ -23,6 +24,8 @@ from synapse.rest.client import login, register, relations, room, sync from tests import unittest from tests.server import FakeChannel +from tests.test_utils import make_awaitable +from tests.test_utils.event_injection import inject_event class RelationsTestCase(unittest.HomeserverTestCase): @@ -651,6 +654,118 @@ class RelationsTestCase(unittest.HomeserverTestCase): }, ) + @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) + def test_ignore_invalid_room(self): + """Test that we ignore invalid relations over federation.""" + # Create another room and send a message in it. + room2 = self.helper.create_room_as(self.user_id, tok=self.user_token) + res = self.helper.send(room2, body="Hi!", tok=self.user_token) + parent_id = res["event_id"] + + # Disable the validation to pretend this came over federation. + with patch( + "synapse.handlers.message.EventCreationHandler._validate_event_relation", + new=lambda self, event: make_awaitable(None), + ): + # Generate a various relations from a different room. + self.get_success( + inject_event( + self.hs, + room_id=self.room, + type="m.reaction", + sender=self.user_id, + content={ + "m.relates_to": { + "rel_type": RelationTypes.ANNOTATION, + "event_id": parent_id, + "key": "A", + } + }, + ) + ) + + self.get_success( + inject_event( + self.hs, + room_id=self.room, + type="m.room.message", + sender=self.user_id, + content={ + "body": "foo", + "msgtype": "m.text", + "m.relates_to": { + "rel_type": RelationTypes.REFERENCE, + "event_id": parent_id, + }, + }, + ) + ) + + self.get_success( + inject_event( + self.hs, + room_id=self.room, + type="m.room.message", + sender=self.user_id, + content={ + "body": "foo", + "msgtype": "m.text", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": parent_id, + }, + }, + ) + ) + + self.get_success( + inject_event( + self.hs, + room_id=self.room, + type="m.room.message", + sender=self.user_id, + content={ + "body": "foo", + "msgtype": "m.text", + "new_content": { + "body": "new content", + "msgtype": "m.text", + }, + "m.relates_to": { + "rel_type": RelationTypes.REPLACE, + "event_id": parent_id, + }, + }, + ) + ) + + # They should be ignored when fetching relations. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{room2}/relations/{parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(channel.json_body["chunk"], []) + + # And when fetching aggregations. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{room2}/aggregations/{parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(channel.json_body["chunk"], []) + + # And for bundled aggregations. + channel = self.make_request( + "GET", + f"/rooms/{room2}/event/{parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertNotIn("m.relations", channel.json_body["unsigned"]) + def test_edit(self): """Test that a simple edit works.""" -- cgit 1.5.1 From 9562f0c2f1bd3489bfbb64fddbbd21ed657b44dd Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 10 Dec 2021 07:17:28 -0500 Subject: Ensure emails are canonicalized before fetching associated user. (#11547) This should fix pushers with an email in non-canonical form is used as the pushkey. --- changelog.d/11547.bugfix | 1 + synapse/push/pusherpool.py | 5 ++++- synapse/storage/databases/main/monthly_active_users.py | 3 ++- synapse/storage/databases/main/registration.py | 3 ++- tests/rest/admin/test_user.py | 3 ++- 5 files changed, 11 insertions(+), 4 deletions(-) create mode 100644 changelog.d/11547.bugfix (limited to 'tests') diff --git a/changelog.d/11547.bugfix b/changelog.d/11547.bugfix new file mode 100644 index 0000000000..3950c4c8d3 --- /dev/null +++ b/changelog.d/11547.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.17.0 where a pusher created for an email with capital letters would fail to be created. diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 26735447a6..7912311d24 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -27,6 +27,7 @@ from synapse.push.pusher import PusherFactory from synapse.replication.http.push import ReplicationRemovePusherRestServlet from synapse.types import JsonDict, RoomStreamToken from synapse.util.async_helpers import concurrently_execute +from synapse.util.threepids import canonicalise_email if TYPE_CHECKING: from synapse.server import HomeServer @@ -113,7 +114,9 @@ class PusherPool: """ if kind == "email": - email_owner = await self.store.get_user_id_by_threepid("email", pushkey) + email_owner = await self.store.get_user_id_by_threepid( + "email", canonicalise_email(pushkey) + ) if email_owner != user_id: raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index b5284e4f67..3c98ef876f 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -18,6 +18,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_in_list_sql_clause from synapse.util.caches.descriptors import cached +from synapse.util.threepids import canonicalise_email if TYPE_CHECKING: from synapse.server import HomeServer @@ -103,7 +104,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): : self.hs.config.server.max_mau_value ]: user_id = await self.hs.get_datastore().get_user_id_by_threepid( - tp["medium"], tp["address"] + tp["medium"], canonicalise_email(tp["address"]) ) if user_id: users.append(user_id) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index e1ddf06916..86c3425716 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -856,7 +856,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): Args: medium: threepid medium e.g. email - address: threepid address e.g. me@example.com + address: threepid address e.g. me@example.com. This must already be + in canonical form. Returns: The user ID or None if no user id/threepid mapping exists diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 294d429ce1..eea675991c 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1550,7 +1550,8 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Create user body = { "password": "abc123", - "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + # Note that the given email is not in canonical form. + "threepids": [{"medium": "email", "address": "Bob@bob.bob"}], } channel = self.make_request( -- cgit 1.5.1 From 8391bd6ab59387845bae77130dc0ca437c37fb8e Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 10 Dec 2021 20:59:20 -0600 Subject: Test to ensure we share the same `state_group` across the whole historical batch (MSC2716) (#11487) Part of MSC2716: https://github.com/matrix-org/matrix-doc/pull/2716 We did some work on making sure the `state_groups` were shared in https://github.com/matrix-org/synapse/pull/10975 --- changelog.d/11487.misc | 1 + tests/rest/client/test_room_batch.py | 180 +++++++++++++++++++++++++++++++++++ 2 files changed, 181 insertions(+) create mode 100644 changelog.d/11487.misc create mode 100644 tests/rest/client/test_room_batch.py (limited to 'tests') diff --git a/changelog.d/11487.misc b/changelog.d/11487.misc new file mode 100644 index 0000000000..376b9078be --- /dev/null +++ b/changelog.d/11487.misc @@ -0,0 +1 @@ +Add test to ensure we share the same `state_group` across the whole historical batch when using the [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send` endpoint. diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py new file mode 100644 index 0000000000..721454c187 --- /dev/null +++ b/tests/rest/client/test_room_batch.py @@ -0,0 +1,180 @@ +import logging +from typing import List, Tuple +from unittest.mock import Mock, patch + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import EventContentFields, EventTypes +from synapse.appservice import ApplicationService +from synapse.rest import admin +from synapse.rest.client import login, register, room, room_batch +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock + +from tests import unittest + +logger = logging.getLogger(__name__) + + +def _create_join_state_events_for_batch_send_request( + virtual_user_ids: List[str], + insert_time: int, +) -> List[JsonDict]: + return [ + { + "type": EventTypes.Member, + "sender": virtual_user_id, + "origin_server_ts": insert_time, + "content": { + "membership": "join", + "displayname": "display-name-for-%s" % (virtual_user_id,), + }, + "state_key": virtual_user_id, + } + for virtual_user_id in virtual_user_ids + ] + + +def _create_message_events_for_batch_send_request( + virtual_user_id: str, insert_time: int, count: int +) -> List[JsonDict]: + return [ + { + "type": EventTypes.Message, + "sender": virtual_user_id, + "origin_server_ts": insert_time, + "content": { + "msgtype": "m.text", + "body": "Historical %d" % (i), + EventContentFields.MSC2716_HISTORICAL: True, + }, + } + for i in range(count) + ] + + +class RoomBatchTestCase(unittest.HomeserverTestCase): + """Test importing batches of historical messages.""" + + servlets = [ + admin.register_servlets_for_client_rest_resource, + room_batch.register_servlets, + room.register_servlets, + register.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + + self.appservice = ApplicationService( + token="i_am_an_app_service", + hostname="test", + id="1234", + namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + # Note: this user does not have to match the regex above + sender="@as_main:test", + ) + + mock_load_appservices = Mock(return_value=[self.appservice]) + with patch( + "synapse.storage.databases.main.appservice.load_appservices", + mock_load_appservices, + ): + hs = self.setup_test_homeserver(config=config) + return hs + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.clock = clock + self.storage = hs.get_storage() + + self.virtual_user_id = self.register_appservice_user( + "as_user_potato", self.appservice.token + ) + + def _create_test_room(self) -> Tuple[str, str, str, str]: + room_id = self.helper.create_room_as( + self.appservice.sender, tok=self.appservice.token + ) + + res_a = self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "A", + }, + tok=self.appservice.token, + ) + event_id_a = res_a["event_id"] + + res_b = self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "B", + }, + tok=self.appservice.token, + ) + event_id_b = res_b["event_id"] + + res_c = self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "C", + }, + tok=self.appservice.token, + ) + event_id_c = res_c["event_id"] + + return room_id, event_id_a, event_id_b, event_id_c + + @unittest.override_config({"experimental_features": {"msc2716_enabled": True}}) + def test_same_state_groups_for_whole_historical_batch(self): + """Make sure that when using the `/batch_send` endpoint to import a + bunch of historical messages, it re-uses the same `state_group` across + the whole batch. This is an easy optimization to make sure we're getting + right because the state for the whole batch is contained in + `state_events_at_start` and can be shared across everything. + """ + + time_before_room = int(self.clock.time_msec()) + room_id, event_id_a, _, _ = self._create_test_room() + + channel = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2716/rooms/%s/batch_send?prev_event_id=%s" + % (room_id, event_id_a), + content={ + "events": _create_message_events_for_batch_send_request( + self.virtual_user_id, time_before_room, 3 + ), + "state_events_at_start": _create_join_state_events_for_batch_send_request( + [self.virtual_user_id], time_before_room + ), + }, + access_token=self.appservice.token, + ) + self.assertEqual(channel.code, 200, channel.result) + + # Get the historical event IDs that we just imported + historical_event_ids = channel.json_body["event_ids"] + self.assertEqual(len(historical_event_ids), 3) + + # Fetch the state_groups + state_group_map = self.get_success( + self.storage.state.get_state_groups_ids(room_id, historical_event_ids) + ) + + # We expect all of the historical events to be using the same state_group + # so there should only be a single state_group here! + self.assertEqual( + len(state_group_map.keys()), + 1, + "Expected a single state_group to be returned by saw state_groups=%s" + % (state_group_map.keys(),), + ) -- cgit 1.5.1 From aa8708ebed74b03bdebd7e20ddf070c6fd620db1 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 10 Dec 2021 23:08:51 -0600 Subject: Allow events to be created with no `prev_events` (MSC2716) (#11243) The event still needs to have `auth_events` defined to be valid. Split out from https://github.com/matrix-org/synapse/pull/11114 --- changelog.d/11243.misc | 1 + synapse/handlers/message.py | 24 +++++++--- synapse/handlers/room_member.py | 3 +- tests/handlers/test_message.py | 103 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 124 insertions(+), 7 deletions(-) create mode 100644 changelog.d/11243.misc (limited to 'tests') diff --git a/changelog.d/11243.misc b/changelog.d/11243.misc new file mode 100644 index 0000000000..5ef7fe16d4 --- /dev/null +++ b/changelog.d/11243.misc @@ -0,0 +1 @@ +Allow specific, experimental events to be created without `prev_events`. Used by [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716). diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 87f671708c..38409fef38 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -496,6 +496,7 @@ class EventCreationHandler: require_consent: bool = True, outlier: bool = False, historical: bool = False, + allow_no_prev_events: bool = False, depth: Optional[int] = None, ) -> Tuple[EventBase, EventContext]: """ @@ -607,6 +608,7 @@ class EventCreationHandler: prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids, depth=depth, + allow_no_prev_events=allow_no_prev_events, ) # In an ideal world we wouldn't need the second part of this condition. However, @@ -882,6 +884,7 @@ class EventCreationHandler: prev_event_ids: Optional[List[str]] = None, auth_event_ids: Optional[List[str]] = None, depth: Optional[int] = None, + allow_no_prev_events: bool = False, ) -> Tuple[EventBase, EventContext]: """Create a new event for a local client @@ -912,6 +915,7 @@ class EventCreationHandler: full_state_ids_at_event = None if auth_event_ids is not None: # If auth events are provided, prev events must be also. + # prev_event_ids could be an empty array though. assert prev_event_ids is not None # Copy the full auth state before it stripped down @@ -943,14 +947,22 @@ class EventCreationHandler: else: prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id) - # we now ought to have some prev_events (unless it's a create event). - # - # do a quick sanity check here, rather than waiting until we've created the + # Do a quick sanity check here, rather than waiting until we've created the # event and then try to auth it (which fails with a somewhat confusing "No # create event in auth events") - assert ( - builder.type == EventTypes.Create or len(prev_event_ids) > 0 - ), "Attempting to create an event with no prev_events" + if allow_no_prev_events: + # We allow events with no `prev_events` but it better have some `auth_events` + assert ( + builder.type == EventTypes.Create + # Allow an event to have empty list of prev_event_ids + # only if it has auth_event_ids. + or auth_event_ids + ), "Attempting to create a non-m.room.create event with no prev_events or auth_event_ids" + else: + # we now ought to have some prev_events (unless it's a create event). + assert ( + builder.type == EventTypes.Create or prev_event_ids + ), "Attempting to create a non-m.room.create event with no prev_events" event = await builder.build( prev_event_ids=prev_event_ids, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index a6dbff637f..447e3ce571 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -658,7 +658,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if block_invite: raise SynapseError(403, "Invites have been disabled on this server") - if prev_event_ids: + # An empty prev_events list is allowed as long as the auth_event_ids are present + if prev_event_ids is not None: return await self._local_membership_update( requester=requester, target=target, diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 8a8d369fac..5816295d8b 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -23,6 +23,7 @@ from synapse.types import create_requester from synapse.util.stringutils import random_string from tests import unittest +from tests.test_utils.event_injection import create_event logger = logging.getLogger(__name__) @@ -51,6 +52,24 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.requester = create_requester(self.user_id, access_token_id=self.token_id) + def _create_and_persist_member_event(self) -> Tuple[EventBase, EventContext]: + # Create a member event we can use as an auth_event + memberEvent, memberEventContext = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + type="m.room.member", + sender=self.requester.user.to_string(), + state_key=self.requester.user.to_string(), + content={"membership": "join"}, + ) + ) + self.get_success( + self.persist_event_storage.persist_event(memberEvent, memberEventContext) + ) + + return memberEvent, memberEventContext + def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]: """Create a new event with the given transaction ID. All events produced by this method will be considered duplicates. @@ -156,6 +175,90 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.assertEqual(len(events), 2) self.assertEqual(events[0].event_id, events[1].event_id) + def test_when_empty_prev_events_allowed_create_event_with_empty_prev_events(self): + """When we set allow_no_prev_events=True, should be able to create a + event without any prev_events (only auth_events). + """ + # Create a member event we can use as an auth_event + memberEvent, _ = self._create_and_persist_member_event() + + # Try to create the event with empty prev_events bit with some auth_events + event, _ = self.get_success( + self.handler.create_event( + self.requester, + { + "type": EventTypes.Message, + "room_id": self.room_id, + "sender": self.requester.user.to_string(), + "content": {"msgtype": "m.text", "body": random_string(5)}, + }, + # Empty prev_events is the key thing we're testing here + prev_event_ids=[], + # But with some auth_events + auth_event_ids=[memberEvent.event_id], + # Allow no prev_events! + allow_no_prev_events=True, + ) + ) + self.assertIsNotNone(event) + + def test_when_empty_prev_events_not_allowed_reject_event_with_empty_prev_events( + self, + ): + """When we set allow_no_prev_events=False, shouldn't be able to create a + event without any prev_events even if it has auth_events. Expect an + exception to be raised. + """ + # Create a member event we can use as an auth_event + memberEvent, _ = self._create_and_persist_member_event() + + # Try to create the event with empty prev_events but with some auth_events + self.get_failure( + self.handler.create_event( + self.requester, + { + "type": EventTypes.Message, + "room_id": self.room_id, + "sender": self.requester.user.to_string(), + "content": {"msgtype": "m.text", "body": random_string(5)}, + }, + # Empty prev_events is the key thing we're testing here + prev_event_ids=[], + # But with some auth_events + auth_event_ids=[memberEvent.event_id], + # We expect the test to fail because empty prev_events are not + # allowed here! + allow_no_prev_events=False, + ), + AssertionError, + ) + + def test_when_empty_prev_events_allowed_reject_event_with_empty_prev_events_and_auth_events( + self, + ): + """When we set allow_no_prev_events=True, should be able to create a + event without any prev_events or auth_events. Expect an exception to be + raised. + """ + # Try to create the event with empty prev_events and empty auth_events + self.get_failure( + self.handler.create_event( + self.requester, + { + "type": EventTypes.Message, + "room_id": self.room_id, + "sender": self.requester.user.to_string(), + "content": {"msgtype": "m.text", "body": random_string(5)}, + }, + prev_event_ids=[], + # The event should be rejected when there are no auth_events + auth_event_ids=[], + # Allow no prev_events! + allow_no_prev_events=True, + ), + AssertionError, + ) + class ServerAclValidationTestCase(unittest.HomeserverTestCase): servlets = [ -- cgit 1.5.1 From eb39da6782b57c939450839097f32a14cba3ebfc Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 13 Dec 2021 12:55:07 -0500 Subject: Move HTML parsing to a separate file for URL previews. (#11566) * Splits the logic for parsing HTML from the resource handling code. * Fix a circular import in the oEmbed code (which uses the HTML parsing code). * Renames some of the HTML parsing methods to: * Make it clear which methods are "internal" to the module. * Clarify what the methods do. --- changelog.d/11566.misc | 1 + synapse/rest/media/v1/oembed.py | 5 +- synapse/rest/media/v1/preview_html.py | 397 ++++++++++++++++++++++++++ synapse/rest/media/v1/preview_url_resource.py | 383 +------------------------ tests/rest/media/v1/test_url_preview.py | 1 + tests/test_preview.py | 46 +-- 6 files changed, 432 insertions(+), 401 deletions(-) create mode 100644 changelog.d/11566.misc create mode 100644 synapse/rest/media/v1/preview_html.py (limited to 'tests') diff --git a/changelog.d/11566.misc b/changelog.d/11566.misc new file mode 100644 index 0000000000..c48e73cd48 --- /dev/null +++ b/changelog.d/11566.misc @@ -0,0 +1 @@ +Split the HTML parsing code from the URL preview resource code. diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py index 2a59552c20..cce1527ed9 100644 --- a/synapse/rest/media/v1/oembed.py +++ b/synapse/rest/media/v1/oembed.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, List, Optional import attr +from synapse.rest.media.v1.preview_html import parse_html_description from synapse.types import JsonDict from synapse.util import json_decoder @@ -245,8 +246,6 @@ def calc_description_and_urls(open_graph_response: JsonDict, html_body: str) -> if video_urls: open_graph_response["og:video"] = video_urls[0] - from synapse.rest.media.v1.preview_url_resource import _calc_description - - description = _calc_description(tree) + description = parse_html_description(tree) if description: open_graph_response["og:description"] = description diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py new file mode 100644 index 0000000000..30b067dd42 --- /dev/null +++ b/synapse/rest/media/v1/preview_html.py @@ -0,0 +1,397 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import codecs +import itertools +import logging +import re +from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Union +from urllib import parse as urlparse + +if TYPE_CHECKING: + from lxml import etree + +logger = logging.getLogger(__name__) + +_charset_match = re.compile( + br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9_-]+)"?', flags=re.I +) +_xml_encoding_match = re.compile( + br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9_-]+)"', flags=re.I +) +_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I) + + +def _normalise_encoding(encoding: str) -> Optional[str]: + """Use the Python codec's name as the normalised entry.""" + try: + return codecs.lookup(encoding).name + except LookupError: + return None + + +def _get_html_media_encodings( + body: bytes, content_type: Optional[str] +) -> Iterable[str]: + """ + Get potential encoding of the body based on the (presumably) HTML body or the content-type header. + + The precedence used for finding a character encoding is: + + 1. tag with a charset declared. + 2. The XML document's character encoding attribute. + 3. The Content-Type header. + 4. Fallback to utf-8. + 5. Fallback to windows-1252. + + This roughly follows the algorithm used by BeautifulSoup's bs4.dammit.EncodingDetector. + + Args: + body: The HTML document, as bytes. + content_type: The Content-Type header. + + Returns: + The character encoding of the body, as a string. + """ + # There's no point in returning an encoding more than once. + attempted_encodings: Set[str] = set() + + # Limit searches to the first 1kb, since it ought to be at the top. + body_start = body[:1024] + + # Check if it has an encoding set in a meta tag. + match = _charset_match.search(body_start) + if match: + encoding = _normalise_encoding(match.group(1).decode("ascii")) + if encoding: + attempted_encodings.add(encoding) + yield encoding + + # TODO Support + + # Check if it has an XML document with an encoding. + match = _xml_encoding_match.match(body_start) + if match: + encoding = _normalise_encoding(match.group(1).decode("ascii")) + if encoding and encoding not in attempted_encodings: + attempted_encodings.add(encoding) + yield encoding + + # Check the HTTP Content-Type header for a character set. + if content_type: + content_match = _content_type_match.match(content_type) + if content_match: + encoding = _normalise_encoding(content_match.group(1)) + if encoding and encoding not in attempted_encodings: + attempted_encodings.add(encoding) + yield encoding + + # Finally, fallback to UTF-8, then windows-1252. + for fallback in ("utf-8", "cp1252"): + if fallback not in attempted_encodings: + yield fallback + + +def decode_body( + body: bytes, uri: str, content_type: Optional[str] = None +) -> Optional["etree.Element"]: + """ + This uses lxml to parse the HTML document. + + Args: + body: The HTML document, as bytes. + uri: The URI used to download the body. + content_type: The Content-Type header. + + Returns: + The parsed HTML body, or None if an error occurred during processed. + """ + # If there's no body, nothing useful is going to be found. + if not body: + return None + + # The idea here is that multiple encodings are tried until one works. + # Unfortunately the result is never used and then LXML will decode the string + # again with the found encoding. + for encoding in _get_html_media_encodings(body, content_type): + try: + body.decode(encoding) + except Exception: + pass + else: + break + else: + logger.warning("Unable to decode HTML body for %s", uri) + return None + + from lxml import etree + + # Create an HTML parser. + parser = etree.HTMLParser(recover=True, encoding=encoding) + + # Attempt to parse the body. Returns None if the body was successfully + # parsed, but no tree was found. + return etree.fromstring(body, parser) + + +def parse_html_to_open_graph( + tree: "etree.Element", media_uri: str +) -> Dict[str, Optional[str]]: + """ + Parse the HTML document into an Open Graph response. + + This uses lxml to search the HTML document for Open Graph data (or + synthesizes it from the document). + + Args: + tree: The parsed HTML document. + media_url: The URI used to download the body. + + Returns: + The Open Graph response as a dictionary. + """ + + # if we see any image URLs in the OG response, then spider them + # (although the client could choose to do this by asking for previews of those + # URLs to avoid DoSing the server) + + # "og:type" : "video", + # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw", + # "og:site_name" : "YouTube", + # "og:video:type" : "application/x-shockwave-flash", + # "og:description" : "Fun stuff happening here", + # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon", + # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg", + # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", + # "og:video:width" : "1280" + # "og:video:height" : "720", + # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3", + + og: Dict[str, Optional[str]] = {} + for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): + if "content" in tag.attrib: + # if we've got more than 50 tags, someone is taking the piss + if len(og) >= 50: + logger.warning("Skipping OG for page with too many 'og:' tags") + return {} + og[tag.attrib["property"]] = tag.attrib["content"] + + # TODO: grab article: meta tags too, e.g.: + + # "article:publisher" : "https://www.facebook.com/thethudonline" /> + # "article:author" content="https://www.facebook.com/thethudonline" /> + # "article:tag" content="baby" /> + # "article:section" content="Breaking News" /> + # "article:published_time" content="2016-03-31T19:58:24+00:00" /> + # "article:modified_time" content="2016-04-01T18:31:53+00:00" /> + + if "og:title" not in og: + # do some basic spidering of the HTML + title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]") + if title and title[0].text is not None: + og["og:title"] = title[0].text.strip() + else: + og["og:title"] = None + + if "og:image" not in og: + # TODO: extract a favicon failing all else + meta_image = tree.xpath( + "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content" + ) + if meta_image: + og["og:image"] = rebase_url(meta_image[0], media_uri) + else: + # TODO: consider inlined CSS styles as well as width & height attribs + images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") + images = sorted( + images, + key=lambda i: ( + -1 * float(i.attrib["width"]) * float(i.attrib["height"]) + ), + ) + if not images: + images = tree.xpath("//img[@src]") + if images: + og["og:image"] = images[0].attrib["src"] + + if "og:description" not in og: + meta_description = tree.xpath( + "//*/meta" + "[translate(@name, 'DESCRIPTION', 'description')='description']" + "/@content" + ) + if meta_description: + og["og:description"] = meta_description[0] + else: + og["og:description"] = parse_html_description(tree) + elif og["og:description"]: + # This must be a non-empty string at this point. + assert isinstance(og["og:description"], str) + og["og:description"] = summarize_paragraphs([og["og:description"]]) + + # TODO: delete the url downloads to stop diskfilling, + # as we only ever cared about its OG + return og + + +def parse_html_description(tree: "etree.Element") -> Optional[str]: + """ + Calculate a text description based on an HTML document. + + Grabs any text nodes which are inside the tag, unless they are within + an HTML5 semantic markup tag (
,