diff options
Diffstat (limited to 'synapse/rest/client')
-rw-r--r-- | synapse/rest/client/_base.py | 11 | ||||
-rw-r--r-- | synapse/rest/client/account.py | 82 | ||||
-rw-r--r-- | synapse/rest/client/account_data.py | 38 | ||||
-rw-r--r-- | synapse/rest/client/auth.py | 7 | ||||
-rw-r--r-- | synapse/rest/client/groups.py | 22 | ||||
-rw-r--r-- | synapse/rest/client/knock.py | 6 | ||||
-rw-r--r-- | synapse/rest/client/login.py | 67 | ||||
-rw-r--r-- | synapse/rest/client/openid.py | 2 | ||||
-rw-r--r-- | synapse/rest/client/push_rule.py | 114 | ||||
-rw-r--r-- | synapse/rest/client/receipts.py | 15 | ||||
-rw-r--r-- | synapse/rest/client/register.py | 103 | ||||
-rw-r--r-- | synapse/rest/client/relations.py | 80 | ||||
-rw-r--r-- | synapse/rest/client/report_event.py | 15 | ||||
-rw-r--r-- | synapse/rest/client/room.py | 237 | ||||
-rw-r--r-- | synapse/rest/client/room_batch.py | 31 | ||||
-rw-r--r-- | synapse/rest/client/room_keys.py | 53 | ||||
-rw-r--r-- | synapse/rest/client/sendtodevice.py | 27 | ||||
-rw-r--r-- | synapse/rest/client/sync.py | 16 | ||||
-rw-r--r-- | synapse/rest/client/transactions.py | 52 |
19 files changed, 598 insertions, 380 deletions
diff --git a/synapse/rest/client/_base.py b/synapse/rest/client/_base.py index 0443f4571c..a0971ce994 100644 --- a/synapse/rest/client/_base.py +++ b/synapse/rest/client/_base.py @@ -16,7 +16,7 @@ """ import logging import re -from typing import Iterable, Pattern +from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar, cast from synapse.api.errors import InteractiveAuthIncompleteError from synapse.api.urls import CLIENT_API_PREFIX @@ -76,7 +76,10 @@ def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int) ) -def interactive_auth_handler(orig): +C = TypeVar("C", bound=Callable[..., Awaitable[Tuple[int, JsonDict]]]) + + +def interactive_auth_handler(orig: C) -> C: """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors Takes a on_POST method which returns an Awaitable (errcode, body) response @@ -91,10 +94,10 @@ def interactive_auth_handler(orig): await self.auth_handler.check_auth """ - async def wrapped(*args, **kwargs): + async def wrapped(*args: Any, **kwargs: Any) -> Tuple[int, JsonDict]: try: return await orig(*args, **kwargs) except InteractiveAuthIncompleteError as e: return 401, e.result - return wrapped + return cast(C, wrapped) diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index 1ceb4094a9..62e3aa31a6 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -17,9 +17,11 @@ import logging import random import re from http import HTTPStatus -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Tuple from urllib.parse import urlparse +from twisted.web.server import Request + from synapse.api.constants import LoginType from synapse.api.errors import ( Codes, @@ -29,15 +31,17 @@ from synapse.api.errors import ( ) from synapse.config.emailconfig import ThreepidBehaviour from synapse.handlers.ui_auth import UIAuthSessionDataConstants -from synapse.http.server import finish_request, respond_with_html +from synapse.http.server import HttpServer, finish_request, respond_with_html from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.metrics import threepid_send_requests from synapse.push.mailer import Mailer +from synapse.types import JsonDict from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.stringutils import assert_valid_client_secret, random_string from synapse.util.threepids import check_3pid_allowed, validate_email @@ -69,7 +73,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): template_text=self.config.email_password_reset_template_text, ) - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -160,7 +164,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): class PasswordRestServlet(RestServlet): PATTERNS = client_patterns("/account/password$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -171,7 +175,7 @@ class PasswordRestServlet(RestServlet): self.http_client = hs.get_simple_http_client() @interactive_auth_handler - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) # we do basic sanity checks here because the auth layer will store these @@ -192,6 +196,7 @@ class PasswordRestServlet(RestServlet): # # In the second case, we require a password to confirm their identity. + requester = None if self.auth.has_access_token(request): requester = await self.auth.get_user_by_req(request) # blindly trust ASes without UI-authing them @@ -212,16 +217,15 @@ class PasswordRestServlet(RestServlet): # If a password is available now, hash the provided password and # store it for later. if new_password: - password_hash = await self.auth_handler.hash(new_password) + new_password_hash = await self.auth_handler.hash(new_password) await self.auth_handler.set_session_data( e.session_id, UIAuthSessionDataConstants.PASSWORD_HASH, - password_hash, + new_password_hash, ) raise user_id = requester.user.to_string() else: - requester = None try: result, params, session_id = await self.auth_handler.check_ui_auth( [[LoginType.EMAIL_IDENTITY]], @@ -236,11 +240,11 @@ class PasswordRestServlet(RestServlet): # If a password is available now, hash the provided password and # store it for later. if new_password: - password_hash = await self.auth_handler.hash(new_password) + new_password_hash = await self.auth_handler.hash(new_password) await self.auth_handler.set_session_data( e.session_id, UIAuthSessionDataConstants.PASSWORD_HASH, - password_hash, + new_password_hash, ) raise @@ -270,7 +274,7 @@ class PasswordRestServlet(RestServlet): # If we have a password in this request, prefer it. Otherwise, use the # password hash from an earlier request. if new_password: - password_hash = await self.auth_handler.hash(new_password) + password_hash: Optional[str] = await self.auth_handler.hash(new_password) elif session_id is not None: password_hash = await self.auth_handler.get_session_data( session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None @@ -297,7 +301,7 @@ class PasswordRestServlet(RestServlet): class DeactivateAccountRestServlet(RestServlet): PATTERNS = client_patterns("/account/deactivate$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -305,7 +309,7 @@ class DeactivateAccountRestServlet(RestServlet): self._deactivate_account_handler = hs.get_deactivate_account_handler() @interactive_auth_handler - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) erase = body.get("erase", False) if not isinstance(erase, bool): @@ -347,7 +351,7 @@ class DeactivateAccountRestServlet(RestServlet): class EmailThreepidRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/email/requestToken$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.config = hs.config @@ -362,7 +366,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): template_text=self.config.email_add_threepid_template_text, ) - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -458,7 +462,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): self.store = self.hs.get_datastore() self.identity_handler = hs.get_identity_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) assert_params_in_dict( body, ["client_secret", "country", "phone_number", "send_attempt"] @@ -534,11 +538,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): "/add_threepid/email/submit_token$", releases=(), unstable=True ) - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.config = hs.config self.clock = hs.get_clock() @@ -548,7 +548,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): self.config.email_add_threepid_template_failure_html ) - async def on_GET(self, request): + async def on_GET(self, request: Request) -> None: if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -605,18 +605,14 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): "/add_threepid/msisdn/submit_token$", releases=(), unstable=True ) - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.config = hs.config self.clock = hs.get_clock() self.store = hs.get_datastore() self.identity_handler = hs.get_identity_handler() - async def on_POST(self, request): + async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: if not self.config.account_threepid_delegate_msisdn: raise SynapseError( 400, @@ -641,7 +637,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): class ThreepidRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.identity_handler = hs.get_identity_handler() @@ -649,14 +645,14 @@ class ThreepidRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self.datastore = hs.get_datastore() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) threepids = await self.datastore.user_get_threepids(requester.user.to_string()) return 200, {"threepids": threepids} - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if not self.hs.config.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN @@ -697,7 +693,7 @@ class ThreepidRestServlet(RestServlet): class ThreepidAddRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/add$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.identity_handler = hs.get_identity_handler() @@ -706,7 +702,7 @@ class ThreepidAddRestServlet(RestServlet): self.http_client = hs.get_simple_http_client() @interactive_auth_handler - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if not self.hs.config.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN @@ -748,13 +744,13 @@ class ThreepidAddRestServlet(RestServlet): class ThreepidBindRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/bind$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) assert_params_in_dict(body, ["id_server", "sid", "client_secret"]) @@ -777,14 +773,14 @@ class ThreepidBindRestServlet(RestServlet): class ThreepidUnbindRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/unbind$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() self.datastore = self.hs.get_datastore() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: """Unbind the given 3pid from a specific identity server, or identity servers that are known to have this 3pid bound """ @@ -808,13 +804,13 @@ class ThreepidUnbindRestServlet(RestServlet): class ThreepidDeleteRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/delete$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if not self.hs.config.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN @@ -902,7 +898,7 @@ class ThreepidBulkLookupRestServlet(RestServlet): return 200, ret -def assert_valid_next_link(hs: "HomeServer", next_link: str): +def assert_valid_next_link(hs: "HomeServer", next_link: str) -> None: """ Raises a SynapseError if a given next_link value is invalid @@ -944,11 +940,11 @@ def assert_valid_next_link(hs: "HomeServer", next_link: str): class WhoamiRestServlet(RestServlet): PATTERNS = client_patterns("/account/whoami$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) response = {"user_id": requester.user.to_string()} @@ -961,7 +957,7 @@ class WhoamiRestServlet(RestServlet): return 200, response -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: EmailPasswordRequestTokenRestServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py index 6b038f5cc0..273d4c5c04 100644 --- a/synapse/rest/client/account_data.py +++ b/synapse/rest/client/account_data.py @@ -13,13 +13,19 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import AuthError, NotFoundError, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.types import UserID +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict, UserID from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -33,7 +39,7 @@ class AccountDataServlet(RestServlet): "/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -42,7 +48,9 @@ class AccountDataServlet(RestServlet): self._is_worker = hs.config.worker_app is not None self._profile_handler = hs.get_profile_handler() - async def on_PUT(self, request, user_id, account_data_type): + async def on_PUT( + self, request: SynapseRequest, user_id: str, account_data_type: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") @@ -61,7 +69,9 @@ class AccountDataServlet(RestServlet): return 200, {} - async def on_GET(self, request, user_id, account_data_type): + async def on_GET( + self, request: SynapseRequest, user_id: str, account_data_type: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get account data for other users.") @@ -88,13 +98,19 @@ class RoomAccountDataServlet(RestServlet): "/account_data/(?P<account_data_type>[^/]*)" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.handler = hs.get_account_data_handler() - async def on_PUT(self, request, user_id, room_id, account_data_type): + async def on_PUT( + self, + request: SynapseRequest, + user_id: str, + room_id: str, + account_data_type: str, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") @@ -114,7 +130,13 @@ class RoomAccountDataServlet(RestServlet): return 200, {} - async def on_GET(self, request, user_id, room_id, account_data_type): + async def on_GET( + self, + request: SynapseRequest, + user_id: str, + room_id: str, + account_data_type: str, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get account data for other users.") @@ -129,6 +151,6 @@ class RoomAccountDataServlet(RestServlet): return 200, event -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: AccountDataServlet(hs).register(http_server) RoomAccountDataServlet(hs).register(http_server) diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index df8cc4ac7a..7bb7801472 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -68,7 +68,10 @@ class AuthRestServlet(RestServlet): html = self.terms_template.render( session=session, terms_url="%s_matrix/consent?v=%s" - % (self.hs.config.public_baseurl, self.hs.config.user_consent_version), + % ( + self.hs.config.server.public_baseurl, + self.hs.config.user_consent_version, + ), myurl="%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), ) @@ -135,7 +138,7 @@ class AuthRestServlet(RestServlet): session=session, terms_url="%s_matrix/consent?v=%s" % ( - self.hs.config.public_baseurl, + self.hs.config.server.public_baseurl, self.hs.config.user_consent_version, ), myurl="%s/r0/auth/%s/fallback/web" diff --git a/synapse/rest/client/groups.py b/synapse/rest/client/groups.py index c3667ff8aa..a7e9aa3e9b 100644 --- a/synapse/rest/client/groups.py +++ b/synapse/rest/client/groups.py @@ -15,7 +15,7 @@ import logging from functools import wraps -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple from twisted.web.server import Request @@ -43,14 +43,18 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -def _validate_group_id(f): +def _validate_group_id( + f: Callable[..., Awaitable[Tuple[int, JsonDict]]] +) -> Callable[..., Awaitable[Tuple[int, JsonDict]]]: """Wrapper to validate the form of the group ID. Can be applied to any on_FOO methods that accepts a group ID as a URL parameter. """ @wraps(f) - def wrapper(self, request: Request, group_id: str, *args, **kwargs): + def wrapper( + self: RestServlet, request: Request, group_id: str, *args: Any, **kwargs: Any + ) -> Awaitable[Tuple[int, JsonDict]]: if not GroupID.is_valid(group_id): raise SynapseError(400, "%s is not a legal group ID" % (group_id,)) @@ -156,7 +160,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): group_id: str, category_id: Optional[str], room_id: str, - ): + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -188,7 +192,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): @_validate_group_id async def on_DELETE( self, request: SynapseRequest, group_id: str, category_id: str, room_id: str - ): + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -451,7 +455,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): @_validate_group_id async def on_DELETE( self, request: SynapseRequest, group_id: str, role_id: str, user_id: str - ): + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -674,7 +678,7 @@ class GroupAdminRoomsConfigServlet(RestServlet): @_validate_group_id async def on_PUT( self, request: SynapseRequest, group_id: str, room_id: str, config_key: str - ): + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -706,7 +710,7 @@ class GroupAdminUsersInviteServlet(RestServlet): @_validate_group_id async def on_PUT( - self, request: SynapseRequest, group_id, user_id + self, request: SynapseRequest, group_id: str, user_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -738,7 +742,7 @@ class GroupAdminUsersKickServlet(RestServlet): @_validate_group_id async def on_PUT( - self, request: SynapseRequest, group_id, user_id + self, request: SynapseRequest, group_id: str, user_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() diff --git a/synapse/rest/client/knock.py b/synapse/rest/client/knock.py index 68fb08d0ba..0152a0c66a 100644 --- a/synapse/rest/client/knock.py +++ b/synapse/rest/client/knock.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple from twisted.web.server import Request @@ -96,7 +96,9 @@ class KnockRoomAliasServlet(RestServlet): return 200, {"room_id": room_id} - def on_PUT(self, request: Request, room_identifier: str, txn_id: str): + def on_PUT( + self, request: Request, room_identifier: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 4be502a77b..a6ede7e2f3 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -79,7 +79,6 @@ class LoginRestServlet(RestServlet): self.saml2_enabled = hs.config.saml2_enabled self.cas_enabled = hs.config.cas_enabled self.oidc_enabled = hs.config.oidc_enabled - self._msc2858_enabled = hs.config.experimental.msc2858_enabled self._msc2918_enabled = hs.config.access_token_lifetime is not None self.auth = hs.get_auth() @@ -94,14 +93,14 @@ class LoginRestServlet(RestServlet): self._address_ratelimiter = Ratelimiter( store=hs.get_datastore(), clock=hs.get_clock(), - rate_hz=self.hs.config.rc_login_address.per_second, - burst_count=self.hs.config.rc_login_address.burst_count, + rate_hz=self.hs.config.ratelimiting.rc_login_address.per_second, + burst_count=self.hs.config.ratelimiting.rc_login_address.burst_count, ) self._account_ratelimiter = Ratelimiter( store=hs.get_datastore(), clock=hs.get_clock(), - rate_hz=self.hs.config.rc_login_account.per_second, - burst_count=self.hs.config.rc_login_account.burst_count, + rate_hz=self.hs.config.ratelimiting.rc_login_account.per_second, + burst_count=self.hs.config.ratelimiting.rc_login_account.burst_count, ) # ensure the CAS/SAML/OIDC handlers are loaded on this worker instance. @@ -111,7 +110,7 @@ class LoginRestServlet(RestServlet): _load_sso_handlers(hs) def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - flows = [] + flows: List[JsonDict] = [] if self.jwt_enabled: flows.append({"type": LoginRestServlet.JWT_TYPE}) flows.append({"type": LoginRestServlet.JWT_TYPE_DEPRECATED}) @@ -122,25 +121,15 @@ class LoginRestServlet(RestServlet): flows.append({"type": LoginRestServlet.CAS_TYPE}) if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: - sso_flow: JsonDict = { - "type": LoginRestServlet.SSO_TYPE, - "identity_providers": [ - _get_auth_flow_dict_for_idp( - idp, - ) - for idp in self._sso_handler.get_identity_providers().values() - ], - } - - if self._msc2858_enabled: - # backwards-compatibility support for clients which don't - # support the stable API yet - sso_flow["org.matrix.msc2858.identity_providers"] = [ - _get_auth_flow_dict_for_idp(idp, use_unstable_brands=True) - for idp in self._sso_handler.get_identity_providers().values() - ] - - flows.append(sso_flow) + flows.append( + { + "type": LoginRestServlet.SSO_TYPE, + "identity_providers": [ + _get_auth_flow_dict_for_idp(idp) + for idp in self._sso_handler.get_identity_providers().values() + ], + } + ) # While it's valid for us to advertise this login type generally, # synapse currently only gives out these tokens as part of the @@ -433,9 +422,7 @@ class LoginRestServlet(RestServlet): return result -def _get_auth_flow_dict_for_idp( - idp: SsoIdentityProvider, use_unstable_brands: bool = False -) -> JsonDict: +def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict: """Return an entry for the login flow dict Returns an entry suitable for inclusion in "identity_providers" in the @@ -443,17 +430,12 @@ def _get_auth_flow_dict_for_idp( Args: idp: the identity provider to describe - use_unstable_brands: whether we should use brand identifiers suitable - for the unstable API """ e: JsonDict = {"id": idp.idp_id, "name": idp.idp_name} if idp.idp_icon: e["icon"] = idp.idp_icon if idp.idp_brand: e["brand"] = idp.idp_brand - # use the stable brand identifier if the unstable identifier isn't defined. - if use_unstable_brands and idp.unstable_idp_brand: - e["brand"] = idp.unstable_idp_brand return e @@ -504,24 +486,7 @@ class SsoRedirectServlet(RestServlet): # register themselves with the main SSOHandler. _load_sso_handlers(hs) self._sso_handler = hs.get_sso_handler() - self._msc2858_enabled = hs.config.experimental.msc2858_enabled - self._public_baseurl = hs.config.public_baseurl - - def register(self, http_server: HttpServer) -> None: - super().register(http_server) - if self._msc2858_enabled: - # expose additional endpoint for MSC2858 support: backwards-compat support - # for clients which don't yet support the stable endpoints. - http_server.register_paths( - "GET", - client_patterns( - "/org.matrix.msc2858/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$", - releases=(), - unstable=True, - ), - self.on_GET, - self.__class__.__name__, - ) + self._public_baseurl = hs.config.server.public_baseurl async def on_GET( self, request: SynapseRequest, idp_id: Optional[str] = None diff --git a/synapse/rest/client/openid.py b/synapse/rest/client/openid.py index 4dda6dce4b..add56d6998 100644 --- a/synapse/rest/client/openid.py +++ b/synapse/rest/client/openid.py @@ -69,7 +69,7 @@ class IdTokenServlet(RestServlet): self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() - self.server_name = hs.config.server_name + self.server_name = hs.config.server.server_name async def on_POST( self, request: SynapseRequest, user_id: str diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py index 702b351d18..ecebc46e8d 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -12,22 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Union + +import attr + from synapse.api.errors import ( NotFoundError, StoreError, SynapseError, UnrecognizedRequestError, ) +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, parse_json_value_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.push.baserules import BASE_RULE_IDS, NEW_RULE_IDS from synapse.push.clientformat import format_push_rules_for_user from synapse.push.rulekinds import PRIORITY_CLASS_MAP from synapse.rest.client._base import client_patterns from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RuleSpec: + scope: str + template: str + rule_id: str + attr: Optional[str] class PushRuleRestServlet(RestServlet): @@ -36,16 +54,16 @@ class PushRuleRestServlet(RestServlet): "Unrecognised request: You probably wanted a trailing slash" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.notifier = hs.get_notifier() - self._is_worker = hs.config.worker_app is not None + self._is_worker = hs.config.worker.worker_app is not None self._users_new_default_push_rules = hs.config.users_new_default_push_rules - async def on_PUT(self, request, path): + async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]: if self._is_worker: raise Exception("Cannot handle PUT /push_rules on worker") @@ -57,25 +75,25 @@ class PushRuleRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) - if "/" in spec["rule_id"] or "\\" in spec["rule_id"]: + if "/" in spec.rule_id or "\\" in spec.rule_id: raise SynapseError(400, "rule_id may not contain slashes") content = parse_json_value_from_request(request) user_id = requester.user.to_string() - if "attr" in spec: + if spec.attr: await self.set_rule_attr(user_id, spec, content) self.notify_user(user_id) return 200, {} - if spec["rule_id"].startswith("."): + if spec.rule_id.startswith("."): # Rule ids starting with '.' are reserved for server default rules. raise SynapseError(400, "cannot add new rule_ids that start with '.'") try: (conditions, actions) = _rule_tuple_from_request_object( - spec["template"], spec["rule_id"], content + spec.template, spec.rule_id, content ) except InvalidRuleException as e: raise SynapseError(400, str(e)) @@ -106,7 +124,9 @@ class PushRuleRestServlet(RestServlet): return 200, {} - async def on_DELETE(self, request, path): + async def on_DELETE( + self, request: SynapseRequest, path: str + ) -> Tuple[int, JsonDict]: if self._is_worker: raise Exception("Cannot handle DELETE /push_rules on worker") @@ -127,7 +147,7 @@ class PushRuleRestServlet(RestServlet): else: raise - async def on_GET(self, request, path): + async def on_GET(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() @@ -138,40 +158,42 @@ class PushRuleRestServlet(RestServlet): rules = format_push_rules_for_user(requester.user, rules) - path = path.split("/")[1:] + path_parts = path.split("/")[1:] - if path == []: + if path_parts == []: # we're a reference impl: pedantry is our job. raise UnrecognizedRequestError( PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR ) - if path[0] == "": + if path_parts[0] == "": return 200, rules - elif path[0] == "global": - result = _filter_ruleset_with_path(rules["global"], path[1:]) + elif path_parts[0] == "global": + result = _filter_ruleset_with_path(rules["global"], path_parts[1:]) return 200, result else: raise UnrecognizedRequestError() - def notify_user(self, user_id): + def notify_user(self, user_id: str) -> None: stream_id = self.store.get_max_push_rules_stream_id() self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) - async def set_rule_attr(self, user_id, spec, val): - if spec["attr"] not in ("enabled", "actions"): + async def set_rule_attr( + self, user_id: str, spec: RuleSpec, val: Union[bool, JsonDict] + ) -> None: + if spec.attr not in ("enabled", "actions"): # for the sake of potential future expansion, shouldn't report # 404 in the case of an unknown request so check it corresponds to # a known attribute first. raise UnrecognizedRequestError() namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - rule_id = spec["rule_id"] + rule_id = spec.rule_id is_default_rule = rule_id.startswith(".") if is_default_rule: if namespaced_rule_id not in BASE_RULE_IDS: raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,)) - if spec["attr"] == "enabled": + if spec.attr == "enabled": if isinstance(val, dict) and "enabled" in val: val = val["enabled"] if not isinstance(val, bool): @@ -179,14 +201,18 @@ class PushRuleRestServlet(RestServlet): # This should *actually* take a dict, but many clients pass # bools directly, so let's not break them. raise SynapseError(400, "Value for 'enabled' must be boolean") - return await self.store.set_push_rule_enabled( + await self.store.set_push_rule_enabled( user_id, namespaced_rule_id, val, is_default_rule ) - elif spec["attr"] == "actions": + elif spec.attr == "actions": + if not isinstance(val, dict): + raise SynapseError(400, "Value must be a dict") actions = val.get("actions") + if not isinstance(actions, list): + raise SynapseError(400, "Value for 'actions' must be dict") _check_actions(actions) namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - rule_id = spec["rule_id"] + rule_id = spec.rule_id is_default_rule = rule_id.startswith(".") if is_default_rule: if user_id in self._users_new_default_push_rules: @@ -196,22 +222,21 @@ class PushRuleRestServlet(RestServlet): if namespaced_rule_id not in rule_ids: raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) - return await self.store.set_push_rule_actions( + await self.store.set_push_rule_actions( user_id, namespaced_rule_id, actions, is_default_rule ) else: raise UnrecognizedRequestError() -def _rule_spec_from_path(path): +def _rule_spec_from_path(path: Sequence[str]) -> RuleSpec: """Turn a sequence of path components into a rule spec Args: - path (sequence[unicode]): the URL path components. + path: the URL path components. Returns: - dict: rule spec dict, containing scope/template/rule_id entries, - and possibly attr. + rule spec, containing scope/template/rule_id entries, and possibly attr. Raises: UnrecognizedRequestError if the path components cannot be parsed. @@ -237,17 +262,18 @@ def _rule_spec_from_path(path): rule_id = path[0] - spec = {"scope": scope, "template": template, "rule_id": rule_id} - path = path[1:] + attr = None if len(path) > 0 and len(path[0]) > 0: - spec["attr"] = path[0] + attr = path[0] - return spec + return RuleSpec(scope, template, rule_id, attr) -def _rule_tuple_from_request_object(rule_template, rule_id, req_obj): +def _rule_tuple_from_request_object( + rule_template: str, rule_id: str, req_obj: JsonDict +) -> Tuple[List[JsonDict], List[Union[str, JsonDict]]]: if rule_template in ["override", "underride"]: if "conditions" not in req_obj: raise InvalidRuleException("Missing 'conditions'") @@ -277,7 +303,7 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj): return conditions, actions -def _check_actions(actions): +def _check_actions(actions: List[Union[str, JsonDict]]) -> None: if not isinstance(actions, list): raise InvalidRuleException("No actions found") @@ -290,7 +316,7 @@ def _check_actions(actions): raise InvalidRuleException("Unrecognised action") -def _filter_ruleset_with_path(ruleset, path): +def _filter_ruleset_with_path(ruleset: JsonDict, path: List[str]) -> JsonDict: if path == []: raise UnrecognizedRequestError( PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR @@ -315,7 +341,7 @@ def _filter_ruleset_with_path(ruleset, path): if r["rule_id"] == rule_id: the_rule = r if the_rule is None: - raise NotFoundError + raise NotFoundError() path = path[1:] if len(path) == 0: @@ -330,25 +356,25 @@ def _filter_ruleset_with_path(ruleset, path): raise UnrecognizedRequestError() -def _priority_class_from_spec(spec): - if spec["template"] not in PRIORITY_CLASS_MAP.keys(): - raise InvalidRuleException("Unknown template: %s" % (spec["template"])) - pc = PRIORITY_CLASS_MAP[spec["template"]] +def _priority_class_from_spec(spec: RuleSpec) -> int: + if spec.template not in PRIORITY_CLASS_MAP.keys(): + raise InvalidRuleException("Unknown template: %s" % (spec.template)) + pc = PRIORITY_CLASS_MAP[spec.template] return pc -def _namespaced_rule_id_from_spec(spec): - return _namespaced_rule_id(spec, spec["rule_id"]) +def _namespaced_rule_id_from_spec(spec: RuleSpec) -> str: + return _namespaced_rule_id(spec, spec.rule_id) -def _namespaced_rule_id(spec, rule_id): - return "global/%s/%s" % (spec["template"], rule_id) +def _namespaced_rule_id(spec: RuleSpec, rule_id: str) -> str: + return "global/%s/%s" % (spec.template, rule_id) class InvalidRuleException(Exception): pass -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: PushRuleRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index d9ab836cd8..9770413c61 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -13,13 +13,20 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.constants import ReadReceiptEventFields from synapse.api.errors import Codes, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -30,14 +37,16 @@ class ReceiptRestServlet(RestServlet): "/(?P<event_id>[^/]*)$" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.receipts_handler = hs.get_receipts_handler() self.presence_handler = hs.get_presence_handler() - async def on_POST(self, request, room_id, receipt_type, event_id): + async def on_POST( + self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if receipt_type != "m.read": @@ -67,5 +76,5 @@ class ReceiptRestServlet(RestServlet): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReceiptRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 4ad27027d5..42a298b1bf 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -15,7 +15,9 @@ import logging import random import re -from typing import List, Union +from typing import TYPE_CHECKING, List, Optional, Tuple + +from twisted.web.server import Request import synapse import synapse.api.auth @@ -30,15 +32,13 @@ from synapse.api.errors import ( ) from synapse.api.ratelimiting import Ratelimiter from synapse.config import ConfigError -from synapse.config.captcha import CaptchaConfig -from synapse.config.consent import ConsentConfig from synapse.config.emailconfig import ThreepidBehaviour +from synapse.config.homeserver import HomeServerConfig from synapse.config.ratelimiting import FederationRateLimitConfig -from synapse.config.registration import RegistrationConfig from synapse.config.server import is_threepid_reserved from synapse.handlers.auth import AuthHandler from synapse.handlers.ui_auth import UIAuthSessionDataConstants -from synapse.http.server import finish_request, respond_with_html +from synapse.http.server import HttpServer, finish_request, respond_with_html from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -46,6 +46,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.metrics import threepid_send_requests from synapse.push.mailer import Mailer from synapse.types import JsonDict @@ -60,17 +61,16 @@ from synapse.util.threepids import ( from ._base import client_patterns, interactive_auth_handler +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class EmailRegisterRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/register/email/requestToken$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.identity_handler = hs.get_identity_handler() @@ -84,7 +84,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): template_text=self.config.email_registration_template_text, ) - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.hs.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -172,16 +172,12 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): class MsisdnRegisterRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/register/msisdn/requestToken$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.identity_handler = hs.get_identity_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) assert_params_in_dict( @@ -258,11 +254,7 @@ class RegistrationSubmitTokenServlet(RestServlet): "/registration/(?P<medium>[^/]*)/submit_token$", releases=(), unstable=True ) - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -275,7 +267,7 @@ class RegistrationSubmitTokenServlet(RestServlet): self.config.email_registration_template_failure_html ) - async def on_GET(self, request, medium): + async def on_GET(self, request: Request, medium: str) -> None: if medium != "email": raise SynapseError( 400, "This medium is currently not supported for registration" @@ -329,11 +321,7 @@ class RegistrationSubmitTokenServlet(RestServlet): class UsernameAvailabilityRestServlet(RestServlet): PATTERNS = client_patterns("/register/available") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.registration_handler = hs.get_registration_handler() @@ -345,15 +333,15 @@ class UsernameAvailabilityRestServlet(RestServlet): # Artificially delay requests if rate > sleep_limit/window_size sleep_limit=1, # Amount of artificial delay to apply - sleep_msec=1000, + sleep_delay=1000, # Error with 429 if more than reject_limit requests are queued reject_limit=1, # Allow 1 request at a time - concurrent_requests=1, + concurrent=1, ), ) - async def on_GET(self, request): + async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: if not self.hs.config.enable_registration: raise SynapseError( 403, "Registration has been disabled", errcode=Codes.FORBIDDEN @@ -384,11 +372,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): unstable=True, ) - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.store = hs.get_datastore() @@ -399,7 +383,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count, ) - async def on_GET(self, request): + async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: await self.ratelimiter.ratelimit(None, (request.getClientIP(),)) if not self.hs.config.enable_registration: @@ -416,11 +400,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): class RegisterRestServlet(RestServlet): PATTERNS = client_patterns("/register$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs @@ -442,23 +422,21 @@ class RegisterRestServlet(RestServlet): ) @interactive_auth_handler - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) client_addr = request.getClientIP() await self.ratelimiter.ratelimit(None, client_addr, update=False) - kind = b"user" - if b"kind" in request.args: - kind = request.args[b"kind"][0] + kind = parse_string(request, "kind", default="user") - if kind == b"guest": + if kind == "guest": ret = await self._do_guest_registration(body, address=client_addr) return ret - elif kind != b"user": + elif kind != "user": raise UnrecognizedRequestError( - "Do not understand membership kind: %s" % (kind.decode("utf8"),) + f"Do not understand membership kind: {kind}", ) if self._msc2918_enabled: @@ -816,11 +794,11 @@ class RegisterRestServlet(RestServlet): async def _do_appservice_registration( self, - username, - as_token, - body, + username: str, + as_token: str, + body: JsonDict, should_issue_refresh_token: bool = False, - ): + ) -> JsonDict: user_id = await self.registration_handler.appservice_register( username, as_token ) @@ -837,7 +815,7 @@ class RegisterRestServlet(RestServlet): params: JsonDict, is_appservice_ghost: bool = False, should_issue_refresh_token: bool = False, - ): + ) -> JsonDict: """Complete registration of newly-registered user Allocates device_id if one was not given; also creates access_token. @@ -852,7 +830,10 @@ class RegisterRestServlet(RestServlet): Returns: dictionary for response from /register """ - result = {"user_id": user_id, "home_server": self.hs.hostname} + result: JsonDict = { + "user_id": user_id, + "home_server": self.hs.hostname, + } if not params.get("inhibit_login", False): device_id = params.get("device_id") initial_display_name = params.get("initial_device_display_name") @@ -881,7 +862,9 @@ class RegisterRestServlet(RestServlet): return result - async def _do_guest_registration(self, params, address=None): + async def _do_guest_registration( + self, params: JsonDict, address: Optional[str] = None + ) -> Tuple[int, JsonDict]: if not self.hs.config.allow_guest_access: raise SynapseError(403, "Guest access is disabled") user_id = await self.registration_handler.register_user( @@ -901,7 +884,7 @@ class RegisterRestServlet(RestServlet): user_id, device_id, initial_display_name, is_guest=True ) - result = { + result: JsonDict = { "user_id": user_id, "device_id": device_id, "access_token": access_token, @@ -973,9 +956,7 @@ def _map_email_to_displayname(address): def _calculate_registration_flows( - # technically `config` has to provide *all* of these interfaces, not just one - config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig], - auth_handler: AuthHandler, + config: HomeServerConfig, auth_handler: AuthHandler ) -> List[List[str]]: """Get a suitable flows list for registration @@ -1054,7 +1035,7 @@ def _calculate_registration_flows( return flows -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: EmailRegisterRequestTokenRestServlet(hs).register(http_server) MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) UsernameAvailabilityRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 0821cd285f..0b0711c03c 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -19,25 +19,32 @@ any time to reflect changes in the MSC. """ import logging +from typing import TYPE_CHECKING, Awaitable, Optional, Tuple from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import ShadowBanError, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, parse_integer, parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.rest.client.transactions import HttpTransactionCache from synapse.storage.relations import ( AggregationPaginationToken, PaginationChunk, RelationPaginationToken, ) +from synapse.types import JsonDict from synapse.util.stringutils import random_string from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -59,13 +66,13 @@ class RelationSendServlet(RestServlet): "/(?P<parent_id>[^/]*)/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.event_creation_handler = hs.get_event_creation_handler() self.txns = HttpTransactionCache(hs) - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: http_server.register_paths( "POST", client_patterns(self.PATTERN + "$", releases=()), @@ -79,14 +86,35 @@ class RelationSendServlet(RestServlet): self.__class__.__name__, ) - def on_PUT(self, request, *args, **kwargs): + def on_PUT( + self, + request: SynapseRequest, + room_id: str, + parent_id: str, + relation_type: str, + event_type: str, + txn_id: Optional[str] = None, + ) -> Awaitable[Tuple[int, JsonDict]]: return self.txns.fetch_or_execute_request( - request, self.on_PUT_or_POST, request, *args, **kwargs + request, + self.on_PUT_or_POST, + request, + room_id, + parent_id, + relation_type, + event_type, + txn_id, ) async def on_PUT_or_POST( - self, request, room_id, parent_id, relation_type, event_type, txn_id=None - ): + self, + request: SynapseRequest, + room_id: str, + parent_id: str, + relation_type: str, + event_type: str, + txn_id: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) if event_type == EventTypes.Member: @@ -136,7 +164,7 @@ class RelationPaginationServlet(RestServlet): releases=(), ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -145,8 +173,13 @@ class RelationPaginationServlet(RestServlet): self.event_handler = hs.get_event_handler() async def on_GET( - self, request, room_id, parent_id, relation_type=None, event_type=None - ): + self, + request: SynapseRequest, + room_id: str, + parent_id: str, + relation_type: Optional[str] = None, + event_type: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.check_user_in_room_or_world_readable( @@ -156,6 +189,8 @@ class RelationPaginationServlet(RestServlet): # This gets the original event and checks that a) the event exists and # b) the user is allowed to view it. event = await self.event_handler.get_event(requester.user, room_id, parent_id) + if event is None: + raise SynapseError(404, "Unknown parent event.") limit = parse_integer(request, "limit", default=5) from_token_str = parse_string(request, "from") @@ -233,15 +268,20 @@ class RelationAggregationPaginationServlet(RestServlet): releases=(), ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.event_handler = hs.get_event_handler() async def on_GET( - self, request, room_id, parent_id, relation_type=None, event_type=None - ): + self, + request: SynapseRequest, + room_id: str, + parent_id: str, + relation_type: Optional[str] = None, + event_type: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.check_user_in_room_or_world_readable( @@ -253,6 +293,8 @@ class RelationAggregationPaginationServlet(RestServlet): # This checks that a) the event exists and b) the user is allowed to # view it. event = await self.event_handler.get_event(requester.user, room_id, parent_id) + if event is None: + raise SynapseError(404, "Unknown parent event.") if relation_type not in (RelationTypes.ANNOTATION, None): raise SynapseError(400, "Relation type must be 'annotation'") @@ -315,7 +357,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): releases=(), ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -323,7 +365,15 @@ class RelationAggregationGroupPaginationServlet(RestServlet): self._event_serializer = hs.get_event_client_serializer() self.event_handler = hs.get_event_handler() - async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): + async def on_GET( + self, + request: SynapseRequest, + room_id: str, + parent_id: str, + relation_type: str, + event_type: str, + key: str, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.check_user_in_room_or_world_readable( @@ -374,7 +424,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): return 200, return_value -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RelationSendServlet(hs).register(http_server) RelationPaginationServlet(hs).register(http_server) RelationAggregationPaginationServlet(hs).register(http_server) diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/report_event.py index 07ea39a8a3..d4a4adb50c 100644 --- a/synapse/rest/client/report_event.py +++ b/synapse/rest/client/report_event.py @@ -14,26 +14,35 @@ import logging from http import HTTPStatus +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class ReportEventRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.clock = hs.get_clock() self.store = hs.get_datastore() - async def on_POST(self, request, room_id, event_id): + async def on_POST( + self, request: SynapseRequest, room_id: str, event_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() @@ -64,5 +73,5 @@ class ReportEventRestServlet(RestServlet): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReportEventRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index bd6885e5dc..f5c5be3173 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -16,9 +16,11 @@ """ This module contains REST servlets to do with rooms: /rooms/<paths> """ import logging import re -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple from urllib import parse as urlparse +from twisted.web.server import Request + from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, @@ -30,6 +32,7 @@ from synapse.api.errors import ( ) from synapse.api.filtering import Filter from synapse.events.utils import format_event_for_client_v2 +from synapse.http.server import HttpServer from synapse.http.servlet import ( ResolveRoomIdMixin, RestServlet, @@ -57,7 +60,7 @@ logger = logging.getLogger(__name__) class TransactionRestServlet(RestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.txns = HttpTransactionCache(hs) @@ -65,20 +68,22 @@ class TransactionRestServlet(RestServlet): class RoomCreateRestServlet(TransactionRestServlet): # No PATTERN; we have custom dispatch rules here - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self._room_creation_handler = hs.get_room_creation_handler() self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: PATTERNS = "/createRoom" register_txn_path(self, PATTERNS, http_server) - def on_PUT(self, request, txn_id): + def on_PUT( + self, request: SynapseRequest, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request(request, self.on_POST, request) - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) info, _ = await self._room_creation_handler.create_room( @@ -87,21 +92,21 @@ class RoomCreateRestServlet(TransactionRestServlet): return 200, info - def get_room_config(self, request): + def get_room_config(self, request: Request) -> JsonDict: user_supplied_config = parse_json_object_from_request(request) return user_supplied_config # TODO: Needs unit testing for generic events class RoomStateEventRestServlet(TransactionRestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: # /room/$roomid/state/$eventtype no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$" @@ -136,13 +141,19 @@ class RoomStateEventRestServlet(TransactionRestServlet): self.__class__.__name__, ) - def on_GET_no_state_key(self, request, room_id, event_type): + def on_GET_no_state_key( + self, request: SynapseRequest, room_id: str, event_type: str + ) -> Awaitable[Tuple[int, JsonDict]]: return self.on_GET(request, room_id, event_type, "") - def on_PUT_no_state_key(self, request, room_id, event_type): + def on_PUT_no_state_key( + self, request: SynapseRequest, room_id: str, event_type: str + ) -> Awaitable[Tuple[int, JsonDict]]: return self.on_PUT(request, room_id, event_type, "") - async def on_GET(self, request, room_id, event_type, state_key): + async def on_GET( + self, request: SynapseRequest, room_id: str, event_type: str, state_key: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) format = parse_string( request, "format", default="content", allowed_values=["content", "event"] @@ -165,7 +176,17 @@ class RoomStateEventRestServlet(TransactionRestServlet): elif format == "content": return 200, data.get_dict()["content"] - async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): + # Format must be event or content, per the parse_string call above. + raise RuntimeError(f"Unknown format: {format:r}.") + + async def on_PUT( + self, + request: SynapseRequest, + room_id: str, + event_type: str, + state_key: str, + txn_id: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if txn_id: @@ -211,27 +232,35 @@ class RoomStateEventRestServlet(TransactionRestServlet): # TODO: Needs unit testing for generic events + feedback class RoomSendEventRestServlet(TransactionRestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: # /rooms/$roomid/send/$event_type[/$txn_id] PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)" register_txn_path(self, PATTERNS, http_server, with_get=True) - async def on_POST(self, request, room_id, event_type, txn_id=None): + async def on_POST( + self, + request: SynapseRequest, + room_id: str, + event_type: str, + txn_id: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) - event_dict = { + event_dict: JsonDict = { "type": event_type, "content": content, "room_id": room_id, "sender": requester.user.to_string(), } + # Twisted will have processed the args by now. + assert request.args is not None if b"ts" in request.args and requester.app_service: event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) @@ -249,10 +278,14 @@ class RoomSendEventRestServlet(TransactionRestServlet): set_tag("event_id", event_id) return 200, {"event_id": event_id} - def on_GET(self, request, room_id, event_type, txn_id): + def on_GET( + self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str + ) -> Tuple[int, str]: return 200, "Not implemented" - def on_PUT(self, request, room_id, event_type, txn_id): + def on_PUT( + self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( @@ -262,12 +295,12 @@ class RoomSendEventRestServlet(TransactionRestServlet): # TODO: Needs unit testing for room ID + alias joins class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) super(ResolveRoomIdMixin, self).__init__(hs) # ensure the Mixin is set up self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: # /join/$room_identifier[/$txn_id] PATTERNS = "/join/(?P<room_identifier>[^/]*)" register_txn_path(self, PATTERNS, http_server) @@ -277,7 +310,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): request: SynapseRequest, room_identifier: str, txn_id: Optional[str] = None, - ): + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) try: @@ -308,7 +341,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): return 200, {"room_id": room_id} - def on_PUT(self, request, room_identifier, txn_id): + def on_PUT( + self, request: SynapseRequest, room_identifier: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( @@ -320,12 +355,12 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): class PublicRoomListRestServlet(TransactionRestServlet): PATTERNS = client_patterns("/publicRooms$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs self.auth = hs.get_auth() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: server = parse_string(request, "server") try: @@ -353,7 +388,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): limit = None handler = self.hs.get_room_list_handler() - if server and server != self.hs.config.server_name: + if server and server != self.hs.config.server.server_name: # Ensure the server is valid. try: parse_and_validate_server_name(server) @@ -374,7 +409,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): return 200, data - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await self.auth.get_user_by_req(request, allow_guest=True) server = parse_string(request, "server") @@ -403,7 +438,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): limit = None handler = self.hs.get_room_list_handler() - if server and server != self.hs.config.server_name: + if server and server != self.hs.config.server.server_name: # Ensure the server is valid. try: parse_and_validate_server_name(server) @@ -438,13 +473,15 @@ class PublicRoomListRestServlet(TransactionRestServlet): class RoomMemberListRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/members$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: # TODO support Pagination stream API (limit/tokens) requester = await self.auth.get_user_by_req(request, allow_guest=True) handler = self.message_handler @@ -490,12 +527,14 @@ class RoomMemberListRestServlet(RestServlet): class JoinedRoomMemberListRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) users_with_profile = await self.message_handler.get_joined_members( @@ -509,17 +548,21 @@ class JoinedRoomMemberListRestServlet(RestServlet): class RoomMessageListRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/messages$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) pagination_config = await PaginationConfig.from_request( self.store, request, default_limit=10 ) + # Twisted will have processed the args by now. + assert request.args is not None as_client_event = b"raw" not in request.args filter_str = parse_string(request, "filter", encoding="utf-8") if filter_str: @@ -549,12 +592,14 @@ class RoomMessageListRestServlet(RestServlet): class RoomStateRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/state$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, List[JsonDict]]: requester = await self.auth.get_user_by_req(request, allow_guest=True) # Get all the current state for this room events = await self.message_handler.get_state_events( @@ -569,13 +614,15 @@ class RoomStateRestServlet(RestServlet): class RoomInitialSyncRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) pagination_config = await PaginationConfig.from_request(self.store, request) content = await self.initial_sync_handler.room_initial_sync( @@ -589,14 +636,16 @@ class RoomEventServlet(RestServlet): "/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$", v1=True ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.clock = hs.get_clock() self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() - async def on_GET(self, request, room_id, event_id): + async def on_GET( + self, request: SynapseRequest, room_id: str, event_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) try: event = await self.event_handler.get_event( @@ -610,10 +659,10 @@ class RoomEventServlet(RestServlet): time_now = self.clock.time_msec() if event: - event = await self._event_serializer.serialize_event(event, time_now) - return 200, event + event_dict = await self._event_serializer.serialize_event(event, time_now) + return 200, event_dict - return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) + raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) class RoomEventContextServlet(RestServlet): @@ -621,14 +670,16 @@ class RoomEventContextServlet(RestServlet): "/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$", v1=True ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.clock = hs.get_clock() self.room_context_handler = hs.get_room_context_handler() self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() - async def on_GET(self, request, room_id, event_id): + async def on_GET( + self, request: SynapseRequest, room_id: str, event_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) limit = parse_integer(request, "limit", default=10) @@ -669,23 +720,27 @@ class RoomEventContextServlet(RestServlet): class RoomForgetRestServlet(TransactionRestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget" register_txn_path(self, PATTERNS, http_server) - async def on_POST(self, request, room_id, txn_id=None): + async def on_POST( + self, request: SynapseRequest, room_id: str, txn_id: Optional[str] = None + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=False) await self.room_member_handler.forget(user=requester.user, room_id=room_id) return 200, {} - def on_PUT(self, request, room_id, txn_id): + def on_PUT( + self, request: SynapseRequest, room_id: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( @@ -695,12 +750,12 @@ class RoomForgetRestServlet(TransactionRestServlet): # TODO: Needs unit testing class RoomMembershipRestServlet(TransactionRestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: # /rooms/$roomid/[invite|join|leave] PATTERNS = ( "/rooms/(?P<room_id>[^/]*)/" @@ -708,7 +763,13 @@ class RoomMembershipRestServlet(TransactionRestServlet): ) register_txn_path(self, PATTERNS, http_server) - async def on_POST(self, request, room_id, membership_action, txn_id=None): + async def on_POST( + self, + request: SynapseRequest, + room_id: str, + membership_action: str, + txn_id: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) if requester.is_guest and membership_action not in { @@ -771,13 +832,15 @@ class RoomMembershipRestServlet(TransactionRestServlet): return 200, return_value - def _has_3pid_invite_keys(self, content): + def _has_3pid_invite_keys(self, content: JsonDict) -> bool: for key in {"id_server", "medium", "address"}: if key not in content: return False return True - def on_PUT(self, request, room_id, membership_action, txn_id): + def on_PUT( + self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( @@ -786,16 +849,22 @@ class RoomMembershipRestServlet(TransactionRestServlet): class RoomRedactEventRestServlet(TransactionRestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)" register_txn_path(self, PATTERNS, http_server) - async def on_POST(self, request, room_id, event_id, txn_id=None): + async def on_POST( + self, + request: SynapseRequest, + room_id: str, + event_id: str, + txn_id: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) @@ -821,7 +890,9 @@ class RoomRedactEventRestServlet(TransactionRestServlet): set_tag("event_id", event_id) return 200, {"event_id": event_id} - def on_PUT(self, request, room_id, event_id, txn_id): + def on_PUT( + self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( @@ -846,7 +917,9 @@ class RoomTypingRestServlet(RestServlet): hs.config.worker.writers.typing == hs.get_instance_name() ) - async def on_PUT(self, request, room_id, user_id): + async def on_PUT( + self, request: SynapseRequest, room_id: str, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if not self._is_typing_writer: @@ -897,7 +970,9 @@ class RoomAliasListServlet(RestServlet): self.auth = hs.get_auth() self.directory_handler = hs.get_directory_handler() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) alias_list = await self.directory_handler.get_aliases_for_room( @@ -910,12 +985,12 @@ class RoomAliasListServlet(RestServlet): class SearchRestServlet(RestServlet): PATTERNS = client_patterns("/search$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.search_handler = hs.get_search_handler() self.auth = hs.get_auth() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) @@ -929,19 +1004,24 @@ class SearchRestServlet(RestServlet): class JoinedRoomsRestServlet(RestServlet): PATTERNS = client_patterns("/joined_rooms$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) room_ids = await self.store.get_rooms_for_user(requester.user.to_string()) return 200, {"joined_rooms": list(room_ids)} -def register_txn_path(servlet, regex_string, http_server, with_get=False): +def register_txn_path( + servlet: RestServlet, + regex_string: str, + http_server: HttpServer, + with_get: bool = False, +) -> None: """Registers a transaction-based path. This registers two paths: @@ -949,28 +1029,37 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False): POST regex_string Args: - regex_string (str): The regex string to register. Must NOT have a - trailing $ as this string will be appended to. - http_server : The http_server to register paths with. + regex_string: The regex string to register. Must NOT have a + trailing $ as this string will be appended to. + http_server: The http_server to register paths with. with_get: True to also register respective GET paths for the PUTs. """ + on_POST = getattr(servlet, "on_POST", None) + on_PUT = getattr(servlet, "on_PUT", None) + if on_POST is None or on_PUT is None: + raise RuntimeError("on_POST and on_PUT must exist when using register_txn_path") http_server.register_paths( "POST", client_patterns(regex_string + "$", v1=True), - servlet.on_POST, + on_POST, servlet.__class__.__name__, ) http_server.register_paths( "PUT", client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True), - servlet.on_PUT, + on_PUT, servlet.__class__.__name__, ) + on_GET = getattr(servlet, "on_GET", None) if with_get: + if on_GET is None: + raise RuntimeError( + "register_txn_path called with with_get = True, but no on_GET method exists" + ) http_server.register_paths( "GET", client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True), - servlet.on_GET, + on_GET, servlet.__class__.__name__, ) @@ -1120,7 +1209,9 @@ class RoomSummaryRestServlet(ResolveRoomIdMixin, RestServlet): ) -def register_servlets(hs: "HomeServer", http_server, is_worker=False): +def register_servlets( + hs: "HomeServer", http_server: HttpServer, is_worker: bool = False +) -> None: RoomStateEventRestServlet(hs).register(http_server) RoomMemberListRestServlet(hs).register(http_server) JoinedRoomMemberListRestServlet(hs).register(http_server) @@ -1148,5 +1239,5 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False): RoomForgetRestServlet(hs).register(http_server) -def register_deprecated_servlets(hs, http_server): +def register_deprecated_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomInitialSyncRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index 3172aba605..ed96978448 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -14,10 +14,14 @@ import logging import re +from typing import TYPE_CHECKING, Awaitable, List, Tuple + +from twisted.web.server import Request from synapse.api.constants import EventContentFields, EventTypes from synapse.api.errors import AuthError, Codes, SynapseError from synapse.appservice import ApplicationService +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -25,10 +29,14 @@ from synapse.http.servlet import ( parse_string, parse_strings_from_args, ) +from synapse.http.site import SynapseRequest from synapse.rest.client.transactions import HttpTransactionCache -from synapse.types import Requester, UserID, create_requester +from synapse.types import JsonDict, Requester, UserID, create_requester from synapse.util.stringutils import random_string +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -66,7 +74,7 @@ class RoomBatchSendEventRestServlet(RestServlet): ), ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.store = hs.get_datastore() @@ -76,7 +84,7 @@ class RoomBatchSendEventRestServlet(RestServlet): self.auth = hs.get_auth() self.txns = HttpTransactionCache(hs) - async def _inherit_depth_from_prev_ids(self, prev_event_ids) -> int: + async def _inherit_depth_from_prev_ids(self, prev_event_ids: List[str]) -> int: ( most_recent_prev_event_id, most_recent_prev_event_depth, @@ -118,7 +126,7 @@ class RoomBatchSendEventRestServlet(RestServlet): def _create_insertion_event_dict( self, sender: str, room_id: str, origin_server_ts: int - ): + ) -> JsonDict: """Creates an event dict for an "insertion" event with the proper fields and a random chunk ID. @@ -128,7 +136,7 @@ class RoomBatchSendEventRestServlet(RestServlet): origin_server_ts: Timestamp when the event was sent Returns: - Tuple of event ID and stream ordering position + The new event dictionary to insert. """ next_chunk_id = random_string(8) @@ -164,7 +172,9 @@ class RoomBatchSendEventRestServlet(RestServlet): return create_requester(user_id, app_service=app_service) - async def on_POST(self, request, room_id): + async def on_POST( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=False) if not requester.app_service: @@ -176,6 +186,7 @@ class RoomBatchSendEventRestServlet(RestServlet): body = parse_json_object_from_request(request) assert_params_in_dict(body, ["state_events_at_start", "events"]) + assert request.args is not None prev_events_from_query = parse_strings_from_args(request.args, "prev_event") chunk_id_from_query = parse_string(request, "chunk_id") @@ -425,16 +436,18 @@ class RoomBatchSendEventRestServlet(RestServlet): ], } - def on_GET(self, request, room_id): + def on_GET(self, request: Request, room_id: str) -> Tuple[int, str]: return 501, "Not implemented" - def on_PUT(self, request, room_id): + def on_PUT( + self, request: SynapseRequest, room_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: return self.txns.fetch_or_execute_request( request, self.on_POST, request, room_id ) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: msc2716_enabled = hs.config.experimental.msc2716_enabled if msc2716_enabled: diff --git a/synapse/rest/client/room_keys.py b/synapse/rest/client/room_keys.py index 263596be86..37e39570f6 100644 --- a/synapse/rest/client/room_keys.py +++ b/synapse/rest/client/room_keys.py @@ -13,16 +13,23 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Optional, Tuple from synapse.api.errors import Codes, NotFoundError, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -31,16 +38,14 @@ class RoomKeysServlet(RestServlet): "/room_keys/keys(/(?P<room_id>[^/]+))?(/(?P<session_id>[^/]+))?$" ) - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - async def on_PUT(self, request, room_id, session_id): + async def on_PUT( + self, request: SynapseRequest, room_id: Optional[str], session_id: Optional[str] + ) -> Tuple[int, JsonDict]: """ Uploads one or more encrypted E2E room keys for backup purposes. room_id: the ID of the room the keys are for (optional) @@ -133,7 +138,9 @@ class RoomKeysServlet(RestServlet): ret = await self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) return 200, ret - async def on_GET(self, request, room_id, session_id): + async def on_GET( + self, request: SynapseRequest, room_id: Optional[str], session_id: Optional[str] + ) -> Tuple[int, JsonDict]: """ Retrieves one or more encrypted E2E room keys for backup purposes. Symmetric with the PUT version of the API. @@ -215,7 +222,9 @@ class RoomKeysServlet(RestServlet): return 200, room_keys - async def on_DELETE(self, request, room_id, session_id): + async def on_DELETE( + self, request: SynapseRequest, room_id: Optional[str], session_id: Optional[str] + ) -> Tuple[int, JsonDict]: """ Deletes one or more encrypted E2E room keys for a user for backup purposes. @@ -242,16 +251,12 @@ class RoomKeysServlet(RestServlet): class RoomKeysNewVersionServlet(RestServlet): PATTERNS = client_patterns("/room_keys/version$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: """ Create a new backup version for this user's room_keys with the given info. The version is allocated by the server and returned to the user @@ -295,16 +300,14 @@ class RoomKeysNewVersionServlet(RestServlet): class RoomKeysVersionServlet(RestServlet): PATTERNS = client_patterns("/room_keys/version(/(?P<version>[^/]+))?$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - async def on_GET(self, request, version): + async def on_GET( + self, request: SynapseRequest, version: Optional[str] + ) -> Tuple[int, JsonDict]: """ Retrieve the version information about a given version of the user's room_keys backup. If the version part is missing, returns info about the @@ -332,7 +335,9 @@ class RoomKeysVersionServlet(RestServlet): raise SynapseError(404, "No backup found", Codes.NOT_FOUND) return 200, info - async def on_DELETE(self, request, version): + async def on_DELETE( + self, request: SynapseRequest, version: Optional[str] + ) -> Tuple[int, JsonDict]: """ Delete the information about a given version of the user's room_keys backup. If the version part is missing, deletes the most @@ -351,7 +356,9 @@ class RoomKeysVersionServlet(RestServlet): await self.e2e_room_keys_handler.delete_version(user_id, version) return 200, {} - async def on_PUT(self, request, version): + async def on_PUT( + self, request: SynapseRequest, version: Optional[str] + ) -> Tuple[int, JsonDict]: """ Update the information about a given version of the user's room_keys backup. @@ -385,7 +392,7 @@ class RoomKeysVersionServlet(RestServlet): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomKeysServlet(hs).register(http_server) RoomKeysVersionServlet(hs).register(http_server) RoomKeysNewVersionServlet(hs).register(http_server) diff --git a/synapse/rest/client/sendtodevice.py b/synapse/rest/client/sendtodevice.py index d537d811d8..3322c8ef48 100644 --- a/synapse/rest/client/sendtodevice.py +++ b/synapse/rest/client/sendtodevice.py @@ -13,15 +13,21 @@ # limitations under the License. import logging -from typing import Tuple +from typing import TYPE_CHECKING, Awaitable, Tuple from synapse.http import servlet +from synapse.http.server import HttpServer from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request +from synapse.http.site import SynapseRequest from synapse.logging.opentracing import set_tag, trace from synapse.rest.client.transactions import HttpTransactionCache +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -30,11 +36,7 @@ class SendToDeviceRestServlet(servlet.RestServlet): "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$" ) - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -42,14 +44,18 @@ class SendToDeviceRestServlet(servlet.RestServlet): self.device_message_handler = hs.get_device_message_handler() @trace(opname="sendToDevice") - def on_PUT(self, request, message_type, txn_id): + def on_PUT( + self, request: SynapseRequest, message_type: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("message_type", message_type) set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( request, self._put, request, message_type, txn_id ) - async def _put(self, request, message_type, txn_id): + async def _put( + self, request: SynapseRequest, message_type: str, txn_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) @@ -59,9 +65,8 @@ class SendToDeviceRestServlet(servlet.RestServlet): requester, message_type, content["messages"] ) - response: Tuple[int, dict] = (200, {}) - return response + return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: SendToDeviceRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 65c37be3e9..1259058b9b 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -14,12 +14,24 @@ import itertools import logging from collections import defaultdict -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, +) from synapse.api.constants import Membership, PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection from synapse.api.presence import UserPresenceState +from synapse.events import EventBase from synapse.events.utils import ( format_event_for_client_v2_without_room_id, format_event_raw, @@ -504,7 +516,7 @@ class SyncRestServlet(RestServlet): The room, encoded in our response format """ - def serialize(events): + def serialize(events: Iterable[EventBase]) -> Awaitable[List[JsonDict]]: return self._event_serializer.serialize_events( events, time_now=time_now, diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 94ff3719ce..914fb3acf5 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -15,28 +15,37 @@ """This module contains logic for storing HTTP PUT transactions. This is used to ensure idempotency when performing PUTs using the REST API.""" import logging +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Tuple + +from twisted.python.failure import Failure +from twisted.web.server import Request from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.types import JsonDict from synapse.util.async_helpers import ObservableDeferred +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins class HttpTransactionCache: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = self.hs.get_auth() self.clock = self.hs.get_clock() - self.transactions = { - # $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp) - } + # $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp) + self.transactions: Dict[ + str, Tuple[ObservableDeferred[Tuple[int, JsonDict]], int] + ] = {} # Try to clean entries every 30 mins. This means entries will exist # for at *LEAST* 30 mins, and at *MOST* 60 mins. self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS) - def _get_transaction_key(self, request): + def _get_transaction_key(self, request: Request) -> str: """A helper function which returns a transaction key that can be used with TransactionCache for idempotent requests. @@ -45,15 +54,21 @@ class HttpTransactionCache: path and the access_token for the requesting user. Args: - request (twisted.web.http.Request): The incoming request. Must - contain an access_token. + request: The incoming request. Must contain an access_token. Returns: - str: A transaction key + A transaction key """ + assert request.path is not None token = self.auth.get_access_token_from_request(request) return request.path.decode("utf8") + "/" + token - def fetch_or_execute_request(self, request, fn, *args, **kwargs): + def fetch_or_execute_request( + self, + request: Request, + fn: Callable[..., Awaitable[Tuple[int, JsonDict]]], + *args: Any, + **kwargs: Any, + ) -> Awaitable[Tuple[int, JsonDict]]: """A helper function for fetch_or_execute which extracts a transaction key from the given request. @@ -64,15 +79,20 @@ class HttpTransactionCache: self._get_transaction_key(request), fn, *args, **kwargs ) - def fetch_or_execute(self, txn_key, fn, *args, **kwargs): + def fetch_or_execute( + self, + txn_key: str, + fn: Callable[..., Awaitable[Tuple[int, JsonDict]]], + *args: Any, + **kwargs: Any, + ) -> Awaitable[Tuple[int, JsonDict]]: """Fetches the response for this transaction, or executes the given function to produce a response for this transaction. Args: - txn_key (str): A key to ensure idempotency should fetch_or_execute be - called again at a later point in time. - fn (function): A function which returns a tuple of - (response_code, response_dict). + txn_key: A key to ensure idempotency should fetch_or_execute be + called again at a later point in time. + fn: A function which returns a tuple of (response_code, response_dict). *args: Arguments to pass to fn. **kwargs: Keyword arguments to pass to fn. Returns: @@ -90,7 +110,7 @@ class HttpTransactionCache: # if the request fails with an exception, remove it # from the transaction map. This is done to ensure that we don't # cache transient errors like rate-limiting errors, etc. - def remove_from_map(err): + def remove_from_map(err: Failure) -> None: self.transactions.pop(txn_key, None) # we deliberately do not propagate the error any further, as we # expect the observers to have reported it. @@ -99,7 +119,7 @@ class HttpTransactionCache: return make_deferred_yieldable(observable.observe()) - def _cleanup(self): + def _cleanup(self) -> None: now = self.clock.time_msec() for key in list(self.transactions): ts = self.transactions[key][1] |