From 845732be450b3f9c991df35b2f07d600a0eca6dd Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 4 Aug 2022 11:02:29 +0200 Subject: Fix rooms not being properly excluded from incremental sync (#13408) --- tests/rest/client/test_sync.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) (limited to 'tests/rest') diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index b085c50356..ae16184828 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -948,3 +948,24 @@ class ExcludeRoomTestCase(unittest.HomeserverTestCase): self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["invite"]) self.assertIn(self.included_room_id, channel.json_body["rooms"]["invite"]) + + def test_incremental_sync(self) -> None: + """Tests that activity in the room is properly filtered out of incremental + syncs. + """ + channel = self.make_request("GET", "/sync", access_token=self.tok) + self.assertEqual(channel.code, 200, channel.result) + next_batch = channel.json_body["next_batch"] + + self.helper.send(self.excluded_room_id, tok=self.tok) + self.helper.send(self.included_room_id, tok=self.tok) + + channel = self.make_request( + "GET", + f"/sync?since={next_batch}", + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["join"]) + self.assertIn(self.included_room_id, channel.json_body["rooms"]["join"]) -- cgit 1.5.1 From e2ed1b7155bbd38d4a2752073056c112464b3644 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 5 Aug 2022 16:59:09 +0200 Subject: Use literals in place of `HTTPStatus` constants in tests (#13463) --- changelog.d/13463.misc | 1 + tests/federation/test_complexity.py | 5 +- tests/federation/transport/test_knocking.py | 5 +- tests/handlers/test_deactivate_account.py | 3 +- tests/handlers/test_message.py | 2 +- tests/handlers/test_room_member.py | 3 +- tests/http/server/_base.py | 3 +- tests/rest/client/test_filter.py | 14 +-- tests/rest/client/test_login.py | 127 ++++++++++++++-------------- tests/rest/client/test_redactions.py | 4 +- tests/rest/client/test_register.py | 94 ++++++++++---------- tests/rest/client/test_report_event.py | 4 +- tests/rest/client/test_third_party_rules.py | 22 ++--- tests/rest/client/utils.py | 28 +++--- tests/rest/test_health.py | 4 +- tests/rest/test_well_known.py | 12 ++- tests/test_server.py | 26 +++--- tests/test_terms_auth.py | 6 +- 18 files changed, 172 insertions(+), 191 deletions(-) create mode 100644 changelog.d/13463.misc (limited to 'tests/rest') diff --git a/changelog.d/13463.misc b/changelog.d/13463.misc new file mode 100644 index 0000000000..a4c8691144 --- /dev/null +++ b/changelog.d/13463.misc @@ -0,0 +1 @@ +Use literals in place of `HTTPStatus` constants in tests. diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index c6dd99316a..9f1115dd23 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus from unittest.mock import Mock from synapse.api.errors import Codes, SynapseError @@ -51,7 +50,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): channel = self.make_signed_federation_request( "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) ) - self.assertEqual(HTTPStatus.OK, channel.code) + self.assertEqual(200, channel.code) complexity = channel.json_body["v1"] self.assertTrue(complexity > 0, complexity) @@ -63,7 +62,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): channel = self.make_signed_federation_request( "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) ) - self.assertEqual(HTTPStatus.OK, channel.code) + self.assertEqual(200, channel.code) complexity = channel.json_body["v1"] self.assertEqual(complexity, 1.23) diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index 0d048207b7..d21c11b716 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict -from http import HTTPStatus from typing import Dict, List from synapse.api.constants import EventTypes, JoinRules, Membership @@ -256,7 +255,7 @@ class FederationKnockingTestCase( RoomVersions.V7.identifier, ), ) - self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # Note: We don't expect the knock membership event to be sent over federation as # part of the stripped room state, as the knocking homeserver already has that @@ -294,7 +293,7 @@ class FederationKnockingTestCase( % (room_id, signed_knock_event.event_id), signed_knock_event_json, ) - self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # Check that we got the stripped room state in return room_state_events = channel.json_body["knock_state_events"] diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py index 7586e472b5..ff9f2e8edb 100644 --- a/tests/handlers/test_deactivate_account.py +++ b/tests/handlers/test_deactivate_account.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus from typing import Any, Dict from twisted.test.proto_helpers import MemoryReactor @@ -58,7 +57,7 @@ class DeactivateAccountTestCase(HomeserverTestCase): access_token=self.token, ) - self.assertEqual(req.code, HTTPStatus.OK, req) + self.assertEqual(req.code, 200, req) def test_global_account_data_deleted_upon_deactivation(self) -> None: """ diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 44da96c792..986b50ce0c 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -314,4 +314,4 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", path, content={}, access_token=self.access_token ) - self.assertEqual(int(channel.result["code"]), 403) + self.assertEqual(channel.code, 403) diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py index 254e7e4b80..b4e1405aee 100644 --- a/tests/handlers/test_room_member.py +++ b/tests/handlers/test_room_member.py @@ -1,4 +1,3 @@ -from http import HTTPStatus from unittest.mock import Mock, patch from twisted.test.proto_helpers import MemoryReactor @@ -260,7 +259,7 @@ class TestReplicatedJoinsLimitedByPerRoomRateLimiter(RedisMultiWorkerStreamTestC f"/_matrix/client/v3/rooms/{self.room_id}/join", access_token=self.bob_token, ) - self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + self.assertEqual(channel.code, 200, channel.json_body) # wait for join to arrive over replication self.replicate() diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index 994d8880b0..5726e60cee 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -15,7 +15,6 @@ import inspect import itertools import logging -from http import HTTPStatus from typing import ( Any, Callable, @@ -78,7 +77,7 @@ def test_disconnect( if expect_cancellation: expected_code = HTTP_STATUS_REQUEST_CANCELLED else: - expected_code = HTTPStatus.OK + expected_code = 200 request = channel.request if channel.is_finished(): diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py index 823e8ab8c4..afc8d641be 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py @@ -43,7 +43,7 @@ class FilterTestCase(unittest.HomeserverTestCase): self.EXAMPLE_FILTER_JSON, ) - self.assertEqual(channel.result["code"], b"200") + self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {"filter_id": "0"}) filter = self.get_success( self.store.get_user_filter(user_localpart="apple", filter_id=0) @@ -58,7 +58,7 @@ class FilterTestCase(unittest.HomeserverTestCase): self.EXAMPLE_FILTER_JSON, ) - self.assertEqual(channel.result["code"], b"403") + self.assertEqual(channel.code, 403) self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) def test_add_filter_non_local_user(self) -> None: @@ -71,7 +71,7 @@ class FilterTestCase(unittest.HomeserverTestCase): ) self.hs.is_mine = _is_mine - self.assertEqual(channel.result["code"], b"403") + self.assertEqual(channel.code, 403) self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) def test_get_filter(self) -> None: @@ -85,7 +85,7 @@ class FilterTestCase(unittest.HomeserverTestCase): "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id) ) - self.assertEqual(channel.result["code"], b"200") + self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, self.EXAMPLE_FILTER) def test_get_filter_non_existant(self) -> None: @@ -93,7 +93,7 @@ class FilterTestCase(unittest.HomeserverTestCase): "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id) ) - self.assertEqual(channel.result["code"], b"404") + self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) # Currently invalid params do not have an appropriate errcode @@ -103,7 +103,7 @@ class FilterTestCase(unittest.HomeserverTestCase): "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id) ) - self.assertEqual(channel.result["code"], b"400") + self.assertEqual(channel.code, 400) # No ID also returns an invalid_id error def test_get_filter_no_id(self) -> None: @@ -111,4 +111,4 @@ class FilterTestCase(unittest.HomeserverTestCase): "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id) ) - self.assertEqual(channel.result["code"], b"400") + self.assertEqual(channel.code, 400) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index a2958f6959..e2a4d98275 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -13,7 +13,6 @@ # limitations under the License. import time import urllib.parse -from http import HTTPStatus from typing import Any, Dict, List, Optional from unittest.mock import Mock from urllib.parse import urlencode @@ -134,10 +133,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. @@ -152,7 +151,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) @override_config( { @@ -179,10 +178,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. @@ -197,7 +196,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) @override_config( { @@ -224,10 +223,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. @@ -242,7 +241,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) @override_config({"session_lifetime": "24h"}) def test_soft_logout(self) -> None: @@ -250,7 +249,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # we shouldn't be able to make requests without an access token channel = self.make_request(b"GET", TEST_URL) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN") # log in as normal @@ -261,20 +260,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + self.assertEqual(channel.code, 200, channel.result) access_token = channel.json_body["access_token"] device_id = channel.json_body["device_id"] # we should now be able to make requests with the access token channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + self.assertEqual(channel.code, 200, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) + self.assertEqual(channel.code, 401, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["soft_logout"], True) @@ -288,7 +287,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # more requests with the expired token should still return a soft-logout self.reactor.advance(3600) channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) + self.assertEqual(channel.code, 401, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["soft_logout"], True) @@ -296,7 +295,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self._delete_device(access_token_2, "kermit", "monkey", device_id) channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) + self.assertEqual(channel.code, 401, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["soft_logout"], False) @@ -307,7 +306,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request( b"DELETE", "devices/" + device_id, access_token=access_token ) - self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) + self.assertEqual(channel.code, 401, channel.result) # check it's a UI-Auth fail self.assertEqual( set(channel.json_body.keys()), @@ -330,7 +329,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): access_token=access_token, content={"auth": auth}, ) - self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + self.assertEqual(channel.code, 200, channel.result) @override_config({"session_lifetime": "24h"}) def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None: @@ -341,20 +340,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # we should now be able to make requests with the access token channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + self.assertEqual(channel.code, 200, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) + self.assertEqual(channel.code, 401, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["soft_logout"], True) # Now try to hard logout this session channel = self.make_request(b"POST", "/logout", access_token=access_token) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) @override_config({"session_lifetime": "24h"}) def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out( @@ -367,20 +366,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # we should now be able to make requests with the access token channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + self.assertEqual(channel.code, 200, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) + self.assertEqual(channel.code, 401, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["soft_logout"], True) # Now try to hard log out all of the user's sessions channel = self.make_request(b"POST", "/logout/all", access_token=access_token) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) def test_login_with_overly_long_device_id_fails(self) -> None: self.register_user("mickey", "cheese") @@ -466,7 +465,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): def test_get_login_flows(self) -> None: """GET /login should return password and SSO flows""" channel = self.make_request("GET", "/_matrix/client/r0/login") - self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + self.assertEqual(channel.code, 200, channel.result) expected_flow_types = [ "m.login.cas", @@ -494,14 +493,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): """/login/sso/redirect should redirect to an identity picker""" # first hit the redirect url, which should redirect to our idp picker channel = self._make_sso_redirect_request(None) - self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) + self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers uri = location_headers[0] # hitting that picker should give us some HTML channel = self.make_request("GET", uri) - self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + self.assertEqual(channel.code, 200, channel.result) # parse the form to check it has fields assumed elsewhere in this class html = channel.result["body"].decode("utf-8") @@ -530,7 +529,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): + "&idp=cas", shorthand=False, ) - self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) + self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers cas_uri = location_headers[0] @@ -555,7 +554,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + "&idp=saml", ) - self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) + self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers saml_uri = location_headers[0] @@ -579,7 +578,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + "&idp=oidc", ) - self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) + self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers oidc_uri = location_headers[0] @@ -606,7 +605,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"}) # that should serve a confirmation page - self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + self.assertEqual(channel.code, 200, channel.result) content_type_headers = channel.headers.getRawHeaders("Content-Type") assert content_type_headers self.assertTrue(content_type_headers[-1].startswith("text/html")) @@ -634,7 +633,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): "/login", content={"type": "m.login.token", "token": login_token}, ) - self.assertEqual(chan.code, HTTPStatus.OK, chan.result) + self.assertEqual(chan.code, 200, chan.result) self.assertEqual(chan.json_body["user_id"], "@user1:test") def test_multi_sso_redirect_to_unknown(self) -> None: @@ -643,18 +642,18 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): "GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", ) - self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + self.assertEqual(channel.code, 400, channel.result) def test_client_idp_redirect_to_unknown(self) -> None: """If the client tries to pick an unknown IdP, return a 404""" channel = self._make_sso_redirect_request("xxx") - self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result) + self.assertEqual(channel.code, 404, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") def test_client_idp_redirect_to_oidc(self) -> None: """If the client pick a known IdP, redirect to it""" channel = self._make_sso_redirect_request("oidc") - self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) + self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers oidc_uri = location_headers[0] @@ -765,7 +764,7 @@ class CASTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", cas_ticket_url) # Test that the response is HTML. - self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + self.assertEqual(channel.code, 200, channel.result) content_type_header_value = "" for header in channel.result.get("headers", []): if header[0] == b"Content-Type": @@ -878,17 +877,17 @@ class JWTTestCase(unittest.HomeserverTestCase): def test_login_jwt_valid_registered(self) -> None: self.register_user("kermit", "monkey") channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") def test_login_jwt_valid_unregistered(self) -> None: channel = self.jwt_login({"sub": "frog"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@frog:test") def test_login_jwt_invalid_signature(self) -> None: channel = self.jwt_login({"sub": "frog"}, "notsecret") - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -897,7 +896,7 @@ class JWTTestCase(unittest.HomeserverTestCase): def test_login_jwt_expired(self) -> None: channel = self.jwt_login({"sub": "frog", "exp": 864000}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -907,7 +906,7 @@ class JWTTestCase(unittest.HomeserverTestCase): def test_login_jwt_not_before(self) -> None: now = int(time.time()) channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -916,7 +915,7 @@ class JWTTestCase(unittest.HomeserverTestCase): def test_login_no_sub(self) -> None: channel = self.jwt_login({"username": "root"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Invalid JWT") @@ -925,12 +924,12 @@ class JWTTestCase(unittest.HomeserverTestCase): """Test validating the issuer claim.""" # A valid issuer. channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") # An invalid issuer. channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -939,7 +938,7 @@ class JWTTestCase(unittest.HomeserverTestCase): # Not providing an issuer. channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -949,7 +948,7 @@ class JWTTestCase(unittest.HomeserverTestCase): def test_login_iss_no_config(self) -> None: """Test providing an issuer claim without requiring it in the configuration.""" channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}}) @@ -957,12 +956,12 @@ class JWTTestCase(unittest.HomeserverTestCase): """Test validating the audience claim.""" # A valid audience. channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") # An invalid audience. channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -971,7 +970,7 @@ class JWTTestCase(unittest.HomeserverTestCase): # Not providing an audience. channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -981,7 +980,7 @@ class JWTTestCase(unittest.HomeserverTestCase): def test_login_aud_no_config(self) -> None: """Test providing an audience without requiring it in the configuration.""" channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -991,20 +990,20 @@ class JWTTestCase(unittest.HomeserverTestCase): def test_login_default_sub(self) -> None: """Test reading user ID from the default subject claim.""" channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") @override_config({"jwt_config": {**base_config, "subject_claim": "username"}}) def test_login_custom_sub(self) -> None: """Test reading user ID from a custom subject claim.""" channel = self.jwt_login({"username": "frog"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@frog:test") def test_login_no_token(self) -> None: params = {"type": "org.matrix.login.jwt"} channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Token field for JWT is missing") @@ -1086,12 +1085,12 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): def test_login_jwt_valid(self) -> None: channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") def test_login_jwt_invalid_signature(self) -> None: channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -1152,7 +1151,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) def test_login_appservice_user_bot(self) -> None: """Test that the appservice bot can use /login""" @@ -1166,7 +1165,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) def test_login_appservice_wrong_user(self) -> None: """Test that non-as users cannot login with the as token""" @@ -1180,7 +1179,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) def test_login_appservice_wrong_as(self) -> None: """Test that as users cannot login with wrong as token""" @@ -1194,7 +1193,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.another_service.token ) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) def test_login_appservice_no_token(self) -> None: """Test that users must provide a token when using the appservice @@ -1208,7 +1207,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) @skip_unless(HAS_OIDC, "requires OIDC") @@ -1246,7 +1245,7 @@ class UsernamePickerTestCase(HomeserverTestCase): ) # that should redirect to the username picker - self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) + self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers picker_url = location_headers[0] @@ -1290,7 +1289,7 @@ class UsernamePickerTestCase(HomeserverTestCase): ("Content-Length", str(len(content))), ], ) - self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result) + self.assertEqual(chan.code, 302, chan.result) location_headers = chan.headers.getRawHeaders("Location") assert location_headers @@ -1300,7 +1299,7 @@ class UsernamePickerTestCase(HomeserverTestCase): path=location_headers[0], custom_headers=[("Cookie", "username_mapping_session=" + session_id)], ) - self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result) + self.assertEqual(chan.code, 302, chan.result) location_headers = chan.headers.getRawHeaders("Location") assert location_headers @@ -1325,5 +1324,5 @@ class UsernamePickerTestCase(HomeserverTestCase): "/login", content={"type": "m.login.token", "token": login_token}, ) - self.assertEqual(chan.code, HTTPStatus.OK, chan.result) + self.assertEqual(chan.code, 200, chan.result) self.assertEqual(chan.json_body["user_id"], "@bobby:test") diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index 7401b5e0c0..be4c67d68e 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -76,12 +76,12 @@ class RedactionsTestCase(HomeserverTestCase): path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id) channel = self.make_request("POST", path, content={}, access_token=access_token) - self.assertEqual(int(channel.result["code"]), expect_code) + self.assertEqual(channel.code, expect_code) return channel.json_body def _sync_room_timeline(self, access_token: str, room_id: str) -> List[JsonDict]: channel = self.make_request("GET", "sync", access_token=self.mod_access_token) - self.assertEqual(channel.result["code"], b"200") + self.assertEqual(channel.code, 200) room_sync = channel.json_body["rooms"]["join"][room_id] return room_sync["timeline"]["events"] diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index f8e64ce6ac..ab4277dd31 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -70,7 +70,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) det_data = {"user_id": user_id, "home_server": self.hs.hostname} self.assertDictContainsSubset(det_data, channel.json_body) @@ -91,7 +91,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.assertEqual(channel.result["code"], b"400", channel.result) + self.assertEqual(channel.code, 400, msg=channel.result) def test_POST_appservice_registration_invalid(self) -> None: self.appservice = None # no application service exists @@ -100,20 +100,20 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) def test_POST_bad_password(self) -> None: request_data = {"username": "kermit", "password": 666} channel = self.make_request(b"POST", self.url, request_data) - self.assertEqual(channel.result["code"], b"400", channel.result) + self.assertEqual(channel.code, 400, msg=channel.result) self.assertEqual(channel.json_body["error"], "Invalid password") def test_POST_bad_username(self) -> None: request_data = {"username": 777, "password": "monkey"} channel = self.make_request(b"POST", self.url, request_data) - self.assertEqual(channel.result["code"], b"400", channel.result) + self.assertEqual(channel.code, 400, msg=channel.result) self.assertEqual(channel.json_body["error"], "Invalid username") def test_POST_user_valid(self) -> None: @@ -132,7 +132,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "home_server": self.hs.hostname, "device_id": device_id, } - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertDictContainsSubset(det_data, channel.json_body) @override_config({"enable_registration": False}) @@ -142,7 +142,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", self.url, request_data) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["error"], "Registration has been disabled") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -153,7 +153,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertDictContainsSubset(det_data, channel.json_body) def test_POST_disabled_guest_registration(self) -> None: @@ -161,7 +161,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["error"], "Guest access is disabled") @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) @@ -171,16 +171,16 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", url, b"{}") if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) def test_POST_ratelimiting(self) -> None: @@ -194,16 +194,16 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", self.url, request_data) if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) @override_config({"registration_requires_token": True}) def test_POST_registration_requires_token(self) -> None: @@ -231,7 +231,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): # Request without auth to get flows and session channel = self.make_request(b"POST", self.url, params) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) flows = channel.json_body["flows"] # Synapse adds a dummy stage to differentiate flows where otherwise one # flow would be a subset of another flow. @@ -248,7 +248,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "session": session, } channel = self.make_request(b"POST", self.url, params) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) completed = channel.json_body["completed"] self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) @@ -263,7 +263,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "home_server": self.hs.hostname, "device_id": device_id, } - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertDictContainsSubset(det_data, channel.json_body) # Check the `completed` counter has been incremented and pending is 0 @@ -293,21 +293,21 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "session": session, } channel = self.make_request(b"POST", self.url, params) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM) self.assertEqual(channel.json_body["completed"], []) # Test with non-string (invalid) params["auth"]["token"] = 1234 channel = self.make_request(b"POST", self.url, params) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) self.assertEqual(channel.json_body["completed"], []) # Test with unknown token (invalid) params["auth"]["token"] = "1234" channel = self.make_request(b"POST", self.url, params) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) @@ -361,7 +361,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "session": session2, } channel = self.make_request(b"POST", self.url, params2) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) @@ -381,7 +381,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): # Check auth still fails when using token with session2 channel = self.make_request(b"POST", self.url, params2) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) @@ -415,7 +415,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "session": session, } channel = self.make_request(b"POST", self.url, params) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) @@ -570,7 +570,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_advertised_flows(self) -> None: channel = self.make_request(b"POST", self.url, b"{}") - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) flows = channel.json_body["flows"] # with the stock config, we only expect the dummy flow @@ -593,7 +593,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None: channel = self.make_request(b"POST", self.url, b"{}") - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) flows = channel.json_body["flows"] self.assertCountEqual( @@ -625,7 +625,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) def test_advertised_flows_no_msisdn_email_required(self) -> None: channel = self.make_request(b"POST", self.url, b"{}") - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) flows = channel.json_body["flows"] # with the stock config, we expect all four combinations of 3pid @@ -797,13 +797,13 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): # endpoint. channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual( channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) @@ -823,12 +823,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/account_validity/validity" request_data = {"user_id": user_id} channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) def test_manual_expire(self) -> None: user_id = self.register_user("kermit", "monkey") @@ -844,12 +844,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): "enable_renewal_emails": False, } channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual( channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) @@ -868,18 +868,18 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): "enable_renewal_emails": False, } channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # Try to log the user out channel = self.make_request(b"POST", "/logout", access_token=tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # Log the user in again (allowed for expired accounts) tok = self.login("kermit", "monkey") # Try to log out all of the user's sessions channel = self.make_request(b"POST", "/logout/all", access_token=tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): @@ -954,7 +954,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id)) url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token channel = self.make_request(b"GET", url) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # Check that we're getting HTML back. content_type = channel.headers.getRawHeaders(b"Content-Type") @@ -972,7 +972,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # Move 1 day forward. Try to renew with the same token again. url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token channel = self.make_request(b"GET", url) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # Check that we're getting HTML back. content_type = channel.headers.getRawHeaders(b"Content-Type") @@ -992,14 +992,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # succeed. self.reactor.advance(datetime.timedelta(days=3).total_seconds()) channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) def test_renewal_invalid_token(self) -> None: # Hit the renewal endpoint with an invalid token and check that it behaves as # expected, i.e. that it responds with 404 Not Found and the correct HTML. url = "/_matrix/client/unstable/account_validity/renew?token=123" channel = self.make_request(b"GET", url) - self.assertEqual(channel.result["code"], b"404", channel.result) + self.assertEqual(channel.code, 404, msg=channel.result) # Check that we're getting HTML back. content_type = channel.headers.getRawHeaders(b"Content-Type") @@ -1023,7 +1023,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(len(self.email_attempts), 1) @@ -1096,7 +1096,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(len(self.email_attempts), 1) @@ -1176,7 +1176,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): b"GET", f"{self.url}?token={token}", ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["valid"], True) def test_GET_token_invalid(self) -> None: @@ -1185,7 +1185,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): b"GET", f"{self.url}?token={token}", ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["valid"], False) @override_config( @@ -1201,10 +1201,10 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): ) if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) @@ -1212,4 +1212,4 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): b"GET", f"{self.url}?token={token}", ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py index ad0d0209f7..7cb1017a4a 100644 --- a/tests/rest/client/test_report_event.py +++ b/tests/rest/client/test_report_event.py @@ -77,6 +77,4 @@ class ReportEventTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.report_path, data, access_token=self.other_user_tok ) - self.assertEqual( - response_status, int(channel.result["code"]), msg=channel.result["body"] - ) + self.assertEqual(response_status, channel.code, msg=channel.result["body"]) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 18a7195409..3325d43a2f 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -155,7 +155,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): {}, access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) callback.assert_called_once() @@ -173,7 +173,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): {}, access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, channel.result) def test_third_party_rules_workaround_synapse_errors_pass_through(self) -> None: """ @@ -211,7 +211,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): access_token=self.tok, ) # Check the error code - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.code, 429, channel.result) # Check the JSON body has had the `nasty` key injected self.assertEqual( channel.json_body, @@ -260,7 +260,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): {"x": "x"}, access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) event_id = channel.json_body["event_id"] # ... and check that it got modified @@ -269,7 +269,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) ev = channel.json_body self.assertEqual(ev["content"]["x"], "y") @@ -298,7 +298,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): }, access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) orig_event_id = channel.json_body["event_id"] channel = self.make_request( @@ -315,7 +315,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): }, access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) edited_event_id = channel.json_body["event_id"] # ... and check that they both got modified @@ -324,7 +324,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, orig_event_id), access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) ev = channel.json_body self.assertEqual(ev["content"]["body"], "ORIGINAL BODY") @@ -333,7 +333,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, edited_event_id), access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) ev = channel.json_body self.assertEqual(ev["content"]["body"], "EDITED BODY") @@ -379,7 +379,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): }, access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) event_id = channel.json_body["event_id"] @@ -388,7 +388,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) self.assertIn("foo", channel.json_body["content"].keys()) self.assertEqual(channel.json_body["content"]["foo"], "bar") diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 105d418698..dd26145bf8 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -140,7 +140,7 @@ class RestHelper: custom_headers=custom_headers, ) - assert channel.result["code"] == b"%d" % expect_code, channel.result + assert channel.code == expect_code, channel.result self.auth_user_id = temp_id if expect_code == HTTPStatus.OK: @@ -213,11 +213,9 @@ class RestHelper: data, ) - assert ( - int(channel.result["code"]) == expect_code - ), "Expected: %d, got: %d, resp: %r" % ( + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( expect_code, - int(channel.result["code"]), + channel.code, channel.result["body"], ) @@ -312,11 +310,9 @@ class RestHelper: data, ) - assert ( - int(channel.result["code"]) == expect_code - ), "Expected: %d, got: %d, resp: %r" % ( + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( expect_code, - int(channel.result["code"]), + channel.code, channel.result["body"], ) @@ -396,11 +392,9 @@ class RestHelper: custom_headers=custom_headers, ) - assert ( - int(channel.result["code"]) == expect_code - ), "Expected: %d, got: %d, resp: %r" % ( + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( expect_code, - int(channel.result["code"]), + channel.code, channel.result["body"], ) @@ -449,11 +443,9 @@ class RestHelper: channel = make_request(self.hs.get_reactor(), self.site, method, path, content) - assert ( - int(channel.result["code"]) == expect_code - ), "Expected: %d, got: %d, resp: %r" % ( + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( expect_code, - int(channel.result["code"]), + channel.code, channel.result["body"], ) @@ -545,7 +537,7 @@ class RestHelper: assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( expect_code, - int(channel.result["code"]), + channel.code, channel.result["body"], ) diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py index da325955f8..c0a2501742 100644 --- a/tests/rest/test_health.py +++ b/tests/rest/test_health.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus - from synapse.rest.health import HealthResource from tests import unittest @@ -26,5 +24,5 @@ class HealthCheckTests(unittest.HomeserverTestCase): def test_health(self) -> None: channel = self.make_request("GET", "/health", shorthand=False) - self.assertEqual(channel.code, HTTPStatus.OK) + self.assertEqual(channel.code, 200) self.assertEqual(channel.result["body"], b"OK") diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index d8faafec75..2091b08d89 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus - from twisted.web.resource import Resource from synapse.rest.well_known import well_known_resource @@ -38,7 +36,7 @@ class WellKnownTests(unittest.HomeserverTestCase): "GET", "/.well-known/matrix/client", shorthand=False ) - self.assertEqual(channel.code, HTTPStatus.OK) + self.assertEqual(channel.code, 200) self.assertEqual( channel.json_body, { @@ -57,7 +55,7 @@ class WellKnownTests(unittest.HomeserverTestCase): "GET", "/.well-known/matrix/client", shorthand=False ) - self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) + self.assertEqual(channel.code, 404) @unittest.override_config( { @@ -71,7 +69,7 @@ class WellKnownTests(unittest.HomeserverTestCase): "GET", "/.well-known/matrix/client", shorthand=False ) - self.assertEqual(channel.code, HTTPStatus.OK) + self.assertEqual(channel.code, 200) self.assertEqual( channel.json_body, { @@ -87,7 +85,7 @@ class WellKnownTests(unittest.HomeserverTestCase): "GET", "/.well-known/matrix/server", shorthand=False ) - self.assertEqual(channel.code, HTTPStatus.OK) + self.assertEqual(channel.code, 200) self.assertEqual( channel.json_body, {"m.server": "test:443"}, @@ -97,4 +95,4 @@ class WellKnownTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", "/.well-known/matrix/server", shorthand=False ) - self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) + self.assertEqual(channel.code, 404) diff --git a/tests/test_server.py b/tests/test_server.py index 2fe4411401..d2b2d8344a 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -104,7 +104,7 @@ class JsonResourceTests(unittest.TestCase): self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo" ) - self.assertEqual(channel.result["code"], b"500") + self.assertEqual(channel.code, 500) def test_callback_indirect_exception(self) -> None: """ @@ -130,7 +130,7 @@ class JsonResourceTests(unittest.TestCase): self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo" ) - self.assertEqual(channel.result["code"], b"500") + self.assertEqual(channel.code, 500) def test_callback_synapseerror(self) -> None: """ @@ -150,7 +150,7 @@ class JsonResourceTests(unittest.TestCase): self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo" ) - self.assertEqual(channel.result["code"], b"403") + self.assertEqual(channel.code, 403) self.assertEqual(channel.json_body["error"], "Forbidden!!one!") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -174,7 +174,7 @@ class JsonResourceTests(unittest.TestCase): self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foobar" ) - self.assertEqual(channel.result["code"], b"400") + self.assertEqual(channel.code, 400) self.assertEqual(channel.json_body["error"], "Unrecognized request") self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") @@ -203,7 +203,7 @@ class JsonResourceTests(unittest.TestCase): self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/_matrix/foo" ) - self.assertEqual(channel.result["code"], b"200") + self.assertEqual(channel.code, 200) self.assertNotIn("body", channel.result) @@ -242,7 +242,7 @@ class OptionsResourceTests(unittest.TestCase): def test_unknown_options_request(self) -> None: """An OPTIONS requests to an unknown URL still returns 204 No Content.""" channel = self._make_request(b"OPTIONS", b"/foo/") - self.assertEqual(channel.result["code"], b"204") + self.assertEqual(channel.code, 204) self.assertNotIn("body", channel.result) # Ensure the correct CORS headers have been added @@ -262,7 +262,7 @@ class OptionsResourceTests(unittest.TestCase): def test_known_options_request(self) -> None: """An OPTIONS requests to an known URL still returns 204 No Content.""" channel = self._make_request(b"OPTIONS", b"/res/") - self.assertEqual(channel.result["code"], b"204") + self.assertEqual(channel.code, 204) self.assertNotIn("body", channel.result) # Ensure the correct CORS headers have been added @@ -282,12 +282,12 @@ class OptionsResourceTests(unittest.TestCase): def test_unknown_request(self) -> None: """A non-OPTIONS request to an unknown URL should 404.""" channel = self._make_request(b"GET", b"/foo/") - self.assertEqual(channel.result["code"], b"404") + self.assertEqual(channel.code, 404) def test_known_request(self) -> None: """A non-OPTIONS request to an known URL should query the proper resource.""" channel = self._make_request(b"GET", b"/res/") - self.assertEqual(channel.result["code"], b"200") + self.assertEqual(channel.code, 200) self.assertEqual(channel.result["body"], b"/res/") @@ -314,7 +314,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): self.reactor, FakeSite(res, self.reactor), b"GET", b"/path" ) - self.assertEqual(channel.result["code"], b"200") + self.assertEqual(channel.code, 200) body = channel.result["body"] self.assertEqual(body, b"response") @@ -334,7 +334,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): self.reactor, FakeSite(res, self.reactor), b"GET", b"/path" ) - self.assertEqual(channel.result["code"], b"301") + self.assertEqual(channel.code, 301) headers = channel.result["headers"] location_headers = [v for k, v in headers if k == b"Location"] self.assertEqual(location_headers, [b"/look/an/eagle"]) @@ -357,7 +357,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): self.reactor, FakeSite(res, self.reactor), b"GET", b"/path" ) - self.assertEqual(channel.result["code"], b"304") + self.assertEqual(channel.code, 304) headers = channel.result["headers"] location_headers = [v for k, v in headers if k == b"Location"] self.assertEqual(location_headers, [b"/no/over/there"]) @@ -378,7 +378,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/path" ) - self.assertEqual(channel.result["code"], b"200") + self.assertEqual(channel.code, 200) self.assertNotIn("body", channel.result) diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index d3c13cf14c..abd7459a8c 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -53,7 +53,7 @@ class TermsTestCase(unittest.HomeserverTestCase): request_data = {"username": "kermit", "password": "monkey"} channel = self.make_request(b"POST", self.url, request_data) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, channel.result) self.assertTrue(channel.json_body is not None) self.assertIsInstance(channel.json_body["session"], str) @@ -96,7 +96,7 @@ class TermsTestCase(unittest.HomeserverTestCase): # We don't bother checking that the response is correct - we'll leave that to # other tests. We just want to make sure we're on the right path. - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, channel.result) # Finish the UI auth for terms request_data = { @@ -112,7 +112,7 @@ class TermsTestCase(unittest.HomeserverTestCase): # We're interested in getting a response that looks like a successful # registration, not so much that the details are exactly what we want. - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) self.assertTrue(channel.json_body is not None) self.assertIsInstance(channel.json_body["user_id"], str) -- cgit 1.5.1 From ab18441573dc14cea1fe4082b2a89b9d392a4b9f Mon Sep 17 00:00:00 2001 From: Šimon Brandner Date: Fri, 5 Aug 2022 17:09:33 +0200 Subject: Support stable identifiers for MSC2285: private read receipts. (#13273) This adds support for the stable identifiers of MSC2285 while continuing to support the unstable identifiers behind the configuration flag. These will be removed in a future version. --- changelog.d/13273.feature | 1 + synapse/api/constants.py | 3 +- synapse/config/experimental.py | 2 +- synapse/handlers/initial_sync.py | 11 +-- synapse/handlers/receipts.py | 36 ++++++--- synapse/replication/tcp/client.py | 5 +- synapse/rest/client/notifications.py | 7 +- synapse/rest/client/read_marker.py | 8 +- synapse/rest/client/receipts.py | 10 ++- synapse/rest/client/versions.py | 1 + .../storage/databases/main/event_push_actions.py | 85 ++++++++++++++++++---- tests/handlers/test_receipts.py | 58 +++++++++++---- tests/rest/client/test_sync.py | 58 ++++++++++----- tests/storage/test_receipts.py | 55 +++++++++----- 14 files changed, 246 insertions(+), 94 deletions(-) create mode 100644 changelog.d/13273.feature (limited to 'tests/rest') diff --git a/changelog.d/13273.feature b/changelog.d/13273.feature new file mode 100644 index 0000000000..53110d74e9 --- /dev/null +++ b/changelog.d/13273.feature @@ -0,0 +1 @@ +Add support for stable prefixes for [MSC2285 (private read receipts)](https://github.com/matrix-org/matrix-spec-proposals/pull/2285). diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 789859e69e..1d46fb0e43 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -257,7 +257,8 @@ class GuestAccess: class ReceiptTypes: READ: Final = "m.read" - READ_PRIVATE: Final = "org.matrix.msc2285.read.private" + READ_PRIVATE: Final = "m.read.private" + UNSTABLE_READ_PRIVATE: Final = "org.matrix.msc2285.read.private" FULLY_READ: Final = "m.fully_read" diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index c2ecd977cd..7d17c958bb 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -32,7 +32,7 @@ class ExperimentalConfig(Config): # MSC2716 (importing historical messages) self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False) - # MSC2285 (private read receipts) + # MSC2285 (unstable private read receipts) self.msc2285_enabled: bool = experimental.get("msc2285_enabled", False) # MSC3244 (room version capabilities) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 85b472f250..6484e47e5f 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -143,8 +143,8 @@ class InitialSyncHandler: joined_rooms, to_key=int(now_token.receipt_key), ) - if self.hs.config.experimental.msc2285_enabled: - receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id) + + receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id) tags_by_room = await self.store.get_tags_for_user(user_id) @@ -456,11 +456,8 @@ class InitialSyncHandler: ) if not receipts: return [] - if self.hs.config.experimental.msc2285_enabled: - receipts = ReceiptEventSource.filter_out_private_receipts( - receipts, user_id - ) - return receipts + + return ReceiptEventSource.filter_out_private_receipts(receipts, user_id) presence, receipts, (messages, token) = await make_deferred_yieldable( gather_results( diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 43d2882b0a..d4a866b346 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -163,7 +163,10 @@ class ReceiptsHandler: if not is_new: return - if self.federation_sender and receipt_type != ReceiptTypes.READ_PRIVATE: + if self.federation_sender and receipt_type not in ( + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ): await self.federation_sender.send_read_receipt(receipt) @@ -203,24 +206,38 @@ class ReceiptEventSource(EventSource[int, JsonDict]): for event_id, orig_event_content in room.get("content", {}).items(): event_content = orig_event_content # If there are private read receipts, additional logic is necessary. - if ReceiptTypes.READ_PRIVATE in event_content: + if ( + ReceiptTypes.READ_PRIVATE in event_content + or ReceiptTypes.UNSTABLE_READ_PRIVATE in event_content + ): # Make a copy without private read receipts to avoid leaking # other user's private read receipts.. event_content = { receipt_type: receipt_value for receipt_type, receipt_value in event_content.items() - if receipt_type != ReceiptTypes.READ_PRIVATE + if receipt_type + not in ( + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ) } # Copy the current user's private read receipt from the # original content, if it exists. - user_private_read_receipt = orig_event_content[ - ReceiptTypes.READ_PRIVATE - ].get(user_id, None) + user_private_read_receipt = orig_event_content.get( + ReceiptTypes.READ_PRIVATE, {} + ).get(user_id, None) if user_private_read_receipt: event_content[ReceiptTypes.READ_PRIVATE] = { user_id: user_private_read_receipt } + user_unstable_private_read_receipt = orig_event_content.get( + ReceiptTypes.UNSTABLE_READ_PRIVATE, {} + ).get(user_id, None) + if user_unstable_private_read_receipt: + event_content[ReceiptTypes.UNSTABLE_READ_PRIVATE] = { + user_id: user_unstable_private_read_receipt + } # Include the event if there is at least one non-private read # receipt or the current user has a private read receipt. @@ -256,10 +273,9 @@ class ReceiptEventSource(EventSource[int, JsonDict]): room_ids, from_key=from_key, to_key=to_key ) - if self.config.experimental.msc2285_enabled: - events = ReceiptEventSource.filter_out_private_receipts( - events, user.to_string() - ) + events = ReceiptEventSource.filter_out_private_receipts( + events, user.to_string() + ) return events, to_key diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index e4f2201c92..1ed7230e32 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -416,7 +416,10 @@ class FederationSenderHandler: if not self._is_mine_id(receipt.user_id): continue # Private read receipts never get sent over federation. - if receipt.receipt_type == ReceiptTypes.READ_PRIVATE: + if receipt.receipt_type in ( + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ): continue receipt_info = ReadReceipt( receipt.room_id, diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index 24bc7c9095..a73322a6a4 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -58,7 +58,12 @@ class NotificationsServlet(RestServlet): ) receipts_by_room = await self.store.get_receipts_for_user_with_orderings( - user_id, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] + user_id, + [ + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ], ) notif_event_ids = [pa.event_id for pa in push_actions] diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index 8896f2df50..aaad8b233f 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -40,9 +40,13 @@ class ReadMarkerRestServlet(RestServlet): self.read_marker_handler = hs.get_read_marker_handler() self.presence_handler = hs.get_presence_handler() - self._known_receipt_types = {ReceiptTypes.READ, ReceiptTypes.FULLY_READ} + self._known_receipt_types = { + ReceiptTypes.READ, + ReceiptTypes.FULLY_READ, + ReceiptTypes.READ_PRIVATE, + } if hs.config.experimental.msc2285_enabled: - self._known_receipt_types.add(ReceiptTypes.READ_PRIVATE) + self._known_receipt_types.add(ReceiptTypes.UNSTABLE_READ_PRIVATE) async def on_POST( self, request: SynapseRequest, room_id: str diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 409bfd43c1..c6108fc5eb 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -44,11 +44,13 @@ class ReceiptRestServlet(RestServlet): self.read_marker_handler = hs.get_read_marker_handler() self.presence_handler = hs.get_presence_handler() - self._known_receipt_types = {ReceiptTypes.READ} + self._known_receipt_types = { + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.FULLY_READ, + } if hs.config.experimental.msc2285_enabled: - self._known_receipt_types.update( - (ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ) - ) + self._known_receipt_types.add(ReceiptTypes.UNSTABLE_READ_PRIVATE) async def on_POST( self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 0366986755..c9a830cbac 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -94,6 +94,7 @@ class VersionsRestServlet(RestServlet): # Supports the busy presence state described in MSC3026. "org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled, # Supports receiving private read receipts as per MSC2285 + "org.matrix.msc2285.stable": True, # TODO: Remove when MSC2285 becomes a part of the spec "org.matrix.msc2285": self.config.experimental.msc2285_enabled, # Supports filtering of /publicRooms by room type as per MSC3827 "org.matrix.msc3827.stable": True, diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 5db70f9a60..161aad0f89 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -80,7 +80,7 @@ import attr from synapse.api.constants import ReceiptTypes from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, @@ -259,7 +259,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas txn, user_id, room_id, - receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), + receipt_types=( + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ), ) stream_ordering = None @@ -448,6 +452,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas The list will be ordered by ascending stream_ordering. The list will have between 0~limit entries. """ + # find rooms that have a read receipt in them and return the next # push actions def get_after_receipt( @@ -455,7 +460,18 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ) -> List[Tuple[str, str, int, str, bool]]: # find rooms that have a read receipt in them and return the next # push actions - sql = """ + + receipt_types_clause, args = make_in_list_sql_clause( + self.database_engine, + "receipt_type", + ( + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ), + ) + + sql = f""" SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight FROM ( @@ -463,10 +479,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas MAX(stream_ordering) as stream_ordering FROM events INNER JOIN receipts_linearized USING (room_id, event_id) - WHERE receipt_type = 'm.read' AND user_id = ? + WHERE {receipt_types_clause} AND user_id = ? GROUP BY room_id ) AS rl, - event_push_actions AS ep + event_push_actions AS ep WHERE ep.room_id = rl.room_id AND ep.stream_ordering > rl.stream_ordering @@ -476,7 +492,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas AND ep.notif = 1 ORDER BY ep.stream_ordering ASC LIMIT ? """ - args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] + args.extend( + (user_id, user_id, min_stream_ordering, max_stream_ordering, limit) + ) txn.execute(sql, args) return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall()) @@ -490,7 +508,17 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas def get_no_receipt( txn: LoggingTransaction, ) -> List[Tuple[str, str, int, str, bool]]: - sql = """ + receipt_types_clause, args = make_in_list_sql_clause( + self.database_engine, + "receipt_type", + ( + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ), + ) + + sql = f""" SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight FROM event_push_actions AS ep @@ -498,7 +526,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas WHERE ep.room_id NOT IN ( SELECT room_id FROM receipts_linearized - WHERE receipt_type = 'm.read' AND user_id = ? + WHERE {receipt_types_clause} AND user_id = ? GROUP BY room_id ) AND ep.user_id = ? @@ -507,7 +535,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas AND ep.notif = 1 ORDER BY ep.stream_ordering ASC LIMIT ? """ - args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] + args.extend( + (user_id, user_id, min_stream_ordering, max_stream_ordering, limit) + ) txn.execute(sql, args) return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall()) @@ -557,12 +587,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas The list will be ordered by descending received_ts. The list will have between 0~limit entries. """ + # find rooms that have a read receipt in them and return the most recent # push actions def get_after_receipt( txn: LoggingTransaction, ) -> List[Tuple[str, str, int, str, bool, int]]: - sql = """ + receipt_types_clause, args = make_in_list_sql_clause( + self.database_engine, + "receipt_type", + ( + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ), + ) + + sql = f""" SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight, e.received_ts FROM ( @@ -570,7 +611,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas MAX(stream_ordering) as stream_ordering FROM events INNER JOIN receipts_linearized USING (room_id, event_id) - WHERE receipt_type = 'm.read' AND user_id = ? + WHERE {receipt_types_clause} AND user_id = ? GROUP BY room_id ) AS rl, event_push_actions AS ep @@ -584,7 +625,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas AND ep.notif = 1 ORDER BY ep.stream_ordering DESC LIMIT ? """ - args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] + args.extend( + (user_id, user_id, min_stream_ordering, max_stream_ordering, limit) + ) txn.execute(sql, args) return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall()) @@ -598,7 +641,17 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas def get_no_receipt( txn: LoggingTransaction, ) -> List[Tuple[str, str, int, str, bool, int]]: - sql = """ + receipt_types_clause, args = make_in_list_sql_clause( + self.database_engine, + "receipt_type", + ( + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ), + ) + + sql = f""" SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight, e.received_ts FROM event_push_actions AS ep @@ -606,7 +659,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas WHERE ep.room_id NOT IN ( SELECT room_id FROM receipts_linearized - WHERE receipt_type = 'm.read' AND user_id = ? + WHERE {receipt_types_clause} AND user_id = ? GROUP BY room_id ) AND ep.user_id = ? @@ -615,7 +668,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas AND ep.notif = 1 ORDER BY ep.stream_ordering DESC LIMIT ? """ - args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] + args.extend( + (user_id, user_id, min_stream_ordering, max_stream_ordering, limit) + ) txn.execute(sql, args) return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall()) diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py index a95868b5c0..5f70a2db79 100644 --- a/tests/handlers/test_receipts.py +++ b/tests/handlers/test_receipts.py @@ -15,6 +15,8 @@ from copy import deepcopy from typing import List +from parameterized import parameterized + from synapse.api.constants import EduTypes, ReceiptTypes from synapse.types import JsonDict @@ -25,13 +27,16 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.event_source = hs.get_event_sources().sources.receipt - def test_filters_out_private_receipt(self): + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_filters_out_private_receipt(self, receipt_type: str) -> None: self._test_filters_private( [ { "content": { "$1435641916114394fHBLK:matrix.org": { - ReceiptTypes.READ_PRIVATE: { + receipt_type: { "@rikj:jki.re": { "ts": 1436451550453, } @@ -45,13 +50,18 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): [], ) - def test_filters_out_private_receipt_and_ignores_rest(self): + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_filters_out_private_receipt_and_ignores_rest( + self, receipt_type: str + ) -> None: self._test_filters_private( [ { "content": { "$1dgdgrd5641916114394fHBLK:matrix.org": { - ReceiptTypes.READ_PRIVATE: { + receipt_type: { "@rikj:jki.re": { "ts": 1436451550453, }, @@ -84,13 +94,18 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - def test_filters_out_event_with_only_private_receipts_and_ignores_the_rest(self): + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_filters_out_event_with_only_private_receipts_and_ignores_the_rest( + self, receipt_type: str + ) -> None: self._test_filters_private( [ { "content": { "$14356419edgd14394fHBLK:matrix.org": { - ReceiptTypes.READ_PRIVATE: { + receipt_type: { "@rikj:jki.re": { "ts": 1436451550453, }, @@ -125,7 +140,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - def test_handles_empty_event(self): + def test_handles_empty_event(self) -> None: self._test_filters_private( [ { @@ -160,13 +175,18 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - def test_filters_out_receipt_event_with_only_private_receipt_and_ignores_rest(self): + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_filters_out_receipt_event_with_only_private_receipt_and_ignores_rest( + self, receipt_type: str + ) -> None: self._test_filters_private( [ { "content": { "$14356419edgd14394fHBLK:matrix.org": { - ReceiptTypes.READ_PRIVATE: { + receipt_type: { "@rikj:jki.re": { "ts": 1436451550453, }, @@ -207,7 +227,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - def test_handles_string_data(self): + def test_handles_string_data(self) -> None: """ Tests that an invalid shape for read-receipts is handled. Context: https://github.com/matrix-org/synapse/issues/10603 @@ -242,13 +262,16 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - def test_leaves_our_private_and_their_public(self): + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_leaves_our_private_and_their_public(self, receipt_type: str) -> None: self._test_filters_private( [ { "content": { "$1dgdgrd5641916114394fHBLK:matrix.org": { - ReceiptTypes.READ_PRIVATE: { + receipt_type: { "@me:server.org": { "ts": 1436451550453, }, @@ -273,7 +296,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): { "content": { "$1dgdgrd5641916114394fHBLK:matrix.org": { - ReceiptTypes.READ_PRIVATE: { + receipt_type: { "@me:server.org": { "ts": 1436451550453, }, @@ -296,13 +319,16 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - def test_we_do_not_mutate(self): + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_we_do_not_mutate(self, receipt_type: str) -> None: """Ensure the input values are not modified.""" events = [ { "content": { "$1435641916114394fHBLK:matrix.org": { - ReceiptTypes.READ_PRIVATE: { + receipt_type: { "@rikj:jki.re": { "ts": 1436451550453, } @@ -320,7 +346,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): def _test_filters_private( self, events: List[JsonDict], expected_output: List[JsonDict] - ): + ) -> None: """Tests that the _filter_out_private returns the expected output""" filtered_events = self.event_source.filter_out_private_receipts( events, "@me:server.org" diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index ae16184828..de0dec8539 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -38,7 +38,6 @@ from tests.federation.transport.test_knocking import ( KnockingStrippedStateEventHelperMixin, ) from tests.server import TimedOutException -from tests.unittest import override_config class FilterTestCase(unittest.HomeserverTestCase): @@ -390,6 +389,12 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + config["experimental_features"] = {"msc2285_enabled": True} + + return self.setup_test_homeserver(config=config) + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.url = "/sync?since=%s" self.next_batch = "s0" @@ -408,15 +413,17 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Join the second user self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) - @override_config({"experimental_features": {"msc2285_enabled": True}}) - def test_private_read_receipts(self) -> None: + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_private_read_receipts(self, receipt_type: str) -> None: # Send a message as the first user res = self.helper.send(self.room_id, body="hello", tok=self.tok) # Send a private read receipt to tell the server the first user's message was read channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}", {}, access_token=self.tok2, ) @@ -425,8 +432,10 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Test that the first user can't see the other user's private read receipt self.assertIsNone(self._get_read_receipt()) - @override_config({"experimental_features": {"msc2285_enabled": True}}) - def test_public_receipt_can_override_private(self) -> None: + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_public_receipt_can_override_private(self, receipt_type: str) -> None: """ Sending a public read receipt to the same event which has a private read receipt should cause that receipt to become public. @@ -437,7 +446,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Send a private read receipt channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}", {}, access_token=self.tok2, ) @@ -456,8 +465,10 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Test that we did override the private read receipt self.assertNotEqual(self._get_read_receipt(), None) - @override_config({"experimental_features": {"msc2285_enabled": True}}) - def test_private_receipt_cannot_override_public(self) -> None: + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_private_receipt_cannot_override_public(self, receipt_type: str) -> None: """ Sending a private read receipt to the same event which has a public read receipt should cause no change. @@ -478,7 +489,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Send a private read receipt channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}", {}, access_token=self.tok2, ) @@ -590,7 +601,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): tok=self.tok, ) - def test_unread_counts(self) -> None: + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_unread_counts(self, receipt_type: str) -> None: """Tests that /sync returns the right value for the unread count (MSC2654).""" # Check that our own messages don't increase the unread count. @@ -624,7 +638,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # Send a read receipt to tell the server we've read the latest event. channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}", {}, access_token=self.tok, ) @@ -700,7 +714,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): self._check_unread_count(5) res2 = self.helper.send(self.room_id, "hello", tok=self.tok2) - # Make sure both m.read and org.matrix.msc2285.read.private advance + # Make sure both m.read and m.read.private advance channel = self.make_request( "POST", f"/rooms/{self.room_id}/receipt/m.read/{res1['event_id']}", @@ -712,16 +726,22 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res2['event_id']}", + f"/rooms/{self.room_id}/receipt/{receipt_type}/{res2['event_id']}", {}, access_token=self.tok, ) self.assertEqual(channel.code, 200, channel.json_body) self._check_unread_count(0) - # We test for both receipt types that influence notification counts - @parameterized.expand([ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]) - def test_read_receipts_only_go_down(self, receipt_type: ReceiptTypes) -> None: + # We test for all three receipt types that influence notification counts + @parameterized.expand( + [ + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ] + ) + def test_read_receipts_only_go_down(self, receipt_type: str) -> None: # Join the new user self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) @@ -739,11 +759,11 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.json_body) self._check_unread_count(0) - # Make sure neither m.read nor org.matrix.msc2285.read.private make the + # Make sure neither m.read nor m.read.private make the # read receipt go up to an older event channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res1['event_id']}", + f"/rooms/{self.room_id}/receipt/{receipt_type}/{res1['event_id']}", {}, access_token=self.tok, ) diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py index b1a8f8bba7..191c957fb5 100644 --- a/tests/storage/test_receipts.py +++ b/tests/storage/test_receipts.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from parameterized import parameterized + from synapse.api.constants import ReceiptTypes from synapse.types import UserID, create_requester @@ -23,7 +25,7 @@ OUR_USER_ID = "@our:test" class ReceiptTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor, clock, homeserver) -> None: super().prepare(reactor, clock, homeserver) self.store = homeserver.get_datastores().main @@ -83,10 +85,15 @@ class ReceiptTestCase(HomeserverTestCase): ) ) - def test_return_empty_with_no_data(self): + def test_return_empty_with_no_data(self) -> None: res = self.get_success( self.store.get_receipts_for_user( - OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] + OUR_USER_ID, + [ + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ], ) ) self.assertEqual(res, {}) @@ -94,7 +101,11 @@ class ReceiptTestCase(HomeserverTestCase): res = self.get_success( self.store.get_receipts_for_user_with_orderings( OUR_USER_ID, - [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], + [ + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ], ) ) self.assertEqual(res, {}) @@ -103,12 +114,19 @@ class ReceiptTestCase(HomeserverTestCase): self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id1, - [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], + [ + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ], ) ) self.assertEqual(res, None) - def test_get_receipts_for_user(self): + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_get_receipts_for_user(self, receipt_type: str) -> None: # Send some events into the first room event1_1_id = self.create_and_send_event( self.room_id1, UserID.from_string(OTHER_USER_ID) @@ -126,14 +144,14 @@ class ReceiptTestCase(HomeserverTestCase): # Send private read receipt for the second event self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} + self.room_id1, receipt_type, OUR_USER_ID, [event1_2_id], {} ) ) # Test we get the latest event when we want both private and public receipts res = self.get_success( self.store.get_receipts_for_user( - OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] + OUR_USER_ID, [ReceiptTypes.READ, receipt_type] ) ) self.assertEqual(res, {self.room_id1: event1_2_id}) @@ -146,7 +164,7 @@ class ReceiptTestCase(HomeserverTestCase): # Test we get the latest event when we want only the public receipt res = self.get_success( - self.store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ_PRIVATE]) + self.store.get_receipts_for_user(OUR_USER_ID, [receipt_type]) ) self.assertEqual(res, {self.room_id1: event1_2_id}) @@ -169,17 +187,20 @@ class ReceiptTestCase(HomeserverTestCase): # Test new room is reflected in what the method returns self.get_success( self.store.insert_receipt( - self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} + self.room_id2, receipt_type, OUR_USER_ID, [event2_1_id], {} ) ) res = self.get_success( self.store.get_receipts_for_user( - OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] + OUR_USER_ID, [ReceiptTypes.READ, receipt_type] ) ) self.assertEqual(res, {self.room_id1: event1_2_id, self.room_id2: event2_1_id}) - def test_get_last_receipt_event_id_for_user(self): + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_get_last_receipt_event_id_for_user(self, receipt_type: str) -> None: # Send some events into the first room event1_1_id = self.create_and_send_event( self.room_id1, UserID.from_string(OTHER_USER_ID) @@ -197,7 +218,7 @@ class ReceiptTestCase(HomeserverTestCase): # Send private read receipt for the second event self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} + self.room_id1, receipt_type, OUR_USER_ID, [event1_2_id], {} ) ) @@ -206,7 +227,7 @@ class ReceiptTestCase(HomeserverTestCase): self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id1, - [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], + [ReceiptTypes.READ, receipt_type], ) ) self.assertEqual(res, event1_2_id) @@ -222,7 +243,7 @@ class ReceiptTestCase(HomeserverTestCase): # Test we get the latest event when we want only the private receipt res = self.get_success( self.store.get_last_receipt_event_id_for_user( - OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE] + OUR_USER_ID, self.room_id1, [receipt_type] ) ) self.assertEqual(res, event1_2_id) @@ -248,14 +269,14 @@ class ReceiptTestCase(HomeserverTestCase): # Test new room is reflected in what the method returns self.get_success( self.store.insert_receipt( - self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} + self.room_id2, receipt_type, OUR_USER_ID, [event2_1_id], {} ) ) res = self.get_success( self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id2, - [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], + [ReceiptTypes.READ, receipt_type], ) ) self.assertEqual(res, event2_1_id) -- cgit 1.5.1 From c97042f7eef3748e17c90e48a4122389a89c4735 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 8 Aug 2022 22:21:27 +0200 Subject: Use literals in place of `HTTPStatus` constants in tests (#13469) --- changelog.d/13469.misc | 1 + tests/rest/admin/test_admin.py | 24 +-- tests/rest/admin/test_background_updates.py | 18 +- tests/rest/admin/test_device.py | 36 ++-- tests/rest/admin/test_event_reports.py | 34 ++-- tests/rest/admin/test_federation.py | 46 +++--- tests/rest/admin/test_media.py | 51 +++--- tests/rest/admin/test_registration_tokens.py | 32 ++-- tests/rest/admin/test_room.py | 124 +++++++------- tests/rest/admin/test_server_notice.py | 14 +- tests/rest/admin/test_statistics.py | 36 ++-- tests/rest/admin/test_user.py | 238 +++++++++++++-------------- tests/rest/admin/test_username_available.py | 6 +- 13 files changed, 329 insertions(+), 331 deletions(-) create mode 100644 changelog.d/13469.misc (limited to 'tests/rest') diff --git a/changelog.d/13469.misc b/changelog.d/13469.misc new file mode 100644 index 0000000000..315930deab --- /dev/null +++ b/changelog.d/13469.misc @@ -0,0 +1 @@ +Use literals in place of `HTTPStatus` constants in tests. \ No newline at end of file diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 82ac5991e6..06e74d5e58 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -42,7 +42,7 @@ class VersionTestCase(unittest.HomeserverTestCase): def test_version_string(self) -> None: channel = self.make_request("GET", self.url, shorthand=False) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual( {"server_version", "python_version"}, set(channel.json_body.keys()) ) @@ -139,7 +139,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): ) # Should be successful - self.assertEqual(HTTPStatus.OK, channel.code) + self.assertEqual(200, channel.code) # Quarantine the media url = "/_synapse/admin/v1/media/quarantine/%s/%s" % ( @@ -152,7 +152,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): access_token=admin_user_tok, ) self.pump(1.0) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Attempt to access the media self._ensure_quarantined(admin_user_tok, server_name_and_media_id) @@ -209,7 +209,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): access_token=admin_user_tok, ) self.pump(1.0) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual( channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items" ) @@ -251,7 +251,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): access_token=admin_user_tok, ) self.pump(1.0) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual( channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items" ) @@ -285,7 +285,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),) channel = self.make_request("POST", url, access_token=admin_user_tok) self.pump(1.0) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Quarantine all media by this user url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote( @@ -297,7 +297,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): access_token=admin_user_tok, ) self.pump(1.0) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual( channel.json_body, {"num_quarantined": 1}, "Expected 1 quarantined item" ) @@ -318,10 +318,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Shouldn't be quarantined self.assertEqual( - HTTPStatus.OK, + 200, channel.code, msg=( - "Expected to receive a HTTPStatus.OK on accessing not-quarantined media: %s" + "Expected to receive a 200 on accessing not-quarantined media: %s" % server_and_media_id_2 ), ) @@ -350,7 +350,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase): def test_purge_history(self) -> None: """ Simple test of purge history API. - Test only that is is possible to call, get status HTTPStatus.OK and purge_id. + Test only that is is possible to call, get status 200 and purge_id. """ channel = self.make_request( @@ -360,7 +360,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("purge_id", channel.json_body) purge_id = channel.json_body["purge_id"] @@ -371,5 +371,5 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("complete", channel.json_body["status"]) diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py index 6cf56b1e35..8295ecf248 100644 --- a/tests/rest/admin/test_background_updates.py +++ b/tests/rest/admin/test_background_updates.py @@ -125,7 +125,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/status", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Background updates should be enabled, but none should be running. self.assertDictEqual( @@ -147,7 +147,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/status", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Background updates should be enabled, and one should be running. self.assertDictEqual( @@ -181,7 +181,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/enabled", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertDictEqual(channel.json_body, {"enabled": True}) # Disable the BG updates @@ -191,7 +191,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): content={"enabled": False}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertDictEqual(channel.json_body, {"enabled": False}) # Advance a bit and get the current status, note this will finish the in @@ -204,7 +204,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/status", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertDictEqual( channel.json_body, { @@ -231,7 +231,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/status", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # There should be no change from the previous /status response. self.assertDictEqual( @@ -259,7 +259,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): content={"enabled": True}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertDictEqual(channel.json_body, {"enabled": True}) @@ -270,7 +270,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/status", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Background updates should be enabled and making progress. self.assertDictEqual( @@ -325,7 +325,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # test that each background update is waiting now for update in updates: diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index f7080bda87..779f1bfac1 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -122,7 +122,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): def test_unknown_device(self) -> None: """ - Tests that a lookup for a device that does not exist returns either HTTPStatus.NOT_FOUND or HTTPStatus.OK. + Tests that a lookup for a device that does not exist returns either HTTPStatus.NOT_FOUND or 200. """ url = "/_synapse/admin/v2/users/%s/devices/unknown_device" % urllib.parse.quote( self.other_user @@ -143,7 +143,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) channel = self.make_request( "DELETE", @@ -151,8 +151,8 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - # Delete unknown device returns status HTTPStatus.OK - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + # Delete unknown device returns status 200 + self.assertEqual(200, channel.code, msg=channel.json_body) def test_update_device_too_long_display_name(self) -> None: """ @@ -189,12 +189,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) def test_update_no_display_name(self) -> None: """ - Tests that a update for a device without JSON returns a HTTPStatus.OK + Tests that a update for a device without JSON returns a 200 """ # Set iniital display name. update = {"display_name": "new display"} @@ -210,7 +210,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Ensure the display name was not updated. channel = self.make_request( @@ -219,7 +219,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) def test_update_display_name(self) -> None: @@ -234,7 +234,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): content={"display_name": "new displayname"}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Check new display_name channel = self.make_request( @@ -243,7 +243,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new displayname", channel.json_body["display_name"]) def test_get_device(self) -> None: @@ -256,7 +256,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) # Check that all fields are available self.assertIn("user_id", channel.json_body) @@ -281,7 +281,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Ensure that the number of devices is decreased res = self.get_success(self.handler.get_devices_by_user(self.other_user)) @@ -379,7 +379,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["devices"])) @@ -399,7 +399,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_devices, channel.json_body["total"]) self.assertEqual(number_devices, len(channel.json_body["devices"])) self.assertEqual(self.other_user, channel.json_body["devices"][0]["user_id"]) @@ -494,7 +494,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): def test_unknown_devices(self) -> None: """ - Tests that a remove of a device that does not exist returns HTTPStatus.OK. + Tests that a remove of a device that does not exist returns 200. """ channel = self.make_request( "POST", @@ -503,8 +503,8 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): content={"devices": ["unknown_device1", "unknown_device2"]}, ) - # Delete unknown devices returns status HTTPStatus.OK - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + # Delete unknown devices returns status 200 + self.assertEqual(200, channel.code, msg=channel.json_body) def test_delete_devices(self) -> None: """ @@ -533,7 +533,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): content={"devices": device_ids}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) res = self.get_success(self.handler.get_devices_by_user(self.other_user)) self.assertEqual(0, len(res)) diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index 4f89f8b534..9bc6ce62cb 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -117,7 +117,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) self.assertNotIn("next_token", channel.json_body) @@ -134,7 +134,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 5) self.assertEqual(channel.json_body["next_token"], 5) @@ -151,7 +151,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 15) self.assertNotIn("next_token", channel.json_body) @@ -168,7 +168,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(channel.json_body["next_token"], 15) self.assertEqual(len(channel.json_body["event_reports"]), 10) @@ -185,7 +185,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 10) self.assertEqual(len(channel.json_body["event_reports"]), 10) self.assertNotIn("next_token", channel.json_body) @@ -205,7 +205,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 10) self.assertEqual(len(channel.json_body["event_reports"]), 10) self.assertNotIn("next_token", channel.json_body) @@ -225,7 +225,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 5) self.assertEqual(len(channel.json_body["event_reports"]), 5) self.assertNotIn("next_token", channel.json_body) @@ -247,7 +247,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) report = 1 @@ -265,7 +265,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) report = 1 @@ -344,7 +344,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) self.assertNotIn("next_token", channel.json_body) @@ -357,7 +357,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) self.assertNotIn("next_token", channel.json_body) @@ -370,7 +370,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 19) self.assertEqual(channel.json_body["next_token"], 19) @@ -384,7 +384,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 1) self.assertNotIn("next_token", channel.json_body) @@ -400,7 +400,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): {"score": -100, "reason": "this makes me sad"}, access_token=user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) def _create_event_and_report_without_parameters( self, room_id: str, user_tok: str @@ -415,7 +415,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): {}, access_token=user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) def _check_fields(self, content: List[JsonDict]) -> None: """Checks that all attributes are present in an event report""" @@ -502,7 +502,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self._check_fields(channel.json_body) def test_invalid_report_id(self) -> None: @@ -594,7 +594,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): {"score": -100, "reason": "this makes me sad"}, access_token=user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) def _check_fields(self, content: JsonDict) -> None: """Checks that all attributes are present in a event report""" diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py index 929bbdc37d..c3927c2735 100644 --- a/tests/rest/admin/test_federation.py +++ b/tests/rest/admin/test_federation.py @@ -142,7 +142,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_destinations) self.assertEqual(len(channel.json_body["destinations"]), 5) self.assertEqual(channel.json_body["next_token"], "5") @@ -160,7 +160,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_destinations) self.assertEqual(len(channel.json_body["destinations"]), 15) self.assertNotIn("next_token", channel.json_body) @@ -178,7 +178,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_destinations) self.assertEqual(channel.json_body["next_token"], "15") self.assertEqual(len(channel.json_body["destinations"]), 10) @@ -198,7 +198,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_destinations) self.assertEqual(len(channel.json_body["destinations"]), number_destinations) self.assertNotIn("next_token", channel.json_body) @@ -211,7 +211,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_destinations) self.assertEqual(len(channel.json_body["destinations"]), number_destinations) self.assertNotIn("next_token", channel.json_body) @@ -224,7 +224,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_destinations) self.assertEqual(len(channel.json_body["destinations"]), 19) self.assertEqual(channel.json_body["next_token"], "19") @@ -238,7 +238,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_destinations) self.assertEqual(len(channel.json_body["destinations"]), 1) self.assertNotIn("next_token", channel.json_body) @@ -255,7 +255,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_destinations, len(channel.json_body["destinations"])) self.assertEqual(number_destinations, channel.json_body["total"]) @@ -290,7 +290,7 @@ class FederationTestCase(unittest.HomeserverTestCase): url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], len(expected_destination_list)) returned_order = [ @@ -376,7 +376,7 @@ class FederationTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Check that destinations were returned self.assertTrue("destinations" in channel.json_body) @@ -418,7 +418,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("sub0.example.com", channel.json_body["destination"]) # Check that all fields are available @@ -435,7 +435,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("sub0.example.com", channel.json_body["destination"]) self.assertEqual(0, channel.json_body["retry_last_ts"]) self.assertEqual(0, channel.json_body["retry_interval"]) @@ -452,7 +452,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) retry_timings = self.get_success( self.store.get_destination_retry_timings("sub0.example.com") @@ -619,7 +619,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_rooms) self.assertEqual(len(channel.json_body["rooms"]), 3) self.assertEqual(channel.json_body["next_token"], "3") @@ -637,7 +637,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_rooms) self.assertEqual(len(channel.json_body["rooms"]), 5) self.assertNotIn("next_token", channel.json_body) @@ -655,7 +655,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_rooms) self.assertEqual(channel.json_body["next_token"], "8") self.assertEqual(len(channel.json_body["rooms"]), 5) @@ -673,7 +673,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel_asc.code, msg=channel_asc.json_body) + self.assertEqual(200, channel_asc.code, msg=channel_asc.json_body) self.assertEqual(channel_asc.json_body["total"], number_rooms) self.assertEqual(number_rooms, len(channel_asc.json_body["rooms"])) self._check_fields(channel_asc.json_body["rooms"]) @@ -685,7 +685,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel_desc.code, msg=channel_desc.json_body) + self.assertEqual(200, channel_desc.code, msg=channel_desc.json_body) self.assertEqual(channel_desc.json_body["total"], number_rooms) self.assertEqual(number_rooms, len(channel_desc.json_body["rooms"])) self._check_fields(channel_desc.json_body["rooms"]) @@ -711,7 +711,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_rooms) self.assertEqual(len(channel.json_body["rooms"]), number_rooms) self.assertNotIn("next_token", channel.json_body) @@ -724,7 +724,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_rooms) self.assertEqual(len(channel.json_body["rooms"]), number_rooms) self.assertNotIn("next_token", channel.json_body) @@ -737,7 +737,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_rooms) self.assertEqual(len(channel.json_body["rooms"]), 4) self.assertEqual(channel.json_body["next_token"], "4") @@ -751,7 +751,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_rooms) self.assertEqual(len(channel.json_body["rooms"]), 1) self.assertNotIn("next_token", channel.json_body) @@ -767,7 +767,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_rooms) self.assertEqual(number_rooms, len(channel.json_body["rooms"])) self._check_fields(channel.json_body["rooms"]) diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index e909e444ac..92fd6c780d 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -131,7 +131,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): upload_resource, SMALL_PNG, tok=self.admin_user_tok, - expect_code=HTTPStatus.OK, + expect_code=200, ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' @@ -151,11 +151,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): # Should be successful self.assertEqual( - HTTPStatus.OK, + 200, channel.code, msg=( - "Expected to receive a HTTPStatus.OK on accessing media: %s" - % server_and_media_id + "Expected to receive a 200 on accessing media: %s" % server_and_media_id ), ) @@ -172,7 +171,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( media_id, @@ -388,7 +387,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( media_id, @@ -413,7 +412,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self._access_media(server_and_media_id) @@ -425,7 +424,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( server_and_media_id.split("/")[1], @@ -449,7 +448,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&size_gt=67", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self._access_media(server_and_media_id) @@ -460,7 +459,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&size_gt=66", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( server_and_media_id.split("/")[1], @@ -485,7 +484,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): content={"avatar_url": "mxc://%s" % (server_and_media_id,)}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) now_ms = self.clock.time_msec() channel = self.make_request( @@ -493,7 +492,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self._access_media(server_and_media_id) @@ -504,7 +503,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( server_and_media_id.split("/")[1], @@ -530,7 +529,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): content={"url": "mxc://%s" % (server_and_media_id,)}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) now_ms = self.clock.time_msec() channel = self.make_request( @@ -538,7 +537,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self._access_media(server_and_media_id) @@ -549,7 +548,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( server_and_media_id.split("/")[1], @@ -569,7 +568,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): upload_resource, SMALL_PNG, tok=self.admin_user_tok, - expect_code=HTTPStatus.OK, + expect_code=200, ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' @@ -602,10 +601,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): if expect_success: self.assertEqual( - HTTPStatus.OK, + 200, channel.code, msg=( - "Expected to receive a HTTPStatus.OK on accessing media: %s" + "Expected to receive a 200 on accessing media: %s" % server_and_media_id ), ) @@ -648,7 +647,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): upload_resource, SMALL_PNG, tok=self.admin_user_tok, - expect_code=HTTPStatus.OK, + expect_code=200, ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' @@ -712,7 +711,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) @@ -726,7 +725,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) @@ -753,7 +752,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) # verify that is not in quarantine @@ -785,7 +784,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): upload_resource, SMALL_PNG, tok=self.admin_user_tok, - expect_code=HTTPStatus.OK, + expect_code=200, ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' @@ -845,7 +844,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) @@ -859,7 +858,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py index 8354250ec2..544daaa4c8 100644 --- a/tests/rest/admin/test_registration_tokens.py +++ b/tests/rest/admin/test_registration_tokens.py @@ -105,7 +105,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(len(channel.json_body["token"]), 16) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) @@ -129,7 +129,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["token"], token) self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"]) @@ -150,7 +150,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(len(channel.json_body["token"]), 16) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) @@ -207,7 +207,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): data, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel1.code, msg=channel1.json_body) + self.assertEqual(200, channel1.code, msg=channel1.json_body) channel2 = self.make_request( "POST", @@ -251,7 +251,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 0}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["uses_allowed"], 0) # Should fail with negative integer @@ -321,7 +321,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 64}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(len(channel.json_body["token"]), 64) # Should fail with 0 @@ -439,7 +439,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 1}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertIsNone(channel.json_body["expiry_time"]) @@ -450,7 +450,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 0}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["uses_allowed"], 0) self.assertIsNone(channel.json_body["expiry_time"]) @@ -461,7 +461,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": None}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) @@ -506,7 +506,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": new_expiry_time}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) self.assertIsNone(channel.json_body["uses_allowed"]) @@ -517,7 +517,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": None}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIsNone(channel.json_body["expiry_time"]) self.assertIsNone(channel.json_body["uses_allowed"]) @@ -568,7 +568,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) @@ -655,7 +655,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # GETTING ONE @@ -716,7 +716,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["token"], token) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) @@ -762,7 +762,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(len(channel.json_body["registration_tokens"]), 1) token_info = channel.json_body["registration_tokens"][0] self.assertEqual(token_info["token"], token) @@ -816,7 +816,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(len(channel.json_body["registration_tokens"]), 2) token_info_1 = channel.json_body["registration_tokens"][0] token_info_2 = channel.json_body["registration_tokens"][1] diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 989cbdb5e2..6ea7858db7 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -94,7 +94,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) def test_room_is_not_valid(self) -> None: """ @@ -127,7 +127,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("new_room_id", channel.json_body) self.assertIn("kicked_users", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -202,7 +202,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -233,7 +233,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -265,7 +265,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -296,7 +296,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): ) # The room is now blocked. - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self._is_blocked(room_id) def test_shutdown_room_consent(self) -> None: @@ -337,7 +337,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("new_room_id", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -366,7 +366,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): {"history_visibility": "world_readable"}, access_token=self.other_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Test that room is not purged with self.assertRaises(AssertionError): @@ -383,7 +383,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("new_room_id", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -522,7 +522,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("delete_id", channel.json_body) delete_id = channel.json_body["delete_id"] @@ -533,7 +533,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, len(channel.json_body["results"])) self.assertEqual("complete", channel.json_body["results"][0]["status"]) self.assertEqual(delete_id, channel.json_body["results"][0]["delete_id"]) @@ -574,7 +574,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("delete_id", channel.json_body) delete_id = channel.json_body["delete_id"] @@ -639,7 +639,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("delete_id", channel.json_body) delete_id1 = channel.json_body["delete_id"] @@ -654,7 +654,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("delete_id", channel.json_body) delete_id2 = channel.json_body["delete_id"] @@ -665,7 +665,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(2, len(channel.json_body["results"])) self.assertEqual("complete", channel.json_body["results"][0]["status"]) self.assertEqual("complete", channel.json_body["results"][1]["status"]) @@ -682,7 +682,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, len(channel.json_body["results"])) self.assertEqual("complete", channel.json_body["results"][0]["status"]) self.assertEqual(delete_id2, channel.json_body["results"][0]["delete_id"]) @@ -733,7 +733,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): # get result of first call first_channel.await_result() - self.assertEqual(HTTPStatus.OK, first_channel.code, msg=first_channel.json_body) + self.assertEqual(200, first_channel.code, msg=first_channel.json_body) self.assertIn("delete_id", first_channel.json_body) # check status after finish the task @@ -764,7 +764,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("delete_id", channel.json_body) delete_id = channel.json_body["delete_id"] @@ -795,7 +795,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("delete_id", channel.json_body) delete_id = channel.json_body["delete_id"] @@ -827,7 +827,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("delete_id", channel.json_body) delete_id = channel.json_body["delete_id"] @@ -876,7 +876,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("delete_id", channel.json_body) delete_id = channel.json_body["delete_id"] @@ -887,7 +887,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): self.url_status_by_room_id, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, len(channel.json_body["results"])) # Test that member has moved to new room @@ -914,7 +914,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): content={"history_visibility": "world_readable"}, access_token=self.other_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Test that room is not purged with self.assertRaises(AssertionError): @@ -931,7 +931,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("delete_id", channel.json_body) delete_id = channel.json_body["delete_id"] @@ -942,7 +942,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): self.url_status_by_room_id, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, len(channel.json_body["results"])) # Test that member has moved to new room @@ -1026,9 +1026,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): self.url_status_by_room_id, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.OK, channel_room_id.code, msg=channel_room_id.json_body - ) + self.assertEqual(200, channel_room_id.code, msg=channel_room_id.json_body) self.assertEqual(1, len(channel_room_id.json_body["results"])) self.assertEqual( delete_id, channel_room_id.json_body["results"][0]["delete_id"] @@ -1041,7 +1039,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) self.assertEqual( - HTTPStatus.OK, + 200, channel_delete_id.code, msg=channel_delete_id.json_body, ) @@ -1100,7 +1098,7 @@ class RoomTestCase(unittest.HomeserverTestCase): ) # Check request completed successfully - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Check that response json body contains a "rooms" key self.assertTrue( @@ -1186,7 +1184,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertTrue("rooms" in channel.json_body) for r in channel.json_body["rooms"]: @@ -1226,7 +1224,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) def test_correct_room_attributes(self) -> None: """Test the correct attributes for a room are returned""" @@ -1253,7 +1251,7 @@ class RoomTestCase(unittest.HomeserverTestCase): {"room_id": room_id}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Set this new alias as the canonical alias for this room self.helper.send_state( @@ -1285,7 +1283,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Check that rooms were returned self.assertTrue("rooms" in channel.json_body) @@ -1341,7 +1339,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Check that rooms were returned self.assertTrue("rooms" in channel.json_body) @@ -1487,7 +1485,7 @@ class RoomTestCase(unittest.HomeserverTestCase): def _search_test( expected_room_id: Optional[str], search_term: str, - expected_http_code: int = HTTPStatus.OK, + expected_http_code: int = 200, ) -> None: """Search for a room and check that the returned room's id is a match @@ -1505,7 +1503,7 @@ class RoomTestCase(unittest.HomeserverTestCase): ) self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) - if expected_http_code != HTTPStatus.OK: + if expected_http_code != 200: return # Check that rooms were returned @@ -1585,7 +1583,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(room_id, channel.json_body["rooms"][0].get("room_id")) self.assertEqual("ж", channel.json_body["rooms"][0].get("name")) @@ -1618,7 +1616,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("room_id", channel.json_body) self.assertIn("name", channel.json_body) @@ -1650,7 +1648,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["joined_local_devices"]) # Have another user join the room @@ -1664,7 +1662,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(2, channel.json_body["joined_local_devices"]) # leave room @@ -1676,7 +1674,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["joined_local_devices"]) def test_room_members(self) -> None: @@ -1707,7 +1705,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertCountEqual( ["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"] @@ -1720,7 +1718,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertCountEqual( ["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"] @@ -1738,7 +1736,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("state", channel.json_body) # testing that the state events match is painful and not done here. We assume that # the create_room already does the right thing, so no need to verify that we got @@ -1755,7 +1753,7 @@ class RoomTestCase(unittest.HomeserverTestCase): {"room_id": room_id}, access_token=admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Set this new alias as the canonical alias for this room self.helper.send_state( @@ -1924,7 +1922,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.public_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -1934,7 +1932,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) def test_join_private_room_if_not_member(self) -> None: @@ -1982,7 +1980,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) # Join user to room. @@ -1995,7 +1993,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): content={"user_id": self.second_user_id}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -2005,7 +2003,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) def test_join_private_room_if_owner(self) -> None: @@ -2025,7 +2023,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -2035,7 +2033,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) def test_context_as_non_admin(self) -> None: @@ -2099,7 +2097,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): % (room_id, events[midway]["event_id"]), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual( channel.json_body["event"]["event_id"], events[midway]["event_id"] ) @@ -2158,7 +2156,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Now we test that we can join the room and ban a user. self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok) @@ -2185,7 +2183,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Now we test that we can join the room (we should have received an # invite) and can ban a user. @@ -2211,7 +2209,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Now we test that we can join the room and ban a user. self.helper.join(room_id, self.second_user_id, tok=self.second_tok) @@ -2354,7 +2352,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): content={"block": True}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertTrue(channel.json_body["block"]) self._is_blocked(room_id, expect=True) @@ -2378,7 +2376,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): content={"block": True}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertTrue(channel.json_body["block"]) self._is_blocked(self.room_id, expect=True) @@ -2394,7 +2392,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): content={"block": False}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body["block"]) self._is_blocked(room_id, expect=False) @@ -2418,7 +2416,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): content={"block": False}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body["block"]) self._is_blocked(self.room_id, expect=False) @@ -2433,7 +2431,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): self.url % room_id, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertTrue(channel.json_body["block"]) self.assertEqual(self.other_user, channel.json_body["user_id"]) @@ -2457,7 +2455,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): self.url % room_id, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body["block"]) self.assertNotIn("user_id", channel.json_body) diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index dbcba2663c..bea3ac34d8 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -197,7 +197,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg one"}, }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -226,7 +226,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg two"}, }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # user has no new invites or memberships self._check_invite_and_join_status(self.other_user, 0, 1) @@ -260,7 +260,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg one"}, }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -301,7 +301,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg two"}, }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -341,7 +341,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg one"}, }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -388,7 +388,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg two"}, }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -538,7 +538,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "/_matrix/client/r0/sync", access_token=token ) - self.assertEqual(channel.code, HTTPStatus.OK) + self.assertEqual(channel.code, 200) # Get the messages room = channel.json_body["rooms"]["join"][room_id] diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index 7cb8ec57ba..baed27a815 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -204,7 +204,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 10) self.assertEqual(len(channel.json_body["users"]), 5) self.assertEqual(channel.json_body["next_token"], 5) @@ -222,7 +222,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["users"]), 15) self.assertNotIn("next_token", channel.json_body) @@ -240,7 +240,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(channel.json_body["next_token"], 15) self.assertEqual(len(channel.json_body["users"]), 10) @@ -262,7 +262,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -275,7 +275,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -288,7 +288,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 19) self.assertEqual(channel.json_body["next_token"], 19) @@ -301,7 +301,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 1) self.assertNotIn("next_token", channel.json_body) @@ -318,7 +318,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["users"])) @@ -415,7 +415,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["users"][0]["media_count"], 3) # filter media starting at `ts1` after creating first media @@ -425,7 +425,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?from_ts=%s" % (ts1,), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 0) self._create_media(self.other_user_tok, 3) @@ -440,7 +440,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?from_ts=%s&until_ts=%s" % (ts1, ts2), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["users"][0]["media_count"], 3) # filter media until `ts2` and earlier @@ -449,7 +449,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?until_ts=%s" % (ts2,), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["users"][0]["media_count"], 6) def test_search_term(self) -> None: @@ -461,7 +461,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) # filter user 1 and 10-19 by `user_id` @@ -470,7 +470,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?search_term=foo_user_1", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 11) # filter on this user in `displayname` @@ -479,7 +479,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?search_term=bar_user_10", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["users"][0]["displayname"], "bar_user_10") self.assertEqual(channel.json_body["total"], 1) @@ -489,7 +489,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?search_term=foobar", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 0) def _create_users_with_media(self, number_users: int, media_per_user: int) -> None: @@ -515,7 +515,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): for _ in range(number_media): # Upload some media into the room self.helper.upload_media( - upload_resource, SMALL_PNG, tok=user_token, expect_code=HTTPStatus.OK + upload_resource, SMALL_PNG, tok=user_token, expect_code=200 ) def _check_fields(self, content: List[JsonDict]) -> None: @@ -549,7 +549,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], len(expected_user_list)) returned_order = [row["user_id"] for row in channel.json_body["users"]] diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 12db68d564..c2b54b1ef7 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -169,7 +169,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) def test_nonce_reuse(self) -> None: @@ -192,7 +192,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) # Now, try and reuse it @@ -323,11 +323,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob1:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob1:test/displayname") - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("bob1", channel.json_body["displayname"]) # displayname is None @@ -347,11 +347,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob2:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob2:test/displayname") - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("bob2", channel.json_body["displayname"]) # displayname is empty @@ -371,7 +371,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob3:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob3:test/displayname") @@ -394,11 +394,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob4:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob4:test/displayname") - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("Bob's Name", channel.json_body["displayname"]) @override_config( @@ -442,7 +442,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) @@ -494,7 +494,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(3, len(channel.json_body["users"])) self.assertEqual(3, channel.json_body["total"]) @@ -508,7 +508,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): expected_user_id: Optional[str], search_term: str, search_field: Optional[str] = "name", - expected_http_code: Optional[int] = HTTPStatus.OK, + expected_http_code: Optional[int] = 200, ) -> None: """Search for a user and check that the returned user's id is a match @@ -530,7 +530,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): ) self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) - if expected_http_code != HTTPStatus.OK: + if expected_http_code != 200: return # Check that users were returned @@ -659,7 +659,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 5) self.assertEqual(channel.json_body["next_token"], "5") @@ -680,7 +680,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 15) self.assertNotIn("next_token", channel.json_body) @@ -701,7 +701,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(channel.json_body["next_token"], "15") self.assertEqual(len(channel.json_body["users"]), 10) @@ -724,7 +724,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -737,7 +737,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -750,7 +750,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 19) self.assertEqual(channel.json_body["next_token"], "19") @@ -764,7 +764,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 1) self.assertNotIn("next_token", channel.json_body) @@ -867,7 +867,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], len(expected_user_list)) returned_order = [row["name"] for row in channel.json_body["users"]] @@ -1017,7 +1017,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -1032,7 +1032,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): content={"erase": True}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Get user channel = self.make_request( @@ -1041,7 +1041,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -1066,7 +1066,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"erase": True}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self._is_erased("@user:test", True) def test_deactivate_user_erase_false(self) -> None: @@ -1081,7 +1081,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -1096,7 +1096,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): content={"erase": False}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Get user channel = self.make_request( @@ -1105,7 +1105,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -1135,7 +1135,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -1150,7 +1150,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): content={"erase": True}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Get user channel = self.make_request( @@ -1159,7 +1159,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -1352,7 +1352,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("User", channel.json_body["displayname"]) self._check_fields(channel.json_body) @@ -1395,7 +1395,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1458,7 +1458,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1486,7 +1486,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # before limit of monthly active users is reached channel = self.make_request("GET", "/sync", access_token=self.admin_user_tok) - if channel.code != HTTPStatus.OK: + if channel.code != 200: raise HttpResponseException( channel.code, channel.result["reason"], channel.result["body"] ) @@ -1684,7 +1684,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"password": "hahaha"}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self._check_fields(channel.json_body) def test_set_displayname(self) -> None: @@ -1700,7 +1700,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"displayname": "foobar"}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("foobar", channel.json_body["displayname"]) @@ -1711,7 +1711,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("foobar", channel.json_body["displayname"]) @@ -1733,7 +1733,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["threepids"])) # result does not always have the same sort order, therefore it becomes sorted @@ -1759,7 +1759,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1775,7 +1775,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1791,7 +1791,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"threepids": []}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(0, len(channel.json_body["threepids"])) self._check_fields(channel.json_body) @@ -1818,7 +1818,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1837,7 +1837,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1859,7 +1859,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) # other user has this two threepids - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["threepids"])) # result does not always have the same sort order, therefore it becomes sorted @@ -1878,7 +1878,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): url_first_user, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(0, len(channel.json_body["threepids"])) self._check_fields(channel.json_body) @@ -1907,7 +1907,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["external_ids"])) # result does not always have the same sort order, therefore it becomes sorted @@ -1939,7 +1939,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["external_ids"])) self.assertEqual( @@ -1958,7 +1958,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["external_ids"])) self.assertEqual( @@ -1977,7 +1977,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"external_ids": []}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(0, len(channel.json_body["external_ids"])) @@ -2006,7 +2006,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual( @@ -2032,7 +2032,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual( @@ -2075,7 +2075,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual( @@ -2093,7 +2093,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual( @@ -2124,7 +2124,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -2139,7 +2139,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"deactivated": True}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -2158,7 +2158,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -2188,7 +2188,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"deactivated": True}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) @@ -2204,7 +2204,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"displayname": "Foobar"}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) self.assertEqual("Foobar", channel.json_body["displayname"]) @@ -2237,7 +2237,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self._is_erased("@user:test", False) @@ -2271,7 +2271,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self._is_erased("@user:test", False) @@ -2305,7 +2305,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self._is_erased("@user:test", False) @@ -2326,7 +2326,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"admin": True}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["admin"]) @@ -2337,7 +2337,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["admin"]) @@ -2354,7 +2354,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"user_type": UserTypes.SUPPORT}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"]) @@ -2365,7 +2365,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"]) @@ -2377,7 +2377,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"user_type": None}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertIsNone(channel.json_body["user_type"]) @@ -2388,7 +2388,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertIsNone(channel.json_body["user_type"]) @@ -2418,7 +2418,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) self.assertEqual(0, channel.json_body["deactivated"]) @@ -2440,7 +2440,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) @@ -2465,7 +2465,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": True}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertTrue(channel.json_body["deactivated"]) self._is_erased(user_id, False) d = self.store.mark_user_erased(user_id) @@ -2549,7 +2549,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["joined_rooms"])) @@ -2565,7 +2565,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["joined_rooms"])) @@ -2581,7 +2581,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["joined_rooms"])) @@ -2602,7 +2602,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_rooms, channel.json_body["total"]) self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"])) @@ -2649,7 +2649,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual([local_and_remote_room_id], channel.json_body["joined_rooms"]) @@ -2737,7 +2737,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) # Register the pusher @@ -2769,7 +2769,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) for p in channel.json_body["pushers"]: @@ -2865,7 +2865,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 5) self.assertEqual(channel.json_body["next_token"], 5) @@ -2884,7 +2884,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 5) self.assertEqual(len(channel.json_body["deleted_media"]), 5) @@ -2901,7 +2901,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 15) self.assertNotIn("next_token", channel.json_body) @@ -2920,7 +2920,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 15) self.assertEqual(len(channel.json_body["deleted_media"]), 15) @@ -2937,7 +2937,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(channel.json_body["next_token"], 15) self.assertEqual(len(channel.json_body["media"]), 10) @@ -2956,7 +2956,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 10) self.assertEqual(len(channel.json_body["deleted_media"]), 10) @@ -3023,7 +3023,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), number_media) self.assertNotIn("next_token", channel.json_body) @@ -3036,7 +3036,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), number_media) self.assertNotIn("next_token", channel.json_body) @@ -3049,7 +3049,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 19) self.assertEqual(channel.json_body["next_token"], 19) @@ -3063,7 +3063,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 1) self.assertNotIn("next_token", channel.json_body) @@ -3080,7 +3080,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["media"])) @@ -3095,7 +3095,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["deleted_media"])) @@ -3112,7 +3112,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_media, channel.json_body["total"]) self.assertEqual(number_media, len(channel.json_body["media"])) self.assertNotIn("next_token", channel.json_body) @@ -3138,7 +3138,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_media, channel.json_body["total"]) self.assertEqual(number_media, len(channel.json_body["deleted_media"])) self.assertCountEqual(channel.json_body["deleted_media"], media_ids) @@ -3283,7 +3283,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): # Upload some media into the room response = self.helper.upload_media( - upload_resource, image_data, user_token, filename, expect_code=HTTPStatus.OK + upload_resource, image_data, user_token, filename, expect_code=200 ) # Extract media ID from the response @@ -3301,10 +3301,10 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - HTTPStatus.OK, + 200, channel.code, msg=( - f"Expected to receive a HTTPStatus.OK on accessing media: {server_and_media_id}" + f"Expected to receive a 200 on accessing media: {server_and_media_id}" ), ) @@ -3350,7 +3350,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], len(expected_media_list)) returned_order = [row["media_id"] for row in channel.json_body["media"]] @@ -3386,7 +3386,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, b"{}", access_token=self.admin_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) return channel.json_body["access_token"] def test_no_auth(self) -> None: @@ -3427,7 +3427,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # We should only see the one device (from the login in `prepare`) self.assertEqual(len(channel.json_body["devices"]), 1) @@ -3439,11 +3439,11 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Logout with the puppet token channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) @@ -3453,7 +3453,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) def test_user_logout_all(self) -> None: """Tests that the target user calling `/logout/all` does *not* expire @@ -3464,17 +3464,17 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Logout all with the real user token channel = self.make_request( "POST", "logout/all", b"{}", access_token=self.other_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # The puppet token should still work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # .. but the real user's tokens shouldn't channel = self.make_request( @@ -3491,13 +3491,13 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Logout all with the admin user token channel = self.make_request( "POST", "logout/all", b"{}", access_token=self.admin_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) @@ -3507,7 +3507,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) @unittest.override_config( { @@ -3635,7 +3635,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) @@ -3650,7 +3650,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.url, access_token=other_user_token, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) @@ -3715,7 +3715,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): self.assertFalse(result.shadow_banned) channel = self.make_request("POST", self.url, access_token=self.admin_user_tok) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual({}, channel.json_body) # Ensure the user is shadow-banned (and the cache was cleared). @@ -3727,7 +3727,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, access_token=self.admin_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual({}, channel.json_body) # Ensure the user is no longer shadow-banned (and the cache was cleared). @@ -3891,7 +3891,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["messages_per_second"]) self.assertEqual(0, channel.json_body["burst_count"]) @@ -3905,7 +3905,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) @@ -3916,7 +3916,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"messages_per_second": 10, "burst_count": 11}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(10, channel.json_body["messages_per_second"]) self.assertEqual(11, channel.json_body["burst_count"]) @@ -3927,7 +3927,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"messages_per_second": 20, "burst_count": 21}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(20, channel.json_body["messages_per_second"]) self.assertEqual(21, channel.json_body["burst_count"]) @@ -3937,7 +3937,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(20, channel.json_body["messages_per_second"]) self.assertEqual(21, channel.json_body["burst_count"]) @@ -3947,7 +3947,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) @@ -3957,7 +3957,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) @@ -4042,7 +4042,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual( {"a": 1}, channel.json_body["account_data"]["global"]["m.global"] ) diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py index b21f6d4689..b5e7eecf87 100644 --- a/tests/rest/admin/test_username_available.py +++ b/tests/rest/admin/test_username_available.py @@ -50,18 +50,18 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase): def test_username_available(self) -> None: """ - The endpoint should return a HTTPStatus.OK response if the username does not exist + The endpoint should return a 200 response if the username does not exist """ url = "%s?username=%s" % (self.url, "allowed") channel = self.make_request("GET", url, access_token=self.admin_user_tok) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertTrue(channel.json_body["available"]) def test_username_unavailable(self) -> None: """ - The endpoint should return a HTTPStatus.OK response if the username does not exist + The endpoint should return a 200 response if the username does not exist """ url = "%s?username=%s" % (self.url, "disallowed") -- cgit 1.5.1 From 1595052b2681fb86c1c1b9a6028c1bc0d38a2e4b Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 9 Aug 2022 15:56:43 +0200 Subject: Use literals in place of `HTTPStatus` constants in tests (#13479) Replace - `HTTPStatus.NOT_FOUND` - `HTTPStatus.FORBIDDEN` - `HTTPStatus.UNAUTHORIZED` - `HTTPStatus.CONFLICT` - `HTTPStatus.CREATED` Signed-off-by: Dirk Klimpel --- changelog.d/13479.misc | 1 + tests/rest/admin/test_admin.py | 7 +- tests/rest/admin/test_background_updates.py | 4 +- tests/rest/admin/test_device.py | 28 ++++---- tests/rest/admin/test_event_reports.py | 16 ++--- tests/rest/admin/test_federation.py | 10 +-- tests/rest/admin/test_media.py | 32 ++++----- tests/rest/admin/test_registration_tokens.py | 26 +++---- tests/rest/admin/test_room.py | 40 +++++------ tests/rest/admin/test_server_notice.py | 8 +-- tests/rest/admin/test_statistics.py | 6 +- tests/rest/admin/test_user.py | 104 +++++++++++++-------------- 12 files changed, 141 insertions(+), 141 deletions(-) create mode 100644 changelog.d/13479.misc (limited to 'tests/rest') diff --git a/changelog.d/13479.misc b/changelog.d/13479.misc new file mode 100644 index 0000000000..315930deab --- /dev/null +++ b/changelog.d/13479.misc @@ -0,0 +1 @@ +Use literals in place of `HTTPStatus` constants in tests. \ No newline at end of file diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 06e74d5e58..a8f6436836 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -13,7 +13,6 @@ # limitations under the License. import urllib.parse -from http import HTTPStatus from parameterized import parameterized @@ -79,10 +78,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Should be quarantined self.assertEqual( - HTTPStatus.NOT_FOUND, + 404, channel.code, msg=( - "Expected to receive a HTTPStatus.NOT_FOUND on accessing quarantined media: %s" + "Expected to receive a 404 on accessing quarantined media: %s" % server_and_media_id ), ) @@ -107,7 +106,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Expect a forbidden error self.assertEqual( - HTTPStatus.FORBIDDEN, + 403, channel.code, msg="Expected forbidden on quarantining media as a non-admin", ) diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py index 8295ecf248..7cd8b52f02 100644 --- a/tests/rest/admin/test_background_updates.py +++ b/tests/rest/admin/test_background_updates.py @@ -51,7 +51,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): ) def test_requester_is_no_admin(self, method: str, url: str) -> None: """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ self.register_user("user", "pass", admin=False) @@ -64,7 +64,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): access_token=other_user_tok, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_invalid_parameter(self) -> None: diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index 779f1bfac1..6a6e8ad7d8 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -58,7 +58,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): channel = self.make_request(method, self.url, b"{}") self.assertEqual( - HTTPStatus.UNAUTHORIZED, + 401, channel.code, msg=channel.json_body, ) @@ -76,7 +76,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - HTTPStatus.FORBIDDEN, + 403, channel.code, msg=channel.json_body, ) @@ -85,7 +85,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): @parameterized.expand(["GET", "PUT", "DELETE"]) def test_user_does_not_exist(self, method: str) -> None: """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ url = ( "/_synapse/admin/v2/users/@unknown_person:test/devices/%s" @@ -98,7 +98,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @parameterized.expand(["GET", "PUT", "DELETE"]) @@ -122,7 +122,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): def test_unknown_device(self) -> None: """ - Tests that a lookup for a device that does not exist returns either HTTPStatus.NOT_FOUND or 200. + Tests that a lookup for a device that does not exist returns either 404 or 200. """ url = "/_synapse/admin/v2/users/%s/devices/unknown_device" % urllib.parse.quote( self.other_user @@ -134,7 +134,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) channel = self.make_request( @@ -312,7 +312,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", self.url, b"{}") self.assertEqual( - HTTPStatus.UNAUTHORIZED, + 401, channel.code, msg=channel.json_body, ) @@ -331,7 +331,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - HTTPStatus.FORBIDDEN, + 403, channel.code, msg=channel.json_body, ) @@ -339,7 +339,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): def test_user_does_not_exist(self) -> None: """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v2/users/@unknown_person:test/devices" channel = self.make_request( @@ -348,7 +348,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_user_is_not_local(self) -> None: @@ -438,7 +438,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, b"{}") self.assertEqual( - HTTPStatus.UNAUTHORIZED, + 401, channel.code, msg=channel.json_body, ) @@ -457,7 +457,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - HTTPStatus.FORBIDDEN, + 403, channel.code, msg=channel.json_body, ) @@ -465,7 +465,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): def test_user_does_not_exist(self) -> None: """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices" channel = self.make_request( @@ -474,7 +474,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_user_is_not_local(self) -> None: diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index 9bc6ce62cb..c63a86d5b3 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -82,7 +82,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", self.url, b"{}") self.assertEqual( - HTTPStatus.UNAUTHORIZED, + 401, channel.code, msg=channel.json_body, ) @@ -90,7 +90,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): def test_requester_is_no_admin(self) -> None: """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ channel = self.make_request( @@ -100,7 +100,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - HTTPStatus.FORBIDDEN, + 403, channel.code, msg=channel.json_body, ) @@ -467,7 +467,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", self.url, b"{}") self.assertEqual( - HTTPStatus.UNAUTHORIZED, + 401, channel.code, msg=channel.json_body, ) @@ -475,7 +475,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): def test_requester_is_no_admin(self) -> None: """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ channel = self.make_request( @@ -485,7 +485,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - HTTPStatus.FORBIDDEN, + 403, channel.code, msg=channel.json_body, ) @@ -566,7 +566,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): def test_report_id_not_found(self) -> None: """ - Testing that a not existing `report_id` returns a HTTPStatus.NOT_FOUND. + Testing that a not existing `report_id` returns a 404. """ channel = self.make_request( @@ -576,7 +576,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - HTTPStatus.NOT_FOUND, + 404, channel.code, msg=channel.json_body, ) diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py index c3927c2735..8affd830d1 100644 --- a/tests/rest/admin/test_federation.py +++ b/tests/rest/admin/test_federation.py @@ -64,7 +64,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=other_user_tok, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_invalid_parameter(self) -> None: @@ -117,7 +117,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) # invalid destination @@ -127,7 +127,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_limit(self) -> None: @@ -561,7 +561,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=other_user_tok, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_invalid_parameter(self) -> None: @@ -604,7 +604,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_limit(self) -> None: diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 92fd6c780d..d51c10a515 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -60,7 +60,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): channel = self.make_request("DELETE", url, b"{}") self.assertEqual( - HTTPStatus.UNAUTHORIZED, + 401, channel.code, msg=channel.json_body, ) @@ -82,7 +82,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - HTTPStatus.FORBIDDEN, + 403, channel.code, msg=channel.json_body, ) @@ -90,7 +90,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): def test_media_does_not_exist(self) -> None: """ - Tests that a lookup for a media that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a media that does not exist returns a 404 """ url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") @@ -100,7 +100,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_media_is_not_local(self) -> None: @@ -188,10 +188,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) self.assertEqual( - HTTPStatus.NOT_FOUND, + 404, channel.code, msg=( - "Expected to receive a HTTPStatus.NOT_FOUND on accessing deleted media: %s" + "Expected to receive a 404 on accessing deleted media: %s" % server_and_media_id ), ) @@ -231,7 +231,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, b"{}") self.assertEqual( - HTTPStatus.UNAUTHORIZED, + 401, channel.code, msg=channel.json_body, ) @@ -251,7 +251,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - HTTPStatus.FORBIDDEN, + 403, channel.code, msg=channel.json_body, ) @@ -612,10 +612,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.assertTrue(os.path.exists(local_path)) else: self.assertEqual( - HTTPStatus.NOT_FOUND, + 404, channel.code, msg=( - "Expected to receive a HTTPStatus.NOT_FOUND on accessing deleted media: %s" + "Expected to receive a 404 on accessing deleted media: %s" % (server_and_media_id) ), ) @@ -668,7 +668,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - HTTPStatus.UNAUTHORIZED, + 401, channel.code, msg=channel.json_body, ) @@ -689,7 +689,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - HTTPStatus.FORBIDDEN, + 403, channel.code, msg=channel.json_body, ) @@ -801,7 +801,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url % (action, self.media_id), b"{}") self.assertEqual( - HTTPStatus.UNAUTHORIZED, + 401, channel.code, msg=channel.json_body, ) @@ -822,7 +822,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - HTTPStatus.FORBIDDEN, + 403, channel.code, msg=channel.json_body, ) @@ -894,7 +894,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, b"{}") self.assertEqual( - HTTPStatus.UNAUTHORIZED, + 401, channel.code, msg=channel.json_body, ) @@ -914,7 +914,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - HTTPStatus.FORBIDDEN, + 403, channel.code, msg=channel.json_body, ) diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py index 544daaa4c8..bcb602382c 100644 --- a/tests/rest/admin/test_registration_tokens.py +++ b/tests/rest/admin/test_registration_tokens.py @@ -75,7 +75,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): """Try to create a token without authentication.""" channel = self.make_request("POST", self.url + "/new", {}) self.assertEqual( - HTTPStatus.UNAUTHORIZED, + 401, channel.code, msg=channel.json_body, ) @@ -90,7 +90,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.other_user_tok, ) self.assertEqual( - HTTPStatus.FORBIDDEN, + 403, channel.code, msg=channel.json_body, ) @@ -390,7 +390,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, ) self.assertEqual( - HTTPStatus.UNAUTHORIZED, + 401, channel.code, msg=channel.json_body, ) @@ -405,7 +405,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.other_user_tok, ) self.assertEqual( - HTTPStatus.FORBIDDEN, + 403, channel.code, msg=channel.json_body, ) @@ -421,7 +421,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - HTTPStatus.NOT_FOUND, + 404, channel.code, msg=channel.json_body, ) @@ -606,7 +606,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, ) self.assertEqual( - HTTPStatus.UNAUTHORIZED, + 401, channel.code, msg=channel.json_body, ) @@ -621,7 +621,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.other_user_tok, ) self.assertEqual( - HTTPStatus.FORBIDDEN, + 403, channel.code, msg=channel.json_body, ) @@ -637,7 +637,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - HTTPStatus.NOT_FOUND, + 404, channel.code, msg=channel.json_body, ) @@ -667,7 +667,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, ) self.assertEqual( - HTTPStatus.UNAUTHORIZED, + 401, channel.code, msg=channel.json_body, ) @@ -682,7 +682,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.other_user_tok, ) self.assertEqual( - HTTPStatus.FORBIDDEN, + 403, channel.code, msg=channel.json_body, ) @@ -698,7 +698,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - HTTPStatus.NOT_FOUND, + 404, channel.code, msg=channel.json_body, ) @@ -729,7 +729,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): """Try to list tokens without authentication.""" channel = self.make_request("GET", self.url, {}) self.assertEqual( - HTTPStatus.UNAUTHORIZED, + 401, channel.code, msg=channel.json_body, ) @@ -744,7 +744,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.other_user_tok, ) self.assertEqual( - HTTPStatus.FORBIDDEN, + 403, channel.code, msg=channel.json_body, ) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 6ea7858db7..8350f9025a 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -68,7 +68,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): def test_requester_is_no_admin(self) -> None: """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ channel = self.make_request( @@ -78,7 +78,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.other_user_tok, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_room_does_not_exist(self) -> None: @@ -319,7 +319,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): self.room_id, body="foo", tok=self.other_user_tok, - expect_code=HTTPStatus.FORBIDDEN, + expect_code=403, ) # Test that room is not purged @@ -398,7 +398,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): self._has_no_members(self.room_id) # Assert we can no longer peek into the room - self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN) + self._assert_peek(self.room_id, expect_code=403) def _is_blocked(self, room_id: str, expect: bool = True) -> None: """Assert that the room is blocked or not""" @@ -494,7 +494,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): ) def test_requester_is_no_admin(self, method: str, url: str) -> None: """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ channel = self.make_request( @@ -504,7 +504,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.other_user_tok, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_room_does_not_exist(self) -> None: @@ -696,7 +696,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_delete_same_room_twice(self) -> None: @@ -858,7 +858,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): self.room_id, body="foo", tok=self.other_user_tok, - expect_code=HTTPStatus.FORBIDDEN, + expect_code=403, ) # Test that room is not purged @@ -955,7 +955,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): self._has_no_members(self.room_id) # Assert we can no longer peek into the room - self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN) + self._assert_peek(self.room_id, expect_code=403) def _is_blocked(self, room_id: str, expect: bool = True) -> None: """Assert that the room is blocked or not""" @@ -1782,7 +1782,7 @@ class RoomTestCase(unittest.HomeserverTestCase): # delete the rooms and get joined roomed membership url = f"/_matrix/client/r0/rooms/{room_id}/joined_members" channel = self.make_request("GET", url.encode("ascii"), access_token=user_tok) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -1811,7 +1811,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): def test_requester_is_no_admin(self) -> None: """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ channel = self.make_request( @@ -1821,7 +1821,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): access_token=self.second_tok, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_invalid_parameter(self) -> None: @@ -1841,7 +1841,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): def test_local_user_does_not_exist(self) -> None: """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ channel = self.make_request( @@ -1851,7 +1851,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_remote_user(self) -> None: @@ -1874,7 +1874,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): def test_room_does_not_exist(self) -> None: """ - Check that unknown rooms/server return error HTTPStatus.NOT_FOUND. + Check that unknown rooms/server return error 404. """ url = "/_synapse/admin/v1/join/!unknown:test" @@ -1885,7 +1885,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual( "Can't join remote room because no servers that are in the room have been provided.", channel.json_body["error"], @@ -1952,7 +1952,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_join_private_room_if_member(self) -> None: @@ -2067,7 +2067,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): % (room_id, events[midway]["event_id"]), access_token=tok, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_context_as_admin(self) -> None: @@ -2277,7 +2277,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): @parameterized.expand([("PUT",), ("GET",)]) def test_requester_is_no_admin(self, method: str) -> None: - """If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.""" + """If the user is not a server admin, an error 403 is returned.""" channel = self.make_request( method, @@ -2286,7 +2286,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): access_token=self.other_user_tok, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand([("PUT",), ("GET",)]) diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index bea3ac34d8..3a6b98fbb5 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -57,7 +57,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url) self.assertEqual( - HTTPStatus.UNAUTHORIZED, + 401, channel.code, msg=channel.json_body, ) @@ -72,7 +72,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - HTTPStatus.FORBIDDEN, + 403, channel.code, msg=channel.json_body, ) @@ -80,7 +80,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) def test_user_does_not_exist(self) -> None: - """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND""" + """Tests that a lookup for a user that does not exist returns a 404""" channel = self.make_request( "POST", self.url, @@ -88,7 +88,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): content={"user_id": "@unknown_person:test", "content": ""}, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index baed27a815..3a8982afea 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -52,7 +52,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", self.url, b"{}") self.assertEqual( - HTTPStatus.UNAUTHORIZED, + 401, channel.code, msg=channel.json_body, ) @@ -60,7 +60,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): def test_requester_is_no_admin(self) -> None: """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ channel = self.make_request( "GET", @@ -70,7 +70,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - HTTPStatus.FORBIDDEN, + 403, channel.code, msg=channel.json_body, ) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index c2b54b1ef7..beb1b1c120 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -142,7 +142,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("HMAC incorrect", channel.json_body["error"]) def test_register_correct_nonce(self) -> None: @@ -375,7 +375,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.assertEqual("@bob3:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob3:test/displayname") - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) # set displayname channel = self.make_request("GET", self.url) @@ -466,7 +466,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self) -> None: @@ -478,7 +478,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", self.url, access_token=other_user_token) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_all_users(self) -> None: @@ -941,7 +941,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_not_admin(self) -> None: @@ -952,7 +952,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", url, access_token=self.other_user_token) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) channel = self.make_request( @@ -962,12 +962,12 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): content=b"{}", ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) def test_user_does_not_exist(self) -> None: """ - Tests that deactivation for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that deactivation for a user that does not exist returns a 404 """ channel = self.make_request( @@ -976,7 +976,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_erase_is_not_bool(self) -> None: @@ -1220,7 +1220,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) channel = self.make_request( @@ -1230,12 +1230,12 @@ class UserRestTestCase(unittest.HomeserverTestCase): content=b"{}", ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) def test_user_does_not_exist(self) -> None: """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ channel = self.make_request( @@ -1244,7 +1244,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"]) def test_invalid_parameter(self) -> None: @@ -1379,7 +1379,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content=body, ) - self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1434,7 +1434,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content=body, ) - self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1512,7 +1512,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"password": "abc123", "admin": False}, ) - self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertFalse(channel.json_body["admin"]) @@ -1550,7 +1550,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) # Admin user is not blocked by mau anymore - self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertFalse(channel.json_body["admin"]) @@ -1585,7 +1585,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content=body, ) - self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) @@ -1626,7 +1626,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content=body, ) - self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) @@ -1666,7 +1666,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content=body, ) - self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("msisdn", channel.json_body["threepids"][0]["medium"]) self.assertEqual("1234567890", channel.json_body["threepids"][0]["address"]) @@ -2064,7 +2064,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) # must fail - self.assertEqual(HTTPStatus.CONFLICT, channel.code, msg=channel.json_body) + self.assertEqual(409, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual("External id is already in use.", channel.json_body["error"]) @@ -2261,7 +2261,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Reactivate the user without a password. @@ -2295,7 +2295,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Reactivate the user without a password. @@ -2407,7 +2407,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"password": "abc123"}, ) - self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) @@ -2520,7 +2520,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self) -> None: @@ -2535,7 +2535,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_does_not_exist(self) -> None: @@ -2678,7 +2678,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self) -> None: @@ -2693,12 +2693,12 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_does_not_exist(self) -> None: """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v1/users/@unknown_person:test/pushers" channel = self.make_request( @@ -2707,7 +2707,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_user_is_not_local(self) -> None: @@ -2808,7 +2808,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): """Try to list media of an user without authentication.""" channel = self.make_request(method, self.url, {}) - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "DELETE"]) @@ -2822,12 +2822,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "DELETE"]) def test_user_does_not_exist(self, method: str) -> None: - """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND""" + """Tests that a lookup for a user that does not exist returns a 404""" url = "/_synapse/admin/v1/users/@unknown_person:test/media" channel = self.make_request( method, @@ -2835,7 +2835,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @parameterized.expand(["GET", "DELETE"]) @@ -3393,7 +3393,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): """Try to login as a user without authentication.""" channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_not_admin(self) -> None: @@ -3402,7 +3402,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): "POST", self.url, b"{}", access_token=self.other_user_tok ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) def test_send_event(self) -> None: """Test that sending event as a user works.""" @@ -3447,7 +3447,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) # .. but the real user's tokens should still work channel = self.make_request( @@ -3480,7 +3480,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) def test_admin_logout_all(self) -> None: """Tests that the admin user calling `/logout/all` does expire the @@ -3501,7 +3501,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) # .. but the real user's tokens should still work channel = self.make_request( @@ -3538,7 +3538,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): room_id, "com.example.test", tok=self.other_user_tok, - expect_code=HTTPStatus.FORBIDDEN, + expect_code=403, ) # Login in as the user @@ -3559,7 +3559,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): room_id, user=self.other_user, tok=self.other_user_tok, - expect_code=HTTPStatus.FORBIDDEN, + expect_code=403, ) # Logging in as the other user and joining a room should work, even @@ -3594,7 +3594,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): Try to get information of an user without authentication. """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_not_admin(self) -> None: @@ -3609,7 +3609,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.url, access_token=other_user2_token, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_is_not_local(self) -> None: @@ -3680,7 +3680,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): Try to get information of an user without authentication. """ channel = self.make_request(method, self.url) - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["POST", "DELETE"]) @@ -3691,7 +3691,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): other_user_token = self.login("user", "pass") channel = self.make_request(method, self.url, access_token=other_user_token) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand(["POST", "DELETE"]) @@ -3762,7 +3762,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): """ channel = self.make_request(method, self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "POST", "DELETE"]) @@ -3778,13 +3778,13 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "POST", "DELETE"]) def test_user_does_not_exist(self, method: str) -> None: """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit" @@ -3794,7 +3794,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @parameterized.expand( @@ -3982,7 +3982,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase): """Try to get information of a user without authentication.""" channel = self.make_request("GET", self.url, {}) - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self) -> None: @@ -3995,7 +3995,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_does_not_exist(self) -> None: @@ -4008,7 +4008,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_user_is_not_local(self) -> None: -- cgit 1.5.1 From 2281427175e4c93a30c39607fb4ac23c2a1f399f Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 10 Aug 2022 20:01:12 +0200 Subject: Use literals in place of `HTTPStatus` constants in tests (#13488) * Use literals in place of `HTTPStatus` constants in tests * newsfile * code style * code style --- changelog.d/13488.misc | 1 + tests/rest/admin/test_background_updates.py | 7 +- tests/rest/admin/test_device.py | 15 ++- tests/rest/admin/test_event_reports.py | 75 +++--------- tests/rest/admin/test_federation.py | 17 ++- tests/rest/admin/test_media.py | 99 +++------------- tests/rest/admin/test_registration_tokens.py | 167 +++++---------------------- tests/rest/admin/test_room.py | 49 ++++---- tests/rest/admin/test_server_notice.py | 15 ++- tests/rest/admin/test_statistics.py | 61 ++-------- tests/rest/admin/test_user.py | 107 +++++++++-------- tests/rest/admin/test_username_available.py | 11 +- 12 files changed, 177 insertions(+), 447 deletions(-) create mode 100644 changelog.d/13488.misc (limited to 'tests/rest') diff --git a/changelog.d/13488.misc b/changelog.d/13488.misc new file mode 100644 index 0000000000..315930deab --- /dev/null +++ b/changelog.d/13488.misc @@ -0,0 +1 @@ +Use literals in place of `HTTPStatus` constants in tests. \ No newline at end of file diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py index 7cd8b52f02..d507a3af8d 100644 --- a/tests/rest/admin/test_background_updates.py +++ b/tests/rest/admin/test_background_updates.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus from typing import Collection from parameterized import parameterized @@ -81,7 +80,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) # job_name invalid @@ -92,7 +91,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) def _register_bg_update(self) -> None: @@ -365,4 +364,4 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index 6a6e8ad7d8..d52aee8f92 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import urllib.parse -from http import HTTPStatus from parameterized import parameterized @@ -104,7 +103,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): @parameterized.expand(["GET", "PUT", "DELETE"]) def test_user_is_not_local(self, method: str) -> None: """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ url = ( "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices/%s" @@ -117,7 +116,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) def test_unknown_device(self) -> None: @@ -179,7 +178,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): content=update, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"]) # Ensure the display name was not updated. @@ -353,7 +352,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): def test_user_is_not_local(self) -> None: """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices" @@ -363,7 +362,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) def test_user_has_no_devices(self) -> None: @@ -479,7 +478,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): def test_user_is_not_local(self) -> None: """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices" @@ -489,7 +488,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) def test_unknown_devices(self) -> None: diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index c63a86d5b3..fbc490f46d 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus from typing import List from twisted.test.proto_helpers import MemoryReactor @@ -81,11 +80,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual( - 401, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self) -> None: @@ -99,11 +94,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.other_user_tok, ) - self.assertEqual( - 403, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_default_success(self) -> None: @@ -278,7 +269,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): def test_invalid_search_order(self) -> None: """ - Testing that a invalid search order returns a HTTPStatus.BAD_REQUEST + Testing that a invalid search order returns a 400 """ channel = self.make_request( @@ -287,17 +278,13 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual("Unknown direction: bar", channel.json_body["error"]) def test_limit_is_negative(self) -> None: """ - Testing that a negative limit parameter returns a HTTPStatus.BAD_REQUEST + Testing that a negative limit parameter returns a 400 """ channel = self.make_request( @@ -306,16 +293,12 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_from_is_negative(self) -> None: """ - Testing that a negative from parameter returns a HTTPStatus.BAD_REQUEST + Testing that a negative from parameter returns a 400 """ channel = self.make_request( @@ -324,11 +307,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_next_token(self) -> None: @@ -466,11 +445,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual( - 401, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self) -> None: @@ -484,11 +459,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.other_user_tok, ) - self.assertEqual( - 403, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_default_success(self) -> None: @@ -507,7 +478,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): def test_invalid_report_id(self) -> None: """ - Testing that an invalid `report_id` returns a HTTPStatus.BAD_REQUEST. + Testing that an invalid `report_id` returns a 400. """ # `report_id` is negative @@ -517,11 +488,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "The report_id parameter must be a string representing a positive integer.", @@ -535,11 +502,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "The report_id parameter must be a string representing a positive integer.", @@ -553,11 +516,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "The report_id parameter must be a string representing a positive integer.", @@ -575,11 +534,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - 404, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) self.assertEqual("Event report not found", channel.json_body["error"]) diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py index 8affd830d1..4c7864c629 100644 --- a/tests/rest/admin/test_federation.py +++ b/tests/rest/admin/test_federation.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus from typing import List, Optional from parameterized import parameterized @@ -77,7 +76,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from @@ -87,7 +86,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # unkown order_by @@ -97,7 +96,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid search order @@ -107,7 +106,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid destination @@ -469,7 +468,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "The retry timing does not need to be reset for this destination.", channel.json_body["error"], @@ -574,7 +573,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from @@ -584,7 +583,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid search order @@ -594,7 +593,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid destination diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index d51c10a515..aadb31ca83 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from http import HTTPStatus from parameterized import parameterized @@ -81,11 +80,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual( - 403, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_media_does_not_exist(self) -> None: @@ -105,7 +100,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): def test_media_is_not_local(self) -> None: """ - Tests that a lookup for a media that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a media that is not a local returns a 400 """ url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345") @@ -115,7 +110,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only delete local media", channel.json_body["error"]) def test_delete_media(self) -> None: @@ -230,11 +225,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, b"{}") - self.assertEqual( - 401, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self) -> None: @@ -250,16 +241,12 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual( - 403, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_media_is_not_local(self) -> None: """ - Tests that a lookup for media that is not local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for media that is not local returns a 400 """ url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain" @@ -269,7 +256,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only delete local media", channel.json_body["error"]) def test_missing_parameter(self) -> None: @@ -282,11 +269,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) self.assertEqual( "Missing integer query parameter 'before_ts'", channel.json_body["error"] @@ -302,11 +285,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter before_ts must be a positive integer.", @@ -319,11 +298,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter before_ts you provided is from the year 1970. " @@ -337,11 +312,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter size_gt must be a string representing a positive integer.", @@ -354,11 +325,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Boolean query parameter 'keep_profiles' must be one of ['true', 'false']", @@ -667,11 +634,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): b"{}", ) - self.assertEqual( - 401, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["quarantine", "unquarantine"]) @@ -688,11 +651,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual( - 403, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_quarantine_media(self) -> None: @@ -800,11 +759,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url % (action, self.media_id), b"{}") - self.assertEqual( - 401, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["protect", "unprotect"]) @@ -821,11 +776,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual( - 403, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_protect_media(self) -> None: @@ -913,11 +864,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual( - 403, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_invalid_parameter(self) -> None: @@ -930,11 +877,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter before_ts must be a positive integer.", @@ -947,11 +890,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter before_ts you provided is from the year 1970. " diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py index bcb602382c..8f8abc21c7 100644 --- a/tests/rest/admin/test_registration_tokens.py +++ b/tests/rest/admin/test_registration_tokens.py @@ -13,7 +13,6 @@ # limitations under the License. import random import string -from http import HTTPStatus from typing import Optional from twisted.test.proto_helpers import MemoryReactor @@ -74,11 +73,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): def test_create_no_auth(self) -> None: """Try to create a token without authentication.""" channel = self.make_request("POST", self.url + "/new", {}) - self.assertEqual( - 401, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_create_requester_not_admin(self) -> None: @@ -89,11 +84,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual( - 403, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_create_using_defaults(self) -> None: @@ -168,11 +159,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) def test_create_token_invalid_chars(self) -> None: @@ -188,11 +175,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) def test_create_token_already_exists(self) -> None: @@ -215,7 +198,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): data, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel2.code, msg=channel2.json_body) + self.assertEqual(400, channel2.code, msg=channel2.json_body) self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM) def test_create_unable_to_generate_token(self) -> None: @@ -262,7 +245,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) self.assertEqual( - HTTPStatus.BAD_REQUEST, + 400, channel.code, msg=channel.json_body, ) @@ -275,11 +258,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 1.5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) def test_create_expiry_time(self) -> None: @@ -291,11 +270,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": self.clock.time_msec() - 10000}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with float @@ -305,11 +280,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": self.clock.time_msec() + 1000000.5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) def test_create_length(self) -> None: @@ -331,11 +302,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 0}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with a negative integer @@ -345,11 +312,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": -5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with a float @@ -359,11 +322,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 8.5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with 65 @@ -373,11 +332,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 65}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # UPDATING @@ -389,11 +344,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.url + "/1234", # Token doesn't exist but that doesn't matter {}, ) - self.assertEqual( - 401, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_update_requester_not_admin(self) -> None: @@ -404,11 +355,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual( - 403, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_update_non_existent(self) -> None: @@ -420,11 +367,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - 404, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) def test_update_uses_allowed(self) -> None: @@ -472,11 +415,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 1.5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with a negative integer @@ -486,11 +425,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": -5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) def test_update_expiry_time(self) -> None: @@ -529,11 +464,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": past_time}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail a float @@ -543,11 +474,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": new_expiry_time + 0.5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) def test_update_both(self) -> None: @@ -589,11 +516,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # DELETING @@ -605,11 +528,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.url + "/1234", # Token doesn't exist but that doesn't matter {}, ) - self.assertEqual( - 401, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_delete_requester_not_admin(self) -> None: @@ -620,11 +539,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual( - 403, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_delete_non_existent(self) -> None: @@ -636,11 +551,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - 404, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) def test_delete(self) -> None: @@ -666,11 +577,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.url + "/1234", # Token doesn't exist but that doesn't matter {}, ) - self.assertEqual( - 401, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_get_requester_not_admin(self) -> None: @@ -697,11 +604,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - 404, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) def test_get(self) -> None: @@ -728,11 +631,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): def test_list_no_auth(self) -> None: """Try to list tokens without authentication.""" channel = self.make_request("GET", self.url, {}) - self.assertEqual( - 401, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_list_requester_not_admin(self) -> None: @@ -743,11 +642,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual( - 403, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_list_all(self) -> None: @@ -780,11 +675,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) def _test_list_query_parameter(self, valid: str) -> None: """Helper used to test both valid=true and valid=false.""" diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 8350f9025a..dd5000679a 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import urllib.parse -from http import HTTPStatus from typing import List, Optional from unittest.mock import Mock @@ -98,7 +97,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): def test_room_is_not_valid(self) -> None: """ - Check that invalid room names, return an error HTTPStatus.BAD_REQUEST. + Check that invalid room names, return an error 400. """ url = "/_synapse/admin/v1/rooms/%s" % "invalidroom" @@ -109,7 +108,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "invalidroom is not a legal room ID", channel.json_body["error"], @@ -145,7 +144,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "User must be our own: @not:exist.bla", channel.json_body["error"], @@ -163,7 +162,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_purge_is_not_bool(self) -> None: @@ -178,7 +177,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_purge_room_and_block(self) -> None: @@ -546,7 +545,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): ) def test_room_is_not_valid(self, method: str, url: str) -> None: """ - Check that invalid room names, return an error HTTPStatus.BAD_REQUEST. + Check that invalid room names, return an error 400. """ channel = self.make_request( @@ -556,7 +555,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "invalidroom is not a legal room ID", channel.json_body["error"], @@ -592,7 +591,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "User must be our own: @not:exist.bla", channel.json_body["error"], @@ -610,7 +609,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_purge_is_not_bool(self) -> None: @@ -625,7 +624,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_delete_expired_status(self) -> None: @@ -722,9 +721,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, second_channel.code, msg=second_channel.json_body - ) + self.assertEqual(400, second_channel.code, msg=second_channel.json_body) self.assertEqual(Codes.UNKNOWN, second_channel.json_body["errcode"]) self.assertEqual( f"History purge already in progress for {self.room_id}", @@ -1546,7 +1543,7 @@ class RoomTestCase(unittest.HomeserverTestCase): _search_test(None, "foo") _search_test(None, "bar") - _search_test(None, "", expected_http_code=HTTPStatus.BAD_REQUEST) + _search_test(None, "", expected_http_code=400) # Test that the whole room id returns the room _search_test(room_id_1, room_id_1) @@ -1836,7 +1833,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) def test_local_user_does_not_exist(self) -> None: @@ -1866,7 +1863,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "This endpoint can only be used with local users", channel.json_body["error"], @@ -1893,7 +1890,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): def test_room_is_not_valid(self) -> None: """ - Check that invalid room names, return an error HTTPStatus.BAD_REQUEST. + Check that invalid room names, return an error 400. """ url = "/_synapse/admin/v1/join/invalidroom" @@ -1904,7 +1901,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "invalidroom was not legal room ID or room alias", channel.json_body["error"], @@ -2243,11 +2240,11 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - # We expect this to fail with a HTTPStatus.BAD_REQUEST as there are no room admins. + # We expect this to fail with a 400 as there are no room admins. # # (Note we assert the error message to ensure that it's not denied for # some other reason) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( channel.json_body["error"], "No local admin user in room with power to update power levels.", @@ -2291,7 +2288,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): @parameterized.expand([("PUT",), ("GET",)]) def test_room_is_not_valid(self, method: str) -> None: - """Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.""" + """Check that invalid room names, return an error 400.""" channel = self.make_request( method, @@ -2300,7 +2297,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "invalidroom is not a legal room ID", channel.json_body["error"], @@ -2317,7 +2314,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) # `block` is not set @@ -2328,7 +2325,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) # no content is send @@ -2338,7 +2335,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"]) def test_block_room(self) -> None: diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index 3a6b98fbb5..81e125e27d 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus from typing import List from twisted.test.proto_helpers import MemoryReactor @@ -94,7 +93,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) def test_user_is_not_local(self) -> None: """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ channel = self.make_request( "POST", @@ -106,7 +105,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "Server notices can only be sent to local users", channel.json_body["error"] ) @@ -122,7 +121,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"]) # no content @@ -133,7 +132,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): content={"user_id": self.other_user}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) # no body @@ -144,7 +143,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): content={"user_id": self.other_user, "content": ""}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual("'body' not in content", channel.json_body["error"]) @@ -156,7 +155,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): content={"user_id": self.other_user, "content": {"body": ""}}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual("'msgtype' not in content", channel.json_body["error"]) @@ -172,7 +171,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual( "Server notices are not enabled on this server", channel.json_body["error"] diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index 3a8982afea..b60f16b914 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus from typing import List, Optional from twisted.test.proto_helpers import MemoryReactor @@ -51,11 +50,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual( - 401, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self) -> None: @@ -69,11 +64,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.other_user_tok, ) - self.assertEqual( - 403, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_invalid_parameter(self) -> None: @@ -87,11 +78,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from @@ -101,11 +88,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative limit @@ -115,11 +98,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from_ts @@ -129,11 +108,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative until_ts @@ -143,11 +118,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # until_ts smaller from_ts @@ -157,11 +128,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # empty search term @@ -171,11 +138,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid search order @@ -185,11 +148,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_limit(self) -> None: diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index beb1b1c120..411e4ec005 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -17,7 +17,6 @@ import hmac import os import urllib.parse from binascii import unhexlify -from http import HTTPStatus from typing import List, Optional from unittest.mock import Mock, patch @@ -79,7 +78,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "Shared secret registration is not enabled", channel.json_body["error"] ) @@ -111,7 +110,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): body = {"nonce": nonce} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("username must be specified", channel.json_body["error"]) # 61 seconds @@ -119,7 +118,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) def test_register_incorrect_nonce(self) -> None: @@ -198,7 +197,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Now, try and reuse it channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) def test_missing_parts(self) -> None: @@ -219,7 +218,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be an empty body present channel = self.make_request("POST", self.url, {}) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("nonce must be specified", channel.json_body["error"]) # @@ -229,28 +228,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be present channel = self.make_request("POST", self.url, {"nonce": nonce()}) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("username must be specified", channel.json_body["error"]) # Must be a string body = {"nonce": nonce(), "username": 1234} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = {"nonce": nonce(), "username": "abcd\u0000"} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = {"nonce": nonce(), "username": "a" * 1000} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # @@ -261,28 +260,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): body = {"nonce": nonce(), "username": "a"} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("password must be specified", channel.json_body["error"]) # Must be a string body = {"nonce": nonce(), "username": "a", "password": 1234} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # Must not have null bytes body = {"nonce": nonce(), "username": "a", "password": "abcd\u0000"} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # Super long body = {"nonce": nonce(), "username": "a", "password": "A" * 1000} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # @@ -298,7 +297,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid user type", channel.json_body["error"]) def test_displayname(self) -> None: @@ -591,7 +590,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from @@ -601,7 +600,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid guests @@ -611,7 +610,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid deactivated @@ -621,7 +620,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # unkown order_by @@ -631,7 +630,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid search order @@ -641,7 +640,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_limit(self) -> None: @@ -991,18 +990,18 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_user_is_not_local(self) -> None: """ - Tests that deactivation for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that deactivation for a user that is not a local returns a 400 """ url = "/_synapse/admin/v1/deactivate/@unknown_person:unknown_domain" channel = self.make_request("POST", url, access_token=self.admin_user_tok) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only deactivate local users", channel.json_body["error"]) def test_deactivate_user_erase_true(self) -> None: @@ -1259,7 +1258,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"admin": "not_bool"}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) # deactivated not bool @@ -1269,7 +1268,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": "not_bool"}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # password not str @@ -1279,7 +1278,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"password": True}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # password not length @@ -1289,7 +1288,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"password": "x" * 513}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # user_type not valid @@ -1299,7 +1298,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"user_type": "new type"}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # external_ids not valid @@ -1311,7 +1310,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): "external_ids": {"auth_provider": "prov", "wrong_external_id": "id"} }, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) channel = self.make_request( @@ -1320,7 +1319,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"external_ids": {"external_id": "id"}}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) # threepids not valid @@ -1330,7 +1329,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"threepids": {"medium": "email", "wrong_address": "id"}}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) channel = self.make_request( @@ -1339,7 +1338,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"threepids": {"address": "value"}}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) def test_get_user(self) -> None: @@ -2228,7 +2227,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) # Reactivate the user. channel = self.make_request( @@ -2431,7 +2430,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"password": "abc123", "deactivated": "false"}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) # Check user is not deactivated channel = self.make_request( @@ -2712,7 +2711,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): def test_user_is_not_local(self) -> None: """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers" @@ -2722,7 +2721,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only look up local users", channel.json_body["error"]) def test_get_pushers(self) -> None: @@ -2840,7 +2839,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): @parameterized.expand(["GET", "DELETE"]) def test_user_is_not_local(self, method: str) -> None: - """Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST""" + """Tests that a lookup for a user that is not a local returns a 400""" url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media" channel = self.make_request( @@ -2849,7 +2848,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only look up local users", channel.json_body["error"]) def test_limit_GET(self) -> None: @@ -2970,7 +2969,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid search order @@ -2980,7 +2979,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative limit @@ -2990,7 +2989,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from @@ -3000,7 +2999,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_next_token(self) -> None: @@ -3614,7 +3613,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): def test_user_is_not_local(self) -> None: """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ url = self.url_prefix % "@unknown_person:unknown_domain" # type: ignore[attr-defined] @@ -3623,7 +3622,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only whois a local user", channel.json_body["error"]) def test_get_whois_admin(self) -> None: @@ -3697,12 +3696,12 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): @parameterized.expand(["POST", "DELETE"]) def test_user_is_not_local(self, method: str) -> None: """ - Tests that shadow-banning for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that shadow-banning for a user that is not a local returns a 400 """ url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain" channel = self.make_request(method, url, access_token=self.admin_user_tok) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) def test_success(self) -> None: """ @@ -3806,7 +3805,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): ) def test_user_is_not_local(self, method: str, error_msg: str) -> None: """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ url = ( "/_synapse/admin/v1/users/@unknown_person:unknown_domain/override_ratelimit" @@ -3818,7 +3817,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(error_msg, channel.json_body["error"]) def test_invalid_parameter(self) -> None: @@ -3833,7 +3832,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"messages_per_second": "string"}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # messages_per_second is negative @@ -3844,7 +3843,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"messages_per_second": -1}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # burst_count is a string @@ -3855,7 +3854,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"burst_count": "string"}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # burst_count is negative @@ -3866,7 +3865,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"burst_count": -1}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_return_zero_when_null(self) -> None: @@ -4021,7 +4020,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only look up local users", channel.json_body["error"]) def test_success(self) -> None: diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py index b5e7eecf87..30f12f1bff 100644 --- a/tests/rest/admin/test_username_available.py +++ b/tests/rest/admin/test_username_available.py @@ -11,9 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from http import HTTPStatus - from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin @@ -40,7 +37,7 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase): if username == "allowed": return True raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "User ID already taken.", errcode=Codes.USER_IN_USE, ) @@ -67,10 +64,6 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase): url = "%s?username=%s" % (self.url, "disallowed") channel = self.make_request("GET", url, access_token=self.admin_user_tok) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], "M_USER_IN_USE") self.assertEqual(channel.json_body["error"], "User ID already taken.") -- cgit 1.5.1 From d642ce4b3258012da6c024b0b5d1396d2a3e69dd Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 15 Aug 2022 20:05:57 +0100 Subject: Use Pydantic to systematically validate a first batch of endpoints in `synapse.rest.client.account`. (#13188) --- changelog.d/13188.feature | 1 + mypy.ini | 2 +- poetry.lock | 54 +++++++++++++- pyproject.toml | 3 + synapse/http/servlet.py | 25 +++++++ synapse/rest/client/account.py | 148 ++++++++++++++++---------------------- synapse/rest/client/models.py | 69 ++++++++++++++++++ synapse/rest/models.py | 23 ++++++ tests/rest/client/test_account.py | 10 +-- tests/rest/client/test_models.py | 53 ++++++++++++++ 10 files changed, 296 insertions(+), 92 deletions(-) create mode 100644 changelog.d/13188.feature create mode 100644 synapse/rest/client/models.py create mode 100644 synapse/rest/models.py create mode 100644 tests/rest/client/test_models.py (limited to 'tests/rest') diff --git a/changelog.d/13188.feature b/changelog.d/13188.feature new file mode 100644 index 0000000000..4c39b74289 --- /dev/null +++ b/changelog.d/13188.feature @@ -0,0 +1 @@ +Improve validation of request bodies for the following client-server API endpoints: [`/account/password`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3accountpassword), [`/account/password/email/requestToken`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3accountpasswordemailrequesttoken), [`/account/deactivate`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3accountdeactivate) and [`/account/3pid/email/requestToken`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidemailrequesttoken). diff --git a/mypy.ini b/mypy.ini index 6add272990..e2034e411f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,6 +1,6 @@ [mypy] namespace_packages = True -plugins = mypy_zope:plugin, scripts-dev/mypy_synapse_plugin.py +plugins = pydantic.mypy, mypy_zope:plugin, scripts-dev/mypy_synapse_plugin.py follow_imports = normal check_untyped_defs = True show_error_codes = True diff --git a/poetry.lock b/poetry.lock index 1acdb5da56..651659ec98 100644 --- a/poetry.lock +++ b/poetry.lock @@ -778,6 +778,21 @@ category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +[[package]] +name = "pydantic" +version = "1.9.1" +description = "Data validation and settings management using python type hints" +category = "main" +optional = false +python-versions = ">=3.6.1" + +[package.dependencies] +typing-extensions = ">=3.7.4.3" + +[package.extras] +dotenv = ["python-dotenv (>=0.10.4)"] +email = ["email-validator (>=1.0.3)"] + [[package]] name = "pyflakes" version = "2.4.0" @@ -1563,7 +1578,7 @@ url_preview = ["lxml"] [metadata] lock-version = "1.1" python-versions = "^3.7.1" -content-hash = "c24bbcee7e86dbbe7cdbf49f91a25b310bf21095452641e7440129f59b077f78" +content-hash = "7de518bf27967b3547eab8574342cfb67f87d6b47b4145c13de11112141dbf2d" [metadata.files] attrs = [ @@ -2260,6 +2275,43 @@ pycparser = [ {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"}, {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"}, ] +pydantic = [ + {file = "pydantic-1.9.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c8098a724c2784bf03e8070993f6d46aa2eeca031f8d8a048dff277703e6e193"}, + {file = "pydantic-1.9.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c320c64dd876e45254bdd350f0179da737463eea41c43bacbee9d8c9d1021f11"}, + {file = "pydantic-1.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18f3e912f9ad1bdec27fb06b8198a2ccc32f201e24174cec1b3424dda605a310"}, + {file = "pydantic-1.9.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c11951b404e08b01b151222a1cb1a9f0a860a8153ce8334149ab9199cd198131"}, + {file = "pydantic-1.9.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8bc541a405423ce0e51c19f637050acdbdf8feca34150e0d17f675e72d119580"}, + {file = "pydantic-1.9.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e565a785233c2d03724c4dc55464559639b1ba9ecf091288dd47ad9c629433bd"}, + {file = "pydantic-1.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:a4a88dcd6ff8fd47c18b3a3709a89adb39a6373f4482e04c1b765045c7e282fd"}, + {file = "pydantic-1.9.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:447d5521575f18e18240906beadc58551e97ec98142266e521c34968c76c8761"}, + {file = "pydantic-1.9.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:985ceb5d0a86fcaa61e45781e567a59baa0da292d5ed2e490d612d0de5796918"}, + {file = "pydantic-1.9.1-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:059b6c1795170809103a1538255883e1983e5b831faea6558ef873d4955b4a74"}, + {file = "pydantic-1.9.1-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:d12f96b5b64bec3f43c8e82b4aab7599d0157f11c798c9f9c528a72b9e0b339a"}, + {file = "pydantic-1.9.1-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:ae72f8098acb368d877b210ebe02ba12585e77bd0db78ac04a1ee9b9f5dd2166"}, + {file = "pydantic-1.9.1-cp36-cp36m-win_amd64.whl", hash = "sha256:79b485767c13788ee314669008d01f9ef3bc05db9ea3298f6a50d3ef596a154b"}, + {file = "pydantic-1.9.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:494f7c8537f0c02b740c229af4cb47c0d39840b829ecdcfc93d91dcbb0779892"}, + {file = "pydantic-1.9.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f0f047e11febe5c3198ed346b507e1d010330d56ad615a7e0a89fae604065a0e"}, + {file = "pydantic-1.9.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:969dd06110cb780da01336b281f53e2e7eb3a482831df441fb65dd30403f4608"}, + {file = "pydantic-1.9.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:177071dfc0df6248fd22b43036f936cfe2508077a72af0933d0c1fa269b18537"}, + {file = "pydantic-1.9.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:9bcf8b6e011be08fb729d110f3e22e654a50f8a826b0575c7196616780683380"}, + {file = "pydantic-1.9.1-cp37-cp37m-win_amd64.whl", hash = "sha256:a955260d47f03df08acf45689bd163ed9df82c0e0124beb4251b1290fa7ae728"}, + {file = "pydantic-1.9.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9ce157d979f742a915b75f792dbd6aa63b8eccaf46a1005ba03aa8a986bde34a"}, + {file = "pydantic-1.9.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0bf07cab5b279859c253d26a9194a8906e6f4a210063b84b433cf90a569de0c1"}, + {file = "pydantic-1.9.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d93d4e95eacd313d2c765ebe40d49ca9dd2ed90e5b37d0d421c597af830c195"}, + {file = "pydantic-1.9.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1542636a39c4892c4f4fa6270696902acb186a9aaeac6f6cf92ce6ae2e88564b"}, + {file = "pydantic-1.9.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a9af62e9b5b9bc67b2a195ebc2c2662fdf498a822d62f902bf27cccb52dbbf49"}, + {file = "pydantic-1.9.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fe4670cb32ea98ffbf5a1262f14c3e102cccd92b1869df3bb09538158ba90fe6"}, + {file = "pydantic-1.9.1-cp38-cp38-win_amd64.whl", hash = "sha256:9f659a5ee95c8baa2436d392267988fd0f43eb774e5eb8739252e5a7e9cf07e0"}, + {file = "pydantic-1.9.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b83ba3825bc91dfa989d4eed76865e71aea3a6ca1388b59fc801ee04c4d8d0d6"}, + {file = "pydantic-1.9.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1dd8fecbad028cd89d04a46688d2fcc14423e8a196d5b0a5c65105664901f810"}, + {file = "pydantic-1.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02eefd7087268b711a3ff4db528e9916ac9aa18616da7bca69c1871d0b7a091f"}, + {file = "pydantic-1.9.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7eb57ba90929bac0b6cc2af2373893d80ac559adda6933e562dcfb375029acee"}, + {file = "pydantic-1.9.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:4ce9ae9e91f46c344bec3b03d6ee9612802682c1551aaf627ad24045ce090761"}, + {file = "pydantic-1.9.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:72ccb318bf0c9ab97fc04c10c37683d9eea952ed526707fabf9ac5ae59b701fd"}, + {file = "pydantic-1.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:61b6760b08b7c395975d893e0b814a11cf011ebb24f7d869e7118f5a339a82e1"}, + {file = "pydantic-1.9.1-py3-none-any.whl", hash = "sha256:4988c0f13c42bfa9ddd2fe2f569c9d54646ce84adc5de84228cfe83396f3bd58"}, + {file = "pydantic-1.9.1.tar.gz", hash = "sha256:1ed987c3ff29fff7fd8c3ea3a3ea877ad310aae2ef9889a119e22d3f2db0691a"}, +] pyflakes = [ {file = "pyflakes-2.4.0-py2.py3-none-any.whl", hash = "sha256:3bb3a3f256f4b7968c9c788781e4ff07dce46bdf12339dcda61053375426ee2e"}, {file = "pyflakes-2.4.0.tar.gz", hash = "sha256:05a85c2872edf37a4ed30b0cce2f6093e1d0581f8c19d7393122da7e25b2b24c"}, diff --git a/pyproject.toml b/pyproject.toml index a9f59a676f..4f1e0b5c19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -158,6 +158,9 @@ packaging = ">=16.1" # At the time of writing, we only use functions from the version `importlib.metadata` # which shipped in Python 3.8. This corresponds to version 1.4 of the backport. importlib_metadata = { version = ">=1.4", python = "<3.8" } +# This is the most recent version of Pydantic with available on common distros. +pydantic = ">=1.7.4" + # Optional Dependencies diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 4ff840ca0e..26aaabfb34 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -23,9 +23,12 @@ from typing import ( Optional, Sequence, Tuple, + Type, + TypeVar, overload, ) +from pydantic import BaseModel, ValidationError from typing_extensions import Literal from twisted.web.server import Request @@ -694,6 +697,28 @@ def parse_json_object_from_request( return content +Model = TypeVar("Model", bound=BaseModel) + + +def parse_and_validate_json_object_from_request( + request: Request, model_type: Type[Model] +) -> Model: + """Parse a JSON object from the body of a twisted HTTP request, then deserialise and + validate using the given pydantic model. + + Raises: + SynapseError if the request body couldn't be decoded as JSON or + if it wasn't a JSON object. + """ + content = parse_json_object_from_request(request, allow_empty_body=False) + try: + instance = model_type.parse_obj(content) + except ValidationError as e: + raise SynapseError(HTTPStatus.BAD_REQUEST, str(e), errcode=Codes.BAD_JSON) + + return instance + + def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None: absent = [] for k in required: diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index 50edc6b7d3..e5ee63133b 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -15,10 +15,11 @@ # limitations under the License. import logging import random -from http import HTTPStatus from typing import TYPE_CHECKING, Optional, Tuple from urllib.parse import urlparse +from pydantic import StrictBool, StrictStr, constr + from twisted.web.server import Request from synapse.api.constants import LoginType @@ -34,12 +35,15 @@ from synapse.http.server import HttpServer, finish_request, respond_with_html from synapse.http.servlet import ( RestServlet, assert_params_in_dict, + parse_and_validate_json_object_from_request, parse_json_object_from_request, parse_string, ) from synapse.http.site import SynapseRequest from synapse.metrics import threepid_send_requests from synapse.push.mailer import Mailer +from synapse.rest.client.models import AuthenticationData, EmailRequestTokenBody +from synapse.rest.models import RequestBodyModel from synapse.types import JsonDict from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.stringutils import assert_valid_client_secret, random_string @@ -82,32 +86,16 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): 400, "Email-based password resets have been disabled on this server" ) - body = parse_json_object_from_request(request) - - assert_params_in_dict(body, ["client_secret", "email", "send_attempt"]) - - # Extract params from body - client_secret = body["client_secret"] - assert_valid_client_secret(client_secret) - - # Canonicalise the email address. The addresses are all stored canonicalised - # in the database. This allows the user to reset his password without having to - # know the exact spelling (eg. upper and lower case) of address in the database. - # Stored in the database "foo@bar.com" - # User requests with "FOO@bar.com" would raise a Not Found error - try: - email = validate_email(body["email"]) - except ValueError as e: - raise SynapseError(400, str(e)) - send_attempt = body["send_attempt"] - next_link = body.get("next_link") # Optional param + body = parse_and_validate_json_object_from_request( + request, EmailRequestTokenBody + ) - if next_link: + if body.next_link: # Raise if the provided next_link value isn't valid - assert_valid_next_link(self.hs, next_link) + assert_valid_next_link(self.hs, body.next_link) await self.identity_handler.ratelimit_request_token_requests( - request, "email", email + request, "email", body.email ) # The email will be sent to the stored address. @@ -115,7 +103,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # an email address which is controlled by the attacker but which, after # canonicalisation, matches the one in our database. existing_user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( - "email", email + "email", body.email ) if existing_user_id is None: @@ -135,26 +123,26 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # Have the configured identity server handle the request ret = await self.identity_handler.request_email_token( self.hs.config.registration.account_threepid_delegate_email, - email, - client_secret, - send_attempt, - next_link, + body.email, + body.client_secret, + body.send_attempt, + body.next_link, ) else: # Send password reset emails from Synapse sid = await self.identity_handler.send_threepid_validation( - email, - client_secret, - send_attempt, + body.email, + body.client_secret, + body.send_attempt, self.mailer.send_password_reset_mail, - next_link, + body.next_link, ) # Wrap the session id in a JSON object ret = {"sid": sid} threepid_send_requests.labels(type="email", reason="password_reset").observe( - send_attempt + body.send_attempt ) return 200, ret @@ -172,16 +160,23 @@ class PasswordRestServlet(RestServlet): self.password_policy_handler = hs.get_password_policy_handler() self._set_password_handler = hs.get_set_password_handler() + class PostBody(RequestBodyModel): + auth: Optional[AuthenticationData] = None + logout_devices: StrictBool = True + if TYPE_CHECKING: + # workaround for https://github.com/samuelcolvin/pydantic/issues/156 + new_password: Optional[StrictStr] = None + else: + new_password: Optional[constr(max_length=512, strict=True)] = None + @interactive_auth_handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - body = parse_json_object_from_request(request) + body = parse_and_validate_json_object_from_request(request, self.PostBody) # we do basic sanity checks here because the auth layer will store these # in sessions. Pull out the new password provided to us. - new_password = body.pop("new_password", None) + new_password = body.new_password if new_password is not None: - if not isinstance(new_password, str) or len(new_password) > 512: - raise SynapseError(400, "Invalid password") self.password_policy_handler.validate_password(new_password) # there are two possibilities here. Either the user does not have an @@ -201,7 +196,7 @@ class PasswordRestServlet(RestServlet): params, session_id = await self.auth_handler.validate_user_via_ui_auth( requester, request, - body, + body.dict(), "modify your account password", ) except InteractiveAuthIncompleteError as e: @@ -224,7 +219,7 @@ class PasswordRestServlet(RestServlet): result, params, session_id = await self.auth_handler.check_ui_auth( [[LoginType.EMAIL_IDENTITY]], request, - body, + body.dict(), "modify your account password", ) except InteractiveAuthIncompleteError as e: @@ -299,37 +294,33 @@ class DeactivateAccountRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self._deactivate_account_handler = hs.get_deactivate_account_handler() + class PostBody(RequestBodyModel): + auth: Optional[AuthenticationData] = None + id_server: Optional[StrictStr] = None + # Not specced, see https://github.com/matrix-org/matrix-spec/issues/297 + erase: StrictBool = False + @interactive_auth_handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - body = parse_json_object_from_request(request) - erase = body.get("erase", False) - if not isinstance(erase, bool): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Param 'erase' must be a boolean, if given", - Codes.BAD_JSON, - ) + body = parse_and_validate_json_object_from_request(request, self.PostBody) requester = await self.auth.get_user_by_req(request) # allow ASes to deactivate their own users if requester.app_service: await self._deactivate_account_handler.deactivate_account( - requester.user.to_string(), erase, requester + requester.user.to_string(), body.erase, requester ) return 200, {} await self.auth_handler.validate_user_via_ui_auth( requester, request, - body, + body.dict(), "deactivate your account", ) result = await self._deactivate_account_handler.deactivate_account( - requester.user.to_string(), - erase, - requester, - id_server=body.get("id_server"), + requester.user.to_string(), body.erase, requester, id_server=body.id_server ) if result: id_server_unbind_result = "success" @@ -364,28 +355,15 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): "Adding emails have been disabled due to lack of an email config" ) raise SynapseError( - 400, "Adding an email to your account is disabled on this server" + 400, + "Adding an email to your account is disabled on this server", ) - body = parse_json_object_from_request(request) - assert_params_in_dict(body, ["client_secret", "email", "send_attempt"]) - client_secret = body["client_secret"] - assert_valid_client_secret(client_secret) - - # Canonicalise the email address. The addresses are all stored canonicalised - # in the database. - # This ensures that the validation email is sent to the canonicalised address - # as it will later be entered into the database. - # Otherwise the email will be sent to "FOO@bar.com" and stored as - # "foo@bar.com" in database. - try: - email = validate_email(body["email"]) - except ValueError as e: - raise SynapseError(400, str(e)) - send_attempt = body["send_attempt"] - next_link = body.get("next_link") # Optional param + body = parse_and_validate_json_object_from_request( + request, EmailRequestTokenBody + ) - if not await check_3pid_allowed(self.hs, "email", email): + if not await check_3pid_allowed(self.hs, "email", body.email): raise SynapseError( 403, "Your email domain is not authorized on this server", @@ -393,14 +371,14 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): ) await self.identity_handler.ratelimit_request_token_requests( - request, "email", email + request, "email", body.email ) - if next_link: + if body.next_link: # Raise if the provided next_link value isn't valid - assert_valid_next_link(self.hs, next_link) + assert_valid_next_link(self.hs, body.next_link) - existing_user_id = await self.store.get_user_id_by_threepid("email", email) + existing_user_id = await self.store.get_user_id_by_threepid("email", body.email) if existing_user_id is not None: if self.config.server.request_token_inhibit_3pid_errors: @@ -419,26 +397,26 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): # Have the configured identity server handle the request ret = await self.identity_handler.request_email_token( self.hs.config.registration.account_threepid_delegate_email, - email, - client_secret, - send_attempt, - next_link, + body.email, + body.client_secret, + body.send_attempt, + body.next_link, ) else: # Send threepid validation emails from Synapse sid = await self.identity_handler.send_threepid_validation( - email, - client_secret, - send_attempt, + body.email, + body.client_secret, + body.send_attempt, self.mailer.send_add_threepid_mail, - next_link, + body.next_link, ) # Wrap the session id in a JSON object ret = {"sid": sid} threepid_send_requests.labels(type="email", reason="add_threepid").observe( - send_attempt + body.send_attempt ) return 200, ret diff --git a/synapse/rest/client/models.py b/synapse/rest/client/models.py new file mode 100644 index 0000000000..3150602997 --- /dev/null +++ b/synapse/rest/client/models.py @@ -0,0 +1,69 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING, Dict, Optional + +from pydantic import Extra, StrictInt, StrictStr, constr, validator + +from synapse.rest.models import RequestBodyModel +from synapse.util.threepids import validate_email + + +class AuthenticationData(RequestBodyModel): + """ + Data used during user-interactive authentication. + + (The name "Authentication Data" is taken directly from the spec.) + + Additional keys will be present, depending on the `type` field. Use `.dict()` to + access them. + """ + + class Config: + extra = Extra.allow + + session: Optional[StrictStr] = None + type: Optional[StrictStr] = None + + +class EmailRequestTokenBody(RequestBodyModel): + if TYPE_CHECKING: + client_secret: StrictStr + else: + # See also assert_valid_client_secret() + client_secret: constr( + regex="[0-9a-zA-Z.=_-]", # noqa: F722 + min_length=0, + max_length=255, + strict=True, + ) + email: StrictStr + id_server: Optional[StrictStr] + id_access_token: Optional[StrictStr] + next_link: Optional[StrictStr] + send_attempt: StrictInt + + @validator("id_access_token", always=True) + def token_required_for_identity_server( + cls, token: Optional[str], values: Dict[str, object] + ) -> Optional[str]: + if values.get("id_server") is not None and token is None: + raise ValueError("id_access_token is required if an id_server is supplied.") + return token + + # Canonicalise the email address. The addresses are all stored canonicalised + # in the database. This allows the user to reset his password without having to + # know the exact spelling (eg. upper and lower case) of address in the database. + # Without this, an email stored in the database as "foo@bar.com" would cause + # user requests for "FOO@bar.com" to raise a Not Found error. + _email_validator = validator("email", allow_reuse=True)(validate_email) diff --git a/synapse/rest/models.py b/synapse/rest/models.py new file mode 100644 index 0000000000..ac39cda8e5 --- /dev/null +++ b/synapse/rest/models.py @@ -0,0 +1,23 @@ +from pydantic import BaseModel, Extra + + +class RequestBodyModel(BaseModel): + """A custom version of Pydantic's BaseModel which + + - ignores unknown fields and + - does not allow fields to be overwritten after construction, + + but otherwise uses Pydantic's default behaviour. + + Ignoring unknown fields is a useful default. It means that clients can provide + unstable field not known to the server without the request being refused outright. + + Subclassing in this way is recommended by + https://pydantic-docs.helpmanual.io/usage/model_config/#change-behaviour-globally + """ + + class Config: + # By default, ignore fields that we don't recognise. + extra = Extra.ignore + # By default, don't allow fields to be reassigned after parsing. + allow_mutation = False diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 7ae926dc9c..c1a7fb2f8a 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -488,7 +488,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", "account/deactivate", request_data, access_token=tok ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, 200, channel.json_body) class WhoamiTestCase(unittest.HomeserverTestCase): @@ -641,21 +641,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_add_email_no_at(self) -> None: self._request_token_invalid_email( "address-without-at.bar", - expected_errcode=Codes.UNKNOWN, + expected_errcode=Codes.BAD_JSON, expected_error="Unable to parse email address", ) def test_add_email_two_at(self) -> None: self._request_token_invalid_email( "foo@foo@test.bar", - expected_errcode=Codes.UNKNOWN, + expected_errcode=Codes.BAD_JSON, expected_error="Unable to parse email address", ) def test_add_email_bad_format(self) -> None: self._request_token_invalid_email( "user@bad.example.net@good.example.com", - expected_errcode=Codes.UNKNOWN, + expected_errcode=Codes.BAD_JSON, expected_error="Unable to parse email address", ) @@ -1001,7 +1001,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] ) self.assertEqual(expected_errcode, channel.json_body["errcode"]) - self.assertEqual(expected_error, channel.json_body["error"]) + self.assertIn(expected_error, channel.json_body["error"]) def _validate_token(self, link: str) -> None: # Remove the host diff --git a/tests/rest/client/test_models.py b/tests/rest/client/test_models.py new file mode 100644 index 0000000000..a9da00665e --- /dev/null +++ b/tests/rest/client/test_models.py @@ -0,0 +1,53 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +from pydantic import ValidationError + +from synapse.rest.client.models import EmailRequestTokenBody + + +class EmailRequestTokenBodyTestCase(unittest.TestCase): + base_request = { + "client_secret": "hunter2", + "email": "alice@wonderland.com", + "send_attempt": 1, + } + + def test_token_required_if_id_server_provided(self) -> None: + with self.assertRaises(ValidationError): + EmailRequestTokenBody.parse_obj( + { + **self.base_request, + "id_server": "identity.wonderland.com", + } + ) + with self.assertRaises(ValidationError): + EmailRequestTokenBody.parse_obj( + { + **self.base_request, + "id_server": "identity.wonderland.com", + "id_access_token": None, + } + ) + + def test_token_typechecked_when_id_server_provided(self) -> None: + with self.assertRaises(ValidationError): + EmailRequestTokenBody.parse_obj( + { + **self.base_request, + "id_server": "identity.wonderland.com", + "id_access_token": 1337, + } + ) -- cgit 1.5.1 From d75512d19ebea6c0f9e38e9f55474fdb6da02b46 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 17 Aug 2022 11:42:01 +0200 Subject: Add forgotten status to Room Details API (#13503) --- changelog.d/13503.feature | 1 + docs/admin_api/rooms.md | 5 +- synapse/rest/admin/rooms.py | 1 + synapse/storage/databases/main/roommember.py | 24 ++++++++++ tests/rest/admin/test_room.py | 1 + tests/storage/test_roommember.py | 70 ++++++++++++++++++++++++++++ 6 files changed, 101 insertions(+), 1 deletion(-) create mode 100644 changelog.d/13503.feature (limited to 'tests/rest') diff --git a/changelog.d/13503.feature b/changelog.d/13503.feature new file mode 100644 index 0000000000..4baabd1e32 --- /dev/null +++ b/changelog.d/13503.feature @@ -0,0 +1 @@ +Add forgotten status to Room Details API. \ No newline at end of file diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index 9aa489e4a3..ac7c54c20e 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -302,6 +302,8 @@ The following fields are possible in the JSON response body: * `state_events` - Total number of state_events of a room. Complexity of the room. * `room_type` - The type of the room taken from the room's creation event; for example "m.space" if the room is a space. If the room does not define a type, the value will be `null`. +* `forgotten` - Whether all local users have + [forgotten](https://spec.matrix.org/latest/client-server-api/#leaving-rooms) the room. The API is: @@ -330,7 +332,8 @@ A response body like the following is returned: "guest_access": null, "history_visibility": "shared", "state_events": 93534, - "room_type": "m.space" + "room_type": "m.space", + "forgotten": false } ``` diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 9d953d58de..68054ffc28 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -303,6 +303,7 @@ class RoomRestServlet(RestServlet): members = await self.store.get_users_in_room(room_id) ret["joined_local_devices"] = await self.store.count_devices_by_users(members) + ret["forgotten"] = await self.store.is_locally_forgotten_room(room_id) return HTTPStatus.OK, ret diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 5e5f607a14..827c1f1efd 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -1215,6 +1215,30 @@ class RoomMemberWorkerStore(EventsWorkerStore): "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn ) + async def is_locally_forgotten_room(self, room_id: str) -> bool: + """Returns whether all local users have forgotten this room_id. + + Args: + room_id: The room ID to query. + + Returns: + Whether the room is forgotten. + """ + + sql = """ + SELECT count(*) > 0 FROM local_current_membership + INNER JOIN room_memberships USING (room_id, event_id) + WHERE + room_id = ? + AND forgotten = 0; + """ + + rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id) + + # `count(*)` returns always an integer + # If any rows still exist it means someone has not forgotten this room yet + return not rows[0][0] + async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]: """Get all rooms that the user has ever been in. diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index dd5000679a..fd6da557c1 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1633,6 +1633,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertIn("history_visibility", channel.json_body) self.assertIn("state_events", channel.json_body) self.assertIn("room_type", channel.json_body) + self.assertIn("forgotten", channel.json_body) self.assertEqual(room_id_1, channel.json_body["room_id"]) def test_single_room_devices(self) -> None: diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 240b02cb9f..ceec690285 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -23,6 +23,7 @@ from synapse.util import Clock from tests import unittest from tests.server import TestHomeServer +from tests.test_utils import event_injection class RoomMemberStoreTestCase(unittest.HomeserverTestCase): @@ -157,6 +158,75 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # Check that alice's display name is now None self.assertEqual(row[0]["display_name"], None) + def test_room_is_locally_forgotten(self): + """Test that when the last local user has forgotten a room it is known as forgotten.""" + # join two local and one remote user + self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) + self.get_success( + event_injection.inject_member_event(self.hs, self.room, self.u_bob, "join") + ) + self.get_success( + event_injection.inject_member_event( + self.hs, self.room, self.u_charlie.to_string(), "join" + ) + ) + self.assertFalse( + self.get_success(self.store.is_locally_forgotten_room(self.room)) + ) + + # local users leave the room and the room is not forgotten + self.get_success( + event_injection.inject_member_event( + self.hs, self.room, self.u_alice, "leave" + ) + ) + self.get_success( + event_injection.inject_member_event(self.hs, self.room, self.u_bob, "leave") + ) + self.assertFalse( + self.get_success(self.store.is_locally_forgotten_room(self.room)) + ) + + # first user forgets the room, room is not forgotten + self.get_success(self.store.forget(self.u_alice, self.room)) + self.assertFalse( + self.get_success(self.store.is_locally_forgotten_room(self.room)) + ) + + # second (last local) user forgets the room and the room is forgotten + self.get_success(self.store.forget(self.u_bob, self.room)) + self.assertTrue( + self.get_success(self.store.is_locally_forgotten_room(self.room)) + ) + + def test_join_locally_forgotten_room(self): + """Tests if a user joins a forgotten room the room is not forgotten anymore.""" + self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) + self.assertFalse( + self.get_success(self.store.is_locally_forgotten_room(self.room)) + ) + + # after leaving and forget the room, it is forgotten + self.get_success( + event_injection.inject_member_event( + self.hs, self.room, self.u_alice, "leave" + ) + ) + self.get_success(self.store.forget(self.u_alice, self.room)) + self.assertTrue( + self.get_success(self.store.is_locally_forgotten_room(self.room)) + ) + + # after rejoin the room is not forgotten anymore + self.get_success( + event_injection.inject_member_event( + self.hs, self.room, self.u_alice, "join" + ) + ) + self.assertFalse( + self.get_success(self.store.is_locally_forgotten_room(self.room)) + ) + class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: -- cgit 1.5.1 From 8bdf2bd31ef003f0e89a588d8977d4f689ef6856 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 17 Aug 2022 18:08:23 +0000 Subject: Fix a bug in the `/event_reports` Admin API which meant that the total count could be larger than the number of results you can actually query for. (#13525) Co-authored-by: Brendan Abolivier --- changelog.d/13525.bugfix | 1 + synapse/storage/databases/main/room.py | 6 ++++++ tests/rest/admin/test_event_reports.py | 27 +++++++++++++++++++++++++++ 3 files changed, 34 insertions(+) create mode 100644 changelog.d/13525.bugfix (limited to 'tests/rest') diff --git a/changelog.d/13525.bugfix b/changelog.d/13525.bugfix new file mode 100644 index 0000000000..dbd1adbc88 --- /dev/null +++ b/changelog.d/13525.bugfix @@ -0,0 +1 @@ +Fix a bug in the `/event_reports` Admin API which meant that the total count could be larger than the number of results you can actually query for. \ No newline at end of file diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 0f1f0d11ea..b7d4baa6bb 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -2001,9 +2001,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" + # We join on room_stats_state despite not using any columns from it + # because the join can influence the number of rows returned; + # e.g. a room that doesn't have state, maybe because it was deleted. + # The query returning the total count should be consistent with + # the query returning the results. sql = """ SELECT COUNT(*) as total_event_reports FROM event_reports AS er + JOIN room_stats_state ON room_stats_state.room_id = er.room_id {} """.format( where_clause diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index fbc490f46d..8a4e5c3f77 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -410,6 +410,33 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertIn("score", c) self.assertIn("reason", c) + def test_count_correct_despite_table_deletions(self) -> None: + """ + Tests that the count matches the number of rows, even if rows in joined tables + are missing. + """ + + # Delete rows from room_stats_state for one of our rooms. + self.get_success( + self.hs.get_datastores().main.db_pool.simple_delete( + "room_stats_state", {"room_id": self.room_id1}, desc="_" + ) + ) + + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + # The 'total' field is 10 because only 10 reports will actually + # be retrievable since we deleted the rows in the room_stats_state + # table. + self.assertEqual(channel.json_body["total"], 10) + # This is consistent with the number of rows actually returned. + self.assertEqual(len(channel.json_body["event_reports"]), 10) + class EventReportDetailTestCase(unittest.HomeserverTestCase): servlets = [ -- cgit 1.5.1 From f9f03426de338ae1879e174f63adf698bbfc3a4b Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 19 Aug 2022 17:17:10 +0100 Subject: Implement MSC3852: Expose `last_seen_user_agent` to users for their own devices; also expose to Admin API (#13549) --- changelog.d/13549.feature | 1 + changelog.d/13549.misc | 1 + docs/admin_api/user_admin_api.md | 7 +++ synapse/config/experimental.py | 3 ++ synapse/handlers/device.py | 9 +++- synapse/rest/client/devices.py | 27 ++++++++++++ tests/rest/admin/test_user.py | 92 +++++++++++++++++++++++++++++++++++++++- tests/unittest.py | 15 +++++++ 8 files changed, 153 insertions(+), 2 deletions(-) create mode 100644 changelog.d/13549.feature create mode 100644 changelog.d/13549.misc (limited to 'tests/rest') diff --git a/changelog.d/13549.feature b/changelog.d/13549.feature new file mode 100644 index 0000000000..b6a726789c --- /dev/null +++ b/changelog.d/13549.feature @@ -0,0 +1 @@ +Add an experimental implementation for [MSC3852](https://github.com/matrix-org/matrix-spec-proposals/pull/3852). \ No newline at end of file diff --git a/changelog.d/13549.misc b/changelog.d/13549.misc new file mode 100644 index 0000000000..5b4303e87e --- /dev/null +++ b/changelog.d/13549.misc @@ -0,0 +1 @@ +Allow specifying additional request fields when using the `HomeServerTestCase.login` helper method. \ No newline at end of file diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 0871cfebf5..c1ca0c8a64 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -753,6 +753,7 @@ A response body like the following is returned: "device_id": "QBUAZIFURK", "display_name": "android", "last_seen_ip": "1.2.3.4", + "last_seen_user_agent": "Mozilla/5.0 (X11; Linux x86_64; rv:103.0) Gecko/20100101 Firefox/103.0", "last_seen_ts": 1474491775024, "user_id": "" }, @@ -760,6 +761,7 @@ A response body like the following is returned: "device_id": "AUIECTSRND", "display_name": "ios", "last_seen_ip": "1.2.3.5", + "last_seen_user_agent": "Mozilla/5.0 (X11; Linux x86_64; rv:103.0) Gecko/20100101 Firefox/103.0", "last_seen_ts": 1474491775025, "user_id": "" } @@ -786,6 +788,8 @@ The following fields are returned in the JSON response body: Absent if no name has been set. - `last_seen_ip` - The IP address where this device was last seen. (May be a few minutes out of date, for efficiency reasons). + - `last_seen_user_agent` - The user agent of the device when it was last seen. + (May be a few minutes out of date, for efficiency reasons). - `last_seen_ts` - The timestamp (in milliseconds since the unix epoch) when this devices was last seen. (May be a few minutes out of date, for efficiency reasons). - `user_id` - Owner of device. @@ -837,6 +841,7 @@ A response body like the following is returned: "device_id": "", "display_name": "android", "last_seen_ip": "1.2.3.4", + "last_seen_user_agent": "Mozilla/5.0 (X11; Linux x86_64; rv:103.0) Gecko/20100101 Firefox/103.0", "last_seen_ts": 1474491775024, "user_id": "" } @@ -858,6 +863,8 @@ The following fields are returned in the JSON response body: Absent if no name has been set. - `last_seen_ip` - The IP address where this device was last seen. (May be a few minutes out of date, for efficiency reasons). + - `last_seen_user_agent` - The user agent of the device when it was last seen. + (May be a few minutes out of date, for efficiency reasons). - `last_seen_ts` - The timestamp (in milliseconds since the unix epoch) when this devices was last seen. (May be a few minutes out of date, for efficiency reasons). - `user_id` - Owner of device. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 7d17c958bb..c1ff417539 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -90,3 +90,6 @@ class ExperimentalConfig(Config): # MSC3848: Introduce errcodes for specific event sending failures self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False) + + # MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices. + self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 1a8379854c..f5c586f657 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -74,6 +74,7 @@ class DeviceWorkerHandler: self._state_storage = hs.get_storage_controllers().state self._auth_handler = hs.get_auth_handler() self.server_name = hs.hostname + self._msc3852_enabled = hs.config.experimental.msc3852_enabled @trace async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: @@ -747,7 +748,13 @@ def _update_device_from_client_ips( device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]] ) -> None: ip = client_ips.get((device["user_id"], device["device_id"]), {}) - device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")}) + device.update( + { + "last_seen_user_agent": ip.get("user_agent"), + "last_seen_ts": ip.get("last_seen"), + "last_seen_ip": ip.get("ip"), + } + ) class DeviceListUpdater: diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 6fab102437..ed6ce78d47 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -42,12 +42,26 @@ class DevicesRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() + self._msc3852_enabled = hs.config.experimental.msc3852_enabled async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) devices = await self.device_handler.get_devices_by_user( requester.user.to_string() ) + + # If MSC3852 is disabled, then the "last_seen_user_agent" field will be + # removed from each device. If it is enabled, then the field name will + # be replaced by the unstable identifier. + # + # When MSC3852 is accepted, this block of code can just be removed to + # expose "last_seen_user_agent" to clients. + for device in devices: + last_seen_user_agent = device["last_seen_user_agent"] + del device["last_seen_user_agent"] + if self._msc3852_enabled: + device["org.matrix.msc3852.last_seen_user_agent"] = last_seen_user_agent + return 200, {"devices": devices} @@ -108,6 +122,7 @@ class DeviceRestServlet(RestServlet): self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.auth_handler = hs.get_auth_handler() + self._msc3852_enabled = hs.config.experimental.msc3852_enabled async def on_GET( self, request: SynapseRequest, device_id: str @@ -118,6 +133,18 @@ class DeviceRestServlet(RestServlet): ) if device is None: raise NotFoundError("No device found") + + # If MSC3852 is disabled, then the "last_seen_user_agent" field will be + # removed from each device. If it is enabled, then the field name will + # be replaced by the unstable identifier. + # + # When MSC3852 is accepted, this block of code can just be removed to + # expose "last_seen_user_agent" to clients. + last_seen_user_agent = device["last_seen_user_agent"] + del device["last_seen_user_agent"] + if self._msc3852_enabled: + device["org.matrix.msc3852.last_seen_user_agent"] = last_seen_user_agent + return 200, device @interactive_auth_handler diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 411e4ec005..1afd082707 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1,4 +1,4 @@ -# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# Copyright 2018-2022 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -904,6 +904,96 @@ class UsersListTestCase(unittest.HomeserverTestCase): ) +class UserDevicesTestCase(unittest.HomeserverTestCase): + """ + Tests user device management-related Admin APIs. + """ + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + # Set up an Admin user to query the Admin API with. + self.admin_user_id = self.register_user("admin", "pass", admin=True) + self.admin_user_token = self.login("admin", "pass") + + # Set up a test user to query the devices of. + self.other_user_device_id = "TESTDEVICEID" + self.other_user_device_display_name = "My Test Device" + self.other_user_client_ip = "1.2.3.4" + self.other_user_user_agent = "EquestriaTechnology/123.0" + + self.other_user_id = self.register_user("user", "pass", displayname="User1") + self.other_user_token = self.login( + "user", + "pass", + device_id=self.other_user_device_id, + additional_request_fields={ + "initial_device_display_name": self.other_user_device_display_name, + }, + ) + + # Have the "other user" make a request so that the "last_seen_*" fields are + # populated in the tests below. + channel = self.make_request( + "GET", + "/_matrix/client/v3/sync", + access_token=self.other_user_token, + client_ip=self.other_user_client_ip, + custom_headers=[ + ("User-Agent", self.other_user_user_agent), + ], + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + def test_list_user_devices(self) -> None: + """ + Tests that a user's devices and attributes are listed correctly via the Admin API. + """ + # Request all devices of "other user" + channel = self.make_request( + "GET", + f"/_synapse/admin/v2/users/{self.other_user_id}/devices", + access_token=self.admin_user_token, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Double-check we got the single device expected + user_devices = channel.json_body["devices"] + self.assertEqual(len(user_devices), 1) + self.assertEqual(channel.json_body["total"], 1) + + # Check that all the attributes of the device reported are as expected. + self._validate_attributes_of_device_response(user_devices[0]) + + # Request just a single device for "other user" by its ID + channel = self.make_request( + "GET", + f"/_synapse/admin/v2/users/{self.other_user_id}/devices/" + f"{self.other_user_device_id}", + access_token=self.admin_user_token, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Check that all the attributes of the device reported are as expected. + self._validate_attributes_of_device_response(channel.json_body) + + def _validate_attributes_of_device_response(self, response: JsonDict) -> None: + # Check that all device expected attributes are present + self.assertEqual(response["user_id"], self.other_user_id) + self.assertEqual(response["device_id"], self.other_user_device_id) + self.assertEqual(response["display_name"], self.other_user_device_display_name) + self.assertEqual(response["last_seen_ip"], self.other_user_client_ip) + self.assertEqual(response["last_seen_user_agent"], self.other_user_user_agent) + self.assertIsInstance(response["last_seen_ts"], int) + self.assertGreater(response["last_seen_ts"], 0) + + class DeactivateAccountTestCase(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/unittest.py b/tests/unittest.py index bec4a3d023..975b0a23a7 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -677,14 +677,29 @@ class HomeserverTestCase(TestCase): username: str, password: str, device_id: Optional[str] = None, + additional_request_fields: Optional[Dict[str, str]] = None, custom_headers: Optional[Iterable[CustomHeaderType]] = None, ) -> str: """ Log in a user, and get an access token. Requires the Login API be registered. + + Args: + username: The localpart to assign to the new user. + password: The password to assign to the new user. + device_id: An optional device ID to assign to the new device created during + login. + additional_request_fields: A dictionary containing any additional /login + request fields and their values. + custom_headers: Custom HTTP headers and values to add to the /login request. + + Returns: + The newly registered user's Matrix ID. """ body = {"type": "m.login.password", "user": username, "password": password} if device_id: body["device_id"] = device_id + if additional_request_fields: + body.update(additional_request_fields) channel = self.make_request( "POST", -- cgit 1.5.1 From 3dd175b628bab5638165f20de9eade36a4e88147 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 22 Aug 2022 15:17:59 +0200 Subject: `synapse.api.auth.Auth` cleanup: make permission-related methods use `Requester` instead of the `UserID` (#13024) Part of #13019 This changes all the permission-related methods to rely on the Requester instead of the UserID. This is a first step towards enabling scoped access tokens at some point, since I expect the Requester to have scope-related informations in it. It also changes methods which figure out the user/device/appservice out of the access token to return a Requester instead of something else. This avoids having store-related objects in the methods signatures. --- changelog.d/13024.misc | 1 + synapse/api/auth.py | 202 +++++++++++------------ synapse/handlers/auth.py | 17 +- synapse/handlers/directory.py | 24 ++- synapse/handlers/initial_sync.py | 6 +- synapse/handlers/message.py | 23 +-- synapse/handlers/pagination.py | 2 +- synapse/handlers/register.py | 15 +- synapse/handlers/relations.py | 2 +- synapse/handlers/room.py | 4 +- synapse/handlers/room_member.py | 10 +- synapse/handlers/typing.py | 10 +- synapse/http/site.py | 2 +- synapse/rest/admin/_base.py | 10 +- synapse/rest/admin/media.py | 6 +- synapse/rest/admin/rooms.py | 12 +- synapse/rest/admin/users.py | 15 +- synapse/rest/client/profile.py | 4 +- synapse/rest/client/register.py | 3 - synapse/rest/client/room.py | 13 +- synapse/server_notices/server_notices_manager.py | 2 +- synapse/storage/databases/main/registration.py | 2 +- tests/api/test_auth.py | 8 +- tests/handlers/test_typing.py | 8 +- tests/rest/client/test_retention.py | 4 +- tests/rest/client/test_shadow_banned.py | 6 +- 26 files changed, 203 insertions(+), 208 deletions(-) create mode 100644 changelog.d/13024.misc (limited to 'tests/rest') diff --git a/changelog.d/13024.misc b/changelog.d/13024.misc new file mode 100644 index 0000000000..aa43c82429 --- /dev/null +++ b/changelog.d/13024.misc @@ -0,0 +1 @@ +Refactor methods in `synapse.api.auth.Auth` to use `Requester` objects everywhere instead of user IDs. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 523bad0c55..9a1aea083f 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -37,8 +37,7 @@ from synapse.logging.opentracing import ( start_active_span, trace, ) -from synapse.storage.databases.main.registration import TokenLookupResult -from synapse.types import Requester, UserID, create_requester +from synapse.types import Requester, create_requester if TYPE_CHECKING: from synapse.server import HomeServer @@ -70,14 +69,14 @@ class Auth: async def check_user_in_room( self, room_id: str, - user_id: str, + requester: Requester, allow_departed_users: bool = False, ) -> Tuple[str, Optional[str]]: """Check if the user is in the room, or was at some point. Args: room_id: The room to check. - user_id: The user to check. + requester: The user making the request, according to the access token. current_state: Optional map of the current state of the room. If provided then that map is used to check whether they are a @@ -94,6 +93,7 @@ class Auth: membership event ID of the user. """ + user_id = requester.user.to_string() ( membership, member_event_id, @@ -182,96 +182,69 @@ class Auth: access_token = self.get_access_token_from_request(request) - ( - user_id, - device_id, - app_service, - ) = await self._get_appservice_user_id_and_device_id(request) - if user_id and app_service: - if ip_addr and self._track_appservice_user_ips: - await self.store.insert_client_ip( - user_id=user_id, - access_token=access_token, - ip=ip_addr, - user_agent=user_agent, - device_id="dummy-device" - if device_id is None - else device_id, # stubbed - ) - - requester = create_requester( - user_id, app_service=app_service, device_id=device_id + # First check if it could be a request from an appservice + requester = await self._get_appservice_user(request) + if not requester: + # If not, it should be from a regular user + requester = await self.get_user_by_access_token( + access_token, allow_expired=allow_expired ) - request.requester = user_id - return requester - - user_info = await self.get_user_by_access_token( - access_token, allow_expired=allow_expired - ) - token_id = user_info.token_id - is_guest = user_info.is_guest - shadow_banned = user_info.shadow_banned - - # Deny the request if the user account has expired. - if not allow_expired: - if await self._account_validity_handler.is_user_expired( - user_info.user_id - ): - # Raise the error if either an account validity module has determined - # the account has expired, or the legacy account validity - # implementation is enabled and determined the account has expired - raise AuthError( - 403, - "User account has expired", - errcode=Codes.EXPIRED_ACCOUNT, - ) - - device_id = user_info.device_id - - if access_token and ip_addr: + # Deny the request if the user account has expired. + # This check is only done for regular users, not appservice ones. + if not allow_expired: + if await self._account_validity_handler.is_user_expired( + requester.user.to_string() + ): + # Raise the error if either an account validity module has determined + # the account has expired, or the legacy account validity + # implementation is enabled and determined the account has expired + raise AuthError( + 403, + "User account has expired", + errcode=Codes.EXPIRED_ACCOUNT, + ) + + if ip_addr and ( + not requester.app_service or self._track_appservice_user_ips + ): + # XXX(quenting): I'm 95% confident that we could skip setting the + # device_id to "dummy-device" for appservices, and that the only impact + # would be some rows which whould not deduplicate in the 'user_ips' + # table during the transition + recorded_device_id = ( + "dummy-device" + if requester.device_id is None and requester.app_service is not None + else requester.device_id + ) await self.store.insert_client_ip( - user_id=user_info.token_owner, + user_id=requester.authenticated_entity, access_token=access_token, ip=ip_addr, user_agent=user_agent, - device_id=device_id, + device_id=recorded_device_id, ) + # Track also the puppeted user client IP if enabled and the user is puppeting if ( - user_info.user_id != user_info.token_owner + requester.user.to_string() != requester.authenticated_entity and self._track_puppeted_user_ips ): await self.store.insert_client_ip( - user_id=user_info.user_id, + user_id=requester.user.to_string(), access_token=access_token, ip=ip_addr, user_agent=user_agent, - device_id=device_id, + device_id=requester.device_id, ) - if is_guest and not allow_guest: + if requester.is_guest and not allow_guest: raise AuthError( 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN, ) - # Mark the token as used. This is used to invalidate old refresh - # tokens after some time. - if not user_info.token_used and token_id is not None: - await self.store.mark_access_token_as_used(token_id) - - requester = create_requester( - user_info.user_id, - token_id, - is_guest, - shadow_banned, - device_id, - app_service=app_service, - authenticated_entity=user_info.token_owner, - ) - request.requester = requester return requester except KeyError: @@ -308,9 +281,7 @@ class Auth: 403, "Application service has not registered this user (%s)" % user_id ) - async def _get_appservice_user_id_and_device_id( - self, request: Request - ) -> Tuple[Optional[str], Optional[str], Optional[ApplicationService]]: + async def _get_appservice_user(self, request: Request) -> Optional[Requester]: """ Given a request, reads the request parameters to determine: - whether it's an application service that's making this request @@ -325,15 +296,13 @@ class Auth: Must use `org.matrix.msc3202.device_id` in place of `device_id` for now. Returns: - 3-tuple of - (user ID?, device ID?, application service?) + the application service `Requester` of that request Postconditions: - - If an application service is returned, so is a user ID - - A user ID is never returned without an application service - - A device ID is never returned without a user ID or an application service - - The returned application service, if present, is permitted to control the - returned user ID. + - The `app_service` field in the returned `Requester` is set + - The `user_id` field in the returned `Requester` is either the application + service sender or the controlled user set by the `user_id` URI parameter + - The returned application service is permitted to control the returned user ID. - The returned device ID, if present, has been checked to be a valid device ID for the returned user ID. """ @@ -343,12 +312,12 @@ class Auth: self.get_access_token_from_request(request) ) if app_service is None: - return None, None, None + return None if app_service.ip_range_whitelist: ip_address = IPAddress(request.getClientAddress().host) if ip_address not in app_service.ip_range_whitelist: - return None, None, None + return None # This will always be set by the time Twisted calls us. assert request.args is not None @@ -382,13 +351,15 @@ class Auth: Codes.EXCLUSIVE, ) - return effective_user_id, effective_device_id, app_service + return create_requester( + effective_user_id, app_service=app_service, device_id=effective_device_id + ) async def get_user_by_access_token( self, token: str, allow_expired: bool = False, - ) -> TokenLookupResult: + ) -> Requester: """Validate access token and get user_id from it Args: @@ -405,9 +376,9 @@ class Auth: # First look in the database to see if the access token is present # as an opaque token. - r = await self.store.get_user_by_access_token(token) - if r: - valid_until_ms = r.valid_until_ms + user_info = await self.store.get_user_by_access_token(token) + if user_info: + valid_until_ms = user_info.valid_until_ms if ( not allow_expired and valid_until_ms is not None @@ -419,7 +390,20 @@ class Auth: msg="Access token has expired", soft_logout=True ) - return r + # Mark the token as used. This is used to invalidate old refresh + # tokens after some time. + await self.store.mark_access_token_as_used(user_info.token_id) + + requester = create_requester( + user_id=user_info.user_id, + access_token_id=user_info.token_id, + is_guest=user_info.is_guest, + shadow_banned=user_info.shadow_banned, + device_id=user_info.device_id, + authenticated_entity=user_info.token_owner, + ) + + return requester # If the token isn't found in the database, then it could still be a # macaroon for a guest, so we check that here. @@ -445,11 +429,12 @@ class Auth: "Guest access token used for regular user" ) - return TokenLookupResult( + return create_requester( user_id=user_id, is_guest=True, # all guests get the same device id device_id=GUEST_DEVICE_ID, + authenticated_entity=user_id, ) except ( pymacaroons.exceptions.MacaroonException, @@ -472,32 +457,33 @@ class Auth: request.requester = create_requester(service.sender, app_service=service) return service - async def is_server_admin(self, user: UserID) -> bool: + async def is_server_admin(self, requester: Requester) -> bool: """Check if the given user is a local server admin. Args: - user: user to check + requester: The user making the request, according to the access token. Returns: True if the user is an admin """ - return await self.store.is_server_admin(user) + return await self.store.is_server_admin(requester.user) - async def check_can_change_room_list(self, room_id: str, user: UserID) -> bool: + async def check_can_change_room_list( + self, room_id: str, requester: Requester + ) -> bool: """Determine whether the user is allowed to edit the room's entry in the published room list. Args: - room_id - user + room_id: The room to check. + requester: The user making the request, according to the access token. """ - is_admin = await self.is_server_admin(user) + is_admin = await self.is_server_admin(requester) if is_admin: return True - user_id = user.to_string() - await self.check_user_in_room(room_id, user_id) + await self.check_user_in_room(room_id, requester) # We currently require the user is a "moderator" in the room. We do this # by checking if they would (theoretically) be able to change the @@ -516,7 +502,9 @@ class Auth: send_level = event_auth.get_send_level( EventTypes.CanonicalAlias, "", power_level_event ) - user_level = event_auth.get_user_power_level(user_id, auth_events) + user_level = event_auth.get_user_power_level( + requester.user.to_string(), auth_events + ) return user_level >= send_level @@ -574,16 +562,16 @@ class Auth: @trace async def check_user_in_room_or_world_readable( - self, room_id: str, user_id: str, allow_departed_users: bool = False + self, room_id: str, requester: Requester, allow_departed_users: bool = False ) -> Tuple[str, Optional[str]]: """Checks that the user is or was in the room or the room is world readable. If it isn't then an exception is raised. Args: - room_id: room to check - user_id: user to check - allow_departed_users: if True, accept users that were previously - members but have now departed + room_id: The room to check. + requester: The user making the request, according to the access token. + allow_departed_users: If True, accept users that were previously + members but have now departed. Returns: Resolves to the current membership of the user in the room and the @@ -598,7 +586,7 @@ class Auth: # * The user is a guest user, and has joined the room # else it will throw. return await self.check_user_in_room( - room_id, user_id, allow_departed_users=allow_departed_users + room_id, requester, allow_departed_users=allow_departed_users ) except AuthError: visibility = await self._storage_controllers.state.get_current_state_event( @@ -613,6 +601,6 @@ class Auth: raise UnstableSpecAuthError( 403, "User %s not in room %s, and room previews are disabled" - % (user_id, room_id), + % (requester.user, room_id), errcode=Codes.NOT_JOINED, ) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index bfa5535044..0327fc57a4 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -280,7 +280,7 @@ class AuthHandler: that it isn't stolen by re-authenticating them. Args: - requester: The user, as given by the access token + requester: The user making the request, according to the access token. request: The request sent by the client. @@ -1435,20 +1435,25 @@ class AuthHandler: access_token: access token to be deleted """ - user_info = await self.auth.get_user_by_access_token(access_token) + token = await self.store.get_user_by_access_token(access_token) + if not token: + # At this point, the token should already have been fetched once by + # the caller, so this should not happen, unless of a race condition + # between two delete requests + raise SynapseError(HTTPStatus.UNAUTHORIZED, "Unrecognised access token") await self.store.delete_access_token(access_token) # see if any modules want to know about this await self.password_auth_provider.on_logged_out( - user_id=user_info.user_id, - device_id=user_info.device_id, + user_id=token.user_id, + device_id=token.device_id, access_token=access_token, ) # delete pushers associated with this access token - if user_info.token_id is not None: + if token.token_id is not None: await self.hs.get_pusherpool().remove_pushers_by_access_token( - user_info.user_id, (user_info.token_id,) + token.user_id, (token.token_id,) ) async def delete_access_tokens_for_user( diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 09a7a4b238..948f66a94d 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -30,7 +30,7 @@ from synapse.api.errors import ( from synapse.appservice import ApplicationService from synapse.module_api import NOT_SPAM from synapse.storage.databases.main.directory import RoomAliasMapping -from synapse.types import JsonDict, Requester, RoomAlias, UserID, get_domain_from_id +from synapse.types import JsonDict, Requester, RoomAlias, get_domain_from_id if TYPE_CHECKING: from synapse.server import HomeServer @@ -133,7 +133,7 @@ class DirectoryHandler: else: # Server admins are not subject to the same constraints as normal # users when creating an alias (e.g. being in the room). - is_admin = await self.auth.is_server_admin(requester.user) + is_admin = await self.auth.is_server_admin(requester) if (self.require_membership and check_membership) and not is_admin: rooms_for_user = await self.store.get_rooms_for_user(user_id) @@ -197,7 +197,7 @@ class DirectoryHandler: user_id = requester.user.to_string() try: - can_delete = await self._user_can_delete_alias(room_alias, user_id) + can_delete = await self._user_can_delete_alias(room_alias, requester) except StoreError as e: if e.code == 404: raise NotFoundError("Unknown room alias") @@ -400,7 +400,9 @@ class DirectoryHandler: # either no interested services, or no service with an exclusive lock return True - async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str) -> bool: + async def _user_can_delete_alias( + self, alias: RoomAlias, requester: Requester + ) -> bool: """Determine whether a user can delete an alias. One of the following must be true: @@ -413,7 +415,7 @@ class DirectoryHandler: """ creator = await self.store.get_room_alias_creator(alias.to_string()) - if creator == user_id: + if creator == requester.user.to_string(): return True # Resolve the alias to the corresponding room. @@ -422,9 +424,7 @@ class DirectoryHandler: if not room_id: return False - return await self.auth.check_can_change_room_list( - room_id, UserID.from_string(user_id) - ) + return await self.auth.check_can_change_room_list(room_id, requester) async def edit_published_room_list( self, requester: Requester, room_id: str, visibility: str @@ -463,7 +463,7 @@ class DirectoryHandler: raise SynapseError(400, "Unknown room") can_change_room_list = await self.auth.check_can_change_room_list( - room_id, requester.user + room_id, requester ) if not can_change_room_list: raise AuthError( @@ -528,10 +528,8 @@ class DirectoryHandler: Get a list of the aliases that currently point to this room on this server """ # allow access to server admins and current members of the room - is_admin = await self.auth.is_server_admin(requester.user) + is_admin = await self.auth.is_server_admin(requester) if not is_admin: - await self.auth.check_user_in_room_or_world_readable( - room_id, requester.user.to_string() - ) + await self.auth.check_user_in_room_or_world_readable(room_id, requester) return await self.store.get_aliases_for_room(room_id) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 6484e47e5f..860c82c110 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -309,18 +309,18 @@ class InitialSyncHandler: if blocked: raise SynapseError(403, "This room has been blocked on this server") - user_id = requester.user.to_string() - ( membership, member_event_id, ) = await self.auth.check_user_in_room_or_world_readable( room_id, - user_id, + requester, allow_departed_users=True, ) is_peeking = member_event_id is None + user_id = requester.user.to_string() + if membership == Membership.JOIN: result = await self._room_initial_sync_joined( user_id, room_id, pagin_config, membership, is_peeking diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 8f29ee9a87..acd3de06f6 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -104,7 +104,7 @@ class MessageHandler: async def get_room_data( self, - user_id: str, + requester: Requester, room_id: str, event_type: str, state_key: str, @@ -112,7 +112,7 @@ class MessageHandler: """Get data from a room. Args: - user_id + requester: The user who did the request. room_id event_type state_key @@ -125,7 +125,7 @@ class MessageHandler: membership, membership_event_id, ) = await self.auth.check_user_in_room_or_world_readable( - room_id, user_id, allow_departed_users=True + room_id, requester, allow_departed_users=True ) if membership == Membership.JOIN: @@ -161,11 +161,10 @@ class MessageHandler: async def get_state_events( self, - user_id: str, + requester: Requester, room_id: str, state_filter: Optional[StateFilter] = None, at_token: Optional[StreamToken] = None, - is_guest: bool = False, ) -> List[dict]: """Retrieve all state events for a given room. If the user is joined to the room then return the current state. If the user has @@ -174,14 +173,13 @@ class MessageHandler: visible. Args: - user_id: The user requesting state events. + requester: The user requesting state events. room_id: The room ID to get all state events from. state_filter: The state filter used to fetch state from the database. at_token: the stream token of the at which we are requesting the stats. If the user is not allowed to view the state as of that stream token, we raise a 403 SynapseError. If None, returns the current state based on the current_state_events table. - is_guest: whether this user is a guest Returns: A list of dicts representing state events. [{}, {}, {}] Raises: @@ -191,6 +189,7 @@ class MessageHandler: members of this room. """ state_filter = state_filter or StateFilter.all() + user_id = requester.user.to_string() if at_token: last_event_id = ( @@ -223,7 +222,7 @@ class MessageHandler: membership, membership_event_id, ) = await self.auth.check_user_in_room_or_world_readable( - room_id, user_id, allow_departed_users=True + room_id, requester, allow_departed_users=True ) if membership == Membership.JOIN: @@ -317,12 +316,11 @@ class MessageHandler: Returns: A dict of user_id to profile info """ - user_id = requester.user.to_string() if not requester.app_service: # We check AS auth after fetching the room membership, as it # requires us to pull out all joined members anyway. membership, _ = await self.auth.check_user_in_room_or_world_readable( - room_id, user_id, allow_departed_users=True + room_id, requester, allow_departed_users=True ) if membership != Membership.JOIN: raise SynapseError( @@ -340,7 +338,10 @@ class MessageHandler: # If this is an AS, double check that they are allowed to see the members. # This can either be because the AS user is in the room or because there # is a user in the room that the AS is "interested in" - if requester.app_service and user_id not in users_with_profile: + if ( + requester.app_service + and requester.user.to_string() not in users_with_profile + ): for uid in users_with_profile: if requester.app_service.is_interested_in_user(uid): break diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index e1e34e3b16..74e944bce7 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -464,7 +464,7 @@ class PaginationHandler: membership, member_event_id, ) = await self.auth.check_user_in_room_or_world_readable( - room_id, user_id, allow_departed_users=True + room_id, requester, allow_departed_users=True ) if pagin_config.direction == "b": diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index c77d181722..20ec22105a 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -29,7 +29,13 @@ from synapse.api.constants import ( JoinRules, LoginType, ) -from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError +from synapse.api.errors import ( + AuthError, + Codes, + ConsentNotGivenError, + InvalidClientTokenError, + SynapseError, +) from synapse.appservice import ApplicationService from synapse.config.server import is_threepid_reserved from synapse.http.servlet import assert_params_in_dict @@ -180,10 +186,7 @@ class RegistrationHandler: ) if guest_access_token: user_data = await self.auth.get_user_by_access_token(guest_access_token) - if ( - not user_data.is_guest - or UserID.from_string(user_data.user_id).localpart != localpart - ): + if not user_data.is_guest or user_data.user.localpart != localpart: raise AuthError( 403, "Cannot register taken user ID without valid guest " @@ -618,7 +621,7 @@ class RegistrationHandler: user_id = user.to_string() service = self.store.get_app_service_by_token(as_token) if not service: - raise AuthError(403, "Invalid application service token.") + raise InvalidClientTokenError() if not service.is_interested_in_user(user_id): raise SynapseError( 400, diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 72d25df8c8..28d7093f08 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -103,7 +103,7 @@ class RelationsHandler: # TODO Properly handle a user leaving a room. (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable( - room_id, user_id, allow_departed_users=True + room_id, requester, allow_departed_users=True ) # This gets the original event and checks that a) the event exists and diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 55395457c3..2bf0ebd025 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -721,7 +721,7 @@ class RoomCreationHandler: # allow the server notices mxid to create rooms is_requester_admin = True else: - is_requester_admin = await self.auth.is_server_admin(requester.user) + is_requester_admin = await self.auth.is_server_admin(requester) # Let the third party rules modify the room creation config if needed, or abort # the room creation entirely with an exception. @@ -1279,7 +1279,7 @@ class RoomContextHandler: """ user = requester.user if use_admin_priviledge: - await assert_user_is_admin(self.auth, requester.user) + await assert_user_is_admin(self.auth, requester) before_limit = math.floor(limit / 2.0) after_limit = limit - before_limit diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 70dc69c809..d1909665d6 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -179,7 +179,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): """Try and join a room that this server is not in Args: - requester + requester: The user making the request, according to the access token. remote_room_hosts: List of servers that can be used to join via. room_id: Room that we are trying to join user: User who is trying to join @@ -744,7 +744,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): is_requester_admin = True else: - is_requester_admin = await self.auth.is_server_admin(requester.user) + is_requester_admin = await self.auth.is_server_admin(requester) if not is_requester_admin: if self.config.server.block_non_admin_invites: @@ -868,7 +868,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): bypass_spam_checker = True else: - bypass_spam_checker = await self.auth.is_server_admin(requester.user) + bypass_spam_checker = await self.auth.is_server_admin(requester) inviter = await self._get_inviter(target.to_string(), room_id) if ( @@ -1410,7 +1410,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ShadowBanError if the requester has been shadow-banned. """ if self.config.server.block_non_admin_invites: - is_requester_admin = await self.auth.is_server_admin(requester.user) + is_requester_admin = await self.auth.is_server_admin(requester) if not is_requester_admin: raise SynapseError( 403, "Invites have been disabled on this server", Codes.FORBIDDEN @@ -1693,7 +1693,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): check_complexity and self.hs.config.server.limit_remote_rooms.admins_can_join ): - check_complexity = not await self.auth.is_server_admin(user) + check_complexity = not await self.store.is_server_admin(user) if check_complexity: # Fetch the room complexity diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 27aa0d3126..bcac3372a2 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -253,12 +253,11 @@ class TypingWriterHandler(FollowerTypingHandler): self, target_user: UserID, requester: Requester, room_id: str, timeout: int ) -> None: target_user_id = target_user.to_string() - auth_user_id = requester.user.to_string() if not self.is_mine_id(target_user_id): raise SynapseError(400, "User is not hosted on this homeserver") - if target_user_id != auth_user_id: + if target_user != requester.user: raise AuthError(400, "Cannot set another user's typing state") if requester.shadow_banned: @@ -266,7 +265,7 @@ class TypingWriterHandler(FollowerTypingHandler): await self.clock.sleep(random.randint(1, 10)) raise ShadowBanError() - await self.auth.check_user_in_room(room_id, target_user_id) + await self.auth.check_user_in_room(room_id, requester) logger.debug("%s has started typing in %s", target_user_id, room_id) @@ -289,12 +288,11 @@ class TypingWriterHandler(FollowerTypingHandler): self, target_user: UserID, requester: Requester, room_id: str ) -> None: target_user_id = target_user.to_string() - auth_user_id = requester.user.to_string() if not self.is_mine_id(target_user_id): raise SynapseError(400, "User is not hosted on this homeserver") - if target_user_id != auth_user_id: + if target_user != requester.user: raise AuthError(400, "Cannot set another user's typing state") if requester.shadow_banned: @@ -302,7 +300,7 @@ class TypingWriterHandler(FollowerTypingHandler): await self.clock.sleep(random.randint(1, 10)) raise ShadowBanError() - await self.auth.check_user_in_room(room_id, target_user_id) + await self.auth.check_user_in_room(room_id, requester) logger.debug("%s has stopped typing in %s", target_user_id, room_id) diff --git a/synapse/http/site.py b/synapse/http/site.py index eeec74b78a..1155f3f610 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -226,7 +226,7 @@ class SynapseRequest(Request): # If this is a request where the target user doesn't match the user who # authenticated (e.g. and admin is puppetting a user) then we return both. - if self._requester.user.to_string() != authenticated_entity: + if requester != authenticated_entity: return requester, authenticated_entity return requester, None diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py index 399b205aaf..b467a61dfb 100644 --- a/synapse/rest/admin/_base.py +++ b/synapse/rest/admin/_base.py @@ -19,7 +19,7 @@ from typing import Iterable, Pattern from synapse.api.auth import Auth from synapse.api.errors import AuthError from synapse.http.site import SynapseRequest -from synapse.types import UserID +from synapse.types import Requester def admin_patterns(path_regex: str, version: str = "v1") -> Iterable[Pattern]: @@ -48,19 +48,19 @@ async def assert_requester_is_admin(auth: Auth, request: SynapseRequest) -> None AuthError if the requester is not a server admin """ requester = await auth.get_user_by_req(request) - await assert_user_is_admin(auth, requester.user) + await assert_user_is_admin(auth, requester) -async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None: +async def assert_user_is_admin(auth: Auth, requester: Requester) -> None: """Verify that the given user is an admin user Args: auth: Auth singleton - user_id: user to check + requester: The user making the request, according to the access token. Raises: AuthError if the user is not a server admin """ - is_admin = await auth.is_server_admin(user_id) + is_admin = await auth.is_server_admin(requester) if not is_admin: raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin") diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 19d4a008e8..73470f09ae 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -54,7 +54,7 @@ class QuarantineMediaInRoom(RestServlet): self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_user_is_admin(self.auth, requester) logging.info("Quarantining room: %s", room_id) @@ -81,7 +81,7 @@ class QuarantineMediaByUser(RestServlet): self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_user_is_admin(self.auth, requester) logging.info("Quarantining media by user: %s", user_id) @@ -110,7 +110,7 @@ class QuarantineMediaByID(RestServlet): self, request: SynapseRequest, server_name: str, media_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_user_is_admin(self.auth, requester) logging.info("Quarantining media by ID: %s/%s", server_name, media_id) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 68054ffc28..3d870629c4 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -75,7 +75,7 @@ class RoomRestV2Servlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self._auth.get_user_by_req(request) - await assert_user_is_admin(self._auth, requester.user) + await assert_user_is_admin(self._auth, requester) content = parse_json_object_from_request(request) @@ -327,7 +327,7 @@ class RoomRestServlet(RestServlet): pagination_handler: "PaginationHandler", ) -> Tuple[int, JsonDict]: requester = await auth.get_user_by_req(request) - await assert_user_is_admin(auth, requester.user) + await assert_user_is_admin(auth, requester) content = parse_json_object_from_request(request) @@ -461,7 +461,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): assert request.args is not None requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_user_is_admin(self.auth, requester) content = parse_json_object_from_request(request) @@ -551,7 +551,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): self, request: SynapseRequest, room_identifier: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_user_is_admin(self.auth, requester) content = parse_json_object_from_request(request, allow_empty_body=True) room_id, _ = await self.resolve_room_id(room_identifier) @@ -742,7 +742,7 @@ class RoomEventContextServlet(RestServlet): self, request: SynapseRequest, room_id: str, event_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=False) - await assert_user_is_admin(self.auth, requester.user) + await assert_user_is_admin(self.auth, requester) limit = parse_integer(request, "limit", default=10) @@ -834,7 +834,7 @@ class BlockRoomRestServlet(RestServlet): self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: requester = await self._auth.get_user_by_req(request) - await assert_user_is_admin(self._auth, requester.user) + await assert_user_is_admin(self._auth, requester) content = parse_json_object_from_request(request) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index ba2f7fa6d8..78ee9b6532 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -183,7 +183,7 @@ class UserRestServletV2(RestServlet): self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_user_is_admin(self.auth, requester) target_user = UserID.from_string(user_id) body = parse_json_object_from_request(request) @@ -575,10 +575,9 @@ class WhoisRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: target_user = UserID.from_string(user_id) requester = await self.auth.get_user_by_req(request) - auth_user = requester.user - if target_user != auth_user: - await assert_user_is_admin(self.auth, auth_user) + if target_user != requester.user: + await assert_user_is_admin(self.auth, requester) if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user") @@ -601,7 +600,7 @@ class DeactivateAccountRestServlet(RestServlet): self, request: SynapseRequest, target_user_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_user_is_admin(self.auth, requester) if not self.is_mine(UserID.from_string(target_user_id)): raise SynapseError( @@ -693,7 +692,7 @@ class ResetPasswordRestServlet(RestServlet): This needs user to have administrator access in Synapse. """ requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_user_is_admin(self.auth, requester) UserID.from_string(target_user_id) @@ -807,7 +806,7 @@ class UserAdminServlet(RestServlet): self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_user_is_admin(self.auth, requester) auth_user = requester.user target_user = UserID.from_string(user_id) @@ -921,7 +920,7 @@ class UserTokenRestServlet(RestServlet): self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_user_is_admin(self.auth, requester) auth_user = requester.user if not self.is_mine_id(user_id): diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py index c16d707909..e69fa0829d 100644 --- a/synapse/rest/client/profile.py +++ b/synapse/rest/client/profile.py @@ -66,7 +66,7 @@ class ProfileDisplaynameRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) user = UserID.from_string(user_id) - is_admin = await self.auth.is_server_admin(requester.user) + is_admin = await self.auth.is_server_admin(requester) content = parse_json_object_from_request(request) @@ -123,7 +123,7 @@ class ProfileAvatarURLRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) - is_admin = await self.auth.is_server_admin(requester.user) + is_admin = await self.auth.is_server_admin(requester) content = parse_json_object_from_request(request) try: diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 956c45e60a..1b953d3fa0 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -484,9 +484,6 @@ class RegisterRestServlet(RestServlet): "Appservice token must be provided when using a type of m.login.application_service", ) - # Verify the AS - self.auth.get_appservice_by_req(request) - # Set the desired user according to the AS API (which uses the # 'user' key not 'username'). Since this is a new addition, we'll # fallback to 'username' if they gave one. diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 13bc9482c5..0eafbae457 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -229,7 +229,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): msg_handler = self.message_handler data = await msg_handler.get_room_data( - user_id=requester.user.to_string(), + requester=requester, room_id=room_id, event_type=event_type, state_key=state_key, @@ -574,7 +574,7 @@ class RoomMemberListRestServlet(RestServlet): events = await handler.get_state_events( room_id=room_id, - user_id=requester.user.to_string(), + requester=requester, at_token=at_token, state_filter=StateFilter.from_types([(EventTypes.Member, None)]), ) @@ -696,8 +696,7 @@ class RoomStateRestServlet(RestServlet): # Get all the current state for this room events = await self.message_handler.get_state_events( room_id=room_id, - user_id=requester.user.to_string(), - is_guest=requester.is_guest, + requester=requester, ) return 200, events @@ -755,7 +754,7 @@ class RoomEventServlet(RestServlet): == "true" ) if include_unredacted_content and not await self.auth.is_server_admin( - requester.user + requester ): power_level_event = ( await self._storage_controllers.state.get_current_state_event( @@ -1260,9 +1259,7 @@ class TimestampLookupRestServlet(RestServlet): self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: requester = await self._auth.get_user_by_req(request) - await self._auth.check_user_in_room_or_world_readable( - room_id, requester.user.to_string() - ) + await self._auth.check_user_in_room_or_world_readable(room_id, requester) timestamp = parse_integer(request, "ts", required=True) direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"]) diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 8ecab86ec7..70d054a8f4 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -244,7 +244,7 @@ class ServerNoticesManager: assert self.server_notices_mxid is not None notice_user_data_in_room = await self._message_handler.get_room_data( - self.server_notices_mxid, + create_requester(self.server_notices_mxid), room_id, EventTypes.Member, self.server_notices_mxid, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index cb63cd9b7d..7fb9c801da 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -69,9 +69,9 @@ class TokenLookupResult: """ user_id: str + token_id: int is_guest: bool = False shadow_banned: bool = False - token_id: Optional[int] = None device_id: Optional[str] = None valid_until_ms: Optional[int] = None token_owner: str = attr.ib() diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index dfcfaf79b6..e0f363555b 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -284,10 +284,13 @@ class AuthTestCase(unittest.HomeserverTestCase): TokenLookupResult( user_id="@baldrick:matrix.org", device_id="device", + token_id=5, token_owner="@admin:matrix.org", + token_used=True, ) ) self.store.insert_client_ip = simple_async_mock(None) + self.store.mark_access_token_as_used = simple_async_mock(None) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" request.args[b"access_token"] = [self.test_token] @@ -301,10 +304,13 @@ class AuthTestCase(unittest.HomeserverTestCase): TokenLookupResult( user_id="@baldrick:matrix.org", device_id="device", + token_id=5, token_owner="@admin:matrix.org", + token_used=True, ) ) self.store.insert_client_ip = simple_async_mock(None) + self.store.mark_access_token_as_used = simple_async_mock(None) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" request.args[b"access_token"] = [self.test_token] @@ -347,7 +353,7 @@ class AuthTestCase(unittest.HomeserverTestCase): serialized = macaroon.serialize() user_info = self.get_success(self.auth.get_user_by_access_token(serialized)) - self.assertEqual(user_id, user_info.user_id) + self.assertEqual(user_id, user_info.user.to_string()) self.assertTrue(user_info.is_guest) self.store.get_user_by_id.assert_called_with(user_id) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 7af1333126..8adba29d7f 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -25,7 +25,7 @@ from synapse.api.constants import EduTypes from synapse.api.errors import AuthError from synapse.federation.transport.server import TransportLayerServer from synapse.server import HomeServer -from synapse.types import JsonDict, UserID, create_requester +from synapse.types import JsonDict, Requester, UserID, create_requester from synapse.util import Clock from tests import unittest @@ -117,8 +117,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.room_members = [] - async def check_user_in_room(room_id: str, user_id: str) -> None: - if user_id not in [u.to_string() for u in self.room_members]: + async def check_user_in_room(room_id: str, requester: Requester) -> None: + if requester.user.to_string() not in [ + u.to_string() for u in self.room_members + ]: raise AuthError(401, "User is not in the room") return None diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index ac9c113354..9c8c1889d3 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -20,7 +20,7 @@ from synapse.api.constants import EventTypes from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer -from synapse.types import JsonDict +from synapse.types import JsonDict, create_requester from synapse.util import Clock from synapse.visibility import filter_events_for_client @@ -188,7 +188,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): message_handler = self.hs.get_message_handler() create_event = self.get_success( message_handler.get_room_data( - self.user_id, room_id, EventTypes.Create, state_key="" + create_requester(self.user_id), room_id, EventTypes.Create, state_key="" ) ) diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index d9bd8c4a28..c50f034b34 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -26,7 +26,7 @@ from synapse.rest.client import ( room_upgrade_rest_servlet, ) from synapse.server import HomeServer -from synapse.types import UserID +from synapse.types import UserID, create_requester from synapse.util import Clock from tests import unittest @@ -275,7 +275,7 @@ class ProfileTestCase(_ShadowBannedBase): message_handler = self.hs.get_message_handler() event = self.get_success( message_handler.get_room_data( - self.banned_user_id, + create_requester(self.banned_user_id), room_id, "m.room.member", self.banned_user_id, @@ -310,7 +310,7 @@ class ProfileTestCase(_ShadowBannedBase): message_handler = self.hs.get_message_handler() event = self.get_success( message_handler.get_room_data( - self.banned_user_id, + create_requester(self.banned_user_id), room_id, "m.room.member", self.banned_user_id, -- cgit 1.5.1 From 37f329c9adf6ed02df15661850f999edd9e5fd93 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 23 Aug 2022 10:48:35 +0200 Subject: Fix that sending server notices fail if avatar is `None` (#13566) Indroduced in #11846. --- changelog.d/13566.bugfix | 1 + synapse/handlers/room_member.py | 2 +- tests/rest/admin/test_server_notice.py | 56 ++++++++++++++++++++++ .../test_resource_limits_server_notices.py | 9 ++-- 4 files changed, 64 insertions(+), 4 deletions(-) create mode 100644 changelog.d/13566.bugfix (limited to 'tests/rest') diff --git a/changelog.d/13566.bugfix b/changelog.d/13566.bugfix new file mode 100644 index 0000000000..6c44024add --- /dev/null +++ b/changelog.d/13566.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.52.0 where sending server notices fails if `max_avatar_size` or `allowed_avatar_mimetypes` is set and not `system_mxid_avatar_url`. \ No newline at end of file diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index d1909665d6..65b9a655d4 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -689,7 +689,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): errcode=Codes.BAD_JSON, ) - if "avatar_url" in content: + if "avatar_url" in content and content.get("avatar_url") is not None: if not await self.profile_handler.check_avatar_size_and_mime_type( content["avatar_url"], ): diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index 81e125e27d..a2f347f666 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -159,6 +159,62 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual("'msgtype' not in content", channel.json_body["error"]) + @override_config( + { + "server_notices": { + "system_mxid_localpart": "notices", + "system_mxid_avatar_url": "somthingwrong", + }, + "max_avatar_size": "10M", + } + ) + def test_invalid_avatar_url(self) -> None: + """If avatar url in homeserver.yaml is invalid and + "check avatar size and mime type" is set, an error is returned. + TODO: Should be checked when reading the configuration.""" + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg"}, + }, + ) + + self.assertEqual(500, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + @override_config( + { + "server_notices": { + "system_mxid_localpart": "notices", + "system_mxid_display_name": "test display name", + "system_mxid_avatar_url": None, + }, + "max_avatar_size": "10M", + } + ) + def test_displayname_is_set_avatar_is_none(self) -> None: + """ + Tests that sending a server notices is successfully, + if a display_name is set, avatar_url is `None` and + "check avatar size and mime type" is set. + """ + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg"}, + }, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # user has one invite + self._check_invite_and_join_status(self.other_user, 1, 0) + def test_server_notice_disabled(self) -> None: """Tests that server returns error if server notice is disabled""" channel = self.make_request( diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index e07ae78fc4..bf403045e9 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -11,16 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType from synapse.api.errors import ResourceLimitError from synapse.rest import admin from synapse.rest.client import login, room, sync +from synapse.server import HomeServer from synapse.server_notices.resource_limits_server_notices import ( ResourceLimitsServerNotices, ) +from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable @@ -52,7 +55,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): return config - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.server_notices_sender = self.hs.get_server_notices_sender() # relying on [1] is far from ideal, but the only case where @@ -251,7 +254,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): c["admin_contact"] = "mailto:user@test.com" return c - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = self.hs.get_datastores().main self.server_notices_sender = self.hs.get_server_notices_sender() self.server_notices_manager = self.hs.get_server_notices_manager() -- cgit 1.5.1 From 956e015413d3da417c1058e3e72d97b3d1bc8170 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 23 Aug 2022 12:40:00 +0100 Subject: Drop support for delegating email validation, round 2 (#13596) --- CHANGES.md | 12 +++ changelog.d/13596.removal | 1 + docs/upgrade.md | 19 ++++ docs/usage/configuration/config_documentation.md | 5 +- synapse/app/homeserver.py | 3 +- synapse/config/emailconfig.py | 46 ++-------- synapse/config/registration.py | 13 +-- synapse/handlers/identity.py | 56 +----------- synapse/handlers/ui_auth/checkers.py | 21 +---- synapse/rest/client/account.py | 108 ++++++++--------------- synapse/rest/client/register.py | 59 +++++-------- synapse/rest/synapse/client/password_reset.py | 8 +- tests/rest/client/test_register.py | 2 +- 13 files changed, 108 insertions(+), 245 deletions(-) create mode 100644 changelog.d/13596.removal (limited to 'tests/rest') diff --git a/CHANGES.md b/CHANGES.md index 778713f528..14fafc260d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,6 +1,12 @@ Synapse 1.66.0rc1 (2022-08-23) ============================== +This release removes the ability for homeservers to delegate email ownership +verification and password reset confirmation to identity servers. This removal +was originally planned for Synapse 1.64, but was later deferred until now. + +See the [upgrade notes](https://matrix-org.github.io/synapse/v1.66/upgrade.html#upgrading-to-v1660) for more details. + Features -------- @@ -33,6 +39,12 @@ Improved Documentation - Fix the doc and some warnings that were referring to the nonexistent `custom_templates_directory` setting (instead of `custom_template_directory`). ([\#13538](https://github.com/matrix-org/synapse/issues/13538)) +Deprecations and Removals +------------------------- + +- Remove the ability for homeservers to delegate email ownership verification + and password reset confirmation to identity servers. See [upgrade notes](https://matrix-org.github.io/synapse/v1.66/upgrade.html#upgrading-to-v1660) for more details. + Internal Changes ---------------- diff --git a/changelog.d/13596.removal b/changelog.d/13596.removal new file mode 100644 index 0000000000..6c12ae75b4 --- /dev/null +++ b/changelog.d/13596.removal @@ -0,0 +1 @@ +Remove the ability for homeservers to delegate email ownership verification and password reset confirmation to identity servers. See [upgrade notes](https://github.com/matrix-org/synapse/blob/release-v1.66/docs/upgrade.md#upgrading-to-v1660) for more details. \ No newline at end of file diff --git a/docs/upgrade.md b/docs/upgrade.md index 47a74b67de..0ab5bfeaf0 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -89,6 +89,25 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.66.0 + +## Delegation of email validation no longer supported + +As of this version, Synapse no longer allows the tasks of verifying email address +ownership, and password reset confirmation, to be delegated to an identity server. +This removal was previously planned for Synapse 1.64.0, but was +[delayed](https://github.com/matrix-org/synapse/issues/13421) until now to give +homeserver administrators more notice of the change. + +To continue to allow users to add email addresses to their homeserver accounts, +and perform password resets, make sure that Synapse is configured with a working +email server in the [`email` configuration +section](https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#email) +(including, at a minimum, a `notif_from` setting.) + +Specifying an `email` setting under `account_threepid_delegates` will now cause +an error at startup. + # Upgrading to v1.64.0 ## Deprecation of the ability to delegate e-mail verification to identity servers diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index cc72966823..8ae018e628 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -2182,7 +2182,10 @@ their account. by the Matrix Identity Service API [specification](https://matrix.org/docs/spec/identity_service/latest).) -*Updated in Synapse 1.64.0*: The `email` option is deprecated. +*Deprecated in Synapse 1.64.0*: The `email` option is deprecated. + +*Removed in Synapse 1.66.0*: The `email` option has been removed. +If present, Synapse will report a configuration error on startup. Example configuration: ```yaml diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index d98012adeb..68993d91a9 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -44,7 +44,6 @@ from synapse.app._base import ( register_start, ) from synapse.config._base import ConfigError, format_config_error -from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.homeserver import HomeServerConfig from synapse.config.server import ListenerConfig from synapse.federation.transport.server import TransportLayerServer @@ -202,7 +201,7 @@ class SynapseHomeServer(HomeServer): } ) - if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.config.email.can_verify_email: from synapse.rest.synapse.client.password_reset import ( PasswordResetSubmitTokenResource, ) diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 66a6dbf1fe..a3af35b7c4 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -18,7 +18,6 @@ import email.utils import logging import os -from enum import Enum from typing import Any import attr @@ -136,40 +135,22 @@ class EmailConfig(Config): self.email_enable_notifs = email_config.get("enable_notifs", False) - self.threepid_behaviour_email = ( - # Have Synapse handle the email sending if account_threepid_delegates.email - # is not defined - # msisdn is currently always remote while Synapse does not support any method of - # sending SMS messages - ThreepidBehaviour.REMOTE - if self.root.registration.account_threepid_delegate_email - else ThreepidBehaviour.LOCAL - ) - if config.get("trust_identity_server_for_password_resets"): raise ConfigError( - 'The config option "trust_identity_server_for_password_resets" has been removed.' - "Please consult the configuration manual at docs/usage/configuration/config_documentation.md for " - "details and update your config file." + 'The config option "trust_identity_server_for_password_resets" ' + "is no longer supported. Please remove it from the config file." ) - self.local_threepid_handling_disabled_due_to_email_config = False - if ( - self.threepid_behaviour_email == ThreepidBehaviour.LOCAL - and email_config == {} - ): - # We cannot warn the user this has happened here - # Instead do so when a user attempts to reset their password - self.local_threepid_handling_disabled_due_to_email_config = True - - self.threepid_behaviour_email = ThreepidBehaviour.OFF + # If we have email config settings, assume that we can verify ownership of + # email addresses. + self.can_verify_email = email_config != {} # Get lifetime of a validation token in milliseconds self.email_validation_token_lifetime = self.parse_duration( email_config.get("validation_token_lifetime", "1h") ) - if self.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.can_verify_email: missing = [] if not self.email_notif_from: missing.append("email.notif_from") @@ -360,18 +341,3 @@ class EmailConfig(Config): "Config option email.invite_client_location must be a http or https URL", path=("email", "invite_client_location"), ) - - -class ThreepidBehaviour(Enum): - """ - Enum to define the behaviour of Synapse with regards to when it contacts an identity - server for 3pid registration and password resets - - REMOTE = use an external server to send tokens - LOCAL = send tokens ourselves - OFF = disable registration via 3pid and password resets - """ - - REMOTE = "remote" - LOCAL = "local" - OFF = "off" diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 01fb0331bc..a888d976f2 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse -import logging from typing import Any, Optional from synapse.api.constants import RoomCreationPreset @@ -21,15 +20,11 @@ from synapse.config._base import Config, ConfigError from synapse.types import JsonDict, RoomAlias, UserID from synapse.util.stringutils import random_string_with_symbols, strtobool -logger = logging.getLogger(__name__) - -LEGACY_EMAIL_DELEGATE_WARNING = """\ -Delegation of email verification to an identity server is now deprecated. To +NO_EMAIL_DELEGATE_ERROR = """\ +Delegation of email verification to an identity server is no longer supported. To continue to allow users to add email addresses to their accounts, and use them for password resets, configure Synapse with an SMTP server via the `email` setting, and remove `account_threepid_delegates.email`. - -This will be an error in a future version. """ @@ -64,9 +59,7 @@ class RegistrationConfig(Config): account_threepid_delegates = config.get("account_threepid_delegates") or {} if "email" in account_threepid_delegates: - logger.warning(LEGACY_EMAIL_DELEGATE_WARNING) - - self.account_threepid_delegate_email = account_threepid_delegates.get("email") + raise ConfigError(NO_EMAIL_DELEGATE_ERROR) self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn") self.default_identity_server = config.get("default_identity_server") self.allow_guest_access = config.get("allow_guest_access", False) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index e5afe84df9..9571d461c8 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -26,7 +26,6 @@ from synapse.api.errors import ( SynapseError, ) from synapse.api.ratelimiting import Ratelimiter -from synapse.config.emailconfig import ThreepidBehaviour from synapse.http import RequestTimedOutError from synapse.http.client import SimpleHttpClient from synapse.http.site import SynapseRequest @@ -416,48 +415,6 @@ class IdentityHandler: return session_id - async def request_email_token( - self, - id_server: str, - email: str, - client_secret: str, - send_attempt: int, - next_link: Optional[str] = None, - ) -> JsonDict: - """ - Request an external server send an email on our behalf for the purposes of threepid - validation. - - Args: - id_server: The identity server to proxy to - email: The email to send the message to - client_secret: The unique client_secret sends by the user - send_attempt: Which attempt this is - next_link: A link to redirect the user to once they submit the token - - Returns: - The json response body from the server - """ - params = { - "email": email, - "client_secret": client_secret, - "send_attempt": send_attempt, - } - if next_link: - params["next_link"] = next_link - - try: - data = await self.http_client.post_json_get_json( - id_server + "/_matrix/identity/api/v1/validate/email/requestToken", - params, - ) - return data - except HttpResponseException as e: - logger.info("Proxied requestToken failed: %r", e) - raise e.to_synapse_error() - except RequestTimedOutError: - raise SynapseError(500, "Timed out contacting identity server") - async def requestMsisdnToken( self, id_server: str, @@ -531,18 +488,7 @@ class IdentityHandler: validation_session = None # Try to validate as email - if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - # Remote emails will only be used if a valid identity server is provided. - assert ( - self.hs.config.registration.account_threepid_delegate_email is not None - ) - - # Ask our delegated email identity server - validation_session = await self.threepid_from_creds( - self.hs.config.registration.account_threepid_delegate_email, - threepid_creds, - ) - elif self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.hs.config.email.can_verify_email: # Get a validated session matching these details validation_session = await self.store.get_threepid_validation_session( "email", client_secret, sid=sid, validated=True diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 05cebb5d4d..a744d68c64 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -19,7 +19,6 @@ from twisted.web.client import PartialDownloadError from synapse.api.constants import LoginType from synapse.api.errors import Codes, LoginError, SynapseError -from synapse.config.emailconfig import ThreepidBehaviour from synapse.util import json_decoder if TYPE_CHECKING: @@ -153,7 +152,7 @@ class _BaseThreepidAuthChecker: logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,)) - # msisdns are currently always ThreepidBehaviour.REMOTE + # msisdns are currently always verified via the IS if medium == "msisdn": if not self.hs.config.registration.account_threepid_delegate_msisdn: raise SynapseError( @@ -164,18 +163,7 @@ class _BaseThreepidAuthChecker: threepid_creds, ) elif medium == "email": - if ( - self.hs.config.email.threepid_behaviour_email - == ThreepidBehaviour.REMOTE - ): - assert self.hs.config.registration.account_threepid_delegate_email - threepid = await identity_handler.threepid_from_creds( - self.hs.config.registration.account_threepid_delegate_email, - threepid_creds, - ) - elif ( - self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL - ): + if self.hs.config.email.can_verify_email: threepid = None row = await self.store.get_threepid_validation_session( medium, @@ -227,10 +215,7 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec _BaseThreepidAuthChecker.__init__(self, hs) def is_enabled(self) -> bool: - return self.hs.config.email.threepid_behaviour_email in ( - ThreepidBehaviour.REMOTE, - ThreepidBehaviour.LOCAL, - ) + return self.hs.config.email.can_verify_email async def check_auth(self, authdict: dict, clientip: str) -> Any: return await self._check_threepid("email", authdict) diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index 9041e29d6c..1f9a8ccc23 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -29,7 +29,6 @@ from synapse.api.errors import ( SynapseError, ThreepidValidationError, ) -from synapse.config.emailconfig import ThreepidBehaviour from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http.server import HttpServer, finish_request, respond_with_html from synapse.http.servlet import ( @@ -68,7 +67,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): self.config = hs.config self.identity_handler = hs.get_identity_handler() - if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.config.email.can_verify_email: self.mailer = Mailer( hs=self.hs, app_name=self.config.email.email_app_name, @@ -77,11 +76,10 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): ) async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.config.email.local_threepid_handling_disabled_due_to_email_config: - logger.warning( - "User password resets have been disabled due to lack of email config" - ) + if not self.config.email.can_verify_email: + logger.warning( + "User password resets have been disabled due to lack of email config" + ) raise SynapseError( 400, "Email-based password resets have been disabled on this server" ) @@ -117,35 +115,20 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) - if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.registration.account_threepid_delegate_email - - # Have the configured identity server handle the request - ret = await self.identity_handler.request_email_token( - self.hs.config.registration.account_threepid_delegate_email, - body.email, - body.client_secret, - body.send_attempt, - body.next_link, - ) - else: - # Send password reset emails from Synapse - sid = await self.identity_handler.send_threepid_validation( - body.email, - body.client_secret, - body.send_attempt, - self.mailer.send_password_reset_mail, - body.next_link, - ) - - # Wrap the session id in a JSON object - ret = {"sid": sid} - + # Send password reset emails from Synapse + sid = await self.identity_handler.send_threepid_validation( + body.email, + body.client_secret, + body.send_attempt, + self.mailer.send_password_reset_mail, + body.next_link, + ) threepid_send_requests.labels(type="email", reason="password_reset").observe( body.send_attempt ) - return 200, ret + # Wrap the session id in a JSON object + return 200, {"sid": sid} class PasswordRestServlet(RestServlet): @@ -340,7 +323,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): self.identity_handler = hs.get_identity_handler() self.store = self.hs.get_datastores().main - if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.config.email.can_verify_email: self.mailer = Mailer( hs=self.hs, app_name=self.config.email.email_app_name, @@ -349,11 +332,10 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): ) async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.config.email.local_threepid_handling_disabled_due_to_email_config: - logger.warning( - "Adding emails have been disabled due to lack of an email config" - ) + if not self.config.email.can_verify_email: + logger.warning( + "Adding emails have been disabled due to lack of an email config" + ) raise SynapseError( 400, "Adding an email to your account is disabled on this server", @@ -391,35 +373,21 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) - if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.registration.account_threepid_delegate_email - - # Have the configured identity server handle the request - ret = await self.identity_handler.request_email_token( - self.hs.config.registration.account_threepid_delegate_email, - body.email, - body.client_secret, - body.send_attempt, - body.next_link, - ) - else: - # Send threepid validation emails from Synapse - sid = await self.identity_handler.send_threepid_validation( - body.email, - body.client_secret, - body.send_attempt, - self.mailer.send_add_threepid_mail, - body.next_link, - ) - - # Wrap the session id in a JSON object - ret = {"sid": sid} + # Send threepid validation emails from Synapse + sid = await self.identity_handler.send_threepid_validation( + body.email, + body.client_secret, + body.send_attempt, + self.mailer.send_add_threepid_mail, + body.next_link, + ) threepid_send_requests.labels(type="email", reason="add_threepid").observe( body.send_attempt ) - return 200, ret + # Wrap the session id in a JSON object + return 200, {"sid": sid} class MsisdnThreepidRequestTokenRestServlet(RestServlet): @@ -512,24 +480,18 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): self.config = hs.config self.clock = hs.get_clock() self.store = hs.get_datastores().main - if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.config.email.can_verify_email: self._failure_email_template = ( self.config.email.email_add_threepid_template_failure_html ) async def on_GET(self, request: Request) -> None: - if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.config.email.local_threepid_handling_disabled_due_to_email_config: - logger.warning( - "Adding emails have been disabled due to lack of an email config" - ) - raise SynapseError( - 400, "Adding an email to your account is disabled on this server" + if not self.config.email.can_verify_email: + logger.warning( + "Adding emails have been disabled due to lack of an email config" ) - elif self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: raise SynapseError( - 400, - "This homeserver is not validating threepids.", + 400, "Adding an email to your account is disabled on this server" ) sid = parse_string(request, "sid", required=True) diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 1b953d3fa0..20bab20c8f 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -31,7 +31,6 @@ from synapse.api.errors import ( ) from synapse.api.ratelimiting import Ratelimiter from synapse.config import ConfigError -from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.homeserver import HomeServerConfig from synapse.config.ratelimiting import FederationRatelimitSettings from synapse.config.server import is_threepid_reserved @@ -74,7 +73,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): self.identity_handler = hs.get_identity_handler() self.config = hs.config - if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.hs.config.email.can_verify_email: self.mailer = Mailer( hs=self.hs, app_name=self.config.email.email_app_name, @@ -83,13 +82,10 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): ) async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF: - if ( - self.hs.config.email.local_threepid_handling_disabled_due_to_email_config - ): - logger.warning( - "Email registration has been disabled due to lack of email config" - ) + if not self.hs.config.email.can_verify_email: + logger.warning( + "Email registration has been disabled due to lack of email config" + ) raise SynapseError( 400, "Email-based registration has been disabled on this server" ) @@ -138,35 +134,21 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) - if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.registration.account_threepid_delegate_email - - # Have the configured identity server handle the request - ret = await self.identity_handler.request_email_token( - self.hs.config.registration.account_threepid_delegate_email, - email, - client_secret, - send_attempt, - next_link, - ) - else: - # Send registration emails from Synapse, - # wrapping the session id in a JSON object. - ret = { - "sid": await self.identity_handler.send_threepid_validation( - email, - client_secret, - send_attempt, - self.mailer.send_registration_mail, - next_link, - ) - } + # Send registration emails from Synapse + sid = await self.identity_handler.send_threepid_validation( + email, + client_secret, + send_attempt, + self.mailer.send_registration_mail, + next_link, + ) threepid_send_requests.labels(type="email", reason="register").observe( send_attempt ) - return 200, ret + # Wrap the session id in a JSON object + return 200, {"sid": sid} class MsisdnRegisterRequestTokenRestServlet(RestServlet): @@ -260,7 +242,7 @@ class RegistrationSubmitTokenServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastores().main - if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.config.email.can_verify_email: self._failure_email_template = ( self.config.email.email_registration_template_failure_html ) @@ -270,11 +252,10 @@ class RegistrationSubmitTokenServlet(RestServlet): raise SynapseError( 400, "This medium is currently not supported for registration" ) - if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.config.email.local_threepid_handling_disabled_due_to_email_config: - logger.warning( - "User registration via email has been disabled due to lack of email config" - ) + if not self.config.email.can_verify_email: + logger.warning( + "User registration via email has been disabled due to lack of email config" + ) raise SynapseError( 400, "Email-based registration is disabled on this server" ) diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py index 6ac9dbc7c9..b9402cfb75 100644 --- a/synapse/rest/synapse/client/password_reset.py +++ b/synapse/rest/synapse/client/password_reset.py @@ -17,7 +17,6 @@ from typing import TYPE_CHECKING, Tuple from twisted.web.server import Request from synapse.api.errors import ThreepidValidationError -from synapse.config.emailconfig import ThreepidBehaviour from synapse.http.server import DirectServeHtmlResource from synapse.http.servlet import parse_string from synapse.util.stringutils import assert_valid_client_secret @@ -46,9 +45,6 @@ class PasswordResetSubmitTokenResource(DirectServeHtmlResource): self.clock = hs.get_clock() self.store = hs.get_datastores().main - self._local_threepid_handling_disabled_due_to_email_config = ( - hs.config.email.local_threepid_handling_disabled_due_to_email_config - ) self._confirmation_email_template = ( hs.config.email.email_password_reset_template_confirmation_html ) @@ -59,8 +55,8 @@ class PasswordResetSubmitTokenResource(DirectServeHtmlResource): hs.config.email.email_password_reset_template_failure_html ) - # This resource should not be mounted if threepid behaviour is not LOCAL - assert hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL + # This resource should only be mounted if email validation is enabled + assert hs.config.email.can_verify_email async def _async_render_GET(self, request: Request) -> Tuple[int, bytes]: sid = parse_string(request, "sid", required=True) diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index ab4277dd31..b781875d52 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -586,9 +586,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "require_at_registration": True, }, "account_threepid_delegates": { - "email": "https://id_server", "msisdn": "https://id_server", }, + "email": {"notif_from": "Synapse "}, } ) def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None: -- cgit 1.5.1 From d58615c82cec5bd866bedcb33e3e2a5d2a961c44 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Wed, 24 Aug 2022 14:13:12 -0500 Subject: Directly lookup local membership instead of getting all members in a room first (`get_users_in_room` mis-use) (#13608) See https://github.com/matrix-org/synapse/pull/13575#discussion_r953023755 --- changelog.d/13608.misc | 1 + synapse/handlers/events.py | 9 +++++--- synapse/handlers/message.py | 6 ++++-- synapse/handlers/room.py | 7 +++++-- synapse/handlers/room_member.py | 6 ++++-- synapse/server_notices/server_notices_manager.py | 10 +++++++-- synapse/storage/databases/main/roommember.py | 26 ++++++++++++++++++++++++ tests/rest/client/test_relations.py | 12 +++++------ 8 files changed, 60 insertions(+), 17 deletions(-) create mode 100644 changelog.d/13608.misc (limited to 'tests/rest') diff --git a/changelog.d/13608.misc b/changelog.d/13608.misc new file mode 100644 index 0000000000..19bcc45e33 --- /dev/null +++ b/changelog.d/13608.misc @@ -0,0 +1 @@ +Refactor `get_users_in_room(room_id)` mis-use to lookup single local user with dedicated `check_local_user_in_room(...)` function. diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index ac13340d3a..949b69cb41 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -151,7 +151,7 @@ class EventHandler: """Retrieve a single specified event. Args: - user: The user requesting the event + user: The local user requesting the event room_id: The expected room id. We'll return None if the event's room does not match. event_id: The event ID to obtain. @@ -173,8 +173,11 @@ class EventHandler: if not event: return None - users = await self.store.get_users_in_room(event.room_id) - is_peeking = user.to_string() not in users + is_user_in_room = await self.store.check_local_user_in_room( + user_id=user.to_string(), room_id=event.room_id + ) + # The user is peeking if they aren't in the room already + is_peeking = not is_user_in_room filtered = await filter_events_for_client( self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index acd3de06f6..72157d5a36 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -761,8 +761,10 @@ class EventCreationHandler: async def _is_server_notices_room(self, room_id: str) -> bool: if self.config.servernotices.server_notices_mxid is None: return False - user_ids = await self.store.get_users_in_room(room_id) - return self.config.servernotices.server_notices_mxid in user_ids + is_server_notices_room = await self.store.check_local_user_in_room( + user_id=self.config.servernotices.server_notices_mxid, room_id=room_id + ) + return is_server_notices_room async def assert_accepted_privacy_policy(self, requester: Requester) -> None: """Check if a user has accepted the privacy policy diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 2bf0ebd025..2fc8264858 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1284,8 +1284,11 @@ class RoomContextHandler: before_limit = math.floor(limit / 2.0) after_limit = limit - before_limit - users = await self.store.get_users_in_room(room_id) - is_peeking = user.to_string() not in users + is_user_in_room = await self.store.check_local_user_in_room( + user_id=user.to_string(), room_id=room_id + ) + # The user is peeking if they aren't in the room already + is_peeking = not is_user_in_room async def filter_evts(events: List[EventBase]) -> List[EventBase]: if use_admin_priviledge: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 65b9a655d4..709682622f 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1620,8 +1620,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): async def _is_server_notice_room(self, room_id: str) -> bool: if self._server_notices_mxid is None: return False - user_ids = await self.store.get_users_in_room(room_id) - return self._server_notices_mxid in user_ids + is_server_notices_room = await self.store.check_local_user_in_room( + user_id=self._server_notices_mxid, room_id=room_id + ) + return is_server_notices_room class RoomMemberMasterHandler(RoomMemberHandler): diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 70d054a8f4..564e3705c2 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -102,6 +102,10 @@ class ServerNoticesManager: Returns: The room's ID, or None if no room could be found. """ + # If there is no server notices MXID, then there is no server notices room + if self.server_notices_mxid is None: + return None + rooms = await self._store.get_rooms_for_local_user_where_membership_is( user_id, [Membership.INVITE, Membership.JOIN] ) @@ -111,8 +115,10 @@ class ServerNoticesManager: # be joined. This is kinda deliberate, in that if somebody somehow # manages to invite the system user to a room, that doesn't make it # the server notices room. - user_ids = await self._store.get_users_in_room(room.room_id) - if len(user_ids) <= 2 and self.server_notices_mxid in user_ids: + is_server_notices_room = await self._store.check_local_user_in_room( + user_id=self.server_notices_mxid, room_id=room.room_id + ) + if is_server_notices_room: # we found a room which our user shares with the system notice # user return room.room_id diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 046ad3a11c..9e5034b401 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -534,6 +534,32 @@ class RoomMemberWorkerStore(EventsWorkerStore): desc="get_local_users_in_room", ) + async def check_local_user_in_room(self, user_id: str, room_id: str) -> bool: + """ + Check whether a given local user is currently joined to the given room. + + Returns: + A boolean indicating whether the user is currently joined to the room + + Raises: + Exeption when called with a non-local user to this homeserver + """ + if not self.hs.is_mine_id(user_id): + raise Exception( + "Cannot call 'check_local_user_in_room' on " + "non-local user %s" % (user_id,), + ) + + ( + membership, + member_event_id, + ) = await self.get_local_current_membership_for_user_in_room( + user_id=user_id, + room_id=room_id, + ) + + return membership == Membership.JOIN + async def get_local_current_membership_for_user_in_room( self, user_id: str, room_id: str ) -> Tuple[Optional[str], Optional[str]]: diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index d589f07314..651f4f415d 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -999,7 +999,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations, ) - self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 6) + self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7) def test_annotation_to_annotation(self) -> None: """Any relation to an annotation should be ignored.""" @@ -1035,7 +1035,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations, ) - self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 6) + self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7) def test_thread(self) -> None: """ @@ -1080,21 +1080,21 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): # The "user" sent the root event and is making queries for the bundled # aggregations: they have participated. - self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8) + self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 9) # The "user2" sent replies in the thread and is making queries for the # bundled aggregations: they have participated. # # Note that this re-uses some cached values, so the total number of # queries is much smaller. self._test_bundled_aggregations( - RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token + RelationTypes.THREAD, _gen_assert(True), 3, access_token=self.user2_token ) # A user with no interactions with the thread: they have not participated. user3_id, user3_token = self._create_user("charlie") self.helper.join(self.room, user=user3_id, tok=user3_token) self._test_bundled_aggregations( - RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token + RelationTypes.THREAD, _gen_assert(False), 3, access_token=user3_token ) def test_thread_with_bundled_aggregations_for_latest(self) -> None: @@ -1142,7 +1142,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations["latest_event"].get("unsigned"), ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8) + self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9) def test_nested_thread(self) -> None: """ -- cgit 1.5.1 From a160406d245cb84b48fa67fe3ee73f0cffceb495 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 31 Aug 2022 11:38:16 +0100 Subject: Fix admin List Room API return type on sqlite (#13509) --- changelog.d/13509.bugfix | 1 + synapse/storage/databases/main/room.py | 6 ++++-- tests/rest/admin/test_room.py | 19 ++++++++++++++----- 3 files changed, 19 insertions(+), 7 deletions(-) create mode 100644 changelog.d/13509.bugfix (limited to 'tests/rest') diff --git a/changelog.d/13509.bugfix b/changelog.d/13509.bugfix new file mode 100644 index 0000000000..6dcb9741d9 --- /dev/null +++ b/changelog.d/13509.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.13 where the [List Rooms admin API](https://matrix-org.github.io/synapse/develop/admin_api/rooms.html#list-room-api) would return integers instead of booleans for the `federatable` and `public` fields when using a Sqlite database. diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index b7d4baa6bb..367424b4a8 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -641,8 +641,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): "version": room[5], "creator": room[6], "encryption": room[7], - "federatable": room[8], - "public": room[9], + # room_stats_state.federatable is an integer on sqlite. + "federatable": bool(room[8]), + # rooms.is_public is an integer on sqlite. + "public": bool(room[9]), "join_rules": room[10], "guest_access": room[11], "history_visibility": room[12], diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index fd6da557c1..9d71a97524 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1080,7 +1080,9 @@ class RoomTestCase(unittest.HomeserverTestCase): room_ids = [] for _ in range(total_rooms): room_id = self.helper.create_room_as( - self.admin_user, tok=self.admin_user_tok + self.admin_user, + tok=self.admin_user_tok, + is_public=True, ) room_ids.append(room_id) @@ -1119,8 +1121,8 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertIn("version", r) self.assertIn("creator", r) self.assertIn("encryption", r) - self.assertIn("federatable", r) - self.assertIn("public", r) + self.assertIs(r["federatable"], True) + self.assertIs(r["public"], True) self.assertIn("join_rules", r) self.assertIn("guest_access", r) self.assertIn("history_visibility", r) @@ -1587,8 +1589,12 @@ class RoomTestCase(unittest.HomeserverTestCase): def test_single_room(self) -> None: """Test that a single room can be requested correctly""" # Create two test rooms - room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_1 = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok, is_public=True + ) + room_id_2 = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok, is_public=False + ) room_name_1 = "something" room_name_2 = "else" @@ -1634,7 +1640,10 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertIn("state_events", channel.json_body) self.assertIn("room_type", channel.json_body) self.assertIn("forgotten", channel.json_body) + self.assertEqual(room_id_1, channel.json_body["room_id"]) + self.assertIs(True, channel.json_body["federatable"]) + self.assertIs(True, channel.json_body["public"]) def test_single_room_devices(self) -> None: """Test that `joined_local_devices` can be requested correctly""" -- cgit 1.5.1 From 84ddcd7bbfe4100101741a408a91f283a8f742c7 Mon Sep 17 00:00:00 2001 From: Jacek Kuśnierz Date: Wed, 31 Aug 2022 14:10:25 +0200 Subject: Drop support for calling `/_matrix/client/v3/rooms/{roomId}/invite` without an `id_access_token` (#13241) Fixes #13206 Signed-off-by: Jacek Kusnierz jacek.kusnierz@tum.de --- changelog.d/13241.removal | 1 + synapse/handlers/identity.py | 142 +++++------------------------- synapse/handlers/room.py | 20 ++++- synapse/handlers/room_member.py | 6 +- synapse/rest/client/room.py | 20 +++-- synapse/rest/media/v1/media_repository.py | 1 - tests/rest/client/test_identity.py | 3 +- tests/rest/client/test_rooms.py | 18 ++++ tests/rest/client/test_shadow_banned.py | 7 +- 9 files changed, 81 insertions(+), 137 deletions(-) create mode 100644 changelog.d/13241.removal (limited to 'tests/rest') diff --git a/changelog.d/13241.removal b/changelog.d/13241.removal new file mode 100644 index 0000000000..60b0e7969c --- /dev/null +++ b/changelog.d/13241.removal @@ -0,0 +1 @@ +Drop support for calling `/_matrix/client/v3/rooms/{roomId}/invite` without an `id_access_token`, which was not permitted by the spec. Contributed by @Vetchu. \ No newline at end of file diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 9571d461c8..93d09e9939 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -538,11 +538,7 @@ class IdentityHandler: raise SynapseError(400, "Error contacting the identity server") async def lookup_3pid( - self, - id_server: str, - medium: str, - address: str, - id_access_token: Optional[str] = None, + self, id_server: str, medium: str, address: str, id_access_token: str ) -> Optional[str]: """Looks up a 3pid in the passed identity server. @@ -557,60 +553,15 @@ class IdentityHandler: Returns: the matrix ID of the 3pid, or None if it is not recognized. """ - if id_access_token is not None: - try: - results = await self._lookup_3pid_v2( - id_server, id_access_token, medium, address - ) - return results - - except Exception as e: - # Catch HttpResponseExcept for a non-200 response code - # Check if this identity server does not know about v2 lookups - if isinstance(e, HttpResponseException) and e.code == 404: - # This is an old identity server that does not yet support v2 lookups - logger.warning( - "Attempted v2 lookup on v1 identity server %s. Falling " - "back to v1", - id_server, - ) - else: - logger.warning("Error when looking up hashing details: %s", e) - return None - - return await self._lookup_3pid_v1(id_server, medium, address) - - async def _lookup_3pid_v1( - self, id_server: str, medium: str, address: str - ) -> Optional[str]: - """Looks up a 3pid in the passed identity server using v1 lookup. - Args: - id_server: The server name (including port, if required) - of the identity server to use. - medium: The type of the third party identifier (e.g. "email"). - address: The third party identifier (e.g. "foo@example.com"). - - Returns: - the matrix ID of the 3pid, or None if it is not recognized. - """ try: - data = await self.blacklisting_http_client.get_json( - "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server), - {"medium": medium, "address": address}, + results = await self._lookup_3pid_v2( + id_server, id_access_token, medium, address ) - - if "mxid" in data: - # note: we used to verify the identity server's signature here, but no longer - # require or validate it. See the following for context: - # https://github.com/matrix-org/synapse/issues/5253#issuecomment-666246950 - return data["mxid"] - except RequestTimedOutError: - raise SynapseError(500, "Timed out contacting identity server") - except OSError as e: - logger.warning("Error from v1 identity server lookup: %s" % (e,)) - - return None + return results + except Exception as e: + logger.warning("Error when looking up hashing details: %s", e) + return None async def _lookup_3pid_v2( self, id_server: str, id_access_token: str, medium: str, address: str @@ -739,7 +690,7 @@ class IdentityHandler: room_type: Optional[str], inviter_display_name: str, inviter_avatar_url: str, - id_access_token: Optional[str] = None, + id_access_token: str, ) -> Tuple[str, List[Dict[str, str]], Dict[str, str], str]: """ Asks an identity server for a third party invite. @@ -760,7 +711,7 @@ class IdentityHandler: inviter_display_name: The current display name of the inviter. inviter_avatar_url: The URL of the inviter's avatar. - id_access_token (str|None): The access token to authenticate to the identity + id_access_token (str): The access token to authenticate to the identity server with Returns: @@ -792,71 +743,24 @@ class IdentityHandler: invite_config["org.matrix.web_client_location"] = self._web_client_location # Add the identity service access token to the JSON body and use the v2 - # Identity Service endpoints if id_access_token is present + # Identity Service endpoints data = None - base_url = "%s%s/_matrix/identity" % (id_server_scheme, id_server) - if id_access_token: - key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % ( - id_server_scheme, - id_server, - ) + key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % ( + id_server_scheme, + id_server, + ) - # Attempt a v2 lookup - url = base_url + "/v2/store-invite" - try: - data = await self.blacklisting_http_client.post_json_get_json( - url, - invite_config, - {"Authorization": create_id_access_token_header(id_access_token)}, - ) - except RequestTimedOutError: - raise SynapseError(500, "Timed out contacting identity server") - except HttpResponseException as e: - if e.code != 404: - logger.info("Failed to POST %s with JSON: %s", url, e) - raise e - - if data is None: - key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( - id_server_scheme, - id_server, + url = "%s%s/_matrix/identity/v2/store-invite" % (id_server_scheme, id_server) + try: + data = await self.blacklisting_http_client.post_json_get_json( + url, + invite_config, + {"Authorization": create_id_access_token_header(id_access_token)}, ) - url = base_url + "/api/v1/store-invite" - - try: - data = await self.blacklisting_http_client.post_json_get_json( - url, invite_config - ) - except RequestTimedOutError: - raise SynapseError(500, "Timed out contacting identity server") - except HttpResponseException as e: - logger.warning( - "Error trying to call /store-invite on %s%s: %s", - id_server_scheme, - id_server, - e, - ) - - if data is None: - # Some identity servers may only support application/x-www-form-urlencoded - # types. This is especially true with old instances of Sydent, see - # https://github.com/matrix-org/sydent/pull/170 - try: - data = await self.blacklisting_http_client.post_urlencoded_get_json( - url, invite_config - ) - except HttpResponseException as e: - logger.warning( - "Error calling /store-invite on %s%s with fallback " - "encoding: %s", - id_server_scheme, - id_server, - e, - ) - raise e - - # TODO: Check for success + except RequestTimedOutError: + raise SynapseError(500, "Timed out contacting identity server") + token = data["token"] public_keys = data.get("public_keys", []) if "public_key" in data: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index f64a8690a5..33e9a87002 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -19,6 +19,7 @@ import math import random import string from collections import OrderedDict +from http import HTTPStatus from typing import ( TYPE_CHECKING, Any, @@ -704,8 +705,8 @@ class RoomCreationHandler: was, requested, `room_alias`. Secondly, the stream_id of the last persisted event. Raises: - SynapseError if the room ID couldn't be stored, or something went - horribly wrong. + SynapseError if the room ID couldn't be stored, 3pid invitation config + validation failed, or something went horribly wrong. ResourceLimitError if server is blocked to some resource being exceeded """ @@ -731,6 +732,19 @@ class RoomCreationHandler: invite_3pid_list = config.get("invite_3pid", []) invite_list = config.get("invite", []) + # validate each entry for correctness + for invite_3pid in invite_3pid_list: + if not all( + key in invite_3pid + for key in ("medium", "address", "id_server", "id_access_token") + ): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "all of `medium`, `address`, `id_server` and `id_access_token` " + "are required when making a 3pid invite", + Codes.MISSING_PARAM, + ) + if not is_requester_admin: spam_check = await self.spam_checker.user_may_create_room(user_id) if spam_check != NOT_SPAM: @@ -978,7 +992,7 @@ class RoomCreationHandler: for invite_3pid in invite_3pid_list: id_server = invite_3pid["id_server"] - id_access_token = invite_3pid.get("id_access_token") # optional + id_access_token = invite_3pid["id_access_token"] address = invite_3pid["address"] medium = invite_3pid["medium"] # Note that do_3pid_invite can raise a ShadowBanError, but this was diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index e726997d83..5d4adf5bfd 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1382,7 +1382,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): id_server: str, requester: Requester, txn_id: Optional[str], - id_access_token: Optional[str] = None, + id_access_token: str, prev_event_ids: Optional[List[str]] = None, depth: Optional[int] = None, ) -> Tuple[str, int]: @@ -1397,7 +1397,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): requester: The user making the request. txn_id: The transaction ID this is part of, or None if this is not part of a transaction. - id_access_token: The optional identity server access token. + id_access_token: Identity server access token. depth: Override the depth used to order the event in the DAG. prev_event_ids: The event IDs to use as the prev events Should normally be set to None, which will cause the depth to be calculated @@ -1494,7 +1494,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): room_id: str, user: UserID, txn_id: Optional[str], - id_access_token: Optional[str] = None, + id_access_token: str, prev_event_ids: Optional[List[str]] = None, depth: Optional[int] = None, ) -> Tuple[EventBase, int]: diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 0e2834008e..0bca012535 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -17,6 +17,7 @@ import logging import re from enum import Enum +from http import HTTPStatus from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple from urllib import parse as urlparse @@ -947,7 +948,16 @@ class RoomMembershipRestServlet(TransactionRestServlet): # cheekily send invalid bodies. content = {} - if membership_action == "invite" and self._has_3pid_invite_keys(content): + if membership_action == "invite" and all( + key in content for key in ("medium", "address") + ): + if not all(key in content for key in ("id_server", "id_access_token")): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "`id_server` and `id_access_token` are required when doing 3pid invite", + Codes.MISSING_PARAM, + ) + try: await self.room_member_handler.do_3pid_invite( room_id, @@ -957,7 +967,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): content["id_server"], requester, txn_id, - content.get("id_access_token"), + content["id_access_token"], ) except ShadowBanError: # Pretend the request succeeded. @@ -994,12 +1004,6 @@ class RoomMembershipRestServlet(TransactionRestServlet): return 200, return_value - def _has_3pid_invite_keys(self, content: JsonDict) -> bool: - for key in {"id_server", "medium", "address"}: - if key not in content: - return False - return True - def on_PUT( self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str ) -> Awaitable[Tuple[int, JsonDict]]: diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 7435fd9130..9dd3c8d4bb 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -64,7 +64,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) - # How often to run the background job to update the "recently accessed" # attribute of local and remote media. UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000 # 1 minute diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index dc17c9d113..b0c8215744 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -25,7 +25,6 @@ from tests import unittest class IdentityTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -33,7 +32,6 @@ class IdentityTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["enable_3pid_lookup"] = False self.hs = self.setup_test_homeserver(config=config) @@ -54,6 +52,7 @@ class IdentityTestCase(unittest.HomeserverTestCase): "id_server": "testis", "medium": "email", "address": "test@example.com", + "id_access_token": tok, } request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii") channel = self.make_request( diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index aa2f578441..c7eb88d33f 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -3461,3 +3461,21 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): # Also check that it stopped before calling _make_and_store_3pid_invite. make_invite_mock.assert_called_once() + + def test_400_missing_param_without_id_access_token(self) -> None: + """ + Test that a 3pid invite request returns 400 M_MISSING_PARAM + if we do not include id_access_token. + """ + channel = self.make_request( + method="POST", + path="/rooms/" + self.room_id + "/invite", + content={ + "id_server": "example.com", + "medium": "email", + "address": "teresa@example.com", + }, + access_token=self.tok, + ) + self.assertEqual(channel.code, 400) + self.assertEqual(channel.json_body["errcode"], "M_MISSING_PARAM") diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index c50f034b34..c807a37bc2 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -97,7 +97,12 @@ class RoomTestCase(_ShadowBannedBase): channel = self.make_request( "POST", "/rooms/%s/invite" % (room_id,), - {"id_server": "test", "medium": "email", "address": "test@test.test"}, + { + "id_server": "test", + "medium": "email", + "address": "test@test.test", + "id_access_token": "anytoken", + }, access_token=self.banned_access_token, ) self.assertEqual(200, channel.code, channel.result) -- cgit 1.5.1 From 0e99f07952edcb6396654e34da50ddeb0a211067 Mon Sep 17 00:00:00 2001 From: Šimon Brandner Date: Thu, 1 Sep 2022 14:31:54 +0200 Subject: Remove support for unstable private read receipts (#13653) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Šimon Brandner --- changelog.d/13653.removal | 1 + synapse/api/constants.py | 1 - synapse/config/experimental.py | 3 -- synapse/handlers/receipts.py | 29 +++---------- synapse/replication/tcp/client.py | 5 +-- synapse/rest/client/notifications.py | 1 - synapse/rest/client/read_marker.py | 2 - synapse/rest/client/receipts.py | 2 - synapse/rest/client/versions.py | 1 - .../storage/databases/main/event_push_actions.py | 2 - tests/handlers/test_receipts.py | 48 ++++++---------------- tests/rest/client/test_sync.py | 37 +++++------------ tests/storage/test_receipts.py | 34 ++++++--------- 13 files changed, 44 insertions(+), 122 deletions(-) create mode 100644 changelog.d/13653.removal (limited to 'tests/rest') diff --git a/changelog.d/13653.removal b/changelog.d/13653.removal new file mode 100644 index 0000000000..eb075d4517 --- /dev/null +++ b/changelog.d/13653.removal @@ -0,0 +1 @@ +Remove support for unstable [private read receipts](https://github.com/matrix-org/matrix-spec-proposals/pull/2285). diff --git a/synapse/api/constants.py b/synapse/api/constants.py index c73aea622a..c178ddf070 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -258,7 +258,6 @@ class GuestAccess: class ReceiptTypes: READ: Final = "m.read" READ_PRIVATE: Final = "m.read.private" - UNSTABLE_READ_PRIVATE: Final = "org.matrix.msc2285.read.private" FULLY_READ: Final = "m.fully_read" diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index c1ff417539..260db49cad 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -32,9 +32,6 @@ class ExperimentalConfig(Config): # MSC2716 (importing historical messages) self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False) - # MSC2285 (unstable private read receipts) - self.msc2285_enabled: bool = experimental.get("msc2285_enabled", False) - # MSC3244 (room version capabilities) self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index d4a866b346..d2bdb9c8be 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -163,10 +163,7 @@ class ReceiptsHandler: if not is_new: return - if self.federation_sender and receipt_type not in ( - ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, - ): + if self.federation_sender and receipt_type != ReceiptTypes.READ_PRIVATE: await self.federation_sender.send_read_receipt(receipt) @@ -206,38 +203,24 @@ class ReceiptEventSource(EventSource[int, JsonDict]): for event_id, orig_event_content in room.get("content", {}).items(): event_content = orig_event_content # If there are private read receipts, additional logic is necessary. - if ( - ReceiptTypes.READ_PRIVATE in event_content - or ReceiptTypes.UNSTABLE_READ_PRIVATE in event_content - ): + if ReceiptTypes.READ_PRIVATE in event_content: # Make a copy without private read receipts to avoid leaking # other user's private read receipts.. event_content = { receipt_type: receipt_value for receipt_type, receipt_value in event_content.items() - if receipt_type - not in ( - ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, - ) + if receipt_type != ReceiptTypes.READ_PRIVATE } # Copy the current user's private read receipt from the # original content, if it exists. - user_private_read_receipt = orig_event_content.get( - ReceiptTypes.READ_PRIVATE, {} - ).get(user_id, None) + user_private_read_receipt = orig_event_content[ + ReceiptTypes.READ_PRIVATE + ].get(user_id, None) if user_private_read_receipt: event_content[ReceiptTypes.READ_PRIVATE] = { user_id: user_private_read_receipt } - user_unstable_private_read_receipt = orig_event_content.get( - ReceiptTypes.UNSTABLE_READ_PRIVATE, {} - ).get(user_id, None) - if user_unstable_private_read_receipt: - event_content[ReceiptTypes.UNSTABLE_READ_PRIVATE] = { - user_id: user_unstable_private_read_receipt - } # Include the event if there is at least one non-private read # receipt or the current user has a private read receipt. diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 1ed7230e32..e4f2201c92 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -416,10 +416,7 @@ class FederationSenderHandler: if not self._is_mine_id(receipt.user_id): continue # Private read receipts never get sent over federation. - if receipt.receipt_type in ( - ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, - ): + if receipt.receipt_type == ReceiptTypes.READ_PRIVATE: continue receipt_info = ReadReceipt( receipt.room_id, diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index a73322a6a4..61268e3af1 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -62,7 +62,6 @@ class NotificationsServlet(RestServlet): [ ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ], ) diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index aaad8b233f..5e53096539 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -45,8 +45,6 @@ class ReadMarkerRestServlet(RestServlet): ReceiptTypes.FULLY_READ, ReceiptTypes.READ_PRIVATE, } - if hs.config.experimental.msc2285_enabled: - self._known_receipt_types.add(ReceiptTypes.UNSTABLE_READ_PRIVATE) async def on_POST( self, request: SynapseRequest, room_id: str diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index c6108fc5eb..5b7fad7402 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -49,8 +49,6 @@ class ReceiptRestServlet(RestServlet): ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ, } - if hs.config.experimental.msc2285_enabled: - self._known_receipt_types.add(ReceiptTypes.UNSTABLE_READ_PRIVATE) async def on_POST( self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index c9a830cbac..c516cda95d 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -95,7 +95,6 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled, # Supports receiving private read receipts as per MSC2285 "org.matrix.msc2285.stable": True, # TODO: Remove when MSC2285 becomes a part of the spec - "org.matrix.msc2285": self.config.experimental.msc2285_enabled, # Supports filtering of /publicRooms by room type as per MSC3827 "org.matrix.msc3827.stable": True, # Adds support for importing historical messages as per MSC2716 diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 9f410d69de..f4a07de2a3 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -274,7 +274,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas receipt_types=( ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ), ) @@ -468,7 +467,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ( ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ), ) diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py index 5f70a2db79..b55238650c 100644 --- a/tests/handlers/test_receipts.py +++ b/tests/handlers/test_receipts.py @@ -15,8 +15,6 @@ from copy import deepcopy from typing import List -from parameterized import parameterized - from synapse.api.constants import EduTypes, ReceiptTypes from synapse.types import JsonDict @@ -27,16 +25,13 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.event_source = hs.get_event_sources().sources.receipt - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_filters_out_private_receipt(self, receipt_type: str) -> None: + def test_filters_out_private_receipt(self) -> None: self._test_filters_private( [ { "content": { "$1435641916114394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, } @@ -50,18 +45,13 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): [], ) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_filters_out_private_receipt_and_ignores_rest( - self, receipt_type: str - ) -> None: + def test_filters_out_private_receipt_and_ignores_rest(self) -> None: self._test_filters_private( [ { "content": { "$1dgdgrd5641916114394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, }, @@ -94,18 +84,15 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) def test_filters_out_event_with_only_private_receipts_and_ignores_the_rest( - self, receipt_type: str + self, ) -> None: self._test_filters_private( [ { "content": { "$14356419edgd14394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, }, @@ -175,18 +162,15 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) def test_filters_out_receipt_event_with_only_private_receipt_and_ignores_rest( - self, receipt_type: str + self, ) -> None: self._test_filters_private( [ { "content": { "$14356419edgd14394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, }, @@ -262,16 +246,13 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_leaves_our_private_and_their_public(self, receipt_type: str) -> None: + def test_leaves_our_private_and_their_public(self) -> None: self._test_filters_private( [ { "content": { "$1dgdgrd5641916114394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@me:server.org": { "ts": 1436451550453, }, @@ -296,7 +277,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): { "content": { "$1dgdgrd5641916114394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@me:server.org": { "ts": 1436451550453, }, @@ -319,16 +300,13 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_we_do_not_mutate(self, receipt_type: str) -> None: + def test_we_do_not_mutate(self) -> None: """Ensure the input values are not modified.""" events = [ { "content": { "$1435641916114394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, } diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index de0dec8539..0af643ecd9 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -391,7 +391,6 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() - config["experimental_features"] = {"msc2285_enabled": True} return self.setup_test_homeserver(config=config) @@ -413,17 +412,14 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Join the second user self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_private_read_receipts(self, receipt_type: str) -> None: + def test_private_read_receipts(self) -> None: # Send a message as the first user res = self.helper.send(self.room_id, body="hello", tok=self.tok) # Send a private read receipt to tell the server the first user's message was read channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", {}, access_token=self.tok2, ) @@ -432,10 +428,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Test that the first user can't see the other user's private read receipt self.assertIsNone(self._get_read_receipt()) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_public_receipt_can_override_private(self, receipt_type: str) -> None: + def test_public_receipt_can_override_private(self) -> None: """ Sending a public read receipt to the same event which has a private read receipt should cause that receipt to become public. @@ -446,7 +439,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Send a private read receipt channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", {}, access_token=self.tok2, ) @@ -465,10 +458,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Test that we did override the private read receipt self.assertNotEqual(self._get_read_receipt(), None) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_private_receipt_cannot_override_public(self, receipt_type: str) -> None: + def test_private_receipt_cannot_override_public(self) -> None: """ Sending a private read receipt to the same event which has a public read receipt should cause no change. @@ -489,7 +479,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Send a private read receipt channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", {}, access_token=self.tok2, ) @@ -554,7 +544,6 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): config = super().default_config() config["experimental_features"] = { "msc2654_enabled": True, - "msc2285_enabled": True, } return config @@ -601,10 +590,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): tok=self.tok, ) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_unread_counts(self, receipt_type: str) -> None: + def test_unread_counts(self) -> None: """Tests that /sync returns the right value for the unread count (MSC2654).""" # Check that our own messages don't increase the unread count. @@ -638,7 +624,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # Send a read receipt to tell the server we've read the latest event. channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", {}, access_token=self.tok, ) @@ -726,7 +712,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res2['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res2['event_id']}", {}, access_token=self.tok, ) @@ -738,7 +724,6 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): [ ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ] ) def test_read_receipts_only_go_down(self, receipt_type: str) -> None: @@ -752,7 +737,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # Read last event channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res2['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res2['event_id']}", {}, access_token=self.tok, ) @@ -763,7 +748,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # read receipt go up to an older event channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res1['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res1['event_id']}", {}, access_token=self.tok, ) diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py index 191c957fb5..c89bfff241 100644 --- a/tests/storage/test_receipts.py +++ b/tests/storage/test_receipts.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from parameterized import parameterized from synapse.api.constants import ReceiptTypes from synapse.types import UserID, create_requester @@ -92,7 +91,6 @@ class ReceiptTestCase(HomeserverTestCase): [ ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ], ) ) @@ -104,7 +102,6 @@ class ReceiptTestCase(HomeserverTestCase): [ ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ], ) ) @@ -117,16 +114,12 @@ class ReceiptTestCase(HomeserverTestCase): [ ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ], ) ) self.assertEqual(res, None) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_get_receipts_for_user(self, receipt_type: str) -> None: + def test_get_receipts_for_user(self) -> None: # Send some events into the first room event1_1_id = self.create_and_send_event( self.room_id1, UserID.from_string(OTHER_USER_ID) @@ -144,14 +137,14 @@ class ReceiptTestCase(HomeserverTestCase): # Send private read receipt for the second event self.get_success( self.store.insert_receipt( - self.room_id1, receipt_type, OUR_USER_ID, [event1_2_id], {} + self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} ) ) # Test we get the latest event when we want both private and public receipts res = self.get_success( self.store.get_receipts_for_user( - OUR_USER_ID, [ReceiptTypes.READ, receipt_type] + OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] ) ) self.assertEqual(res, {self.room_id1: event1_2_id}) @@ -164,7 +157,7 @@ class ReceiptTestCase(HomeserverTestCase): # Test we get the latest event when we want only the public receipt res = self.get_success( - self.store.get_receipts_for_user(OUR_USER_ID, [receipt_type]) + self.store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ_PRIVATE]) ) self.assertEqual(res, {self.room_id1: event1_2_id}) @@ -187,20 +180,17 @@ class ReceiptTestCase(HomeserverTestCase): # Test new room is reflected in what the method returns self.get_success( self.store.insert_receipt( - self.room_id2, receipt_type, OUR_USER_ID, [event2_1_id], {} + self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} ) ) res = self.get_success( self.store.get_receipts_for_user( - OUR_USER_ID, [ReceiptTypes.READ, receipt_type] + OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] ) ) self.assertEqual(res, {self.room_id1: event1_2_id, self.room_id2: event2_1_id}) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_get_last_receipt_event_id_for_user(self, receipt_type: str) -> None: + def test_get_last_receipt_event_id_for_user(self) -> None: # Send some events into the first room event1_1_id = self.create_and_send_event( self.room_id1, UserID.from_string(OTHER_USER_ID) @@ -218,7 +208,7 @@ class ReceiptTestCase(HomeserverTestCase): # Send private read receipt for the second event self.get_success( self.store.insert_receipt( - self.room_id1, receipt_type, OUR_USER_ID, [event1_2_id], {} + self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} ) ) @@ -227,7 +217,7 @@ class ReceiptTestCase(HomeserverTestCase): self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id1, - [ReceiptTypes.READ, receipt_type], + [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], ) ) self.assertEqual(res, event1_2_id) @@ -243,7 +233,7 @@ class ReceiptTestCase(HomeserverTestCase): # Test we get the latest event when we want only the private receipt res = self.get_success( self.store.get_last_receipt_event_id_for_user( - OUR_USER_ID, self.room_id1, [receipt_type] + OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE] ) ) self.assertEqual(res, event1_2_id) @@ -269,14 +259,14 @@ class ReceiptTestCase(HomeserverTestCase): # Test new room is reflected in what the method returns self.get_success( self.store.insert_receipt( - self.room_id2, receipt_type, OUR_USER_ID, [event2_1_id], {} + self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} ) ) res = self.get_success( self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id2, - [ReceiptTypes.READ, receipt_type], + [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], ) ) self.assertEqual(res, event2_1_id) -- cgit 1.5.1 From bb5b47b62a11b14a3458e5a8aafd9ddaf1294199 Mon Sep 17 00:00:00 2001 From: Connor Davis Date: Wed, 7 Sep 2022 05:54:44 -0400 Subject: Add Admin API to Fetch Messages Within a Particular Window (#13672) This adds two new admin APIs that allow us to fetch messages from a room within a particular time. --- changelog.d/13672.feature | 1 + docs/admin_api/rooms.md | 145 +++++++++++++++++++++++++++++++++++++ synapse/handlers/pagination.py | 37 ++++++---- synapse/rest/admin/__init__.py | 4 ++ synapse/rest/admin/rooms.py | 104 +++++++++++++++++++++++++++ tests/rest/admin/test_room.py | 158 ++++++++++++++++++++++++++++++++++++++++- 6 files changed, 435 insertions(+), 14 deletions(-) create mode 100644 changelog.d/13672.feature (limited to 'tests/rest') diff --git a/changelog.d/13672.feature b/changelog.d/13672.feature new file mode 100644 index 0000000000..2334e6fe15 --- /dev/null +++ b/changelog.d/13672.feature @@ -0,0 +1 @@ +Add admin APIs to fetch messages within a particular window of time. diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index 7526956bec..8f727b363e 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -393,6 +393,151 @@ A response body like the following is returned: } ``` +# Room Messages API + +The Room Messages admin API allows server admins to get all messages +sent to a room in a given timeframe. There are various parameters available +that allow for filtering and ordering the returned list. This API supports pagination. + +To use it, you will need to authenticate by providing an `access_token` +for a server admin: see [Admin API](../usage/administration/admin_api). + +This endpoint mirrors the [Matrix Spec defined Messages API](https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3roomsroomidmessages). + +The API is: +``` +GET /_synapse/admin/v1/rooms//messages +``` + +**Parameters** + +The following path parameters are required: + +* `room_id` - The ID of the room you wish you fetch messages from. + +The following query parameters are available: + +* `from` (required) - The token to start returning events from. This token can be obtained from a prev_batch + or next_batch token returned by the /sync endpoint, or from an end token returned by a previous request to this endpoint. +* `to` - The token to spot returning events at. +* `limit` - The maximum number of events to return. Defaults to `10`. +* `filter` - A JSON RoomEventFilter to filter returned events with. +* `dir` - The direction to return events from. Either `f` for forwards or `b` for backwards. Setting + this value to `b` will reverse the above sort order. Defaults to `f`. + +**Response** + +The following fields are possible in the JSON response body: + +* `chunk` - A list of room events. The order depends on the dir parameter. + Note that an empty chunk does not necessarily imply that no more events are available. Clients should continue to paginate until no end property is returned. +* `end` - A token corresponding to the end of chunk. This token can be passed back to this endpoint to request further events. + If no further events are available, this property is omitted from the response. +* `start` - A token corresponding to the start of chunk. +* `state` - A list of state events relevant to showing the chunk. + +**Example** + +For more details on each chunk, read [the Matrix specification](https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3roomsroomidmessages). + +```json +{ + "chunk": [ + { + "content": { + "body": "This is an example text message", + "format": "org.matrix.custom.html", + "formatted_body": "This is an example text message", + "msgtype": "m.text" + }, + "event_id": "$143273582443PhrSn:example.org", + "origin_server_ts": 1432735824653, + "room_id": "!636q39766251:example.com", + "sender": "@example:example.org", + "type": "m.room.message", + "unsigned": { + "age": 1234 + } + }, + { + "content": { + "name": "The room name" + }, + "event_id": "$143273582443PhrSn:example.org", + "origin_server_ts": 1432735824653, + "room_id": "!636q39766251:example.com", + "sender": "@example:example.org", + "state_key": "", + "type": "m.room.name", + "unsigned": { + "age": 1234 + } + }, + { + "content": { + "body": "Gangnam Style", + "info": { + "duration": 2140786, + "h": 320, + "mimetype": "video/mp4", + "size": 1563685, + "thumbnail_info": { + "h": 300, + "mimetype": "image/jpeg", + "size": 46144, + "w": 300 + }, + "thumbnail_url": "mxc://example.org/FHyPlCeYUSFFxlgbQYZmoEoe", + "w": 480 + }, + "msgtype": "m.video", + "url": "mxc://example.org/a526eYUSFFxlgbQYZmo442" + }, + "event_id": "$143273582443PhrSn:example.org", + "origin_server_ts": 1432735824653, + "room_id": "!636q39766251:example.com", + "sender": "@example:example.org", + "type": "m.room.message", + "unsigned": { + "age": 1234 + } + } + ], + "end": "t47409-4357353_219380_26003_2265", + "start": "t47429-4392820_219380_26003_2265" +} +``` + +# Room Timestamp to Event API + +The Room Timestamp to Event API endpoint fetches the `event_id` of the closest event to the given +timestamp (`ts` query parameter) in the given direction (`dir` query parameter). + +Useful for cases like jump to date so you can start paginating messages from +a given date in the archive. + +The API is: +``` + GET /_synapse/admin/v1/rooms//timestamp_to_event +``` + +**Parameters** + +The following path parameters are required: + +* `room_id` - The ID of the room you wish to check. + +The following query parameters are available: + +* `ts` - a timestamp in milliseconds where we will find the closest event in + the given direction. +* `dir` - can be `f` or `b` to indicate forwards and backwards in time from the + given timestamp. Defaults to `f`. + +**Response** + +* `event_id` - converted from timestamp + # Block Room API The Block Room admin API allows server admins to block and unblock rooms, and query to see if a given room is blocked. diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index a0c39778ab..1f83bab836 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -26,6 +26,7 @@ from synapse.events.utils import SerializeEventConfig from synapse.handlers.room import ShutdownRoomResponse from synapse.logging.opentracing import trace from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.rest.admin._base import assert_user_is_admin from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, Requester, StreamKeyType @@ -423,6 +424,7 @@ class PaginationHandler: pagin_config: PaginationConfig, as_client_event: bool = True, event_filter: Optional[Filter] = None, + use_admin_priviledge: bool = False, ) -> JsonDict: """Get messages in a room. @@ -432,10 +434,16 @@ class PaginationHandler: pagin_config: The pagination config rules to apply, if any. as_client_event: True to get events in client-server format. event_filter: Filter to apply to results or None + use_admin_priviledge: if `True`, return all events, regardless + of whether `user` has access to them. To be used **ONLY** + from the admin API. Returns: Pagination API results """ + if use_admin_priviledge: + await assert_user_is_admin(self.auth, requester) + user_id = requester.user.to_string() if pagin_config.from_token: @@ -458,12 +466,14 @@ class PaginationHandler: room_token = from_token.room_key async with self.pagination_lock.read(room_id): - ( - membership, - member_event_id, - ) = await self.auth.check_user_in_room_or_world_readable( - room_id, requester, allow_departed_users=True - ) + (membership, member_event_id) = (None, None) + if not use_admin_priviledge: + ( + membership, + member_event_id, + ) = await self.auth.check_user_in_room_or_world_readable( + room_id, requester, allow_departed_users=True + ) if pagin_config.direction == "b": # if we're going backwards, we might need to backfill. This @@ -475,7 +485,7 @@ class PaginationHandler: room_id, room_token.stream ) - if membership == Membership.LEAVE: + if not use_admin_priviledge and membership == Membership.LEAVE: # If they have left the room then clamp the token to be before # they left the room, to save the effort of loading from the # database. @@ -528,12 +538,13 @@ class PaginationHandler: if event_filter: events = await event_filter.filter(events) - events = await filter_events_for_client( - self._storage_controllers, - user_id, - events, - is_peeking=(member_event_id is None), - ) + if not use_admin_priviledge: + events = await filter_events_for_client( + self._storage_controllers, + user_id, + events, + is_peeking=(member_event_id is None), + ) # if after the filter applied there are no more events # return immediately - but there might be more in next_token batch diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index fa3266720b..bac754e1b1 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -61,9 +61,11 @@ from synapse.rest.admin.rooms import ( MakeRoomAdminRestServlet, RoomEventContextServlet, RoomMembersRestServlet, + RoomMessagesRestServlet, RoomRestServlet, RoomRestV2Servlet, RoomStateRestServlet, + RoomTimestampToEventRestServlet, ) from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet from synapse.rest.admin.statistics import UserMediaStatisticsRestServlet @@ -271,6 +273,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: DestinationResetConnectionRestServlet(hs).register(http_server) DestinationRestServlet(hs).register(http_server) ListDestinationsRestServlet(hs).register(http_server) + RoomMessagesRestServlet(hs).register(http_server) + RoomTimestampToEventRestServlet(hs).register(http_server) # Some servlets only get registered for the main process. if hs.config.worker.worker_app is None: diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 3d870629c4..747e6fda83 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -35,6 +35,7 @@ from synapse.rest.admin._base import ( ) from synapse.storage.databases.main.room import RoomSortOrder from synapse.storage.state import StateFilter +from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, RoomID, UserID, create_requester from synapse.util import json_decoder @@ -858,3 +859,106 @@ class BlockRoomRestServlet(RestServlet): await self._store.unblock_room(room_id) return HTTPStatus.OK, {"block": block} + + +class RoomMessagesRestServlet(RestServlet): + """ + Get messages list of a room. + """ + + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/messages$") + + def __init__(self, hs: "HomeServer"): + self._hs = hs + self._clock = hs.get_clock() + self._pagination_handler = hs.get_pagination_handler() + self._auth = hs.get_auth() + self._store = hs.get_datastores().main + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request) + await assert_user_is_admin(self._auth, requester) + + pagination_config = await PaginationConfig.from_request( + self._store, request, default_limit=10 + ) + # Twisted will have processed the args by now. + assert request.args is not None + as_client_event = b"raw" not in request.args + filter_str = parse_string(request, "filter", encoding="utf-8") + if filter_str: + filter_json = urlparse.unquote(filter_str) + event_filter: Optional[Filter] = Filter( + self._hs, json_decoder.decode(filter_json) + ) + if ( + event_filter + and event_filter.filter_json.get("event_format", "client") + == "federation" + ): + as_client_event = False + else: + event_filter = None + + msgs = await self._pagination_handler.get_messages( + room_id=room_id, + requester=requester, + pagin_config=pagination_config, + as_client_event=as_client_event, + event_filter=event_filter, + use_admin_priviledge=True, + ) + + return HTTPStatus.OK, msgs + + +class RoomTimestampToEventRestServlet(RestServlet): + """ + API endpoint to fetch the `event_id` of the closest event to the given + timestamp (`ts` query parameter) in the given direction (`dir` query + parameter). + + Useful for cases like jump to date so you can start paginating messages from + a given date in the archive. + + `ts` is a timestamp in milliseconds where we will find the closest event in + the given direction. + + `dir` can be `f` or `b` to indicate forwards and backwards in time from the + given timestamp. + + GET /_synapse/admin/v1/rooms//timestamp_to_event?ts=&dir= + { + "event_id": ... + } + """ + + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/timestamp_to_event$") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastores().main + self._timestamp_lookup_handler = hs.get_timestamp_lookup_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request) + await assert_user_is_admin(self._auth, requester) + + timestamp = parse_integer(request, "ts", required=True) + direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"]) + + ( + event_id, + origin_server_ts, + ) = await self._timestamp_lookup_handler.get_event_for_timestamp( + requester, room_id, timestamp, direction + ) + + return HTTPStatus.OK, { + "event_id": event_id, + "origin_server_ts": origin_server_ts, + } diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 9d71a97524..d156be82b0 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json +import time import urllib.parse from typing import List, Optional from unittest.mock import Mock @@ -22,10 +24,11 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.constants import EventTypes, Membership, RoomTypes from synapse.api.errors import Codes -from synapse.handlers.pagination import PaginationHandler +from synapse.handlers.pagination import PaginationHandler, PurgeStatus from synapse.rest.client import directory, events, login, room from synapse.server import HomeServer from synapse.util import Clock +from synapse.util.stringutils import random_string from tests import unittest @@ -1793,6 +1796,159 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) +class RoomMessagesTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.user = self.register_user("foo", "pass") + self.user_tok = self.login("foo", "pass") + self.room_id = self.helper.create_room_as(self.user, tok=self.user_tok) + + def test_timestamp_to_event(self) -> None: + """Test that providing the current timestamp can get the last event.""" + self.helper.send(self.room_id, body="message 1", tok=self.user_tok) + second_event_id = self.helper.send( + self.room_id, body="message 2", tok=self.user_tok + )["event_id"] + ts = str(round(time.time() * 1000)) + + channel = self.make_request( + "GET", + "/_synapse/admin/v1/rooms/%s/timestamp_to_event?dir=b&ts=%s" + % (self.room_id, ts), + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code) + self.assertIn("event_id", channel.json_body) + self.assertEqual(second_event_id, channel.json_body["event_id"]) + + def test_topo_token_is_accepted(self) -> None: + """Test Topo Token is accepted.""" + token = "t1-0_0_0_0_0_0_0_0_0" + channel = self.make_request( + "GET", + "/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token), + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code) + self.assertIn("start", channel.json_body) + self.assertEqual(token, channel.json_body["start"]) + self.assertIn("chunk", channel.json_body) + self.assertIn("end", channel.json_body) + + def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None: + """Test that stream token is accepted for forward pagination.""" + token = "s0_0_0_0_0_0_0_0_0" + channel = self.make_request( + "GET", + "/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token), + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code) + self.assertIn("start", channel.json_body) + self.assertEqual(token, channel.json_body["start"]) + self.assertIn("chunk", channel.json_body) + self.assertIn("end", channel.json_body) + + def test_room_messages_purge(self) -> None: + """Test room messages can be retrieved by an admin that isn't in the room.""" + store = self.hs.get_datastores().main + pagination_handler = self.hs.get_pagination_handler() + + # Send a first message in the room, which will be removed by the purge. + first_event_id = self.helper.send( + self.room_id, body="message 1", tok=self.user_tok + )["event_id"] + first_token = self.get_success( + store.get_topological_token_for_event(first_event_id) + ) + first_token_str = self.get_success(first_token.to_string(store)) + + # Send a second message in the room, which won't be removed, and which we'll + # use as the marker to purge events before. + second_event_id = self.helper.send( + self.room_id, body="message 2", tok=self.user_tok + )["event_id"] + second_token = self.get_success( + store.get_topological_token_for_event(second_event_id) + ) + second_token_str = self.get_success(second_token.to_string(store)) + + # Send a third event in the room to ensure we don't fall under any edge case + # due to our marker being the latest forward extremity in the room. + self.helper.send(self.room_id, body="message 3", tok=self.user_tok) + + # Check that we get the first and second message when querying /messages. + channel = self.make_request( + "GET", + "/_synapse/admin/v1/rooms/%s/messages?from=%s&dir=b&filter=%s" + % ( + self.room_id, + second_token_str, + json.dumps({"types": [EventTypes.Message]}), + ), + access_token=self.admin_user_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 2, [event["content"] for event in chunk]) + + # Purge every event before the second event. + purge_id = random_string(16) + pagination_handler._purges_by_id[purge_id] = PurgeStatus() + self.get_success( + pagination_handler._purge_history( + purge_id=purge_id, + room_id=self.room_id, + token=second_token_str, + delete_local_events=True, + ) + ) + + # Check that we only get the second message through /message now that the first + # has been purged. + channel = self.make_request( + "GET", + "/_synapse/admin/v1/rooms/%s/messages?from=%s&dir=b&filter=%s" + % ( + self.room_id, + second_token_str, + json.dumps({"types": [EventTypes.Message]}), + ), + access_token=self.admin_user_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 1, [event["content"] for event in chunk]) + + # Check that we get no event, but also no error, when querying /messages with + # the token that was pointing at the first event, because we don't have it + # anymore. + channel = self.make_request( + "GET", + "/_synapse/admin/v1/rooms/%s/messages?from=%s&dir=b&filter=%s" + % ( + self.room_id, + first_token_str, + json.dumps({"types": [EventTypes.Message]}), + ), + access_token=self.admin_user_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 0, [event["content"] for event in chunk]) + + class JoinAliasRoomTestCase(unittest.HomeserverTestCase): servlets = [ -- cgit 1.5.1 From d3d9ca156e323fe194b1bcb1af1628f65a2f3c1c Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 7 Sep 2022 11:03:32 +0000 Subject: Cancel the processing of key query requests when they time out. (#13680) --- changelog.d/13680.feature | 1 + synapse/api/auth.py | 5 +++ synapse/handlers/device.py | 3 ++ synapse/handlers/e2e_keys.py | 40 +++++++++++++--------- synapse/rest/client/keys.py | 6 ++-- synapse/storage/controllers/state.py | 4 +++ synapse/storage/databases/main/devices.py | 4 +++ synapse/storage/databases/main/end_to_end_keys.py | 5 ++- synapse/storage/databases/main/event_federation.py | 2 ++ synapse/storage/databases/main/events_worker.py | 4 +++ synapse/storage/databases/main/roommember.py | 2 ++ synapse/storage/databases/main/state.py | 2 ++ synapse/storage/databases/main/stream.py | 2 ++ synapse/storage/databases/state/store.py | 3 ++ .../storage/util/partial_state_events_tracker.py | 3 ++ synapse/types.py | 5 +++ tests/http/server/_base.py | 10 +++++- tests/rest/client/test_keys.py | 29 ++++++++++++++++ 18 files changed, 110 insertions(+), 20 deletions(-) create mode 100644 changelog.d/13680.feature (limited to 'tests/rest') diff --git a/changelog.d/13680.feature b/changelog.d/13680.feature new file mode 100644 index 0000000000..4234c7e082 --- /dev/null +++ b/changelog.d/13680.feature @@ -0,0 +1 @@ +Cancel the processing of key query requests when they time out. \ No newline at end of file diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 9a1aea083f..8e54ef84b2 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -38,6 +38,7 @@ from synapse.logging.opentracing import ( trace, ) from synapse.types import Requester, create_requester +from synapse.util.cancellation import cancellable if TYPE_CHECKING: from synapse.server import HomeServer @@ -118,6 +119,7 @@ class Auth: errcode=Codes.NOT_JOINED, ) + @cancellable async def get_user_by_req( self, request: SynapseRequest, @@ -166,6 +168,7 @@ class Auth: parent_span.set_tag("appservice_id", requester.app_service.id) return requester + @cancellable async def _wrapped_get_user_by_req( self, request: SynapseRequest, @@ -281,6 +284,7 @@ class Auth: 403, "Application service has not registered this user (%s)" % user_id ) + @cancellable async def _get_appservice_user(self, request: Request) -> Optional[Requester]: """ Given a request, reads the request parameters to determine: @@ -523,6 +527,7 @@ class Auth: return bool(query_params) or bool(auth_headers) @staticmethod + @cancellable def get_access_token_from_request(request: Request) -> str: """Extracts the access_token from the request. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 9c2c3a0e68..c5ac169644 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -52,6 +52,7 @@ from synapse.types import ( from synapse.util import stringutils from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache +from synapse.util.cancellation import cancellable from synapse.util.metrics import measure_func from synapse.util.retryutils import NotRetryingDestination @@ -124,6 +125,7 @@ class DeviceWorkerHandler: return device + @cancellable async def get_device_changes_in_shared_rooms( self, user_id: str, room_ids: Collection[str], from_token: StreamToken ) -> Collection[str]: @@ -163,6 +165,7 @@ class DeviceWorkerHandler: @trace @measure_func("device.get_user_ids_changed") + @cancellable async def get_user_ids_changed( self, user_id: str, from_token: StreamToken ) -> JsonDict: diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index c938339ddd..ec81639c78 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -37,7 +37,8 @@ from synapse.types import ( get_verify_key_from_cross_signing_key, ) from synapse.util import json_decoder, unwrapFirstError -from synapse.util.async_helpers import Linearizer +from synapse.util.async_helpers import Linearizer, delay_cancellation +from synapse.util.cancellation import cancellable from synapse.util.retryutils import NotRetryingDestination if TYPE_CHECKING: @@ -91,6 +92,7 @@ class E2eKeysHandler: ) @trace + @cancellable async def query_devices( self, query_body: JsonDict, @@ -208,22 +210,26 @@ class E2eKeysHandler: r[user_id] = remote_queries[user_id] # Now fetch any devices that we don't have in our cache + # TODO It might make sense to propagate cancellations into the + # deferreds which are querying remote homeservers. await make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background( - self._query_devices_for_destination, - results, - cross_signing_keys, - failures, - destination, - queries, - timeout, - ) - for destination, queries in remote_queries_not_in_cache.items() - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) + delay_cancellation( + defer.gatherResults( + [ + run_in_background( + self._query_devices_for_destination, + results, + cross_signing_keys, + failures, + destination, + queries, + timeout, + ) + for destination, queries in remote_queries_not_in_cache.items() + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + ) ) ret = {"device_keys": results, "failures": failures} @@ -347,6 +353,7 @@ class E2eKeysHandler: return + @cancellable async def get_cross_signing_keys_from_cache( self, query: Iterable[str], from_user_id: Optional[str] ) -> Dict[str, Dict[str, dict]]: @@ -393,6 +400,7 @@ class E2eKeysHandler: } @trace + @cancellable async def query_local_devices( self, query: Mapping[str, Optional[List[str]]] ) -> Dict[str, Dict[str, dict]]: diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index a395694fa5..f653d2a3e1 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -27,9 +27,9 @@ from synapse.http.servlet import ( ) from synapse.http.site import SynapseRequest from synapse.logging.opentracing import log_kv, set_tag +from synapse.rest.client._base import client_patterns, interactive_auth_handler from synapse.types import JsonDict, StreamToken - -from ._base import client_patterns, interactive_auth_handler +from synapse.util.cancellation import cancellable if TYPE_CHECKING: from synapse.server import HomeServer @@ -156,6 +156,7 @@ class KeyQueryServlet(RestServlet): self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() + @cancellable async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() @@ -199,6 +200,7 @@ class KeyChangesServlet(RestServlet): self.device_handler = hs.get_device_handler() self.store = hs.get_datastores().main + @cancellable async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index ba5380ce3e..bbe568bf05 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -36,6 +36,7 @@ from synapse.storage.util.partial_state_events_tracker import ( PartialStateEventsTracker, ) from synapse.types import MutableStateMap, StateMap +from synapse.util.cancellation import cancellable if TYPE_CHECKING: from synapse.server import HomeServer @@ -229,6 +230,7 @@ class StateStorageController: @trace @tag_args + @cancellable async def get_state_ids_for_events( self, event_ids: Collection[str], @@ -350,6 +352,7 @@ class StateStorageController: @trace @tag_args + @cancellable async def get_state_group_for_events( self, event_ids: Collection[str], @@ -398,6 +401,7 @@ class StateStorageController: event_id, room_id, prev_group, delta_ids, current_state_ids ) + @cancellable async def get_current_state_ids( self, room_id: str, diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index ca0fe8c4be..5d700ca6c3 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -53,6 +53,7 @@ from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -668,6 +669,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore): ... @trace + @cancellable async def get_user_devices_from_cache( self, query_list: List[Tuple[str, Optional[str]]] ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]: @@ -743,6 +745,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore): return self._device_list_stream_cache.get_all_entities_changed(from_key) + @cancellable async def get_users_whose_devices_changed( self, from_key: int, @@ -1221,6 +1224,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore): desc="get_min_device_lists_changes_in_room", ) + @cancellable async def get_device_list_changes_in_rooms( self, room_ids: Collection[str], from_id: int ) -> Optional[Set[str]]: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 46c0d06157..8e9e1b0b4b 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -50,6 +50,7 @@ from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -135,6 +136,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker return now_stream_id, [] @trace + @cancellable async def get_e2e_device_keys_for_cs_api( self, query_list: List[Tuple[str, Optional[str]]] ) -> Dict[str, Dict[str, JsonDict]]: @@ -197,6 +199,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker ... @trace + @cancellable async def get_e2e_device_keys_and_signatures( self, query_list: Collection[Tuple[str, Optional[str]]], @@ -887,6 +890,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker return keys + @cancellable async def get_e2e_cross_signing_keys_bulk( self, user_ids: List[str], from_user_id: Optional[str] = None ) -> Dict[str, Optional[Dict[str, JsonDict]]]: @@ -902,7 +906,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker keys were not found, either their user ID will not be in the dict, or their user ID will map to None. """ - result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids) if from_user_id: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index e687f87eca..ca47a22bf1 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -48,6 +48,7 @@ from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache +from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -976,6 +977,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas return int(min_depth) if min_depth is not None else None + @cancellable async def get_forward_extremities_for_room_at_stream_ordering( self, room_id: str, stream_ordering: int ) -> List[str]: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 84f17a9945..52914febf9 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -81,6 +81,7 @@ from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred, delay_cancellation from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import AsyncLruCache +from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -339,6 +340,7 @@ class EventsWorkerStore(SQLBaseStore): ) -> Optional[EventBase]: ... + @cancellable async def get_event( self, event_id: str, @@ -433,6 +435,7 @@ class EventsWorkerStore(SQLBaseStore): @trace @tag_args + @cancellable async def get_events_as_list( self, event_ids: Collection[str], @@ -584,6 +587,7 @@ class EventsWorkerStore(SQLBaseStore): return events + @cancellable async def _get_events_from_cache_or_db( self, event_ids: Iterable[str], allow_rejected: bool = False ) -> Dict[str, EventCacheEntry]: diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 4f0adb136a..a77e49dc66 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -55,6 +55,7 @@ from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList +from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -770,6 +771,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): _get_users_server_still_shares_room_with_txn, ) + @cancellable async def get_rooms_for_user( self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None ) -> FrozenSet[str]: diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 0b10af0e58..e607ccfdc9 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -36,6 +36,7 @@ from synapse.storage.state import StateFilter from synapse.types import JsonDict, JsonMapping, StateMap from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -281,6 +282,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) # FIXME: how should this be cached? + @cancellable async def get_partial_filtered_current_state_ids( self, room_id: str, state_filter: Optional[StateFilter] = None ) -> StateMap[str]: diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index a347430aa7..3f9bfaeac5 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -72,6 +72,7 @@ from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import PersistedEventPosition, RoomStreamToken from synapse.util.caches.descriptors import cached from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.util.cancellation import cancellable if TYPE_CHECKING: from synapse.server import HomeServer @@ -597,6 +598,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return ret, key + @cancellable async def get_membership_changes_for_user( self, user_id: str, diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index bb64543c1f..f8cfcaca83 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -31,6 +31,7 @@ from synapse.storage.util.sequence import build_sequence_generator from synapse.types import MutableStateMap, StateKey, StateMap from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache +from synapse.util.cancellation import cancellable if TYPE_CHECKING: from synapse.server import HomeServer @@ -156,6 +157,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): "get_state_group_delta", _get_state_group_delta_txn ) + @cancellable async def _get_state_groups_from_groups( self, groups: List[int], state_filter: StateFilter ) -> Dict[int, StateMap[str]]: @@ -235,6 +237,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return state_filter.filter_state(state_dict_ids), not missing_types + @cancellable async def _get_state_for_groups( self, groups: Iterable[int], state_filter: Optional[StateFilter] = None ) -> Dict[int, MutableStateMap[str]]: diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py index b4bf49dace..8d8894d1d5 100644 --- a/synapse/storage/util/partial_state_events_tracker.py +++ b/synapse/storage/util/partial_state_events_tracker.py @@ -24,6 +24,7 @@ from synapse.logging.opentracing import trace_with_opname from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.room import RoomWorkerStore from synapse.util import unwrapFirstError +from synapse.util.cancellation import cancellable logger = logging.getLogger(__name__) @@ -60,6 +61,7 @@ class PartialStateEventsTracker: o.callback(None) @trace_with_opname("PartialStateEventsTracker.await_full_state") + @cancellable async def await_full_state(self, event_ids: Collection[str]) -> None: """Wait for all the given events to have full state. @@ -154,6 +156,7 @@ class PartialCurrentStateTracker: o.callback(None) @trace_with_opname("PartialCurrentStateTracker.await_full_state") + @cancellable async def await_full_state(self, room_id: str) -> None: # We add the deferred immediately so that the DB call to check for # partial state doesn't race when we unpartial the room. diff --git a/synapse/types.py b/synapse/types.py index 668d48d646..ec44601f54 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -52,6 +52,7 @@ from twisted.internet.interfaces import ( ) from synapse.api.errors import Codes, SynapseError +from synapse.util.cancellation import cancellable from synapse.util.stringutils import parse_and_validate_server_name if TYPE_CHECKING: @@ -699,7 +700,11 @@ class StreamToken: START: ClassVar["StreamToken"] @classmethod + @cancellable async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": + """ + Creates a RoomStreamToken from its textual representation. + """ try: keys = string.split(cls._SEPARATOR) while len(keys) < len(attr.fields(cls)): diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index 5726e60cee..5071f83574 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -140,6 +140,8 @@ def make_request_with_cancellation_test( method: str, path: str, content: Union[bytes, str, JsonDict] = b"", + *, + token: Optional[str] = None, ) -> FakeChannel: """Performs a request repeatedly, disconnecting at successive `await`s, until one completes. @@ -211,7 +213,13 @@ def make_request_with_cancellation_test( with deferred_patch.patch(): # Start the request. channel = make_request( - reactor, site, method, path, content, await_result=False + reactor, + site, + method, + path, + content, + await_result=False, + access_token=token, ) request = channel.request diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py index bbc8e74243..741fecea77 100644 --- a/tests/rest/client/test_keys.py +++ b/tests/rest/client/test_keys.py @@ -19,6 +19,7 @@ from synapse.rest import admin from synapse.rest.client import keys, login from tests import unittest +from tests.http.server._base import make_request_with_cancellation_test class KeyQueryTestCase(unittest.HomeserverTestCase): @@ -89,3 +90,31 @@ class KeyQueryTestCase(unittest.HomeserverTestCase): Codes.BAD_JSON, channel.result, ) + + def test_key_query_cancellation(self) -> None: + """ + Tests that /keys/query is cancellable and does not swallow the + CancelledError. + """ + self.register_user("alice", "wonderland") + alice_token = self.login("alice", "wonderland") + + bob = self.register_user("bob", "uncle") + + channel = make_request_with_cancellation_test( + "test_key_query_cancellation", + self.reactor, + self.site, + "POST", + "/_matrix/client/r0/keys/query", + { + "device_keys": { + # Empty list means we request keys for all bob's devices + bob: [], + }, + }, + token=alice_token, + ) + + self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertIn(bob, channel.json_body["device_keys"]) -- cgit 1.5.1 From f799eac7ea96f943ad1272a5a81f845dfa08a254 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 8 Sep 2022 17:41:48 +0200 Subject: Add timestamp to user's consent (#13741) Co-authored-by: reivilibre --- changelog.d/13741.feature | 1 + docs/admin_api/user_admin_api.md | 2 ++ synapse/handlers/admin.py | 1 + synapse/storage/databases/main/registration.py | 6 +++- .../main/delta/72/06add_consent_ts_to_users.sql | 16 +++++++++++ tests/rest/admin/test_user.py | 1 + tests/storage/test_registration.py | 33 +++++++++++++++++----- 7 files changed, 52 insertions(+), 8 deletions(-) create mode 100644 changelog.d/13741.feature create mode 100644 synapse/storage/schema/main/delta/72/06add_consent_ts_to_users.sql (limited to 'tests/rest') diff --git a/changelog.d/13741.feature b/changelog.d/13741.feature new file mode 100644 index 0000000000..dff46f373f --- /dev/null +++ b/changelog.d/13741.feature @@ -0,0 +1 @@ +Document the timestamp when a user accepts the consent, if [consent tracking](https://matrix-org.github.io/synapse/latest/consent_tracking.html) is used. \ No newline at end of file diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index c1ca0c8a64..975f05c929 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -42,6 +42,7 @@ It returns a JSON body like the following: "appservice_id": null, "consent_server_notice_sent": null, "consent_version": null, + "consent_ts": null, "external_ids": [ { "auth_provider": "", @@ -364,6 +365,7 @@ The following actions are **NOT** performed. The list may be incomplete. - Remove the user's creation (registration) timestamp - [Remove rate limit overrides](#override-ratelimiting-for-users) - Remove from monthly active users +- Remove user's consent information (consent version and timestamp) ## Reset password diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index d4fe7df533..cf9f19608a 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -70,6 +70,7 @@ class AdminHandler: "appservice_id", "consent_server_notice_sent", "consent_version", + "consent_ts", "user_type", "is_guest", } diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 7fb9c801da..ac821878b0 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -175,6 +175,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "is_guest", "admin", "consent_version", + "consent_ts", "consent_server_notice_sent", "appservice_id", "creation_ts", @@ -2227,7 +2228,10 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): txn, table="users", keyvalues={"name": user_id}, - updatevalues={"consent_version": consent_version}, + updatevalues={ + "consent_version": consent_version, + "consent_ts": self._clock.time_msec(), + }, ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) diff --git a/synapse/storage/schema/main/delta/72/06add_consent_ts_to_users.sql b/synapse/storage/schema/main/delta/72/06add_consent_ts_to_users.sql new file mode 100644 index 0000000000..609eb1750f --- /dev/null +++ b/synapse/storage/schema/main/delta/72/06add_consent_ts_to_users.sql @@ -0,0 +1,16 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE users ADD consent_ts bigint; diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 1afd082707..ec5ccf6fca 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2580,6 +2580,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertIn("appservice_id", content) self.assertIn("consent_server_notice_sent", content) self.assertIn("consent_version", content) + self.assertIn("consent_ts", content) self.assertIn("external_ids", content) # This key was removed intentionally. Ensure it is not accidentally re-included. diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index a49ac1525e..853a93afab 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -11,15 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import UserTypes from synapse.api.errors import ThreepidValidationError +from synapse.server import HomeServer +from synapse.util import Clock from tests.unittest import HomeserverTestCase class RegistrationStoreTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.user_id = "@my-user:test" @@ -27,7 +30,7 @@ class RegistrationStoreTestCase(HomeserverTestCase): self.pwhash = "{xx1}123456789" self.device_id = "akgjhdjklgshg" - def test_register(self): + def test_register(self) -> None: self.get_success(self.store.register_user(self.user_id, self.pwhash)) self.assertEqual( @@ -38,6 +41,7 @@ class RegistrationStoreTestCase(HomeserverTestCase): "admin": 0, "is_guest": 0, "consent_version": None, + "consent_ts": None, "consent_server_notice_sent": None, "appservice_id": None, "creation_ts": 0, @@ -48,7 +52,20 @@ class RegistrationStoreTestCase(HomeserverTestCase): (self.get_success(self.store.get_user_by_id(self.user_id))), ) - def test_add_tokens(self): + def test_consent(self) -> None: + self.get_success(self.store.register_user(self.user_id, self.pwhash)) + before_consent = self.clock.time_msec() + self.reactor.advance(5) + self.get_success(self.store.user_set_consent_version(self.user_id, "1")) + self.reactor.advance(5) + + user = self.get_success(self.store.get_user_by_id(self.user_id)) + assert user + self.assertEqual(user["consent_version"], "1") + self.assertGreater(user["consent_ts"], before_consent) + self.assertLess(user["consent_ts"], self.clock.time_msec()) + + def test_add_tokens(self) -> None: self.get_success(self.store.register_user(self.user_id, self.pwhash)) self.get_success( self.store.add_access_token_to_user( @@ -58,11 +75,12 @@ class RegistrationStoreTestCase(HomeserverTestCase): result = self.get_success(self.store.get_user_by_access_token(self.tokens[1])) + assert result self.assertEqual(result.user_id, self.user_id) self.assertEqual(result.device_id, self.device_id) self.assertIsNotNone(result.token_id) - def test_user_delete_access_tokens(self): + def test_user_delete_access_tokens(self) -> None: # add some tokens self.get_success(self.store.register_user(self.user_id, self.pwhash)) self.get_success( @@ -87,6 +105,7 @@ class RegistrationStoreTestCase(HomeserverTestCase): # check the one not associated with the device was not deleted user = self.get_success(self.store.get_user_by_access_token(self.tokens[0])) + assert user self.assertEqual(self.user_id, user.user_id) # now delete the rest @@ -95,11 +114,11 @@ class RegistrationStoreTestCase(HomeserverTestCase): user = self.get_success(self.store.get_user_by_access_token(self.tokens[0])) self.assertIsNone(user, "access token was not deleted without device_id") - def test_is_support_user(self): + def test_is_support_user(self) -> None: TEST_USER = "@test:test" SUPPORT_USER = "@support:test" - res = self.get_success(self.store.is_support_user(None)) + res = self.get_success(self.store.is_support_user(None)) # type: ignore[arg-type] self.assertFalse(res) self.get_success( self.store.register_user(user_id=TEST_USER, password_hash=None) @@ -115,7 +134,7 @@ class RegistrationStoreTestCase(HomeserverTestCase): res = self.get_success(self.store.is_support_user(SUPPORT_USER)) self.assertTrue(res) - def test_3pid_inhibit_invalid_validation_session_error(self): + def test_3pid_inhibit_invalid_validation_session_error(self) -> None: """Tests that enabling the configuration option to inhibit 3PID errors on /requestToken also inhibits validation errors caused by an unknown session ID. """ -- cgit 1.5.1 From 918c74bfb57e3ca4d300ed9a3bfb99b99126f821 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Thu, 15 Sep 2022 13:57:16 +0100 Subject: Add a `MXCUri` class to make working with mxc uri's easier. (#13162) --- changelog.d/13162.misc | 1 + poetry.lock | 10 +-- pyproject.toml | 2 +- synapse/rest/media/v1/media_repository.py | 6 +- synapse/rest/media/v1/upload_resource.py | 6 +- tests/rest/media/test_media_retention.py | 102 +++++++++++------------------- 6 files changed, 53 insertions(+), 74 deletions(-) create mode 100644 changelog.d/13162.misc (limited to 'tests/rest') diff --git a/changelog.d/13162.misc b/changelog.d/13162.misc new file mode 100644 index 0000000000..b0d7c05e74 --- /dev/null +++ b/changelog.d/13162.misc @@ -0,0 +1 @@ +Bump the minimum dependency of `matrix_common` to 1.3.0 to make use of the `MXCUri` class. Use `MXCUri` to simplify media retention test code. \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index cdc69f8ea9..291f3c51e6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -524,11 +524,11 @@ python-versions = ">=3.7" [[package]] name = "matrix-common" -version = "1.2.1" +version = "1.3.0" description = "Common utilities for Synapse, Sydent and Sygnal" category = "main" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" [package.dependencies] attrs = "*" @@ -1625,7 +1625,7 @@ url_preview = ["lxml"] [metadata] lock-version = "1.1" python-versions = "^3.7.1" -content-hash = "79cfa09d59f9f8b5ef24318fb860df1915f54328692aa56d04331ecbdd92a8cb" +content-hash = "1b14fc274d9e2a495a7f864150f3ffcf4d9f585e09a67e53301ae4ef3c2f3e48" [metadata.files] attrs = [ @@ -2113,8 +2113,8 @@ markupsafe = [ {file = "MarkupSafe-2.1.0.tar.gz", hash = "sha256:80beaf63ddfbc64a0452b841d8036ca0611e049650e20afcb882f5d3c266d65f"}, ] matrix-common = [ - {file = "matrix_common-1.2.1-py3-none-any.whl", hash = "sha256:946709c405944a0d4b1d73207b77eb064b6dbfc5d70a69471320b06d8ce98b20"}, - {file = "matrix_common-1.2.1.tar.gz", hash = "sha256:a99dcf02a6bd95b24a5a61b354888a2ac92bf2b4b839c727b8dd9da2cdfa3853"}, + {file = "matrix_common-1.3.0-py3-none-any.whl", hash = "sha256:524e2785b9b03be4d15f3a8a6b857c5b6af68791ffb1b9918f0ad299abc4db20"}, + {file = "matrix_common-1.3.0.tar.gz", hash = "sha256:62e121cccd9f243417b57ec37a76dc44aeb198a7a5c67afd6b8275992ff2abd1"}, ] matrix-synapse-ldap3 = [ {file = "matrix-synapse-ldap3-0.2.2.tar.gz", hash = "sha256:b388d95693486eef69adaefd0fd9e84463d52fe17b0214a00efcaa669b73cb74"}, diff --git a/pyproject.toml b/pyproject.toml index 157385ad8a..8e50dd2852 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -164,7 +164,7 @@ typing-extensions = ">=3.10.0.1" cryptography = ">=3.4.7" # ijson 3.1.4 fixes a bug with "." in property names ijson = ">=3.1.4" -matrix-common = "^1.2.1" +matrix-common = "^1.3.0" # We need packaging.requirements.Requirement, added in 16.1. packaging = ">=16.1" # At the time of writing, we only use functions from the version `importlib.metadata` diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 9dd3c8d4bb..328c0c5477 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -19,6 +19,8 @@ import shutil from io import BytesIO from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from matrix_common.types.mxc_uri import MXCUri + import twisted.internet.error import twisted.web.http from twisted.internet.defer import Deferred @@ -186,7 +188,7 @@ class MediaRepository: content: IO, content_length: int, auth_user: UserID, - ) -> str: + ) -> MXCUri: """Store uploaded content for a local user and return the mxc URL Args: @@ -219,7 +221,7 @@ class MediaRepository: await self._generate_thumbnails(None, media_id, media_id, media_type) - return "mxc://%s/%s" % (self.server_name, media_id) + return MXCUri(self.server_name, media_id) async def get_local_media( self, request: SynapseRequest, media_id: str, name: Optional[str] diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index e73e431dc9..97548b54e5 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -101,6 +101,8 @@ class UploadResource(DirectServeJsonResource): # the default 404, as that would just be confusing. raise SynapseError(400, "Bad content") - logger.info("Uploaded content with URI %r", content_uri) + logger.info("Uploaded content with URI '%s'", content_uri) - respond_with_json(request, 200, {"content_uri": content_uri}, send_cors=True) + respond_with_json( + request, 200, {"content_uri": str(content_uri)}, send_cors=True + ) diff --git a/tests/rest/media/test_media_retention.py b/tests/rest/media/test_media_retention.py index 14af07c5af..23f227aed6 100644 --- a/tests/rest/media/test_media_retention.py +++ b/tests/rest/media/test_media_retention.py @@ -13,7 +13,9 @@ # limitations under the License. import io -from typing import Iterable, Optional, Tuple +from typing import Iterable, Optional + +from matrix_common.types.mxc_uri import MXCUri from twisted.test.proto_helpers import MemoryReactor @@ -63,9 +65,9 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase): last_accessed_ms: Optional[int], is_quarantined: Optional[bool] = False, is_protected: Optional[bool] = False, - ) -> str: + ) -> MXCUri: # "Upload" some media to the local media store - mxc_uri = self.get_success( + mxc_uri: MXCUri = self.get_success( media_repository.create_content( media_type="text/plain", upload_name=None, @@ -75,13 +77,11 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase): ) ) - media_id = mxc_uri.split("/")[-1] - # Set the last recently accessed time for this media if last_accessed_ms is not None: self.get_success( self.store.update_cached_last_access_time( - local_media=(media_id,), + local_media=(mxc_uri.media_id,), remote_media=(), time_ms=last_accessed_ms, ) @@ -92,7 +92,7 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase): self.get_success( self.store.quarantine_media_by_id( server_name=self.hs.config.server.server_name, - media_id=media_id, + media_id=mxc_uri.media_id, quarantined_by="@theadmin:test", ) ) @@ -101,18 +101,18 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase): # Mark this media as protected from quarantine self.get_success( self.store.mark_local_media_as_safe( - media_id=media_id, + media_id=mxc_uri.media_id, safe=True, ) ) - return media_id + return mxc_uri def _cache_remote_media_and_set_attributes( media_id: str, last_accessed_ms: Optional[int], is_quarantined: Optional[bool] = False, - ) -> str: + ) -> MXCUri: # Pretend to cache some remote media self.get_success( self.store.store_cached_remote_media( @@ -146,7 +146,7 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase): ) ) - return media_id + return MXCUri(self.remote_server_name, media_id) # Start with the local media store self.local_recently_accessed_media = _create_media_and_set_attributes( @@ -214,28 +214,16 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase): # Remote media should be unaffected. self._assert_if_mxc_uris_purged( purged=[ - ( - self.hs.config.server.server_name, - self.local_not_recently_accessed_media, - ), - (self.hs.config.server.server_name, self.local_never_accessed_media), + self.local_not_recently_accessed_media, + self.local_never_accessed_media, ], not_purged=[ - (self.hs.config.server.server_name, self.local_recently_accessed_media), - ( - self.hs.config.server.server_name, - self.local_not_recently_accessed_quarantined_media, - ), - ( - self.hs.config.server.server_name, - self.local_not_recently_accessed_protected_media, - ), - (self.remote_server_name, self.remote_recently_accessed_media), - (self.remote_server_name, self.remote_not_recently_accessed_media), - ( - self.remote_server_name, - self.remote_not_recently_accessed_quarantined_media, - ), + self.local_recently_accessed_media, + self.local_not_recently_accessed_quarantined_media, + self.local_not_recently_accessed_protected_media, + self.remote_recently_accessed_media, + self.remote_not_recently_accessed_media, + self.remote_not_recently_accessed_quarantined_media, ], ) @@ -261,49 +249,35 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase): # Remote media accessed <30 days ago should still exist. self._assert_if_mxc_uris_purged( purged=[ - (self.remote_server_name, self.remote_not_recently_accessed_media), + self.remote_not_recently_accessed_media, ], not_purged=[ - (self.remote_server_name, self.remote_recently_accessed_media), - (self.hs.config.server.server_name, self.local_recently_accessed_media), - ( - self.hs.config.server.server_name, - self.local_not_recently_accessed_media, - ), - ( - self.hs.config.server.server_name, - self.local_not_recently_accessed_quarantined_media, - ), - ( - self.hs.config.server.server_name, - self.local_not_recently_accessed_protected_media, - ), - ( - self.remote_server_name, - self.remote_not_recently_accessed_quarantined_media, - ), - (self.hs.config.server.server_name, self.local_never_accessed_media), + self.remote_recently_accessed_media, + self.local_recently_accessed_media, + self.local_not_recently_accessed_media, + self.local_not_recently_accessed_quarantined_media, + self.local_not_recently_accessed_protected_media, + self.remote_not_recently_accessed_quarantined_media, + self.local_never_accessed_media, ], ) def _assert_if_mxc_uris_purged( - self, purged: Iterable[Tuple[str, str]], not_purged: Iterable[Tuple[str, str]] + self, purged: Iterable[MXCUri], not_purged: Iterable[MXCUri] ) -> None: - def _assert_mxc_uri_purge_state( - server_name: str, media_id: str, expect_purged: bool - ) -> None: + def _assert_mxc_uri_purge_state(mxc_uri: MXCUri, expect_purged: bool) -> None: """Given an MXC URI, assert whether it has been purged or not.""" - if server_name == self.hs.config.server.server_name: + if mxc_uri.server_name == self.hs.config.server.server_name: found_media_dict = self.get_success( - self.store.get_local_media(media_id) + self.store.get_local_media(mxc_uri.media_id) ) else: found_media_dict = self.get_success( - self.store.get_cached_remote_media(server_name, media_id) + self.store.get_cached_remote_media( + mxc_uri.server_name, mxc_uri.media_id + ) ) - mxc_uri = f"mxc://{server_name}/{media_id}" - if expect_purged: self.assertIsNone( found_media_dict, msg=f"{mxc_uri} unexpectedly not purged" @@ -315,7 +289,7 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase): ) # Assert that the given MXC URIs have either been correctly purged or not. - for server_name, media_id in purged: - _assert_mxc_uri_purge_state(server_name, media_id, expect_purged=True) - for server_name, media_id in not_purged: - _assert_mxc_uri_purge_state(server_name, media_id, expect_purged=False) + for mxc_uri in purged: + _assert_mxc_uri_purge_state(mxc_uri, expect_purged=True) + for mxc_uri in not_purged: + _assert_mxc_uri_purge_state(mxc_uri, expect_purged=False) -- cgit 1.5.1 From 742f9f9d78490f7f16bdb607a8f61ca258d520ef Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 15 Sep 2022 18:36:02 +0100 Subject: A third batch of Pydantic validation for rest/client/account.py (#13736) --- changelog.d/13736.feature | 1 + synapse/rest/client/account.py | 65 ++++++++++++++++++++++------------------ synapse/rest/client/models.py | 28 +++++++++-------- tests/rest/client/test_models.py | 29 ++++++++++++++++-- 4 files changed, 78 insertions(+), 45 deletions(-) create mode 100644 changelog.d/13736.feature (limited to 'tests/rest') diff --git a/changelog.d/13736.feature b/changelog.d/13736.feature new file mode 100644 index 0000000000..60a63c1009 --- /dev/null +++ b/changelog.d/13736.feature @@ -0,0 +1 @@ +Improve validation of request bodies for the following client-server API endpoints: [`/account/3pid/add`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidadd), [`/account/3pid/bind`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidbind), [`/account/3pid/delete`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3piddelete) and [`/account/3pid/unbind`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidunbind). diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index a09aaf3448..2db2a04f95 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple from urllib.parse import urlparse from pydantic import StrictBool, StrictStr, constr +from typing_extensions import Literal from twisted.web.server import Request @@ -43,6 +44,7 @@ from synapse.metrics import threepid_send_requests from synapse.push.mailer import Mailer from synapse.rest.client.models import ( AuthenticationData, + ClientSecretStr, EmailRequestTokenBody, MsisdnRequestTokenBody, ) @@ -627,6 +629,11 @@ class ThreepidAddRestServlet(RestServlet): self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() + class PostBody(RequestBodyModel): + auth: Optional[AuthenticationData] = None + client_secret: ClientSecretStr + sid: StrictStr + @interactive_auth_handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if not self.hs.config.registration.enable_3pid_changes: @@ -636,22 +643,17 @@ class ThreepidAddRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() - body = parse_json_object_from_request(request) - - assert_params_in_dict(body, ["client_secret", "sid"]) - sid = body["sid"] - client_secret = body["client_secret"] - assert_valid_client_secret(client_secret) + body = parse_and_validate_json_object_from_request(request, self.PostBody) await self.auth_handler.validate_user_via_ui_auth( requester, request, - body, + body.dict(exclude_unset=True), "add a third-party identifier to your account", ) validation_session = await self.identity_handler.validate_threepid_session( - client_secret, sid + body.client_secret, body.sid ) if validation_session: await self.auth_handler.add_threepid( @@ -676,23 +678,20 @@ class ThreepidBindRestServlet(RestServlet): self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() - async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - body = parse_json_object_from_request(request) + class PostBody(RequestBodyModel): + client_secret: ClientSecretStr + id_access_token: StrictStr + id_server: StrictStr + sid: StrictStr - assert_params_in_dict( - body, ["id_server", "sid", "id_access_token", "client_secret"] - ) - id_server = body["id_server"] - sid = body["sid"] - id_access_token = body["id_access_token"] - client_secret = body["client_secret"] - assert_valid_client_secret(client_secret) + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + body = parse_and_validate_json_object_from_request(request, self.PostBody) requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() await self.identity_handler.bind_threepid( - client_secret, sid, user_id, id_server, id_access_token + body.client_secret, body.sid, user_id, body.id_server, body.id_access_token ) return 200, {} @@ -708,23 +707,27 @@ class ThreepidUnbindRestServlet(RestServlet): self.auth = hs.get_auth() self.datastore = self.hs.get_datastores().main + class PostBody(RequestBodyModel): + address: StrictStr + id_server: Optional[StrictStr] = None + medium: Literal["email", "msisdn"] + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: """Unbind the given 3pid from a specific identity server, or identity servers that are known to have this 3pid bound """ requester = await self.auth.get_user_by_req(request) - body = parse_json_object_from_request(request) - assert_params_in_dict(body, ["medium", "address"]) - - medium = body.get("medium") - address = body.get("address") - id_server = body.get("id_server") + body = parse_and_validate_json_object_from_request(request, self.PostBody) # Attempt to unbind the threepid from an identity server. If id_server is None, try to # unbind from all identity servers this threepid has been added to in the past result = await self.identity_handler.try_unbind_threepid( requester.user.to_string(), - {"address": address, "medium": medium, "id_server": id_server}, + { + "address": body.address, + "medium": body.medium, + "id_server": body.id_server, + }, ) return 200, {"id_server_unbind_result": "success" if result else "no-support"} @@ -738,21 +741,25 @@ class ThreepidDeleteRestServlet(RestServlet): self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() + class PostBody(RequestBodyModel): + address: StrictStr + id_server: Optional[StrictStr] = None + medium: Literal["email", "msisdn"] + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if not self.hs.config.registration.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN ) - body = parse_json_object_from_request(request) - assert_params_in_dict(body, ["medium", "address"]) + body = parse_and_validate_json_object_from_request(request, self.PostBody) requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() try: ret = await self.auth_handler.delete_threepid( - user_id, body["medium"], body["address"], body.get("id_server") + user_id, body.medium, body.address, body.id_server ) except Exception: # NB. This endpoint should succeed if there is nothing to diff --git a/synapse/rest/client/models.py b/synapse/rest/client/models.py index 6278450c70..3d7940b0fc 100644 --- a/synapse/rest/client/models.py +++ b/synapse/rest/client/models.py @@ -36,18 +36,20 @@ class AuthenticationData(RequestBodyModel): type: Optional[StrictStr] = None -class ThreePidRequestTokenBody(RequestBodyModel): - if TYPE_CHECKING: - client_secret: StrictStr - else: - # See also assert_valid_client_secret() - client_secret: constr( - regex="[0-9a-zA-Z.=_-]", # noqa: F722 - min_length=0, - max_length=255, - strict=True, - ) +if TYPE_CHECKING: + ClientSecretStr = StrictStr +else: + # See also assert_valid_client_secret() + ClientSecretStr = constr( + regex="[0-9a-zA-Z.=_-]", # noqa: F722 + min_length=1, + max_length=255, + strict=True, + ) + +class ThreepidRequestTokenBody(RequestBodyModel): + client_secret: ClientSecretStr id_server: Optional[StrictStr] id_access_token: Optional[StrictStr] next_link: Optional[StrictStr] @@ -62,7 +64,7 @@ class ThreePidRequestTokenBody(RequestBodyModel): return token -class EmailRequestTokenBody(ThreePidRequestTokenBody): +class EmailRequestTokenBody(ThreepidRequestTokenBody): email: StrictStr # Canonicalise the email address. The addresses are all stored canonicalised @@ -80,6 +82,6 @@ else: ISO3116_1_Alpha_2 = constr(regex="[A-Z]{2}", strict=True) -class MsisdnRequestTokenBody(ThreePidRequestTokenBody): +class MsisdnRequestTokenBody(ThreepidRequestTokenBody): country: ISO3116_1_Alpha_2 phone_number: StrictStr diff --git a/tests/rest/client/test_models.py b/tests/rest/client/test_models.py index a9da00665e..0b8fcb0c47 100644 --- a/tests/rest/client/test_models.py +++ b/tests/rest/client/test_models.py @@ -11,14 +11,37 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import unittest +import unittest as stdlib_unittest -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError +from typing_extensions import Literal from synapse.rest.client.models import EmailRequestTokenBody -class EmailRequestTokenBodyTestCase(unittest.TestCase): +class ThreepidMediumEnumTestCase(stdlib_unittest.TestCase): + class Model(BaseModel): + medium: Literal["email", "msisdn"] + + def test_accepts_valid_medium_string(self) -> None: + """Sanity check that Pydantic behaves sensibly with an enum-of-str + + This is arguably more of a test of a class that inherits from str and Enum + simultaneously. + """ + model = self.Model.parse_obj({"medium": "email"}) + self.assertEqual(model.medium, "email") + + def test_rejects_invalid_medium_value(self) -> None: + with self.assertRaises(ValidationError): + self.Model.parse_obj({"medium": "interpretive_dance"}) + + def test_rejects_invalid_medium_type(self) -> None: + with self.assertRaises(ValidationError): + self.Model.parse_obj({"medium": 123}) + + +class EmailRequestTokenBodyTestCase(stdlib_unittest.TestCase): base_request = { "client_secret": "hunter2", "email": "alice@wonderland.com", -- cgit 1.5.1 From 74f60cec92c5aff87d6e74d177e95ec5f1a69f2b Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 16 Sep 2022 14:29:03 +0200 Subject: Add an admin API endpoint to find a user based on its external ID in an auth provider. (#13810) --- changelog.d/13810.feature | 1 + docs/admin_api/user_admin_api.md | 38 ++++++++++++++++++ synapse/rest/admin/__init__.py | 2 + synapse/rest/admin/users.py | 27 +++++++++++++ tests/rest/admin/test_user.py | 87 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 155 insertions(+) create mode 100644 changelog.d/13810.feature (limited to 'tests/rest') diff --git a/changelog.d/13810.feature b/changelog.d/13810.feature new file mode 100644 index 0000000000..f0258af661 --- /dev/null +++ b/changelog.d/13810.feature @@ -0,0 +1 @@ +Add an admin API endpoint to find a user based on its external ID in an auth provider. diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 975f05c929..3625c7b6c5 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -1155,3 +1155,41 @@ GET /_synapse/admin/v1/username_available?username=$localpart The request and response format is the same as the [/_matrix/client/r0/register/available](https://matrix.org/docs/spec/client_server/r0.6.0#get-matrix-client-r0-register-available) API. + +### Find a user based on their ID in an auth provider + +The API is: + +``` +GET /_synapse/admin/v1/auth_providers/$provider/users/$external_id +``` + +When a user matched the given ID for the given provider, an HTTP code `200` with a response body like the following is returned: + +```json +{ + "user_id": "@hello:example.org" +} +``` + +**Parameters** + +The following parameters should be set in the URL: + +- `provider` - The ID of the authentication provider, as advertised by the [`GET /_matrix/client/v3/login`](https://spec.matrix.org/latest/client-server-api/#post_matrixclientv3login) API in the `m.login.sso` authentication method. +- `external_id` - The user ID from the authentication provider. Usually corresponds to the `sub` claim for OIDC providers, or to the `uid` attestation for SAML2 providers. + +The `external_id` may have characters that are not URL-safe (typically `/`, `:` or `@`), so it is advised to URL-encode those parameters. + +**Errors** + +Returns a `404` HTTP status code if no user was found, with a response body like this: + +```json +{ + "errcode":"M_NOT_FOUND", + "error":"User not found" +} +``` + +_Added in Synapse 1.68.0._ diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index bac754e1b1..885669f9c7 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -80,6 +80,7 @@ from synapse.rest.admin.users import ( SearchUsersRestServlet, ShadowBanRestServlet, UserAdminServlet, + UserByExternalId, UserMembershipRestServlet, UserRegisterServlet, UserRestServletV2, @@ -275,6 +276,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ListDestinationsRestServlet(hs).register(http_server) RoomMessagesRestServlet(hs).register(http_server) RoomTimestampToEventRestServlet(hs).register(http_server) + UserByExternalId(hs).register(http_server) # Some servlets only get registered for the main process. if hs.config.worker.worker_app is None: diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 78ee9b6532..2ca6b2d08a 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -1156,3 +1156,30 @@ class AccountDataRestServlet(RestServlet): "rooms": by_room_data, }, } + + +class UserByExternalId(RestServlet): + """Find a user based on an external ID from an auth provider""" + + PATTERNS = admin_patterns( + "/auth_providers/(?P[^/]*)/users/(?P[^/]*)" + ) + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastores().main + + async def on_GET( + self, + request: SynapseRequest, + provider: str, + external_id: str, + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + user_id = await self._store.get_user_by_external_id(provider, external_id) + + if user_id is None: + raise NotFoundError("User not found") + + return HTTPStatus.OK, {"user_id": user_id} diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index ec5ccf6fca..9f536ceeb3 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -4140,3 +4140,90 @@ class AccountDataTestCase(unittest.HomeserverTestCase): {"b": 2}, channel.json_body["account_data"]["rooms"]["test_room"]["m.per_room"], ) + + +class UsersByExternalIdTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.get_success( + self.store.record_user_external_id( + "the-auth-provider", "the-external-id", self.other_user + ) + ) + self.get_success( + self.store.record_user_external_id( + "another-auth-provider", "a:complex@external/id", self.other_user + ) + ) + + def test_no_auth(self) -> None: + """Try to lookup a user without authentication.""" + url = ( + "/_synapse/admin/v1/auth_providers/the-auth-provider/users/the-external-id" + ) + + channel = self.make_request( + "GET", + url, + ) + + self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_binding_does_not_exist(self) -> None: + """Tests that a lookup for an external ID that does not exist returns a 404""" + url = "/_synapse/admin/v1/auth_providers/the-auth-provider/users/unknown-id" + + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_success(self) -> None: + """Tests a successful external ID lookup""" + url = ( + "/_synapse/admin/v1/auth_providers/the-auth-provider/users/the-external-id" + ) + + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual( + {"user_id": self.other_user}, + channel.json_body, + ) + + def test_success_urlencoded(self) -> None: + """Tests a successful external ID lookup with an url-encoded ID""" + url = "/_synapse/admin/v1/auth_providers/another-auth-provider/users/a%3Acomplex%40external%2Fid" + + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual( + {"user_id": self.other_user}, + channel.json_body, + ) -- cgit 1.5.1 From 8ae42ab8fa3c6b52d74c24daa7ca75a478fa4fbb Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 21 Sep 2022 15:39:01 +0100 Subject: Support enabling/disabling pushers (from MSC3881) (#13799) Partial implementation of MSC3881 --- changelog.d/13799.feature | 1 + synapse/_scripts/synapse_port_db.py | 1 + synapse/config/experimental.py | 3 + synapse/handlers/register.py | 4 +- synapse/push/__init__.py | 2 + synapse/push/pusherpool.py | 81 ++++++++--- synapse/replication/tcp/client.py | 10 +- synapse/rest/admin/users.py | 4 +- synapse/rest/client/pusher.py | 18 ++- synapse/storage/databases/main/pusher.py | 69 ++++++---- .../schema/main/delta/73/02add_pusher_enabled.sql | 16 +++ tests/push/test_email.py | 4 +- tests/push/test_http.py | 148 +++++++++++++++++++-- tests/replication/test_pusher_shard.py | 2 +- tests/rest/admin/test_user.py | 2 +- 15 files changed, 294 insertions(+), 71 deletions(-) create mode 100644 changelog.d/13799.feature create mode 100644 synapse/storage/schema/main/delta/73/02add_pusher_enabled.sql (limited to 'tests/rest') diff --git a/changelog.d/13799.feature b/changelog.d/13799.feature new file mode 100644 index 0000000000..6c8e5cffe2 --- /dev/null +++ b/changelog.d/13799.feature @@ -0,0 +1 @@ +Add experimental support for [MSC3881: Remotely toggle push notifications for another client](https://github.com/matrix-org/matrix-spec-proposals/pull/3881). diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 30983c47fb..450ba462ba 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -111,6 +111,7 @@ BOOLEAN_COLUMNS = { "e2e_fallback_keys_json": ["used"], "access_tokens": ["used"], "device_lists_changes_in_room": ["converted_to_destinations"], + "pushers": ["enabled"], } diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 702b81e636..f4541a8db0 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -93,3 +93,6 @@ class ExperimentalConfig(Config): # MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices. self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False) + + # MSC3881: Remotely toggle push notifications for another client + self.msc3881_enabled: bool = experimental.get("msc3881_enabled", False) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 20ec22105a..cfcadb34db 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -997,7 +997,7 @@ class RegistrationHandler: assert user_tuple token_id = user_tuple.token_id - await self.pusher_pool.add_pusher( + await self.pusher_pool.add_or_update_pusher( user_id=user_id, access_token=token_id, kind="email", @@ -1005,7 +1005,7 @@ class RegistrationHandler: app_display_name="Email Notifications", device_display_name=threepid["address"], pushkey=threepid["address"], - lang=None, # We don't know a user's language here + lang=None, data={}, ) diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 57c4d70466..ac99d35a7e 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -116,6 +116,7 @@ class PusherConfig: last_stream_ordering: int last_success: Optional[int] failing_since: Optional[int] + enabled: bool def as_dict(self) -> Dict[str, Any]: """Information that can be retrieved about a pusher after creation.""" @@ -128,6 +129,7 @@ class PusherConfig: "lang": self.lang, "profile_tag": self.profile_tag, "pushkey": self.pushkey, + "enabled": self.enabled, } diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 1e0ef44fc7..2597898cf4 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -94,7 +94,7 @@ class PusherPool: return run_as_background_process("start_pushers", self._start_pushers) - async def add_pusher( + async def add_or_update_pusher( self, user_id: str, access_token: Optional[int], @@ -106,6 +106,7 @@ class PusherPool: lang: Optional[str], data: JsonDict, profile_tag: str = "", + enabled: bool = True, ) -> Optional[Pusher]: """Creates a new pusher and adds it to the pool @@ -147,9 +148,20 @@ class PusherPool: last_stream_ordering=last_stream_ordering, last_success=None, failing_since=None, + enabled=enabled, ) ) + # Before we actually persist the pusher, we check if the user already has one + # for this app ID and pushkey. If so, we want to keep the access token in place, + # since this could be one device modifying (e.g. enabling/disabling) another + # device's pusher. + existing_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey( + user_id, app_id, pushkey + ) + if existing_config: + access_token = existing_config.access_token + await self.store.add_pusher( user_id=user_id, access_token=access_token, @@ -163,8 +175,9 @@ class PusherPool: data=data, last_stream_ordering=last_stream_ordering, profile_tag=profile_tag, + enabled=enabled, ) - pusher = await self.start_pusher_by_id(app_id, pushkey, user_id) + pusher = await self.process_pusher_change_by_id(app_id, pushkey, user_id) return pusher @@ -276,10 +289,25 @@ class PusherPool: except Exception: logger.exception("Exception in pusher on_new_receipts") - async def start_pusher_by_id( + async def _get_pusher_config_for_user_by_app_id_and_pushkey( + self, user_id: str, app_id: str, pushkey: str + ) -> Optional[PusherConfig]: + resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) + + pusher_config = None + for r in resultlist: + if r.user_name == user_id: + pusher_config = r + + return pusher_config + + async def process_pusher_change_by_id( self, app_id: str, pushkey: str, user_id: str ) -> Optional[Pusher]: - """Look up the details for the given pusher, and start it + """Look up the details for the given pusher, and either start it if its + "enabled" flag is True, or try to stop it otherwise. + + If the pusher is new and its "enabled" flag is False, the stop is a noop. Returns: The pusher started, if any @@ -290,12 +318,13 @@ class PusherPool: if not self._pusher_shard_config.should_handle(self._instance_name, user_id): return None - resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) + pusher_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey( + user_id, app_id, pushkey + ) - pusher_config = None - for r in resultlist: - if r.user_name == user_id: - pusher_config = r + if pusher_config and not pusher_config.enabled: + self.maybe_stop_pusher(app_id, pushkey, user_id) + return None pusher = None if pusher_config: @@ -305,7 +334,7 @@ class PusherPool: async def _start_pushers(self) -> None: """Start all the pushers""" - pushers = await self.store.get_all_pushers() + pushers = await self.store.get_enabled_pushers() # Stagger starting up the pushers so we don't completely drown the # process on start up. @@ -363,6 +392,8 @@ class PusherPool: synapse_pushers.labels(type(pusher).__name__, pusher.app_id).inc() + logger.info("Starting pusher %s / %s", pusher.user_id, appid_pushkey) + # Check if there *may* be push to process. We do this as this check is a # lot cheaper to do than actually fetching the exact rows we need to # push. @@ -382,16 +413,7 @@ class PusherPool: return pusher async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None: - appid_pushkey = "%s:%s" % (app_id, pushkey) - - byuser = self.pushers.get(user_id, {}) - - if appid_pushkey in byuser: - logger.info("Stopping pusher %s / %s", user_id, appid_pushkey) - pusher = byuser.pop(appid_pushkey) - pusher.on_stop() - - synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec() + self.maybe_stop_pusher(app_id, pushkey, user_id) # We can only delete pushers on master. if self._remove_pusher_client: @@ -402,3 +424,22 @@ class PusherPool: await self.store.delete_pusher_by_app_id_pushkey_user_id( app_id, pushkey, user_id ) + + def maybe_stop_pusher(self, app_id: str, pushkey: str, user_id: str) -> None: + """Stops a pusher with the given app ID and push key if one is running. + + Args: + app_id: the pusher's app ID. + pushkey: the pusher's push key. + user_id: the user the pusher belongs to. Only used for logging. + """ + appid_pushkey = "%s:%s" % (app_id, pushkey) + + byuser = self.pushers.get(user_id, {}) + + if appid_pushkey in byuser: + logger.info("Stopping pusher %s / %s", user_id, appid_pushkey) + pusher = byuser.pop(appid_pushkey) + pusher.on_stop() + + synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec() diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index e4f2201c92..cf9cd6833b 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -189,7 +189,9 @@ class ReplicationDataHandler: if row.deleted: self.stop_pusher(row.user_id, row.app_id, row.pushkey) else: - await self.start_pusher(row.user_id, row.app_id, row.pushkey) + await self.process_pusher_change( + row.user_id, row.app_id, row.pushkey + ) elif stream_name == EventsStream.NAME: # We shouldn't get multiple rows per token for events stream, so # we don't need to optimise this for multiple rows. @@ -334,13 +336,15 @@ class ReplicationDataHandler: logger.info("Stopping pusher %r / %r", user_id, key) pusher.on_stop() - async def start_pusher(self, user_id: str, app_id: str, pushkey: str) -> None: + async def process_pusher_change( + self, user_id: str, app_id: str, pushkey: str + ) -> None: if not self._notify_pushers: return key = "%s:%s" % (app_id, pushkey) logger.info("Starting pusher %r / %r", user_id, key) - await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id) + await self._pusher_pool.process_pusher_change_by_id(app_id, pushkey, user_id) class FederationSenderHandler: diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 2ca6b2d08a..1274773d7e 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -375,7 +375,7 @@ class UserRestServletV2(RestServlet): and self.hs.config.email.email_notif_for_new_users and medium == "email" ): - await self.pusher_pool.add_pusher( + await self.pusher_pool.add_or_update_pusher( user_id=user_id, access_token=None, kind="email", @@ -383,7 +383,7 @@ class UserRestServletV2(RestServlet): app_display_name="Email Notifications", device_display_name=address, pushkey=address, - lang=None, # We don't know a user's language here + lang=None, data={}, ) diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py index 9a1f10f4be..c9f76125dc 100644 --- a/synapse/rest/client/pusher.py +++ b/synapse/rest/client/pusher.py @@ -42,6 +42,7 @@ class PushersRestServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() + self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -51,9 +52,14 @@ class PushersRestServlet(RestServlet): user.to_string() ) - filtered_pushers = [p.as_dict() for p in pushers] + pusher_dicts = [p.as_dict() for p in pushers] - return 200, {"pushers": filtered_pushers} + for pusher in pusher_dicts: + if self._msc3881_enabled: + pusher["org.matrix.msc3881.enabled"] = pusher["enabled"] + del pusher["enabled"] + + return 200, {"pushers": pusher_dicts} class PushersSetRestServlet(RestServlet): @@ -65,6 +71,7 @@ class PushersSetRestServlet(RestServlet): self.auth = hs.get_auth() self.notifier = hs.get_notifier() self.pusher_pool = self.hs.get_pusherpool() + self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -103,6 +110,10 @@ class PushersSetRestServlet(RestServlet): if "append" in content: append = content["append"] + enabled = True + if self._msc3881_enabled and "org.matrix.msc3881.enabled" in content: + enabled = content["org.matrix.msc3881.enabled"] + if not append: await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( app_id=content["app_id"], @@ -111,7 +122,7 @@ class PushersSetRestServlet(RestServlet): ) try: - await self.pusher_pool.add_pusher( + await self.pusher_pool.add_or_update_pusher( user_id=user.to_string(), access_token=requester.access_token_id, kind=content["kind"], @@ -122,6 +133,7 @@ class PushersSetRestServlet(RestServlet): lang=content["lang"], data=content["data"], profile_tag=content.get("profile_tag", ""), + enabled=enabled, ) except PusherConfigException as pce: raise SynapseError( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index bd0cfa7f32..ee55b8c4a9 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -89,6 +89,11 @@ class PusherWorkerStore(SQLBaseStore): ) continue + # If we're using SQLite, then boolean values are integers. This is + # troublesome since some code using the return value of this method might + # expect it to be a boolean, or will expose it to clients (in responses). + r["enabled"] = bool(r["enabled"]) + yield PusherConfig(**r) async def get_pushers_by_app_id_and_pushkey( @@ -100,38 +105,52 @@ class PusherWorkerStore(SQLBaseStore): return await self.get_pushers_by({"user_name": user_id}) async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]: - ret = await self.db_pool.simple_select_list( - "pushers", - keyvalues, - [ - "id", - "user_name", - "access_token", - "profile_tag", - "kind", - "app_id", - "app_display_name", - "device_display_name", - "pushkey", - "ts", - "lang", - "data", - "last_stream_ordering", - "last_success", - "failing_since", - ], + """Retrieve pushers that match the given criteria. + + Args: + keyvalues: A {column: value} dictionary. + + Returns: + The pushers for which the given columns have the given values. + """ + + def get_pushers_by_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]: + # We could technically use simple_select_list here, but we need to call + # COALESCE on the 'enabled' column. While it is technically possible to give + # simple_select_list the whole `COALESCE(...) AS ...` as a column name, it + # feels a bit hacky, so it's probably better to just inline the query. + sql = """ + SELECT + id, user_name, access_token, profile_tag, kind, app_id, + app_display_name, device_display_name, pushkey, ts, lang, data, + last_stream_ordering, last_success, failing_since, + COALESCE(enabled, TRUE) AS enabled + FROM pushers + """ + + sql += "WHERE %s" % (" AND ".join("%s = ?" % (k,) for k in keyvalues),) + + txn.execute(sql, list(keyvalues.values())) + + return self.db_pool.cursor_to_dict(txn) + + ret = await self.db_pool.runInteraction( desc="get_pushers_by", + func=get_pushers_by_txn, ) + return self._decode_pushers_rows(ret) - async def get_all_pushers(self) -> Iterator[PusherConfig]: - def get_pushers(txn: LoggingTransaction) -> Iterator[PusherConfig]: - txn.execute("SELECT * FROM pushers") + async def get_enabled_pushers(self) -> Iterator[PusherConfig]: + def get_enabled_pushers_txn(txn: LoggingTransaction) -> Iterator[PusherConfig]: + txn.execute("SELECT * FROM pushers WHERE COALESCE(enabled, TRUE)") rows = self.db_pool.cursor_to_dict(txn) return self._decode_pushers_rows(rows) - return await self.db_pool.runInteraction("get_all_pushers", get_pushers) + return await self.db_pool.runInteraction( + "get_enabled_pushers", get_enabled_pushers_txn + ) async def get_all_updated_pushers_rows( self, instance_name: str, last_id: int, current_id: int, limit: int @@ -476,6 +495,7 @@ class PusherStore(PusherWorkerStore): data: Optional[JsonDict], last_stream_ordering: int, profile_tag: str = "", + enabled: bool = True, ) -> None: async with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on @@ -494,6 +514,7 @@ class PusherStore(PusherWorkerStore): "last_stream_ordering": last_stream_ordering, "profile_tag": profile_tag, "id": stream_id, + "enabled": enabled, }, desc="add_pusher", lock=False, diff --git a/synapse/storage/schema/main/delta/73/02add_pusher_enabled.sql b/synapse/storage/schema/main/delta/73/02add_pusher_enabled.sql new file mode 100644 index 0000000000..dba3b4900b --- /dev/null +++ b/synapse/storage/schema/main/delta/73/02add_pusher_enabled.sql @@ -0,0 +1,16 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE pushers ADD COLUMN enabled BOOLEAN; \ No newline at end of file diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 7a3b0d6755..fd14568f55 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -114,7 +114,7 @@ class EmailPusherTests(HomeserverTestCase): ) self.pusher = self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=self.user_id, access_token=self.token_id, kind="email", @@ -136,7 +136,7 @@ class EmailPusherTests(HomeserverTestCase): """ with self.assertRaises(SynapseError) as cm: self.get_success_or_raise( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=self.user_id, access_token=self.token_id, kind="email", diff --git a/tests/push/test_http.py b/tests/push/test_http.py index d9c68cdd2d..af67d84463 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -19,8 +19,8 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.logging.context import make_deferred_yieldable -from synapse.push import PusherConfigException -from synapse.rest.client import login, push_rule, receipts, room +from synapse.push import PusherConfig, PusherConfigException +from synapse.rest.client import login, push_rule, pusher, receipts, room from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock @@ -35,6 +35,7 @@ class HTTPPusherTests(HomeserverTestCase): login.register_servlets, receipts.register_servlets, push_rule.register_servlets, + pusher.register_servlets, ] user_id = True hijack_auth = False @@ -74,7 +75,7 @@ class HTTPPusherTests(HomeserverTestCase): def test_data(data: Optional[JsonDict]) -> None: self.get_failure( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -119,7 +120,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -235,7 +236,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -355,7 +356,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -441,7 +442,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -518,7 +519,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -624,7 +625,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -728,18 +729,38 @@ class HTTPPusherTests(HomeserverTestCase): ) self.assertEqual(channel.code, 200, channel.json_body) - def _make_user_with_pusher(self, username: str) -> Tuple[str, str]: + def _make_user_with_pusher( + self, username: str, enabled: bool = True + ) -> Tuple[str, str]: + """Registers a user and creates a pusher for them. + + Args: + username: the localpart of the new user's Matrix ID. + enabled: whether to create the pusher in an enabled or disabled state. + """ user_id = self.register_user(username, "pass") access_token = self.login(username, "pass") # Register the pusher + self._set_pusher(user_id, access_token, enabled) + + return user_id, access_token + + def _set_pusher(self, user_id: str, access_token: str, enabled: bool) -> None: + """Creates or updates the pusher for the given user. + + Args: + user_id: the user's Matrix ID. + access_token: the access token associated with the pusher. + enabled: whether to enable or disable the pusher. + """ user_tuple = self.get_success( self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -749,11 +770,10 @@ class HTTPPusherTests(HomeserverTestCase): pushkey="a@example.com", lang=None, data={"url": "http://example.com/_matrix/push/v1/notify"}, + enabled=enabled, ) ) - return user_id, access_token - def test_dont_notify_rule_overrides_message(self) -> None: """ The override push rule will suppress notification @@ -791,3 +811,105 @@ class HTTPPusherTests(HomeserverTestCase): # The user sends a message back (sends a notification) self.helper.send(room, body="Hello", tok=access_token) self.assertEqual(len(self.push_attempts), 1) + + @override_config({"experimental_features": {"msc3881_enabled": True}}) + def test_disable(self) -> None: + """Tests that disabling a pusher means it's not pushed to anymore.""" + user_id, access_token = self._make_user_with_pusher("user") + other_user_id, other_access_token = self._make_user_with_pusher("otheruser") + + room = self.helper.create_room_as(user_id, tok=access_token) + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # Send a message and check that it generated a push. + self.helper.send(room, body="Hi!", tok=other_access_token) + self.assertEqual(len(self.push_attempts), 1) + + # Disable the pusher. + self._set_pusher(user_id, access_token, enabled=False) + + # Send another message and check that it did not generate a push. + self.helper.send(room, body="Hi!", tok=other_access_token) + self.assertEqual(len(self.push_attempts), 1) + + # Get the pushers for the user and check that it is marked as disabled. + channel = self.make_request("GET", "/pushers", access_token=access_token) + self.assertEqual(channel.code, 200) + self.assertEqual(len(channel.json_body["pushers"]), 1) + + enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"] + self.assertFalse(enabled) + self.assertTrue(isinstance(enabled, bool)) + + @override_config({"experimental_features": {"msc3881_enabled": True}}) + def test_enable(self) -> None: + """Tests that enabling a disabled pusher means it gets pushed to.""" + # Create the user with the pusher already disabled. + user_id, access_token = self._make_user_with_pusher("user", enabled=False) + other_user_id, other_access_token = self._make_user_with_pusher("otheruser") + + room = self.helper.create_room_as(user_id, tok=access_token) + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # Send a message and check that it did not generate a push. + self.helper.send(room, body="Hi!", tok=other_access_token) + self.assertEqual(len(self.push_attempts), 0) + + # Enable the pusher. + self._set_pusher(user_id, access_token, enabled=True) + + # Send another message and check that it did generate a push. + self.helper.send(room, body="Hi!", tok=other_access_token) + self.assertEqual(len(self.push_attempts), 1) + + # Get the pushers for the user and check that it is marked as enabled. + channel = self.make_request("GET", "/pushers", access_token=access_token) + self.assertEqual(channel.code, 200) + self.assertEqual(len(channel.json_body["pushers"]), 1) + + enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"] + self.assertTrue(enabled) + self.assertTrue(isinstance(enabled, bool)) + + @override_config({"experimental_features": {"msc3881_enabled": True}}) + def test_null_enabled(self) -> None: + """Tests that a pusher that has an 'enabled' column set to NULL (eg pushers + created before the column was introduced) is considered enabled. + """ + # We intentionally set 'enabled' to None so that it's stored as NULL in the + # database. + user_id, access_token = self._make_user_with_pusher("user", enabled=None) # type: ignore[arg-type] + + channel = self.make_request("GET", "/pushers", access_token=access_token) + self.assertEqual(channel.code, 200) + self.assertEqual(len(channel.json_body["pushers"]), 1) + self.assertTrue(channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]) + + def test_update_different_device_access_token(self) -> None: + """Tests that if we create a pusher from one device, the update it from another + device, the access token associated with the pusher stays the same. + """ + # Create a user with a pusher. + user_id, access_token = self._make_user_with_pusher("user") + + # Get the token ID for the current access token, since that's what we store in + # the pushers table. + user_tuple = self.get_success( + self.hs.get_datastores().main.get_user_by_access_token(access_token) + ) + token_id = user_tuple.token_id + + # Generate a new access token, and update the pusher with it. + new_token = self.login("user", "pass") + self._set_pusher(user_id, new_token, enabled=False) + + # Get the current list of pushers for the user. + ret = self.get_success( + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) + ) + pushers: List[PusherConfig] = list(ret) + + # Check that we still have one pusher, and that the access token associated with + # it didn't change. + self.assertEqual(len(pushers), 1) + self.assertEqual(pushers[0].access_token, token_id) diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index 8f4f6688ce..59fea93e49 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): token_id = user_dict.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 9f536ceeb3..1847e6ad6b 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2839,7 +2839,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=self.other_user, access_token=token_id, kind="http", -- cgit 1.5.1 From 0fd2f2d46064efd37284a36d5b478815d69ddd96 Mon Sep 17 00:00:00 2001 From: Hugh Nimmo-Smith Date: Wed, 21 Sep 2022 16:12:29 +0100 Subject: Implementation of MSC3882 login token request (#13722) --- changelog.d/13722.feature | 1 + synapse/config/experimental.py | 7 ++ synapse/rest/__init__.py | 2 + synapse/rest/client/login_token_request.py | 94 ++++++++++++++++++ synapse/rest/client/versions.py | 2 + tests/rest/client/test_login_token_request.py | 132 ++++++++++++++++++++++++++ 6 files changed, 238 insertions(+) create mode 100644 changelog.d/13722.feature create mode 100644 synapse/rest/client/login_token_request.py create mode 100644 tests/rest/client/test_login_token_request.py (limited to 'tests/rest') diff --git a/changelog.d/13722.feature b/changelog.d/13722.feature new file mode 100644 index 0000000000..588d143c0f --- /dev/null +++ b/changelog.d/13722.feature @@ -0,0 +1 @@ +Experimental implementation of MSC3882 to allow an existing device/session to generate a login token for use on a new device/session. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index f4541a8db0..bf27f6c101 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -96,3 +96,10 @@ class ExperimentalConfig(Config): # MSC3881: Remotely toggle push notifications for another client self.msc3881_enabled: bool = experimental.get("msc3881_enabled", False) + + # MSC3882: Allow an existing session to sign in a new session + self.msc3882_enabled: bool = experimental.get("msc3882_enabled", False) + self.msc3882_ui_auth: bool = experimental.get("msc3882_ui_auth", True) + self.msc3882_token_timeout = self.parse_duration( + experimental.get("msc3882_token_timeout", "5m") + ) diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index b712215112..9a2ab99ede 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -30,6 +30,7 @@ from synapse.rest.client import ( keys, knock, login as v1_login, + login_token_request, logout, mutual_rooms, notifications, @@ -130,3 +131,4 @@ class ClientRestResource(JsonResource): # unstable mutual_rooms.register_servlets(hs, client_resource) + login_token_request.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/login_token_request.py b/synapse/rest/client/login_token_request.py new file mode 100644 index 0000000000..ca5c54bf17 --- /dev/null +++ b/synapse/rest/client/login_token_request.py @@ -0,0 +1,94 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING, Tuple + +from synapse.http.server import HttpServer +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.rest.client._base import client_patterns, interactive_auth_handler +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class LoginTokenRequestServlet(RestServlet): + """ + Get a token that can be used with `m.login.token` to log in a second device. + + Request: + + POST /login/token HTTP/1.1 + Content-Type: application/json + + {} + + Response: + + HTTP/1.1 200 OK + { + "login_token": "ABDEFGH", + "expires_in": 3600, + } + """ + + PATTERNS = client_patterns("/login/token$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastores().main + self.clock = hs.get_clock() + self.server_name = hs.config.server.server_name + self.macaroon_gen = hs.get_macaroon_generator() + self.auth_handler = hs.get_auth_handler() + self.token_timeout = hs.config.experimental.msc3882_token_timeout + self.ui_auth = hs.config.experimental.msc3882_ui_auth + + @interactive_auth_handler + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + body = parse_json_object_from_request(request) + + if self.ui_auth: + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + "issue a new access token for your account", + can_skip_ui_auth=False, # Don't allow skipping of UI auth + ) + + login_token = self.macaroon_gen.generate_short_term_login_token( + user_id=requester.user.to_string(), + auth_provider_id="org.matrix.msc3882.login_token_request", + duration_in_ms=self.token_timeout, + ) + + return ( + 200, + { + "login_token": login_token, + "expires_in": self.token_timeout // 1000, + }, + ) + + +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + if hs.config.experimental.msc3882_enabled: + LoginTokenRequestServlet(hs).register(http_server) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index c516cda95d..c3488f4330 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -105,6 +105,8 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above # Allows moderators to fetch redacted event content as described in MSC2815 "fi.mau.msc2815": self.config.experimental.msc2815_enabled, + # Adds support for login token requests as per MSC3882 + "org.matrix.msc3882": self.config.experimental.msc3882_enabled, }, }, ) diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py new file mode 100644 index 0000000000..d5bb16c98d --- /dev/null +++ b/tests/rest/client/test_login_token_request.py @@ -0,0 +1,132 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.rest import admin +from synapse.rest.client import login, login_token_request +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest +from tests.unittest import override_config + + +class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): + + servlets = [ + login.register_servlets, + admin.register_servlets, + login_token_request.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.hs = self.setup_test_homeserver() + self.hs.config.registration.enable_registration = True + self.hs.config.registration.registrations_require_3pid = [] + self.hs.config.registration.auto_join_rooms = [] + self.hs.config.captcha.enable_registration_captcha = False + + return self.hs + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.user = "user123" + self.password = "password" + + def test_disabled(self) -> None: + channel = self.make_request("POST", "/login/token", {}, access_token=None) + self.assertEqual(channel.code, 400) + + self.register_user(self.user, self.password) + token = self.login(self.user, self.password) + + channel = self.make_request("POST", "/login/token", {}, access_token=token) + self.assertEqual(channel.code, 400) + + @override_config({"experimental_features": {"msc3882_enabled": True}}) + def test_require_auth(self) -> None: + channel = self.make_request("POST", "/login/token", {}, access_token=None) + self.assertEqual(channel.code, 401) + + @override_config({"experimental_features": {"msc3882_enabled": True}}) + def test_uia_on(self) -> None: + user_id = self.register_user(self.user, self.password) + token = self.login(self.user, self.password) + + channel = self.make_request("POST", "/login/token", {}, access_token=token) + self.assertEqual(channel.code, 401) + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + session = channel.json_body["session"] + + uia = { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.password, + "session": session, + }, + } + + channel = self.make_request("POST", "/login/token", uia, access_token=token) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["expires_in"], 300) + + login_token = channel.json_body["login_token"] + + channel = self.make_request( + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body["user_id"], user_id) + + @override_config( + {"experimental_features": {"msc3882_enabled": True, "msc3882_ui_auth": False}} + ) + def test_uia_off(self) -> None: + user_id = self.register_user(self.user, self.password) + token = self.login(self.user, self.password) + + channel = self.make_request("POST", "/login/token", {}, access_token=token) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["expires_in"], 300) + + login_token = channel.json_body["login_token"] + + channel = self.make_request( + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body["user_id"], user_id) + + @override_config( + { + "experimental_features": { + "msc3882_enabled": True, + "msc3882_ui_auth": False, + "msc3882_token_timeout": "15s", + } + } + ) + def test_expires_in(self) -> None: + self.register_user(self.user, self.password) + token = self.login(self.user, self.password) + + channel = self.make_request("POST", "/login/token", {}, access_token=token) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["expires_in"], 15) -- cgit 1.5.1 From b7272b73aa38dcb19c9b075514f963390358113d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 22 Sep 2022 08:47:49 -0400 Subject: Properly paginate forward in the /relations API. (#13840) This fixes a bug where the `/relations` API with `dir=f` would skip the first item of each page (except the first page), causing incomplete data to be returned to the client. --- changelog.d/13840.bugfix | 1 + synapse/storage/databases/main/relations.py | 38 +++++++++++++++++++++-------- synapse/storage/databases/main/stream.py | 6 ++--- tests/rest/client/test_relations.py | 29 +++++++++++++++++++++- 4 files changed, 60 insertions(+), 14 deletions(-) create mode 100644 changelog.d/13840.bugfix (limited to 'tests/rest') diff --git a/changelog.d/13840.bugfix b/changelog.d/13840.bugfix new file mode 100644 index 0000000000..0f014439a8 --- /dev/null +++ b/changelog.d/13840.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse v1.53.0 where the experimental implementation of [MSC3715](https://github.com/matrix-org/matrix-spec-proposals/pull/3715) would give incorrect results when paginating forward. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 7bd27790eb..898947af95 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -51,6 +51,8 @@ class _RelatedEvent: event_id: str # The sender of the related event. sender: str + topological_ordering: Optional[int] + stream_ordering: int class RelationsWorkerStore(SQLBaseStore): @@ -91,6 +93,9 @@ class RelationsWorkerStore(SQLBaseStore): # it. The `event_id` must match the `event.event_id`. assert event.event_id == event_id + # Ensure bad limits aren't being passed in. + assert limit >= 0 + where_clause = ["relates_to_id = ?", "room_id = ?"] where_args: List[Union[str, int]] = [event.event_id, room_id] is_redacted = event.internal_metadata.is_redacted() @@ -139,21 +144,34 @@ class RelationsWorkerStore(SQLBaseStore): ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: txn.execute(sql, where_args + [limit + 1]) - last_topo_id = None - last_stream_id = None events = [] - for row in txn: + for event_id, relation_type, sender, topo_ordering, stream_ordering in txn: # Do not include edits for redacted events as they leak event # content. - if not is_redacted or row[1] != RelationTypes.REPLACE: - events.append(_RelatedEvent(row[0], row[2])) - last_topo_id = row[3] - last_stream_id = row[4] + if not is_redacted or relation_type != RelationTypes.REPLACE: + events.append( + _RelatedEvent(event_id, sender, topo_ordering, stream_ordering) + ) - # If there are more events, generate the next pagination key. + # If there are more events, generate the next pagination key from the + # last event returned. next_token = None - if len(events) > limit and last_topo_id and last_stream_id: - next_key = RoomStreamToken(last_topo_id, last_stream_id) + if len(events) > limit: + # Instead of using the last row (which tells us there is more + # data), use the last row to be returned. + events = events[:limit] + + topo = events[-1].topological_ordering + token = events[-1].stream_ordering + if direction == "b": + # Tokens are positions between events. + # This token points *after* the last event in the chunk. + # We need it to point to the event before it in the chunk + # when we are going backwards so we subtract one from the + # stream part. + token -= 1 + next_key = RoomStreamToken(topo, token) + if from_token: next_token = from_token.copy_and_replace( StreamKeyType.ROOM, next_key diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 3f9bfaeac5..530f04e149 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -1334,15 +1334,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if rows: topo = rows[-1].topological_ordering - toke = rows[-1].stream_ordering + token = rows[-1].stream_ordering if direction == "b": # Tokens are positions between events. # This token points *after* the last event in the chunk. # We need it to point to the event before it in the chunk # when we are going backwards so we subtract one from the # stream part. - toke -= 1 - next_token = RoomStreamToken(topo, toke) + token -= 1 + next_token = RoomStreamToken(topo, token) else: # TODO (erikj): We should work out what to do here instead. next_token = to_token if to_token else from_token diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 651f4f415d..d33e34d829 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -788,6 +788,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase): channel.json_body["chunk"][0], ) + @unittest.override_config({"experimental_features": {"msc3715_enabled": True}}) def test_repeated_paginate_relations(self) -> None: """Test that if we paginate using a limit and tokens then we get the expected events. @@ -809,7 +810,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", - f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", + f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?limit=3{from_token}", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) @@ -827,6 +828,32 @@ class RelationPaginationTestCase(BaseRelationsTestCase): found_event_ids.reverse() self.assertEqual(found_event_ids, expected_event_ids) + # Test forward pagination. + prev_token = "" + found_event_ids = [] + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + channel = self.make_request( + "GET", + f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?org.matrix.msc3715.dir=f&limit=3{from_token}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + next_batch = channel.json_body.get("next_batch") + + self.assertNotEqual(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + self.assertEqual(found_event_ids, expected_event_ids) + def test_pagination_from_sync_and_messages(self) -> None: """Pagination tokens from /sync and /messages can be used to paginate /relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A") -- cgit 1.5.1 From 87fe9db4675e510ea9c0234429b4773341c4e86d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 27 Sep 2022 10:47:34 -0400 Subject: Support the stable dir parameter for /relations. (#13920) Since MSC3715 has passed FCP, the stable parameter can be used. This currently falls back to the unstable parameter if the stable parameter is not provided (and MSC3715 support is enabled in the configuration). --- changelog.d/13920.feature | 1 + synapse/rest/client/relations.py | 24 +++++++++++++++--------- tests/rest/client/test_relations.py | 6 ++---- 3 files changed, 18 insertions(+), 13 deletions(-) create mode 100644 changelog.d/13920.feature (limited to 'tests/rest') diff --git a/changelog.d/13920.feature b/changelog.d/13920.feature new file mode 100644 index 0000000000..aee702bcd2 --- /dev/null +++ b/changelog.d/13920.feature @@ -0,0 +1 @@ +Support a `dir` parameter on the `/relations` endpoint per [MSC3715](https://github.com/matrix-org/matrix-doc/pull/3715). diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index ce97080013..205c556f64 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -56,15 +56,21 @@ class RelationPaginationServlet(RestServlet): requester = await self.auth.get_user_by_req(request, allow_guest=True) limit = parse_integer(request, "limit", default=5) - if self._msc3715_enabled: - direction = parse_string( - request, - "org.matrix.msc3715.dir", - default="b", - allowed_values=["f", "b"], - ) - else: - direction = "b" + # Fetch the direction parameter, if provided. + # + # TODO Use PaginationConfig.from_request when the unstable parameter is + # no longer needed. + direction = parse_string(request, "dir", allowed_values=["f", "b"]) + if direction is None: + if self._msc3715_enabled: + direction = parse_string( + request, + "org.matrix.msc3715.dir", + default="b", + allowed_values=["f", "b"], + ) + else: + direction = "b" from_token_str = parse_string(request, "from") to_token_str = parse_string(request, "to") diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index d33e34d829..fef3b72d76 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -728,7 +728,6 @@ class RelationsTestCase(BaseRelationsTestCase): class RelationPaginationTestCase(BaseRelationsTestCase): - @unittest.override_config({"experimental_features": {"msc3715_enabled": True}}) def test_basic_paginate_relations(self) -> None: """Tests that calling pagination API correctly the latest relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") @@ -771,7 +770,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", f"/_matrix/client/v1/rooms/{self.room}/relations" - f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f", + f"/{self.parent_id}?limit=1&dir=f", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) @@ -788,7 +787,6 @@ class RelationPaginationTestCase(BaseRelationsTestCase): channel.json_body["chunk"][0], ) - @unittest.override_config({"experimental_features": {"msc3715_enabled": True}}) def test_repeated_paginate_relations(self) -> None: """Test that if we paginate using a limit and tokens then we get the expected events. @@ -838,7 +836,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", - f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?org.matrix.msc3715.dir=f&limit=3{from_token}", + f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?dir=f&limit=3{from_token}", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) -- cgit 1.5.1 From a2cf66a94d5dfd9d6496ac3e48ec9a22f17be69a Mon Sep 17 00:00:00 2001 From: Shay Date: Wed, 28 Sep 2022 02:39:03 -0700 Subject: Prepatory work for batching events to send (#13487) This PR begins work on batching up events during the creation of a room. The PR splits out the creation and sending/persisting of the events. The first three events in the creation of the room-creating the room, joining the creator to the room, and the power levels event are sent sequentially, while the subsequent events are created and collected to be sent at the end of the function. This is currently done by appending them to a list and then iterating over the list to send, the next step (after this PR) would be to send and persist the collected events as a batch. --- changelog.d/13487.misc | 1 + synapse/handlers/message.py | 175 ++++++++++++++++++++++++++-------------- synapse/handlers/room.py | 155 ++++++++++++++++++++++++----------- synapse/state/__init__.py | 63 +++++++++++++++ tests/rest/client/test_rooms.py | 4 +- 5 files changed, 290 insertions(+), 108 deletions(-) create mode 100644 changelog.d/13487.misc (limited to 'tests/rest') diff --git a/changelog.d/13487.misc b/changelog.d/13487.misc new file mode 100644 index 0000000000..761adc8b05 --- /dev/null +++ b/changelog.d/13487.misc @@ -0,0 +1 @@ +Speed up creation of DM rooms. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e07cda133a..062f93bc67 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -63,6 +63,7 @@ from synapse.types import ( MutableStateMap, Requester, RoomAlias, + StateMap, StreamToken, UserID, create_requester, @@ -567,9 +568,17 @@ class EventCreationHandler: outlier: bool = False, historical: bool = False, depth: Optional[int] = None, + state_map: Optional[StateMap[str]] = None, + for_batch: bool = False, + current_state_group: Optional[int] = None, ) -> Tuple[EventBase, EventContext]: """ - Given a dict from a client, create a new event. + Given a dict from a client, create a new event. If bool for_batch is true, will + create an event using the prev_event_ids, and will create an event context for + the event using the parameters state_map and current_state_group, thus these parameters + must be provided in this case if for_batch is True. The subsequently created event + and context are suitable for being batched up and bulk persisted to the database + with other similarly created events. Creates an FrozenEvent object, filling out auth_events, prev_events, etc. @@ -612,16 +621,27 @@ class EventCreationHandler: outlier: Indicates whether the event is an `outlier`, i.e. if it's from an arbitrary point and floating in the DAG as opposed to being inline with the current DAG. + historical: Indicates whether the message is being inserted back in time around some existing events. This is used to skip a few checks and mark the event as backfilled. + depth: Override the depth used to order the event in the DAG. Should normally be set to None, which will cause the depth to be calculated based on the prev_events. + state_map: A state map of previously created events, used only when creating events + for batch persisting + + for_batch: whether the event is being created for batch persisting to the db + + current_state_group: the current state group, used only for creating events for + batch persisting + Raises: ResourceLimitError if server is blocked to some resource being exceeded + Returns: Tuple of created event, Context """ @@ -693,6 +713,9 @@ class EventCreationHandler: auth_event_ids=auth_event_ids, state_event_ids=state_event_ids, depth=depth, + state_map=state_map, + for_batch=for_batch, + current_state_group=current_state_group, ) # In an ideal world we wouldn't need the second part of this condition. However, @@ -707,10 +730,14 @@ class EventCreationHandler: # federation as well as those created locally. As of room v3, aliases events # can be created by users that are not in the room, therefore we have to # tolerate them in event_auth.check(). - prev_state_ids = await context.get_prev_state_ids( - StateFilter.from_types([(EventTypes.Member, None)]) - ) - prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) + if for_batch: + assert state_map is not None + prev_event_id = state_map.get((EventTypes.Member, event.sender)) + else: + prev_state_ids = await context.get_prev_state_ids( + StateFilter.from_types([(EventTypes.Member, None)]) + ) + prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) prev_event = ( await self.store.get_event(prev_event_id, allow_none=True) if prev_event_id @@ -1009,8 +1036,16 @@ class EventCreationHandler: auth_event_ids: Optional[List[str]] = None, state_event_ids: Optional[List[str]] = None, depth: Optional[int] = None, + state_map: Optional[StateMap[str]] = None, + for_batch: bool = False, + current_state_group: Optional[int] = None, ) -> Tuple[EventBase, EventContext]: - """Create a new event for a local client + """Create a new event for a local client. If bool for_batch is true, will + create an event using the prev_event_ids, and will create an event context for + the event using the parameters state_map and current_state_group, thus these parameters + must be provided in this case if for_batch is True. The subsequently created event + and context are suitable for being batched up and bulk persisted to the database + with other similarly created events. Args: builder: @@ -1043,6 +1078,14 @@ class EventCreationHandler: Should normally be set to None, which will cause the depth to be calculated based on the prev_events. + state_map: A state map of previously created events, used only when creating events + for batch persisting + + for_batch: whether the event is being created for batch persisting to the db + + current_state_group: the current state group, used only for creating events for + batch persisting + Returns: Tuple of created event, context """ @@ -1095,64 +1138,76 @@ class EventCreationHandler: builder.type == EventTypes.Create or prev_event_ids ), "Attempting to create a non-m.room.create event with no prev_events" - event = await builder.build( - prev_event_ids=prev_event_ids, - auth_event_ids=auth_event_ids, - depth=depth, - ) + if for_batch: + assert prev_event_ids is not None + assert state_map is not None + assert current_state_group is not None + auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map) + event = await builder.build( + prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth + ) + context = await self.state.compute_event_context_for_batched( + event, state_map, current_state_group + ) + else: + event = await builder.build( + prev_event_ids=prev_event_ids, + auth_event_ids=auth_event_ids, + depth=depth, + ) - # Pass on the outlier property from the builder to the event - # after it is created - if builder.internal_metadata.outlier: - event.internal_metadata.outlier = True - context = EventContext.for_outlier(self._storage_controllers) - elif ( - event.type == EventTypes.MSC2716_INSERTION - and state_event_ids - and builder.internal_metadata.is_historical() - ): - # Add explicit state to the insertion event so it has state to derive - # from even though it's floating with no `prev_events`. The rest of - # the batch can derive from this state and state_group. - # - # TODO(faster_joins): figure out how this works, and make sure that the - # old state is complete. - # https://github.com/matrix-org/synapse/issues/13003 - metadata = await self.store.get_metadata_for_events(state_event_ids) - - state_map_for_event: MutableStateMap[str] = {} - for state_id in state_event_ids: - data = metadata.get(state_id) - if data is None: - # We're trying to persist a new historical batch of events - # with the given state, e.g. via - # `RoomBatchSendEventRestServlet`. The state can be inferred - # by Synapse or set directly by the client. - # - # Either way, we should have persisted all the state before - # getting here. - raise Exception( - f"State event {state_id} not found in DB," - " Synapse should have persisted it before using it." - ) + # Pass on the outlier property from the builder to the event + # after it is created + if builder.internal_metadata.outlier: + event.internal_metadata.outlier = True + context = EventContext.for_outlier(self._storage_controllers) + elif ( + event.type == EventTypes.MSC2716_INSERTION + and state_event_ids + and builder.internal_metadata.is_historical() + ): + # Add explicit state to the insertion event so it has state to derive + # from even though it's floating with no `prev_events`. The rest of + # the batch can derive from this state and state_group. + # + # TODO(faster_joins): figure out how this works, and make sure that the + # old state is complete. + # https://github.com/matrix-org/synapse/issues/13003 + metadata = await self.store.get_metadata_for_events(state_event_ids) + + state_map_for_event: MutableStateMap[str] = {} + for state_id in state_event_ids: + data = metadata.get(state_id) + if data is None: + # We're trying to persist a new historical batch of events + # with the given state, e.g. via + # `RoomBatchSendEventRestServlet`. The state can be inferred + # by Synapse or set directly by the client. + # + # Either way, we should have persisted all the state before + # getting here. + raise Exception( + f"State event {state_id} not found in DB," + " Synapse should have persisted it before using it." + ) - if data.state_key is None: - raise Exception( - f"Trying to set non-state event {state_id} as state" - ) + if data.state_key is None: + raise Exception( + f"Trying to set non-state event {state_id} as state" + ) - state_map_for_event[(data.event_type, data.state_key)] = state_id + state_map_for_event[(data.event_type, data.state_key)] = state_id - context = await self.state.compute_event_context( - event, - state_ids_before_event=state_map_for_event, - # TODO(faster_joins): check how MSC2716 works and whether we can have - # partial state here - # https://github.com/matrix-org/synapse/issues/13003 - partial_state=False, - ) - else: - context = await self.state.compute_event_context(event) + context = await self.state.compute_event_context( + event, + state_ids_before_event=state_map_for_event, + # TODO(faster_joins): check how MSC2716 works and whether we can have + # partial state here + # https://github.com/matrix-org/synapse/issues/13003 + partial_state=False, + ) + else: + context = await self.state.compute_event_context(event) if requester: context.app_service = requester.app_service diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 33e9a87002..09a1a82e6c 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -716,7 +716,7 @@ class RoomCreationHandler: if ( self._server_notices_mxid is not None - and requester.user.to_string() == self._server_notices_mxid + and user_id == self._server_notices_mxid ): # allow the server notices mxid to create rooms is_requester_admin = True @@ -1042,7 +1042,9 @@ class RoomCreationHandler: creator_join_profile: Optional[JsonDict] = None, ratelimit: bool = True, ) -> Tuple[int, str, int]: - """Sends the initial events into a new room. + """Sends the initial events into a new room. Sends the room creation, membership, + and power level events into the room sequentially, then creates and batches up the + rest of the events to persist as a batch to the DB. `power_level_content_override` doesn't apply when initial state has power level state event content. @@ -1053,13 +1055,21 @@ class RoomCreationHandler: """ creator_id = creator.user.to_string() - event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} - depth = 1 + # the last event sent/persisted to the db last_sent_event_id: Optional[str] = None - - def create(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict: + # the most recently created event + prev_event: List[str] = [] + # a map of event types, state keys -> event_ids. We collect these mappings this as events are + # created (but not persisted to the db) to determine state for future created events + # (as this info can't be pulled from the db) + state_map: MutableStateMap[str] = {} + # current_state_group of last event created. Used for computing event context of + # events to be batched + current_state_group = None + + def create_event_dict(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict: e = {"type": etype, "content": content} e.update(event_keys) @@ -1067,32 +1077,52 @@ class RoomCreationHandler: return e - async def send(etype: str, content: JsonDict, **kwargs: Any) -> int: - nonlocal last_sent_event_id + async def create_event( + etype: str, + content: JsonDict, + for_batch: bool, + **kwargs: Any, + ) -> Tuple[EventBase, synapse.events.snapshot.EventContext]: nonlocal depth + nonlocal prev_event - event = create(etype, content, **kwargs) - logger.debug("Sending %s in new room", etype) - # Allow these events to be sent even if the user is shadow-banned to - # allow the room creation to complete. - ( - sent_event, - last_stream_id, - ) = await self.event_creation_handler.create_and_send_nonmember_event( + event_dict = create_event_dict(etype, content, **kwargs) + + new_event, new_context = await self.event_creation_handler.create_event( creator, - event, + event_dict, + prev_event_ids=prev_event, + depth=depth, + state_map=state_map, + for_batch=for_batch, + current_state_group=current_state_group, + ) + depth += 1 + prev_event = [new_event.event_id] + state_map[(new_event.type, new_event.state_key)] = new_event.event_id + + return new_event, new_context + + async def send( + event: EventBase, + context: synapse.events.snapshot.EventContext, + creator: Requester, + ) -> int: + nonlocal last_sent_event_id + + ev = await self.event_creation_handler.handle_new_client_event( + requester=creator, + event=event, + context=context, ratelimit=False, ignore_shadow_ban=True, - # Note: we don't pass state_event_ids here because this triggers - # an additional query per event to look them up from the events table. - prev_event_ids=[last_sent_event_id] if last_sent_event_id else [], - depth=depth, ) - last_sent_event_id = sent_event.event_id - depth += 1 + last_sent_event_id = ev.event_id - return last_stream_id + # we know it was persisted, so must have a stream ordering + assert ev.internal_metadata.stream_ordering + return ev.internal_metadata.stream_ordering try: config = self._presets_dict[preset_config] @@ -1102,9 +1132,13 @@ class RoomCreationHandler: ) creation_content.update({"creator": creator_id}) - await send(etype=EventTypes.Create, content=creation_content) + creation_event, creation_context = await create_event( + EventTypes.Create, creation_content, False + ) logger.debug("Sending %s in new room", EventTypes.Member) + await send(creation_event, creation_context, creator) + # Room create event must exist at this point assert last_sent_event_id is not None member_event_id, _ = await self.room_member_handler.update_membership( @@ -1119,14 +1153,22 @@ class RoomCreationHandler: depth=depth, ) last_sent_event_id = member_event_id + prev_event = [member_event_id] + + # update the depth and state map here as the membership event has been created + # through a different code path + depth += 1 + state_map[(EventTypes.Member, creator.user.to_string())] = member_event_id # We treat the power levels override specially as this needs to be one # of the first events that get sent into a room. pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) if pl_content is not None: - last_sent_stream_id = await send( - etype=EventTypes.PowerLevels, content=pl_content + power_event, power_context = await create_event( + EventTypes.PowerLevels, pl_content, False ) + current_state_group = power_context._state_group + last_sent_stream_id = await send(power_event, power_context, creator) else: power_level_content: JsonDict = { "users": {creator_id: 100}, @@ -1169,47 +1211,68 @@ class RoomCreationHandler: # apply those. if power_level_content_override: power_level_content.update(power_level_content_override) - - last_sent_stream_id = await send( - etype=EventTypes.PowerLevels, content=power_level_content + pl_event, pl_context = await create_event( + EventTypes.PowerLevels, + power_level_content, + False, ) + current_state_group = pl_context._state_group + last_sent_stream_id = await send(pl_event, pl_context, creator) + events_to_send = [] if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: - last_sent_stream_id = await send( - etype=EventTypes.CanonicalAlias, - content={"alias": room_alias.to_string()}, + room_alias_event, room_alias_context = await create_event( + EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True ) + current_state_group = room_alias_context._state_group + events_to_send.append((room_alias_event, room_alias_context)) if (EventTypes.JoinRules, "") not in initial_state: - last_sent_stream_id = await send( - etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]} + join_rules_event, join_rules_context = await create_event( + EventTypes.JoinRules, + {"join_rule": config["join_rules"]}, + True, ) + current_state_group = join_rules_context._state_group + events_to_send.append((join_rules_event, join_rules_context)) if (EventTypes.RoomHistoryVisibility, "") not in initial_state: - last_sent_stream_id = await send( - etype=EventTypes.RoomHistoryVisibility, - content={"history_visibility": config["history_visibility"]}, + visibility_event, visibility_context = await create_event( + EventTypes.RoomHistoryVisibility, + {"history_visibility": config["history_visibility"]}, + True, ) + current_state_group = visibility_context._state_group + events_to_send.append((visibility_event, visibility_context)) if config["guest_can_join"]: if (EventTypes.GuestAccess, "") not in initial_state: - last_sent_stream_id = await send( - etype=EventTypes.GuestAccess, - content={EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN}, + guest_access_event, guest_access_context = await create_event( + EventTypes.GuestAccess, + {EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN}, + True, ) + current_state_group = guest_access_context._state_group + events_to_send.append((guest_access_event, guest_access_context)) for (etype, state_key), content in initial_state.items(): - last_sent_stream_id = await send( - etype=etype, state_key=state_key, content=content + event, context = await create_event( + etype, content, True, state_key=state_key ) + current_state_group = context._state_group + events_to_send.append((event, context)) if config["encrypted"]: - last_sent_stream_id = await send( - etype=EventTypes.RoomEncryption, + encryption_event, encryption_context = await create_event( + EventTypes.RoomEncryption, + {"algorithm": RoomEncryptionAlgorithms.DEFAULT}, + True, state_key="", - content={"algorithm": RoomEncryptionAlgorithms.DEFAULT}, ) + events_to_send.append((encryption_event, encryption_context)) + for event, context in events_to_send: + last_sent_stream_id = await send(event, context, creator) return last_sent_stream_id, last_sent_event_id, depth def _generate_room_id(self) -> str: diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 3787d35b24..6f3dd0463e 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -420,6 +420,69 @@ class StateHandler: partial_state=partial_state, ) + async def compute_event_context_for_batched( + self, + event: EventBase, + state_ids_before_event: StateMap[str], + current_state_group: int, + ) -> EventContext: + """ + Generate an event context for an event that has not yet been persisted to the + database. Intended for use with events that are created to be persisted in a batch. + Args: + event: the event the context is being computed for + state_ids_before_event: a state map consisting of the state ids of the events + created prior to this event. + current_state_group: the current state group before the event. + """ + state_group_before_event_prev_group = None + deltas_to_state_group_before_event = None + + state_group_before_event = current_state_group + + # if the event is not state, we are set + if not event.is_state(): + return EventContext.with_state( + storage=self._storage_controllers, + state_group_before_event=state_group_before_event, + state_group=state_group_before_event, + state_delta_due_to_event={}, + prev_group=state_group_before_event_prev_group, + delta_ids=deltas_to_state_group_before_event, + partial_state=False, + ) + + # otherwise, we'll need to create a new state group for after the event + key = (event.type, event.state_key) + + if state_ids_before_event is not None: + replaces = state_ids_before_event.get(key) + + if replaces and replaces != event.event_id: + event.unsigned["replaces_state"] = replaces + + delta_ids = {key: event.event_id} + + state_group_after_event = ( + await self._state_storage_controller.store_state_group( + event.event_id, + event.room_id, + prev_group=state_group_before_event, + delta_ids=delta_ids, + current_state_ids=None, + ) + ) + + return EventContext.with_state( + storage=self._storage_controllers, + state_group=state_group_after_event, + state_group_before_event=state_group_before_event, + state_delta_due_to_event=delta_ids, + prev_group=state_group_before_event, + delta_ids=delta_ids, + partial_state=False, + ) + @measure_func() async def resolve_state_groups_for_events( self, room_id: str, event_ids: Collection[str], await_full_state: bool = True diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index c7eb88d33f..e281aef779 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -710,7 +710,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(44, channel.resource_usage.db_txn_count) + self.assertEqual(35, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -723,7 +723,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(50, channel.resource_usage.db_txn_count) + self.assertEqual(38, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id -- cgit 1.5.1 From e5fdf16d4680b00ca8120ddb697bd14ab89fdf0c Mon Sep 17 00:00:00 2001 From: Hugh Nimmo-Smith Date: Thu, 29 Sep 2022 12:22:27 +0100 Subject: Expose MSC3882 only be under an unstable endpoint. (#13868) --- changelog.d/13868.misc | 1 + synapse/rest/client/login_token_request.py | 4 +++- tests/rest/client/test_login_token_request.py | 16 +++++++++------- 3 files changed, 13 insertions(+), 8 deletions(-) create mode 100644 changelog.d/13868.misc (limited to 'tests/rest') diff --git a/changelog.d/13868.misc b/changelog.d/13868.misc new file mode 100644 index 0000000000..d7a99c042a --- /dev/null +++ b/changelog.d/13868.misc @@ -0,0 +1 @@ +Fix unstable MSC3882 endpoint being incorrectly available on stable API versions. \ No newline at end of file diff --git a/synapse/rest/client/login_token_request.py b/synapse/rest/client/login_token_request.py index ca5c54bf17..277b20fb63 100644 --- a/synapse/rest/client/login_token_request.py +++ b/synapse/rest/client/login_token_request.py @@ -47,7 +47,9 @@ class LoginTokenRequestServlet(RestServlet): } """ - PATTERNS = client_patterns("/login/token$") + PATTERNS = client_patterns( + "/org.matrix.msc3882/login/token$", releases=[], v1=False, unstable=True + ) def __init__(self, hs: "HomeServer"): super().__init__() diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py index d5bb16c98d..c2e1e08811 100644 --- a/tests/rest/client/test_login_token_request.py +++ b/tests/rest/client/test_login_token_request.py @@ -22,6 +22,8 @@ from synapse.util import Clock from tests import unittest from tests.unittest import override_config +endpoint = "/_matrix/client/unstable/org.matrix.msc3882/login/token" + class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): @@ -45,18 +47,18 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): self.password = "password" def test_disabled(self) -> None: - channel = self.make_request("POST", "/login/token", {}, access_token=None) + channel = self.make_request("POST", endpoint, {}, access_token=None) self.assertEqual(channel.code, 400) self.register_user(self.user, self.password) token = self.login(self.user, self.password) - channel = self.make_request("POST", "/login/token", {}, access_token=token) + channel = self.make_request("POST", endpoint, {}, access_token=token) self.assertEqual(channel.code, 400) @override_config({"experimental_features": {"msc3882_enabled": True}}) def test_require_auth(self) -> None: - channel = self.make_request("POST", "/login/token", {}, access_token=None) + channel = self.make_request("POST", endpoint, {}, access_token=None) self.assertEqual(channel.code, 401) @override_config({"experimental_features": {"msc3882_enabled": True}}) @@ -64,7 +66,7 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): user_id = self.register_user(self.user, self.password) token = self.login(self.user, self.password) - channel = self.make_request("POST", "/login/token", {}, access_token=token) + channel = self.make_request("POST", endpoint, {}, access_token=token) self.assertEqual(channel.code, 401) self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) @@ -79,7 +81,7 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): }, } - channel = self.make_request("POST", "/login/token", uia, access_token=token) + channel = self.make_request("POST", endpoint, uia, access_token=token) self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["expires_in"], 300) @@ -100,7 +102,7 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): user_id = self.register_user(self.user, self.password) token = self.login(self.user, self.password) - channel = self.make_request("POST", "/login/token", {}, access_token=token) + channel = self.make_request("POST", endpoint, {}, access_token=token) self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["expires_in"], 300) @@ -127,6 +129,6 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): self.register_user(self.user, self.password) token = self.login(self.user, self.password) - channel = self.make_request("POST", "/login/token", {}, access_token=token) + channel = self.make_request("POST", endpoint, {}, access_token=token) self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["expires_in"], 15) -- cgit 1.5.1 From be76cd8200b18f3c68b895f85ac7ef5b0ddc2466 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 29 Sep 2022 14:23:24 +0100 Subject: Allow admins to require a manual approval process before new accounts can be used (using MSC3866) (#13556) --- changelog.d/13556.feature | 1 + synapse/_scripts/synapse_port_db.py | 2 +- synapse/api/constants.py | 11 ++ synapse/api/errors.py | 16 ++ synapse/config/experimental.py | 19 +++ synapse/handlers/admin.py | 5 + synapse/handlers/auth.py | 11 ++ synapse/handlers/register.py | 8 + synapse/replication/http/register.py | 5 + synapse/rest/admin/users.py | 43 ++++- synapse/rest/client/login.py | 37 +++- synapse/rest/client/register.py | 22 ++- synapse/storage/databases/main/__init__.py | 9 +- synapse/storage/databases/main/registration.py | 150 +++++++++++++++-- .../main/delta/73/03users_approved_column.sql | 20 +++ tests/rest/admin/test_user.py | 186 ++++++++++++++++++++- tests/rest/client/test_auth.py | 33 +++- tests/rest/client/test_login.py | 41 +++++ tests/rest/client/test_register.py | 32 +++- tests/rest/client/utils.py | 12 +- tests/storage/test_registration.py | 102 ++++++++++- 21 files changed, 731 insertions(+), 34 deletions(-) create mode 100644 changelog.d/13556.feature create mode 100644 synapse/storage/schema/main/delta/73/03users_approved_column.sql (limited to 'tests/rest') diff --git a/changelog.d/13556.feature b/changelog.d/13556.feature new file mode 100644 index 0000000000..f9d63db6c0 --- /dev/null +++ b/changelog.d/13556.feature @@ -0,0 +1 @@ +Allow server admins to require a manual approval process before new accounts can be used (using [MSC3866](https://github.com/matrix-org/matrix-spec-proposals/pull/3866)). diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 450ba462ba..5fa599e70e 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -107,7 +107,7 @@ BOOLEAN_COLUMNS = { "redactions": ["have_censored"], "room_stats_state": ["is_federatable"], "local_media_repository": ["safe_from_quarantine"], - "users": ["shadow_banned"], + "users": ["shadow_banned", "approved"], "e2e_fallback_keys_json": ["used"], "access_tokens": ["used"], "device_lists_changes_in_room": ["converted_to_destinations"], diff --git a/synapse/api/constants.py b/synapse/api/constants.py index c178ddf070..c031903b1a 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -269,3 +269,14 @@ class PublicRoomsFilterFields: GENERIC_SEARCH_TERM: Final = "generic_search_term" ROOM_TYPES: Final = "room_types" + + +class ApprovalNoticeMedium: + """Identifier for the medium this server will use to serve notice of approval for a + specific user's registration. + + As defined in https://github.com/matrix-org/matrix-spec-proposals/blob/babolivier/m_not_approved/proposals/3866-user-not-approved-error.md + """ + + NONE = "org.matrix.msc3866.none" + EMAIL = "org.matrix.msc3866.email" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 1c6b53aa24..c606207569 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -106,6 +106,8 @@ class Codes(str, Enum): # Part of MSC3895. UNABLE_DUE_TO_PARTIAL_STATE = "ORG.MATRIX.MSC3895_UNABLE_DUE_TO_PARTIAL_STATE" + USER_AWAITING_APPROVAL = "ORG.MATRIX.MSC3866_USER_AWAITING_APPROVAL" + class CodeMessageException(RuntimeError): """An exception with integer code and message string attributes. @@ -566,6 +568,20 @@ class UnredactedContentDeletedError(SynapseError): return cs_error(self.msg, self.errcode, **extra) +class NotApprovedError(SynapseError): + def __init__( + self, + msg: str, + approval_notice_medium: str, + ): + super().__init__( + code=403, + msg=msg, + errcode=Codes.USER_AWAITING_APPROVAL, + additional_fields={"approval_notice_medium": approval_notice_medium}, + ) + + def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs: Any) -> "JsonDict": """Utility method for constructing an error response for client-server interactions. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 933779c23a..31834fb27d 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -14,10 +14,25 @@ from typing import Any +import attr + from synapse.config._base import Config from synapse.types import JsonDict +@attr.s(auto_attribs=True, frozen=True, slots=True) +class MSC3866Config: + """Configuration for MSC3866 (mandating approval for new users)""" + + # Whether the base support for the approval process is enabled. This includes the + # ability for administrators to check and update the approval of users, even if no + # approval is currently required. + enabled: bool = False + # Whether to require that new users are approved by an admin before their account + # can be used. Note that this setting is ignored if 'enabled' is false. + require_approval_for_new_accounts: bool = False + + class ExperimentalConfig(Config): """Config section for enabling experimental features""" @@ -97,6 +112,10 @@ class ExperimentalConfig(Config): # MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices. self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False) + # MSC3866: M_USER_AWAITING_APPROVAL error code + raw_msc3866_config = experimental.get("msc3866", {}) + self.msc3866 = MSC3866Config(**raw_msc3866_config) + # MSC3881: Remotely toggle push notifications for another client self.msc3881_enabled: bool = experimental.get("msc3881_enabled", False) diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index cf9f19608a..f2989cc4a2 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -32,6 +32,7 @@ class AdminHandler: self.store = hs.get_datastores().main self._storage_controllers = hs.get_storage_controllers() self._state_storage_controller = self._storage_controllers.state + self._msc3866_enabled = hs.config.experimental.msc3866.enabled async def get_whois(self, user: UserID) -> JsonDict: connections = [] @@ -75,6 +76,10 @@ class AdminHandler: "is_guest", } + if self._msc3866_enabled: + # Only include the approved flag if support for MSC3866 is enabled. + user_info_to_return.add("approved") + # Restrict returned keys to a known set. user_info_dict = { key: value diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index eacd631ee0..f5f0e0e7a7 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1009,6 +1009,17 @@ class AuthHandler: return res[0] return None + async def is_user_approved(self, user_id: str) -> bool: + """Checks if a user is approved and therefore can be allowed to log in. + + Args: + user_id: the user to check the approval status of. + + Returns: + A boolean that is True if the user is approved, False otherwise. + """ + return await self.store.is_user_approved(user_id) + async def _find_user_id_and_pwd_hash( self, user_id: str ) -> Optional[Tuple[str, str]]: diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index cfcadb34db..ca1c7a1866 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -220,6 +220,7 @@ class RegistrationHandler: by_admin: bool = False, user_agent_ips: Optional[List[Tuple[str, str]]] = None, auth_provider_id: Optional[str] = None, + approved: bool = False, ) -> str: """Registers a new client on the server. @@ -246,6 +247,8 @@ class RegistrationHandler: user_agent_ips: Tuples of user-agents and IP addresses used during the registration process. auth_provider_id: The SSO IdP the user used, if any. + approved: True if the new user should be considered already + approved by an administrator. Returns: The registered user_id. Raises: @@ -307,6 +310,7 @@ class RegistrationHandler: user_type=user_type, address=address, shadow_banned=shadow_banned, + approved=approved, ) profile = await self.store.get_profileinfo(localpart) @@ -695,6 +699,7 @@ class RegistrationHandler: user_type: Optional[str] = None, address: Optional[str] = None, shadow_banned: bool = False, + approved: bool = False, ) -> None: """Register user in the datastore. @@ -713,6 +718,7 @@ class RegistrationHandler: api.constants.UserTypes, or None for a normal user. address: the IP address used to perform the registration. shadow_banned: Whether to shadow-ban the user + approved: Whether to mark the user as approved by an administrator """ if self.hs.config.worker.worker_app: await self._register_client( @@ -726,6 +732,7 @@ class RegistrationHandler: user_type=user_type, address=address, shadow_banned=shadow_banned, + approved=approved, ) else: await self.store.register_user( @@ -738,6 +745,7 @@ class RegistrationHandler: admin=admin, user_type=user_type, shadow_banned=shadow_banned, + approved=approved, ) # Only call the account validity module(s) on the main process, to avoid diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index 6c8f8388fd..61abb529c8 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -51,6 +51,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): user_type: Optional[str], address: Optional[str], shadow_banned: bool, + approved: bool, ) -> JsonDict: """ Args: @@ -68,6 +69,8 @@ class ReplicationRegisterServlet(ReplicationEndpoint): or None for a normal user. address: the IP address used to perform the regitration. shadow_banned: Whether to shadow-ban the user + approved: Whether the user should be considered already approved by an + administrator. """ return { "password_hash": password_hash, @@ -79,6 +82,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): "user_type": user_type, "address": address, "shadow_banned": shadow_banned, + "approved": approved, } async def _handle_request( # type: ignore[override] @@ -99,6 +103,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): user_type=content["user_type"], address=content["address"], shadow_banned=content["shadow_banned"], + approved=content["approved"], ) return 200, {} diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 1274773d7e..15ac2059aa 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -69,6 +69,7 @@ class UsersRestServletV2(RestServlet): self.store = hs.get_datastores().main self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() + self._msc3866_enabled = hs.config.experimental.msc3866.enabled async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) @@ -95,6 +96,13 @@ class UsersRestServletV2(RestServlet): guests = parse_boolean(request, "guests", default=True) deactivated = parse_boolean(request, "deactivated", default=False) + # If support for MSC3866 is not enabled, apply no filtering based on the + # `approved` column. + if self._msc3866_enabled: + approved = parse_boolean(request, "approved", default=True) + else: + approved = True + order_by = parse_string( request, "order_by", @@ -115,8 +123,22 @@ class UsersRestServletV2(RestServlet): direction = parse_string(request, "dir", default="f", allowed_values=("f", "b")) users, total = await self.store.get_users_paginate( - start, limit, user_id, name, guests, deactivated, order_by, direction + start, + limit, + user_id, + name, + guests, + deactivated, + order_by, + direction, + approved, ) + + # If support for MSC3866 is not enabled, don't show the approval flag. + if not self._msc3866_enabled: + for user in users: + del user["approved"] + ret = {"users": users, "total": total} if (start + limit) < total: ret["next_token"] = str(start + len(users)) @@ -163,6 +185,7 @@ class UserRestServletV2(RestServlet): self.deactivate_account_handler = hs.get_deactivate_account_handler() self.registration_handler = hs.get_registration_handler() self.pusher_pool = hs.get_pusherpool() + self._msc3866_enabled = hs.config.experimental.msc3866.enabled async def on_GET( self, request: SynapseRequest, user_id: str @@ -239,6 +262,15 @@ class UserRestServletV2(RestServlet): HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean" ) + approved: Optional[bool] = None + if "approved" in body and self._msc3866_enabled: + approved = body["approved"] + if not isinstance(approved, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "'approved' parameter is not of type boolean", + ) + # convert List[Dict[str, str]] into List[Tuple[str, str]] if external_ids is not None: new_external_ids = [ @@ -343,6 +375,9 @@ class UserRestServletV2(RestServlet): if "user_type" in body: await self.store.set_user_type(target_user, user_type) + if approved is not None: + await self.store.update_user_approval_status(target_user, approved) + user = await self.admin_handler.get_user(target_user) assert user is not None @@ -355,6 +390,10 @@ class UserRestServletV2(RestServlet): if password is not None: password_hash = await self.auth_handler.hash(password) + new_user_approved = True + if self._msc3866_enabled and approved is not None: + new_user_approved = approved + user_id = await self.registration_handler.register_user( localpart=target_user.localpart, password_hash=password_hash, @@ -362,6 +401,7 @@ class UserRestServletV2(RestServlet): default_display_name=displayname, user_type=user_type, by_admin=True, + approved=new_user_approved, ) if threepids is not None: @@ -550,6 +590,7 @@ class UserRegisterServlet(RestServlet): user_type=user_type, default_display_name=displayname, by_admin=True, + approved=True, ) result = await register._create_registration_details(user_id, body) diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 0437c87d8d..f554586ac3 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -28,7 +28,14 @@ from typing import ( from typing_extensions import TypedDict -from synapse.api.errors import Codes, InvalidClientTokenError, LoginError, SynapseError +from synapse.api.constants import ApprovalNoticeMedium +from synapse.api.errors import ( + Codes, + InvalidClientTokenError, + LoginError, + NotApprovedError, + SynapseError, +) from synapse.api.ratelimiting import Ratelimiter from synapse.api.urls import CLIENT_API_PREFIX from synapse.appservice import ApplicationService @@ -55,11 +62,11 @@ logger = logging.getLogger(__name__) class LoginResponse(TypedDict, total=False): user_id: str - access_token: str + access_token: Optional[str] home_server: str expires_in_ms: Optional[int] refresh_token: Optional[str] - device_id: str + device_id: Optional[str] well_known: Optional[Dict[str, Any]] @@ -92,6 +99,12 @@ class LoginRestServlet(RestServlet): hs.config.registration.refreshable_access_token_lifetime is not None ) + # Whether we need to check if the user has been approved or not. + self._require_approval = ( + hs.config.experimental.msc3866.enabled + and hs.config.experimental.msc3866.require_approval_for_new_accounts + ) + self.auth = hs.get_auth() self.clock = hs.get_clock() @@ -220,6 +233,14 @@ class LoginRestServlet(RestServlet): except KeyError: raise SynapseError(400, "Missing JSON keys.") + if self._require_approval: + approved = await self.auth_handler.is_user_approved(result["user_id"]) + if not approved: + raise NotApprovedError( + msg="This account is pending approval by a server administrator.", + approval_notice_medium=ApprovalNoticeMedium.NONE, + ) + well_known_data = self._well_known_builder.get_well_known() if well_known_data: result["well_known"] = well_known_data @@ -356,6 +377,16 @@ class LoginRestServlet(RestServlet): errcode=Codes.INVALID_PARAM, ) + if self._require_approval: + approved = await self.auth_handler.is_user_approved(user_id) + if not approved: + # If the user isn't approved (and needs to be) we won't allow them to + # actually log in, so we don't want to create a device/access token. + return LoginResponse( + user_id=user_id, + home_server=self.hs.hostname, + ) + initial_display_name = login_submission.get("initial_device_display_name") ( device_id, diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 20bab20c8f..de810ae3ec 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -21,10 +21,15 @@ from twisted.web.server import Request import synapse import synapse.api.auth import synapse.types -from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType +from synapse.api.constants import ( + APP_SERVICE_REGISTRATION_TYPE, + ApprovalNoticeMedium, + LoginType, +) from synapse.api.errors import ( Codes, InteractiveAuthIncompleteError, + NotApprovedError, SynapseError, ThreepidValidationError, UnrecognizedRequestError, @@ -414,6 +419,11 @@ class RegisterRestServlet(RestServlet): hs.config.registration.inhibit_user_in_use_error ) + self._require_approval = ( + hs.config.experimental.msc3866.enabled + and hs.config.experimental.msc3866.require_approval_for_new_accounts + ) + self._registration_flows = _calculate_registration_flows( hs.config, self.auth_handler ) @@ -734,6 +744,12 @@ class RegisterRestServlet(RestServlet): access_token=return_dict.get("access_token"), ) + if self._require_approval: + raise NotApprovedError( + msg="This account needs to be approved by an administrator before it can be used.", + approval_notice_medium=ApprovalNoticeMedium.NONE, + ) + return 200, return_dict async def _do_appservice_registration( @@ -778,7 +794,9 @@ class RegisterRestServlet(RestServlet): "user_id": user_id, "home_server": self.hs.hostname, } - if not params.get("inhibit_login", False): + # We don't want to log the user in if we're going to deny them access because + # they need to be approved first. + if not params.get("inhibit_login", False) and not self._require_approval: device_id = params.get("device_id") initial_display_name = params.get("initial_device_display_name") ( diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 0843f10340..a62b4abd4e 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -203,6 +203,7 @@ class DataStore( deactivated: bool = False, order_by: str = UserSortOrder.USER_ID.value, direction: str = "f", + approved: bool = True, ) -> Tuple[List[JsonDict], int]: """Function to retrieve a paginated list of users from users list. This will return a json list of users and the @@ -217,6 +218,7 @@ class DataStore( deactivated: whether to include deactivated users order_by: the sort order of the returned list direction: sort ascending or descending + approved: whether to include approved users Returns: A tuple of a list of mappings from user to information and a count of total users. """ @@ -249,6 +251,11 @@ class DataStore( if not deactivated: filters.append("deactivated = 0") + if not approved: + # We ignore NULL values for the approved flag because these should only + # be already existing users that we consider as already approved. + filters.append("approved IS FALSE") + where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" sql_base = f""" @@ -262,7 +269,7 @@ class DataStore( sql = f""" SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, - displayname, avatar_url, creation_ts * 1000 as creation_ts + displayname, avatar_url, creation_ts * 1000 as creation_ts, approved {sql_base} ORDER BY {order_by_column} {order}, u.name ASC LIMIT ? OFFSET ? diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index ac821878b0..2996d6bb4d 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -166,27 +166,49 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): @cached() async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: """Deprecated: use get_userinfo_by_id instead""" - return await self.db_pool.simple_select_one( - table="users", - keyvalues={"name": user_id}, - retcols=[ - "name", - "password_hash", - "is_guest", - "admin", - "consent_version", - "consent_ts", - "consent_server_notice_sent", - "appservice_id", - "creation_ts", - "user_type", - "deactivated", - "shadow_banned", - ], - allow_none=True, + + def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: + # We could technically use simple_select_one here, but it would not perform + # the COALESCEs (unless hacked into the column names), which could yield + # confusing results. + txn.execute( + """ + SELECT + name, password_hash, is_guest, admin, consent_version, consent_ts, + consent_server_notice_sent, appservice_id, creation_ts, user_type, + deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned, + COALESCE(approved, TRUE) AS approved + FROM users + WHERE name = ? + """, + (user_id,), + ) + + rows = self.db_pool.cursor_to_dict(txn) + + if len(rows) == 0: + return None + + return rows[0] + + row = await self.db_pool.runInteraction( desc="get_user_by_id", + func=get_user_by_id_txn, ) + if row is not None: + # If we're using SQLite our boolean values will be integers. Because we + # present some of this data as is to e.g. server admins via REST APIs, we + # want to make sure we're returning the right type of data. + # Note: when adding a column name to this list, be wary of NULLable columns, + # since NULL values will be turned into False. + boolean_columns = ["admin", "deactivated", "shadow_banned", "approved"] + for column in boolean_columns: + if not isinstance(row[column], bool): + row[column] = bool(row[column]) + + return row + async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: """Get a UserInfo object for a user by user ID. @@ -1779,6 +1801,40 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return res if res else False + @cached() + async def is_user_approved(self, user_id: str) -> bool: + """Checks if a user is approved and therefore can be allowed to log in. + + If the user's 'approved' column is NULL, we consider it as true given it means + the user was registered when support for an approval flow was either disabled + or nonexistent. + + Args: + user_id: the user to check the approval status of. + + Returns: + A boolean that is True if the user is approved, False otherwise. + """ + + def is_user_approved_txn(txn: LoggingTransaction) -> bool: + txn.execute( + """ + SELECT COALESCE(approved, TRUE) AS approved FROM users WHERE name = ? + """, + (user_id,), + ) + + rows = self.db_pool.cursor_to_dict(txn) + + # We cast to bool because the value returned by the database engine might + # be an integer if we're using SQLite. + return bool(rows[0]["approved"]) + + return await self.db_pool.runInteraction( + desc="is_user_pending_approval", + func=is_user_approved_txn, + ) + class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): def __init__( @@ -1916,6 +1972,29 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,)) + def update_user_approval_status_txn( + self, txn: LoggingTransaction, user_id: str, approved: bool + ) -> None: + """Set the user's 'approved' flag to the given value. + + The boolean is turned into an int because the column is a smallint. + + Args: + txn: the current database transaction. + user_id: the user to update the flag for. + approved: the value to set the flag to. + """ + self.db_pool.simple_update_one_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"approved": approved}, + ) + + # Invalidate the caches of methods that read the value of the 'approved' flag. + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + self._invalidate_cache_and_stream(txn, self.is_user_approved, (user_id,)) + class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): def __init__( @@ -1933,6 +2012,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") + # If support for MSC3866 is enabled and configured to require approval for new + # account, we will create new users with an 'approved' flag set to false. + self._require_approval = ( + hs.config.experimental.msc3866.enabled + and hs.config.experimental.msc3866.require_approval_for_new_accounts + ) + async def add_access_token_to_user( self, user_id: str, @@ -2065,6 +2151,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): admin: bool = False, user_type: Optional[str] = None, shadow_banned: bool = False, + approved: bool = False, ) -> None: """Attempts to register an account. @@ -2083,6 +2170,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): or None for a normal user. shadow_banned: Whether the user is shadow-banned, i.e. they may be told their requests succeeded but we ignore them. + approved: Whether to consider the user has already been approved by an + administrator. Raises: StoreError if the user_id could not be registered. @@ -2099,6 +2188,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): admin, user_type, shadow_banned, + approved, ) def _register_user( @@ -2113,11 +2203,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): admin: bool, user_type: Optional[str], shadow_banned: bool, + approved: bool, ) -> None: user_id_obj = UserID.from_string(user_id) now = int(self._clock.time()) + user_approved = approved or not self._require_approval + try: if was_guest: # Ensure that the guest user actually exists @@ -2143,6 +2236,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): "admin": 1 if admin else 0, "user_type": user_type, "shadow_banned": shadow_banned, + "approved": user_approved, }, ) else: @@ -2158,6 +2252,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): "admin": 1 if admin else 0, "user_type": user_type, "shadow_banned": shadow_banned, + "approved": user_approved, }, ) @@ -2503,6 +2598,25 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): start_or_continue_validation_session_txn, ) + async def update_user_approval_status( + self, user_id: UserID, approved: bool + ) -> None: + """Set the user's 'approved' flag to the given value. + + The boolean will be turned into an int (in update_user_approval_status_txn) + because the column is a smallint. + + Args: + user_id: the user to update the flag for. + approved: the value to set the flag to. + """ + await self.db_pool.runInteraction( + "update_user_approval_status", + self.update_user_approval_status_txn, + user_id.to_string(), + approved, + ) + def find_max_generated_user_id_localpart(cur: Cursor) -> int: """ diff --git a/synapse/storage/schema/main/delta/73/03users_approved_column.sql b/synapse/storage/schema/main/delta/73/03users_approved_column.sql new file mode 100644 index 0000000000..5328d592ea --- /dev/null +++ b/synapse/storage/schema/main/delta/73/03users_approved_column.sql @@ -0,0 +1,20 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add a column to the users table to track whether the user needs to be approved by an +-- administrator. +-- A NULL column means the user was created before this feature was supported by Synapse, +-- and should be considered as TRUE. +ALTER TABLE users ADD COLUMN approved BOOLEAN; diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 1847e6ad6b..4c1ce33463 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -25,10 +25,10 @@ from parameterized import parameterized, parameterized_class from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin -from synapse.api.constants import UserTypes +from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.room_versions import RoomVersions -from synapse.rest.client import devices, login, logout, profile, room, sync +from synapse.rest.client import devices, login, logout, profile, register, room, sync from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.server import HomeServer from synapse.types import JsonDict, UserID @@ -578,6 +578,16 @@ class UsersListTestCase(unittest.HomeserverTestCase): _search_test(None, "foo", "user_id") _search_test(None, "bar", "user_id") + @override_config( + { + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + } + } + ) def test_invalid_parameter(self) -> None: """ If parameters are invalid, an error is returned. @@ -623,6 +633,16 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + # invalid approved + channel = self.make_request( + "GET", + self.url + "?approved=not_bool", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + # unkown order_by channel = self.make_request( "GET", @@ -841,6 +861,69 @@ class UsersListTestCase(unittest.HomeserverTestCase): self._order_test([self.admin_user, user1, user2], "creation_ts", "f") self._order_test([user2, user1, self.admin_user], "creation_ts", "b") + @override_config( + { + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + } + } + ) + def test_filter_out_approved(self) -> None: + """Tests that the endpoint can filter out approved users.""" + # Create our users. + self._create_users(2) + + # Get the list of users. + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, channel.result) + + # Exclude the admin, because we don't want to accidentally un-approve the admin. + non_admin_user_ids = [ + user["name"] + for user in channel.json_body["users"] + if user["name"] != self.admin_user + ] + + self.assertEqual(2, len(non_admin_user_ids), non_admin_user_ids) + + # Select a user and un-approve them. We do this rather than the other way around + # because, since these users are created by an admin, we consider them already + # approved. + not_approved_user = non_admin_user_ids[0] + + channel = self.make_request( + "PUT", + f"/_synapse/admin/v2/users/{not_approved_user}", + {"approved": False}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, channel.result) + + # Now get the list of users again, this time filtering out approved users. + channel = self.make_request( + "GET", + self.url + "?approved=false", + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, channel.result) + + non_admin_user_ids = [ + user["name"] + for user in channel.json_body["users"] + if user["name"] != self.admin_user + ] + + # We should only have our unapproved user now. + self.assertEqual(1, len(non_admin_user_ids), non_admin_user_ids) + self.assertEqual(not_approved_user, non_admin_user_ids[0]) + def _order_test( self, expected_user_list: List[str], @@ -1272,6 +1355,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): synapse.rest.admin.register_servlets, login.register_servlets, sync.register_servlets, + register.register_servlets, ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -2536,6 +2620,104 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Ensure they're still alive self.assertEqual(0, channel.json_body["deactivated"]) + @override_config( + { + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + } + } + ) + def test_approve_account(self) -> None: + """Tests that approving an account correctly sets the approved flag for the user.""" + url = self.url_prefix % "@bob:test" + + # Create the user using the client-server API since otherwise the user will be + # marked as approved automatically. + channel = self.make_request( + "POST", + "register", + { + "username": "bob", + "password": "test", + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual(403, channel.code, channel.result) + self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"]) + self.assertEqual( + ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"] + ) + + # Get user + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertIs(False, channel.json_body["approved"]) + + # Approve user + channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + content={"approved": True}, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertIs(True, channel.json_body["approved"]) + + # Check that the user is now approved + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertIs(True, channel.json_body["approved"]) + + @override_config( + { + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + } + } + ) + def test_register_approved(self) -> None: + url = self.url_prefix % "@bob:test" + + # Create user + channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + content={"password": "abc123", "approved": True}, + ) + + self.assertEqual(201, channel.code, msg=channel.json_body) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual(1, channel.json_body["approved"]) + + # Get user + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual(1, channel.json_body["approved"]) + def _is_erased(self, user_id: str, expect: bool) -> None: """Assert that the user is erased or not""" d = self.store.is_user_erased(user_id) diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 05355c7fb6..090cef5216 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -20,7 +20,8 @@ from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource import synapse.rest.admin -from synapse.api.constants import LoginType +from synapse.api.constants import ApprovalNoticeMedium, LoginType +from synapse.api.errors import Codes from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.rest.client import account, auth, devices, login, logout, register from synapse.rest.synapse.client import build_synapse_client_resource_tree @@ -567,6 +568,36 @@ class UIAuthTests(unittest.HomeserverTestCase): body={"auth": {"session": session_id}}, ) + @skip_unless(HAS_OIDC, "requires OIDC") + @override_config( + { + "oidc_config": TEST_OIDC_CONFIG, + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + }, + } + ) + def test_sso_not_approved(self) -> None: + """Tests that if we register a user via SSO while requiring approval for new + accounts, we still raise the correct error before logging the user in. + """ + login_resp = self.helper.login_via_oidc("username", expected_status=403) + + self.assertEqual(login_resp["errcode"], Codes.USER_AWAITING_APPROVAL) + self.assertEqual( + ApprovalNoticeMedium.NONE, login_resp["approval_notice_medium"] + ) + + # Check that we didn't register a device for the user during the login attempt. + devices = self.get_success( + self.hs.get_datastores().main.get_devices_by_user("@username:test") + ) + + self.assertEqual(len(devices), 0) + class RefreshAuthTests(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index e2a4d98275..e801ba8c8b 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -23,6 +23,8 @@ from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource import synapse.rest.admin +from synapse.api.constants import ApprovalNoticeMedium, LoginType +from synapse.api.errors import Codes from synapse.appservice import ApplicationService from synapse.rest.client import devices, login, logout, register from synapse.rest.client.account import WhoamiRestServlet @@ -94,6 +96,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): logout.register_servlets, devices.register_servlets, lambda hs, http_server: WhoamiRestServlet(hs).register(http_server), + register.register_servlets, ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: @@ -406,6 +409,44 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400) self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM") + @override_config( + { + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + } + } + ) + def test_require_approval(self) -> None: + channel = self.make_request( + "POST", + "register", + { + "username": "kermit", + "password": "monkey", + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual(403, channel.code, channel.result) + self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"]) + self.assertEqual( + ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"] + ) + + params = { + "type": LoginType.PASSWORD, + "identifier": {"type": "m.id.user", "user": "kermit"}, + "password": "monkey", + } + channel = self.make_request("POST", LOGIN_URL, params) + self.assertEqual(403, channel.code, channel.result) + self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"]) + self.assertEqual( + ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"] + ) + @skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC") class MultiSSOTestCase(unittest.HomeserverTestCase): diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index b781875d52..11cf3939d8 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -22,7 +22,11 @@ import pkg_resources from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin -from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType +from synapse.api.constants import ( + APP_SERVICE_REGISTRATION_TYPE, + ApprovalNoticeMedium, + LoginType, +) from synapse.api.errors import Codes from synapse.appservice import ApplicationService from synapse.rest.client import account, account_validity, login, logout, register, sync @@ -765,6 +769,32 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.USER_IN_USE) + @override_config( + { + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + } + } + ) + def test_require_approval(self) -> None: + channel = self.make_request( + "POST", + "register", + { + "username": "kermit", + "password": "monkey", + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual(403, channel.code, channel.result) + self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"]) + self.assertEqual( + ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"] + ) + class AccountValidityTestCase(unittest.HomeserverTestCase): diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index dd26145bf8..c249a42bb6 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -543,8 +543,12 @@ class RestHelper: return channel.json_body - def login_via_oidc(self, remote_user_id: str) -> JsonDict: - """Log in (as a new user) via OIDC + def login_via_oidc( + self, + remote_user_id: str, + expected_status: int = 200, + ) -> JsonDict: + """Log in via OIDC Returns the result of the final token login. @@ -578,7 +582,9 @@ class RestHelper: "/login", content={"type": "m.login.token", "token": login_token}, ) - assert channel.code == HTTPStatus.OK + assert ( + channel.code == expected_status + ), f"unexpected status in response: {channel.code}" return channel.json_body def auth_via_oidc( diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 853a93afab..05ea802008 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -16,9 +16,10 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import UserTypes from synapse.api.errors import ThreepidValidationError from synapse.server import HomeServer +from synapse.types import JsonDict, UserID from synapse.util import Clock -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, override_config class RegistrationStoreTestCase(HomeserverTestCase): @@ -48,6 +49,7 @@ class RegistrationStoreTestCase(HomeserverTestCase): "user_type": None, "deactivated": 0, "shadow_banned": 0, + "approved": 1, }, (self.get_success(self.store.get_user_by_id(self.user_id))), ) @@ -166,3 +168,101 @@ class RegistrationStoreTestCase(HomeserverTestCase): ThreepidValidationError, ) self.assertEqual(e.value.msg, "Validation token not found or has expired", e) + + +class ApprovalRequiredRegistrationTestCase(HomeserverTestCase): + def default_config(self) -> JsonDict: + config = super().default_config() + + # If there's already some config for this feature in the default config, it + # means we're overriding it with @override_config. In this case we don't want + # to do anything more with it. + msc3866_config = config.get("experimental_features", {}).get("msc3866") + if msc3866_config is not None: + return config + + # Require approval for all new accounts. + config["experimental_features"] = { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + } + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.user_id = "@my-user:test" + self.pwhash = "{xx1}123456789" + + @override_config( + { + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": False, + } + } + } + ) + def test_approval_not_required(self) -> None: + """Tests that if we don't require approval for new accounts, newly created + accounts are automatically marked as approved. + """ + self.get_success(self.store.register_user(self.user_id, self.pwhash)) + + user = self.get_success(self.store.get_user_by_id(self.user_id)) + assert user is not None + self.assertTrue(user["approved"]) + + approved = self.get_success(self.store.is_user_approved(self.user_id)) + self.assertTrue(approved) + + def test_approval_required(self) -> None: + """Tests that if we require approval for new accounts, newly created accounts + are not automatically marked as approved. + """ + self.get_success(self.store.register_user(self.user_id, self.pwhash)) + + user = self.get_success(self.store.get_user_by_id(self.user_id)) + assert user is not None + self.assertFalse(user["approved"]) + + approved = self.get_success(self.store.is_user_approved(self.user_id)) + self.assertFalse(approved) + + def test_override(self) -> None: + """Tests that if we require approval for new accounts, but we explicitly say the + new user should be considered approved, they're marked as approved. + """ + self.get_success( + self.store.register_user( + self.user_id, + self.pwhash, + approved=True, + ) + ) + + user = self.get_success(self.store.get_user_by_id(self.user_id)) + self.assertIsNotNone(user) + assert user is not None + self.assertEqual(user["approved"], 1) + + approved = self.get_success(self.store.is_user_approved(self.user_id)) + self.assertTrue(approved) + + def test_approve_user(self) -> None: + """Tests that approving the user updates their approval status.""" + self.get_success(self.store.register_user(self.user_id, self.pwhash)) + + approved = self.get_success(self.store.is_user_approved(self.user_id)) + self.assertFalse(approved) + + self.get_success( + self.store.update_user_approval_status( + UserID.from_string(self.user_id), True + ) + ) + + approved = self.get_success(self.store.is_user_approved(self.user_id)) + self.assertTrue(approved) -- cgit 1.5.1 From 535f8c8f7d64d4058500a5988278fd3026645164 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 30 Sep 2022 17:40:33 +0100 Subject: Skip filtering during push if there are no push actions (#13992) --- changelog.d/13992.misc | 1 + synapse/push/bulk_push_rule_evaluator.py | 5 +++++ synapse/visibility.py | 4 ++++ tests/rest/client/test_rooms.py | 4 ++-- 4 files changed, 12 insertions(+), 2 deletions(-) create mode 100644 changelog.d/13992.misc (limited to 'tests/rest') diff --git a/changelog.d/13992.misc b/changelog.d/13992.misc new file mode 100644 index 0000000000..58150a2b35 --- /dev/null +++ b/changelog.d/13992.misc @@ -0,0 +1 @@ +Speed up calculating push actions in large rooms. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 7bfe380543..4270438918 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -332,6 +332,11 @@ class BulkPushRuleEvaluator: # Push rules say we should notify the user of this event actions_by_user[uid] = actions + # If there aren't any actions then we can skip the rest of the + # processing. + if not actions_by_user: + return + # This is a check for the case where user joins a room without being # allowed to see history, and then the server receives a delayed event # from before the user joined, which they should not be pushed for diff --git a/synapse/visibility.py b/synapse/visibility.py index c810a05907..c4048d2477 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -162,6 +162,10 @@ async def filter_event_for_clients_with_state( if event.internal_metadata.is_soft_failed(): return [] + # Fast path if we don't have any user IDs to check. + if not user_ids: + return () + # Make a set for all user IDs that haven't been filtered out by a check. allowed_user_ids = set(user_ids) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index e281aef779..7f8cf4fab0 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -710,7 +710,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(35, channel.resource_usage.db_txn_count) + self.assertEqual(34, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -723,7 +723,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(38, channel.resource_usage.db_txn_count) + self.assertEqual(37, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id -- cgit 1.5.1 From 719488dda87b04e4650a32f0c2b0b71782e0d48b Mon Sep 17 00:00:00 2001 From: lukasdenk <63459921+lukasdenk@users.noreply.github.com> Date: Mon, 3 Oct 2022 14:30:45 +0100 Subject: Add query parameter `ts` to allow appservices set the `origin_server_ts` for state events. (#11866) MSC3316 declares that both /rooms/{roomId}/send and /rooms/{roomId}/state should accept a ts parameter for appservices. This change expands support to /state and adds tests. --- changelog.d/11866.feature | 1 + synapse/handlers/room_member.py | 13 +++++ synapse/rest/client/room.py | 34 +++++++----- tests/rest/client/test_rooms.py | 119 +++++++++++++++++++++++++++++++++++++++- 4 files changed, 152 insertions(+), 15 deletions(-) create mode 100644 changelog.d/11866.feature (limited to 'tests/rest') diff --git a/changelog.d/11866.feature b/changelog.d/11866.feature new file mode 100644 index 0000000000..0b52caf805 --- /dev/null +++ b/changelog.d/11866.feature @@ -0,0 +1 @@ +Allow application services to set the `origin_server_ts` of a state event by providing the query parameter `ts` in `PUT /_matrix/client/r0/rooms/{roomId}/state/{eventType}/{stateKey}`, per [MSC3316](https://github.com/matrix-org/matrix-doc/pull/3316). Contributed by @lukasdenk. \ No newline at end of file diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index ee669eb30f..6ad2b38b8f 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -322,6 +322,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): require_consent: bool = True, outlier: bool = False, historical: bool = False, + origin_server_ts: Optional[int] = None, ) -> Tuple[str, int]: """ Internal membership update function to get an existing event or create @@ -361,6 +362,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): historical: Indicates whether the message is being inserted back in time around some existing events. This is used to skip a few checks and mark the event as backfilled. + origin_server_ts: The origin_server_ts to use if a new event is created. Uses + the current timestamp if set to None. Returns: Tuple of event ID and stream ordering position @@ -399,6 +402,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): "state_key": user_id, # For backwards compatibility: "membership": membership, + "origin_server_ts": origin_server_ts, }, txn_id=txn_id, allow_no_prev_events=allow_no_prev_events, @@ -504,6 +508,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): prev_event_ids: Optional[List[str]] = None, state_event_ids: Optional[List[str]] = None, depth: Optional[int] = None, + origin_server_ts: Optional[int] = None, ) -> Tuple[str, int]: """Update a user's membership in a room. @@ -542,6 +547,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): depth: Override the depth used to order the event in the DAG. Should normally be set to None, which will cause the depth to be calculated based on the prev_events. + origin_server_ts: The origin_server_ts to use if a new event is created. Uses + the current timestamp if set to None. Returns: A tuple of the new event ID and stream ID. @@ -583,6 +590,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): prev_event_ids=prev_event_ids, state_event_ids=state_event_ids, depth=depth, + origin_server_ts=origin_server_ts, ) return result @@ -606,6 +614,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): prev_event_ids: Optional[List[str]] = None, state_event_ids: Optional[List[str]] = None, depth: Optional[int] = None, + origin_server_ts: Optional[int] = None, ) -> Tuple[str, int]: """Helper for update_membership. @@ -646,6 +655,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): depth: Override the depth used to order the event in the DAG. Should normally be set to None, which will cause the depth to be calculated based on the prev_events. + origin_server_ts: The origin_server_ts to use if a new event is created. Uses + the current timestamp if set to None. Returns: A tuple of the new event ID and stream ID. @@ -785,6 +796,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): require_consent=require_consent, outlier=outlier, historical=historical, + origin_server_ts=origin_server_ts, ) latest_event_ids = await self.store.get_prev_events_for_room(room_id) @@ -1030,6 +1042,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): content=content, require_consent=require_consent, outlier=outlier, + origin_server_ts=origin_server_ts, ) async def _should_perform_remote_join( diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 0bca012535..b6dedbed04 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -268,15 +268,9 @@ class RoomStateEventRestServlet(TransactionRestServlet): content = parse_json_object_from_request(request) - event_dict = { - "type": event_type, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - } - - if state_key is not None: - event_dict["state_key"] = state_key + origin_server_ts = None + if requester.app_service: + origin_server_ts = parse_integer(request, "ts") try: if event_type == EventTypes.Member: @@ -287,8 +281,22 @@ class RoomStateEventRestServlet(TransactionRestServlet): room_id=room_id, action=membership, content=content, + origin_server_ts=origin_server_ts, ) else: + event_dict: JsonDict = { + "type": event_type, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + } + + if state_key is not None: + event_dict["state_key"] = state_key + + if origin_server_ts is not None: + event_dict["origin_server_ts"] = origin_server_ts + ( event, _, @@ -333,10 +341,10 @@ class RoomSendEventRestServlet(TransactionRestServlet): "sender": requester.user.to_string(), } - # Twisted will have processed the args by now. - assert request.args is not None - if b"ts" in request.args and requester.app_service: - event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) + if requester.app_service: + origin_server_ts = parse_integer(request, "ts") + if origin_server_ts is not None: + event_dict["origin_server_ts"] = origin_server_ts try: ( diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 7f8cf4fab0..5e66b5b26c 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -20,7 +20,7 @@ import json from http import HTTPStatus from typing import Any, Dict, Iterable, List, Optional, Tuple, Union -from unittest.mock import Mock, call +from unittest.mock import Mock, call, patch from urllib import parse as urlparse from parameterized import param, parameterized @@ -39,9 +39,10 @@ from synapse.api.constants import ( RoomTypes, ) from synapse.api.errors import Codes, HttpResponseException +from synapse.appservice import ApplicationService from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin -from synapse.rest.client import account, directory, login, profile, room, sync +from synapse.rest.client import account, directory, login, profile, register, room, sync from synapse.server import HomeServer from synapse.types import JsonDict, RoomAlias, UserID, create_requester from synapse.util import Clock @@ -1252,6 +1253,120 @@ class RoomJoinTestCase(RoomBase): ) +class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase): + servlets = [ + room.register_servlets, + synapse.rest.admin.register_servlets, + register.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.appservice_user, _ = self.register_appservice_user( + "as_user_potato", self.appservice.token + ) + + # Create a room as the appservice user. + args = { + "access_token": self.appservice.token, + "user_id": self.appservice_user, + } + channel = self.make_request( + "POST", + f"/_matrix/client/r0/createRoom?{urlparse.urlencode(args)}", + content={"visibility": "public"}, + ) + + assert channel.code == 200 + self.room = channel.json_body["room_id"] + + self.main_store = self.hs.get_datastores().main + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + + self.appservice = ApplicationService( + token="i_am_an_app_service", + id="1234", + namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + # Note: this user does not have to match the regex above + sender="@as_main:test", + ) + + mock_load_appservices = Mock(return_value=[self.appservice]) + with patch( + "synapse.storage.databases.main.appservice.load_appservices", + mock_load_appservices, + ): + hs = self.setup_test_homeserver(config=config) + return hs + + def test_send_event_ts(self) -> None: + """Test sending a non-state event with a custom timestamp.""" + ts = 1 + + url_params = { + "user_id": self.appservice_user, + "ts": ts, + } + channel = self.make_request( + "PUT", + path=f"/_matrix/client/r0/rooms/{self.room}/send/m.room.message/1234?" + + urlparse.urlencode(url_params), + content={"body": "test", "msgtype": "m.text"}, + access_token=self.appservice.token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + event_id = channel.json_body["event_id"] + + # Ensure the event was persisted with the correct timestamp. + res = self.get_success(self.main_store.get_event(event_id)) + self.assertEquals(ts, res.origin_server_ts) + + def test_send_state_event_ts(self) -> None: + """Test sending a state event with a custom timestamp.""" + ts = 1 + + url_params = { + "user_id": self.appservice_user, + "ts": ts, + } + channel = self.make_request( + "PUT", + path=f"/_matrix/client/r0/rooms/{self.room}/state/m.room.name?" + + urlparse.urlencode(url_params), + content={"name": "test"}, + access_token=self.appservice.token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + event_id = channel.json_body["event_id"] + + # Ensure the event was persisted with the correct timestamp. + res = self.get_success(self.main_store.get_event(event_id)) + self.assertEquals(ts, res.origin_server_ts) + + def test_send_membership_event_ts(self) -> None: + """Test sending a membership event with a custom timestamp.""" + ts = 1 + + url_params = { + "user_id": self.appservice_user, + "ts": ts, + } + channel = self.make_request( + "PUT", + path=f"/_matrix/client/r0/rooms/{self.room}/state/m.room.member/{self.appservice_user}?" + + urlparse.urlencode(url_params), + content={"membership": "join", "display_name": "test"}, + access_token=self.appservice.token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + event_id = channel.json_body["event_id"] + + # Ensure the event was persisted with the correct timestamp. + res = self.get_success(self.main_store.get_event(event_id)) + self.assertEquals(ts, res.origin_server_ts) + + class RoomJoinRatelimitTestCase(RoomBase): user_id = "@sid1:red" -- cgit 1.5.1 From b706111b7805dceb268e114b6c291c4318288cf0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 3 Oct 2022 12:47:15 -0400 Subject: Do not return unspecced original_event field when using the stable /relations endpoint. (#14025) Keep the old behavior (of including the original_event field) for any requests to the /unstable version of the endpoint, but do not include the field when the /v1 version is used. This should avoid new clients from depending on this field, but will not help with current dependencies. --- changelog.d/14025.bugfix | 1 + synapse/handlers/relations.py | 25 +++++++++++++------------ synapse/rest/client/relations.py | 6 ++++++ tests/rest/client/test_relations.py | 13 ++++++++----- 4 files changed, 28 insertions(+), 17 deletions(-) create mode 100644 changelog.d/14025.bugfix (limited to 'tests/rest') diff --git a/changelog.d/14025.bugfix b/changelog.d/14025.bugfix new file mode 100644 index 0000000000..391364f44d --- /dev/null +++ b/changelog.d/14025.bugfix @@ -0,0 +1 @@ +Do not return an unspecified `original_event` field when using the stable `/relations` endpoint. Introduced in Synapse v1.57.0. diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 28d7093f08..63bc6a7aa5 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -78,6 +78,7 @@ class RelationsHandler: direction: str = "b", from_token: Optional[StreamToken] = None, to_token: Optional[StreamToken] = None, + include_original_event: bool = False, ) -> JsonDict: """Get related events of a event, ordered by topological ordering. @@ -94,6 +95,7 @@ class RelationsHandler: oldest first (`"f"`). from_token: Fetch rows from the given token, or from the start if None. to_token: Fetch rows up to the given token, or up to the end if None. + include_original_event: Whether to include the parent event. Returns: The pagination chunk. @@ -138,25 +140,24 @@ class RelationsHandler: is_peeking=(member_event_id is None), ) - now = self._clock.time_msec() - # Do not bundle aggregations when retrieving the original event because - # we want the content before relations are applied to it. - original_event = self._event_serializer.serialize_event( - event, now, bundle_aggregations=None - ) # The relations returned for the requested event do include their # bundled aggregations. aggregations = await self.get_bundled_aggregations( events, requester.user.to_string() ) - serialized_events = self._event_serializer.serialize_events( - events, now, bundle_aggregations=aggregations - ) - return_value = { - "chunk": serialized_events, - "original_event": original_event, + now = self._clock.time_msec() + return_value: JsonDict = { + "chunk": self._event_serializer.serialize_events( + events, now, bundle_aggregations=aggregations + ), } + if include_original_event: + # Do not bundle aggregations when retrieving the original event because + # we want the content before relations are applied to it. + return_value["original_event"] = self._event_serializer.serialize_event( + event, now, bundle_aggregations=None + ) if next_token: return_value["next_batch"] = await next_token.to_string(self._main_store) diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 205c556f64..7a25de5c85 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -82,6 +82,11 @@ class RelationPaginationServlet(RestServlet): if to_token_str: to_token = await StreamToken.from_string(self.store, to_token_str) + # The unstable version of this API returns an extra field for client + # compatibility, see https://github.com/matrix-org/synapse/issues/12930. + assert request.path is not None + include_original_event = request.path.startswith(b"/_matrix/client/unstable/") + result = await self._relations_handler.get_relations( requester=requester, event_id=parent_id, @@ -92,6 +97,7 @@ class RelationPaginationServlet(RestServlet): direction=direction, from_token=from_token, to_token=to_token, + include_original_event=include_original_event, ) return 200, result diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index fef3b72d76..988cdb746d 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -654,6 +654,14 @@ class RelationsTestCase(BaseRelationsTestCase): ) # We also expect to get the original event (the id of which is self.parent_id) + # when requesting the unstable endpoint. + self.assertNotIn("original_event", channel.json_body) + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) self.assertEqual( channel.json_body["original_event"]["event_id"], self.parent_id ) @@ -755,11 +763,6 @@ class RelationPaginationTestCase(BaseRelationsTestCase): channel.json_body["chunk"][0], ) - # We also expect to get the original event (the id of which is self.parent_id) - self.assertEqual( - channel.json_body["original_event"]["event_id"], self.parent_id - ) - # Make sure next_batch has something in it that looks like it could be a # valid token. self.assertIsInstance( -- cgit 1.5.1 From 0b037d6c918cb04f86b1fccae9610552de9386d7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 5 Oct 2022 08:49:52 -0400 Subject: Fix handling of public rooms filter with a network tuple. (#14053) Fixes two related bugs: * The handling of `[null]` for a `room_types` filter was incorrect. * The ordering of arguments when providing both a network tuple and room type field was incorrect. --- changelog.d/14053.bugfix | 1 + synapse/storage/databases/main/room.py | 43 ++++++++++++++++++++-------------- tests/rest/client/test_rooms.py | 41 ++++++++++++++++++++++++-------- 3 files changed, 58 insertions(+), 27 deletions(-) create mode 100644 changelog.d/14053.bugfix (limited to 'tests/rest') diff --git a/changelog.d/14053.bugfix b/changelog.d/14053.bugfix new file mode 100644 index 0000000000..07769f51d0 --- /dev/null +++ b/changelog.d/14053.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.53.0 when querying `/publicRooms` with both a `room_type` filter and a `third_party_instance_id`. diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 7412bce255..e41c99027a 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -207,21 +207,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): def _construct_room_type_where_clause( self, room_types: Union[List[Union[str, None]], None] - ) -> Tuple[Union[str, None], List[str]]: + ) -> Tuple[Union[str, None], list]: if not room_types: return None, [] - else: - # We use None when we want get rooms without a type - is_null_clause = "" - if None in room_types: - is_null_clause = "OR room_type IS NULL" - room_types = [value for value in room_types if value is not None] + # Since None is used to represent a room without a type, care needs to + # be taken into account when constructing the where clause. + clauses = [] + args: list = [] + + room_types_set = set(room_types) + + # We use None to represent a room without a type. + if None in room_types_set: + clauses.append("room_type IS NULL") + room_types_set.remove(None) + + # If there are other room types, generate the proper clause. + if room_types: list_clause, args = make_in_list_sql_clause( - self.database_engine, "room_type", room_types + self.database_engine, "room_type", room_types_set ) + clauses.append(list_clause) - return f"({list_clause} {is_null_clause})", args + return f"({' OR '.join(clauses)})", args async def count_public_rooms( self, @@ -241,14 +250,6 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): def _count_public_rooms_txn(txn: LoggingTransaction) -> int: query_args = [] - room_type_clause, args = self._construct_room_type_where_clause( - search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None) - if search_filter - else None - ) - room_type_clause = f" AND {room_type_clause}" if room_type_clause else "" - query_args += args - if network_tuple: if network_tuple.appservice_id: published_sql = """ @@ -268,6 +269,14 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): UNION SELECT room_id from appservice_room_list """ + room_type_clause, args = self._construct_room_type_where_clause( + search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None) + if search_filter + else None + ) + room_type_clause = f" AND {room_type_clause}" if room_type_clause else "" + query_args += args + sql = f""" SELECT COUNT(*) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 5e66b5b26c..3612ebe7b9 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -2213,14 +2213,17 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase): ) def make_public_rooms_request( - self, room_types: Union[List[Union[str, None]], None] + self, + room_types: Optional[List[Union[str, None]]], + instance_id: Optional[str] = None, ) -> Tuple[List[Dict[str, Any]], int]: - channel = self.make_request( - "POST", - self.url, - {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}}, - self.token, - ) + body: JsonDict = {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}} + if instance_id: + body["third_party_instance_id"] = "test|test" + + channel = self.make_request("POST", self.url, body, self.token) + self.assertEqual(channel.code, 200) + chunk = channel.json_body["chunk"] count = channel.json_body["total_room_count_estimate"] @@ -2230,31 +2233,49 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase): def test_returns_both_rooms_and_spaces_if_no_filter(self) -> None: chunk, count = self.make_public_rooms_request(None) - self.assertEqual(count, 2) + # Also check if there's no filter property at all in the body. + channel = self.make_request("POST", self.url, {}, self.token) + self.assertEqual(channel.code, 200) + self.assertEqual(len(channel.json_body["chunk"]), 2) + self.assertEqual(channel.json_body["total_room_count_estimate"], 2) + + chunk, count = self.make_public_rooms_request(None, "test|test") + self.assertEqual(count, 0) + def test_returns_only_rooms_based_on_filter(self) -> None: chunk, count = self.make_public_rooms_request([None]) self.assertEqual(count, 1) self.assertEqual(chunk[0].get("room_type", None), None) + chunk, count = self.make_public_rooms_request([None], "test|test") + self.assertEqual(count, 0) + def test_returns_only_space_based_on_filter(self) -> None: chunk, count = self.make_public_rooms_request(["m.space"]) self.assertEqual(count, 1) self.assertEqual(chunk[0].get("room_type", None), "m.space") + chunk, count = self.make_public_rooms_request(["m.space"], "test|test") + self.assertEqual(count, 0) + def test_returns_both_rooms_and_space_based_on_filter(self) -> None: chunk, count = self.make_public_rooms_request(["m.space", None]) - self.assertEqual(count, 2) + chunk, count = self.make_public_rooms_request(["m.space", None], "test|test") + self.assertEqual(count, 0) + def test_returns_both_rooms_and_spaces_if_array_is_empty(self) -> None: chunk, count = self.make_public_rooms_request([]) - self.assertEqual(count, 2) + chunk, count = self.make_public_rooms_request([], "test|test") + self.assertEqual(count, 0) + class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): """Test that we correctly fallback to local filtering if a remote server -- cgit 1.5.1 From 00c93d2e7ef5642c9cf900f3fdcfa229e70f843d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 7 Oct 2022 09:29:43 -0400 Subject: Be more lenient in the oEmbed response parsing. (#14089) Attempt to parse any valid information from an oEmbed response (instead of bailing at the first unexpected data). This should allow for more partial oEmbed data to be returned, resulting in better / more URL previews, even if those URL previews are only partial. --- changelog.d/14089.bugfix | 1 + synapse/rest/media/v1/oembed.py | 107 ++++++++++++++++++++----------------- tests/rest/media/v1/test_oembed.py | 103 ++++++++++++++++++++++++++++++++++- 3 files changed, 160 insertions(+), 51 deletions(-) create mode 100644 changelog.d/14089.bugfix (limited to 'tests/rest') diff --git a/changelog.d/14089.bugfix b/changelog.d/14089.bugfix new file mode 100644 index 0000000000..4a398921bb --- /dev/null +++ b/changelog.d/14089.bugfix @@ -0,0 +1 @@ +Fix a bug where invalid oEmbed fields would cause the entire response to be discarded. Introduced in Synapse 1.18.0. diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py index 2177b46c9e..827afd868d 100644 --- a/synapse/rest/media/v1/oembed.py +++ b/synapse/rest/media/v1/oembed.py @@ -139,65 +139,72 @@ class OEmbedProvider: try: # oEmbed responses *must* be UTF-8 according to the spec. oembed = json_decoder.decode(raw_body.decode("utf-8")) + except ValueError: + return OEmbedResult({}, None, None) - # The version is a required string field, but not always provided, - # or sometimes provided as a float. Be lenient. - oembed_version = oembed.get("version", "1.0") - if oembed_version != "1.0" and oembed_version != 1: - raise RuntimeError(f"Invalid oEmbed version: {oembed_version}") + # The version is a required string field, but not always provided, + # or sometimes provided as a float. Be lenient. + oembed_version = oembed.get("version", "1.0") + if oembed_version != "1.0" and oembed_version != 1: + return OEmbedResult({}, None, None) - # Ensure the cache age is None or an int. - cache_age = oembed.get("cache_age") - if cache_age: - cache_age = int(cache_age) * 1000 - - # The results. - open_graph_response = { - "og:url": url, - } - - title = oembed.get("title") - if title: - open_graph_response["og:title"] = title - - author_name = oembed.get("author_name") + # Attempt to parse the cache age, if possible. + try: + cache_age = int(oembed.get("cache_age")) * 1000 + except (TypeError, ValueError): + # If the cache age cannot be parsed (e.g. wrong type or invalid + # string), ignore it. + cache_age = None - # Use the provider name and as the site. - provider_name = oembed.get("provider_name") - if provider_name: - open_graph_response["og:site_name"] = provider_name + # The oEmbed response converted to Open Graph. + open_graph_response: JsonDict = {"og:url": url} - # If a thumbnail exists, use it. Note that dimensions will be calculated later. - if "thumbnail_url" in oembed: - open_graph_response["og:image"] = oembed["thumbnail_url"] + title = oembed.get("title") + if title and isinstance(title, str): + open_graph_response["og:title"] = title - # Process each type separately. - oembed_type = oembed["type"] - if oembed_type == "rich": - calc_description_and_urls(open_graph_response, oembed["html"]) - - elif oembed_type == "photo": - # If this is a photo, use the full image, not the thumbnail. - open_graph_response["og:image"] = oembed["url"] + author_name = oembed.get("author_name") + if not isinstance(author_name, str): + author_name = None - elif oembed_type == "video": - open_graph_response["og:type"] = "video.other" + # Use the provider name and as the site. + provider_name = oembed.get("provider_name") + if provider_name and isinstance(provider_name, str): + open_graph_response["og:site_name"] = provider_name + + # If a thumbnail exists, use it. Note that dimensions will be calculated later. + thumbnail_url = oembed.get("thumbnail_url") + if thumbnail_url and isinstance(thumbnail_url, str): + open_graph_response["og:image"] = thumbnail_url + + # Process each type separately. + oembed_type = oembed.get("type") + if oembed_type == "rich": + html = oembed.get("html") + if isinstance(html, str): + calc_description_and_urls(open_graph_response, html) + + elif oembed_type == "photo": + # If this is a photo, use the full image, not the thumbnail. + url = oembed.get("url") + if url and isinstance(url, str): + open_graph_response["og:image"] = url + + elif oembed_type == "video": + open_graph_response["og:type"] = "video.other" + html = oembed.get("html") + if html and isinstance(html, str): calc_description_and_urls(open_graph_response, oembed["html"]) - open_graph_response["og:video:width"] = oembed["width"] - open_graph_response["og:video:height"] = oembed["height"] - - elif oembed_type == "link": - open_graph_response["og:type"] = "website" + for size in ("width", "height"): + val = oembed.get(size) + if val is not None and isinstance(val, int): + open_graph_response[f"og:video:{size}"] = val - else: - raise RuntimeError(f"Unknown oEmbed type: {oembed_type}") + elif oembed_type == "link": + open_graph_response["og:type"] = "website" - except Exception as e: - # Trap any exception and let the code follow as usual. - logger.warning("Error parsing oEmbed metadata from %s: %r", url, e) - open_graph_response = {} - author_name = None - cache_age = None + else: + logger.warning("Unknown oEmbed type: %s", oembed_type) return OEmbedResult(open_graph_response, author_name, cache_age) diff --git a/tests/rest/media/v1/test_oembed.py b/tests/rest/media/v1/test_oembed.py index f38d7225f8..319ae8b1cc 100644 --- a/tests/rest/media/v1/test_oembed.py +++ b/tests/rest/media/v1/test_oembed.py @@ -14,6 +14,8 @@ import json +from parameterized import parameterized + from twisted.test.proto_helpers import MemoryReactor from synapse.rest.media.v1.oembed import OEmbedProvider, OEmbedResult @@ -23,8 +25,16 @@ from synapse.util import Clock from tests.unittest import HomeserverTestCase +try: + import lxml +except ImportError: + lxml = None + class OEmbedTests(HomeserverTestCase): + if not lxml: + skip = "url preview feature requires lxml" + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.oembed = OEmbedProvider(hs) @@ -36,7 +46,7 @@ class OEmbedTests(HomeserverTestCase): def test_version(self) -> None: """Accept versions that are similar to 1.0 as a string or int (or missing).""" for version in ("1.0", 1.0, 1): - result = self.parse_response({"version": version, "type": "link"}) + result = self.parse_response({"version": version}) # An empty Open Graph response is an error, ensure the URL is included. self.assertIn("og:url", result.open_graph_result) @@ -49,3 +59,94 @@ class OEmbedTests(HomeserverTestCase): result = self.parse_response({"version": version, "type": "link"}) # An empty Open Graph response is an error, ensure the URL is included. self.assertEqual({}, result.open_graph_result) + + def test_cache_age(self) -> None: + """Ensure a cache-age is parsed properly.""" + # Correct-ish cache ages are allowed. + for cache_age in ("1", 1.0, 1): + result = self.parse_response({"cache_age": cache_age}) + self.assertEqual(result.cache_age, 1000) + + # Invalid cache ages are ignored. + for cache_age in ("invalid", {}): + result = self.parse_response({"cache_age": cache_age}) + self.assertIsNone(result.cache_age) + + # Cache age is optional. + result = self.parse_response({}) + self.assertIsNone(result.cache_age) + + @parameterized.expand( + [ + ("title", "title"), + ("provider_name", "site_name"), + ("thumbnail_url", "image"), + ], + name_func=lambda func, num, p: f"{func.__name__}_{p.args[0]}", + ) + def test_property(self, oembed_property: str, open_graph_property: str) -> None: + """Test properties which must be strings.""" + result = self.parse_response({oembed_property: "test"}) + self.assertIn(f"og:{open_graph_property}", result.open_graph_result) + self.assertEqual(result.open_graph_result[f"og:{open_graph_property}"], "test") + + result = self.parse_response({oembed_property: 1}) + self.assertNotIn(f"og:{open_graph_property}", result.open_graph_result) + + def test_author_name(self) -> None: + """Test the author_name property.""" + result = self.parse_response({"author_name": "test"}) + self.assertEqual(result.author_name, "test") + + result = self.parse_response({"author_name": 1}) + self.assertIsNone(result.author_name) + + def test_rich(self) -> None: + """Test a type of rich.""" + result = self.parse_response({"html": "test", "type": "rich"}) + self.assertIn("og:description", result.open_graph_result) + self.assertIn("og:image", result.open_graph_result) + self.assertEqual(result.open_graph_result["og:description"], "test") + self.assertEqual(result.open_graph_result["og:image"], "foo") + + result = self.parse_response({"type": "rich"}) + self.assertNotIn("og:description", result.open_graph_result) + + result = self.parse_response({"html": 1, "type": "rich"}) + self.assertNotIn("og:description", result.open_graph_result) + + def test_photo(self) -> None: + """Test a type of photo.""" + result = self.parse_response({"url": "test", "type": "photo"}) + self.assertIn("og:image", result.open_graph_result) + self.assertEqual(result.open_graph_result["og:image"], "test") + + result = self.parse_response({"type": "photo"}) + self.assertNotIn("og:image", result.open_graph_result) + + result = self.parse_response({"url": 1, "type": "photo"}) + self.assertNotIn("og:image", result.open_graph_result) + + def test_video(self) -> None: + """Test a type of video.""" + result = self.parse_response({"html": "test", "type": "video"}) + self.assertIn("og:type", result.open_graph_result) + self.assertEqual(result.open_graph_result["og:type"], "video.other") + self.assertIn("og:description", result.open_graph_result) + self.assertEqual(result.open_graph_result["og:description"], "test") + + result = self.parse_response({"type": "video"}) + self.assertIn("og:type", result.open_graph_result) + self.assertEqual(result.open_graph_result["og:type"], "video.other") + self.assertNotIn("og:description", result.open_graph_result) + + result = self.parse_response({"url": 1, "type": "video"}) + self.assertIn("og:type", result.open_graph_result) + self.assertEqual(result.open_graph_result["og:type"], "video.other") + self.assertNotIn("og:description", result.open_graph_result) + + def test_link(self) -> None: + """Test type of link.""" + result = self.parse_response({"type": "link"}) + self.assertIn("og:type", result.open_graph_result) + self.assertEqual(result.open_graph_result["og:type"], "website") -- cgit 1.5.1 From 3bbe532abb7bfc41467597731ac1a18c0331f539 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 13 Oct 2022 08:02:11 -0400 Subject: Add an API for listing threads in a room. (#13394) Implement the /threads endpoint from MSC3856. This is currently unstable and behind an experimental configuration flag. It includes a background update to backfill data, results from the /threads endpoint will be partial until that finishes. --- changelog.d/13394.feature | 1 + synapse/_scripts/synapse_port_db.py | 2 + synapse/config/experimental.py | 3 + synapse/handlers/relations.py | 86 ++++++++++- synapse/rest/client/relations.py | 50 ++++++- synapse/storage/databases/main/cache.py | 1 + synapse/storage/databases/main/events.py | 38 ++++- synapse/storage/databases/main/relations.py | 166 ++++++++++++++++++++- .../schema/main/delta/73/09threads_table.sql | 30 ++++ tests/rest/client/test_relations.py | 151 +++++++++++++++++++ 10 files changed, 522 insertions(+), 6 deletions(-) create mode 100644 changelog.d/13394.feature create mode 100644 synapse/storage/schema/main/delta/73/09threads_table.sql (limited to 'tests/rest') diff --git a/changelog.d/13394.feature b/changelog.d/13394.feature new file mode 100644 index 0000000000..68de079cf3 --- /dev/null +++ b/changelog.d/13394.feature @@ -0,0 +1 @@ +Experimental support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API. diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 5fa599e70e..d850e54e17 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -72,6 +72,7 @@ from synapse.storage.databases.main.registration import ( RegistrationBackgroundUpdateStore, find_max_generated_user_id_localpart, ) +from synapse.storage.databases.main.relations import RelationsWorkerStore from synapse.storage.databases.main.room import RoomBackgroundUpdateStore from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore from synapse.storage.databases.main.search import SearchBackgroundUpdateStore @@ -206,6 +207,7 @@ class Store( PusherWorkerStore, PresenceBackgroundUpdateStore, ReceiptsBackgroundUpdateStore, + RelationsWorkerStore, ): def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]: return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index f44655516e..1860006536 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -101,6 +101,9 @@ class ExperimentalConfig(Config): # MSC3848: Introduce errcodes for specific event sending failures self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False) + # MSC3856: Threads list API + self.msc3856_enabled: bool = experimental.get("msc3856_enabled", False) + # MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices. self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False) diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index cc5e45c241..1fdd7a10bc 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import enum import logging from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple @@ -20,7 +21,7 @@ from synapse.api.constants import RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event from synapse.logging.opentracing import trace -from synapse.storage.databases.main.relations import _RelatedEvent +from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, Requester, StreamToken, UserID from synapse.visibility import filter_events_for_client @@ -32,6 +33,13 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class ThreadsListInclude(str, enum.Enum): + """Valid values for the 'include' flag of /threads.""" + + all = "all" + participated = "participated" + + @attr.s(slots=True, frozen=True, auto_attribs=True) class _ThreadAggregation: # The latest event in the thread. @@ -482,3 +490,79 @@ class RelationsHandler: results.setdefault(event_id, BundledAggregations()).replace = edit return results + + async def get_threads( + self, + requester: Requester, + room_id: str, + include: ThreadsListInclude, + limit: int = 5, + from_token: Optional[ThreadsNextBatch] = None, + ) -> JsonDict: + """Get related events of a event, ordered by topological ordering. + + Args: + requester: The user requesting the relations. + room_id: The room the event belongs to. + include: One of "all" or "participated" to indicate which threads should + be returned. + limit: Only fetch the most recent `limit` events. + from_token: Fetch rows from the given token, or from the start if None. + + Returns: + The pagination chunk. + """ + + user_id = requester.user.to_string() + + # TODO Properly handle a user leaving a room. + (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable( + room_id, requester, allow_departed_users=True + ) + + # Note that ignored users are not passed into get_relations_for_event + # below. Ignored users are handled in filter_events_for_client (and by + # not passing them in here we should get a better cache hit rate). + thread_roots, next_batch = await self._main_store.get_threads( + room_id=room_id, limit=limit, from_token=from_token + ) + + events = await self._main_store.get_events_as_list(thread_roots) + + if include == ThreadsListInclude.participated: + # Pre-seed thread participation with whether the requester sent the event. + participated = {event.event_id: event.sender == user_id for event in events} + # For events the requester did not send, check the database for whether + # the requester sent a threaded reply. + participated.update( + await self._main_store.get_threads_participated( + [eid for eid, p in participated.items() if not p], + user_id, + ) + ) + + # Limit the returned threads to those the user has participated in. + events = [event for event in events if participated[event.event_id]] + + events = await filter_events_for_client( + self._storage_controllers, + user_id, + events, + is_peeking=(member_event_id is None), + ) + + aggregations = await self.get_bundled_aggregations( + events, requester.user.to_string() + ) + + now = self._clock.time_msec() + serialized_events = self._event_serializer.serialize_events( + events, now, bundle_aggregations=aggregations + ) + + return_value: JsonDict = {"chunk": serialized_events} + + if next_batch: + return_value["next_batch"] = str(next_batch) + + return return_value diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index b31ce5a0d3..d1aa1947a5 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -13,12 +13,15 @@ # limitations under the License. import logging +import re from typing import TYPE_CHECKING, Optional, Tuple +from synapse.handlers.relations import ThreadsListInclude from synapse.http.server import HttpServer -from synapse.http.servlet import RestServlet +from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns +from synapse.storage.databases.main.relations import ThreadsNextBatch from synapse.streams.config import PaginationConfig from synapse.types import JsonDict @@ -78,5 +81,50 @@ class RelationPaginationServlet(RestServlet): return 200, result +class ThreadsServlet(RestServlet): + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/org.matrix.msc3856/rooms/(?P[^/]*)/threads" + ), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastores().main + self._relations_handler = hs.get_relations_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + + limit = parse_integer(request, "limit", default=5) + from_token_str = parse_string(request, "from") + include = parse_string( + request, + "include", + default=ThreadsListInclude.all.value, + allowed_values=[v.value for v in ThreadsListInclude], + ) + + # Return the relations + from_token = None + if from_token_str: + from_token = ThreadsNextBatch.from_string(from_token_str) + + result = await self._relations_handler.get_threads( + requester=requester, + room_id=room_id, + include=ThreadsListInclude(include), + limit=limit, + from_token=from_token, + ) + + return 200, result + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RelationPaginationServlet(hs).register(http_server) + if hs.config.experimental.msc3856_enabled: + ThreadsServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index a9f25a5904..0ce3156c9c 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -259,6 +259,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,)) + self._attempt_to_invalidate_cache("get_threads", (room_id,)) async def invalidate_cache_and_stream( self, cache_name: str, keys: Tuple[Any, ...] diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 060fe71454..6698cbf664 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -35,7 +35,7 @@ import attr from prometheus_client import Counter import synapse.metrics -from synapse.api.constants import EventContentFields, EventTypes +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, relation_from_event @@ -1616,7 +1616,7 @@ class PersistEventsStore: ) # Remove from relations table. - self._handle_redact_relations(txn, event.redacts) + self._handle_redact_relations(txn, event.room_id, event.redacts) # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. @@ -1866,6 +1866,34 @@ class PersistEventsStore: }, ) + if relation.rel_type == RelationTypes.THREAD: + # Upsert into the threads table, but only overwrite the value if the + # new event is of a later topological order OR if the topological + # ordering is equal, but the stream ordering is later. + sql = """ + INSERT INTO threads (room_id, thread_id, latest_event_id, topological_ordering, stream_ordering) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT (room_id, thread_id) + DO UPDATE SET + latest_event_id = excluded.latest_event_id, + topological_ordering = excluded.topological_ordering, + stream_ordering = excluded.stream_ordering + WHERE + threads.topological_ordering <= excluded.topological_ordering AND + threads.stream_ordering < excluded.stream_ordering + """ + + txn.execute( + sql, + ( + event.room_id, + relation.parent_id, + event.event_id, + event.depth, + event.internal_metadata.stream_ordering, + ), + ) + def _handle_insertion_event( self, txn: LoggingTransaction, event: EventBase ) -> None: @@ -1989,13 +2017,14 @@ class PersistEventsStore: txn.execute(sql, (batch_id,)) def _handle_redact_relations( - self, txn: LoggingTransaction, redacted_event_id: str + self, txn: LoggingTransaction, room_id: str, redacted_event_id: str ) -> None: """Handles receiving a redaction and checking whether the redacted event has any relations which must be removed from the database. Args: txn + room_id: The room ID of the event that was redacted. redacted_event_id: The event that was redacted. """ @@ -2024,6 +2053,9 @@ class PersistEventsStore: self.store._invalidate_cache_and_stream( txn, self.store.get_thread_participated, (redacted_relates_to,) ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_threads, (room_id,) + ) self.db_pool.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index e7fbf950e6..ac9b96ab44 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -14,6 +14,7 @@ import logging from typing import ( + TYPE_CHECKING, Collection, Dict, FrozenSet, @@ -29,17 +30,46 @@ from typing import ( import attr from synapse.api.constants import MAIN_TIMELINE, RelationTypes +from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.storage._base import SQLBaseStore -from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_in_list_sql_clause, +) from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.engines import PostgresEngine from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken from synapse.util.caches.descriptors import cached, cachedList +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ThreadsNextBatch: + topological_ordering: int + stream_ordering: int + + def __str__(self) -> str: + return f"{self.topological_ordering}_{self.stream_ordering}" + + @classmethod + def from_string(cls, string: str) -> "ThreadsNextBatch": + """ + Creates a ThreadsNextBatch from its textual representation. + """ + try: + keys = (int(s) for s in string.split("_")) + return cls(*keys) + except Exception: + raise SynapseError(400, "Invalid threads token") + + @attr.s(slots=True, frozen=True, auto_attribs=True) class _RelatedEvent: """ @@ -56,6 +86,76 @@ class _RelatedEvent: class RelationsWorkerStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_update_handler( + "threads_backfill", self._backfill_threads + ) + + async def _backfill_threads(self, progress: JsonDict, batch_size: int) -> int: + """Backfill the threads table.""" + + def threads_backfill_txn(txn: LoggingTransaction) -> int: + last_thread_id = progress.get("last_thread_id", "") + + # Get the latest event in each thread by topo ordering / stream ordering. + # + # Note that the MAX(event_id) is needed to abide by the rules of group by, + # but doesn't actually do anything since there should only be a single event + # ID per topo/stream ordering pair. + sql = f""" + SELECT room_id, relates_to_id, MAX(topological_ordering), MAX(stream_ordering), MAX(event_id) + FROM event_relations + INNER JOIN events USING (event_id) + WHERE + relates_to_id > ? AND + relation_type = '{RelationTypes.THREAD}' + GROUP BY room_id, relates_to_id + ORDER BY relates_to_id + LIMIT ? + """ + txn.execute(sql, (last_thread_id, batch_size)) + + # No more rows to process. + rows = txn.fetchall() + if not rows: + return 0 + + # Insert the rows into the threads table. If a matching thread already exists, + # assume it is from a newer event. + sql = """ + INSERT INTO threads (room_id, thread_id, topological_ordering, stream_ordering, latest_event_id) + VALUES %s + ON CONFLICT (room_id, thread_id) + DO NOTHING + """ + if isinstance(txn.database_engine, PostgresEngine): + txn.execute_values(sql % ("?",), rows, fetch=False) + else: + txn.execute_batch(sql % ("?, ?, ?, ?, ?",), rows) + + # Mark the progress. + self.db_pool.updates._background_update_progress_txn( + txn, "threads_backfill", {"last_thread_id": rows[-1][1]} + ) + + return txn.rowcount + + result = await self.db_pool.runInteraction( + "threads_backfill", threads_backfill_txn + ) + + if not result: + await self.db_pool.updates._end_background_update("threads_backfill") + + return result + @cached(uncached_args=("event",), tree=True) async def get_relations_for_event( self, @@ -776,6 +876,70 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) + @cached(tree=True) + async def get_threads( + self, + room_id: str, + limit: int = 5, + from_token: Optional[ThreadsNextBatch] = None, + ) -> Tuple[List[str], Optional[ThreadsNextBatch]]: + """Get a list of thread IDs, ordered by topological ordering of their + latest reply. + + Args: + room_id: The room the event belongs to. + limit: Only fetch the most recent `limit` threads. + from_token: Fetch rows from a previous next_batch, or from the start if None. + + Returns: + A tuple of: + A list of thread root event IDs. + + The next_batch, if one exists. + """ + # Generate the pagination clause, if necessary. + # + # Find any threads where the latest reply is equal / before the last + # thread's topo ordering and earlier in stream ordering. + pagination_clause = "" + pagination_args: tuple = () + if from_token: + pagination_clause = "AND topological_ordering <= ? AND stream_ordering < ?" + pagination_args = ( + from_token.topological_ordering, + from_token.stream_ordering, + ) + + sql = f""" + SELECT thread_id, topological_ordering, stream_ordering + FROM threads + WHERE + room_id = ? + {pagination_clause} + ORDER BY topological_ordering DESC, stream_ordering DESC + LIMIT ? + """ + + def _get_threads_txn( + txn: LoggingTransaction, + ) -> Tuple[List[str], Optional[ThreadsNextBatch]]: + txn.execute(sql, (room_id, *pagination_args, limit + 1)) + + rows = cast(List[Tuple[str, int, int]], txn.fetchall()) + thread_ids = [r[0] for r in rows] + + # If there are more events, generate the next pagination key from the + # last thread which will be returned. + next_token = None + if len(thread_ids) > limit: + last_topo_id = rows[-2][1] + last_stream_id = rows[-2][2] + next_token = ThreadsNextBatch(last_topo_id, last_stream_id) + + return thread_ids[:limit], next_token + + return await self.db_pool.runInteraction("get_threads", _get_threads_txn) + @cached() async def get_thread_id(self, event_id: str) -> str: """ diff --git a/synapse/storage/schema/main/delta/73/09threads_table.sql b/synapse/storage/schema/main/delta/73/09threads_table.sql new file mode 100644 index 0000000000..aa7c5e9a2e --- /dev/null +++ b/synapse/storage/schema/main/delta/73/09threads_table.sql @@ -0,0 +1,30 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE threads ( + room_id TEXT NOT NULL, + -- The event ID of the root event in the thread. + thread_id TEXT NOT NULL, + -- The latest event ID and corresponding topo / stream ordering. + latest_event_id TEXT NOT NULL, + topological_ordering BIGINT NOT NULL, + stream_ordering BIGINT NOT NULL, + CONSTRAINT threads_uniqueness UNIQUE (room_id, thread_id) +); + +CREATE INDEX threads_ordering_idx ON threads(room_id, topological_ordering, stream_ordering); + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7309, 'threads_backfill', '{}'); diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 988cdb746d..d595295e2c 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1707,3 +1707,154 @@ class RelationRedactionTestCase(BaseRelationsTestCase): relations[RelationTypes.THREAD]["latest_event"]["event_id"], related_event_id, ) + + +class ThreadsTestCase(BaseRelationsTestCase): + @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def test_threads(self) -> None: + """Create threads and ensure the ordering is due to their latest event.""" + # Create 2 threads. + thread_1 = self.parent_id + res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token) + thread_2 = res["event_id"] + + self._send_relation(RelationTypes.THREAD, "m.room.test") + self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + + # Request the threads in the room. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_2, thread_1]) + + # Update the first thread, the ordering should swap. + self._send_relation(RelationTypes.THREAD, "m.room.test") + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_1, thread_2]) + + @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def test_pagination(self) -> None: + """Create threads and paginate through them.""" + # Create 2 threads. + thread_1 = self.parent_id + res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token) + thread_2 = res["event_id"] + + self._send_relation(RelationTypes.THREAD, "m.room.test") + self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + + # Request the threads in the room. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_2]) + + # Make sure next_batch has something in it that looks like it could be a + # valid token. + next_batch = channel.json_body.get("next_batch") + self.assertIsInstance(next_batch, str, channel.json_body) + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1&from={next_batch}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_1], channel.json_body) + + self.assertNotIn("next_batch", channel.json_body, channel.json_body) + + @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def test_include(self) -> None: + """Filtering threads to all or participated in should work.""" + # Thread 1 has the user as the root event. + thread_1 = self.parent_id + self._send_relation( + RelationTypes.THREAD, "m.room.test", access_token=self.user2_token + ) + + # Thread 2 has the user replying. + res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token) + thread_2 = res["event_id"] + self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + + # Thread 3 has the user not participating in. + res = self.helper.send(self.room, body="Another thread!", tok=self.user2_token) + thread_3 = res["event_id"] + self._send_relation( + RelationTypes.THREAD, + "m.room.test", + access_token=self.user2_token, + parent_id=thread_3, + ) + + # All threads in the room. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual( + thread_roots, [thread_3, thread_2, thread_1], channel.json_body + ) + + # Only participated threads. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?include=participated", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body) + + @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def test_ignored_user(self) -> None: + """Events from ignored users should be ignored.""" + # Thread 1 has a reply from an ignored user. + thread_1 = self.parent_id + self._send_relation( + RelationTypes.THREAD, "m.room.test", access_token=self.user2_token + ) + + # Thread 2 is created by an ignored user. + res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token) + thread_2 = res["event_id"] + self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + + # Ignore user2. + self.get_success( + self.store.add_account_data_for_user( + self.user_id, + AccountDataTypes.IGNORED_USER_LIST, + {"ignored_users": {self.user2_id: {}}}, + ) + ) + + # Only thread 1 is returned. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_1], channel.json_body) -- cgit 1.5.1 From c3e4edb4d6ba33383bc056e3ff22b2d034d3e248 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Oct 2022 07:16:50 -0400 Subject: Stabilize the threads API. (#14175) Stabilize the threads API (MSC3856) by supporting (only) the v1 path for the endpoint. This also marks the API as safe for workers since it is a read-only API. --- changelog.d/13394.feature | 2 +- changelog.d/14175.feature | 1 + docker/configure_workers_and_start.py | 1 + docs/workers.md | 1 + synapse/config/experimental.py | 3 --- synapse/rest/client/relations.py | 9 ++----- tests/rest/client/test_relations.py | 47 +++++++++++++++++++++-------------- 7 files changed, 35 insertions(+), 29 deletions(-) create mode 100644 changelog.d/14175.feature (limited to 'tests/rest') diff --git a/changelog.d/13394.feature b/changelog.d/13394.feature index 68de079cf3..df3ce45a76 100644 --- a/changelog.d/13394.feature +++ b/changelog.d/13394.feature @@ -1 +1 @@ -Experimental support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API. +Support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API. diff --git a/changelog.d/14175.feature b/changelog.d/14175.feature new file mode 100644 index 0000000000..df3ce45a76 --- /dev/null +++ b/changelog.d/14175.feature @@ -0,0 +1 @@ +Support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API. diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py index 8e7f605b24..d708237f69 100755 --- a/docker/configure_workers_and_start.py +++ b/docker/configure_workers_and_start.py @@ -118,6 +118,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = { "^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/state$", "^/_matrix/client/v1/rooms/.*/hierarchy$", "^/_matrix/client/(v1|unstable)/rooms/.*/relations/", + "^/_matrix/client/v1/rooms/.*/threads$", "^/_matrix/client/(api/v1|r0|v3|unstable)/login$", "^/_matrix/client/(api/v1|r0|v3|unstable)/account/3pid$", "^/_matrix/client/(api/v1|r0|v3|unstable)/account/whoami$", diff --git a/docs/workers.md b/docs/workers.md index e8d6cbaf8b..c27b3f8bd5 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -204,6 +204,7 @@ information. ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/state$ ^/_matrix/client/v1/rooms/.*/hierarchy$ ^/_matrix/client/(v1|unstable)/rooms/.*/relations/ + ^/_matrix/client/v1/rooms/.*/threads$ ^/_matrix/client/unstable/org.matrix.msc2716/rooms/.*/batch_send$ ^/_matrix/client/unstable/im.nheko.summary/rooms/.*/summary$ ^/_matrix/client/(r0|v3|unstable)/account/3pid$ diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 1860006536..f44655516e 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -101,9 +101,6 @@ class ExperimentalConfig(Config): # MSC3848: Introduce errcodes for specific event sending failures self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False) - # MSC3856: Threads list API - self.msc3856_enabled: bool = experimental.get("msc3856_enabled", False) - # MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices. self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False) diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index d1aa1947a5..9dd59196d9 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -82,11 +82,7 @@ class RelationPaginationServlet(RestServlet): class ThreadsServlet(RestServlet): - PATTERNS = ( - re.compile( - "^/_matrix/client/unstable/org.matrix.msc3856/rooms/(?P[^/]*)/threads" - ), - ) + PATTERNS = (re.compile("^/_matrix/client/v1/rooms/(?P[^/]*)/threads"),) def __init__(self, hs: "HomeServer"): super().__init__() @@ -126,5 +122,4 @@ class ThreadsServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RelationPaginationServlet(hs).register(http_server) - if hs.config.experimental.msc3856_enabled: - ThreadsServlet(hs).register(http_server) + ThreadsServlet(hs).register(http_server) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index d595295e2c..f5c1070b2c 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1710,7 +1710,15 @@ class RelationRedactionTestCase(BaseRelationsTestCase): class ThreadsTestCase(BaseRelationsTestCase): - @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def _get_threads(self, body: JsonDict) -> List[Tuple[str, str]]: + return [ + ( + ev["event_id"], + ev["unsigned"]["m.relations"]["m.thread"]["latest_event"]["event_id"], + ) + for ev in body["chunk"] + ] + def test_threads(self) -> None: """Create threads and ensure the ordering is due to their latest event.""" # Create 2 threads. @@ -1718,32 +1726,37 @@ class ThreadsTestCase(BaseRelationsTestCase): res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token) thread_2 = res["event_id"] - self._send_relation(RelationTypes.THREAD, "m.room.test") - self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + reply_1 = channel.json_body["event_id"] + channel = self._send_relation( + RelationTypes.THREAD, "m.room.test", parent_id=thread_2 + ) + reply_2 = channel.json_body["event_id"] # Request the threads in the room. channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + f"/_matrix/client/v1/rooms/{self.room}/threads", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] - self.assertEqual(thread_roots, [thread_2, thread_1]) + threads = self._get_threads(channel.json_body) + self.assertEqual(threads, [(thread_2, reply_2), (thread_1, reply_1)]) # Update the first thread, the ordering should swap. - self._send_relation(RelationTypes.THREAD, "m.room.test") + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + reply_3 = channel.json_body["event_id"] channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + f"/_matrix/client/v1/rooms/{self.room}/threads", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] - self.assertEqual(thread_roots, [thread_1, thread_2]) + # Tuple of (thread ID, latest event ID) for each thread. + threads = self._get_threads(channel.json_body) + self.assertEqual(threads, [(thread_1, reply_3), (thread_2, reply_2)]) - @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) def test_pagination(self) -> None: """Create threads and paginate through them.""" # Create 2 threads. @@ -1757,7 +1770,7 @@ class ThreadsTestCase(BaseRelationsTestCase): # Request the threads in the room. channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1", + f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) @@ -1771,7 +1784,7 @@ class ThreadsTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1&from={next_batch}", + f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1&from={next_batch}", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) @@ -1780,7 +1793,6 @@ class ThreadsTestCase(BaseRelationsTestCase): self.assertNotIn("next_batch", channel.json_body, channel.json_body) - @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) def test_include(self) -> None: """Filtering threads to all or participated in should work.""" # Thread 1 has the user as the root event. @@ -1807,7 +1819,7 @@ class ThreadsTestCase(BaseRelationsTestCase): # All threads in the room. channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + f"/_matrix/client/v1/rooms/{self.room}/threads", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) @@ -1819,14 +1831,13 @@ class ThreadsTestCase(BaseRelationsTestCase): # Only participated threads. channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?include=participated", + f"/_matrix/client/v1/rooms/{self.room}/threads?include=participated", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body) - @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) def test_ignored_user(self) -> None: """Events from ignored users should be ignored.""" # Thread 1 has a reply from an ignored user. @@ -1852,7 +1863,7 @@ class ThreadsTestCase(BaseRelationsTestCase): # Only thread 1 is returned. channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + f"/_matrix/client/v1/rooms/{self.room}/threads", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) -- cgit 1.5.1 From 126a15794c95002560709283640ad412636b29b8 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Oct 2022 08:30:05 -0400 Subject: Do not allow a None-limit on PaginationConfig. (#14146) The callers either set a default limit or manually handle a None-limit later on (by setting a default value). Update the callers to always instantiate PaginationConfig with a default limit and then assume the limit is non-None. --- changelog.d/14146.removal | 1 + synapse/handlers/account_data.py | 2 +- synapse/handlers/initial_sync.py | 27 ++++----------------------- synapse/handlers/pagination.py | 5 ----- synapse/handlers/presence.py | 4 +++- synapse/handlers/receipts.py | 2 +- synapse/handlers/relations.py | 3 --- synapse/handlers/room.py | 2 +- synapse/handlers/typing.py | 2 +- synapse/rest/client/events.py | 4 +++- synapse/rest/client/initial_sync.py | 4 +++- synapse/rest/client/room.py | 4 +++- synapse/storage/databases/main/stream.py | 2 -- synapse/streams/__init__.py | 2 +- synapse/streams/config.py | 12 +++++------- tests/rest/client/test_typing.py | 3 ++- 16 files changed, 29 insertions(+), 50 deletions(-) create mode 100644 changelog.d/14146.removal (limited to 'tests/rest') diff --git a/changelog.d/14146.removal b/changelog.d/14146.removal new file mode 100644 index 0000000000..08fa752897 --- /dev/null +++ b/changelog.d/14146.removal @@ -0,0 +1 @@ +Remove the unstable identifier for [MSC3715](https://github.com/matrix-org/matrix-doc/pull/3715). diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 0478448b47..fc21d58001 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -225,7 +225,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]): self, user: UserID, from_key: int, - limit: Optional[int], + limit: int, room_ids: Collection[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 860c82c110..9c335e6863 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -57,13 +57,7 @@ class InitialSyncHandler: self.validator = EventValidator() self.snapshot_cache: ResponseCache[ Tuple[ - str, - Optional[StreamToken], - Optional[StreamToken], - str, - Optional[int], - bool, - bool, + str, Optional[StreamToken], Optional[StreamToken], str, int, bool, bool ] ] = ResponseCache(hs.get_clock(), "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() @@ -154,11 +148,6 @@ class InitialSyncHandler: public_room_ids = await self.store.get_public_room_ids() - if pagin_config.limit is not None: - limit = pagin_config.limit - else: - limit = 10 - serializer_options = SerializeEventConfig(as_client_event=as_client_event) async def handle_room(event: RoomsForUser) -> None: @@ -210,7 +199,7 @@ class InitialSyncHandler: run_in_background( self.store.get_recent_events_for_room, event.room_id, - limit=limit, + limit=pagin_config.limit, end_token=room_end_token, ), deferred_room_state, @@ -360,15 +349,11 @@ class InitialSyncHandler: member_event_id ) - limit = pagin_config.limit if pagin_config else None - if limit is None: - limit = 10 - leave_position = await self.store.get_position_for_event(member_event_id) stream_token = leave_position.to_room_stream_token() messages, token = await self.store.get_recent_events_for_room( - room_id, limit=limit, end_token=stream_token + room_id, limit=pagin_config.limit, end_token=stream_token ) messages = await filter_events_for_client( @@ -420,10 +405,6 @@ class InitialSyncHandler: now_token = self.hs.get_event_sources().get_current_token() - limit = pagin_config.limit if pagin_config else None - if limit is None: - limit = 10 - room_members = [ m for m in current_state.values() @@ -467,7 +448,7 @@ class InitialSyncHandler: run_in_background( self.store.get_recent_events_for_room, room_id, - limit=limit, + limit=pagin_config.limit, end_token=now_token.room_key, ), ), diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 1f83bab836..a4ca9cb8b4 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -458,11 +458,6 @@ class PaginationHandler: # `/messages` should still works with live tokens when manually provided. assert from_token.room_key.topological is not None - if pagin_config.limit is None: - # This shouldn't happen as we've set a default limit before this - # gets called. - raise Exception("limit not set") - room_token = from_token.room_key async with self.pagination_lock.read(room_id): diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 4e575ffbaa..2670e561d7 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1596,7 +1596,9 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): self, user: UserID, from_key: Optional[int], - limit: Optional[int] = None, + # Having a default limit doesn't match the EventSource API, but some + # callers do not provide it. It is unused in this class. + limit: int = 0, room_ids: Optional[Collection[str]] = None, is_guest: bool = False, explicit_room_id: Optional[str] = None, diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 4a7ec9e426..ac01582442 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -257,7 +257,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): self, user: UserID, from_key: int, - limit: Optional[int], + limit: int, room_ids: Iterable[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 1fdd7a10bc..0a0c6d938e 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -116,9 +116,6 @@ class RelationsHandler: if event is None: raise SynapseError(404, "Unknown parent event.") - # TODO Update pagination config to not allow None limits. - assert pagin_config.limit is not None - # Note that ignored users are not passed into get_relations_for_event # below. Ignored users are handled in filter_events_for_client (and by # not passing them in here we should get a better cache hit rate). diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 57ab05ad25..4e1aacb408 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1646,7 +1646,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]): self, user: UserID, from_key: RoomStreamToken, - limit: Optional[int], + limit: int, room_ids: Collection[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index f953691669..a0ea719430 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -513,7 +513,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]): self, user: UserID, from_key: int, - limit: Optional[int], + limit: int, room_ids: Iterable[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py index 916f5230f1..782e7d14e8 100644 --- a/synapse/rest/client/events.py +++ b/synapse/rest/client/events.py @@ -50,7 +50,9 @@ class EventStreamRestServlet(RestServlet): raise SynapseError(400, "Guest users must specify room_id param") room_id = parse_string(request, "room_id") - pagin_config = await PaginationConfig.from_request(self.store, request) + pagin_config = await PaginationConfig.from_request( + self.store, request, default_limit=10 + ) timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS if b"timeout" in args: try: diff --git a/synapse/rest/client/initial_sync.py b/synapse/rest/client/initial_sync.py index cfadcb8e50..9b1bb8b521 100644 --- a/synapse/rest/client/initial_sync.py +++ b/synapse/rest/client/initial_sync.py @@ -39,7 +39,9 @@ class InitialSyncRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) args: Dict[bytes, List[bytes]] = request.args # type: ignore as_client_event = b"raw" not in args - pagination_config = await PaginationConfig.from_request(self.store, request) + pagination_config = await PaginationConfig.from_request( + self.store, request, default_limit=10 + ) include_archived = parse_boolean(request, "archived", default=False) content = await self.initial_sync_handler.snapshot_all_rooms( user_id=requester.user.to_string(), diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index b6dedbed04..01e5079963 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -729,7 +729,9 @@ class RoomInitialSyncRestServlet(RestServlet): self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) - pagination_config = await PaginationConfig.from_request(self.store, request) + pagination_config = await PaginationConfig.from_request( + self.store, request, default_limit=10 + ) content = await self.initial_sync_handler.room_initial_sync( room_id=room_id, requester=requester, pagin_config=pagination_config ) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index ffeb2b3683..5baffbfe55 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -1200,8 +1200,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): `to_token`), or `limit` is zero. """ - assert int(limit) >= 0 - # Tokens really represent positions between elements, but we use # the convention of pointing to the event before the gap. Hence # we have a bit of asymmetry when it comes to equalities. diff --git a/synapse/streams/__init__.py b/synapse/streams/__init__.py index 806b671305..2dcd43d0a2 100644 --- a/synapse/streams/__init__.py +++ b/synapse/streams/__init__.py @@ -27,7 +27,7 @@ class EventSource(Generic[K, R]): self, user: UserID, from_key: K, - limit: Optional[int], + limit: int, room_ids: Collection[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/streams/config.py b/synapse/streams/config.py index f6f7bf3d8b..6df2de919c 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -35,14 +35,14 @@ class PaginationConfig: from_token: Optional[StreamToken] to_token: Optional[StreamToken] direction: str - limit: Optional[int] + limit: int @classmethod async def from_request( cls, store: "DataStore", request: SynapseRequest, - default_limit: Optional[int] = None, + default_limit: int, default_dir: str = "f", ) -> "PaginationConfig": direction = parse_string( @@ -69,12 +69,10 @@ class PaginationConfig: raise SynapseError(400, "'to' parameter is invalid") limit = parse_integer(request, "limit", default=default_limit) + if limit < 0: + raise SynapseError(400, "Limit must be 0 or above") - if limit: - if limit < 0: - raise SynapseError(400, "Limit must be 0 or above") - - limit = min(int(limit), MAX_LIMIT) + limit = min(limit, MAX_LIMIT) try: return PaginationConfig(from_tok, to_tok, direction, limit) diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py index 61b66d7685..fdc433a8b5 100644 --- a/tests/rest/client/test_typing.py +++ b/tests/rest/client/test_typing.py @@ -59,7 +59,8 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): self.event_source.get_new_events( user=UserID.from_string(self.user_id), from_key=0, - limit=None, + # Limit is unused. + limit=0, room_ids=[self.room_id], is_guest=False, ) -- cgit 1.5.1 From 4283bd1cf9c3da2157c3642a7c4f105e9fac2636 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 17 Oct 2022 11:32:11 -0400 Subject: Support filtering the /messages API by relation type (MSC3874). (#14148) Gated behind an experimental configuration flag. --- changelog.d/14148.feature | 1 + synapse/api/filtering.py | 27 +++++- synapse/config/experimental.py | 3 + synapse/rest/client/versions.py | 2 + synapse/storage/databases/main/stream.py | 29 ++++++- tests/api/test_filtering.py | 63 +++++++++++++- tests/rest/client/test_relations.py | 1 - tests/rest/client/test_rooms.py | 145 ++----------------------------- tests/storage/test_stream.py | 118 ++++++++++++++++++------- 9 files changed, 212 insertions(+), 177 deletions(-) create mode 100644 changelog.d/14148.feature (limited to 'tests/rest') diff --git a/changelog.d/14148.feature b/changelog.d/14148.feature new file mode 100644 index 0000000000..951d0cac80 --- /dev/null +++ b/changelog.d/14148.feature @@ -0,0 +1 @@ +Experimental support for [MSC3874](https://github.com/matrix-org/matrix-spec-proposals/pull/3874). diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index cc31cf8cc7..26be377d03 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -36,7 +36,7 @@ from jsonschema import FormatChecker from synapse.api.constants import EduTypes, EventContentFields from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState -from synapse.events import EventBase +from synapse.events import EventBase, relation_from_event from synapse.types import JsonDict, RoomID, UserID if TYPE_CHECKING: @@ -53,6 +53,12 @@ FILTER_SCHEMA = { # check types are valid event types "types": {"type": "array", "items": {"type": "string"}}, "not_types": {"type": "array", "items": {"type": "string"}}, + # MSC3874, filtering /messages. + "org.matrix.msc3874.rel_types": {"type": "array", "items": {"type": "string"}}, + "org.matrix.msc3874.not_rel_types": { + "type": "array", + "items": {"type": "string"}, + }, }, } @@ -334,8 +340,15 @@ class Filter: self.labels = filter_json.get("org.matrix.labels", None) self.not_labels = filter_json.get("org.matrix.not_labels", []) - self.related_by_senders = self.filter_json.get("related_by_senders", None) - self.related_by_rel_types = self.filter_json.get("related_by_rel_types", None) + self.related_by_senders = filter_json.get("related_by_senders", None) + self.related_by_rel_types = filter_json.get("related_by_rel_types", None) + + # For compatibility with _check_fields. + self.rel_types = None + self.not_rel_types = [] + if hs.config.experimental.msc3874_enabled: + self.rel_types = filter_json.get("org.matrix.msc3874.rel_types", None) + self.not_rel_types = filter_json.get("org.matrix.msc3874.not_rel_types", []) def filters_all_types(self) -> bool: return "*" in self.not_types @@ -386,11 +399,19 @@ class Filter: # check if there is a string url field in the content for filtering purposes labels = content.get(EventContentFields.LABELS, []) + # Check if the event has a relation. + rel_type = None + if isinstance(event, EventBase): + relation = relation_from_event(event) + if relation: + rel_type = relation.rel_type + field_matchers = { "rooms": lambda v: room_id == v, "senders": lambda v: sender == v, "types": lambda v: _matches_wildcard(ev_type, v), "labels": lambda v: v in labels, + "rel_types": lambda v: rel_type == v, } result = self._check_fields(field_matchers) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index f44655516e..f9a49451d8 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -117,3 +117,6 @@ class ExperimentalConfig(Config): self.msc3882_token_timeout = self.parse_duration( experimental.get("msc3882_token_timeout", "5m") ) + + # MSC3874: Filtering /messages with rel_types / not_rel_types. + self.msc3874_enabled: bool = experimental.get("msc3874_enabled", False) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 4e1fd2bbe7..4b87ee978a 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -114,6 +114,8 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3882": self.config.experimental.msc3882_enabled, # Adds support for remotely enabling/disabling pushers, as per MSC3881 "org.matrix.msc3881": self.config.experimental.msc3881_enabled, + # Adds support for filtering /messages by event relation. + "org.matrix.msc3874": self.config.experimental.msc3874_enabled, }, }, ) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 5baffbfe55..09ce855aa8 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -357,6 +357,24 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: ) args.extend(event_filter.related_by_rel_types) + if event_filter.rel_types: + clauses.append( + "(%s)" + % " OR ".join( + "event_relation.relation_type = ?" for _ in event_filter.rel_types + ) + ) + args.extend(event_filter.rel_types) + + if event_filter.not_rel_types: + clauses.append( + "((%s) OR event_relation.relation_type IS NULL)" + % " AND ".join( + "event_relation.relation_type != ?" for _ in event_filter.not_rel_types + ) + ) + args.extend(event_filter.not_rel_types) + return " AND ".join(clauses), args @@ -1278,8 +1296,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # Multiple labels could cause the same event to appear multiple times. needs_distinct = True - # If there is a filter on relation_senders and relation_types join to the - # relations table. + # If there is a relation_senders and relation_types filter join to the + # relations table to get events related to the current event. if event_filter and ( event_filter.related_by_senders or event_filter.related_by_rel_types ): @@ -1294,6 +1312,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id) """ + # If there is a not_rel_types filter join to the relations table to get + # the event's relation information. + if event_filter and (event_filter.rel_types or event_filter.not_rel_types): + join_clause += """ + LEFT JOIN event_relations AS event_relation USING (event_id) + """ + if needs_distinct: select_keywords += " DISTINCT" diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index a269c477fb..a82c4eed86 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -35,6 +35,8 @@ def MockEvent(**kwargs): kwargs["event_id"] = "fake_event_id" if "type" not in kwargs: kwargs["type"] = "fake_type" + if "content" not in kwargs: + kwargs["content"] = {} return make_event_from_dict(kwargs) @@ -357,6 +359,66 @@ class FilteringTestCase(unittest.HomeserverTestCase): self.assertTrue(Filter(self.hs, definition)._check(event)) + @unittest.override_config({"experimental_features": {"msc3874_enabled": True}}) + def test_filter_rel_type(self): + definition = {"org.matrix.msc3874.rel_types": ["m.thread"]} + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={}, + ) + + self.assertFalse(Filter(self.hs, definition)._check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.reference"}}, + ) + + self.assertFalse(Filter(self.hs, definition)._check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.thread"}}, + ) + + self.assertTrue(Filter(self.hs, definition)._check(event)) + + @unittest.override_config({"experimental_features": {"msc3874_enabled": True}}) + def test_filter_not_rel_type(self): + definition = {"org.matrix.msc3874.not_rel_types": ["m.thread"]} + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.thread"}}, + ) + + self.assertFalse(Filter(self.hs, definition)._check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={}, + ) + + self.assertTrue(Filter(self.hs, definition)._check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.reference"}}, + ) + + self.assertTrue(Filter(self.hs, definition)._check(event)) + def test_filter_presence_match(self): user_filter_json = {"presence": {"types": ["m.*"]}} filter_id = self.get_success( @@ -456,7 +518,6 @@ class FilteringTestCase(unittest.HomeserverTestCase): self.assertEqual(filtered_room_ids, ["!allowed:example.com"]) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_filter_relations(self): events = [ # An event without a relation. diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index f5c1070b2c..ddf315b894 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1677,7 +1677,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]}, ) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_redact_parent_thread(self) -> None: """ Test that thread replies are still available when the root event is redacted. diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 3612ebe7b9..71b1637be8 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -35,7 +35,6 @@ from synapse.api.constants import ( EventTypes, Membership, PublicRoomsFilterFields, - RelationTypes, RoomTypes, ) from synapse.api.errors import Codes, HttpResponseException @@ -50,6 +49,7 @@ from synapse.util.stringutils import random_string from tests import unittest from tests.http.server._base import make_request_with_cancellation_test +from tests.storage.test_stream import PaginationTestCase from tests.test_utils import make_awaitable PATH_PREFIX = b"/_matrix/client/api/v1" @@ -2915,149 +2915,20 @@ class LabelsTestCase(unittest.HomeserverTestCase): return event_id -class RelationsTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - ] - - def default_config(self) -> Dict[str, Any]: - config = super().default_config() - config["experimental_features"] = {"msc3440_enabled": True} - return config - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.user_id = self.register_user("test", "test") - self.tok = self.login("test", "test") - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - - self.second_user_id = self.register_user("second", "test") - self.second_tok = self.login("second", "test") - self.helper.join( - room=self.room_id, user=self.second_user_id, tok=self.second_tok - ) - - self.third_user_id = self.register_user("third", "test") - self.third_tok = self.login("third", "test") - self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok) - - # An initial event with a relation from second user. - res = self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "Message 1"}, - tok=self.tok, - ) - self.event_id_1 = res["event_id"] - self.helper.send_event( - room_id=self.room_id, - type="m.reaction", - content={ - "m.relates_to": { - "rel_type": RelationTypes.ANNOTATION, - "event_id": self.event_id_1, - "key": "👍", - } - }, - tok=self.second_tok, - ) - - # Another event with a relation from third user. - res = self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "Message 2"}, - tok=self.tok, - ) - self.event_id_2 = res["event_id"] - self.helper.send_event( - room_id=self.room_id, - type="m.reaction", - content={ - "m.relates_to": { - "rel_type": RelationTypes.REFERENCE, - "event_id": self.event_id_2, - } - }, - tok=self.third_tok, - ) - - # An event with no relations. - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "No relations"}, - tok=self.tok, - ) - - def _filter_messages(self, filter: JsonDict) -> List[JsonDict]: +class RelationsTestCase(PaginationTestCase): + def _filter_messages(self, filter: JsonDict) -> List[str]: """Make a request to /messages with a filter, returns the chunk of events.""" + from_token = self.get_success( + self.from_token.to_string(self.hs.get_datastores().main) + ) channel = self.make_request( "GET", - "/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)), + f"/rooms/{self.room_id}/messages?filter={json.dumps(filter)}&dir=f&from={from_token}", access_token=self.tok, ) self.assertEqual(channel.code, HTTPStatus.OK, channel.result) - return channel.json_body["chunk"] - - def test_filter_relation_senders(self) -> None: - # Messages which second user reacted to. - filter = {"related_by_senders": [self.second_user_id]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_1) - - # Messages which third user reacted to. - filter = {"related_by_senders": [self.third_user_id]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_2) - - # Messages which either user reacted to. - filter = {"related_by_senders": [self.second_user_id, self.third_user_id]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 2, chunk) - self.assertCountEqual( - [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2] - ) - - def test_filter_relation_type(self) -> None: - # Messages which have annotations. - filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_1) - - # Messages which have references. - filter = {"related_by_rel_types": [RelationTypes.REFERENCE]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_2) - - # Messages which have either annotations or references. - filter = { - "related_by_rel_types": [ - RelationTypes.ANNOTATION, - RelationTypes.REFERENCE, - ] - } - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 2, chunk) - self.assertCountEqual( - [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2] - ) - - def test_filter_relation_senders_and_type(self) -> None: - # Messages which second user reacted to. - filter = { - "related_by_senders": [self.second_user_id], - "related_by_rel_types": [RelationTypes.ANNOTATION], - } - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_1) + return [ev["event_id"] for ev in channel.json_body["chunk"]] class ContextTestCase(unittest.HomeserverTestCase): diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index 78663a53fe..34fa810cf6 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -16,7 +16,6 @@ from typing import List from synapse.api.constants import EventTypes, RelationTypes from synapse.api.filtering import Filter -from synapse.events import EventBase from synapse.rest import admin from synapse.rest.client import login, room from synapse.types import JsonDict @@ -40,7 +39,7 @@ class PaginationTestCase(HomeserverTestCase): def default_config(self): config = super().default_config() - config["experimental_features"] = {"msc3440_enabled": True} + config["experimental_features"] = {"msc3874_enabled": True} return config def prepare(self, reactor, clock, homeserver): @@ -58,6 +57,11 @@ class PaginationTestCase(HomeserverTestCase): self.third_tok = self.login("third", "test") self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok) + # Store a token which is after all the room creation events. + self.from_token = self.get_success( + self.hs.get_event_sources().get_current_token_for_pagination(self.room_id) + ) + # An initial event with a relation from second user. res = self.helper.send_event( room_id=self.room_id, @@ -66,7 +70,7 @@ class PaginationTestCase(HomeserverTestCase): tok=self.tok, ) self.event_id_1 = res["event_id"] - self.helper.send_event( + res = self.helper.send_event( room_id=self.room_id, type="m.reaction", content={ @@ -78,6 +82,7 @@ class PaginationTestCase(HomeserverTestCase): }, tok=self.second_tok, ) + self.event_id_annotation = res["event_id"] # Another event with a relation from third user. res = self.helper.send_event( @@ -87,7 +92,7 @@ class PaginationTestCase(HomeserverTestCase): tok=self.tok, ) self.event_id_2 = res["event_id"] - self.helper.send_event( + res = self.helper.send_event( room_id=self.room_id, type="m.reaction", content={ @@ -98,68 +103,59 @@ class PaginationTestCase(HomeserverTestCase): }, tok=self.third_tok, ) + self.event_id_reference = res["event_id"] # An event with no relations. - self.helper.send_event( + res = self.helper.send_event( room_id=self.room_id, type=EventTypes.Message, content={"msgtype": "m.text", "body": "No relations"}, tok=self.tok, ) + self.event_id_none = res["event_id"] - def _filter_messages(self, filter: JsonDict) -> List[EventBase]: + def _filter_messages(self, filter: JsonDict) -> List[str]: """Make a request to /messages with a filter, returns the chunk of events.""" - from_token = self.get_success( - self.hs.get_event_sources().get_current_token_for_pagination(self.room_id) - ) - events, next_key = self.get_success( self.hs.get_datastores().main.paginate_room_events( room_id=self.room_id, - from_key=from_token.room_key, + from_key=self.from_token.room_key, to_key=None, - direction="b", + direction="f", limit=10, event_filter=Filter(self.hs, filter), ) ) - return events + return [ev.event_id for ev in events] def test_filter_relation_senders(self): # Messages which second user reacted to. filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_1) + self.assertEqual(chunk, [self.event_id_1]) # Messages which third user reacted to. filter = {"related_by_senders": [self.third_user_id]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_2) + self.assertEqual(chunk, [self.event_id_2]) # Messages which either user reacted to. filter = {"related_by_senders": [self.second_user_id, self.third_user_id]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 2, chunk) - self.assertCountEqual( - [c.event_id for c in chunk], [self.event_id_1, self.event_id_2] - ) + self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2]) def test_filter_relation_type(self): # Messages which have annotations. filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_1) + self.assertEqual(chunk, [self.event_id_1]) # Messages which have references. filter = {"related_by_rel_types": [RelationTypes.REFERENCE]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_2) + self.assertEqual(chunk, [self.event_id_2]) # Messages which have either annotations or references. filter = { @@ -169,10 +165,7 @@ class PaginationTestCase(HomeserverTestCase): ] } chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 2, chunk) - self.assertCountEqual( - [c.event_id for c in chunk], [self.event_id_1, self.event_id_2] - ) + self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2]) def test_filter_relation_senders_and_type(self): # Messages which second user reacted to. @@ -181,8 +174,7 @@ class PaginationTestCase(HomeserverTestCase): "related_by_rel_types": [RelationTypes.ANNOTATION], } chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_1) + self.assertEqual(chunk, [self.event_id_1]) def test_duplicate_relation(self): """An event should only be returned once if there are multiple relations to it.""" @@ -201,5 +193,65 @@ class PaginationTestCase(HomeserverTestCase): filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_1) + self.assertEqual(chunk, [self.event_id_1]) + + def test_filter_rel_types(self) -> None: + # Messages which are annotations. + filter = {"org.matrix.msc3874.rel_types": [RelationTypes.ANNOTATION]} + chunk = self._filter_messages(filter) + self.assertEqual(chunk, [self.event_id_annotation]) + + # Messages which are references. + filter = {"org.matrix.msc3874.rel_types": [RelationTypes.REFERENCE]} + chunk = self._filter_messages(filter) + self.assertEqual(chunk, [self.event_id_reference]) + + # Messages which are either annotations or references. + filter = { + "org.matrix.msc3874.rel_types": [ + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + ] + } + chunk = self._filter_messages(filter) + self.assertCountEqual( + chunk, + [self.event_id_annotation, self.event_id_reference], + ) + + def test_filter_not_rel_types(self) -> None: + # Messages which are not annotations. + filter = {"org.matrix.msc3874.not_rel_types": [RelationTypes.ANNOTATION]} + chunk = self._filter_messages(filter) + self.assertEqual( + chunk, + [ + self.event_id_1, + self.event_id_2, + self.event_id_reference, + self.event_id_none, + ], + ) + + # Messages which are not references. + filter = {"org.matrix.msc3874.not_rel_types": [RelationTypes.REFERENCE]} + chunk = self._filter_messages(filter) + self.assertEqual( + chunk, + [ + self.event_id_1, + self.event_id_annotation, + self.event_id_2, + self.event_id_none, + ], + ) + + # Messages which are neither annotations or references. + filter = { + "org.matrix.msc3874.not_rel_types": [ + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + ] + } + chunk = self._filter_messages(filter) + self.assertEqual(chunk, [self.event_id_1, self.event_id_2, self.event_id_none]) -- cgit 1.5.1 From 4eaf3eb840b8cfa78d970216c74fc128495f08a5 Mon Sep 17 00:00:00 2001 From: Hugh Nimmo-Smith Date: Tue, 18 Oct 2022 16:52:25 +0100 Subject: Implementation of HTTP 307 response for MSC3886 POST endpoint (#14018) Co-authored-by: reivilibre Co-authored-by: Andrew Morgan --- changelog.d/14018.feature | 1 + synapse/config/experimental.py | 7 +- synapse/config/server.py | 4 ++ synapse/handlers/sso.py | 2 +- synapse/http/server.py | 48 ++++++++++--- synapse/http/site.py | 3 + synapse/rest/__init__.py | 2 + synapse/rest/client/rendezvous.py | 74 +++++++++++++++++++ synapse/rest/client/versions.py | 3 + synapse/rest/key/v2/local_key_resource.py | 4 +- synapse/rest/synapse/client/new_user_consent.py | 3 +- synapse/rest/well_known.py | 3 +- tests/logging/test_terse_json.py | 1 + tests/rest/client/test_rendezvous.py | 45 ++++++++++++ tests/server.py | 8 ++- tests/test_server.py | 94 ++++++++++++++++++------- 16 files changed, 257 insertions(+), 45 deletions(-) create mode 100644 changelog.d/14018.feature create mode 100644 synapse/rest/client/rendezvous.py create mode 100644 tests/rest/client/test_rendezvous.py (limited to 'tests/rest') diff --git a/changelog.d/14018.feature b/changelog.d/14018.feature new file mode 100644 index 0000000000..c8454607eb --- /dev/null +++ b/changelog.d/14018.feature @@ -0,0 +1 @@ +Support for redirecting to an implementation of a [MSC3886](https://github.com/matrix-org/matrix-spec-proposals/pull/3886) HTTP rendezvous service. \ No newline at end of file diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index f9a49451d8..4009add01d 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional import attr @@ -120,3 +120,8 @@ class ExperimentalConfig(Config): # MSC3874: Filtering /messages with rel_types / not_rel_types. self.msc3874_enabled: bool = experimental.get("msc3874_enabled", False) + + # MSC3886: Simple client rendezvous capability + self.msc3886_endpoint: Optional[str] = experimental.get( + "msc3886_endpoint", None + ) diff --git a/synapse/config/server.py b/synapse/config/server.py index f2353ce5fb..ec46ca63ad 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -207,6 +207,9 @@ class HttpListenerConfig: additional_resources: Dict[str, dict] = attr.Factory(dict) tag: Optional[str] = None request_id_header: Optional[str] = None + # If true, the listener will return CORS response headers compatible with MSC3886: + # https://github.com/matrix-org/matrix-spec-proposals/pull/3886 + experimental_cors_msc3886: bool = False @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -935,6 +938,7 @@ def parse_listener_def(num: int, listener: Any) -> ListenerConfig: additional_resources=listener.get("additional_resources", {}), tag=listener.get("tag"), request_id_header=listener.get("request_id_header"), + experimental_cors_msc3886=listener.get("experimental_cors_msc3886", False), ) return ListenerConfig(port, bind_addresses, listener_type, tls, http_config) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index e035677b8a..5943f08e91 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -874,7 +874,7 @@ class SsoHandler: ) async def handle_terms_accepted( - self, request: Request, session_id: str, terms_version: str + self, request: SynapseRequest, session_id: str, terms_version: str ) -> None: """Handle a request to the new-user 'consent' endpoint diff --git a/synapse/http/server.py b/synapse/http/server.py index bcbfac2c9f..b26e34bceb 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -19,6 +19,7 @@ import logging import types import urllib from http import HTTPStatus +from http.client import FOUND from inspect import isawaitable from typing import ( TYPE_CHECKING, @@ -339,7 +340,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): return callback_return - _unrecognised_request_handler(request) + return _unrecognised_request_handler(request) @abc.abstractmethod def _send_response( @@ -598,7 +599,7 @@ class RootRedirect(resource.Resource): class OptionsResource(resource.Resource): """Responds to OPTION requests for itself and all children.""" - def render_OPTIONS(self, request: Request) -> bytes: + def render_OPTIONS(self, request: SynapseRequest) -> bytes: request.setResponseCode(204) request.setHeader(b"Content-Length", b"0") @@ -763,7 +764,7 @@ def respond_with_json( def respond_with_json_bytes( - request: Request, + request: SynapseRequest, code: int, json_bytes: bytes, send_cors: bool = False, @@ -859,7 +860,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None: _ByteProducer(request, bytes_generator) -def set_cors_headers(request: Request) -> None: +def set_cors_headers(request: SynapseRequest) -> None: """Set the CORS headers so that javascript running in a web browsers can use this API @@ -870,10 +871,20 @@ def set_cors_headers(request: Request) -> None: request.setHeader( b"Access-Control-Allow-Methods", b"GET, HEAD, POST, PUT, DELETE, OPTIONS" ) - request.setHeader( - b"Access-Control-Allow-Headers", - b"X-Requested-With, Content-Type, Authorization, Date", - ) + if request.experimental_cors_msc3886: + request.setHeader( + b"Access-Control-Allow-Headers", + b"X-Requested-With, Content-Type, Authorization, Date, If-Match, If-None-Match", + ) + request.setHeader( + b"Access-Control-Expose-Headers", + b"ETag, Location, X-Max-Bytes", + ) + else: + request.setHeader( + b"Access-Control-Allow-Headers", + b"X-Requested-With, Content-Type, Authorization, Date", + ) def set_corp_headers(request: Request) -> None: @@ -942,10 +953,25 @@ def set_clickjacking_protection_headers(request: Request) -> None: request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';") -def respond_with_redirect(request: Request, url: bytes) -> None: - """Write a 302 response to the request, if it is still alive.""" +def respond_with_redirect( + request: SynapseRequest, url: bytes, statusCode: int = FOUND, cors: bool = False +) -> None: + """ + Write a 302 (or other specified status code) response to the request, if it is still alive. + + Args: + request: The http request to respond to. + url: The URL to redirect to. + statusCode: The HTTP status code to use for the redirect (defaults to 302). + cors: Whether to set CORS headers on the response. + """ logger.debug("Redirect to %s", url.decode("utf-8")) - request.redirect(url) + + if cors: + set_cors_headers(request) + + request.setResponseCode(statusCode) + request.setHeader(b"location", url) finish_request(request) diff --git a/synapse/http/site.py b/synapse/http/site.py index 55a6afce35..3dbd541fed 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -82,6 +82,7 @@ class SynapseRequest(Request): self.reactor = site.reactor self._channel = channel # this is used by the tests self.start_time = 0.0 + self.experimental_cors_msc3886 = site.experimental_cors_msc3886 # The requester, if authenticated. For federation requests this is the # server name, for client requests this is the Requester object. @@ -622,6 +623,8 @@ class SynapseSite(Site): request_id_header = config.http_options.request_id_header + self.experimental_cors_msc3886 = config.http_options.experimental_cors_msc3886 + def request_factory(channel: HTTPChannel, queued: bool) -> Request: return request_class( channel, diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 9a2ab99ede..28542cd774 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -44,6 +44,7 @@ from synapse.rest.client import ( receipts, register, relations, + rendezvous, report_event, room, room_batch, @@ -132,3 +133,4 @@ class ClientRestResource(JsonResource): # unstable mutual_rooms.register_servlets(hs, client_resource) login_token_request.register_servlets(hs, client_resource) + rendezvous.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/rendezvous.py b/synapse/rest/client/rendezvous.py new file mode 100644 index 0000000000..89176b1ffa --- /dev/null +++ b/synapse/rest/client/rendezvous.py @@ -0,0 +1,74 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from http.client import TEMPORARY_REDIRECT +from typing import TYPE_CHECKING, Optional + +from synapse.http.server import HttpServer, respond_with_redirect +from synapse.http.servlet import RestServlet +from synapse.http.site import SynapseRequest +from synapse.rest.client._base import client_patterns + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class RendezvousServlet(RestServlet): + """ + This is a placeholder implementation of [MSC3886](https://github.com/matrix-org/matrix-spec-proposals/pull/3886) + simple client rendezvous capability that is used by the "Sign in with QR" functionality. + + This implementation only serves as a 307 redirect to a configured server rather than being a full implementation. + + A module that implements the full functionality is available at: https://pypi.org/project/matrix-http-rendezvous-synapse/. + + Request: + + POST /rendezvous HTTP/1.1 + Content-Type: ... + + ... + + Response: + + HTTP/1.1 307 + Location: + """ + + PATTERNS = client_patterns( + "/org.matrix.msc3886/rendezvous$", releases=[], v1=False, unstable=True + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + redirection_target: Optional[str] = hs.config.experimental.msc3886_endpoint + assert ( + redirection_target is not None + ), "Servlet is only registered if there is a redirection target" + self.endpoint = redirection_target.encode("utf-8") + + async def on_POST(self, request: SynapseRequest) -> None: + respond_with_redirect( + request, self.endpoint, statusCode=TEMPORARY_REDIRECT, cors=True + ) + + # PUT, GET and DELETE are not implemented as they should be fulfilled by the redirect target. + + +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + if hs.config.experimental.msc3886_endpoint is not None: + RendezvousServlet(hs).register(http_server) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 4b87ee978a..9b1b72c68a 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -116,6 +116,9 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3881": self.config.experimental.msc3881_enabled, # Adds support for filtering /messages by event relation. "org.matrix.msc3874": self.config.experimental.msc3874_enabled, + # Adds support for simple HTTP rendezvous as per MSC3886 + "org.matrix.msc3886": self.config.experimental.msc3886_endpoint + is not None, }, }, ) diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index 0c9f042c84..095993415c 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -20,9 +20,9 @@ from signedjson.sign import sign_json from unpaddedbase64 import encode_base64 from twisted.web.resource import Resource -from twisted.web.server import Request from synapse.http.server import respond_with_json_bytes +from synapse.http.site import SynapseRequest from synapse.types import JsonDict if TYPE_CHECKING: @@ -99,7 +99,7 @@ class LocalKey(Resource): json_object = sign_json(json_object, self.config.server.server_name, key) return json_object - def render_GET(self, request: Request) -> Optional[int]: + def render_GET(self, request: SynapseRequest) -> Optional[int]: time_now = self.clock.time_msec() # Update the expiry time if less than half the interval remains. if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts: diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py index 1c1c7b3613..22784157e6 100644 --- a/synapse/rest/synapse/client/new_user_consent.py +++ b/synapse/rest/synapse/client/new_user_consent.py @@ -20,6 +20,7 @@ from synapse.api.errors import SynapseError from synapse.handlers.sso import get_username_mapping_session_cookie_from_request from synapse.http.server import DirectServeHtmlResource, respond_with_html from synapse.http.servlet import parse_string +from synapse.http.site import SynapseRequest from synapse.types import UserID from synapse.util.templates import build_jinja_env @@ -88,7 +89,7 @@ class NewUserConsentResource(DirectServeHtmlResource): html = template.render(template_params) respond_with_html(request, 200, html) - async def _async_render_POST(self, request: Request) -> None: + async def _async_render_POST(self, request: SynapseRequest) -> None: try: session_id = get_username_mapping_session_cookie_from_request(request) except SynapseError as e: diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index 6f7ac54c65..e2174fdfea 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -18,6 +18,7 @@ from twisted.web.resource import Resource from twisted.web.server import Request from synapse.http.server import set_cors_headers +from synapse.http.site import SynapseRequest from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.stringutils import parse_server_name @@ -63,7 +64,7 @@ class ClientWellKnownResource(Resource): Resource.__init__(self) self._well_known_builder = WellKnownBuilder(hs) - def render_GET(self, request: Request) -> bytes: + def render_GET(self, request: SynapseRequest) -> bytes: set_cors_headers(request) r = self._well_known_builder.get_well_known() if not r: diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py index 96f399b7ab..0b0d8737c1 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py @@ -153,6 +153,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): site.site_tag = "test-site" site.server_version_string = "Server v1" site.reactor = Mock() + site.experimental_cors_msc3886 = False request = SynapseRequest(FakeChannel(site, None), site) # Call requestReceived to finish instantiating the object. request.content = BytesIO() diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py new file mode 100644 index 0000000000..ad00a476e1 --- /dev/null +++ b/tests/rest/client/test_rendezvous.py @@ -0,0 +1,45 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.rest.client import rendezvous +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest +from tests.unittest import override_config + +endpoint = "/_matrix/client/unstable/org.matrix.msc3886/rendezvous" + + +class RendezvousServletTestCase(unittest.HomeserverTestCase): + + servlets = [ + rendezvous.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.hs = self.setup_test_homeserver() + return self.hs + + def test_disabled(self) -> None: + channel = self.make_request("POST", endpoint, {}, access_token=None) + self.assertEqual(channel.code, 400) + + @override_config({"experimental_features": {"msc3886_endpoint": "/asd"}}) + def test_redirect(self) -> None: + channel = self.make_request("POST", endpoint, {}, access_token=None) + self.assertEqual(channel.code, 307) + self.assertEqual(channel.headers.getRawHeaders("Location"), ["/asd"]) diff --git a/tests/server.py b/tests/server.py index c447d5e4c4..8b1d186219 100644 --- a/tests/server.py +++ b/tests/server.py @@ -266,7 +266,12 @@ class FakeSite: site_tag = "test" access_logger = logging.getLogger("synapse.access.http.fake") - def __init__(self, resource: IResource, reactor: IReactorTime): + def __init__( + self, + resource: IResource, + reactor: IReactorTime, + experimental_cors_msc3886: bool = False, + ): """ Args: @@ -274,6 +279,7 @@ class FakeSite: """ self._resource = resource self.reactor = reactor + self.experimental_cors_msc3886 = experimental_cors_msc3886 def getResourceFor(self, request): return self._resource diff --git a/tests/test_server.py b/tests/test_server.py index 7c66448245..2d9a0257d4 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -222,13 +222,22 @@ class OptionsResourceTests(unittest.TestCase): self.resource = OptionsResource() self.resource.putChild(b"res", DummyResource()) - def _make_request(self, method: bytes, path: bytes) -> FakeChannel: + def _make_request( + self, method: bytes, path: bytes, experimental_cors_msc3886: bool = False + ) -> FakeChannel: """Create a request from the method/path and return a channel with the response.""" # Create a site and query for the resource. site = SynapseSite( "test", "site_tag", - parse_listener_def(0, {"type": "http", "port": 0}), + parse_listener_def( + 0, + { + "type": "http", + "port": 0, + "experimental_cors_msc3886": experimental_cors_msc3886, + }, + ), self.resource, "1.0", max_request_body_size=4096, @@ -239,25 +248,58 @@ class OptionsResourceTests(unittest.TestCase): channel = make_request(self.reactor, site, method, path, shorthand=False) return channel + def _check_cors_standard_headers(self, channel: FakeChannel) -> None: + # Ensure the correct CORS headers have been added + # as per https://spec.matrix.org/v1.4/client-server-api/#web-browser-clients + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Origin"), + [b"*"], + "has correct CORS Origin header", + ) + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Methods"), + [b"GET, HEAD, POST, PUT, DELETE, OPTIONS"], # HEAD isn't in the spec + "has correct CORS Methods header", + ) + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Headers"), + [b"X-Requested-With, Content-Type, Authorization, Date"], + "has correct CORS Headers header", + ) + + def _check_cors_msc3886_headers(self, channel: FakeChannel) -> None: + # Ensure the correct CORS headers have been added + # as per https://github.com/matrix-org/matrix-spec-proposals/blob/hughns/simple-rendezvous-capability/proposals/3886-simple-rendezvous-capability.md#cors + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Origin"), + [b"*"], + "has correct CORS Origin header", + ) + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Methods"), + [b"GET, HEAD, POST, PUT, DELETE, OPTIONS"], # HEAD isn't in the spec + "has correct CORS Methods header", + ) + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Headers"), + [ + b"X-Requested-With, Content-Type, Authorization, Date, If-Match, If-None-Match" + ], + "has correct CORS Headers header", + ) + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Expose-Headers"), + [b"ETag, Location, X-Max-Bytes"], + "has correct CORS Expose Headers header", + ) + def test_unknown_options_request(self) -> None: """An OPTIONS requests to an unknown URL still returns 204 No Content.""" channel = self._make_request(b"OPTIONS", b"/foo/") self.assertEqual(channel.code, 204) self.assertNotIn("body", channel.result) - # Ensure the correct CORS headers have been added - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Origin"), - "has CORS Origin header", - ) - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Methods"), - "has CORS Methods header", - ) - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Headers"), - "has CORS Headers header", - ) + self._check_cors_standard_headers(channel) def test_known_options_request(self) -> None: """An OPTIONS requests to an known URL still returns 204 No Content.""" @@ -265,19 +307,17 @@ class OptionsResourceTests(unittest.TestCase): self.assertEqual(channel.code, 204) self.assertNotIn("body", channel.result) - # Ensure the correct CORS headers have been added - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Origin"), - "has CORS Origin header", - ) - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Methods"), - "has CORS Methods header", - ) - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Headers"), - "has CORS Headers header", + self._check_cors_standard_headers(channel) + + def test_known_options_request_msc3886(self) -> None: + """An OPTIONS requests to an known URL still returns 204 No Content.""" + channel = self._make_request( + b"OPTIONS", b"/res/", experimental_cors_msc3886=True ) + self.assertEqual(channel.code, 204) + self.assertNotIn("body", channel.result) + + self._check_cors_msc3886_headers(channel) def test_unknown_request(self) -> None: """A non-OPTIONS request to an unknown URL should 404.""" -- cgit 1.5.1 From fa8616e65c82367712a7b75c62682a89541b6330 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Tue, 18 Oct 2022 19:46:25 -0500 Subject: Fix MSC3030 `/timestamp_to_event` returning `outliers` that it has no idea whether are near a gap or not (#14215) Fix MSC3030 `/timestamp_to_event` endpoint returning `outliers` that it has no idea whether are near a gap or not (and therefore unable to determine whether it's actually the closest event). The reason Synapse doesn't know whether an `outlier` is next to a gap is because our gap checks rely on entries in the `event_edges`, `event_forward_extremeties`, and `event_backward_extremities` tables which is [not the case for `outliers`](https://github.com/matrix-org/synapse/blob/2c63cdcc3f1aa4625e947de3c23e0a8133c61286/docs/development/room-dag-concepts.md#outliers). Also fixes MSC3030 Complement `can_paginate_after_getting_remote_event_from_timestamp_to_event_endpoint` test flake. Although this acted flakey in Complement, if `sync_partial_state` raced and beat us before `/timestamp_to_event`, then even if we retried the failing `/context` request it wouldn't work until we made this Synapse change. With this PR, Synapse will never return an `outlier` event so that test will always go and ask over federation. Fix https://github.com/matrix-org/synapse/issues/13944 ### Why did this fail before? Why was it flakey? Sleuthing the server logs on the [CI failure](https://github.com/matrix-org/synapse/actions/runs/3149623842/jobs/5121449357#step:5:5805), it looks like `hs2:/timestamp_to_event` found `$NP6-oU7mIFVyhtKfGvfrEQX949hQX-T-gvuauG6eurU` as an `outlier` event locally. Then when we went and asked for it via `/context`, since it's an `outlier`, it was filtered out of the results -> `You don't have permission to access that event.` This is reproducible when `sync_partial_state` races and persists `$NP6-oU7mIFVyhtKfGvfrEQX949hQX-T-gvuauG6eurU` as an `outlier` before we evaluate `get_event_for_timestamp(...)`. To consistently reproduce locally, just add a delay at the [start of `get_event_for_timestamp(...)`](https://github.com/matrix-org/synapse/blob/cb20b885cb4bd1648581dd043a184d86fc8c7a00/synapse/handlers/room.py#L1470-L1496) so it always runs after `sync_partial_state` completes. ```py from twisted.internet import task as twisted_task d = twisted_task.deferLater(self.hs.get_reactor(), 3.5) await d ``` In a run where it passes, on `hs2`, `get_event_for_timestamp(...)` finds a different event locally which is next to a gap and we request from a closer one from `hs1` which gets backfilled. And since the backfilled event is not an `outlier`, it's returned as expected during `/context`. With this PR, Synapse will never return an `outlier` event so that test will always go and ask over federation. --- changelog.d/14215.bugfix | 1 + synapse/storage/databases/main/events_worker.py | 59 ++++++++++++++-------- tests/rest/client/test_rooms.py | 65 +++++++++++++++++++++++++ 3 files changed, 104 insertions(+), 21 deletions(-) create mode 100644 changelog.d/14215.bugfix (limited to 'tests/rest') diff --git a/changelog.d/14215.bugfix b/changelog.d/14215.bugfix new file mode 100644 index 0000000000..31c109f534 --- /dev/null +++ b/changelog.d/14215.bugfix @@ -0,0 +1 @@ +Fix [MSC3030](https://github.com/matrix-org/matrix-spec-proposals/pull/3030) `/timestamp_to_event` endpoint returning potentially inaccurate closest events with `outliers` present. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 7bc7f2f33e..69fea452ad 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1971,12 +1971,17 @@ class EventsWorkerStore(SQLBaseStore): Args: room_id: room where the event lives - event_id: event to check + event: event to check (can't be an `outlier`) Returns: Boolean indicating whether it's an extremity """ + assert not event.internal_metadata.is_outlier(), ( + "is_event_next_to_backward_gap(...) can't be used with `outlier` events. " + "This function relies on `event_backward_extremities` which won't be filled in for `outliers`." + ) + def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool: # If the event in question has any of its prev_events listed as a # backward extremity, it's next to a gap. @@ -2026,12 +2031,17 @@ class EventsWorkerStore(SQLBaseStore): Args: room_id: room where the event lives - event_id: event to check + event: event to check (can't be an `outlier`) Returns: Boolean indicating whether it's an extremity """ + assert not event.internal_metadata.is_outlier(), ( + "is_event_next_to_forward_gap(...) can't be used with `outlier` events. " + "This function relies on `event_edges` and `event_forward_extremities` which won't be filled in for `outliers`." + ) + def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool: # If the event in question is a forward extremity, we will just # consider any potential forward gap as not a gap since it's one of @@ -2112,13 +2122,33 @@ class EventsWorkerStore(SQLBaseStore): The closest event_id otherwise None if we can't find any event in the given direction. """ + if direction == "b": + # Find closest event *before* a given timestamp. We use descending + # (which gives values largest to smallest) because we want the + # largest possible timestamp *before* the given timestamp. + comparison_operator = "<=" + order = "DESC" + else: + # Find closest event *after* a given timestamp. We use ascending + # (which gives values smallest to largest) because we want the + # closest possible timestamp *after* the given timestamp. + comparison_operator = ">=" + order = "ASC" - sql_template = """ + sql_template = f""" SELECT event_id FROM events LEFT JOIN rejections USING (event_id) WHERE - origin_server_ts %s ? - AND room_id = ? + room_id = ? + AND origin_server_ts {comparison_operator} ? + /** + * Make sure the event isn't an `outlier` because we have no way + * to later check whether it's next to a gap. `outliers` do not + * have entries in the `event_edges`, `event_forward_extremeties`, + * and `event_backward_extremities` tables to check against + * (used by `is_event_next_to_backward_gap` and `is_event_next_to_forward_gap`). + */ + AND NOT outlier /* Make sure event is not rejected */ AND rejections.event_id IS NULL /** @@ -2128,27 +2158,14 @@ class EventsWorkerStore(SQLBaseStore): * Finally, we can tie-break based on when it was received on the server * (`stream_ordering`). */ - ORDER BY origin_server_ts %s, depth %s, stream_ordering %s + ORDER BY origin_server_ts {order}, depth {order}, stream_ordering {order} LIMIT 1; """ def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]: - if direction == "b": - # Find closest event *before* a given timestamp. We use descending - # (which gives values largest to smallest) because we want the - # largest possible timestamp *before* the given timestamp. - comparison_operator = "<=" - order = "DESC" - else: - # Find closest event *after* a given timestamp. We use ascending - # (which gives values smallest to largest) because we want the - # closest possible timestamp *after* the given timestamp. - comparison_operator = ">=" - order = "ASC" - txn.execute( - sql_template % (comparison_operator, order, order, order), - (timestamp, room_id), + sql_template, + (room_id, timestamp), ) row = txn.fetchone() if row: diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 71b1637be8..716366eb90 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -39,6 +39,8 @@ from synapse.api.constants import ( ) from synapse.api.errors import Codes, HttpResponseException from synapse.appservice import ApplicationService +from synapse.events import EventBase +from synapse.events.snapshot import EventContext from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin from synapse.rest.client import account, directory, login, profile, register, room, sync @@ -51,6 +53,7 @@ from tests import unittest from tests.http.server._base import make_request_with_cancellation_test from tests.storage.test_stream import PaginationTestCase from tests.test_utils import make_awaitable +from tests.test_utils.event_injection import create_event PATH_PREFIX = b"/_matrix/client/api/v1" @@ -3486,3 +3489,65 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400) self.assertEqual(channel.json_body["errcode"], "M_MISSING_PARAM") + + +class TimestampLookupTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def default_config(self) -> JsonDict: + config = super().default_config() + config["experimental_features"] = {"msc3030_enabled": True} + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self._storage_controllers = self.hs.get_storage_controllers() + + self.room_owner = self.register_user("room_owner", "test") + self.room_owner_tok = self.login("room_owner", "test") + + def _inject_outlier(self, room_id: str) -> EventBase: + event, _context = self.get_success( + create_event( + self.hs, + room_id=room_id, + type="m.test", + sender="@test_remote_user:remote", + ) + ) + + event.internal_metadata.outlier = True + self.get_success( + self._storage_controllers.persistence.persist_event( + event, EventContext.for_outlier(self._storage_controllers) + ) + ) + return event + + def test_no_outliers(self) -> None: + """ + Test to make sure `/timestamp_to_event` does not return `outlier` events. + We're unable to determine whether an `outlier` is next to a gap so we + don't know whether it's actually the closest event. Instead, let's just + ignore `outliers` with this endpoint. + + This test is really seeing that we choose the non-`outlier` event behind the + `outlier`. Since the gap checking logic considers the latest message in the room + as *not* next to a gap, asking over federation does not come into play here. + """ + room_id = self.helper.create_room_as(self.room_owner, tok=self.room_owner_tok) + + outlier_event = self._inject_outlier(room_id) + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3030/rooms/{room_id}/timestamp_to_event?dir=b&ts={outlier_event.origin_server_ts}", + access_token=self.room_owner_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + + # Make sure the outlier event is not returned + self.assertNotEqual(channel.json_body["event_id"], outlier_event.event_id) -- cgit 1.5.1 From 755bfeee3a1ac7077045ab9e5a994b6ca89afba3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 20 Oct 2022 11:32:47 -0400 Subject: Use servlets for /key/ endpoints. (#14229) To fix the response for unknown endpoints under that prefix. See MSC3743. --- changelog.d/14229.misc | 1 + synapse/api/urls.py | 2 +- synapse/app/generic_worker.py | 20 +++----- synapse/app/homeserver.py | 26 ++++------ synapse/rest/key/v2/__init__.py | 19 ++++--- synapse/rest/key/v2/local_key_resource.py | 22 ++++---- synapse/rest/key/v2/remote_key_resource.py | 73 +++++++++++++++------------ tests/app/test_openid_listener.py | 2 +- tests/rest/key/v2/test_remote_key_resource.py | 4 +- 9 files changed, 86 insertions(+), 83 deletions(-) create mode 100644 changelog.d/14229.misc (limited to 'tests/rest') diff --git a/changelog.d/14229.misc b/changelog.d/14229.misc new file mode 100644 index 0000000000..b9cd9a34d5 --- /dev/null +++ b/changelog.d/14229.misc @@ -0,0 +1 @@ +Refactor `/key/` endpoints to use `RestServlet` classes. diff --git a/synapse/api/urls.py b/synapse/api/urls.py index bd49fa6a5f..a918579f50 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -28,7 +28,7 @@ FEDERATION_V1_PREFIX = FEDERATION_PREFIX + "/v1" FEDERATION_V2_PREFIX = FEDERATION_PREFIX + "/v2" FEDERATION_UNSTABLE_PREFIX = FEDERATION_PREFIX + "/unstable" STATIC_PREFIX = "/_matrix/static" -SERVER_KEY_V2_PREFIX = "/_matrix/key/v2" +SERVER_KEY_PREFIX = "/_matrix/key" MEDIA_R0_PREFIX = "/_matrix/media/r0" MEDIA_V3_PREFIX = "/_matrix/media/v3" LEGACY_MEDIA_PREFIX = "/_matrix/media/v1" diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index dc49840f73..2a9f039367 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -28,7 +28,7 @@ from synapse.api.urls import ( LEGACY_MEDIA_PREFIX, MEDIA_R0_PREFIX, MEDIA_V3_PREFIX, - SERVER_KEY_V2_PREFIX, + SERVER_KEY_PREFIX, ) from synapse.app import _base from synapse.app._base import ( @@ -89,7 +89,7 @@ from synapse.rest.client.register import ( RegistrationTokenValidityRestServlet, ) from synapse.rest.health import HealthResource -from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.rest.key.v2 import KeyResource from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.rest.well_known import well_known_resource from synapse.server import HomeServer @@ -325,13 +325,13 @@ class GenericWorkerServer(HomeServer): presence.register_servlets(self, resource) - resources.update({CLIENT_API_PREFIX: resource}) + resources[CLIENT_API_PREFIX] = resource resources.update(build_synapse_client_resource_tree(self)) - resources.update({"/.well-known": well_known_resource(self)}) + resources["/.well-known"] = well_known_resource(self) elif name == "federation": - resources.update({FEDERATION_PREFIX: TransportLayerServer(self)}) + resources[FEDERATION_PREFIX] = TransportLayerServer(self) elif name == "media": if self.config.media.can_load_media_repo: media_repo = self.get_media_repository_resource() @@ -359,16 +359,12 @@ class GenericWorkerServer(HomeServer): # Only load the openid resource separately if federation resource # is not specified since federation resource includes openid # resource. - resources.update( - { - FEDERATION_PREFIX: TransportLayerServer( - self, servlet_groups=["openid"] - ) - } + resources[FEDERATION_PREFIX] = TransportLayerServer( + self, servlet_groups=["openid"] ) if name in ["keys", "federation"]: - resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) + resources[SERVER_KEY_PREFIX] = KeyResource(self) if name == "replication": resources[REPLICATION_PREFIX] = ReplicationRestResource(self) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 883f2fd2ec..de3f08876f 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -31,7 +31,7 @@ from synapse.api.urls import ( LEGACY_MEDIA_PREFIX, MEDIA_R0_PREFIX, MEDIA_V3_PREFIX, - SERVER_KEY_V2_PREFIX, + SERVER_KEY_PREFIX, STATIC_PREFIX, ) from synapse.app import _base @@ -60,7 +60,7 @@ from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource from synapse.rest import ClientRestResource from synapse.rest.admin import AdminRestResource from synapse.rest.health import HealthResource -from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.rest.key.v2 import KeyResource from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.rest.well_known import well_known_resource from synapse.server import HomeServer @@ -215,30 +215,22 @@ class SynapseHomeServer(HomeServer): consent_resource: Resource = ConsentResource(self) if compress: consent_resource = gz_wrap(consent_resource) - resources.update({"/_matrix/consent": consent_resource}) + resources["/_matrix/consent"] = consent_resource if name == "federation": federation_resource: Resource = TransportLayerServer(self) if compress: federation_resource = gz_wrap(federation_resource) - resources.update({FEDERATION_PREFIX: federation_resource}) + resources[FEDERATION_PREFIX] = federation_resource if name == "openid": - resources.update( - { - FEDERATION_PREFIX: TransportLayerServer( - self, servlet_groups=["openid"] - ) - } + resources[FEDERATION_PREFIX] = TransportLayerServer( + self, servlet_groups=["openid"] ) if name in ["static", "client"]: - resources.update( - { - STATIC_PREFIX: StaticResource( - os.path.join(os.path.dirname(synapse.__file__), "static") - ) - } + resources[STATIC_PREFIX] = StaticResource( + os.path.join(os.path.dirname(synapse.__file__), "static") ) if name in ["media", "federation", "client"]: @@ -257,7 +249,7 @@ class SynapseHomeServer(HomeServer): ) if name in ["keys", "federation"]: - resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) + resources[SERVER_KEY_PREFIX] = KeyResource(self) if name == "metrics" and self.config.metrics.enable_metrics: metrics_resource: Resource = MetricsResource(RegistryProxy) diff --git a/synapse/rest/key/v2/__init__.py b/synapse/rest/key/v2/__init__.py index 7f8c1de1ff..26403facb8 100644 --- a/synapse/rest/key/v2/__init__.py +++ b/synapse/rest/key/v2/__init__.py @@ -14,17 +14,20 @@ from typing import TYPE_CHECKING -from twisted.web.resource import Resource - -from .local_key_resource import LocalKey -from .remote_key_resource import RemoteKey +from synapse.http.server import HttpServer, JsonResource +from synapse.rest.key.v2.local_key_resource import LocalKey +from synapse.rest.key.v2.remote_key_resource import RemoteKey if TYPE_CHECKING: from synapse.server import HomeServer -class KeyApiV2Resource(Resource): +class KeyResource(JsonResource): def __init__(self, hs: "HomeServer"): - Resource.__init__(self) - self.putChild(b"server", LocalKey(hs)) - self.putChild(b"query", RemoteKey(hs)) + super().__init__(hs, canonical_json=True) + self.register_servlets(self, hs) + + @staticmethod + def register_servlets(http_server: HttpServer, hs: "HomeServer") -> None: + LocalKey(hs).register(http_server) + RemoteKey(hs).register(http_server) diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index 095993415c..d03e728d42 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -13,16 +13,15 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Optional +import re +from typing import TYPE_CHECKING, Optional, Tuple -from canonicaljson import encode_canonical_json from signedjson.sign import sign_json from unpaddedbase64 import encode_base64 -from twisted.web.resource import Resource +from twisted.web.server import Request -from synapse.http.server import respond_with_json_bytes -from synapse.http.site import SynapseRequest +from synapse.http.servlet import RestServlet from synapse.types import JsonDict if TYPE_CHECKING: @@ -31,7 +30,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class LocalKey(Resource): +class LocalKey(RestServlet): """HTTP resource containing encoding the TLS X.509 certificate and NACL signature verification keys for this server:: @@ -61,18 +60,17 @@ class LocalKey(Resource): } """ - isLeaf = True + PATTERNS = (re.compile("^/_matrix/key/v2/server(/(?P[^/]*))?$"),) def __init__(self, hs: "HomeServer"): self.config = hs.config self.clock = hs.get_clock() self.update_response_body(self.clock.time_msec()) - Resource.__init__(self) def update_response_body(self, time_now_msec: int) -> None: refresh_interval = self.config.key.key_refresh_interval self.valid_until_ts = int(time_now_msec + refresh_interval) - self.response_body = encode_canonical_json(self.response_json_object()) + self.response_body = self.response_json_object() def response_json_object(self) -> JsonDict: verify_keys = {} @@ -99,9 +97,11 @@ class LocalKey(Resource): json_object = sign_json(json_object, self.config.server.server_name, key) return json_object - def render_GET(self, request: SynapseRequest) -> Optional[int]: + def on_GET( + self, request: Request, key_id: Optional[str] = None + ) -> Tuple[int, JsonDict]: time_now = self.clock.time_msec() # Update the expiry time if less than half the interval remains. if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts: self.update_response_body(time_now) - return respond_with_json_bytes(request, 200, self.response_body) + return 200, self.response_body diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 7f8ad29566..19820886f5 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -13,15 +13,20 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, Set +import re +from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple from signedjson.sign import sign_json -from synapse.api.errors import Codes, SynapseError +from twisted.web.server import Request + from synapse.crypto.keyring import ServerKeyFetcher -from synapse.http.server import DirectServeJsonResource, respond_with_json -from synapse.http.servlet import parse_integer, parse_json_object_from_request -from synapse.http.site import SynapseRequest +from synapse.http.server import HttpServer +from synapse.http.servlet import ( + RestServlet, + parse_integer, + parse_json_object_from_request, +) from synapse.types import JsonDict from synapse.util import json_decoder from synapse.util.async_helpers import yieldable_gather_results @@ -32,7 +37,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class RemoteKey(DirectServeJsonResource): +class RemoteKey(RestServlet): """HTTP resource for retrieving the TLS certificate and NACL signature verification keys for a collection of servers. Checks that the reported X.509 TLS certificate matches the one used in the HTTPS connection. Checks @@ -88,11 +93,7 @@ class RemoteKey(DirectServeJsonResource): } """ - isLeaf = True - def __init__(self, hs: "HomeServer"): - super().__init__() - self.fetcher = ServerKeyFetcher(hs) self.store = hs.get_datastores().main self.clock = hs.get_clock() @@ -101,36 +102,48 @@ class RemoteKey(DirectServeJsonResource): ) self.config = hs.config - async def _async_render_GET(self, request: SynapseRequest) -> None: - assert request.postpath is not None - if len(request.postpath) == 1: - (server,) = request.postpath - query: dict = {server.decode("ascii"): {}} - elif len(request.postpath) == 2: - server, key_id = request.postpath + def register(self, http_server: HttpServer) -> None: + http_server.register_paths( + "GET", + ( + re.compile( + "^/_matrix/key/v2/query/(?P[^/]*)(/(?P[^/]*))?$" + ), + ), + self.on_GET, + self.__class__.__name__, + ) + http_server.register_paths( + "POST", + (re.compile("^/_matrix/key/v2/query$"),), + self.on_POST, + self.__class__.__name__, + ) + + async def on_GET( + self, request: Request, server: str, key_id: Optional[str] = None + ) -> Tuple[int, JsonDict]: + if server and key_id: minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts") arguments = {} if minimum_valid_until_ts is not None: arguments["minimum_valid_until_ts"] = minimum_valid_until_ts - query = {server.decode("ascii"): {key_id.decode("ascii"): arguments}} + query = {server: {key_id: arguments}} else: - raise SynapseError(404, "Not found %r" % request.postpath, Codes.NOT_FOUND) + query = {server: {}} - await self.query_keys(request, query, query_remote_on_cache_miss=True) + return 200, await self.query_keys(query, query_remote_on_cache_miss=True) - async def _async_render_POST(self, request: SynapseRequest) -> None: + async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) query = content["server_keys"] - await self.query_keys(request, query, query_remote_on_cache_miss=True) + return 200, await self.query_keys(query, query_remote_on_cache_miss=True) async def query_keys( - self, - request: SynapseRequest, - query: JsonDict, - query_remote_on_cache_miss: bool = False, - ) -> None: + self, query: JsonDict, query_remote_on_cache_miss: bool = False + ) -> JsonDict: logger.info("Handling query for keys %r", query) store_queries = [] @@ -232,7 +245,7 @@ class RemoteKey(DirectServeJsonResource): for server_name, keys in cache_misses.items() ), ) - await self.query_keys(request, query, query_remote_on_cache_miss=False) + return await self.query_keys(query, query_remote_on_cache_miss=False) else: signed_keys = [] for key_json_raw in json_results: @@ -244,6 +257,4 @@ class RemoteKey(DirectServeJsonResource): signed_keys.append(key_json) - response = {"server_keys": signed_keys} - - respond_with_json(request, 200, response, canonical_json=True) + return {"server_keys": signed_keys} diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index c7dae58eb5..8d03da7f96 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -79,7 +79,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): self.assertEqual(channel.code, 401) -@patch("synapse.app.homeserver.KeyApiV2Resource", new=Mock()) +@patch("synapse.app.homeserver.KeyResource", new=Mock()) class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index ac0ac06b7e..7f1fba1086 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -26,7 +26,7 @@ from twisted.web.resource import NoResource, Resource from synapse.crypto.keyring import PerspectivesKeyFetcher from synapse.http.site import SynapseRequest -from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.rest.key.v2 import KeyResource from synapse.server import HomeServer from synapse.storage.keys import FetchKeyResult from synapse.types import JsonDict @@ -46,7 +46,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): def create_test_resource(self) -> Resource: return create_resource_tree( - {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource() + {"/_matrix/key/v2": KeyResource(self.hs)}, root_resource=NoResource() ) def expect_outgoing_key_request( -- cgit 1.5.1 From 1433b5d5b64c3a6624e6e4ff4fef22127c49df86 Mon Sep 17 00:00:00 2001 From: Tadeusz Sośnierz Date: Fri, 21 Oct 2022 14:52:44 +0200 Subject: Show erasure status when listing users in the Admin API (#14205) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Show erasure status when listing users in the Admin API * Use USING when joining erased_users * Add changelog entry * Revert "Use USING when joining erased_users" This reverts commit 30bd2bf106415caadcfdbdd1b234ef2b106cc394. * Make the erased check work on postgres * Add a testcase for showing erased user status * Appease the style linter * Explicitly convert `erased` to bool to make SQLite consistent with Postgres This also adds us an easy way in to fix the other accidentally integered columns. * Move erasure status test to UsersListTestCase * Include user erased status when fetching user info via the admin API * Document the erase status in user_admin_api * Appease the linter and mypy * Signpost comments in tests Co-authored-by: Tadeusz Sośnierz Co-authored-by: David Robertson --- changelog.d/14205.feature | 1 + docs/admin_api/user_admin_api.md | 4 ++++ synapse/handlers/admin.py | 1 + synapse/storage/databases/main/__init__.py | 13 +++++++++-- tests/rest/admin/test_user.py | 35 +++++++++++++++++++++++++++++- 5 files changed, 51 insertions(+), 3 deletions(-) create mode 100644 changelog.d/14205.feature (limited to 'tests/rest') diff --git a/changelog.d/14205.feature b/changelog.d/14205.feature new file mode 100644 index 0000000000..6692063352 --- /dev/null +++ b/changelog.d/14205.feature @@ -0,0 +1 @@ +Show erasure status when listing users in the Admin API. diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 3625c7b6c5..c95d6c9b05 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -37,6 +37,7 @@ It returns a JSON body like the following: "is_guest": 0, "admin": 0, "deactivated": 0, + "erased": false, "shadow_banned": 0, "creation_ts": 1560432506, "appservice_id": null, @@ -167,6 +168,7 @@ A response body like the following is returned: "admin": 0, "user_type": null, "deactivated": 0, + "erased": false, "shadow_banned": 0, "displayname": "", "avatar_url": null, @@ -177,6 +179,7 @@ A response body like the following is returned: "admin": 1, "user_type": null, "deactivated": 0, + "erased": false, "shadow_banned": 0, "displayname": "", "avatar_url": "", @@ -247,6 +250,7 @@ The following fields are returned in the JSON response body: - `user_type` - string - Type of the user. Normal users are type `None`. This allows user type specific behaviour. There are also types `support` and `bot`. - `deactivated` - bool - Status if that user has been marked as deactivated. + - `erased` - bool - Status if that user has been marked as erased. - `shadow_banned` - bool - Status if that user has been marked as shadow banned. - `displayname` - string - The user's display name if they have set one. - `avatar_url` - string - The user's avatar URL if they have set one. diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index f2989cc4a2..5bf8e86387 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -100,6 +100,7 @@ class AdminHandler: user_info_dict["avatar_url"] = profile.avatar_url user_info_dict["threepids"] = threepids user_info_dict["external_ids"] = external_ids + user_info_dict["erased"] = await self.store.is_user_erased(user.to_string()) return user_info_dict diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index a62b4abd4e..cfaedf5e0c 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -201,7 +201,7 @@ class DataStore( name: Optional[str] = None, guests: bool = True, deactivated: bool = False, - order_by: str = UserSortOrder.USER_ID.value, + order_by: str = UserSortOrder.NAME.value, direction: str = "f", approved: bool = True, ) -> Tuple[List[JsonDict], int]: @@ -261,6 +261,7 @@ class DataStore( sql_base = f""" FROM users as u LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ? + LEFT JOIN erased_users AS eu ON u.name = eu.user_id {where_clause} """ sql = "SELECT COUNT(*) as total_users " + sql_base @@ -269,7 +270,8 @@ class DataStore( sql = f""" SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, - displayname, avatar_url, creation_ts * 1000 as creation_ts, approved + displayname, avatar_url, creation_ts * 1000 as creation_ts, approved, + eu.user_id is not null as erased {sql_base} ORDER BY {order_by_column} {order}, u.name ASC LIMIT ? OFFSET ? @@ -277,6 +279,13 @@ class DataStore( args += [limit, start] txn.execute(sql, args) users = self.db_pool.cursor_to_dict(txn) + + # some of those boolean values are returned as integers when we're on SQLite + columns_to_boolify = ["erased"] + for user in users: + for column in columns_to_boolify: + user[column] = bool(user[column]) + return users, count return await self.db_pool.runInteraction( diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 4c1ce33463..63410ffdf1 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -31,7 +31,7 @@ from synapse.api.room_versions import RoomVersions from synapse.rest.client import devices, login, logout, profile, register, room, sync from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.server import HomeServer -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock from tests import unittest @@ -924,6 +924,36 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(1, len(non_admin_user_ids), non_admin_user_ids) self.assertEqual(not_approved_user, non_admin_user_ids[0]) + def test_erasure_status(self) -> None: + # Create a new user. + user_id = self.register_user("eraseme", "eraseme") + + # They should appear in the list users API, marked as not erased. + channel = self.make_request( + "GET", + self.url + "?deactivated=true", + access_token=self.admin_user_tok, + ) + users = {user["name"]: user for user in channel.json_body["users"]} + self.assertIs(users[user_id]["erased"], False) + + # Deactivate that user, requesting erasure. + deactivate_account_handler = self.hs.get_deactivate_account_handler() + self.get_success( + deactivate_account_handler.deactivate_account( + user_id, erase_data=True, requester=create_requester(user_id) + ) + ) + + # Repeat the list users query. They should now be marked as erased. + channel = self.make_request( + "GET", + self.url + "?deactivated=true", + access_token=self.admin_user_tok, + ) + users = {user["name"]: user for user in channel.json_body["users"]} + self.assertIs(users[user_id]["erased"], True) + def _order_test( self, expected_user_list: List[str], @@ -1195,6 +1225,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User1", channel.json_body["displayname"]) + self.assertFalse(channel.json_body["erased"]) # Deactivate and erase user channel = self.make_request( @@ -1219,6 +1250,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self.assertEqual(0, len(channel.json_body["threepids"])) self.assertIsNone(channel.json_body["avatar_url"]) self.assertIsNone(channel.json_body["displayname"]) + self.assertTrue(channel.json_body["erased"]) self._is_erased("@user:test", True) @@ -2757,6 +2789,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertIn("avatar_url", content) self.assertIn("admin", content) self.assertIn("deactivated", content) + self.assertIn("erased", content) self.assertIn("shadow_banned", content) self.assertIn("creation_ts", content) self.assertIn("appservice_id", content) -- cgit 1.5.1 From 4dd7aa371b6bc746fa4b0a9af220b2013b17a45d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 21 Oct 2022 09:11:19 -0400 Subject: Properly update the threads table when thread events are redacted. (#14248) When the last event in a thread is redacted we need to update the threads table: * Find the new latest event in the thread and store it into the table; or * Remove the thread from the table if it is no longer a thread (i.e. all events in the thread were redacted). --- changelog.d/14248.bugfix | 1 + synapse/storage/databases/main/events.py | 61 ++++++++++++++--- tests/rest/client/test_relations.py | 110 +++++++++++++++++++++---------- 3 files changed, 129 insertions(+), 43 deletions(-) create mode 100644 changelog.d/14248.bugfix (limited to 'tests/rest') diff --git a/changelog.d/14248.bugfix b/changelog.d/14248.bugfix new file mode 100644 index 0000000000..203c52c16b --- /dev/null +++ b/changelog.d/14248.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.70.0rc1 where the information returned from the `/threads` API could be stale when threaded events are redacted. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 6698cbf664..00880bb37d 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2028,25 +2028,37 @@ class PersistEventsStore: redacted_event_id: The event that was redacted. """ - # Fetch the current relation of the event being redacted. - redacted_relates_to = self.db_pool.simple_select_one_onecol_txn( + # Fetch the relation of the event being redacted. + row = self.db_pool.simple_select_one_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id}, - retcol="relates_to_id", + retcols=("relates_to_id", "relation_type"), allow_none=True, ) + # Nothing to do if no relation is found. + if row is None: + return + + redacted_relates_to = row["relates_to_id"] + rel_type = row["relation_type"] + self.db_pool.simple_delete_txn( + txn, table="event_relations", keyvalues={"event_id": redacted_event_id} + ) + # Any relation information for the related event must be cleared. - if redacted_relates_to is not None: - self.store._invalidate_cache_and_stream( - txn, self.store.get_relations_for_event, (redacted_relates_to,) - ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_relations_for_event, (redacted_relates_to,) + ) + if rel_type == RelationTypes.ANNOTATION: self.store._invalidate_cache_and_stream( txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,) ) + if rel_type == RelationTypes.REPLACE: self.store._invalidate_cache_and_stream( txn, self.store.get_applicable_edit, (redacted_relates_to,) ) + if rel_type == RelationTypes.THREAD: self.store._invalidate_cache_and_stream( txn, self.store.get_thread_summary, (redacted_relates_to,) ) @@ -2057,9 +2069,38 @@ class PersistEventsStore: txn, self.store.get_threads, (room_id,) ) - self.db_pool.simple_delete_txn( - txn, table="event_relations", keyvalues={"event_id": redacted_event_id} - ) + # Find the new latest event in the thread. + sql = """ + SELECT event_id, topological_ordering, stream_ordering + FROM event_relations + INNER JOIN events USING (event_id) + WHERE relates_to_id = ? AND relation_type = ? + ORDER BY topological_ordering DESC, stream_ordering DESC + LIMIT 1 + """ + txn.execute(sql, (redacted_relates_to, RelationTypes.THREAD)) + + # If a latest event is found, update the threads table, this might + # be the same current latest event (if an earlier event in the thread + # was redacted). + latest_event_row = txn.fetchone() + if latest_event_row: + self.db_pool.simple_upsert_txn( + txn, + table="threads", + keyvalues={"room_id": room_id, "thread_id": redacted_relates_to}, + values={ + "latest_event_id": latest_event_row[0], + "topological_ordering": latest_event_row[1], + "stream_ordering": latest_event_row[2], + }, + ) + + # Otherwise, delete the thread: it no longer exists. + else: + self.db_pool.simple_delete_one_txn( + txn, table="threads", keyvalues={"thread_id": redacted_relates_to} + ) def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase) -> None: if isinstance(event.content.get("topic"), str): diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index ddf315b894..e3d801f7a8 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1523,6 +1523,26 @@ class RelationRedactionTestCase(BaseRelationsTestCase): ) self.assertEqual(200, channel.code, channel.json_body) + def _get_threads(self) -> List[Tuple[str, str]]: + """Request the threads in the room and returns a list of thread ID and latest event ID.""" + # Request the threads in the room. + channel = self.make_request( + "GET", + f"/_matrix/client/v1/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + threads = channel.json_body["chunk"] + return [ + ( + t["event_id"], + t["unsigned"]["m.relations"][RelationTypes.THREAD]["latest_event"][ + "event_id" + ], + ) + for t in threads + ] + def test_redact_relation_annotation(self) -> None: """ Test that annotations of an event are properly handled after the @@ -1567,58 +1587,82 @@ class RelationRedactionTestCase(BaseRelationsTestCase): The redacted event should not be included in bundled aggregations or the response to relations. """ - channel = self._send_relation( - RelationTypes.THREAD, - EventTypes.Message, - content={"body": "reply 1", "msgtype": "m.text"}, - ) - unredacted_event_id = channel.json_body["event_id"] + # Create a thread with a few events in it. + thread_replies = [] + for i in range(3): + channel = self._send_relation( + RelationTypes.THREAD, + EventTypes.Message, + content={"body": f"reply {i}", "msgtype": "m.text"}, + ) + thread_replies.append(channel.json_body["event_id"]) - # Note that the *last* event in the thread is redacted, as that gets - # included in the bundled aggregation. - channel = self._send_relation( - RelationTypes.THREAD, - EventTypes.Message, - content={"body": "reply 2", "msgtype": "m.text"}, + ################################################## + # Check the test data is configured as expected. # + ################################################## + self.assertEquals(self._get_related_events(), list(reversed(thread_replies))) + relations = self._get_bundled_aggregations() + self.assertDictContainsSubset( + {"count": 3, "current_user_participated": True}, + relations[RelationTypes.THREAD], + ) + # The latest event is the last sent event. + self.assertEqual( + relations[RelationTypes.THREAD]["latest_event"]["event_id"], + thread_replies[-1], ) - to_redact_event_id = channel.json_body["event_id"] - # Both relations exist. - event_ids = self._get_related_events() + # There should be one thread, the latest event is the event that will be redacted. + self.assertEqual(self._get_threads(), [(self.parent_id, thread_replies[-1])]) + + ########################## + # Redact the last event. # + ########################## + self._redact(thread_replies.pop()) + + # The thread should still exist, but the latest event should be updated. + self.assertEquals(self._get_related_events(), list(reversed(thread_replies))) relations = self._get_bundled_aggregations() - self.assertEquals(event_ids, [to_redact_event_id, unredacted_event_id]) self.assertDictContainsSubset( - { - "count": 2, - "current_user_participated": True, - }, + {"count": 2, "current_user_participated": True}, relations[RelationTypes.THREAD], ) - # And the latest event returned is the event that will be redacted. + # And the latest event is the last unredacted event. self.assertEqual( relations[RelationTypes.THREAD]["latest_event"]["event_id"], - to_redact_event_id, + thread_replies[-1], ) + self.assertEqual(self._get_threads(), [(self.parent_id, thread_replies[-1])]) - # Redact one of the reactions. - self._redact(to_redact_event_id) + ########################################### + # Redact the *first* event in the thread. # + ########################################### + self._redact(thread_replies.pop(0)) - # The unredacted relation should still exist. - event_ids = self._get_related_events() + # Nothing should have changed (except the thread count). + self.assertEquals(self._get_related_events(), thread_replies) relations = self._get_bundled_aggregations() - self.assertEquals(event_ids, [unredacted_event_id]) self.assertDictContainsSubset( - { - "count": 1, - "current_user_participated": True, - }, + {"count": 1, "current_user_participated": True}, relations[RelationTypes.THREAD], ) - # And the latest event is now the unredacted event. + # And the latest event is the last unredacted event. self.assertEqual( relations[RelationTypes.THREAD]["latest_event"]["event_id"], - unredacted_event_id, + thread_replies[-1], ) + self.assertEqual(self._get_threads(), [(self.parent_id, thread_replies[-1])]) + + #################################### + # Redact the last remaining event. # + #################################### + self._redact(thread_replies.pop(0)) + self.assertEquals(thread_replies, []) + + # The event should no longer be considered a thread. + self.assertEquals(self._get_related_events(), []) + self.assertEquals(self._get_bundled_aggregations(), {}) + self.assertEqual(self._get_threads(), []) def test_redact_parent_edit(self) -> None: """Test that edits of an event are redacted when the original event -- cgit 1.5.1 From 9192d74b0bf2f87b00d3e106a18baa9ce27acda1 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 25 Oct 2022 16:25:02 +0200 Subject: Refactor OIDC tests to better mimic an actual OIDC provider. (#13910) This implements a fake OIDC server, which intercepts calls to the HTTP client. Improves accuracy of tests by covering more internal methods. One particular example was the ID token validation, which previously mocked. This uncovered an incorrect dependency: Synapse actually requires at least authlib 0.15.1, not 0.14.0. --- changelog.d/13910.misc | 1 + pyproject.toml | 2 +- synapse/handlers/oidc.py | 15 +- tests/federation/test_federation_client.py | 36 +- tests/handlers/test_oidc.py | 580 +++++++++++++---------------- tests/rest/client/test_auth.py | 32 +- tests/rest/client/test_login.py | 40 +- tests/rest/client/utils.py | 136 +++---- tests/test_utils/__init__.py | 40 +- tests/test_utils/oidc.py | 325 ++++++++++++++++ 10 files changed, 747 insertions(+), 460 deletions(-) create mode 100644 changelog.d/13910.misc create mode 100644 tests/test_utils/oidc.py (limited to 'tests/rest') diff --git a/changelog.d/13910.misc b/changelog.d/13910.misc new file mode 100644 index 0000000000..e906952aab --- /dev/null +++ b/changelog.d/13910.misc @@ -0,0 +1 @@ +Refactor OIDC tests to better mimic an actual OIDC provider. diff --git a/pyproject.toml b/pyproject.toml index 6ebac41ed1..7e0feb75aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -192,7 +192,7 @@ psycopg2 = { version = ">=2.8", markers = "platform_python_implementation != 'Py psycopg2cffi = { version = ">=2.8", markers = "platform_python_implementation == 'PyPy'", optional = true } psycopg2cffi-compat = { version = "==1.1", markers = "platform_python_implementation == 'PyPy'", optional = true } pysaml2 = { version = ">=4.5.0", optional = true } -authlib = { version = ">=0.14.0", optional = true } +authlib = { version = ">=0.15.1", optional = true } # systemd-python is necessary for logging to the systemd journal via # `systemd.journal.JournalHandler`, as is documented in # `contrib/systemd/log_config.yaml`. diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index d7a8226900..9759daf043 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -275,6 +275,7 @@ class OidcProvider: provider: OidcProviderConfig, ): self._store = hs.get_datastores().main + self._clock = hs.get_clock() self._macaroon_generaton = macaroon_generator @@ -673,6 +674,13 @@ class OidcProvider: Returns: The decoded claims in the ID token. """ + id_token = token.get("id_token") + logger.debug("Attempting to decode JWT id_token %r", id_token) + + # That has been theoritically been checked by the caller, so even though + # assertion are not enabled in production, it is mainly here to appease mypy + assert id_token is not None + metadata = await self.load_metadata() claims_params = { "nonce": nonce, @@ -688,9 +696,6 @@ class OidcProvider: claim_options = {"iss": {"values": [metadata["issuer"]]}} - id_token = token["id_token"] - logger.debug("Attempting to decode JWT id_token %r", id_token) - # Try to decode the keys in cache first, then retry by forcing the keys # to be reloaded jwk_set = await self.load_jwks() @@ -715,7 +720,9 @@ class OidcProvider: logger.debug("Decoded id_token JWT %r; validating", claims) - claims.validate(leeway=120) # allows 2 min of clock skew + claims.validate( + now=self._clock.time(), leeway=120 + ) # allows 2 min of clock skew return claims diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py index a538215931..51d3bb8fff 100644 --- a/tests/federation/test_federation_client.py +++ b/tests/federation/test_federation_client.py @@ -12,13 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json from unittest import mock import twisted.web.client from twisted.internet import defer -from twisted.internet.protocol import Protocol -from twisted.python.failure import Failure from twisted.test.proto_helpers import MemoryReactor from synapse.api.room_versions import RoomVersions @@ -26,10 +23,9 @@ from synapse.events import EventBase from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer -from synapse.types import JsonDict from synapse.util import Clock -from tests.test_utils import event_injection +from tests.test_utils import FakeResponse, event_injection from tests.unittest import FederatingHomeserverTestCase @@ -98,8 +94,8 @@ class FederationClientTest(FederatingHomeserverTestCase): # mock up the response, and have the agent return it self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed( - _mock_response( - { + FakeResponse.json( + payload={ "pdus": [ create_event_dict, member_event_dict, @@ -208,8 +204,8 @@ class FederationClientTest(FederatingHomeserverTestCase): # mock up the response, and have the agent return it self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed( - _mock_response( - { + FakeResponse.json( + payload={ "origin": "yet.another.server", "origin_server_ts": 900, "pdus": [ @@ -269,8 +265,8 @@ class FederationClientTest(FederatingHomeserverTestCase): # We expect an outbound request to /backfill, so stub that out self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed( - _mock_response( - { + FakeResponse.json( + payload={ "origin": "yet.another.server", "origin_server_ts": 900, # Mimic the other server returning our new `pulled_event` @@ -305,21 +301,3 @@ class FederationClientTest(FederatingHomeserverTestCase): # This is 2 because it failed once from `self.OTHER_SERVER_NAME` and the # other from "yet.another.server" self.assertEqual(backfill_num_attempts, 2) - - -def _mock_response(resp: JsonDict): - body = json.dumps(resp).encode("utf-8") - - def deliver_body(p: Protocol): - p.dataReceived(body) - p.connectionLost(Failure(twisted.web.client.ResponseDone())) - - response = mock.Mock( - code=200, - phrase=b"OK", - headers=twisted.web.client.Headers({"content-Type": ["application/json"]}), - length=len(body), - deliverBody=deliver_body, - ) - mock.seal(response) - return response diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index e6cd3af7b7..5955410524 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -11,9 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import os -from typing import Any, Dict +from typing import Any, Dict, Tuple from unittest.mock import ANY, Mock, patch from urllib.parse import parse_qs, urlparse @@ -22,12 +21,15 @@ import pymacaroons from twisted.test.proto_helpers import MemoryReactor from synapse.handlers.sso import MappingException +from synapse.http.site import SynapseRequest from synapse.server import HomeServer -from synapse.types import JsonDict, UserID +from synapse.types import UserID from synapse.util import Clock -from synapse.util.macaroons import OidcSessionData, get_value_from_macaroon +from synapse.util.macaroons import get_value_from_macaroon +from synapse.util.stringutils import random_string from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock +from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer from tests.unittest import HomeserverTestCase, override_config try: @@ -46,12 +48,6 @@ BASE_URL = "https://synapse/" CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback" SCOPES = ["openid"] -AUTHORIZATION_ENDPOINT = ISSUER + "authorize" -TOKEN_ENDPOINT = ISSUER + "token" -USERINFO_ENDPOINT = ISSUER + "userinfo" -WELL_KNOWN = ISSUER + ".well-known/openid-configuration" -JWKS_URI = ISSUER + ".well-known/jwks.json" - # config for common cases DEFAULT_CONFIG = { "enabled": True, @@ -66,9 +62,9 @@ DEFAULT_CONFIG = { EXPLICIT_ENDPOINT_CONFIG = { **DEFAULT_CONFIG, "discover": False, - "authorization_endpoint": AUTHORIZATION_ENDPOINT, - "token_endpoint": TOKEN_ENDPOINT, - "jwks_uri": JWKS_URI, + "authorization_endpoint": ISSUER + "authorize", + "token_endpoint": ISSUER + "token", + "jwks_uri": ISSUER + "jwks", } @@ -102,27 +98,6 @@ class TestMappingProviderFailures(TestMappingProvider): } -async def get_json(url: str) -> JsonDict: - # Mock get_json calls to handle jwks & oidc discovery endpoints - if url == WELL_KNOWN: - # Minimal discovery document, as defined in OpenID.Discovery - # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata - return { - "issuer": ISSUER, - "authorization_endpoint": AUTHORIZATION_ENDPOINT, - "token_endpoint": TOKEN_ENDPOINT, - "jwks_uri": JWKS_URI, - "userinfo_endpoint": USERINFO_ENDPOINT, - "response_types_supported": ["code"], - "subject_types_supported": ["public"], - "id_token_signing_alg_values_supported": ["RS256"], - } - elif url == JWKS_URI: - return {"keys": []} - - return {} - - def _key_file_path() -> str: """path to a file containing the private half of a test key""" @@ -159,11 +134,11 @@ class OidcHandlerTestCase(HomeserverTestCase): return config def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.http_client = Mock(spec=["get_json"]) - self.http_client.get_json.side_effect = get_json - self.http_client.user_agent = b"Synapse Test" + self.fake_server = FakeOidcServer(clock=clock, issuer=ISSUER) - hs = self.setup_test_homeserver(proxied_http_client=self.http_client) + hs = self.setup_test_homeserver() + self.hs_patcher = self.fake_server.patch_homeserver(hs=hs) + self.hs_patcher.start() self.handler = hs.get_oidc_handler() self.provider = self.handler._providers["oidc"] @@ -175,18 +150,51 @@ class OidcHandlerTestCase(HomeserverTestCase): # Reduce the number of attempts when generating MXIDs. sso_handler._MAP_USERNAME_RETRIES = 3 + auth_handler = hs.get_auth_handler() + # Mock the complete SSO login method. + self.complete_sso_login = simple_async_mock() + auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[assignment] + return hs + def tearDown(self) -> None: + self.hs_patcher.stop() + return super().tearDown() + + def reset_mocks(self): + """Reset all the Mocks.""" + self.fake_server.reset_mocks() + self.render_error.reset_mock() + self.complete_sso_login.reset_mock() + def metadata_edit(self, values): """Modify the result that will be returned by the well-known query""" - async def patched_get_json(uri): - res = await get_json(uri) - if uri == WELL_KNOWN: - res.update(values) - return res + metadata = self.fake_server.get_metadata() + metadata.update(values) + return patch.object(self.fake_server, "get_metadata", return_value=metadata) - return patch.object(self.http_client, "get_json", patched_get_json) + def start_authorization( + self, + userinfo: dict, + client_redirect_url: str = "http://client/redirect", + scope: str = "openid", + with_sid: bool = False, + ) -> Tuple[SynapseRequest, FakeAuthorizationGrant]: + """Start an authorization request, and get the callback request back.""" + nonce = random_string(10) + state = random_string(10) + + code, grant = self.fake_server.start_authorization( + userinfo=userinfo, + scope=scope, + client_id=self.provider._client_auth.client_id, + redirect_uri=self.provider._callback_url, + nonce=nonce, + with_sid=with_sid, + ) + session = self._generate_oidc_session_token(state, nonce, client_redirect_url) + return _build_callback_request(code, state, session), grant def assertRenderedError(self, error, error_description=None): self.render_error.assert_called_once() @@ -210,52 +218,54 @@ class OidcHandlerTestCase(HomeserverTestCase): """The handler should discover the endpoints from OIDC discovery document.""" # This would throw if some metadata were invalid metadata = self.get_success(self.provider.load_metadata()) - self.http_client.get_json.assert_called_once_with(WELL_KNOWN) + self.fake_server.get_metadata_handler.assert_called_once() - self.assertEqual(metadata.issuer, ISSUER) - self.assertEqual(metadata.authorization_endpoint, AUTHORIZATION_ENDPOINT) - self.assertEqual(metadata.token_endpoint, TOKEN_ENDPOINT) - self.assertEqual(metadata.jwks_uri, JWKS_URI) - # FIXME: it seems like authlib does not have that defined in its metadata models - # self.assertEqual(metadata.userinfo_endpoint, USERINFO_ENDPOINT) + self.assertEqual(metadata.issuer, self.fake_server.issuer) + self.assertEqual( + metadata.authorization_endpoint, + self.fake_server.authorization_endpoint, + ) + self.assertEqual(metadata.token_endpoint, self.fake_server.token_endpoint) + self.assertEqual(metadata.jwks_uri, self.fake_server.jwks_uri) + # It seems like authlib does not have that defined in its metadata models + self.assertEqual( + metadata.get("userinfo_endpoint"), + self.fake_server.userinfo_endpoint, + ) # subsequent calls should be cached - self.http_client.reset_mock() + self.reset_mocks() self.get_success(self.provider.load_metadata()) - self.http_client.get_json.assert_not_called() + self.fake_server.get_metadata_handler.assert_not_called() @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) def test_no_discovery(self) -> None: """When discovery is disabled, it should not try to load from discovery document.""" self.get_success(self.provider.load_metadata()) - self.http_client.get_json.assert_not_called() + self.fake_server.get_metadata_handler.assert_not_called() - @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_load_jwks(self) -> None: """JWKS loading is done once (then cached) if used.""" jwks = self.get_success(self.provider.load_jwks()) - self.http_client.get_json.assert_called_once_with(JWKS_URI) - self.assertEqual(jwks, {"keys": []}) + self.fake_server.get_jwks_handler.assert_called_once() + self.assertEqual(jwks, self.fake_server.get_jwks()) # subsequent calls should be cached… - self.http_client.reset_mock() + self.reset_mocks() self.get_success(self.provider.load_jwks()) - self.http_client.get_json.assert_not_called() + self.fake_server.get_jwks_handler.assert_not_called() # …unless forced - self.http_client.reset_mock() + self.reset_mocks() self.get_success(self.provider.load_jwks(force=True)) - self.http_client.get_json.assert_called_once_with(JWKS_URI) + self.fake_server.get_jwks_handler.assert_called_once() - # Throw if the JWKS uri is missing - original = self.provider.load_metadata - - async def patched_load_metadata(): - m = (await original()).copy() - m.update({"jwks_uri": None}) - return m - - with patch.object(self.provider, "load_metadata", patched_load_metadata): + with self.metadata_edit({"jwks_uri": None}): + # If we don't do this, the load_metadata call will throw because of the + # missing jwks_uri + self.provider._user_profile_method = "userinfo_endpoint" + self.get_success(self.provider.load_metadata(force=True)) self.get_failure(self.provider.load_jwks(force=True), RuntimeError) @override_config({"oidc_config": DEFAULT_CONFIG}) @@ -359,7 +369,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.provider.handle_redirect_request(req, b"http://client/redirect") ) ) - auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT) + auth_endpoint = urlparse(self.fake_server.authorization_endpoint) self.assertEqual(url.scheme, auth_endpoint.scheme) self.assertEqual(url.netloc, auth_endpoint.netloc) @@ -424,48 +434,34 @@ class OidcHandlerTestCase(HomeserverTestCase): with self.assertRaises(AttributeError): _ = mapping_provider.get_extra_attributes - token = { - "type": "bearer", - "id_token": "id_token", - "access_token": "access_token", - } username = "bar" userinfo = { "sub": "foo", "username": username, } expected_user_id = "@%s:%s" % (username, self.hs.hostname) - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] - self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - code = "code" - state = "state" - nonce = "nonce" client_redirect_url = "http://client/redirect" - ip_address = "10.0.0.1" - session = self._generate_oidc_session_token(state, nonce, client_redirect_url) - request = _build_callback_request(code, state, session, ip_address=ip_address) - + request, _ = self.start_authorization( + userinfo, client_redirect_url=client_redirect_url + ) self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( expected_user_id, - "oidc", + self.provider.idp_id, request, client_redirect_url, None, new_user=True, auth_provider_session_id=None, ) - self.provider._exchange_code.assert_called_once_with(code) - self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.provider._fetch_userinfo.assert_not_called() + self.fake_server.post_token_handler.assert_called_once() + self.fake_server.get_userinfo_handler.assert_not_called() self.render_error.assert_not_called() # Handle mapping errors + request, _ = self.start_authorization(userinfo) with patch.object( self.provider, "_remote_id_from_userinfo", @@ -475,81 +471,63 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertRenderedError("mapping_error") # Handle ID token errors - self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment] - self.get_success(self.handler.handle_oidc_callback(request)) + request, _ = self.start_authorization(userinfo) + with self.fake_server.id_token_override({"iss": "https://bad.issuer/"}): + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_token") - auth_handler.complete_sso_login.reset_mock() - self.provider._exchange_code.reset_mock() - self.provider._parse_id_token.reset_mock() - self.provider._fetch_userinfo.reset_mock() + self.reset_mocks() # With userinfo fetching self.provider._user_profile_method = "userinfo_endpoint" - token = { - "type": "bearer", - "access_token": "access_token", - } - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] + # Without the "openid" scope, the FakeProvider does not generate an id_token + request, _ = self.start_authorization(userinfo, scope="") self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( expected_user_id, - "oidc", + self.provider.idp_id, request, - client_redirect_url, + ANY, None, new_user=False, auth_provider_session_id=None, ) - self.provider._exchange_code.assert_called_once_with(code) - self.provider._parse_id_token.assert_not_called() - self.provider._fetch_userinfo.assert_called_once_with(token) + self.fake_server.post_token_handler.assert_called_once() + self.fake_server.get_userinfo_handler.assert_called_once() self.render_error.assert_not_called() + self.reset_mocks() + # With an ID token, userinfo fetching and sid in the ID token self.provider._user_profile_method = "userinfo_endpoint" - token = { - "type": "bearer", - "access_token": "access_token", - "id_token": "id_token", - } - id_token = { - "sid": "abcdefgh", - } - self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment] - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] - auth_handler.complete_sso_login.reset_mock() - self.provider._fetch_userinfo.reset_mock() + request, grant = self.start_authorization(userinfo, with_sid=True) + self.assertIsNotNone(grant.sid) self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( expected_user_id, - "oidc", + self.provider.idp_id, request, - client_redirect_url, + ANY, None, new_user=False, - auth_provider_session_id=id_token["sid"], + auth_provider_session_id=grant.sid, ) - self.provider._exchange_code.assert_called_once_with(code) - self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.provider._fetch_userinfo.assert_called_once_with(token) + self.fake_server.post_token_handler.assert_called_once() + self.fake_server.get_userinfo_handler.assert_called_once() self.render_error.assert_not_called() # Handle userinfo fetching error - self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment] - self.get_success(self.handler.handle_oidc_callback(request)) + request, _ = self.start_authorization(userinfo) + with self.fake_server.buggy_endpoint(userinfo=True): + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("fetch_error") - # Handle code exchange failure - from synapse.handlers.oidc import OidcError - - self.provider._exchange_code = simple_async_mock( # type: ignore[assignment] - raises=OidcError("invalid_request") - ) - self.get_success(self.handler.handle_oidc_callback(request)) - self.assertRenderedError("invalid_request") + request, _ = self.start_authorization(userinfo) + with self.fake_server.buggy_endpoint(token=True): + self.get_success(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("server_error") @override_config({"oidc_config": DEFAULT_CONFIG}) def test_callback_session(self) -> None: @@ -599,18 +577,22 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_exchange_code(self) -> None: """Code exchange behaves correctly and handles various error scenarios.""" - token = {"type": "bearer"} - token_json = json.dumps(token).encode("utf-8") - self.http_client.request = simple_async_mock( - return_value=FakeResponse(code=200, phrase=b"OK", body=token_json) + token = { + "type": "Bearer", + "access_token": "aabbcc", + } + + self.fake_server.post_token_handler.side_effect = None + self.fake_server.post_token_handler.return_value = FakeResponse.json( + payload=token ) code = "code" ret = self.get_success(self.provider._exchange_code(code)) - kwargs = self.http_client.request.call_args[1] + kwargs = self.fake_server.request.call_args[1] self.assertEqual(ret, token) self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint) args = parse_qs(kwargs["data"].decode("utf-8")) self.assertEqual(args["grant_type"], ["authorization_code"]) @@ -620,12 +602,8 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(args["redirect_uri"], [CALLBACK_URL]) # Test error handling - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=400, - phrase=b"Bad Request", - body=b'{"error": "foo", "error_description": "bar"}', - ) + self.fake_server.post_token_handler.return_value = FakeResponse.json( + code=400, payload={"error": "foo", "error_description": "bar"} ) from synapse.handlers.oidc import OidcError @@ -634,46 +612,30 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(exc.value.error_description, "bar") # Internal server error with no JSON body - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=500, - phrase=b"Internal Server Error", - body=b"Not JSON", - ) + self.fake_server.post_token_handler.return_value = FakeResponse( + code=500, body=b"Not JSON" ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # Internal server error with JSON body - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=500, - phrase=b"Internal Server Error", - body=b'{"error": "internal_server_error"}', - ) + self.fake_server.post_token_handler.return_value = FakeResponse.json( + code=500, payload={"error": "internal_server_error"} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "internal_server_error") # 4xx error without "error" field - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=400, - phrase=b"Bad request", - body=b"{}", - ) + self.fake_server.post_token_handler.return_value = FakeResponse.json( + code=400, payload={} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # 2xx error with "error" field - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=200, - phrase=b"OK", - body=b'{"error": "some_error"}', - ) + self.fake_server.post_token_handler.return_value = FakeResponse.json( + code=200, payload={"error": "some_error"} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "some_error") @@ -697,11 +659,14 @@ class OidcHandlerTestCase(HomeserverTestCase): """Test that code exchange works with a JWK client secret.""" from authlib.jose import jwt - token = {"type": "bearer"} - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8") - ) + token = { + "type": "Bearer", + "access_token": "aabbcc", + } + + self.fake_server.post_token_handler.side_effect = None + self.fake_server.post_token_handler.return_value = FakeResponse.json( + payload=token ) code = "code" @@ -714,9 +679,9 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(ret, token) # the request should have hit the token endpoint - kwargs = self.http_client.request.call_args[1] + kwargs = self.fake_server.request.call_args[1] self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint) # the client secret provided to the should be a jwt which can be checked with # the public key @@ -750,11 +715,14 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_exchange_code_no_auth(self) -> None: """Test that code exchange works with no client secret.""" - token = {"type": "bearer"} - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8") - ) + token = { + "type": "Bearer", + "access_token": "aabbcc", + } + + self.fake_server.post_token_handler.side_effect = None + self.fake_server.post_token_handler.return_value = FakeResponse.json( + payload=token ) code = "code" ret = self.get_success(self.provider._exchange_code(code)) @@ -762,9 +730,9 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(ret, token) # the request should have hit the token endpoint - kwargs = self.http_client.request.call_args[1] + kwargs = self.fake_server.request.call_args[1] self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint) # check the POSTed data args = parse_qs(kwargs["data"].decode("utf-8")) @@ -787,37 +755,19 @@ class OidcHandlerTestCase(HomeserverTestCase): """ Login while using a mapping provider that implements get_extra_attributes. """ - token = { - "type": "bearer", - "id_token": "id_token", - "access_token": "access_token", - } userinfo = { "sub": "foo", "username": "foo", "phone": "1234567", } - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] - self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - - state = "state" - client_redirect_url = "http://client/redirect" - session = self._generate_oidc_session_token( - state=state, - nonce="nonce", - client_redirect_url=client_redirect_url, - ) - request = _build_callback_request("code", state, session) - + request, _ = self.start_authorization(userinfo) self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@foo:test", - "oidc", + self.provider.idp_id, request, - client_redirect_url, + ANY, {"phone": "1234567"}, new_user=True, auth_provider_session_id=None, @@ -826,41 +776,40 @@ class OidcHandlerTestCase(HomeserverTestCase): @override_config({"oidc_config": DEFAULT_CONFIG}) def test_map_userinfo_to_user(self) -> None: """Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - userinfo: dict = { "sub": "test_user", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( "@test_user:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Some providers return an integer ID. userinfo = { "sub": 1234, "username": "test_user_2", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( "@test_user_2:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Test if the mxid is already taken store = self.hs.get_datastores().main @@ -869,8 +818,9 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user3.to_string(), password_hash=None) ) userinfo = {"sub": "test3", "username": "test_user_3"} - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() self.assertRenderedError( "mapping_error", "Mapping provider does not support de-duplicating Matrix IDs", @@ -885,38 +835,37 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user.to_string(), password_hash=None) ) - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - # Map a user via SSO. userinfo = { "sub": "test", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( user.to_string(), - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Subsequent calls should map to the same mxid. - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( user.to_string(), - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Note that a second SSO user can be mapped to the same Matrix ID. (This # requires a unique sub, but something that maps to the same matrix ID, @@ -927,17 +876,18 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test1", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( user.to_string(), - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Register some non-exact matching cases. user2 = UserID.from_string("@TEST_user_2:test") @@ -954,8 +904,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test2", "username": "TEST_USER_2", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() args = self.assertRenderedError("mapping_error") self.assertTrue( args[2].startswith( @@ -969,11 +920,12 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user2.to_string(), password_hash=None) ) - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( "@TEST_USER_2:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, @@ -983,9 +935,9 @@ class OidcHandlerTestCase(HomeserverTestCase): @override_config({"oidc_config": DEFAULT_CONFIG}) def test_map_userinfo_to_invalid_localpart(self) -> None: """If the mapping provider generates an invalid localpart it should be rejected.""" - self.get_success( - _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"}) - ) + userinfo = {"sub": "test2", "username": "föö"} + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mapping_error", "localpart is invalid: föö") @override_config( @@ -1000,9 +952,6 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_map_userinfo_to_user_retries(self) -> None: """The mapping provider can retry generating an MXID if the MXID is already in use.""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) @@ -1011,19 +960,20 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) # test_user is already taken, so test_user1 gets registered instead. - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@test_user1:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Register all of the potential mxids for a particular OIDC username. self.get_success( @@ -1039,8 +989,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "tester", "username": "tester", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() self.assertRenderedError( "mapping_error", "Unable to generate a Matrix ID from the SSO response" ) @@ -1052,7 +1003,8 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "tester", "username": "", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mapping_error", "localpart is invalid: ") @override_config( @@ -1071,7 +1023,8 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "tester", "username": None, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mapping_error", "localpart is invalid: ") @override_config( @@ -1084,16 +1037,14 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_attribute_requirements(self) -> None: """The required attributes must be met from the OIDC userinfo response.""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - # userinfo lacking "test": "foobar" attribute should fail. userinfo = { "sub": "tester", "username": "tester", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": "foobar" attribute should succeed. userinfo = { @@ -1101,13 +1052,14 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": "foobar", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) # check that the auth handler got called as expected - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@tester:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, @@ -1124,21 +1076,20 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_attribute_requirements_contains(self) -> None: """Test that auth succeeds if userinfo attribute CONTAINS required value""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # userinfo with "test": ["foobar", "foo", "bar"] attribute should succeed. userinfo = { "sub": "tester", "username": "tester", "test": ["foobar", "foo", "bar"], } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) # check that the auth handler got called as expected - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@tester:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, @@ -1158,16 +1109,15 @@ class OidcHandlerTestCase(HomeserverTestCase): Test that auth fails if attributes exist but don't match, or are non-string values. """ - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # userinfo with "test": "not_foobar" attribute should fail userinfo: dict = { "sub": "tester", "username": "tester", "test": "not_foobar", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": ["foo", "bar"] attribute should fail userinfo = { @@ -1175,8 +1125,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": ["foo", "bar"], } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": False attribute should fail # this is largely just to ensure we don't crash here @@ -1185,8 +1136,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": False, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": None attribute should fail # a value of None breaks the OIDC spec, but it's important to not crash here @@ -1195,8 +1147,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": None, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": 1 attribute should fail # this is largely just to ensure we don't crash here @@ -1205,8 +1158,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": 1, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": 3.14 attribute should fail # this is largely just to ensure we don't crash here @@ -1215,8 +1169,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": 3.14, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() def _generate_oidc_session_token( self, @@ -1230,7 +1185,7 @@ class OidcHandlerTestCase(HomeserverTestCase): return self.handler._macaroon_generator.generate_oidc_session_token( state=state, session_data=OidcSessionData( - idp_id="oidc", + idp_id=self.provider.idp_id, nonce=nonce, client_redirect_url=client_redirect_url, ui_auth_session_id=ui_auth_session_id, @@ -1238,41 +1193,6 @@ class OidcHandlerTestCase(HomeserverTestCase): ) -async def _make_callback_with_userinfo( - hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect" -) -> None: - """Mock up an OIDC callback with the given userinfo dict - - We'll pull out the OIDC handler from the homeserver, stub out a couple of methods, - and poke in the userinfo dict as if it were the response to an OIDC userinfo call. - - Args: - hs: the HomeServer impl to send the callback to. - userinfo: the OIDC userinfo dict - client_redirect_url: the URL to redirect to on success. - """ - - handler = hs.get_oidc_handler() - provider = handler._providers["oidc"] - provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment] - provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - - state = "state" - session = handler._macaroon_generator.generate_oidc_session_token( - state=state, - session_data=OidcSessionData( - idp_id="oidc", - nonce="nonce", - client_redirect_url=client_redirect_url, - ui_auth_session_id="", - ), - ) - request = _build_callback_request("code", state, session) - - await handler.handle_oidc_callback(request) - - def _build_callback_request( code: str, state: str, diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 090cef5216..ebf653d018 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -465,9 +465,11 @@ class UIAuthTests(unittest.HomeserverTestCase): * checking that the original operation succeeds """ + fake_oidc_server = self.helper.fake_oidc_server() + # log the user in remote_user_id = UserID.from_string(self.user).localpart - login_resp = self.helper.login_via_oidc(remote_user_id) + login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, remote_user_id) self.assertEqual(login_resp["user_id"], self.user) # initiate a UI Auth process by attempting to delete the device @@ -481,8 +483,8 @@ class UIAuthTests(unittest.HomeserverTestCase): # run the UIA-via-SSO flow session_id = channel.json_body["session"] - channel = self.helper.auth_via_oidc( - {"sub": remote_user_id}, ui_auth_session_id=session_id + channel, _ = self.helper.auth_via_oidc( + fake_oidc_server, {"sub": remote_user_id}, ui_auth_session_id=session_id ) # that should serve a confirmation page @@ -499,7 +501,8 @@ class UIAuthTests(unittest.HomeserverTestCase): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_does_not_offer_password_for_sso_user(self) -> None: - login_resp = self.helper.login_via_oidc("username") + fake_oidc_server = self.helper.fake_oidc_server() + login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, "username") user_tok = login_resp["access_token"] device_id = login_resp["device_id"] @@ -522,7 +525,10 @@ class UIAuthTests(unittest.HomeserverTestCase): @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_offers_both_flows_for_upgraded_user(self) -> None: """A user that had a password and then logged in with SSO should get both flows""" - login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) + fake_oidc_server = self.helper.fake_oidc_server() + login_resp, _ = self.helper.login_via_oidc( + fake_oidc_server, UserID.from_string(self.user).localpart + ) self.assertEqual(login_resp["user_id"], self.user) channel = self.delete_device( @@ -539,8 +545,13 @@ class UIAuthTests(unittest.HomeserverTestCase): @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_ui_auth_fails_for_incorrect_sso_user(self) -> None: """If the user tries to authenticate with the wrong SSO user, they get an error""" + + fake_oidc_server = self.helper.fake_oidc_server() + # log the user in - login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) + login_resp, _ = self.helper.login_via_oidc( + fake_oidc_server, UserID.from_string(self.user).localpart + ) self.assertEqual(login_resp["user_id"], self.user) # start a UI Auth flow by attempting to delete a device @@ -553,8 +564,8 @@ class UIAuthTests(unittest.HomeserverTestCase): session_id = channel.json_body["session"] # do the OIDC auth, but auth as the wrong user - channel = self.helper.auth_via_oidc( - {"sub": "wrong_user"}, ui_auth_session_id=session_id + channel, _ = self.helper.auth_via_oidc( + fake_oidc_server, {"sub": "wrong_user"}, ui_auth_session_id=session_id ) # that should return a failure message @@ -584,7 +595,10 @@ class UIAuthTests(unittest.HomeserverTestCase): """Tests that if we register a user via SSO while requiring approval for new accounts, we still raise the correct error before logging the user in. """ - login_resp = self.helper.login_via_oidc("username", expected_status=403) + fake_oidc_server = self.helper.fake_oidc_server() + login_resp, _ = self.helper.login_via_oidc( + fake_oidc_server, "username", expected_status=403 + ) self.assertEqual(login_resp["errcode"], Codes.USER_AWAITING_APPROVAL) self.assertEqual( diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index e801ba8c8b..ff5baa9f0a 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -36,7 +36,7 @@ from synapse.util import Clock from tests import unittest from tests.handlers.test_oidc import HAS_OIDC from tests.handlers.test_saml import has_saml2 -from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG +from tests.rest.client.utils import TEST_OIDC_CONFIG from tests.server import FakeChannel from tests.test_utils.html_parsers import TestHtmlParser from tests.unittest import HomeserverTestCase, override_config, skip_unless @@ -612,13 +612,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): def test_login_via_oidc(self) -> None: """If OIDC is chosen, should redirect to the OIDC auth endpoint""" - # pick the default OIDC provider - channel = self.make_request( - "GET", - "/_synapse/client/pick_idp?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) - + "&idp=oidc", - ) + fake_oidc_server = self.helper.fake_oidc_server() + + with fake_oidc_server.patch_homeserver(hs=self.hs): + # pick the default OIDC provider + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + "&idp=oidc", + ) self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers @@ -626,7 +629,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) # it should redirect us to the auth page of the OIDC server - self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint) # ... and should have set a cookie including the redirect url cookie_headers = channel.headers.getRawHeaders("Set-Cookie") @@ -643,7 +646,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): TEST_CLIENT_REDIRECT_URL, ) - channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"}) + channel, _ = self.helper.complete_oidc_auth( + fake_oidc_server, oidc_uri, cookies, {"sub": "user1"} + ) # that should serve a confirmation page self.assertEqual(channel.code, 200, channel.result) @@ -693,7 +698,10 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): def test_client_idp_redirect_to_oidc(self) -> None: """If the client pick a known IdP, redirect to it""" - channel = self._make_sso_redirect_request("oidc") + fake_oidc_server = self.helper.fake_oidc_server() + + with fake_oidc_server.patch_homeserver(hs=self.hs): + channel = self._make_sso_redirect_request("oidc") self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers @@ -701,7 +709,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) # it should redirect us to the auth page of the OIDC server - self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint) def _make_sso_redirect_request(self, idp_prov: Optional[str] = None) -> FakeChannel: """Send a request to /_matrix/client/r0/login/sso/redirect @@ -1280,9 +1288,13 @@ class UsernamePickerTestCase(HomeserverTestCase): def test_username_picker(self) -> None: """Test the happy path of a username picker flow.""" + fake_oidc_server = self.helper.fake_oidc_server() + # do the start of the login flow - channel = self.helper.auth_via_oidc( - {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL + channel, _ = self.helper.auth_via_oidc( + fake_oidc_server, + {"sub": "tester", "displayname": "Jonny"}, + TEST_CLIENT_REDIRECT_URL, ) # that should redirect to the username picker diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index c249a42bb6..967d229223 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -31,7 +31,6 @@ from typing import ( Tuple, overload, ) -from unittest.mock import patch from urllib.parse import urlencode import attr @@ -46,8 +45,19 @@ from synapse.server import HomeServer from synapse.types import JsonDict from tests.server import FakeChannel, FakeSite, make_request -from tests.test_utils import FakeResponse from tests.test_utils.html_parsers import TestHtmlParser +from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer + +# an 'oidc_config' suitable for login_via_oidc. +TEST_OIDC_ISSUER = "https://issuer.test/" +TEST_OIDC_CONFIG = { + "enabled": True, + "issuer": TEST_OIDC_ISSUER, + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": ["openid"], + "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, +} @attr.s(auto_attribs=True) @@ -543,12 +553,28 @@ class RestHelper: return channel.json_body + def fake_oidc_server(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcServer: + """Create a ``FakeOidcServer``. + + This can be used in conjuction with ``login_via_oidc``:: + + fake_oidc_server = self.helper.fake_oidc_server() + login_data, _ = self.helper.login_via_oidc(fake_oidc_server, "user") + """ + + return FakeOidcServer( + clock=self.hs.get_clock(), + issuer=issuer, + ) + def login_via_oidc( self, + fake_server: FakeOidcServer, remote_user_id: str, + with_sid: bool = False, expected_status: int = 200, - ) -> JsonDict: - """Log in via OIDC + ) -> Tuple[JsonDict, FakeAuthorizationGrant]: + """Log in (as a new user) via OIDC Returns the result of the final token login. @@ -560,7 +586,10 @@ class RestHelper: the normal places. """ client_redirect_url = "https://x" - channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url) + userinfo = {"sub": remote_user_id} + channel, grant = self.auth_via_oidc( + fake_server, userinfo, client_redirect_url, with_sid=with_sid + ) # expect a confirmation page assert channel.code == HTTPStatus.OK, channel.result @@ -585,14 +614,16 @@ class RestHelper: assert ( channel.code == expected_status ), f"unexpected status in response: {channel.code}" - return channel.json_body + return channel.json_body, grant def auth_via_oidc( self, + fake_server: FakeOidcServer, user_info_dict: JsonDict, client_redirect_url: Optional[str] = None, ui_auth_session_id: Optional[str] = None, - ) -> FakeChannel: + with_sid: bool = False, + ) -> Tuple[FakeChannel, FakeAuthorizationGrant]: """Perform an OIDC authentication flow via a mock OIDC provider. This can be used for either login or user-interactive auth. @@ -616,6 +647,7 @@ class RestHelper: the login redirect endpoint ui_auth_session_id: if set, we will perform a UI Auth flow. The session id of the UI auth. + with_sid: if True, generates a random `sid` (OIDC session ID) Returns: A FakeChannel containing the result of calling the OIDC callback endpoint. @@ -625,14 +657,15 @@ class RestHelper: cookies: Dict[str, str] = {} - # if we're doing a ui auth, hit the ui auth redirect endpoint - if ui_auth_session_id: - # can't set the client redirect url for UI Auth - assert client_redirect_url is None - oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) - else: - # otherwise, hit the login redirect endpoint - oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) + with fake_server.patch_homeserver(hs=self.hs): + # if we're doing a ui auth, hit the ui auth redirect endpoint + if ui_auth_session_id: + # can't set the client redirect url for UI Auth + assert client_redirect_url is None + oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) + else: + # otherwise, hit the login redirect endpoint + oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) # we now have a URI for the OIDC IdP, but we skip that and go straight # back to synapse's OIDC callback resource. However, we do need the "state" @@ -640,17 +673,21 @@ class RestHelper: # that synapse passes to the client. oauth_uri_path, _ = oauth_uri.split("?", 1) - assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, ( + assert oauth_uri_path == fake_server.authorization_endpoint, ( "unexpected SSO URI " + oauth_uri_path ) - return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict) + return self.complete_oidc_auth( + fake_server, oauth_uri, cookies, user_info_dict, with_sid=with_sid + ) def complete_oidc_auth( self, + fake_serer: FakeOidcServer, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict, - ) -> FakeChannel: + with_sid: bool = False, + ) -> Tuple[FakeChannel, FakeAuthorizationGrant]: """Mock out an OIDC authentication flow Assumes that an OIDC auth has been initiated by one of initiate_sso_login or @@ -661,50 +698,37 @@ class RestHelper: Requires the OIDC callback resource to be mounted at the normal place. Args: + fake_server: the fake OIDC server with which the auth should be done oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie, from initiate_sso_login or initiate_sso_ui_auth). cookies: the cookies set by synapse's redirect endpoint, which will be sent back to the callback endpoint. user_info_dict: the remote userinfo that the OIDC provider should present. Typically this should be '{"sub": ""}'. + with_sid: if True, generates a random `sid` (OIDC session ID) Returns: A FakeChannel containing the result of calling the OIDC callback endpoint. """ _, oauth_uri_qs = oauth_uri.split("?", 1) params = urllib.parse.parse_qs(oauth_uri_qs) + + code, grant = fake_serer.start_authorization( + scope=params["scope"][0], + userinfo=user_info_dict, + client_id=params["client_id"][0], + redirect_uri=params["redirect_uri"][0], + nonce=params["nonce"][0], + with_sid=with_sid, + ) + state = params["state"][0] + callback_uri = "%s?%s" % ( urllib.parse.urlparse(params["redirect_uri"][0]).path, - urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}), + urllib.parse.urlencode({"state": state, "code": code}), ) - # before we hit the callback uri, stub out some methods in the http client so - # that we don't have to handle full HTTPS requests. - # (expected url, json response) pairs, in the order we expect them. - expected_requests = [ - # first we get a hit to the token endpoint, which we tell to return - # a dummy OIDC access token - (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}), - # and then one to the user_info endpoint, which returns our remote user id. - (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict), - ] - - async def mock_req( - method: str, - uri: str, - data: Optional[dict] = None, - headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, - ): - (expected_uri, resp_obj) = expected_requests.pop(0) - assert uri == expected_uri - resp = FakeResponse( - code=HTTPStatus.OK, - phrase=b"OK", - body=json.dumps(resp_obj).encode("utf-8"), - ) - return resp - - with patch.object(self.hs.get_proxied_http_client(), "request", mock_req): + with fake_serer.patch_homeserver(hs=self.hs): # now hit the callback URI with the right params and a made-up code channel = make_request( self.hs.get_reactor(), @@ -715,7 +739,7 @@ class RestHelper: ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items() ], ) - return channel + return channel, grant def initiate_sso_login( self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str] @@ -806,21 +830,3 @@ class RestHelper: assert len(p.links) == 1, "not exactly one link in confirmation page" oauth_uri = p.links[0] return oauth_uri - - -# an 'oidc_config' suitable for login_via_oidc. -TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth" -TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token" -TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo" -TEST_OIDC_CONFIG = { - "enabled": True, - "discover": False, - "issuer": "https://issuer.test", - "client_id": "test-client-id", - "client_secret": "test-client-secret", - "scopes": ["profile"], - "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT, - "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT, - "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT, - "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, -} diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index 0d0d6faf0d..e62ebcc6a5 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -15,17 +15,24 @@ """ Utilities for running the unit tests """ +import json import sys import warnings from asyncio import Future from binascii import unhexlify -from typing import Awaitable, Callable, TypeVar +from typing import Awaitable, Callable, Tuple, TypeVar from unittest.mock import Mock import attr +import zope.interface from twisted.python.failure import Failure from twisted.web.client import ResponseDone +from twisted.web.http import RESPONSES +from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse + +from synapse.types import JsonDict TV = TypeVar("TV") @@ -97,27 +104,44 @@ def simple_async_mock(return_value=None, raises=None) -> Mock: return Mock(side_effect=cb) -@attr.s -class FakeResponse: +# Type ignore: it does not fully implement IResponse, but is good enough for tests +@zope.interface.implementer(IResponse) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class FakeResponse: # type: ignore[misc] """A fake twisted.web.IResponse object there is a similar class at treq.test.test_response, but it lacks a `phrase` attribute, and didn't support deliverBody until recently. """ - # HTTP response code - code = attr.ib(type=int) + version: Tuple[bytes, int, int] = (b"HTTP", 1, 1) - # HTTP response phrase (eg b'OK' for a 200) - phrase = attr.ib(type=bytes) + # HTTP response code + code: int = 200 # body of the response - body = attr.ib(type=bytes) + body: bytes = b"" + + headers: Headers = attr.Factory(Headers) + + @property + def phrase(self): + return RESPONSES.get(self.code, b"Unknown Status") + + @property + def length(self): + return len(self.body) def deliverBody(self, protocol): protocol.dataReceived(self.body) protocol.connectionLost(Failure(ResponseDone())) + @classmethod + def json(cls, *, code: int = 200, payload: JsonDict) -> "FakeResponse": + headers = Headers({"Content-Type": ["application/json"]}) + body = json.dumps(payload).encode("utf-8") + return cls(code=code, body=body, headers=headers) + # A small image used in some tests. # diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py new file mode 100644 index 0000000000..de134bbc89 --- /dev/null +++ b/tests/test_utils/oidc.py @@ -0,0 +1,325 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +from typing import Any, Dict, List, Optional, Tuple +from unittest.mock import Mock, patch +from urllib.parse import parse_qs + +import attr + +from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse + +from synapse.server import HomeServer +from synapse.util import Clock +from synapse.util.stringutils import random_string + +from tests.test_utils import FakeResponse + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class FakeAuthorizationGrant: + userinfo: dict + client_id: str + redirect_uri: str + scope: str + nonce: Optional[str] + sid: Optional[str] + + +class FakeOidcServer: + """A fake OpenID Connect Provider.""" + + # All methods here are mocks, so we can track when they are called, and override + # their values + request: Mock + get_jwks_handler: Mock + get_metadata_handler: Mock + get_userinfo_handler: Mock + post_token_handler: Mock + + def __init__(self, clock: Clock, issuer: str): + from authlib.jose import ECKey, KeySet + + self._clock = clock + self.issuer = issuer + + self.request = Mock(side_effect=self._request) + self.get_jwks_handler = Mock(side_effect=self._get_jwks_handler) + self.get_metadata_handler = Mock(side_effect=self._get_metadata_handler) + self.get_userinfo_handler = Mock(side_effect=self._get_userinfo_handler) + self.post_token_handler = Mock(side_effect=self._post_token_handler) + + # A code -> grant mapping + self._authorization_grants: Dict[str, FakeAuthorizationGrant] = {} + # An access token -> grant mapping + self._sessions: Dict[str, FakeAuthorizationGrant] = {} + + # We generate here an ECDSA key with the P-256 curve (ES256 algorithm) used for + # signing JWTs. ECDSA keys are really quick to generate compared to RSA. + self._key = ECKey.generate_key(crv="P-256", is_private=True) + self._jwks = KeySet([ECKey.import_key(self._key.as_pem(is_private=False))]) + + self._id_token_overrides: Dict[str, Any] = {} + + def reset_mocks(self): + self.request.reset_mock() + self.get_jwks_handler.reset_mock() + self.get_metadata_handler.reset_mock() + self.get_userinfo_handler.reset_mock() + self.post_token_handler.reset_mock() + + def patch_homeserver(self, hs: HomeServer): + """Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``. + + This patch should be used whenever the HS is expected to perform request to the + OIDC provider, e.g.:: + + fake_oidc_server = self.helper.fake_oidc_server() + with fake_oidc_server.patch_homeserver(hs): + self.make_request("GET", "/_matrix/client/r0/login/sso/redirect") + """ + return patch.object(hs.get_proxied_http_client(), "request", self.request) + + @property + def authorization_endpoint(self) -> str: + return self.issuer + "authorize" + + @property + def token_endpoint(self) -> str: + return self.issuer + "token" + + @property + def userinfo_endpoint(self) -> str: + return self.issuer + "userinfo" + + @property + def metadata_endpoint(self) -> str: + return self.issuer + ".well-known/openid-configuration" + + @property + def jwks_uri(self) -> str: + return self.issuer + "jwks" + + def get_metadata(self) -> dict: + return { + "issuer": self.issuer, + "authorization_endpoint": self.authorization_endpoint, + "token_endpoint": self.token_endpoint, + "jwks_uri": self.jwks_uri, + "userinfo_endpoint": self.userinfo_endpoint, + "response_types_supported": ["code"], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["ES256"], + } + + def get_jwks(self) -> dict: + return self._jwks.as_dict() + + def get_userinfo(self, access_token: str) -> Optional[dict]: + """Given an access token, get the userinfo of the associated session.""" + session = self._sessions.get(access_token, None) + if session is None: + return None + return session.userinfo + + def _sign(self, payload: dict) -> str: + from authlib.jose import JsonWebSignature + + jws = JsonWebSignature() + kid = self.get_jwks()["keys"][0]["kid"] + protected = {"alg": "ES256", "kid": kid} + json_payload = json.dumps(payload) + return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8") + + def generate_id_token(self, grant: FakeAuthorizationGrant) -> str: + now = self._clock.time() + id_token = { + **grant.userinfo, + "iss": self.issuer, + "aud": grant.client_id, + "iat": now, + "nbf": now, + "exp": now + 600, + } + + if grant.nonce is not None: + id_token["nonce"] = grant.nonce + + if grant.sid is not None: + id_token["sid"] = grant.sid + + id_token.update(self._id_token_overrides) + + return self._sign(id_token) + + def id_token_override(self, overrides: dict): + """Temporarily patch the ID token generated by the token endpoint.""" + return patch.object(self, "_id_token_overrides", overrides) + + def start_authorization( + self, + client_id: str, + scope: str, + redirect_uri: str, + userinfo: dict, + nonce: Optional[str] = None, + with_sid: bool = False, + ) -> Tuple[str, FakeAuthorizationGrant]: + """Start an authorization request, and get back the code to use on the authorization endpoint.""" + code = random_string(10) + sid = None + if with_sid: + sid = random_string(10) + + grant = FakeAuthorizationGrant( + userinfo=userinfo, + scope=scope, + redirect_uri=redirect_uri, + nonce=nonce, + client_id=client_id, + sid=sid, + ) + self._authorization_grants[code] = grant + + return code, grant + + def exchange_code(self, code: str) -> Optional[Dict[str, Any]]: + grant = self._authorization_grants.pop(code, None) + if grant is None: + return None + + access_token = random_string(10) + self._sessions[access_token] = grant + + token = { + "token_type": "Bearer", + "access_token": access_token, + "expires_in": 3600, + "scope": grant.scope, + } + + if "openid" in grant.scope: + token["id_token"] = self.generate_id_token(grant) + + return dict(token) + + def buggy_endpoint( + self, + *, + jwks: bool = False, + metadata: bool = False, + token: bool = False, + userinfo: bool = False, + ): + """A context which makes a set of endpoints return a 500 error. + + Args: + jwks: If True, makes the JWKS endpoint return a 500 error. + metadata: If True, makes the OIDC Discovery endpoint return a 500 error. + token: If True, makes the token endpoint return a 500 error. + userinfo: If True, makes the userinfo endpoint return a 500 error. + """ + buggy = FakeResponse(code=500, body=b"Internal server error") + + patches = {} + if jwks: + patches["get_jwks_handler"] = Mock(return_value=buggy) + if metadata: + patches["get_metadata_handler"] = Mock(return_value=buggy) + if token: + patches["post_token_handler"] = Mock(return_value=buggy) + if userinfo: + patches["get_userinfo_handler"] = Mock(return_value=buggy) + + return patch.multiple(self, **patches) + + async def _request( + self, + method: str, + uri: str, + data: Optional[bytes] = None, + headers: Optional[Headers] = None, + ) -> IResponse: + """The override of the SimpleHttpClient#request() method""" + access_token: Optional[str] = None + + if headers is None: + headers = Headers() + + # Try to find the access token in the headers if any + auth_headers = headers.getRawHeaders(b"Authorization") + if auth_headers: + parts = auth_headers[0].split(b" ") + if parts[0] == b"Bearer" and len(parts) == 2: + access_token = parts[1].decode("ascii") + + if method == "POST": + # If the method is POST, assume it has an url-encoded body + if data is None or headers.getRawHeaders(b"Content-Type") != [ + b"application/x-www-form-urlencoded" + ]: + return FakeResponse.json(code=400, payload={"error": "invalid_request"}) + + params = parse_qs(data.decode("utf-8")) + + if uri == self.token_endpoint: + # Even though this endpoint should be protected, this does not check + # for client authentication. We're not checking it for simplicity, + # and because client authentication is tested in other standalone tests. + return self.post_token_handler(params) + + elif method == "GET": + if uri == self.jwks_uri: + return self.get_jwks_handler() + elif uri == self.metadata_endpoint: + return self.get_metadata_handler() + elif uri == self.userinfo_endpoint: + return self.get_userinfo_handler(access_token=access_token) + + return FakeResponse(code=404, body=b"404 not found") + + # Request handlers + def _get_jwks_handler(self) -> IResponse: + """Handles requests to the JWKS URI.""" + return FakeResponse.json(payload=self.get_jwks()) + + def _get_metadata_handler(self) -> IResponse: + """Handles requests to the OIDC well-known document.""" + return FakeResponse.json(payload=self.get_metadata()) + + def _get_userinfo_handler(self, access_token: Optional[str]) -> IResponse: + """Handles requests to the userinfo endpoint.""" + if access_token is None: + return FakeResponse(code=401) + user_info = self.get_userinfo(access_token) + if user_info is None: + return FakeResponse(code=401) + + return FakeResponse.json(payload=user_info) + + def _post_token_handler(self, params: Dict[str, List[str]]) -> IResponse: + """Handles requests to the token endpoint.""" + code = params.get("code", []) + + if len(code) != 1: + return FakeResponse.json(code=400, payload={"error": "invalid_request"}) + + grant = self.exchange_code(code=code[0]) + if grant is None: + return FakeResponse.json(code=400, payload={"error": "invalid_grant"}) + + return FakeResponse.json(payload=grant) -- cgit 1.5.1 From 6a6e1e8c0711939338f25d8d41d1e4d33d984949 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Fri, 28 Oct 2022 10:53:34 +0000 Subject: Fix room creation being rate limited too aggressively since Synapse v1.69.0. (#14314) * Introduce a test for the old behaviour which we want to restore * Reintroduce the old behaviour in a simpler way * Newsfile Signed-off-by: Olivier Wilkinson (reivilibre) * Use 1 credit instead of 2 for creating a room: be more lenient than before Notably, the UI in Element Web was still broken after restoring to prior behaviour. After discussion, we agreed that it would be sensible to increase the limit. Signed-off-by: Olivier Wilkinson (reivilibre) --- changelog.d/14314.bugfix | 1 + synapse/api/ratelimiting.py | 8 +++++- synapse/handlers/room.py | 16 ++++++++---- tests/rest/client/test_rooms.py | 54 ++++++++++++++++++++++++++++++++++++++--- 4 files changed, 70 insertions(+), 9 deletions(-) create mode 100644 changelog.d/14314.bugfix (limited to 'tests/rest') diff --git a/changelog.d/14314.bugfix b/changelog.d/14314.bugfix new file mode 100644 index 0000000000..8be47ee083 --- /dev/null +++ b/changelog.d/14314.bugfix @@ -0,0 +1 @@ +Fix room creation being rate limited too aggressively since Synapse v1.69.0. \ No newline at end of file diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 044c7d4926..511790c7c5 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -343,6 +343,7 @@ class RequestRatelimiter: requester: Requester, update: bool = True, is_admin_redaction: bool = False, + n_actions: int = 1, ) -> None: """Ratelimits requests. @@ -355,6 +356,8 @@ class RequestRatelimiter: is_admin_redaction: Whether this is a room admin/moderator redacting an event. If so then we may apply different ratelimits depending on config. + n_actions: Multiplier for the number of actions to apply to the + rate limiter at once. Raises: LimitExceededError if the request should be ratelimited @@ -383,7 +386,9 @@ class RequestRatelimiter: if is_admin_redaction and self.admin_redaction_ratelimiter: # If we have separate config for admin redactions, use a separate # ratelimiter as to not have user_ids clash - await self.admin_redaction_ratelimiter.ratelimit(requester, update=update) + await self.admin_redaction_ratelimiter.ratelimit( + requester, update=update, n_actions=n_actions + ) else: # Override rate and burst count per-user await self.request_ratelimiter.ratelimit( @@ -391,4 +396,5 @@ class RequestRatelimiter: rate_hz=messages_per_second, burst_count=burst_count, update=update, + n_actions=n_actions, ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 638f54051a..d74b675adc 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -559,7 +559,6 @@ class RoomCreationHandler: invite_list=[], initial_state=initial_state, creation_content=creation_content, - ratelimit=False, ) # Transfer membership events @@ -753,6 +752,10 @@ class RoomCreationHandler: ) if ratelimit: + # Rate limit once in advance, but don't rate limit the individual + # events in the room — room creation isn't atomic and it's very + # janky if half the events in the initial state don't make it because + # of rate limiting. await self.request_ratelimiter.ratelimit(requester) room_version_id = config.get( @@ -913,7 +916,6 @@ class RoomCreationHandler: room_alias=room_alias, power_level_content_override=power_level_content_override, creator_join_profile=creator_join_profile, - ratelimit=ratelimit, ) if "name" in config: @@ -1037,7 +1039,6 @@ class RoomCreationHandler: room_alias: Optional[RoomAlias] = None, power_level_content_override: Optional[JsonDict] = None, creator_join_profile: Optional[JsonDict] = None, - ratelimit: bool = True, ) -> Tuple[int, str, int]: """Sends the initial events into a new room. Sends the room creation, membership, and power level events into the room sequentially, then creates and batches up the @@ -1046,6 +1047,8 @@ class RoomCreationHandler: `power_level_content_override` doesn't apply when initial state has power level state event content. + Rate limiting should already have been applied by this point. + Returns: A tuple containing the stream ID, event ID and depth of the last event sent to the room. @@ -1144,7 +1147,7 @@ class RoomCreationHandler: creator.user, room_id, "join", - ratelimit=ratelimit, + ratelimit=False, content=creator_join_profile, new_room=True, prev_event_ids=[last_sent_event_id], @@ -1269,7 +1272,10 @@ class RoomCreationHandler: events_to_send.append((encryption_event, encryption_context)) last_event = await self.event_creation_handler.handle_new_client_event( - creator, events_to_send, ignore_shadow_ban=True + creator, + events_to_send, + ignore_shadow_ban=True, + ratelimit=False, ) assert last_event.internal_metadata.stream_ordering is not None return last_event.internal_metadata.stream_ordering, last_event.event_id, depth diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 716366eb90..1084d4ad9d 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -54,6 +54,7 @@ from tests.http.server._base import make_request_with_cancellation_test from tests.storage.test_stream import PaginationTestCase from tests.test_utils import make_awaitable from tests.test_utils.event_injection import create_event +from tests.unittest import override_config PATH_PREFIX = b"/_matrix/client/api/v1" @@ -871,6 +872,41 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) self.assertEqual(join_mock.call_count, 0) + def _create_basic_room(self) -> Tuple[int, object]: + """ + Tries to create a basic room and returns the response code. + """ + channel = self.make_request( + "POST", + "/createRoom", + {}, + ) + return channel.code, channel.json_body + + @override_config( + { + "rc_message": {"per_second": 0.2, "burst_count": 10}, + } + ) + def test_room_creation_ratelimiting(self) -> None: + """ + Regression test for #14312, where ratelimiting was made too strict. + Clients should be able to create 10 rooms in a row + without hitting rate limits, using default rate limit config. + (We override rate limiting config back to its default value.) + + To ensure we don't make ratelimiting too generous accidentally, + also check that we can't create an 11th room. + """ + + for _ in range(10): + code, json_body = self._create_basic_room() + self.assertEqual(code, HTTPStatus.OK, json_body) + + # The 6th room hits the rate limit. + code, json_body = self._create_basic_room() + self.assertEqual(code, HTTPStatus.TOO_MANY_REQUESTS, json_body) + class RoomTopicTestCase(RoomBase): """Tests /rooms/$room_id/topic REST events.""" @@ -1390,10 +1426,22 @@ class RoomJoinRatelimitTestCase(RoomBase): ) def test_join_local_ratelimit(self) -> None: """Tests that local joins are actually rate-limited.""" - for _ in range(3): - self.helper.create_room_as(self.user_id) + # Create 4 rooms + room_ids = [ + self.helper.create_room_as(self.user_id, is_public=True) for _ in range(4) + ] + + joiner_user_id = self.register_user("joiner", "secret") + # Now make a new user try to join some of them. - self.helper.create_room_as(self.user_id, expect_code=429) + # The user can join 3 rooms + for room_id in room_ids[0:3]: + self.helper.join(room_id, joiner_user_id) + + # But the user cannot join a 4th room + self.helper.join( + room_ids[3], joiner_user_id, expect_code=HTTPStatus.TOO_MANY_REQUESTS + ) @unittest.override_config( {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} -- cgit 1.5.1 From cc3a52b33df72bb4230367536b924a6d1f510d36 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 31 Oct 2022 18:07:30 +0100 Subject: Support OIDC backchannel logouts (#11414) If configured an OIDC IdP can log a user's session out of Synapse when they log out of the identity provider. The IdP sends a request directly to Synapse (and must be configured with an endpoint) when a user logs out. --- changelog.d/11414.feature | 1 + docs/openid.md | 14 + docs/usage/configuration/config_documentation.md | 9 + synapse/config/oidc.py | 12 + synapse/handlers/oidc.py | 381 ++++++++++++++++++-- synapse/handlers/sso.py | 71 ++++ synapse/rest/synapse/client/oidc/__init__.py | 4 + .../client/oidc/backchannel_logout_resource.py | 35 ++ synapse/storage/databases/main/registration.py | 21 ++ tests/rest/client/test_auth.py | 390 +++++++++++++++++++-- tests/rest/client/utils.py | 55 ++- tests/server.py | 6 + tests/test_utils/oidc.py | 27 +- 13 files changed, 960 insertions(+), 66 deletions(-) create mode 100644 changelog.d/11414.feature create mode 100644 synapse/rest/synapse/client/oidc/backchannel_logout_resource.py (limited to 'tests/rest') diff --git a/changelog.d/11414.feature b/changelog.d/11414.feature new file mode 100644 index 0000000000..fc035e50a7 --- /dev/null +++ b/changelog.d/11414.feature @@ -0,0 +1 @@ +Support back-channel logouts from OpenID Connect providers. diff --git a/docs/openid.md b/docs/openid.md index 87ebea4c29..37c5eb244d 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -49,6 +49,13 @@ setting in your configuration file. See the [configuration manual](usage/configuration/config_documentation.md#oidc_providers) for some sample settings, as well as the text below for example configurations for specific providers. +## OIDC Back-Channel Logout + +Synapse supports receiving [OpenID Connect Back-Channel Logout](https://openid.net/specs/openid-connect-backchannel-1_0.html) notifications. + +This lets the OpenID Connect Provider notify Synapse when a user logs out, so that Synapse can end that user session. +This feature can be enabled by setting the `backchannel_logout_enabled` property to `true` in the provider configuration, and setting the following URL as destination for Back-Channel Logout notifications in your OpenID Connect Provider: `[synapse public baseurl]/_synapse/client/oidc/backchannel_logout` + ## Sample configs Here are a few configs for providers that should work with Synapse. @@ -123,6 +130,9 @@ oidc_providers: [Keycloak][keycloak-idp] is an opensource IdP maintained by Red Hat. +Keycloak supports OIDC Back-Channel Logout, which sends logout notification to Synapse, so that Synapse users get logged out when they log out from Keycloak. +This can be optionally enabled by setting `backchannel_logout_enabled` to `true` in the Synapse configuration, and by setting the "Backchannel Logout URL" in Keycloak. + Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to install Keycloak and set up a realm. 1. Click `Clients` in the sidebar and click `Create` @@ -144,6 +154,8 @@ Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to | Client Protocol | `openid-connect` | | Access Type | `confidential` | | Valid Redirect URIs | `[synapse public baseurl]/_synapse/client/oidc/callback` | +| Backchannel Logout URL (optional) | `[synapse public baseurl]/_synapse/client/oidc/backchannel_logout` | +| Backchannel Logout Session Required (optional) | `On` | 5. Click `Save` 6. On the Credentials tab, update the fields: @@ -167,7 +179,9 @@ oidc_providers: config: localpart_template: "{{ user.preferred_username }}" display_name_template: "{{ user.name }}" + backchannel_logout_enabled: true # Optional ``` + ### Auth0 [Auth0][auth0] is a hosted SaaS IdP solution. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 97fb505a5f..44358faf59 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -3021,6 +3021,15 @@ Options for each entry include: which is set to the claims returned by the UserInfo Endpoint and/or in the ID Token. +* `backchannel_logout_enabled`: set to `true` to process OIDC Back-Channel Logout notifications. + Those notifications are expected to be received on `/_synapse/client/oidc/backchannel_logout`. + Defaults to `false`. + +* `backchannel_logout_ignore_sub`: by default, the OIDC Back-Channel Logout feature checks that the + `sub` claim matches the subject claim received during login. This check can be disabled by setting + this to `true`. Defaults to `false`. + + You might want to disable this if the `subject_claim` returned by the mapping provider is not `sub`. It is possible to configure Synapse to only allow logins if certain attributes match particular values in the OIDC userinfo. The requirements can be listed under diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py index 5418a332da..0bd83f4010 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py @@ -123,6 +123,8 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { "userinfo_endpoint": {"type": "string"}, "jwks_uri": {"type": "string"}, "skip_verification": {"type": "boolean"}, + "backchannel_logout_enabled": {"type": "boolean"}, + "backchannel_logout_ignore_sub": {"type": "boolean"}, "user_profile_method": { "type": "string", "enum": ["auto", "userinfo_endpoint"], @@ -292,6 +294,10 @@ def _parse_oidc_config_dict( token_endpoint=oidc_config.get("token_endpoint"), userinfo_endpoint=oidc_config.get("userinfo_endpoint"), jwks_uri=oidc_config.get("jwks_uri"), + backchannel_logout_enabled=oidc_config.get("backchannel_logout_enabled", False), + backchannel_logout_ignore_sub=oidc_config.get( + "backchannel_logout_ignore_sub", False + ), skip_verification=oidc_config.get("skip_verification", False), user_profile_method=oidc_config.get("user_profile_method", "auto"), allow_existing_users=oidc_config.get("allow_existing_users", False), @@ -368,6 +374,12 @@ class OidcProviderConfig: # "openid" scope is used. jwks_uri: Optional[str] + # Whether Synapse should react to backchannel logouts + backchannel_logout_enabled: bool + + # Whether Synapse should ignore the `sub` claim in backchannel logouts or not. + backchannel_logout_ignore_sub: bool + # Whether to skip metadata verification skip_verification: bool diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 9759daf043..867973dcca 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -12,14 +12,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import binascii import inspect +import json import logging -from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + List, + Optional, + Type, + TypeVar, + Union, +) from urllib.parse import urlencode, urlparse import attr +import unpaddedbase64 from authlib.common.security import generate_token -from authlib.jose import JsonWebToken, jwt +from authlib.jose import JsonWebToken, JWTClaims +from authlib.jose.errors import InvalidClaimError, JoseError, MissingClaimError from authlib.oauth2.auth import ClientAuth from authlib.oauth2.rfc6749.parameters import prepare_grant_uri from authlib.oidc.core import CodeIDToken, UserInfo @@ -35,9 +49,12 @@ from typing_extensions import TypedDict from twisted.web.client import readBody from twisted.web.http_headers import Headers +from synapse.api.errors import SynapseError from synapse.config import ConfigError from synapse.config.oidc import OidcProviderClientSecretJwtKey, OidcProviderConfig from synapse.handlers.sso import MappingException, UserAttributes +from synapse.http.server import finish_request +from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart @@ -88,6 +105,8 @@ class Token(TypedDict): #: there is no real point of doing this in our case. JWK = Dict[str, str] +C = TypeVar("C") + #: A JWK Set, as per RFC7517 sec 5. class JWKS(TypedDict): @@ -247,6 +266,80 @@ class OidcHandler: await oidc_provider.handle_oidc_callback(request, session_data, code) + async def handle_backchannel_logout(self, request: SynapseRequest) -> None: + """Handle an incoming request to /_synapse/client/oidc/backchannel_logout + + This extracts the logout_token from the request and tries to figure out + which OpenID Provider it is comming from. This works by matching the iss claim + with the issuer and the aud claim with the client_id. + + Since at this point we don't know who signed the JWT, we can't just + decode it using authlib since it will always verifies the signature. We + have to decode it manually without validating the signature. The actual JWT + verification is done in the `OidcProvider.handler_backchannel_logout` method, + once we figured out which provider sent the request. + + Args: + request: the incoming request from the browser. + """ + logout_token = parse_string(request, "logout_token") + if logout_token is None: + raise SynapseError(400, "Missing logout_token in request") + + # A JWT looks like this: + # header.payload.signature + # where all parts are encoded with urlsafe base64. + # The aud and iss claims we care about are in the payload part, which + # is a JSON object. + try: + # By destructuring the list after splitting, we ensure that we have + # exactly 3 segments + _, payload, _ = logout_token.split(".") + except ValueError: + raise SynapseError(400, "Invalid logout_token in request") + + try: + payload_bytes = unpaddedbase64.decode_base64(payload) + claims = json_decoder.decode(payload_bytes.decode("utf-8")) + except (json.JSONDecodeError, binascii.Error, UnicodeError): + raise SynapseError(400, "Invalid logout_token payload in request") + + try: + # Let's extract the iss and aud claims + iss = claims["iss"] + aud = claims["aud"] + # The aud claim can be either a string or a list of string. Here we + # normalize it as a list of strings. + if isinstance(aud, str): + aud = [aud] + + # Check that we have the right types for the aud and the iss claims + if not isinstance(iss, str) or not isinstance(aud, list): + raise TypeError() + for a in aud: + if not isinstance(a, str): + raise TypeError() + + # At this point we properly checked both claims types + issuer: str = iss + audience: List[str] = aud + except (TypeError, KeyError): + raise SynapseError(400, "Invalid issuer/audience in logout_token") + + # Now that we know the audience and the issuer, we can figure out from + # what provider it is coming from + oidc_provider: Optional[OidcProvider] = None + for provider in self._providers.values(): + if provider.issuer == issuer and provider.client_id in audience: + oidc_provider = provider + break + + if oidc_provider is None: + raise SynapseError(400, "Could not find the OP that issued this event") + + # Ask the provider to handle the logout request. + await oidc_provider.handle_backchannel_logout(request, logout_token) + class OidcError(Exception): """Used to catch errors when calling the token_endpoint""" @@ -342,6 +435,7 @@ class OidcProvider: self.idp_brand = provider.idp_brand self._sso_handler = hs.get_sso_handler() + self._device_handler = hs.get_device_handler() self._sso_handler.register_identity_provider(self) @@ -400,6 +494,41 @@ class OidcProvider: # If we're not using userinfo, we need a valid jwks to validate the ID token m.validate_jwks_uri() + if self._config.backchannel_logout_enabled: + if not m.get("backchannel_logout_supported", False): + logger.warning( + "OIDC Back-Channel Logout is enabled for issuer %r" + "but it does not advertise support for it", + self.issuer, + ) + + elif not m.get("backchannel_logout_session_supported", False): + logger.warning( + "OIDC Back-Channel Logout is enabled and supported " + "by issuer %r but it might not send a session ID with " + "logout tokens, which is required for the logouts to work", + self.issuer, + ) + + if not self._config.backchannel_logout_ignore_sub: + # If OIDC backchannel logouts are enabled, the provider mapping provider + # should use the `sub` claim. We verify that by mapping a dumb user and + # see if we get back the sub claim + user = UserInfo({"sub": "thisisasubject"}) + try: + subject = self._user_mapping_provider.get_remote_user_id(user) + if subject != user["sub"]: + raise ValueError("Unexpected subject") + except Exception: + logger.warning( + f"OIDC Back-Channel Logout is enabled for issuer {self.issuer!r} " + "but it looks like the configured `user_mapping_provider` " + "does not use the `sub` claim as subject. If it is the case, " + "and you want Synapse to ignore the `sub` claim in OIDC " + "Back-Channel Logouts, set `backchannel_logout_ignore_sub` " + "to `true` in the issuer config." + ) + @property def _uses_userinfo(self) -> bool: """Returns True if the ``userinfo_endpoint`` should be used. @@ -415,6 +544,16 @@ class OidcProvider: or self._user_profile_method == "userinfo_endpoint" ) + @property + def issuer(self) -> str: + """The issuer identifying this provider.""" + return self._config.issuer + + @property + def client_id(self) -> str: + """The client_id used when interacting with this provider.""" + return self._config.client_id + async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata: """Return the provider metadata. @@ -662,6 +801,59 @@ class OidcProvider: return UserInfo(resp) + async def _verify_jwt( + self, + alg_values: List[str], + token: str, + claims_cls: Type[C], + claims_options: Optional[dict] = None, + claims_params: Optional[dict] = None, + ) -> C: + """Decode and validate a JWT, re-fetching the JWKS as needed. + + Args: + alg_values: list of `alg` values allowed when verifying the JWT. + token: the JWT. + claims_cls: the JWTClaims class to use to validate the claims. + claims_options: dict of options passed to the `claims_cls` constructor. + claims_params: dict of params passed to the `claims_cls` constructor. + + Returns: + The decoded claims in the JWT. + """ + jwt = JsonWebToken(alg_values) + + logger.debug("Attempting to decode JWT (%s) %r", claims_cls.__name__, token) + + # Try to decode the keys in cache first, then retry by forcing the keys + # to be reloaded + jwk_set = await self.load_jwks() + try: + claims = jwt.decode( + token, + key=jwk_set, + claims_cls=claims_cls, + claims_options=claims_options, + claims_params=claims_params, + ) + except ValueError: + logger.info("Reloading JWKS after decode error") + jwk_set = await self.load_jwks(force=True) # try reloading the jwks + claims = jwt.decode( + token, + key=jwk_set, + claims_cls=claims_cls, + claims_options=claims_options, + claims_params=claims_params, + ) + + logger.debug("Decoded JWT (%s) %r; validating", claims_cls.__name__, claims) + + claims.validate( + now=self._clock.time(), leeway=120 + ) # allows 2 min of clock skew + return claims + async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken: """Return an instance of UserInfo from token's ``id_token``. @@ -675,13 +867,13 @@ class OidcProvider: The decoded claims in the ID token. """ id_token = token.get("id_token") - logger.debug("Attempting to decode JWT id_token %r", id_token) # That has been theoritically been checked by the caller, so even though # assertion are not enabled in production, it is mainly here to appease mypy assert id_token is not None metadata = await self.load_metadata() + claims_params = { "nonce": nonce, "client_id": self._client_auth.client_id, @@ -691,38 +883,17 @@ class OidcProvider: # in the `id_token` that we can check against. claims_params["access_token"] = token["access_token"] - alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) - jwt = JsonWebToken(alg_values) - - claim_options = {"iss": {"values": [metadata["issuer"]]}} + claims_options = {"iss": {"values": [metadata["issuer"]]}} - # Try to decode the keys in cache first, then retry by forcing the keys - # to be reloaded - jwk_set = await self.load_jwks() - try: - claims = jwt.decode( - id_token, - key=jwk_set, - claims_cls=CodeIDToken, - claims_options=claim_options, - claims_params=claims_params, - ) - except ValueError: - logger.info("Reloading JWKS after decode error") - jwk_set = await self.load_jwks(force=True) # try reloading the jwks - claims = jwt.decode( - id_token, - key=jwk_set, - claims_cls=CodeIDToken, - claims_options=claim_options, - claims_params=claims_params, - ) - - logger.debug("Decoded id_token JWT %r; validating", claims) + alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) - claims.validate( - now=self._clock.time(), leeway=120 - ) # allows 2 min of clock skew + claims = await self._verify_jwt( + alg_values=alg_values, + token=id_token, + claims_cls=CodeIDToken, + claims_options=claims_options, + claims_params=claims_params, + ) return claims @@ -1043,6 +1214,146 @@ class OidcProvider: # to be strings. return str(remote_user_id) + async def handle_backchannel_logout( + self, request: SynapseRequest, logout_token: str + ) -> None: + """Handle an incoming request to /_synapse/client/oidc/backchannel_logout + + The OIDC Provider posts a logout token to this endpoint when a user + session ends. That token is a JWT signed with the same keys as + ID tokens. The OpenID Connect Back-Channel Logout draft explains how to + validate the JWT and figure out what session to end. + + Args: + request: The request to respond to + logout_token: The logout token (a JWT) extracted from the request body + """ + # Back-Channel Logout can be disabled in the config, hence this check. + # This is not that important for now since Synapse is registered + # manually to the OP, so not specifying the backchannel-logout URI is + # as effective than disabling it here. It might make more sense if we + # support dynamic registration in Synapse at some point. + if not self._config.backchannel_logout_enabled: + logger.warning( + f"Received an OIDC Back-Channel Logout request from issuer {self.issuer!r} but it is disabled in config" + ) + + # TODO: this responds with a 400 status code, which is what the OIDC + # Back-Channel Logout spec expects, but spec also suggests answering with + # a JSON object, with the `error` and `error_description` fields set, which + # we are not doing here. + # See https://openid.net/specs/openid-connect-backchannel-1_0.html#BCResponse + raise SynapseError( + 400, "OpenID Connect Back-Channel Logout is disabled for this provider" + ) + + metadata = await self.load_metadata() + + # As per OIDC Back-Channel Logout 1.0 sec. 2.4: + # A Logout Token MUST be signed and MAY also be encrypted. The same + # keys are used to sign and encrypt Logout Tokens as are used for ID + # Tokens. If the Logout Token is encrypted, it SHOULD replicate the + # iss (issuer) claim in the JWT Header Parameters, as specified in + # Section 5.3 of [JWT]. + alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) + + # As per sec. 2.6: + # 3. Validate the iss, aud, and iat Claims in the same way they are + # validated in ID Tokens. + # Which means the audience should contain Synapse's client_id and the + # issuer should be the IdP issuer + claims_options = { + "iss": {"values": [metadata["issuer"]]}, + "aud": {"values": [self.client_id]}, + } + + try: + claims = await self._verify_jwt( + alg_values=alg_values, + token=logout_token, + claims_cls=LogoutToken, + claims_options=claims_options, + ) + except JoseError: + logger.exception("Invalid logout_token") + raise SynapseError(400, "Invalid logout_token") + + # As per sec. 2.6: + # 4. Verify that the Logout Token contains a sub Claim, a sid Claim, + # or both. + # 5. Verify that the Logout Token contains an events Claim whose + # value is JSON object containing the member name + # http://schemas.openid.net/event/backchannel-logout. + # 6. Verify that the Logout Token does not contain a nonce Claim. + # This is all verified by the LogoutToken claims class, so at this + # point the `sid` claim exists and is a string. + sid: str = claims.get("sid") + + # If the `sub` claim was included in the logout token, we check that it matches + # that it matches the right user. We can have cases where the `sub` claim is not + # the ID saved in database, so we let admins disable this check in config. + sub: Optional[str] = claims.get("sub") + expected_user_id: Optional[str] = None + if sub is not None and not self._config.backchannel_logout_ignore_sub: + expected_user_id = await self._store.get_user_by_external_id( + self.idp_id, sub + ) + + # Invalidate any running user-mapping sessions, in-flight login tokens and + # active devices + await self._sso_handler.revoke_sessions_for_provider_session_id( + auth_provider_id=self.idp_id, + auth_provider_session_id=sid, + expected_user_id=expected_user_id, + ) + + request.setResponseCode(200) + request.setHeader(b"Cache-Control", b"no-cache, no-store") + request.setHeader(b"Pragma", b"no-cache") + finish_request(request) + + +class LogoutToken(JWTClaims): + """ + Holds and verify claims of a logout token, as per + https://openid.net/specs/openid-connect-backchannel-1_0.html#LogoutToken + """ + + REGISTERED_CLAIMS = ["iss", "sub", "aud", "iat", "jti", "events", "sid"] + + def validate(self, now: Optional[int] = None, leeway: int = 0) -> None: + """Validate everything in claims payload.""" + super().validate(now, leeway) + self.validate_sid() + self.validate_events() + self.validate_nonce() + + def validate_sid(self) -> None: + """Ensure the sid claim is present""" + sid = self.get("sid") + if not sid: + raise MissingClaimError("sid") + + if not isinstance(sid, str): + raise InvalidClaimError("sid") + + def validate_nonce(self) -> None: + """Ensure the nonce claim is absent""" + if "nonce" in self: + raise InvalidClaimError("nonce") + + def validate_events(self) -> None: + """Ensure the events claim is present and with the right value""" + events = self.get("events") + if not events: + raise MissingClaimError("events") + + if not isinstance(events, dict): + raise InvalidClaimError("events") + + if "http://schemas.openid.net/event/backchannel-logout" not in events: + raise InvalidClaimError("events") + # number of seconds a newly-generated client secret should be valid for CLIENT_SECRET_VALIDITY_SECONDS = 3600 @@ -1112,6 +1423,7 @@ class JwtClientSecret: logger.info( "Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload ) + jwt = JsonWebToken(header["alg"]) self._cached_secret = jwt.encode(header, payload, self._key.key) self._cached_secret_replacement_time = ( expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS @@ -1126,9 +1438,6 @@ class UserAttributeDict(TypedDict): emails: List[str] -C = TypeVar("C") - - class OidcMappingProvider(Generic[C]): """A mapping provider maps a UserInfo object to user attributes. diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 5943f08e91..749d7e93b0 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -191,6 +191,7 @@ class SsoHandler: self._server_name = hs.hostname self._registration_handler = hs.get_registration_handler() self._auth_handler = hs.get_auth_handler() + self._device_handler = hs.get_device_handler() self._error_template = hs.config.sso.sso_error_template self._bad_user_template = hs.config.sso.sso_auth_bad_user_template self._profile_handler = hs.get_profile_handler() @@ -1026,6 +1027,76 @@ class SsoHandler: return True + async def revoke_sessions_for_provider_session_id( + self, + auth_provider_id: str, + auth_provider_session_id: str, + expected_user_id: Optional[str] = None, + ) -> None: + """Revoke any devices and in-flight logins tied to a provider session. + + Args: + auth_provider_id: A unique identifier for this SSO provider, e.g. + "oidc" or "saml". + auth_provider_session_id: The session ID from the provider to logout + expected_user_id: The user we're expecting to logout. If set, it will ignore + sessions belonging to other users and log an error. + """ + # Invalidate any running user-mapping sessions + to_delete = [] + for session_id, session in self._username_mapping_sessions.items(): + if ( + session.auth_provider_id == auth_provider_id + and session.auth_provider_session_id == auth_provider_session_id + ): + to_delete.append(session_id) + + for session_id in to_delete: + logger.info("Revoking mapping session %s", session_id) + del self._username_mapping_sessions[session_id] + + # Invalidate any in-flight login tokens + await self._store.invalidate_login_tokens_by_session_id( + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) + + # Fetch any device(s) in the store associated with the session ID. + devices = await self._store.get_devices_by_auth_provider_session_id( + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) + + # We have no guarantee that all the devices of that session are for the same + # `user_id`. Hence, we have to iterate over the list of devices and log them out + # one by one. + for device in devices: + user_id = device["user_id"] + device_id = device["device_id"] + + # If the user_id associated with that device/session is not the one we got + # out of the `sub` claim, skip that device and show log an error. + if expected_user_id is not None and user_id != expected_user_id: + logger.error( + "Received a logout notification from SSO provider " + f"{auth_provider_id!r} for the user {expected_user_id!r}, but with " + f"a session ID ({auth_provider_session_id!r}) which belongs to " + f"{user_id!r}. This may happen when the SSO provider user mapper " + "uses something else than the standard attribute as mapping ID. " + "For OIDC providers, set `backchannel_logout_ignore_sub` to `true` " + "in the provider config if that is the case." + ) + continue + + logger.info( + "Logging out %r (device %r) via SSO (%r) logout notification (session %r).", + user_id, + device_id, + auth_provider_id, + auth_provider_session_id, + ) + await self._device_handler.delete_devices(user_id, [device_id]) + def get_username_mapping_session_cookie_from_request(request: IRequest) -> str: """Extract the session ID from the cookie diff --git a/synapse/rest/synapse/client/oidc/__init__.py b/synapse/rest/synapse/client/oidc/__init__.py index 81fec39659..e4b28ce3df 100644 --- a/synapse/rest/synapse/client/oidc/__init__.py +++ b/synapse/rest/synapse/client/oidc/__init__.py @@ -17,6 +17,9 @@ from typing import TYPE_CHECKING from twisted.web.resource import Resource +from synapse.rest.synapse.client.oidc.backchannel_logout_resource import ( + OIDCBackchannelLogoutResource, +) from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource if TYPE_CHECKING: @@ -29,6 +32,7 @@ class OIDCResource(Resource): def __init__(self, hs: "HomeServer"): Resource.__init__(self) self.putChild(b"callback", OIDCCallbackResource(hs)) + self.putChild(b"backchannel_logout", OIDCBackchannelLogoutResource(hs)) __all__ = ["OIDCResource"] diff --git a/synapse/rest/synapse/client/oidc/backchannel_logout_resource.py b/synapse/rest/synapse/client/oidc/backchannel_logout_resource.py new file mode 100644 index 0000000000..e07e76855a --- /dev/null +++ b/synapse/rest/synapse/client/oidc/backchannel_logout_resource.py @@ -0,0 +1,35 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING + +from synapse.http.server import DirectServeJsonResource +from synapse.http.site import SynapseRequest + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class OIDCBackchannelLogoutResource(DirectServeJsonResource): + isLeaf = 1 + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._oidc_handler = hs.get_oidc_handler() + + async def _async_render_POST(self, request: SynapseRequest) -> None: + await self._oidc_handler.handle_backchannel_logout(request) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 0255295317..5167089e03 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1920,6 +1920,27 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): self._clock.time_msec(), ) + async def invalidate_login_tokens_by_session_id( + self, auth_provider_id: str, auth_provider_session_id: str + ) -> None: + """Invalidate login tokens with the given IdP session ID. + + Args: + auth_provider_id: The SSO Identity Provider that the user authenticated with + to get this token + auth_provider_session_id: The session ID advertised by the SSO Identity + Provider + """ + await self.db_pool.simple_update( + table="login_tokens", + keyvalues={ + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, + }, + updatevalues={"used_ts": self._clock.time_msec()}, + desc="invalidate_login_tokens_by_session_id", + ) + @cached() async def is_guest(self, user_id: str) -> bool: res = await self.db_pool.simple_select_one_onecol( diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index ebf653d018..847294dc8e 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re from http import HTTPStatus from typing import Any, Dict, List, Optional, Tuple, Union @@ -21,7 +22,7 @@ from twisted.web.resource import Resource import synapse.rest.admin from synapse.api.constants import ApprovalNoticeMedium, LoginType -from synapse.api.errors import Codes +from synapse.api.errors import Codes, SynapseError from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.rest.client import account, auth, devices, login, logout, register from synapse.rest.synapse.client import build_synapse_client_resource_tree @@ -32,8 +33,8 @@ from synapse.util import Clock from tests import unittest from tests.handlers.test_oidc import HAS_OIDC -from tests.rest.client.utils import TEST_OIDC_CONFIG -from tests.server import FakeChannel +from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER +from tests.server import FakeChannel, make_request from tests.unittest import override_config, skip_unless @@ -638,19 +639,6 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": refresh_token}, ) - def is_access_token_valid(self, access_token: str) -> bool: - """ - Checks whether an access token is valid, returning whether it is or not. - """ - code = self.make_request( - "GET", "/_matrix/client/v3/account/whoami", access_token=access_token - ).code - - # Either 200 or 401 is what we get back; anything else is a bug. - assert code in {HTTPStatus.OK, HTTPStatus.UNAUTHORIZED} - - return code == HTTPStatus.OK - def test_login_issue_refresh_token(self) -> None: """ A login response should include a refresh_token only if asked. @@ -847,29 +835,37 @@ class RefreshAuthTests(unittest.HomeserverTestCase): self.reactor.advance(59.0) # Both tokens should still be valid. - self.assertTrue(self.is_access_token_valid(refreshable_access_token)) - self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) + self.helper.whoami(refreshable_access_token, expect_code=HTTPStatus.OK) + self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK) # Advance to 61 s (just past 1 minute, the time of expiry) self.reactor.advance(2.0) # Only the non-refreshable token is still valid. - self.assertFalse(self.is_access_token_valid(refreshable_access_token)) - self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) + self.helper.whoami( + refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED + ) + self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK) # Advance to 599 s (just shy of 10 minutes, the time of expiry) self.reactor.advance(599.0 - 61.0) # It's still the case that only the non-refreshable token is still valid. - self.assertFalse(self.is_access_token_valid(refreshable_access_token)) - self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) + self.helper.whoami( + refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED + ) + self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK) # Advance to 601 s (just past 10 minutes, the time of expiry) self.reactor.advance(2.0) # Now neither token is valid. - self.assertFalse(self.is_access_token_valid(refreshable_access_token)) - self.assertFalse(self.is_access_token_valid(nonrefreshable_access_token)) + self.helper.whoami( + refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED + ) + self.helper.whoami( + nonrefreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED + ) @override_config( {"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"} @@ -1165,3 +1161,349 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # and no refresh token self.assertEqual(_table_length("access_tokens"), 0) self.assertEqual(_table_length("refresh_tokens"), 0) + + +def oidc_config( + id: str, with_localpart_template: bool, **kwargs: Any +) -> Dict[str, Any]: + """Sample OIDC provider config used in backchannel logout tests. + + Args: + id: IDP ID for this provider + with_localpart_template: Set to `true` to have a default localpart_template in + the `user_mapping_provider` config and skip the user mapping session + **kwargs: rest of the config + + Returns: + A dict suitable for the `oidc_config` or the `oidc_providers[]` parts of + the HS config + """ + config: Dict[str, Any] = { + "idp_id": id, + "idp_name": id, + "issuer": TEST_OIDC_ISSUER, + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": ["openid"], + } + + if with_localpart_template: + config["user_mapping_provider"] = { + "config": {"localpart_template": "{{ user.sub }}"} + } + else: + config["user_mapping_provider"] = {"config": {}} + + config.update(kwargs) + + return config + + +@skip_unless(HAS_OIDC, "Requires OIDC") +class OidcBackchannelLogoutTests(unittest.HomeserverTestCase): + servlets = [ + account.register_servlets, + login.register_servlets, + ] + + def default_config(self) -> Dict[str, Any]: + config = super().default_config() + + # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns + # False, so synapse will see the requested uri as http://..., so using http in + # the public_baseurl stops Synapse trying to redirect to https. + config["public_baseurl"] = "http://synapse.test" + + return config + + def create_resource_dict(self) -> Dict[str, Resource]: + resource_dict = super().create_resource_dict() + resource_dict.update(build_synapse_client_resource_tree(self.hs)) + return resource_dict + + def submit_logout_token(self, logout_token: str) -> FakeChannel: + return self.make_request( + "POST", + "/_synapse/client/oidc/backchannel_logout", + content=f"logout_token={logout_token}", + content_is_form=True, + ) + + @override_config( + { + "oidc_providers": [ + oidc_config( + id="oidc", + with_localpart_template=True, + backchannel_logout_enabled=True, + ) + ] + } + ) + def test_simple_logout(self) -> None: + """ + Receiving a logout token should logout the user + """ + fake_oidc_server = self.helper.fake_oidc_server() + user = "john" + + login_resp, first_grant = self.helper.login_via_oidc( + fake_oidc_server, user, with_sid=True + ) + first_access_token: str = login_resp["access_token"] + self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK) + + login_resp, second_grant = self.helper.login_via_oidc( + fake_oidc_server, user, with_sid=True + ) + second_access_token: str = login_resp["access_token"] + self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK) + + self.assertNotEqual(first_grant.sid, second_grant.sid) + self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"]) + + # Logging out of the first session + logout_token = fake_oidc_server.generate_logout_token(first_grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + self.helper.whoami(first_access_token, expect_code=HTTPStatus.UNAUTHORIZED) + self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK) + + # Logging out of the second session + logout_token = fake_oidc_server.generate_logout_token(second_grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + @override_config( + { + "oidc_providers": [ + oidc_config( + id="oidc", + with_localpart_template=True, + backchannel_logout_enabled=True, + ) + ] + } + ) + def test_logout_during_login(self) -> None: + """ + It should revoke login tokens when receiving a logout token + """ + fake_oidc_server = self.helper.fake_oidc_server() + user = "john" + + # Get an authentication, and logout before submitting the logout token + client_redirect_url = "https://x" + userinfo = {"sub": user} + channel, grant = self.helper.auth_via_oidc( + fake_oidc_server, + userinfo, + client_redirect_url, + with_sid=True, + ) + + # expect a confirmation page + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + # fish the matrix login token out of the body of the confirmation page + m = re.search( + 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,), + channel.text_body, + ) + assert m, channel.text_body + login_token = m.group(1) + + # Submit a logout + logout_token = fake_oidc_server.generate_logout_token(grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + # Now try to exchange the login token + channel = make_request( + self.hs.get_reactor(), + self.site, + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + # It should have failed + self.assertEqual(channel.code, 403) + + @override_config( + { + "oidc_providers": [ + oidc_config( + id="oidc", + with_localpart_template=False, + backchannel_logout_enabled=True, + ) + ] + } + ) + def test_logout_during_mapping(self) -> None: + """ + It should stop ongoing user mapping session when receiving a logout token + """ + fake_oidc_server = self.helper.fake_oidc_server() + user = "john" + + # Get an authentication, and logout before submitting the logout token + client_redirect_url = "https://x" + userinfo = {"sub": user} + channel, grant = self.helper.auth_via_oidc( + fake_oidc_server, + userinfo, + client_redirect_url, + with_sid=True, + ) + + # Expect a user mapping page + self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) + + # We should have a user_mapping_session cookie + cookie_headers = channel.headers.getRawHeaders("Set-Cookie") + assert cookie_headers + cookies: Dict[str, str] = {} + for h in cookie_headers: + key, value = h.split(";")[0].split("=", maxsplit=1) + cookies[key] = value + + user_mapping_session_id = cookies["username_mapping_session"] + + # Getting that session should not raise + session = self.hs.get_sso_handler().get_mapping_session(user_mapping_session_id) + self.assertIsNotNone(session) + + # Submit a logout + logout_token = fake_oidc_server.generate_logout_token(grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + # Now it should raise + with self.assertRaises(SynapseError): + self.hs.get_sso_handler().get_mapping_session(user_mapping_session_id) + + @override_config( + { + "oidc_providers": [ + oidc_config( + id="oidc", + with_localpart_template=True, + backchannel_logout_enabled=False, + ) + ] + } + ) + def test_disabled(self) -> None: + """ + Receiving a logout token should do nothing if it is disabled in the config + """ + fake_oidc_server = self.helper.fake_oidc_server() + user = "john" + + login_resp, grant = self.helper.login_via_oidc( + fake_oidc_server, user, with_sid=True + ) + access_token: str = login_resp["access_token"] + self.helper.whoami(access_token, expect_code=HTTPStatus.OK) + + # Logging out shouldn't work + logout_token = fake_oidc_server.generate_logout_token(grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 400) + + # And the token should still be valid + self.helper.whoami(access_token, expect_code=HTTPStatus.OK) + + @override_config( + { + "oidc_providers": [ + oidc_config( + id="oidc", + with_localpart_template=True, + backchannel_logout_enabled=True, + ) + ] + } + ) + def test_no_sid(self) -> None: + """ + Receiving a logout token without `sid` during the login should do nothing + """ + fake_oidc_server = self.helper.fake_oidc_server() + user = "john" + + login_resp, grant = self.helper.login_via_oidc( + fake_oidc_server, user, with_sid=False + ) + access_token: str = login_resp["access_token"] + self.helper.whoami(access_token, expect_code=HTTPStatus.OK) + + # Logging out shouldn't work + logout_token = fake_oidc_server.generate_logout_token(grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 400) + + # And the token should still be valid + self.helper.whoami(access_token, expect_code=HTTPStatus.OK) + + @override_config( + { + "oidc_providers": [ + oidc_config( + "first", + issuer="https://first-issuer.com/", + with_localpart_template=True, + backchannel_logout_enabled=True, + ), + oidc_config( + "second", + issuer="https://second-issuer.com/", + with_localpart_template=True, + backchannel_logout_enabled=True, + ), + ] + } + ) + def test_multiple_providers(self) -> None: + """ + It should be able to distinguish login tokens from two different IdPs + """ + first_server = self.helper.fake_oidc_server(issuer="https://first-issuer.com/") + second_server = self.helper.fake_oidc_server( + issuer="https://second-issuer.com/" + ) + user = "john" + + login_resp, first_grant = self.helper.login_via_oidc( + first_server, user, with_sid=True, idp_id="oidc-first" + ) + first_access_token: str = login_resp["access_token"] + self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK) + + login_resp, second_grant = self.helper.login_via_oidc( + second_server, user, with_sid=True, idp_id="oidc-second" + ) + second_access_token: str = login_resp["access_token"] + self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK) + + # `sid` in the fake providers are generated by a counter, so the first grant of + # each provider should give the same SID + self.assertEqual(first_grant.sid, second_grant.sid) + self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"]) + + # Logging out of the first session + logout_token = first_server.generate_logout_token(first_grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + self.helper.whoami(first_access_token, expect_code=HTTPStatus.UNAUTHORIZED) + self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK) + + # Logging out of the second session + logout_token = second_server.generate_logout_token(second_grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + self.helper.whoami(second_access_token, expect_code=HTTPStatus.UNAUTHORIZED) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 967d229223..706399fae5 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -553,6 +553,34 @@ class RestHelper: return channel.json_body + def whoami( + self, + access_token: str, + expect_code: Literal[HTTPStatus.OK, HTTPStatus.UNAUTHORIZED] = HTTPStatus.OK, + ) -> JsonDict: + """Perform a 'whoami' request, which can be a quick way to check for access + token validity + + Args: + access_token: The user token to use during the request + expect_code: The return code to expect from attempting the whoami request + """ + channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + "account/whoami", + access_token=access_token, + ) + + assert channel.code == expect_code, "Exepcted: %d, got %d, resp: %r" % ( + expect_code, + channel.code, + channel.result["body"], + ) + + return channel.json_body + def fake_oidc_server(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcServer: """Create a ``FakeOidcServer``. @@ -572,6 +600,7 @@ class RestHelper: fake_server: FakeOidcServer, remote_user_id: str, with_sid: bool = False, + idp_id: Optional[str] = None, expected_status: int = 200, ) -> Tuple[JsonDict, FakeAuthorizationGrant]: """Log in (as a new user) via OIDC @@ -588,7 +617,11 @@ class RestHelper: client_redirect_url = "https://x" userinfo = {"sub": remote_user_id} channel, grant = self.auth_via_oidc( - fake_server, userinfo, client_redirect_url, with_sid=with_sid + fake_server, + userinfo, + client_redirect_url, + with_sid=with_sid, + idp_id=idp_id, ) # expect a confirmation page @@ -623,6 +656,7 @@ class RestHelper: client_redirect_url: Optional[str] = None, ui_auth_session_id: Optional[str] = None, with_sid: bool = False, + idp_id: Optional[str] = None, ) -> Tuple[FakeChannel, FakeAuthorizationGrant]: """Perform an OIDC authentication flow via a mock OIDC provider. @@ -648,6 +682,7 @@ class RestHelper: ui_auth_session_id: if set, we will perform a UI Auth flow. The session id of the UI auth. with_sid: if True, generates a random `sid` (OIDC session ID) + idp_id: if set, explicitely chooses one specific IDP Returns: A FakeChannel containing the result of calling the OIDC callback endpoint. @@ -665,7 +700,9 @@ class RestHelper: oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) else: # otherwise, hit the login redirect endpoint - oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) + oauth_uri = self.initiate_sso_login( + client_redirect_url, cookies, idp_id=idp_id + ) # we now have a URI for the OIDC IdP, but we skip that and go straight # back to synapse's OIDC callback resource. However, we do need the "state" @@ -742,7 +779,10 @@ class RestHelper: return channel, grant def initiate_sso_login( - self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str] + self, + client_redirect_url: Optional[str], + cookies: MutableMapping[str, str], + idp_id: Optional[str] = None, ) -> str: """Make a request to the login-via-sso redirect endpoint, and return the target @@ -753,6 +793,7 @@ class RestHelper: client_redirect_url: the client redirect URL to pass to the login redirect endpoint cookies: any cookies returned will be added to this dict + idp_id: if set, explicitely chooses one specific IDP Returns: the URI that the client gets redirected to (ie, the SSO server) @@ -761,6 +802,12 @@ class RestHelper: if client_redirect_url: params["redirectUrl"] = client_redirect_url + uri = "/_matrix/client/r0/login/sso/redirect" + if idp_id is not None: + uri = f"{uri}/{idp_id}" + + uri = f"{uri}?{urllib.parse.urlencode(params)}" + # hit the redirect url (which should redirect back to the redirect url. This # is the easiest way of figuring out what the Host header ought to be set to # to keep Synapse happy. @@ -768,7 +815,7 @@ class RestHelper: self.hs.get_reactor(), self.site, "GET", - "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params), + uri, ) assert channel.code == 302 diff --git a/tests/server.py b/tests/server.py index 8b1d186219..b1730fcc8d 100644 --- a/tests/server.py +++ b/tests/server.py @@ -362,6 +362,12 @@ def make_request( # Twisted expects to be at the end of the content when parsing the request. req.content.seek(0, SEEK_END) + # Old version of Twisted (<20.3.0) have issues with parsing x-www-form-urlencoded + # bodies if the Content-Length header is missing + req.requestHeaders.addRawHeader( + b"Content-Length", str(len(content)).encode("ascii") + ) + if access_token: req.requestHeaders.addRawHeader( b"Authorization", b"Bearer " + access_token.encode("ascii") diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py index de134bbc89..1461d23ee8 100644 --- a/tests/test_utils/oidc.py +++ b/tests/test_utils/oidc.py @@ -51,6 +51,8 @@ class FakeOidcServer: get_userinfo_handler: Mock post_token_handler: Mock + sid_counter: int = 0 + def __init__(self, clock: Clock, issuer: str): from authlib.jose import ECKey, KeySet @@ -146,7 +148,7 @@ class FakeOidcServer: return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8") def generate_id_token(self, grant: FakeAuthorizationGrant) -> str: - now = self._clock.time() + now = int(self._clock.time()) id_token = { **grant.userinfo, "iss": self.issuer, @@ -166,6 +168,26 @@ class FakeOidcServer: return self._sign(id_token) + def generate_logout_token(self, grant: FakeAuthorizationGrant) -> str: + now = int(self._clock.time()) + logout_token = { + "iss": self.issuer, + "aud": grant.client_id, + "iat": now, + "jti": random_string(10), + "events": { + "http://schemas.openid.net/event/backchannel-logout": {}, + }, + } + + if grant.sid is not None: + logout_token["sid"] = grant.sid + + if "sub" in grant.userinfo: + logout_token["sub"] = grant.userinfo["sub"] + + return self._sign(logout_token) + def id_token_override(self, overrides: dict): """Temporarily patch the ID token generated by the token endpoint.""" return patch.object(self, "_id_token_overrides", overrides) @@ -183,7 +205,8 @@ class FakeOidcServer: code = random_string(10) sid = None if with_sid: - sid = random_string(10) + sid = str(self.sid_counter) + self.sid_counter += 1 grant = FakeAuthorizationGrant( userinfo=userinfo, -- cgit 1.5.1 From dbfc9b803ee32f7b31c2b5ccbc53a1bfcaa95983 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 31 Oct 2022 20:31:43 +0000 Subject: Fix dehydrated device REST checks (#14336) --- changelog.d/14336.bugfix | 1 + synapse/rest/client/devices.py | 5 ++--- tests/rest/client/test_devices.py | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 3 deletions(-) create mode 100644 changelog.d/14336.bugfix (limited to 'tests/rest') diff --git a/changelog.d/14336.bugfix b/changelog.d/14336.bugfix new file mode 100644 index 0000000000..d44ff1bbc7 --- /dev/null +++ b/changelog.d/14336.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.70 where clients were unable to PUT new [dehydrated devices](https://github.com/matrix-org/matrix-spec-proposals/pull/2697). diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 90828c95c4..8f3cbd4ea2 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -231,7 +231,7 @@ class DehydratedDeviceServlet(RestServlet): } } - PUT /org.matrix.msc2697/dehydrated_device + PUT /org.matrix.msc2697.v2/dehydrated_device Content-Type: application/json { @@ -271,7 +271,6 @@ class DehydratedDeviceServlet(RestServlet): raise errors.NotFoundError("No dehydrated device available") class PutBody(RequestBodyModel): - device_id: StrictStr device_data: DehydratedDeviceDataModel initial_device_display_name: Optional[StrictStr] @@ -281,7 +280,7 @@ class DehydratedDeviceServlet(RestServlet): device_id = await self.device_handler.store_dehydrated_device( requester.user.to_string(), - submission.device_data, + submission.device_data.dict(), submission.initial_device_display_name, ) return 200, {"device_id": device_id} diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py index aa98222434..d80eea17d3 100644 --- a/tests/rest/client/test_devices.py +++ b/tests/rest/client/test_devices.py @@ -200,3 +200,37 @@ class DevicesTestCase(unittest.HomeserverTestCase): self.reactor.advance(43200) self.get_success(self.handler.get_device(user_id, "abc")) self.get_failure(self.handler.get_device(user_id, "def"), NotFoundError) + + +class DehydratedDeviceTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + devices.register_servlets, + ] + + def test_PUT(self) -> None: + """Sanity-check that we can PUT a dehydrated device. + + Detects https://github.com/matrix-org/synapse/issues/14334. + """ + alice = self.register_user("alice", "correcthorse") + token = self.login(alice, "correcthorse") + + # Have alice update their device list + channel = self.make_request( + "PUT", + "_matrix/client/unstable/org.matrix.msc2697.v2/dehydrated_device", + { + "device_data": { + "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm", + "account": "dehydrated_device", + } + }, + access_token=token, + shorthand=False, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + device_id = channel.json_body.get("device_id") + self.assertIsInstance(device_id, str) -- cgit 1.5.1 From 86c5a710d8b4212f8a8a668d7d4a79c0bb371508 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 3 Nov 2022 16:21:31 +0000 Subject: Implement MSC3912: Relation-based redactions (#14260) Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/14260.feature | 1 + synapse/api/constants.py | 2 + synapse/config/experimental.py | 3 + synapse/handlers/message.py | 47 ++++- synapse/handlers/relations.py | 56 +++++- synapse/rest/client/room.py | 57 ++++-- synapse/rest/client/versions.py | 2 + synapse/storage/databases/main/relations.py | 36 ++++ tests/rest/client/test_redactions.py | 273 +++++++++++++++++++++++++++- tests/rest/client/utils.py | 37 ++++ 10 files changed, 486 insertions(+), 28 deletions(-) create mode 100644 changelog.d/14260.feature (limited to 'tests/rest') diff --git a/changelog.d/14260.feature b/changelog.d/14260.feature new file mode 100644 index 0000000000..102dc7b3e0 --- /dev/null +++ b/changelog.d/14260.feature @@ -0,0 +1 @@ +Add experimental support for [MSC3912](https://github.com/matrix-org/matrix-spec-proposals/pull/3912): Relation-based redactions. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 44c5ffc6a5..bc04a0755b 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -125,6 +125,8 @@ class EventTypes: MSC2716_BATCH: Final = "org.matrix.msc2716.batch" MSC2716_MARKER: Final = "org.matrix.msc2716.marker" + Reaction: Final = "m.reaction" + class ToDeviceEventTypes: RoomKeyRequest: Final = "m.room_key_request" diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index d9bdd66d55..d4b71d1673 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -128,3 +128,6 @@ class ExperimentalConfig(Config): self.msc3886_endpoint: Optional[str] = experimental.get( "msc3886_endpoint", None ) + + # MSC3912: Relation-based redactions. + self.msc3912_enabled: bool = experimental.get("msc3912_enabled", False) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 468900a07f..4cf593cfdc 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -877,6 +877,36 @@ class EventCreationHandler: return prev_event return None + async def get_event_from_transaction( + self, + requester: Requester, + txn_id: str, + room_id: str, + ) -> Optional[EventBase]: + """For the given transaction ID and room ID, check if there is a matching event. + If so, fetch it and return it. + + Args: + requester: The requester making the request in the context of which we want + to fetch the event. + txn_id: The transaction ID. + room_id: The room ID. + + Returns: + An event if one could be found, None otherwise. + """ + if requester.access_token_id: + existing_event_id = await self.store.get_event_id_from_transaction_id( + room_id, + requester.user.to_string(), + requester.access_token_id, + txn_id, + ) + if existing_event_id: + return await self.store.get_event(existing_event_id) + + return None + async def create_and_send_nonmember_event( self, requester: Requester, @@ -956,18 +986,17 @@ class EventCreationHandler: # extremities to pile up, which in turn leads to state resolution # taking longer. async with self.limiter.queue(event_dict["room_id"]): - if txn_id and requester.access_token_id: - existing_event_id = await self.store.get_event_id_from_transaction_id( - event_dict["room_id"], - requester.user.to_string(), - requester.access_token_id, - txn_id, + if txn_id: + event = await self.get_event_from_transaction( + requester, txn_id, event_dict["room_id"] ) - if existing_event_id: - event = await self.store.get_event(existing_event_id) + if event: # we know it was persisted, so must have a stream ordering assert event.internal_metadata.stream_ordering - return event, event.internal_metadata.stream_ordering + return ( + event, + event.internal_metadata.stream_ordering, + ) event, context = await self.create_event( requester, diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 0a0c6d938e..8e71dda970 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tup import attr -from synapse.api.constants import RelationTypes +from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event from synapse.logging.opentracing import trace @@ -75,6 +75,7 @@ class RelationsHandler: self._clock = hs.get_clock() self._event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() + self._event_creation_handler = hs.get_event_creation_handler() async def get_relations( self, @@ -205,6 +206,59 @@ class RelationsHandler: return related_events, next_token + async def redact_events_related_to( + self, + requester: Requester, + event_id: str, + initial_redaction_event: EventBase, + relation_types: List[str], + ) -> None: + """Redacts all events related to the given event ID with one of the given + relation types. + + This method is expected to be called when redacting the event referred to by + the given event ID. + + If an event cannot be redacted (e.g. because of insufficient permissions), log + the error and try to redact the next one. + + Args: + requester: The requester to redact events on behalf of. + event_id: The event IDs to look and redact relations of. + initial_redaction_event: The redaction for the event referred to by + event_id. + relation_types: The types of relations to look for. + + Raises: + ShadowBanError if the requester is shadow-banned + """ + related_event_ids = ( + await self._main_store.get_all_relations_for_event_with_types( + event_id, relation_types + ) + ) + + for related_event_id in related_event_ids: + try: + await self._event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Redaction, + "content": initial_redaction_event.content, + "room_id": initial_redaction_event.room_id, + "sender": requester.user.to_string(), + "redacts": related_event_id, + }, + ratelimit=False, + ) + except SynapseError as e: + logger.warning( + "Failed to redact event %s (related to event %s): %s", + related_event_id, + event_id, + e.msg, + ) + async def get_annotations_for_event( self, event_id: str, diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 01e5079963..91cb791139 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -52,6 +52,7 @@ from synapse.http.servlet import ( from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import set_tag +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.client._base import client_patterns from synapse.rest.client.transactions import HttpTransactionCache from synapse.storage.state import StateFilter @@ -1029,6 +1030,8 @@ class RoomRedactEventRestServlet(TransactionRestServlet): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() + self._relation_handler = hs.get_relations_handler() + self._msc3912_enabled = hs.config.experimental.msc3912_enabled def register(self, http_server: HttpServer) -> None: PATTERNS = "/rooms/(?P[^/]*)/redact/(?P[^/]*)" @@ -1045,20 +1048,46 @@ class RoomRedactEventRestServlet(TransactionRestServlet): content = parse_json_object_from_request(request) try: - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, - { - "type": EventTypes.Redaction, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - "redacts": event_id, - }, - txn_id=txn_id, - ) + with_relations = None + if self._msc3912_enabled and "org.matrix.msc3912.with_relations" in content: + with_relations = content["org.matrix.msc3912.with_relations"] + del content["org.matrix.msc3912.with_relations"] + + # Check if there's an existing event for this transaction now (even though + # create_and_send_nonmember_event also does it) because, if there's one, + # then we want to skip the call to redact_events_related_to. + event = None + if txn_id: + event = await self.event_creation_handler.get_event_from_transaction( + requester, txn_id, room_id + ) + + if event is None: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Redaction, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + "redacts": event_id, + }, + txn_id=txn_id, + ) + + if with_relations: + run_as_background_process( + "redact_related_events", + self._relation_handler.redact_events_related_to, + requester=requester, + event_id=event_id, + initial_redaction_event=event, + relation_types=with_relations, + ) + event_id = event.event_id except ShadowBanError: event_id = "$" + random_string(43) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 9b1b72c68a..180a11ef88 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -119,6 +119,8 @@ class VersionsRestServlet(RestServlet): # Adds support for simple HTTP rendezvous as per MSC3886 "org.matrix.msc3886": self.config.experimental.msc3886_endpoint is not None, + # Adds support for relation-based redactions as per MSC3912. + "org.matrix.msc3912": self.config.experimental.msc3912_enabled, }, }, ) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index c022510e76..ca431002c8 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -295,6 +295,42 @@ class RelationsWorkerStore(SQLBaseStore): "get_recent_references_for_event", _get_recent_references_for_event_txn ) + async def get_all_relations_for_event_with_types( + self, + event_id: str, + relation_types: List[str], + ) -> List[str]: + """Get the event IDs of all events that have a relation to the given event with + one of the given relation types. + + Args: + event_id: The event for which to look for related events. + relation_types: The types of relations to look for. + + Returns: + A list of the IDs of the events that relate to the given event with one of + the given relation types. + """ + + def get_all_relation_ids_for_event_with_types_txn( + txn: LoggingTransaction, + ) -> List[str]: + rows = self.db_pool.simple_select_many_txn( + txn=txn, + table="event_relations", + column="relation_type", + iterable=relation_types, + keyvalues={"relates_to_id": event_id}, + retcols=["event_id"], + ) + + return [row["event_id"] for row in rows] + + return await self.db_pool.runInteraction( + desc="get_all_relation_ids_for_event_with_types", + func=get_all_relation_ids_for_event_with_types_txn, + ) + async def event_includes_relation(self, event_id: str) -> bool: """Check if the given event relates to another event. diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index be4c67d68e..5dfe44defb 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -11,17 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Optional from twisted.test.proto_helpers import MemoryReactor +from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin from synapse.rest.client import login, room, sync from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, override_config class RedactionsTestCase(HomeserverTestCase): @@ -67,7 +68,12 @@ class RedactionsTestCase(HomeserverTestCase): ) def _redact_event( - self, access_token: str, room_id: str, event_id: str, expect_code: int = 200 + self, + access_token: str, + room_id: str, + event_id: str, + expect_code: int = 200, + with_relations: Optional[List[str]] = None, ) -> JsonDict: """Helper function to send a redaction event. @@ -75,7 +81,13 @@ class RedactionsTestCase(HomeserverTestCase): """ path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id) - channel = self.make_request("POST", path, content={}, access_token=access_token) + request_content = {} + if with_relations: + request_content["org.matrix.msc3912.with_relations"] = with_relations + + channel = self.make_request( + "POST", path, request_content, access_token=access_token + ) self.assertEqual(channel.code, expect_code) return channel.json_body @@ -201,3 +213,256 @@ class RedactionsTestCase(HomeserverTestCase): # These should all succeed, even though this would be denied by # the standard message ratelimiter self._redact_event(self.mod_access_token, self.room_id, msg_id) + + @override_config({"experimental_features": {"msc3912_enabled": True}}) + def test_redact_relations(self) -> None: + """Tests that we can redact the relations of an event at the same time as the + event itself. + """ + # Send a root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "hello"}, + tok=self.mod_access_token, + ) + root_event_id = res["event_id"] + + # Send an edit to this root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "body": " * hello world", + "m.new_content": { + "body": "hello world", + "msgtype": "m.text", + }, + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.REPLACE, + }, + "msgtype": "m.text", + }, + tok=self.mod_access_token, + ) + edit_event_id = res["event_id"] + + # Also send a threaded message whose root is the same as the edit's. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message 1", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.mod_access_token, + ) + threaded_event_id = res["event_id"] + + # Also send a reaction, again with the same root. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Reaction, + content={ + "m.relates_to": { + "rel_type": RelationTypes.ANNOTATION, + "event_id": root_event_id, + "key": "👍", + } + }, + tok=self.mod_access_token, + ) + reaction_event_id = res["event_id"] + + # Redact the root event, specifying that we also want to delete events that + # relate to it with m.replace. + self._redact_event( + self.mod_access_token, + self.room_id, + root_event_id, + with_relations=[ + RelationTypes.REPLACE, + RelationTypes.THREAD, + ], + ) + + # Check that the root event got redacted. + event_dict = self.helper.get_event( + self.room_id, root_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the edit got redacted. + event_dict = self.helper.get_event( + self.room_id, edit_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the threaded message got redacted. + event_dict = self.helper.get_event( + self.room_id, threaded_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the reaction did not get redacted. + event_dict = self.helper.get_event( + self.room_id, reaction_event_id, self.mod_access_token + ) + self.assertNotIn("redacted_because", event_dict, event_dict) + + @override_config({"experimental_features": {"msc3912_enabled": True}}) + def test_redact_relations_no_perms(self) -> None: + """Tests that, when redacting a message along with its relations, if not all + the related messages can be redacted because of insufficient permissions, the + server still redacts all the ones that can be. + """ + # Send a root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "root", + }, + tok=self.other_access_token, + ) + root_event_id = res["event_id"] + + # Send a first threaded message, this one from the moderator. We do this for the + # first message with the m.thread relation (and not the last one) to ensure + # that, when the server fails to redact it, it doesn't stop there, and it + # instead goes on to redact the other one. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message 1", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.mod_access_token, + ) + first_threaded_event_id = res["event_id"] + + # Send a second threaded message, this time from the user who'll perform the + # redaction. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message 2", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.other_access_token, + ) + second_threaded_event_id = res["event_id"] + + # Redact the thread's root, and request that all threaded messages are also + # redacted. Send that request from the non-mod user, so that the first threaded + # event cannot be redacted. + self._redact_event( + self.other_access_token, + self.room_id, + root_event_id, + with_relations=[RelationTypes.THREAD], + ) + + # Check that the thread root got redacted. + event_dict = self.helper.get_event( + self.room_id, root_event_id, self.other_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the last message in the thread got redacted, despite failing to + # redact the one before it. + event_dict = self.helper.get_event( + self.room_id, second_threaded_event_id, self.other_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the message that was sent into the tread by the mod user is not + # redacted. + event_dict = self.helper.get_event( + self.room_id, first_threaded_event_id, self.other_access_token + ) + self.assertIn("body", event_dict["content"], event_dict) + self.assertEqual("message 1", event_dict["content"]["body"]) + + @override_config({"experimental_features": {"msc3912_enabled": True}}) + def test_redact_relations_txn_id_reuse(self) -> None: + """Tests that redacting a message using a transaction ID, then reusing the same + transaction ID but providing an additional list of relations to redact, is + effectively a no-op. + """ + # Send a root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "root", + }, + tok=self.mod_access_token, + ) + root_event_id = res["event_id"] + + # Send a first threaded message. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "I'm in a thread!", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.mod_access_token, + ) + threaded_event_id = res["event_id"] + + # Send a first redaction request which redacts only the root event. + channel = self.make_request( + method="PUT", + path=f"/rooms/{self.room_id}/redact/{root_event_id}/foo", + content={}, + access_token=self.mod_access_token, + ) + self.assertEqual(channel.code, 200) + + # Send a second redaction request which redacts the root event as well as + # threaded messages. + channel = self.make_request( + method="PUT", + path=f"/rooms/{self.room_id}/redact/{root_event_id}/foo", + content={"org.matrix.msc3912.with_relations": [RelationTypes.THREAD]}, + access_token=self.mod_access_token, + ) + self.assertEqual(channel.code, 200) + + # Check that the root event got redacted. + event_dict = self.helper.get_event( + self.room_id, root_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict) + + # Check that the threaded message didn't get redacted (since that wasn't part of + # the original redaction). + event_dict = self.helper.get_event( + self.room_id, threaded_event_id, self.mod_access_token + ) + self.assertIn("body", event_dict["content"], event_dict) + self.assertEqual("I'm in a thread!", event_dict["content"]["body"]) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 706399fae5..8d6f2b6ff9 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -410,6 +410,43 @@ class RestHelper: return channel.json_body + def get_event( + self, + room_id: str, + event_id: str, + tok: Optional[str] = None, + expect_code: int = HTTPStatus.OK, + ) -> JsonDict: + """Request a specific event from the server. + + Args: + room_id: the room in which the event was sent. + event_id: the event's ID. + tok: the token to request the event with. + expect_code: the expected HTTP status for the response. + + Returns: + The event as a dict. + """ + path = f"/_matrix/client/v3/rooms/{room_id}/event/{event_id}" + if tok: + path = path + f"?access_token={tok}" + + channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + path, + ) + + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( + expect_code, + channel.code, + channel.result["body"], + ) + + return channel.json_body + def _read_write_state( self, room_id: str, -- cgit 1.5.1 From a4b1f6456276e62b3f4d6b060c289b6413b8a5c2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 4 Nov 2022 18:43:51 +0200 Subject: Fix /refresh endpoint version (#14364) --- changelog.d/14364.bugfix | 1 + synapse/rest/client/login.py | 2 +- tests/rest/client/test_auth.py | 16 ++++++++-------- 3 files changed, 10 insertions(+), 9 deletions(-) create mode 100644 changelog.d/14364.bugfix (limited to 'tests/rest') diff --git a/changelog.d/14364.bugfix b/changelog.d/14364.bugfix new file mode 100644 index 0000000000..514bf859bb --- /dev/null +++ b/changelog.d/14364.bugfix @@ -0,0 +1 @@ +Fix refresh token endpoint to be under /r0 and /v3 instead of /v1. Contributed by Tulir @ Beeper. diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 7774f1967d..05706b598c 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -536,7 +536,7 @@ def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict: class RefreshTokenServlet(RestServlet): - PATTERNS = (re.compile("^/_matrix/client/v1/refresh$"),) + PATTERNS = client_patterns("/refresh$") def __init__(self, hs: "HomeServer"): self._auth_handler = hs.get_auth_handler() diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 847294dc8e..208ec44829 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -635,7 +635,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): """ return self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": refresh_token}, ) @@ -724,7 +724,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) @@ -765,7 +765,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) @@ -1002,7 +1002,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # This first refresh should work properly first_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -1012,7 +1012,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # This one as well, since the token in the first one was never used second_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -1022,7 +1022,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # This one should not, since the token from the first refresh is not valid anymore third_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": first_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -1056,7 +1056,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail fourth_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -1068,7 +1068,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # But refreshing from the last valid refresh token still works fifth_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": second_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( -- cgit 1.5.1 From 7894251bcea7714b47e3849e509ea717bb18e9f5 Mon Sep 17 00:00:00 2001 From: Shay Date: Mon, 7 Nov 2022 13:38:50 -0800 Subject: Correctly create power level event during initial room creation (#14361) --- changelog.d/14361.bugfix | 1 + synapse/handlers/room.py | 25 +++++++++++++++++++++++-- tests/rest/client/test_rooms.py | 4 ++-- 3 files changed, 26 insertions(+), 4 deletions(-) create mode 100644 changelog.d/14361.bugfix (limited to 'tests/rest') diff --git a/changelog.d/14361.bugfix b/changelog.d/14361.bugfix new file mode 100644 index 0000000000..33ba1d92af --- /dev/null +++ b/changelog.d/14361.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.71.0rc1 where the power level event was incorrectly created during initial room creation. \ No newline at end of file diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index f10cfca073..66a50bca6e 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1080,6 +1080,19 @@ class RoomCreationHandler: for_batch: bool, **kwargs: Any, ) -> Tuple[EventBase, synapse.events.snapshot.EventContext]: + """ + Creates an event and associated event context. + Args: + etype: the type of event to be created + content: content of the event + for_batch: whether the event is being created for batch persisting. If + bool for_batch is true, this will create an event using the prev_event_ids, + and will create an event context for the event using the parameters state_map + and current_state_group, thus these parameters must be provided in this + case if for_batch is True. The subsequently created event and context + are suitable for being batched up and bulk persisted to the database + with other similarly created events. + """ nonlocal depth nonlocal prev_event @@ -1139,13 +1152,21 @@ class RoomCreationHandler: depth += 1 state_map[(EventTypes.Member, creator.user.to_string())] = member_event_id + # we need the state group of the membership event as it is the current state group + event_to_state = ( + await self._storage_controllers.state.get_state_group_for_events( + [member_event_id] + ) + ) + current_state_group = event_to_state[member_event_id] + events_to_send = [] # We treat the power levels override specially as this needs to be one # of the first events that get sent into a room. pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) if pl_content is not None: power_event, power_context = await create_event( - EventTypes.PowerLevels, pl_content, False + EventTypes.PowerLevels, pl_content, True ) current_state_group = power_context._state_group events_to_send.append((power_event, power_context)) @@ -1194,7 +1215,7 @@ class RoomCreationHandler: pl_event, pl_context = await create_event( EventTypes.PowerLevels, power_level_content, - False, + True, ) current_state_group = pl_context._state_group events_to_send.append((pl_event, pl_context)) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 1084d4ad9d..e919e089cb 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -715,7 +715,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(34, channel.resource_usage.db_txn_count) + self.assertEqual(33, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -728,7 +728,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(37, channel.resource_usage.db_txn_count) + self.assertEqual(36, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id -- cgit 1.5.1 From a3623af74e0af0d2f6cbd37b47dc54a1acd314d5 Mon Sep 17 00:00:00 2001 From: Ashish Kumar Date: Fri, 11 Nov 2022 19:38:17 +0400 Subject: Add an Admin API endpoint for looking up users based on 3PID (#14405) --- changelog.d/14405.feature | 1 + docs/admin_api/user_admin_api.md | 39 ++++++++++++++ synapse/rest/admin/__init__.py | 2 + synapse/rest/admin/users.py | 25 +++++++++ tests/rest/admin/test_user.py | 107 ++++++++++++++++++++++++++++++++++----- 5 files changed, 161 insertions(+), 13 deletions(-) create mode 100644 changelog.d/14405.feature (limited to 'tests/rest') diff --git a/changelog.d/14405.feature b/changelog.d/14405.feature new file mode 100644 index 0000000000..d3ba89b597 --- /dev/null +++ b/changelog.d/14405.feature @@ -0,0 +1 @@ +Add an [Admin API](https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/index.html) endpoint for user lookup based on third-party ID (3PID). Contributed by @ashfame. diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index c95d6c9b05..880bef4194 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -1197,3 +1197,42 @@ Returns a `404` HTTP status code if no user was found, with a response body like ``` _Added in Synapse 1.68.0._ + + +### Find a user based on their Third Party ID (ThreePID or 3PID) + +The API is: + +``` +GET /_synapse/admin/v1/threepid/$medium/users/$address +``` + +When a user matched the given address for the given medium, an HTTP code `200` with a response body like the following is returned: + +```json +{ + "user_id": "@hello:example.org" +} +``` + +**Parameters** + +The following parameters should be set in the URL: + +- `medium` - Kind of third-party ID, either `email` or `msisdn`. +- `address` - Value of the third-party ID. + +The `address` may have characters that are not URL-safe, so it is advised to URL-encode those parameters. + +**Errors** + +Returns a `404` HTTP status code if no user was found, with a response body like this: + +```json +{ + "errcode":"M_NOT_FOUND", + "error":"User not found" +} +``` + +_Added in Synapse 1.72.0._ diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 885669f9c7..c62ea22116 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -81,6 +81,7 @@ from synapse.rest.admin.users import ( ShadowBanRestServlet, UserAdminServlet, UserByExternalId, + UserByThreePid, UserMembershipRestServlet, UserRegisterServlet, UserRestServletV2, @@ -277,6 +278,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomMessagesRestServlet(hs).register(http_server) RoomTimestampToEventRestServlet(hs).register(http_server) UserByExternalId(hs).register(http_server) + UserByThreePid(hs).register(http_server) # Some servlets only get registered for the main process. if hs.config.worker.worker_app is None: diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 15ac2059aa..1951b8a9f2 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -1224,3 +1224,28 @@ class UserByExternalId(RestServlet): raise NotFoundError("User not found") return HTTPStatus.OK, {"user_id": user_id} + + +class UserByThreePid(RestServlet): + """Find a user based on 3PID of a particular medium""" + + PATTERNS = admin_patterns("/threepid/(?P[^/]*)/users/(?P
[^/]*)") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastores().main + + async def on_GET( + self, + request: SynapseRequest, + medium: str, + address: str, + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + user_id = await self._store.get_user_id_by_threepid(medium, address) + + if user_id is None: + raise NotFoundError("User not found") + + return HTTPStatus.OK, {"user_id": user_id} diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 63410ffdf1..e8c9457794 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -41,14 +41,12 @@ from tests.unittest import override_config class UserRegisterTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, profile.register_servlets, ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.url = "/_synapse/admin/v1/register" self.registration_handler = Mock() @@ -446,7 +444,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): class UsersListTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -1108,7 +1105,6 @@ class UserDevicesTestCase(unittest.HomeserverTestCase): class DeactivateAccountTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -1382,7 +1378,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): class UserRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -2803,7 +2798,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): class UserMembershipRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -2960,7 +2954,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): class PushersRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -3089,7 +3082,6 @@ class PushersRestTestCase(unittest.HomeserverTestCase): class UserMediaRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -3881,7 +3873,6 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): ], ) class WhoisRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -3961,7 +3952,6 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): class ShadowBanRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -4042,7 +4032,6 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): class RateLimitTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -4268,7 +4257,6 @@ class RateLimitTestCase(unittest.HomeserverTestCase): class AccountDataTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -4358,7 +4346,6 @@ class AccountDataTestCase(unittest.HomeserverTestCase): class UsersByExternalIdTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -4442,3 +4429,97 @@ class UsersByExternalIdTestCase(unittest.HomeserverTestCase): {"user_id": self.other_user}, channel.json_body, ) + + +class UsersByThreePidTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.get_success( + self.store.user_add_threepid( + self.other_user, "email", "user@email.com", 1, 1 + ) + ) + self.get_success( + self.store.user_add_threepid(self.other_user, "msidn", "+1-12345678", 1, 1) + ) + + def test_no_auth(self) -> None: + """Try to look up a user without authentication.""" + url = "/_synapse/admin/v1/threepid/email/users/user%40email.com" + + channel = self.make_request( + "GET", + url, + ) + + self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_medium_does_not_exist(self) -> None: + """Tests that both a lookup for a medium that does not exist and a user that + doesn't exist with that third party ID returns a 404""" + # test for unknown medium + url = "/_synapse/admin/v1/threepid/publickey/users/unknown-key" + + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + # test for unknown user with a known medium + url = "/_synapse/admin/v1/threepid/email/users/unknown" + + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_success(self) -> None: + """Tests a successful medium + address lookup""" + # test for email medium with encoded value of user@email.com + url = "/_synapse/admin/v1/threepid/email/users/user%40email.com" + + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual( + {"user_id": self.other_user}, + channel.json_body, + ) + + # test for msidn medium with encoded value of +1-12345678 + url = "/_synapse/admin/v1/threepid/msidn/users/%2B1-12345678" + + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual( + {"user_id": self.other_user}, + channel.json_body, + ) -- cgit 1.5.1 From 1799a54a545618782840a60950ef4b64da9ee24d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 22 Nov 2022 07:26:11 -0500 Subject: Batch fetch bundled annotations (#14491) Avoid an n+1 query problem and fetch the bundled aggregations for m.annotation relations in a single query instead of a query per event. This applies similar logic for as was previously done for edits in 8b309adb436c162510ed1402f33b8741d71fc058 (#11660) and threads in b65acead428653b988351ae8d7b22127a22039cd (#11752). --- changelog.d/14491.feature | 1 + synapse/handlers/relations.py | 197 ++++++++++++++++------------ synapse/storage/databases/main/relations.py | 139 ++++++++++++-------- synapse/util/caches/descriptors.py | 2 +- tests/rest/client/test_relations.py | 4 +- 5 files changed, 202 insertions(+), 141 deletions(-) create mode 100644 changelog.d/14491.feature (limited to 'tests/rest') diff --git a/changelog.d/14491.feature b/changelog.d/14491.feature new file mode 100644 index 0000000000..4fca7282f7 --- /dev/null +++ b/changelog.d/14491.feature @@ -0,0 +1 @@ +Reduce database load of [Client-Server endpoints](https://spec.matrix.org/v1.4/client-server-api/#aggregations) which return bundled aggregations. diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 8e71dda970..ca94239f61 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -13,7 +13,16 @@ # limitations under the License. import enum import logging -from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + FrozenSet, + Iterable, + List, + Optional, + Tuple, +) import attr @@ -259,48 +268,64 @@ class RelationsHandler: e.msg, ) - async def get_annotations_for_event( - self, - event_id: str, - room_id: str, - limit: int = 5, - ignored_users: FrozenSet[str] = frozenset(), - ) -> List[JsonDict]: - """Get a list of annotations on the event, grouped by event type and + async def get_annotations_for_events( + self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset() + ) -> Dict[str, List[JsonDict]]: + """Get a list of annotations to the given events, grouped by event type and aggregation key, sorted by count. - This is used e.g. to get the what and how many reactions have happend + This is used e.g. to get the what and how many reactions have happened on an event. Args: - event_id: Fetch events that relate to this event ID. - room_id: The room the event belongs to. - limit: Only fetch the `limit` groups. + event_ids: Fetch events that relate to these event IDs. ignored_users: The users ignored by the requesting user. Returns: - List of groups of annotations that match. Each row is a dict with - `type`, `key` and `count` fields. + A map of event IDs to a list of groups of annotations that match. + Each entry is a dict with `type`, `key` and `count` fields. """ # Get the base results for all users. - full_results = await self._main_store.get_aggregation_groups_for_event( - event_id, room_id, limit + full_results = await self._main_store.get_aggregation_groups_for_events( + event_ids ) + # Avoid additional logic if there are no ignored users. + if not ignored_users: + return { + event_id: results + for event_id, results in full_results.items() + if results + } + # Then subtract off the results for any ignored users. ignored_results = await self._main_store.get_aggregation_groups_for_users( - event_id, room_id, limit, ignored_users + [event_id for event_id, results in full_results.items() if results], + ignored_users, ) - filtered_results = [] - for result in full_results: - key = (result["type"], result["key"]) - if key in ignored_results: - result = result.copy() - result["count"] -= ignored_results[key] - if result["count"] <= 0: - continue - filtered_results.append(result) + filtered_results = {} + for event_id, results in full_results.items(): + # If no annotations, skip. + if not results: + continue + + # If there are not ignored results for this event, copy verbatim. + if event_id not in ignored_results: + filtered_results[event_id] = results + continue + + # Otherwise, subtract out the ignored results. + event_ignored_results = ignored_results[event_id] + for result in results: + key = (result["type"], result["key"]) + if key in event_ignored_results: + # Ensure to not modify the cache. + result = result.copy() + result["count"] -= event_ignored_results[key] + if result["count"] <= 0: + continue + filtered_results.setdefault(event_id, []).append(result) return filtered_results @@ -366,59 +391,62 @@ class RelationsHandler: results = {} for event_id, summary in summaries.items(): - if summary: - thread_count, latest_thread_event = summary - - # Subtract off the count of any ignored users. - for ignored_user in ignored_users: - thread_count -= ignored_results.get((event_id, ignored_user), 0) - - # This is gnarly, but if the latest event is from an ignored user, - # attempt to find one that isn't from an ignored user. - if latest_thread_event.sender in ignored_users: - room_id = latest_thread_event.room_id - - # If the root event is not found, something went wrong, do - # not include a summary of the thread. - event = await self._event_handler.get_event(user, room_id, event_id) - if event is None: - continue + # If no thread, skip. + if not summary: + continue - potential_events, _ = await self.get_relations_for_event( - event_id, - event, - room_id, - RelationTypes.THREAD, - ignored_users, - ) + thread_count, latest_thread_event = summary - # If all found events are from ignored users, do not include - # a summary of the thread. - if not potential_events: - continue + # Subtract off the count of any ignored users. + for ignored_user in ignored_users: + thread_count -= ignored_results.get((event_id, ignored_user), 0) - # The *last* event returned is the one that is cared about. - event = await self._event_handler.get_event( - user, room_id, potential_events[-1].event_id - ) - # It is unexpected that the event will not exist. - if event is None: - logger.warning( - "Unable to fetch latest event in a thread with event ID: %s", - potential_events[-1].event_id, - ) - continue - latest_thread_event = event - - results[event_id] = _ThreadAggregation( - latest_event=latest_thread_event, - count=thread_count, - # If there's a thread summary it must also exist in the - # participated dictionary. - current_user_participated=events_by_id[event_id].sender == user_id - or participated[event_id], + # This is gnarly, but if the latest event is from an ignored user, + # attempt to find one that isn't from an ignored user. + if latest_thread_event.sender in ignored_users: + room_id = latest_thread_event.room_id + + # If the root event is not found, something went wrong, do + # not include a summary of the thread. + event = await self._event_handler.get_event(user, room_id, event_id) + if event is None: + continue + + potential_events, _ = await self.get_relations_for_event( + event_id, + event, + room_id, + RelationTypes.THREAD, + ignored_users, ) + # If all found events are from ignored users, do not include + # a summary of the thread. + if not potential_events: + continue + + # The *last* event returned is the one that is cared about. + event = await self._event_handler.get_event( + user, room_id, potential_events[-1].event_id + ) + # It is unexpected that the event will not exist. + if event is None: + logger.warning( + "Unable to fetch latest event in a thread with event ID: %s", + potential_events[-1].event_id, + ) + continue + latest_thread_event = event + + results[event_id] = _ThreadAggregation( + latest_event=latest_thread_event, + count=thread_count, + # If there's a thread summary it must also exist in the + # participated dictionary. + current_user_participated=events_by_id[event_id].sender == user_id + or participated[event_id], + ) + return results @trace @@ -496,17 +524,18 @@ class RelationsHandler: # (as that is what makes it part of the thread). relations_by_id[latest_thread_event.event_id] = RelationTypes.THREAD - # Fetch other relations per event. - for event in events_by_id.values(): - # Fetch any annotations (ie, reactions) to bundle with this event. - annotations = await self.get_annotations_for_event( - event.event_id, event.room_id, ignored_users=ignored_users - ) + # Fetch any annotations (ie, reactions) to bundle with this event. + annotations_by_event_id = await self.get_annotations_for_events( + events_by_id.keys(), ignored_users=ignored_users + ) + for event_id, annotations in annotations_by_event_id.items(): if annotations: - results.setdefault( - event.event_id, BundledAggregations() - ).annotations = {"chunk": annotations} + results.setdefault(event_id, BundledAggregations()).annotations = { + "chunk": annotations + } + # Fetch other relations per event. + for event in events_by_id.values(): # Fetch any references to bundle with this event. references, next_token = await self.get_relations_for_event( event.event_id, diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index ca431002c8..f96a16956a 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -20,6 +20,7 @@ from typing import ( FrozenSet, Iterable, List, + Mapping, Optional, Set, Tuple, @@ -394,106 +395,136 @@ class RelationsWorkerStore(SQLBaseStore): ) return result is not None - @cached(tree=True) - async def get_aggregation_groups_for_event( - self, event_id: str, room_id: str, limit: int = 5 - ) -> List[JsonDict]: - """Get a list of annotations on the event, grouped by event type and + @cached() + async def get_aggregation_groups_for_event(self, event_id: str) -> List[JsonDict]: + raise NotImplementedError() + + @cachedList( + cached_method_name="get_aggregation_groups_for_event", list_name="event_ids" + ) + async def get_aggregation_groups_for_events( + self, event_ids: Collection[str] + ) -> Mapping[str, Optional[List[JsonDict]]]: + """Get a list of annotations on the given events, grouped by event type and aggregation key, sorted by count. This is used e.g. to get the what and how many reactions have happend on an event. Args: - event_id: Fetch events that relate to this event ID. - room_id: The room the event belongs to. - limit: Only fetch the `limit` groups. + event_ids: Fetch events that relate to these event IDs. Returns: - List of groups of annotations that match. Each row is a dict with - `type`, `key` and `count` fields. + A map of event IDs to a list of groups of annotations that match. + Each entry is a dict with `type`, `key` and `count` fields. """ + # The number of entries to return per event ID. + limit = 5 - args = [ - event_id, - room_id, - RelationTypes.ANNOTATION, - limit, - ] + clause, args = make_in_list_sql_clause( + self.database_engine, "relates_to_id", event_ids + ) + args.append(RelationTypes.ANNOTATION) - sql = """ - SELECT type, aggregation_key, COUNT(DISTINCT sender) - FROM event_relations - INNER JOIN events USING (event_id) - WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? - GROUP BY relation_type, type, aggregation_key - ORDER BY COUNT(*) DESC - LIMIT ? + sql = f""" + SELECT + relates_to_id, + annotation.type, + aggregation_key, + COUNT(DISTINCT annotation.sender) + FROM events AS annotation + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = annotation.room_id + WHERE + {clause} + AND relation_type = ? + GROUP BY relates_to_id, annotation.type, aggregation_key + ORDER BY relates_to_id, COUNT(*) DESC """ - def _get_aggregation_groups_for_event_txn( + def _get_aggregation_groups_for_events_txn( txn: LoggingTransaction, - ) -> List[JsonDict]: + ) -> Mapping[str, List[JsonDict]]: txn.execute(sql, args) - return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn] + result: Dict[str, List[JsonDict]] = {} + for event_id, type, key, count in cast( + List[Tuple[str, str, str, int]], txn + ): + event_results = result.setdefault(event_id, []) + + # Limit the number of results per event ID. + if len(event_results) == limit: + continue + + event_results.append({"type": type, "key": key, "count": count}) + + return result return await self.db_pool.runInteraction( - "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn + "get_aggregation_groups_for_events", _get_aggregation_groups_for_events_txn ) async def get_aggregation_groups_for_users( - self, - event_id: str, - room_id: str, - limit: int, - users: FrozenSet[str] = frozenset(), - ) -> Dict[Tuple[str, str], int]: + self, event_ids: Collection[str], users: FrozenSet[str] + ) -> Dict[str, Dict[Tuple[str, str], int]]: """Fetch the partial aggregations for an event for specific users. This is used, in conjunction with get_aggregation_groups_for_event, to remove information from the results for ignored users. Args: - event_id: Fetch events that relate to this event ID. - room_id: The room the event belongs to. - limit: Only fetch the `limit` groups. + event_ids: Fetch events that relate to these event IDs. users: The users to fetch information for. Returns: - A map of (event type, aggregation key) to a count of users. + A map of event ID to a map of (event type, aggregation key) to a + count of users. """ if not users: return {} - args: List[Union[str, int]] = [ - event_id, - room_id, - RelationTypes.ANNOTATION, - ] + events_sql, args = make_in_list_sql_clause( + self.database_engine, "relates_to_id", event_ids + ) users_sql, users_args = make_in_list_sql_clause( - self.database_engine, "sender", users + self.database_engine, "annotation.sender", users ) args.extend(users_args) + args.append(RelationTypes.ANNOTATION) sql = f""" - SELECT type, aggregation_key, COUNT(DISTINCT sender) - FROM event_relations - INNER JOIN events USING (event_id) - WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? AND {users_sql} - GROUP BY relation_type, type, aggregation_key - ORDER BY COUNT(*) DESC - LIMIT ? + SELECT + relates_to_id, + annotation.type, + aggregation_key, + COUNT(DISTINCT annotation.sender) + FROM events AS annotation + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = annotation.room_id + WHERE {events_sql} AND {users_sql} AND relation_type = ? + GROUP BY relates_to_id, annotation.type, aggregation_key + ORDER BY relates_to_id, COUNT(*) DESC """ def _get_aggregation_groups_for_users_txn( txn: LoggingTransaction, - ) -> Dict[Tuple[str, str], int]: - txn.execute(sql, args + [limit]) + ) -> Dict[str, Dict[Tuple[str, str], int]]: + txn.execute(sql, args) - return {(row[0], row[1]): row[2] for row in txn} + result: Dict[str, Dict[Tuple[str, str], int]] = {} + for event_id, type, key, count in cast( + List[Tuple[str, str, str, int]], txn + ): + result.setdefault(event_id, {})[(type, key)] = count + + return result return await self.db_pool.runInteraction( "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 75428d19ba..72227359b9 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -503,7 +503,7 @@ def cachedList( is specified as a list that is iterated through to lookup keys in the original cache. A new tuple consisting of the (deduplicated) keys that weren't in the cache gets passed to the original function, which is expected to results - in a map of key to value for each passed value. THe new results are stored in the + in a map of key to value for each passed value. The new results are stored in the original cache. Note that any missing values are cached as None. Args: diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index e3d801f7a8..2d2b683548 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1108,7 +1108,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): # The "user" sent the root event and is making queries for the bundled # aggregations: they have participated. - self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 9) + self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8) # The "user2" sent replies in the thread and is making queries for the # bundled aggregations: they have participated. # @@ -1170,7 +1170,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations["latest_event"].get("unsigned"), ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9) + self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8) def test_nested_thread(self) -> None: """ -- cgit 1.5.1 From 6d7523ef1484ec56f4a6dffdd2ea3d8736b4cc98 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 22 Nov 2022 09:41:09 -0500 Subject: Batch fetch bundled references (#14508) Avoid an n+1 query problem and fetch the bundled aggregations for m.reference relations in a single query instead of a query per event. This applies similar logic for as was previously done for edits in 8b309adb436c162510ed1402f33b8741d71fc058 (#11660; threads in b65acead428653b988351ae8d7b22127a22039cd (#11752); and annotations in 1799a54a545618782840a60950ef4b64da9ee24d (#14491). --- changelog.d/14508.feature | 1 + synapse/handlers/relations.py | 128 +++++++++++++--------------- synapse/storage/databases/main/cache.py | 1 + synapse/storage/databases/main/events.py | 4 + synapse/storage/databases/main/relations.py | 74 ++++++++++++++-- tests/rest/client/test_relations.py | 4 +- 6 files changed, 133 insertions(+), 79 deletions(-) create mode 100644 changelog.d/14508.feature (limited to 'tests/rest') diff --git a/changelog.d/14508.feature b/changelog.d/14508.feature new file mode 100644 index 0000000000..4fca7282f7 --- /dev/null +++ b/changelog.d/14508.feature @@ -0,0 +1 @@ +Reduce database load of [Client-Server endpoints](https://spec.matrix.org/v1.4/client-server-api/#aggregations) which return bundled aggregations. diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index ca94239f61..8414be5879 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -13,16 +13,7 @@ # limitations under the License. import enum import logging -from typing import ( - TYPE_CHECKING, - Collection, - Dict, - FrozenSet, - Iterable, - List, - Optional, - Tuple, -) +from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, Optional import attr @@ -32,7 +23,7 @@ from synapse.events import EventBase, relation_from_event from synapse.logging.opentracing import trace from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, Requester, StreamToken, UserID +from synapse.types import JsonDict, Requester, UserID from synapse.visibility import filter_events_for_client if TYPE_CHECKING: @@ -181,40 +172,6 @@ class RelationsHandler: return return_value - async def get_relations_for_event( - self, - event_id: str, - event: EventBase, - room_id: str, - relation_type: str, - ignored_users: FrozenSet[str] = frozenset(), - ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: - """Get a list of events which relate to an event, ordered by topological ordering. - - Args: - event_id: Fetch events that relate to this event ID. - event: The matching EventBase to event_id. - room_id: The room the event belongs to. - relation_type: The type of relation. - ignored_users: The users ignored by the requesting user. - - Returns: - List of event IDs that match relations requested. The rows are of - the form `{"event_id": "..."}`. - """ - - # Call the underlying storage method, which is cached. - related_events, next_token = await self._main_store.get_relations_for_event( - event_id, event, room_id, relation_type, direction="f" - ) - - # Filter out ignored users and convert to the expected format. - related_events = [ - event for event in related_events if event.sender not in ignored_users - ] - - return related_events, next_token - async def redact_events_related_to( self, requester: Requester, @@ -329,6 +286,46 @@ class RelationsHandler: return filtered_results + async def get_references_for_events( + self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset() + ) -> Dict[str, List[_RelatedEvent]]: + """Get a list of references to the given events. + + Args: + event_ids: Fetch events that relate to this event ID. + ignored_users: The users ignored by the requesting user. + + Returns: + A map of event IDs to a list related events. + """ + + related_events = await self._main_store.get_references_for_events(event_ids) + + # Avoid additional logic if there are no ignored users. + if not ignored_users: + return { + event_id: results + for event_id, results in related_events.items() + if results + } + + # Filter out ignored users. + results = {} + for event_id, events in related_events.items(): + # If no references, skip. + if not events: + continue + + # Filter ignored users out. + events = [event for event in events if event.sender not in ignored_users] + # If there are no events left, skip this event. + if not events: + continue + + results[event_id] = events + + return results + async def _get_threads_for_events( self, events_by_id: Dict[str, EventBase], @@ -412,14 +409,18 @@ class RelationsHandler: if event is None: continue - potential_events, _ = await self.get_relations_for_event( - event_id, - event, - room_id, - RelationTypes.THREAD, - ignored_users, + # Attempt to find another event to use as the latest event. + potential_events, _ = await self._main_store.get_relations_for_event( + event_id, event, room_id, RelationTypes.THREAD, direction="f" ) + # Filter out ignored users. + potential_events = [ + event + for event in potential_events + if event.sender not in ignored_users + ] + # If all found events are from ignored users, do not include # a summary of the thread. if not potential_events: @@ -534,27 +535,16 @@ class RelationsHandler: "chunk": annotations } - # Fetch other relations per event. - for event in events_by_id.values(): - # Fetch any references to bundle with this event. - references, next_token = await self.get_relations_for_event( - event.event_id, - event, - event.room_id, - RelationTypes.REFERENCE, - ignored_users=ignored_users, - ) + # Fetch any references to bundle with this event. + references_by_event_id = await self.get_references_for_events( + events_by_id.keys(), ignored_users=ignored_users + ) + for event_id, references in references_by_event_id.items(): if references: - aggregations = results.setdefault(event.event_id, BundledAggregations()) - aggregations.references = { + results.setdefault(event_id, BundledAggregations()).references = { "chunk": [{"event_id": ev.event_id} for ev in references] } - if next_token: - aggregations.references["next_batch"] = await next_token.to_string( - self._main_store - ) - # Fetch any edits (but not for redacted events). # # Note that there is no use in limiting edits by ignored users since the @@ -600,7 +590,7 @@ class RelationsHandler: room_id, requester, allow_departed_users=True ) - # Note that ignored users are not passed into get_relations_for_event + # Note that ignored users are not passed into get_threads # below. Ignored users are handled in filter_events_for_client (and by # not passing them in here we should get a better cache hit rate). thread_roots, next_batch = await self._main_store.get_threads( diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index ddb7397714..a58668a380 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -259,6 +259,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): if relates_to: self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,)) + self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,)) self._attempt_to_invalidate_cache( "get_aggregation_groups_for_event", (relates_to,) ) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index d68f127f9b..0f097a2927 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2049,6 +2049,10 @@ class PersistEventsStore: self.store._invalidate_cache_and_stream( txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,) ) + if rel_type == RelationTypes.REFERENCE: + self.store._invalidate_cache_and_stream( + txn, self.store.get_references_for_event, (redacted_relates_to,) + ) if rel_type == RelationTypes.REPLACE: self.store._invalidate_cache_and_stream( txn, self.store.get_applicable_edit, (redacted_relates_to,) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index f96a16956a..aea96e9d24 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -82,8 +82,6 @@ class _RelatedEvent: event_id: str # The sender of the related event. sender: str - topological_ordering: Optional[int] - stream_ordering: int class RelationsWorkerStore(SQLBaseStore): @@ -246,13 +244,17 @@ class RelationsWorkerStore(SQLBaseStore): txn.execute(sql, where_args + [limit + 1]) events = [] - for event_id, relation_type, sender, topo_ordering, stream_ordering in txn: + topo_orderings: List[int] = [] + stream_orderings: List[int] = [] + for event_id, relation_type, sender, topo_ordering, stream_ordering in cast( + List[Tuple[str, str, str, int, int]], txn + ): # Do not include edits for redacted events as they leak event # content. if not is_redacted or relation_type != RelationTypes.REPLACE: - events.append( - _RelatedEvent(event_id, sender, topo_ordering, stream_ordering) - ) + events.append(_RelatedEvent(event_id, sender)) + topo_orderings.append(topo_ordering) + stream_orderings.append(stream_ordering) # If there are more events, generate the next pagination key from the # last event returned. @@ -261,9 +263,11 @@ class RelationsWorkerStore(SQLBaseStore): # Instead of using the last row (which tells us there is more # data), use the last row to be returned. events = events[:limit] + topo_orderings = topo_orderings[:limit] + stream_orderings = stream_orderings[:limit] - topo = events[-1].topological_ordering - token = events[-1].stream_ordering + topo = topo_orderings[-1] + token = stream_orderings[-1] if direction == "b": # Tokens are positions between events. # This token points *after* the last event in the chunk. @@ -530,6 +534,60 @@ class RelationsWorkerStore(SQLBaseStore): "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn ) + @cached() + async def get_references_for_event(self, event_id: str) -> List[JsonDict]: + raise NotImplementedError() + + @cachedList(cached_method_name="get_references_for_event", list_name="event_ids") + async def get_references_for_events( + self, event_ids: Collection[str] + ) -> Mapping[str, Optional[List[_RelatedEvent]]]: + """Get a list of references to the given events. + + Args: + event_ids: Fetch events that relate to these event IDs. + + Returns: + A map of event IDs to a list of related event IDs (and their senders). + """ + + clause, args = make_in_list_sql_clause( + self.database_engine, "relates_to_id", event_ids + ) + args.append(RelationTypes.REFERENCE) + + sql = f""" + SELECT relates_to_id, ref.event_id, ref.sender + FROM events AS ref + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = ref.room_id + WHERE + {clause} + AND relation_type = ? + ORDER BY ref.topological_ordering, ref.stream_ordering + """ + + def _get_references_for_events_txn( + txn: LoggingTransaction, + ) -> Mapping[str, List[_RelatedEvent]]: + txn.execute(sql, args) + + result: Dict[str, List[_RelatedEvent]] = {} + for relates_to_id, event_id, sender in cast( + List[Tuple[str, str, str]], txn + ): + result.setdefault(relates_to_id, []).append( + _RelatedEvent(event_id, sender) + ) + + return result + + return await self.db_pool.runInteraction( + "_get_references_for_events_txn", _get_references_for_events_txn + ) + @cached() def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: raise NotImplementedError() diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 2d2b683548..b86f341ff5 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1108,7 +1108,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): # The "user" sent the root event and is making queries for the bundled # aggregations: they have participated. - self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8) + self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 7) # The "user2" sent replies in the thread and is making queries for the # bundled aggregations: they have participated. # @@ -1170,7 +1170,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations["latest_event"].get("unsigned"), ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8) + self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 7) def test_nested_thread(self) -> None: """ -- cgit 1.5.1 From 6d47b7e32589e816eb766446cc1ff19ea73fc7c1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 22 Nov 2022 14:08:04 -0500 Subject: Add a type hint for `get_device_handler()` and fix incorrect types. (#14055) This was the last untyped handler from the HomeServer object. Since it was being treated as Any (and thus unchecked) it was being used incorrectly in a few places. --- changelog.d/14055.misc | 1 + synapse/handlers/deactivate_account.py | 4 +++ synapse/handlers/device.py | 65 ++++++++++++++++++++++++++-------- synapse/handlers/e2e_keys.py | 61 ++++++++++++++++--------------- synapse/handlers/register.py | 4 +++ synapse/handlers/set_password.py | 6 +++- synapse/handlers/sso.py | 9 +++++ synapse/module_api/__init__.py | 10 +++++- synapse/replication/http/devices.py | 11 ++++-- synapse/rest/admin/__init__.py | 26 ++++++++------ synapse/rest/admin/devices.py | 13 +++++-- synapse/rest/client/devices.py | 17 ++++++--- synapse/rest/client/logout.py | 9 +++-- synapse/server.py | 2 +- tests/handlers/test_device.py | 19 ++++++---- tests/rest/admin/test_device.py | 5 ++- 16 files changed, 185 insertions(+), 77 deletions(-) create mode 100644 changelog.d/14055.misc (limited to 'tests/rest') diff --git a/changelog.d/14055.misc b/changelog.d/14055.misc new file mode 100644 index 0000000000..02980bc528 --- /dev/null +++ b/changelog.d/14055.misc @@ -0,0 +1 @@ +Add missing type hints to `HomeServer`. diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 816e1a6d79..d74d135c0c 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -16,6 +16,7 @@ import logging from typing import TYPE_CHECKING, Optional from synapse.api.errors import SynapseError +from synapse.handlers.device import DeviceHandler from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import Codes, Requester, UserID, create_requester @@ -76,6 +77,9 @@ class DeactivateAccountHandler: True if identity server supports removing threepids, otherwise False. """ + # This can only be called on the main process. + assert isinstance(self._device_handler, DeviceHandler) + # Check if this user can be deactivated if not await self._third_party_rules.check_can_deactivate_user( user_id, by_admin diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index da3ddafeae..b1e55e1b9e 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -65,6 +65,8 @@ DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000 class DeviceWorkerHandler: + device_list_updater: "DeviceListWorkerUpdater" + def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.hs = hs @@ -76,6 +78,8 @@ class DeviceWorkerHandler: self.server_name = hs.hostname self._msc3852_enabled = hs.config.experimental.msc3852_enabled + self.device_list_updater = DeviceListWorkerUpdater(hs) + @trace async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: """ @@ -99,6 +103,19 @@ class DeviceWorkerHandler: log_kv(device_map) return devices + async def get_dehydrated_device( + self, user_id: str + ) -> Optional[Tuple[str, JsonDict]]: + """Retrieve the information for a dehydrated device. + + Args: + user_id: the user whose dehydrated device we are looking for + Returns: + a tuple whose first item is the device ID, and the second item is + the dehydrated device information + """ + return await self.store.get_dehydrated_device(user_id) + @trace async def get_device(self, user_id: str, device_id: str) -> JsonDict: """Retrieve the given device @@ -127,7 +144,7 @@ class DeviceWorkerHandler: @cancellable async def get_device_changes_in_shared_rooms( self, user_id: str, room_ids: Collection[str], from_token: StreamToken - ) -> Collection[str]: + ) -> Set[str]: """Get the set of users whose devices have changed who share a room with the given user. """ @@ -320,6 +337,8 @@ class DeviceWorkerHandler: class DeviceHandler(DeviceWorkerHandler): + device_list_updater: "DeviceListUpdater" + def __init__(self, hs: "HomeServer"): super().__init__(hs) @@ -606,19 +625,6 @@ class DeviceHandler(DeviceWorkerHandler): await self.delete_devices(user_id, [old_device_id]) return device_id - async def get_dehydrated_device( - self, user_id: str - ) -> Optional[Tuple[str, JsonDict]]: - """Retrieve the information for a dehydrated device. - - Args: - user_id: the user whose dehydrated device we are looking for - Returns: - a tuple whose first item is the device ID, and the second item is - the dehydrated device information - """ - return await self.store.get_dehydrated_device(user_id) - async def rehydrate_device( self, user_id: str, access_token: str, device_id: str ) -> dict: @@ -882,7 +888,36 @@ def _update_device_from_client_ips( ) -class DeviceListUpdater: +class DeviceListWorkerUpdater: + "Handles incoming device list updates from federation and contacts the main process over replication" + + def __init__(self, hs: "HomeServer"): + from synapse.replication.http.devices import ( + ReplicationUserDevicesResyncRestServlet, + ) + + self._user_device_resync_client = ( + ReplicationUserDevicesResyncRestServlet.make_client(hs) + ) + + async def user_device_resync( + self, user_id: str, mark_failed_as_stale: bool = True + ) -> Optional[JsonDict]: + """Fetches all devices for a user and updates the device cache with them. + + Args: + user_id: The user's id whose device_list will be updated. + mark_failed_as_stale: Whether to mark the user's device list as stale + if the attempt to resync failed. + Returns: + A dict with device info as under the "devices" in the result of this + request: + https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid + """ + return await self._user_device_resync_client(user_id=user_id) + + +class DeviceListUpdater(DeviceListWorkerUpdater): "Handles incoming device list updates from federation and updates the DB" def __init__(self, hs: "HomeServer", device_handler: DeviceHandler): diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index bf1221f523..5fe102e2f2 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -27,9 +27,9 @@ from twisted.internet import defer from synapse.api.constants import EduTypes from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError +from synapse.handlers.device import DeviceHandler from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace -from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.types import ( JsonDict, UserID, @@ -56,27 +56,23 @@ class E2eKeysHandler: self.is_mine = hs.is_mine self.clock = hs.get_clock() - self._edu_updater = SigningKeyEduUpdater(hs, self) - federation_registry = hs.get_federation_registry() - self._is_master = hs.config.worker.worker_app is None - if not self._is_master: - self._user_device_resync_client = ( - ReplicationUserDevicesResyncRestServlet.make_client(hs) - ) - else: + is_master = hs.config.worker.worker_app is None + if is_master: + edu_updater = SigningKeyEduUpdater(hs) + # Only register this edu handler on master as it requires writing # device updates to the db federation_registry.register_edu_handler( EduTypes.SIGNING_KEY_UPDATE, - self._edu_updater.incoming_signing_key_update, + edu_updater.incoming_signing_key_update, ) # also handle the unstable version # FIXME: remove this when enough servers have upgraded federation_registry.register_edu_handler( EduTypes.UNSTABLE_SIGNING_KEY_UPDATE, - self._edu_updater.incoming_signing_key_update, + edu_updater.incoming_signing_key_update, ) # doesn't really work as part of the generic query API, because the @@ -319,14 +315,13 @@ class E2eKeysHandler: # probably be tracking their device lists. However, we haven't # done an initial sync on the device list so we do it now. try: - if self._is_master: - resync_results = await self.device_handler.device_list_updater.user_device_resync( + resync_results = ( + await self.device_handler.device_list_updater.user_device_resync( user_id ) - else: - resync_results = await self._user_device_resync_client( - user_id=user_id - ) + ) + if resync_results is None: + raise ValueError("Device resync failed") # Add the device keys to the results. user_devices = resync_results["devices"] @@ -605,6 +600,8 @@ class E2eKeysHandler: async def upload_keys_for_user( self, user_id: str, device_id: str, keys: JsonDict ) -> JsonDict: + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) time_now = self.clock.time_msec() @@ -732,6 +729,8 @@ class E2eKeysHandler: user_id: the user uploading the keys keys: the signing keys """ + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) # if a master key is uploaded, then check it. Otherwise, load the # stored master key, to check signatures on other keys @@ -823,6 +822,9 @@ class E2eKeysHandler: Raises: SynapseError: if the signatures dict is not valid. """ + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) + failures = {} # signatures to be stored. Each item will be a SignatureListItem @@ -1200,6 +1202,9 @@ class E2eKeysHandler: A tuple of the retrieved key content, the key's ID and the matching VerifyKey. If the key cannot be retrieved, all values in the tuple will instead be None. """ + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) + try: remote_result = await self.federation.query_user_devices( user.domain, user.to_string() @@ -1396,11 +1401,14 @@ class SignatureListItem: class SigningKeyEduUpdater: """Handles incoming signing key updates from federation and updates the DB""" - def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.clock = hs.get_clock() - self.e2e_keys_handler = e2e_keys_handler + + device_handler = hs.get_device_handler() + assert isinstance(device_handler, DeviceHandler) + self._device_handler = device_handler self._remote_edu_linearizer = Linearizer(name="remote_signing_key") @@ -1445,9 +1453,6 @@ class SigningKeyEduUpdater: user_id: the user whose updates we are processing """ - device_handler = self.e2e_keys_handler.device_handler - device_list_updater = device_handler.device_list_updater - async with self._remote_edu_linearizer.queue(user_id): pending_updates = self._pending_updates.pop(user_id, []) if not pending_updates: @@ -1459,13 +1464,11 @@ class SigningKeyEduUpdater: logger.info("pending updates: %r", pending_updates) for master_key, self_signing_key in pending_updates: - new_device_ids = ( - await device_list_updater.process_cross_signing_key_update( - user_id, - master_key, - self_signing_key, - ) + new_device_ids = await self._device_handler.device_list_updater.process_cross_signing_key_update( + user_id, + master_key, + self_signing_key, ) device_ids = device_ids + new_device_ids - await device_handler.notify_device_update(user_id, device_ids) + await self._device_handler.notify_device_update(user_id, device_ids) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ca1c7a1866..6307fa9c5d 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -38,6 +38,7 @@ from synapse.api.errors import ( ) from synapse.appservice import ApplicationService from synapse.config.server import is_threepid_reserved +from synapse.handlers.device import DeviceHandler from synapse.http.servlet import assert_params_in_dict from synapse.replication.http.login import RegisterDeviceReplicationServlet from synapse.replication.http.register import ( @@ -841,6 +842,9 @@ class RegistrationHandler: refresh_token = None refresh_token_id = None + # This can only run on the main process. + assert isinstance(self.device_handler, DeviceHandler) + registered_device_id = await self.device_handler.check_device_registered( user_id, device_id, diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 73861bbd40..bd9d0bb34b 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -15,6 +15,7 @@ import logging from typing import TYPE_CHECKING, Optional from synapse.api.errors import Codes, StoreError, SynapseError +from synapse.handlers.device import DeviceHandler from synapse.types import Requester if TYPE_CHECKING: @@ -29,7 +30,10 @@ class SetPasswordHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() + # This can only be instantiated on the main process. + device_handler = hs.get_device_handler() + assert isinstance(device_handler, DeviceHandler) + self._device_handler = device_handler async def set_password( self, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 749d7e93b0..e1c0bff1b2 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -37,6 +37,7 @@ from twisted.web.server import Request from synapse.api.constants import LoginType from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.config.sso import SsoAttributeRequirement +from synapse.handlers.device import DeviceHandler from synapse.handlers.register import init_counters_for_auth_provider from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent @@ -1035,6 +1036,8 @@ class SsoHandler: ) -> None: """Revoke any devices and in-flight logins tied to a provider session. + Can only be called from the main process. + Args: auth_provider_id: A unique identifier for this SSO provider, e.g. "oidc" or "saml". @@ -1042,6 +1045,12 @@ class SsoHandler: expected_user_id: The user we're expecting to logout. If set, it will ignore sessions belonging to other users and log an error. """ + + # It is expected that this is the main process. + assert isinstance( + self._device_handler, DeviceHandler + ), "revoking SSO sessions can only be called on the main process" + # Invalidate any running user-mapping sessions to_delete = [] for session_id, session in self._username_mapping_sessions.items(): diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 1adc1fd64f..96a661177a 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -86,6 +86,7 @@ from synapse.handlers.auth import ( ON_LOGGED_OUT_CALLBACK, AuthHandler, ) +from synapse.handlers.device import DeviceHandler from synapse.handlers.push_rules import RuleSpec, check_actions from synapse.http.client import SimpleHttpClient from synapse.http.server import ( @@ -207,6 +208,7 @@ class ModuleApi: self._registration_handler = hs.get_registration_handler() self._send_email_handler = hs.get_send_email_handler() self._push_rules_handler = hs.get_push_rules_handler() + self._device_handler = hs.get_device_handler() self.custom_template_dir = hs.config.server.custom_template_directory try: @@ -784,6 +786,8 @@ class ModuleApi: ) -> Generator["defer.Deferred[Any]", Any, None]: """Invalidate an access token for a user + Can only be called from the main process. + Added in Synapse v0.25.0. Args: @@ -796,6 +800,10 @@ class ModuleApi: Raises: synapse.api.errors.AuthError: the access token is invalid """ + assert isinstance( + self._device_handler, DeviceHandler + ), "invalidate_access_token can only be called on the main process" + # see if the access token corresponds to a device user_info = yield defer.ensureDeferred( self._auth.get_user_by_access_token(access_token) @@ -805,7 +813,7 @@ class ModuleApi: if device_id: # delete the device, which will also delete its access tokens yield defer.ensureDeferred( - self._hs.get_device_handler().delete_devices(user_id, [device_id]) + self._device_handler.delete_devices(user_id, [device_id]) ) else: # no associated device. Just delete the access token. diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index c21629def8..7c4941c3d3 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Optional, Tuple from twisted.web.server import Request @@ -63,7 +63,12 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.device_list_updater = hs.get_device_handler().device_list_updater + from synapse.handlers.device import DeviceHandler + + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_list_updater = handler.device_list_updater + self.store = hs.get_datastores().main self.clock = hs.get_clock() @@ -73,7 +78,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): async def _handle_request( # type: ignore[override] self, request: Request, user_id: str - ) -> Tuple[int, JsonDict]: + ) -> Tuple[int, Optional[JsonDict]]: user_devices = await self.device_list_updater.user_device_resync(user_id) return 200, user_devices diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index c62ea22116..fb73886df0 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -238,6 +238,10 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: """ Register all the admin servlets. """ + # Admin servlets aren't registered on workers. + if hs.config.worker.worker_app is not None: + return + register_servlets_for_client_rest_resource(hs, http_server) BlockRoomRestServlet(hs).register(http_server) ListRoomRestServlet(hs).register(http_server) @@ -254,9 +258,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: UserTokenRestServlet(hs).register(http_server) UserRestServletV2(hs).register(http_server) UsersRestServletV2(hs).register(http_server) - DeviceRestServlet(hs).register(http_server) - DevicesRestServlet(hs).register(http_server) - DeleteDevicesRestServlet(hs).register(http_server) UserMediaStatisticsRestServlet(hs).register(http_server) EventReportDetailRestServlet(hs).register(http_server) EventReportsRestServlet(hs).register(http_server) @@ -280,12 +281,13 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: UserByExternalId(hs).register(http_server) UserByThreePid(hs).register(http_server) - # Some servlets only get registered for the main process. - if hs.config.worker.worker_app is None: - SendServerNoticeServlet(hs).register(http_server) - BackgroundUpdateEnabledRestServlet(hs).register(http_server) - BackgroundUpdateRestServlet(hs).register(http_server) - BackgroundUpdateStartJobRestServlet(hs).register(http_server) + DeviceRestServlet(hs).register(http_server) + DevicesRestServlet(hs).register(http_server) + DeleteDevicesRestServlet(hs).register(http_server) + SendServerNoticeServlet(hs).register(http_server) + BackgroundUpdateEnabledRestServlet(hs).register(http_server) + BackgroundUpdateRestServlet(hs).register(http_server) + BackgroundUpdateStartJobRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource( @@ -294,9 +296,11 @@ def register_servlets_for_client_rest_resource( """Register only the servlets which need to be exposed on /_matrix/client/xxx""" WhoisRestServlet(hs).register(http_server) PurgeHistoryStatusRestServlet(hs).register(http_server) - DeactivateAccountRestServlet(hs).register(http_server) PurgeHistoryRestServlet(hs).register(http_server) - ResetPasswordRestServlet(hs).register(http_server) + # The following resources can only be run on the main process. + if hs.config.worker.worker_app is None: + DeactivateAccountRestServlet(hs).register(http_server) + ResetPasswordRestServlet(hs).register(http_server) SearchUsersRestServlet(hs).register(http_server) UserRegisterServlet(hs).register(http_server) AccountValidityRenewServlet(hs).register(http_server) diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index d934880102..3b2f2d9abb 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -16,6 +16,7 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import NotFoundError, SynapseError +from synapse.handlers.device import DeviceHandler from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -43,7 +44,9 @@ class DeviceRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.store = hs.get_datastores().main self.is_mine = hs.is_mine @@ -112,7 +115,9 @@ class DevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.store = hs.get_datastores().main self.is_mine = hs.is_mine @@ -143,7 +148,9 @@ class DeleteDevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.store = hs.get_datastores().main self.is_mine = hs.is_mine diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 8f3cbd4ea2..69b803f9f8 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -20,6 +20,7 @@ from pydantic import Extra, StrictStr from synapse.api import errors from synapse.api.errors import NotFoundError +from synapse.handlers.device import DeviceHandler from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -80,7 +81,9 @@ class DeleteDevicesRestServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.auth_handler = hs.get_auth_handler() class PostBody(RequestBodyModel): @@ -125,7 +128,9 @@ class DeviceRestServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.auth_handler = hs.get_auth_handler() self._msc3852_enabled = hs.config.experimental.msc3852_enabled @@ -256,7 +261,9 @@ class DehydratedDeviceServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -313,7 +320,9 @@ class ClaimDehydratedDeviceServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler class PostBody(RequestBodyModel): device_id: StrictStr diff --git a/synapse/rest/client/logout.py b/synapse/rest/client/logout.py index 23dfa4518f..6d34625ad5 100644 --- a/synapse/rest/client/logout.py +++ b/synapse/rest/client/logout.py @@ -15,6 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple +from synapse.handlers.device import DeviceHandler from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet from synapse.http.site import SynapseRequest @@ -34,7 +35,9 @@ class LogoutRestServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self._device_handler = handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_expired=True) @@ -59,7 +62,9 @@ class LogoutAllRestServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self._device_handler = handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_expired=True) diff --git a/synapse/server.py b/synapse/server.py index f0a60d0056..5baae2325e 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -510,7 +510,7 @@ class HomeServer(metaclass=abc.ABCMeta): ) @cache_in_self - def get_device_handler(self): + def get_device_handler(self) -> DeviceWorkerHandler: if self.config.worker.worker_app: return DeviceWorkerHandler(self) else: diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index b8b465d35b..ce7525e29c 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -19,7 +19,7 @@ from typing import Optional from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import NotFoundError, SynapseError -from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN +from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler from synapse.server import HomeServer from synapse.util import Clock @@ -32,7 +32,9 @@ user2 = "@theresa:bbb" class DeviceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) - self.handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.handler = handler self.store = hs.get_datastores().main return hs @@ -61,6 +63,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): self.assertEqual(res, "fco") dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) + assert dev is not None self.assertEqual(dev["display_name"], "display name") def test_device_is_preserved_if_exists(self) -> None: @@ -83,6 +86,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): self.assertEqual(res2, "fco") dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) + assert dev is not None self.assertEqual(dev["display_name"], "display name") def test_device_id_is_made_up_if_unspecified(self) -> None: @@ -95,6 +99,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): ) dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id)) + assert dev is not None self.assertEqual(dev["display_name"], "display") def test_get_devices_by_user(self) -> None: @@ -264,7 +269,9 @@ class DeviceTestCase(unittest.HomeserverTestCase): class DehydrationTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) - self.handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.handler = handler self.registration = hs.get_registration_handler() self.auth = hs.get_auth() self.store = hs.get_datastores().main @@ -284,9 +291,9 @@ class DehydrationTestCase(unittest.HomeserverTestCase): ) ) - retrieved_device_id, device_data = self.get_success( - self.handler.get_dehydrated_device(user_id=user_id) - ) + result = self.get_success(self.handler.get_dehydrated_device(user_id=user_id)) + assert result is not None + retrieved_device_id, device_data = result self.assertEqual(retrieved_device_id, stored_dehydrated_device_id) self.assertEqual(device_data, {"device_data": {"foo": "bar"}}) diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index d52aee8f92..03f2112b07 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -19,6 +19,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes +from synapse.handlers.device import DeviceHandler from synapse.rest.client import login from synapse.server import HomeServer from synapse.util import Clock @@ -34,7 +35,9 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.handler = handler self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") -- cgit 1.5.1 From f6c74d1cb2ed966802b01a2b037f09ce7a842c18 Mon Sep 17 00:00:00 2001 From: Benjamin Kampmann Date: Thu, 24 Nov 2022 09:10:51 +0000 Subject: Implement message forward pagination from start when no from is given, fixes #12383 (#14149) Fixes https://github.com/matrix-org/synapse/issues/12383 --- changelog.d/14149.bugfix | 1 + synapse/handlers/pagination.py | 6 ++++++ synapse/streams/events.py | 13 +++++++++++++ tests/rest/admin/test_room.py | 40 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 60 insertions(+) create mode 100644 changelog.d/14149.bugfix (limited to 'tests/rest') diff --git a/changelog.d/14149.bugfix b/changelog.d/14149.bugfix new file mode 100644 index 0000000000..b31c658266 --- /dev/null +++ b/changelog.d/14149.bugfix @@ -0,0 +1 @@ +Fix #12383: paginate room messages from the start if no from is given. Contributed by @gnunicorn . \ No newline at end of file diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index a4ca9cb8b4..c572508a02 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -448,6 +448,12 @@ class PaginationHandler: if pagin_config.from_token: from_token = pagin_config.from_token + elif pagin_config.direction == "f": + from_token = ( + await self.hs.get_event_sources().get_start_token_for_pagination( + room_id + ) + ) else: from_token = ( await self.hs.get_event_sources().get_current_token_for_pagination( diff --git a/synapse/streams/events.py b/synapse/streams/events.py index f331e1af16..619eb7f601 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -73,6 +73,19 @@ class EventSources: ) return token + @trace + async def get_start_token_for_pagination(self, room_id: str) -> StreamToken: + """Get the start token for a given room to be used to paginate + events. + + The returned token does not have the current values for fields other + than `room`, since they are not used during pagination. + + Returns: + The start token for pagination. + """ + return StreamToken.START + @trace async def get_current_token_for_pagination(self, room_id: str) -> StreamToken: """Get the current token for a given room to be used to paginate diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index d156be82b0..e0f5d54aba 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1857,6 +1857,46 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase): self.assertIn("chunk", channel.json_body) self.assertIn("end", channel.json_body) + def test_room_messages_backward(self) -> None: + """Test room messages can be retrieved by an admin that isn't in the room.""" + latest_event_id = self.helper.send( + self.room_id, body="message 1", tok=self.user_tok + )["event_id"] + + # Check that we get the first and second message when querying /messages. + channel = self.make_request( + "GET", + "/_synapse/admin/v1/rooms/%s/messages?dir=b" % (self.room_id,), + access_token=self.admin_user_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 6, [event["content"] for event in chunk]) + + # in backwards, this is the first event + self.assertEqual(chunk[0]["event_id"], latest_event_id) + + def test_room_messages_forward(self) -> None: + """Test room messages can be retrieved by an admin that isn't in the room.""" + latest_event_id = self.helper.send( + self.room_id, body="message 1", tok=self.user_tok + )["event_id"] + + # Check that we get the first and second message when querying /messages. + channel = self.make_request( + "GET", + "/_synapse/admin/v1/rooms/%s/messages?dir=f" % (self.room_id,), + access_token=self.admin_user_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 6, [event["content"] for event in chunk]) + + # in forward, this is the last event + self.assertEqual(chunk[5]["event_id"], latest_event_id) + def test_room_messages_purge(self) -> None: """Test room messages can be retrieved by an admin that isn't in the room.""" store = self.hs.get_datastores().main -- cgit 1.5.1